lib: rename bt_value_array_get_size() -> bt_value_array_get_length()
[babeltrace.git] / src / bindings / python / bt2 / bt2 / value.py
index b6fb6769fbe938fee8a2adcc29cff5227aeac08f..8c8ef60c3de1a74bca8af47eb72ba203f5428c21 100644 (file)
@@ -29,15 +29,15 @@ import abc
 import bt2
 
 
-def _handle_status(status, obj_name):
-    if status >= 0:
+def _create_from_ptr(ptr):
+    if ptr is None:
         return
-    else:
-        raise RuntimeError('unexpected error')
-
 
-def _create_from_ptr(ptr):
-    if ptr is None or ptr == native_bt.value_null:
+    # bt_value_null is translated to None.  However, we are given a reference
+    # to it that we are not going to manage anymore, since we don't create a
+    # Python wrapper for it.  Therefore put that reference immediately.
+    if ptr == native_bt.value_null:
+        _Value._put_ref(ptr)
         return
 
     typeid = native_bt.value_get_type(ptr)
@@ -63,75 +63,46 @@ def create_value(value):
     if isinstance(value, bool):
         return BoolValue(value)
 
-    if isinstance(value, int):
+    if isinstance(value, numbers.Integral):
         return SignedIntegerValue(value)
 
-    if isinstance(value, float):
+    if isinstance(value, numbers.Real):
         return RealValue(value)
 
     if isinstance(value, str):
         return StringValue(value)
 
-    try:
-        return MapValue(value)
-    except:
-        pass
-
-    try:
+    if isinstance(value, collections.abc.Sequence):
         return ArrayValue(value)
-    except:
-        pass
 
-    raise TypeError("cannot create value object from '{}' object".format(value.__class__.__name__))
+    if isinstance(value, collections.abc.Mapping):
+        return MapValue(value)
+
+    raise TypeError(
+        "cannot create value object from '{}' object".format(value.__class__.__name__)
+    )
 
 
 class _Value(object._SharedObject, metaclass=abc.ABCMeta):
     _get_ref = staticmethod(native_bt.value_get_ref)
     _put_ref = staticmethod(native_bt.value_put_ref)
 
-    def __eq__(self, other):
-        if other is None:
-            # self is never the null value object
-            return False
-
-        # try type-specific comparison first
-        spec_eq = self._spec_eq(other)
-
-        if spec_eq is not None:
-            return spec_eq
-
-        if not isinstance(other, _Value):
-            # not comparing apples to apples
-            return False
-
-        # fall back to native comparison function
-        return native_bt.value_compare(self._ptr, other._ptr)
-
     def __ne__(self, other):
         return not (self == other)
 
-    @abc.abstractmethod
-    def _spec_eq(self, other):
-        pass
-
-    def _handle_status(self, status):
-        _handle_status(status, self._NAME)
-
     def _check_create_status(self, ptr):
         if ptr is None:
-            raise bt2.CreationError(
-                'cannot create {} value object'.format(self._NAME.lower()))
+            raise bt2._MemoryError(
+                'cannot create {} value object'.format(self._NAME.lower())
+            )
 
 
 @functools.total_ordering
 class _NumericValue(_Value):
     @staticmethod
     def _extract_value(other):
-        if isinstance(other, _NumericValue):
-            return other._value
-
-        if other is True or other is False:
-            return other
+        if isinstance(other, BoolValue) or isinstance(other, bool):
+            return bool(other)
 
         if isinstance(other, numbers.Integral):
             return int(other)
@@ -142,7 +113,9 @@ class _NumericValue(_Value):
         if isinstance(other, numbers.Complex):
             return complex(other)
 
-        raise TypeError("'{}' object is not a number object".format(other.__class__.__name__))
+        raise TypeError(
+            "'{}' object is not a number object".format(other.__class__.__name__)
+        )
 
     def __int__(self):
         return int(self._value)
@@ -154,21 +127,14 @@ class _NumericValue(_Value):
         return repr(self._value)
 
     def __lt__(self, other):
-        if not isinstance(other, numbers.Number):
-            raise TypeError('unorderable types: {}() < {}()'.format(self.__class__.__name__,
-                                                                    other.__class__.__name__))
-
         return self._value < self._extract_value(other)
 
-    def _spec_eq(self, other):
-        pass
-
     def __eq__(self, other):
-        if not isinstance(other, numbers.Number):
+        try:
+            return self._value == self._extract_value(other)
+        except:
             return False
 
-        return self._value == self._extract_value(other)
-
     def __rmod__(self, other):
         return self._extract_value(other) % self._value
 
