lib: strictly type function return status enumerations
[babeltrace.git] / src / bindings / python / bt2 / bt2 / value.py
index d8b8332eecd87bd06942ac3a839b0a96c42b5eeb..b6a5a4a1efc1c588b491836e162137e61daf4c77 100644 (file)
@@ -29,13 +29,6 @@ import abc
 import bt2
 
 
-def _handle_status(status, obj_name):
-    if status >= 0:
-        return
-    else:
-        raise RuntimeError('unexpected error')
-
-
 def _create_from_ptr(ptr):
     if ptr is None or ptr == native_bt.value_null:
         return
@@ -63,24 +56,20 @@ def create_value(value):
     if isinstance(value, bool):
         return BoolValue(value)
 
-    if isinstance(value, int):
+    if isinstance(value, numbers.Integral):
         return SignedIntegerValue(value)
 
-    if isinstance(value, float):
+    if isinstance(value, numbers.Real):
         return RealValue(value)
 
     if isinstance(value, str):
         return StringValue(value)
 
-    try:
-        return MapValue(value)
-    except:
-        pass
-
-    try:
+    if isinstance(value, collections.abc.Sequence):
         return ArrayValue(value)
-    except:
-        pass
+
+    if isinstance(value, collections.abc.Mapping):
+        return MapValue(value)
 
     raise TypeError("cannot create value object from '{}' object".format(value.__class__.__name__))
 
@@ -89,34 +78,9 @@ class _Value(object._SharedObject, metaclass=abc.ABCMeta):
     _get_ref = staticmethod(native_bt.value_get_ref)
     _put_ref = staticmethod(native_bt.value_put_ref)
 
-    def __eq__(self, other):
-        if other is None:
-            # self is never the null value object
-            return False
-
-        # try type-specific comparison first
-        spec_eq = self._spec_eq(other)
-
-        if spec_eq is not None:
-            return spec_eq
-
-        if not isinstance(other, _Value):
-            # not comparing apples to apples
-            return False
-
-        # fall back to native comparison function
-        return native_bt.value_compare(self._ptr, other._ptr)
-
     def __ne__(self, other):
         return not (self == other)
 
-    @abc.abstractmethod
-    def _spec_eq(self, other):
-        pass
-
-    def _handle_status(self, status):
-        _handle_status(status, self._NAME)
-
     def _check_create_status(self, ptr):
         if ptr is None:
             raise bt2.CreationError(
@@ -127,11 +91,8 @@ class _Value(object._SharedObject, metaclass=abc.ABCMeta):
 class _NumericValue(_Value):
     @staticmethod
     def _extract_value(other):
-        if isinstance(other, _NumericValue):
-            return other._value
-
-        if other is True or other is False:
-            return other
+        if isinstance(other, BoolValue) or isinstance(other, bool):
+            return bool(other)
 
         if isinstance(other, numbers.Integral):
             return int(other)
@@ -154,28 +115,14 @@ class _NumericValue(_Value):
         return repr(self._value)
 
     def __lt__(self, other):
-        if not isinstance(other, numbers.Number):
-            raise TypeError('unorderable types: {}() < {}()'.format(self.__class__.__name__,
-                                                                    other.__class__.__name__))
-
-        return self._value < float(other)
-
-    def __le__(self, other):
-        if not isinstance(other, numbers.Number):
-            raise TypeError('unorderable types: {}() <= {}()'.format(self.__class__.__name__,
-                                                                     other.__class__.__name__))
-
-        return self._value <= float(other)
-
-    def _spec_eq(self, other):
-        pass
+        return self._value < self._extract_value(other)
 
     def __eq__(self, other):
-        if not isinstance(other, numbers.Number):
+        try:
+            return self._value == self._extract_value(other)
+        except:
             return False
 
-        return self._value == complex(other)
-
     def __rmod__(self, other):
         return self._extract_value(other) % self._value
 
@@ -236,34 +183,6 @@ class _NumericValue(_Value):
     def __rpow__(self, base):
         return self._extract_value(base) ** self._value
 
-    def __iadd__(self, other):
-        self.value = self + other
-        return self
-
-    def __isub__(self, other):
-        self.value = self - other
-        return self
-
-    def __imul__(self, other):
-        self.value = self * other
-        return self
-
-    def __itruediv__(self, other):
-        self.value = self / other
-        return self
-
-    def __ifloordiv__(self, other):
-        self.value = self // other
-        return self
-
-    def __imod__(self, other):
-        self.value = self % other
-        return self
-
-    def __ipow__(self, other):
-        self.value = self ** other
-        return self
-
 
 class _IntegralValue(_NumericValue, numbers.Integral):
     def __lshift__(self, other):
@@ -299,32 +218,12 @@ class _IntegralValue(_NumericValue, numbers.Integral):
     def __invert__(self):
         return ~self._value
 
-    def __ilshift__(self, other):
-        self.value = self << other
-        return self
-
-    def __irshift__(self, other):
-        self.value = self >> other
-        return self
-
-    def __iand__(self, other):
-        self.value = self & other
-        return self
-
-    def __ixor__(self, other):
-        self.value = self ^ other
-        return self
-
-    def __ior__(self, other):
-        self.value = self | other
-        return self
-
 
 class _RealValue(_NumericValue, numbers.Real):
     pass
 
 
-class BoolValue(_Value):
+class BoolValue(_IntegralValue):
     _NAME = 'Boolean'
 
     def __init__(self, value=None):
@@ -336,10 +235,6 @@ class BoolValue(_Value):
         self._check_create_status(ptr)
         super().__init__(ptr)
 
-    def _spec_eq(self, other):
-        if isinstance(other, numbers.Number):
-            return self._value == bool(other)
-
     def __bool__(self):
         return self._value
 
@@ -353,7 +248,7 @@ class BoolValue(_Value):
         if not isinstance(value, bool):
             raise TypeError("'{}' object is not a 'bool' or 'BoolValue' object".format(value.__class__))
 
-        return int(value)
+        return value
 
     @property
     def _value(self):
@@ -377,8 +272,8 @@ class _IntegerValue(_IntegralValue):
         super().__init__(ptr)
 
     def _value_to_int(self, value):
-        if not isinstance(value, numbers.Real):
-            raise TypeError('expecting a number object')
+        if not isinstance(value, numbers.Integral):
+            raise TypeError('expecting an integral number object')
 
         value = int(value)
         self._check_int_range(value)
@@ -465,18 +360,15 @@ class StringValue(collections.abc.Sequence, _Value):
 
     def _set_value(self, value):
         status = native_bt.value_string_set(self._ptr, self._value_to_str(value))
-        self._handle_status(status)
+        utils._handle_func_status(status)
 
     value = property(fset=_set_value)
 
-    def _spec_eq(self, other):
+    def __eq__(self, other):
         try:
             return self._value == self._value_to_str(other)
         except:
-            return
-
-    def __le__(self, other):
-        return self._value <= self._value_to_str(other)
+            return False
 
     def __lt__(self, other):
         return self._value < self._value_to_str(other)
@@ -525,19 +417,19 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value):
             for elem in value:
                 self.append(elem)
 
