From f11ed06245a17724c8c9dd9627be19fab069fca4 Mon Sep 17 00:00:00 2001 From: Philippe Proulx Date: Thu, 27 Jun 2019 23:54:08 -0400 Subject: [PATCH] bt2: field.py: refactor field comparison Changed: * In _NumericField._extract_value(), do not check the `False`, and `True` special cases: check that the parameter is an instance of the `bool` type to return a boolean object. * In _ArrayField.__eq__(), be more strict: expect that the parameter is a sequence object, not just an iterable object. Before this, it was possible to compare an array field to an ordered dict with keys equal to the array field content, and this seems wrong as: collections.OrderedDict((('A', 23), ('B', 42))) == ['A', 'B'] is false. An ordered dict is not a sequence. * In _StructureField.__eq__(), be more strict: expect that the parameter is a mapping object, not just an iterable and indexable object. The reason is similar to the _ArrayField.__eq__() case above. This should be enough to compare to another structure field or to a dict (or ordered dict). Signed-off-by: Philippe Proulx Change-Id: I17f33c24e9dea526e59b5058235d57facb51cfbf Reviewed-on: https://review.lttng.org/c/babeltrace/+/1564 Tested-by: jenkins Reviewed-by: Francis Deslauriers --- src/bindings/python/bt2/bt2/field.py | 61 +++++++++++++--------------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/src/bindings/python/bt2/bt2/field.py b/src/bindings/python/bt2/bt2/field.py index 3a735ec3..ee283157 100644 --- a/src/bindings/python/bt2/bt2/field.py +++ b/src/bindings/python/bt2/bt2/field.py @@ -71,7 +71,7 @@ class _Field(object._UniqueObject): class _NumericField(_Field): @staticmethod def _extract_value(other): - if other is True or other is False: + if isinstance(other, bool): return other if isinstance(other, numbers.Integral): @@ -102,10 +102,10 @@ class _NumericField(_Field): return self._value < self._extract_value(other) def _spec_eq(self, other): - if not isinstance(other, numbers.Number): - return NotImplemented - - return self._value == self._extract_value(other) + try: + return self._value == self._extract_value(other) + except: + return False def __rmod__(self, other): return self._extract_value(other) % self._value @@ -369,12 +369,10 @@ class _StringField(_Field): def _spec_eq(self, other): try: - other = self._value_to_str(other) - except Exception: + return self._value == self._value_to_str(other) + except: return False - return self._value == other - def __lt__(self, other): return self._value < self._value_to_str(other) @@ -432,22 +430,21 @@ class _StructureField(_ContainerField, collections.abc.MutableMapping): return iter(self.field_class) def _spec_eq(self, other): - try: - if len(self) != len(other): - return False + if not isinstance(other, collections.abc.Mapping): + return False - for self_key, self_value in self.items(): - if self_key not in other: - return False + if len(self) != len(other): + # early mismatch + return False - other_value = other[self_key] + for self_key in self: + if self_key not in other: + return False - if self_value != other_value: - return False + if self[self_key] != other[self_key]: + return False - return True - except Exception: - return False + return True def _set_value(self, values): try: @@ -507,8 +504,7 @@ class _VariantField(_ContainerField, _Field): self._owner_put_ref) def _spec_eq(self, other): - new_self = _get_leaf_field(self) - return new_self == other + return _get_leaf_field(self) == other def __bool__(self): raise NotImplementedError @@ -565,18 +561,19 @@ class _ArrayField(_ContainerField, _Field, collections.abc.MutableSequence): raise NotImplementedError def _spec_eq(self, other): - try: - if len(self) != len(other): - return False - - for self_field, other_field in zip(self, other): - if self_field != other_field: - return False + if not isinstance(other, collections.abc.Sequence): + return False - return True - except Exception: + if len(self) != len(other): + # early mismatch return False + for self_elem, other_elem in zip(self, other): + if self_elem != other_elem: + return False + + return True + def _repr(self): return '[{}]'.format(', '.join([repr(v) for v in self])) -- 2.34.1