@@ -229,34 +195,6 @@ class _NumericValue(_Value):
     def __rpow__(self, base):
         return self._extract_value(base) ** self._value
 
-    def __iadd__(self, other):
-        self.value = self + other
-        return self
-
-    def __isub__(self, other):
-        self.value = self - other
-        return self
-
-    def __imul__(self, other):
-        self.value = self * other
-        return self
-
-    def __itruediv__(self, other):
-        self.value = self / other
-        return self
-
-    def __ifloordiv__(self, other):
-        self.value = self // other
-        return self
-
-    def __imod__(self, other):
-        self.value = self % other
-        return self
-
-    def __ipow__(self, other):
-        self.value = self ** other
-        return self
-
 
 class _IntegralValue(_NumericValue, numbers.Integral):
     def __lshift__(self, other):
@@ -292,32 +230,12 @@ class _IntegralValue(_NumericValue, numbers.Integral):
     def __invert__(self):
         return ~self._value
 
-    def __ilshift__(self, other):
-        self.value = self << other
-        return self
-
-    def __irshift__(self, other):
-        self.value = self >> other
-        return self
-
-    def __iand__(self, other):
-        self.value = self & other
-        return self
-
-    def __ixor__(self, other):
-        self.value = self ^ other
-        return self
-
-    def __ior__(self, other):
-        self.value = self | other
-        return self
-
 
 class _RealValue(_NumericValue, numbers.Real):
     pass
 
 
-class BoolValue(_Value):
+class BoolValue(_IntegralValue):
     _NAME = 'Boolean'
 
     def __init__(self, value=None):
@@ -329,10 +247,6 @@ class BoolValue(_Value):
         self._check_create_status(ptr)
         super().__init__(ptr)
 
-    def _spec_eq(self, other):
-        if isinstance(other, numbers.Number):
-            return self._value == bool(other)
-
     def __bool__(self):
         return self._value
 
@@ -344,9 +258,13 @@ class BoolValue(_Value):
             value = value._value
 
         if not isinstance(value, bool):
-            raise TypeError("'{}' object is not a 'bool' or 'BoolValue' object".format(value.__class__))
+            raise TypeError(
+                "'{}' object is not a 'bool' or 'BoolValue' object".format(
+                    value.__class__
+                )
+            )
 
-        return int(value)
+        return value
 
     @property
     def _value(self):
@@ -370,8 +288,8 @@ class _IntegerValue(_IntegralValue):
         super().__init__(ptr)
 
     def _value_to_int(self, value):
-        if not isinstance(value, numbers.Real):
-            raise TypeError('expecting a number object')
+        if not isinstance(value, numbers.Integral):
+            raise TypeError('expecting an integral number object')
 
         value = int(value)
         self._check_int_range(value)
@@ -389,18 +307,18 @@ class _IntegerValue(_IntegralValue):
 
 class UnsignedIntegerValue(_IntegerValue):
     _check_int_range = staticmethod(utils._check_uint64)
-    _create_default_value = staticmethod(native_bt.value_unsigned_integer_create)
-    _create_value = staticmethod(native_bt.value_unsigned_integer_create_init)
-    _set_value = staticmethod(native_bt.value_unsigned_integer_set)
-    _get_value = staticmethod(native_bt.value_unsigned_integer_get)
+    _create_default_value = staticmethod(native_bt.value_integer_unsigned_create)
+    _create_value = staticmethod(native_bt.value_integer_unsigned_create_init)
+    _set_value = staticmethod(native_bt.value_integer_unsigned_set)
+    _get_value = staticmethod(native_bt.value_integer_unsigned_get)
 
 
 class SignedIntegerValue(_IntegerValue):
     _check_int_range = staticmethod(utils._check_int64)
-    _create_default_value = staticmethod(native_bt.value_signed_integer_create)
-    _create_value = staticmethod(native_bt.value_signed_integer_create_init)
-    _set_value = staticmethod(native_bt.value_signed_integer_set)
-    _get_value = staticmethod(native_bt.value_signed_integer_get)
+    _create_default_value = staticmethod(native_bt.value_integer_signed_create)
+    _create_value = staticmethod(native_bt.value_integer_signed_create_init)
+    _set_value = staticmethod(native_bt.value_integer_signed_set)
+    _get_value = staticmethod(native_bt.value_integer_signed_get)
 
 
 class RealValue(_RealValue):
@@ -458,15 +376,15 @@ class StringValue(collections.abc.Sequence, _Value):
 
     def _set_value(self, value):
         status = native_bt.value_string_set(self._ptr, self._value_to_str(value))
