bt2: field.py: refactor field comparison
authorPhilippe Proulx <eeppeliteloop@gmail.com>
Fri, 28 Jun 2019 03:54:08 +0000 (23:54 -0400)
committerPhilippe Proulx <eeppeliteloop@gmail.com>
Wed, 3 Jul 2019 01:31:27 +0000 (21:31 -0400)
Changed:

* In _NumericField._extract_value(), do not check the `False`, and
  `True` special cases: check that the parameter is an instance of the
  `bool` type to return a boolean object.

* In _ArrayField.__eq__(), be more strict: expect that the parameter is
  a sequence object, not just an iterable object. Before this, it was
  possible to compare an array field to an ordered dict with keys equal
  to the array field content, and this seems wrong as:

      collections.OrderedDict((('A', 23), ('B', 42))) == ['A', 'B']

  is false. An ordered dict is not a sequence.

* In _StructureField.__eq__(), be more strict: expect that the parameter
  is a mapping object, not just an iterable and indexable object. The
  reason is similar to the _ArrayField.__eq__() case above. This should
  be enough to compare to another structure field or to a dict (or
  ordered dict).

Signed-off-by: Philippe Proulx <eeppeliteloop@gmail.com>
Change-Id: I17f33c24e9dea526e59b5058235d57facb51cfbf
Reviewed-on: https://review.lttng.org/c/babeltrace/+/1564
Tested-by: jenkins <jenkins@lttng.org>
Reviewed-by: Francis Deslauriers <francis.deslauriers@efficios.com>
src/bindings/python/bt2/bt2/field.py

index 3a735ec3ca4293c50088f7ba2c9132f48271cfbe..ee28315790cf35851170b2f5686458611a4e5474 100644 (file)
@@ -71,7 +71,7 @@ class _Field(object._UniqueObject):
 class _NumericField(_Field):
     @staticmethod
     def _extract_value(other):
-        if other is True or other is False:
+        if isinstance(other, bool):
             return other
 
         if isinstance(other, numbers.Integral):
@@ -102,10 +102,10 @@ class _NumericField(_Field):
         return self._value < self._extract_value(other)
 
     def _spec_eq(self, other):
-        if not isinstance(other, numbers.Number):
-            return NotImplemented
-
-        return self._value == self._extract_value(other)
+        try:
+            return self._value == self._extract_value(other)
+        except:
+            return False
 
     def __rmod__(self, other):
         return self._extract_value(other) % self._value
@@ -369,12 +369,10 @@ class _StringField(_Field):
 
     def _spec_eq(self, other):
         try:
-            other = self._value_to_str(other)
-        except Exception:
+            return self._value == self._value_to_str(other)
+        except:
             return False
 
-        return self._value == other
-
     def __lt__(self, other):
         return self._value < self._value_to_str(other)
 
@@ -432,22 +430,21 @@ class _StructureField(_ContainerField, collections.abc.MutableMapping):
         return iter(self.field_class)
 
     def _spec_eq(self, other):
-        try:
-            if len(self) != len(other):
-                return False
+        if not isinstance(other, collections.abc.Mapping):
+            return False
 
-            for self_key, self_value in self.items():
-                if self_key not in other:
-                    return False
+        if len(self) != len(other):
+            # early mismatch
+            return False
 
-                other_value = other[self_key]
+        for self_key in self:
+            if self_key not in other:
+                return False
 
-                if self_value != other_value:
-                    return False
+            if self[self_key] != other[self_key]:
+                return False
 
-            return True
-        except Exception:
-            return False
+        return True
 
     def _set_value(self, values):
         try:
@@ -507,8 +504,7 @@ class _VariantField(_ContainerField, _Field):
                                       self._owner_put_ref)
 
     def _spec_eq(self, other):
-        new_self = _get_leaf_field(self)
-        return new_self == other
+        return _get_leaf_field(self) == other
 
     def __bool__(self):
         raise NotImplementedError
@@ -565,18 +561,19 @@ class _ArrayField(_ContainerField, _Field, collections.abc.MutableSequence):
         raise NotImplementedError
 
     def _spec_eq(self, other):
-        try:
-            if len(self) != len(other):
-                return False
-
-            for self_field, other_field in zip(self, other):
-                if self_field != other_field:
-                    return False
+        if not isinstance(other, collections.abc.Sequence):
+            return False
 
-            return True
-        except Exception:
+        if len(self) != len(other):
+            # early mismatch
             return False
 
+        for self_elem, other_elem in zip(self, other):
+            if self_elem != other_elem:
+                return False
+
+        return True
+
     def _repr(self):
         return '[{}]'.format(', '.join([repr(v) for v in self]))
 
This page took 0.026851 seconds and 4 git commands to generate.