-    def _spec_eq(self, other):
-        try:
-            if len(self) != len(other):
-                # early mismatch
-                return False
+    def __eq__(self, other):
+        if not isinstance(other, collections.abc.Sequence):
+            return False
 
-            for self_elem, other_elem in zip(self, other):
-                if self_elem != other_elem:
-                    return False
+        if len(self) != len(other):
+            # early mismatch
+            return False
 
-            return True
-        except:
-            return
+        for self_elem, other_elem in zip(self, other):
+            if self_elem != other_elem:
+                return False
+
+        return True
 
     def __len__(self):
         size = native_bt.value_array_get_size(self._ptr)
@@ -571,7 +463,7 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value):
 
         status = native_bt.value_array_set_element_by_index(
             self._ptr, index, ptr)
-        self._handle_status(status)
+        utils._handle_func_status(status)
 
     def append(self, value):
         value = create_value(value)
@@ -582,7 +474,7 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value):
             ptr = value._ptr
 
         status = native_bt.value_array_append_element(self._ptr, ptr)
-        self._handle_status(status)
+        utils._handle_func_status(status)
 
     def __iadd__(self, iterable):
         # Python will raise a TypeError if there's anything wrong with
@@ -633,31 +525,25 @@ class MapValue(_Container, collections.abc.MutableMapping, _Value):
             for key, elem in value.items():
                 self[key] = elem
 
-    def __eq__(self, other):
-        return _Value.__eq__(self, other)
-
     def __ne__(self, other):
         return _Value.__ne__(self, other)
 
-    def _spec_eq(self, other):
-        try:
-            if len(self) != len(other):
-                # early mismatch
-                return False
+    def __eq__(self, other):
+        if not isinstance(other, collections.abc.Mapping):
+            return False
 
-            for self_key in self:
-                if self_key not in other:
-                    return False
+        if len(self) != len(other):
+            # early mismatch
+            return False
 
-                self_value = self[self_key]
-                other_value = other[self_key]
+        for self_key in self:
+            if self_key not in other:
+                return False
 
-                if self_value != other_value:
-                    return False
+            if self[self_key] != other[self_key]:
+                return False
 
-            return True
-        except:
-            return
+        return True
 
     def __len__(self):
         size = native_bt.value_map_get_size(self._ptr)
@@ -694,7 +580,7 @@ class MapValue(_Container, collections.abc.MutableMapping, _Value):
             ptr = value._ptr
 
         status = native_bt.value_map_insert_entry(self._ptr, key, ptr)
-        self._handle_status(status)
+        utils._handle_func_status(status)
 
     def __repr__(self):
         items = ['{}: {}'.format(repr(k), repr(v)) for k, v in self.items()]
This page took 0.037541 seconds and 4 git commands to generate.