bt2: value.py: refactor value comparison
authorPhilippe Proulx <eeppeliteloop@gmail.com>
Fri, 28 Jun 2019 03:51:24 +0000 (23:51 -0400)
committerPhilippe Proulx <eeppeliteloop@gmail.com>
Wed, 3 Jul 2019 01:31:27 +0000 (21:31 -0400)
Changed:

* Remove _spec_eq() methods: each class implements its own __eq__()
  method directly.

* Do not use native_bt.value_compare(): we never reached that, because
  container value classes (`ArrayValue` and `MapValue`) implement their
  own rich, recursive comparison.

* In _NumericValue._extract_value(), do not check the `_NumericValue`,
  `False`, and `True` special case: check for `BoolValue` and `bool` to
  return a boolean object in those cases.

* In NumericValue.__lt__() and NumericValue.__eq__(), do not check that
  the parameter is a number object: self._extract_value() does this
  already.

* In BoolValue._value_to_bool(), return the boolean value directly,
  not using int(): it's already an integral number.

* In ArrayValue.__eq__(), be more strict: expect that the parameter is
  a sequence object, not just an iterable object. Before this, it was
  possible to compare an array value to an ordered dict with keys equal
  to the array value content, and this seems wrong as:

      collections.OrderedDict((('A', 23), ('B', 42))) == ['A', 'B']

  is false. An ordered dict is not a sequence.

* In MapValue.__eq__(), be more strict: expect that the parameter is a
  mapping object, not just an iterable and indexable object. The reason
  is similar to the ArrayValue.__eq__() case above. This should be
  enough to compare to another map value or to a dict (or ordered dict).

Signed-off-by: Philippe Proulx <eeppeliteloop@gmail.com>
Change-Id: I9941d2d82942e2efa8d5380c8ff5a4a2d2cb3a84
Reviewed-on: https://review.lttng.org/c/babeltrace/+/1563
Tested-by: jenkins <jenkins@lttng.org>
src/bindings/python/bt2/bt2/value.py

index b6fb6769fbe938fee8a2adcc29cff5227aeac08f..7f003c2bba69bf8f7d6678da3b04114168e1b211 100644 (file)
@@ -89,31 +89,9 @@ class _Value(object._SharedObject, metaclass=abc.ABCMeta):
     _get_ref = staticmethod(native_bt.value_get_ref)
     _put_ref = staticmethod(native_bt.value_put_ref)
 
-    def __eq__(self, other):
-        if other is None:
-            # self is never the null value object
-            return False
-
-        # try type-specific comparison first
-        spec_eq = self._spec_eq(other)
-
-        if spec_eq is not None:
-            return spec_eq
-
-        if not isinstance(other, _Value):
-            # not comparing apples to apples
-            return False
-
-        # fall back to native comparison function
-        return native_bt.value_compare(self._ptr, other._ptr)
-
     def __ne__(self, other):
         return not (self == other)
 
-    @abc.abstractmethod
-    def _spec_eq(self, other):
-        pass
-
     def _handle_status(self, status):
         _handle_status(status, self._NAME)
 
@@ -127,11 +105,8 @@ class _Value(object._SharedObject, metaclass=abc.ABCMeta):
 class _NumericValue(_Value):
     @staticmethod
     def _extract_value(other):
-        if isinstance(other, _NumericValue):
-            return other._value
-
-        if other is True or other is False:
-            return other
+        if isinstance(other, BoolValue) or isinstance(other, bool):
+            return bool(other)
 
         if isinstance(other, numbers.Integral):
             return int(other)
@@ -154,21 +129,14 @@ class _NumericValue(_Value):
         return repr(self._value)
 
     def __lt__(self, other):
-        if not isinstance(other, numbers.Number):
-            raise TypeError('unorderable types: {}() < {}()'.format(self.__class__.__name__,
-                                                                    other.__class__.__name__))
-
         return self._value < self._extract_value(other)
 
-    def _spec_eq(self, other):
-        pass
-
     def __eq__(self, other):
-        if not isinstance(other, numbers.Number):
+        try:
+            return self._value == self._extract_value(other)
+        except:
             return False
 
-        return self._value == self._extract_value(other)
-
     def __rmod__(self, other):
         return self._extract_value(other) % self._value
 
@@ -329,9 +297,11 @@ class BoolValue(_Value):
         self._check_create_status(ptr)
         super().__init__(ptr)
 
-    def _spec_eq(self, other):
-        if isinstance(other, numbers.Number):
-            return self._value == bool(other)
+    def __eq__(self, other):
+        try:
+            return self._value == self._value_to_bool(other)
+        except:
+            return False
 
     def __bool__(self):
         return self._value
@@ -346,7 +316,7 @@ class BoolValue(_Value):
         if not isinstance(value, bool):
             raise TypeError("'{}' object is not a 'bool' or 'BoolValue' object".format(value.__class__))
 
-        return int(value)
+        return value
 
     @property
     def _value(self):
@@ -462,11 +432,11 @@ class StringValue(collections.abc.Sequence, _Value):
 
     value = property(fset=_set_value)
 
-    def _spec_eq(self, other):
+    def __eq__(self, other):
         try:
             return self._value == self._value_to_str(other)
         except:
-            return
+            return False
 
     def __lt__(self, other):
         return self._value < self._value_to_str(other)
@@ -515,19 +485,19 @@ class ArrayValue(_Container, collections.abc.MutableSequence, _Value):
             for elem in value:
                 self.append(elem)
 
-    def _spec_eq(self, other):
-        try:
-            if len(self) != len(other):
-                # early mismatch
-                return False
+    def __eq__(self, other):
+        if not isinstance(other, collections.abc.Sequence):
+            return False
 
-            for self_elem, other_elem in zip(self, other):
-                if self_elem != other_elem:
-                    return False
+        if len(self) != len(other):
+            # early mismatch
+            return False
 
-            return True
-        except:
-            return
+        for self_elem, other_elem in zip(self, other):
+            if self_elem != other_elem:
+                return False
+
+        return True
 
     def __len__(self):
         size = native_bt.value_array_get_size(self._ptr)
@@ -623,31 +593,25 @@ class MapValue(_Container, collections.abc.MutableMapping, _Value):
             for key, elem in value.items():
                 self[key] = elem
 
-    def __eq__(self, other):
-        return _Value.__eq__(self, other)
-
     def __ne__(self, other):
         return _Value.__ne__(self, other)
 
-    def _spec_eq(self, other):
-        try:
-            if len(self) != len(other):
-                # early mismatch
-                return False
+    def __eq__(self, other):
+        if not isinstance(other, collections.abc.Mapping):
+            return False
 
-            for self_key in self:
-                if self_key not in other:
-                    return False
+        if len(self) != len(other):
+            # early mismatch
+            return False
 
-                self_value = self[self_key]
-                other_value = other[self_key]
+        for self_key in self:
+            if self_key not in other:
+                return False
 
-                if self_value != other_value:
-                    return False
+            if self[self_key] != other[self_key]:
+                return False
 
-            return True
-        except:
-            return
+        return True
 
     def __len__(self):
         size = native_bt.value_map_get_size(self._ptr)
This page took 0.028452 seconds and 4 git commands to generate.