-        self._handle_status(status)
+        utils._handle_func_status(status)
 
     value = property(fset=_set_value)
 
-    def _spec_eq(self, other):
+    def __eq__(self, other):
         try:
             return self._value == self._value_to_str(other)
         except:
-            return
+            return False
 
     def __lt__(self, other):
         return self._value < self._value_to_str(other)
@@ -515,29 +433,33 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value):
             for elem in value:
                 self.append(elem)
 
-    def _spec_eq(self, other):
-        try:
-            if len(self) != len(other):
-                # early mismatch
-                return False
+    def __eq__(self, other):
+        if not isinstance(other, collections.abc.Sequence):
+            return False
 
-            for self_elem, other_elem in zip(self, other):
-                if self_elem != other_elem:
-                    return False
+        if len(self) != len(other):
+            # early mismatch
+            return False
 
-            return True
-        except:
-            return
+        for self_elem, other_elem in zip(self, other):
+            if self_elem != other_elem:
+                return False
+
+        return True
 
     def __len__(self):
-        size = native_bt.value_array_get_size(self._ptr)
-        assert(size >= 0)
+        size = native_bt.value_array_get_length(self._ptr)
+        assert size >= 0
         return size
 
     def _check_index(self, index):
         # TODO: support slices also
         if not isinstance(index, numbers.Integral):
-            raise TypeError("'{}' object is not an integral number object: invalid index".format(index.__class__.__name__))
+            raise TypeError(
+                "'{}' object is not an integral number object: invalid index".format(
+                    index.__class__.__name__
+                )
+            )
 
         index = int(index)
 
@@ -547,7 +469,7 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value):
     def __getitem__(self, index):
         self._check_index(index)
         ptr = native_bt.value_array_borrow_element_by_index(self._ptr, index)
-        assert(ptr)
+        assert ptr
         return _create_from_ptr_and_get_ref(ptr)
 
     def __setitem__(self, index, value):
@@ -559,9 +481,8 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value):
         else:
             ptr = value._ptr
 
-        status = native_bt.value_array_set_element_by_index(
-            self._ptr, index, ptr)
-        self._handle_status(status)
+        status = native_bt.value_array_set_element_by_index(self._ptr, index, ptr)
+        utils._handle_func_status(status)
 
     def append(self, value):
         value = create_value(value)
@@ -572,7 +493,7 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value):
             ptr = value._ptr
 
         status = native_bt.value_array_append_element(self._ptr, ptr)
-        self._handle_status(status)
+        utils._handle_func_status(status)
 
     def __iadd__(self, iterable):
         # Python will raise a TypeError if there's anything wrong with
@@ -623,35 +544,29 @@ class MapValue(_Container, collections.abc.MutableMapping, _Value):
             for key, elem in value.items():
                 self[key] = elem
 
-    def __eq__(self, other):
-        return _Value.__eq__(self, other)
-
     def __ne__(self, other):
         return _Value.__ne__(self, other)
 
-    def _spec_eq(self, other):
-        try:
-            if len(self) != len(other):
-                # early mismatch
-                return False
+    def __eq__(self, other):
+        if not isinstance(other, collections.abc.Mapping):
+            return False
 
-            for self_key in self:
-                if self_key not in other:
-                    return False
+        if len(self) != len(other):
+            # early mismatch
+            return False
 
-                self_value = self[self_key]
-                other_value = other[self_key]
+        for self_key in self:
+            if self_key not in other:
+                return False
 
-                if self_value != other_value:
-                    return False
+            if self[self_key] != other[self_key]:
+                return False
 
-            return True
-        except:
-            return
+        return True
 
     def __len__(self):
         size = native_bt.value_map_get_size(self._ptr)
-        assert(size >= 0)
+        assert size >= 0
         return size
 
     def __contains__(self, key):
@@ -668,7 +583,7 @@ class MapValue(_Container, collections.abc.MutableMapping, _Value):
     def __getitem__(self, key):
         self._check_key(key)
         ptr = native_bt.value_map_borrow_entry_value(self._ptr, key)
-        assert(ptr)
+        assert ptr
         return _create_from_ptr_and_get_ref(ptr)
 
     def __iter__(self):
@@ -684,7 +599,7 @@ class MapValue(_Container, collections.abc.MutableMapping, _Value):
             ptr = value._ptr
 
         status = native_bt.value_map_insert_entry(self._ptr, key, ptr)
-        self._handle_status(status)
+        utils._handle_func_status(status)
 
     def __repr__(self):
         items = ['{}: {}'.format(repr(k), repr(v)) for k, v in self.items()]
This page took 0.030828 seconds and 4 git commands to generate.