From: Jérémie Galarneau Date: Thu, 14 Sep 2017 15:47:36 +0000 (-0400) Subject: Python bt2: value is a write-only property X-Git-Url: http://git.efficios.com/?p=babeltrace.git;a=commitdiff_plain;h=e1c6bebdd944ff21e358c167ef7bb4cbc68b9c28 Python bt2: value is a write-only property This commit turns 'value' into a write-only property. The value property is needed to distinguish between: my_var = bt2.IntegerFieldType(32)(123456) my_var = 123 my_var = bt2.IntegerFieldType(32)(123456) my_var.value = 123 The value 'getter' is not necessary since the various field types implement the interfaces needed to act as native Python types (collections.abc.Sequence, collections.abc.Mapping, numbers.Integral, etc.) Signed-off-by: Jérémie Galarneau --- diff --git a/bindings/python/bt2/bt2/fields.py b/bindings/python/bt2/bt2/fields.py index e40a86c6..afa02d5d 100644 --- a/bindings/python/bt2/bt2/fields.py +++ b/bindings/python/bt2/bt2/fields.py @@ -30,6 +30,13 @@ import abc import bt2 +def _get_leaf_field(obj): + if type(obj) is not _VariantField: + return obj + + return _get_leaf_field(obj.selected_field) + + def _create_from_ptr(ptr): # recreate the field type wrapper of this field's type (the identity # could be different, but the underlying address should be the @@ -54,6 +61,17 @@ class _Field(object._Object, metaclass=abc.ABCMeta): memo[id(self)] = cpy return cpy + def __eq__(self, other): + # special case: two unset fields with the same field type are equal + if isinstance(other, _Field): + if not self.is_set or not other.is_set: + if not self.is_set and not other.is_set and self.field_type == other.field_type: + return True + return False + + other = _get_leaf_field(other) + return self._spec_eq(other) + @property def field_type(self): return self._field_type @@ -87,93 +105,93 @@ class _NumericField(_Field): raise TypeError("'{}' object is not a number object".format(other.__class__.__name__)) def __int__(self): - return int(self.value) + return int(self._value) def __float__(self): - return float(self.value) + return float(self._value) def __str__(self): - return str(self.value) + return str(self._value) def __lt__(self, other): if not isinstance(other, numbers.Number): raise TypeError('unorderable types: {}() < {}()'.format(self.__class__.__name__, other.__class__.__name__)) - return self.value < float(other) + return self._value < float(other) def __le__(self, other): if not isinstance(other, numbers.Number): raise TypeError('unorderable types: {}() <= {}()'.format(self.__class__.__name__, other.__class__.__name__)) - return self.value <= float(other) + return self._value <= float(other) - def __eq__(self, other): + def _spec_eq(self, other): if not isinstance(other, numbers.Number): return False - return self.value == complex(other) + return self._value == complex(other) def __rmod__(self, other): - return self._extract_value(other) % self.value + return self._extract_value(other) % self._value def __mod__(self, other): - return self.value % self._extract_value(other) + return self._value % self._extract_value(other) def __rfloordiv__(self, other): - return self._extract_value(other) // self.value + return self._extract_value(other) // self._value def __floordiv__(self, other): - return self.value // self._extract_value(other) + return self._value // self._extract_value(other) def __round__(self, ndigits=None): if ndigits is None: - return round(self.value) + return round(self._value) else: - return round(self.value, ndigits) + return round(self._value, ndigits) def __ceil__(self): - return math.ceil(self.value) + return math.ceil(self._value) def __floor__(self): - return math.floor(self.value) + return math.floor(self._value) def __trunc__(self): - return int(self.value) + return int(self._value) def __abs__(self): - return abs(self.value) + return abs(self._value) def __add__(self, other): - return self.value + self._extract_value(other) + return self._value + self._extract_value(other) def __radd__(self, other): return self.__add__(other) def __neg__(self): - return -self.value + return -self._value def __pos__(self): - return +self.value + return +self._value def __mul__(self, other): - return self.value * self._extract_value(other) + return self._value * self._extract_value(other) def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): - return self.value / self._extract_value(other) + return self._value / self._extract_value(other) def __rtruediv__(self, other): - return self._extract_value(other) / self.value + return self._extract_value(other) / self._value def __pow__(self, exponent): - return self.value ** self._extract_value(exponent) + return self._value ** self._extract_value(exponent) def __rpow__(self, base): - return self._extract_value(base) ** self.value + return self._extract_value(base) ** self._value def __iadd__(self, other): self.value = self + other @@ -206,37 +224,37 @@ class _NumericField(_Field): class _IntegralField(_NumericField, numbers.Integral): def __lshift__(self, other): - return self.value << self._extract_value(other) + return self._value << self._extract_value(other) def __rlshift__(self, other): - return self._extract_value(other) << self.value + return self._extract_value(other) << self._value def __rshift__(self, other): - return self.value >> self._extract_value(other) + return self._value >> self._extract_value(other) def __rrshift__(self, other): - return self._extract_value(other) >> self.value + return self._extract_value(other) >> self._value def __and__(self, other): - return self.value & self._extract_value(other) + return self._value & self._extract_value(other) def __rand__(self, other): - return self._extract_value(other) & self.value + return self._extract_value(other) & self._value def __xor__(self, other): - return self.value ^ self._extract_value(other) + return self._value ^ self._extract_value(other) def __rxor__(self, other): - return self._extract_value(other) ^ self.value + return self._extract_value(other) ^ self._value def __or__(self, other): - return self.value | self._extract_value(other) + return self._value | self._extract_value(other) def __ror__(self, other): - return self._extract_value(other) | self.value + return self._extract_value(other) | self._value def __invert__(self): - return ~self.value + return ~self._value def __ilshift__(self, other): self.value = self << other @@ -280,20 +298,21 @@ class _IntegerField(_IntegralField): return value @property - def value(self): + def _value(self): if self.field_type.is_signed: ret, value = native_bt.ctf_field_signed_integer_get_value(self._ptr) else: ret, value = native_bt.ctf_field_unsigned_integer_get_value(self._ptr) if ret < 0: - # field is not set - return + if not self.is_set: + return + + utils._handle_ret(ret, "cannot get integer field's value") return value - @value.setter - def value(self, value): + def _set_value(self, value): value = self._value_to_int(value) if self.field_type.is_signed: @@ -303,6 +322,7 @@ class _IntegerField(_IntegralField): utils._handle_ret(ret, "cannot set integer field object's value") + value = property(fset=_set_value) class _FloatingPointNumberField(_RealField): _NAME = 'Floating point number' @@ -314,21 +334,23 @@ class _FloatingPointNumberField(_RealField): return float(value) @property - def value(self): + def _value(self): ret, value = native_bt.ctf_field_floating_point_get_value(self._ptr) if ret < 0: - # field is not set - return + if not self.is_set: + return + + utils._handle_ret(ret, "cannot get floating point number field's value") return value - @value.setter - def value(self, value): + def _set_value(self, value): value = self._value_to_float(value) ret = native_bt.ctf_field_floating_point_set_value(self._ptr, value) utils._handle_ret(ret, "cannot set floating point number field object's value") + value = property(fset=_set_value) class _EnumerationField(_IntegerField): _NAME = 'Enumeration' @@ -339,14 +361,15 @@ class _EnumerationField(_IntegerField): assert(int_field_ptr) return _create_from_ptr(int_field_ptr) - @property - def value(self): - return self.integer_field.value - - @value.setter - def value(self, value): + def _set_value(self, value): self.integer_field.value = value + value = property(fset=_set_value) + + @property + def _value(self): + return self.integer_field._value + @property def mappings(self): iter_ptr = native_bt.ctf_field_enumeration_get_mappings(self._ptr) @@ -361,7 +384,7 @@ class _StringField(_Field, collections.abc.Sequence): def _value_to_str(self, value): if isinstance(value, self.__class__): - value = value.value + value = value._value if not isinstance(value, str): raise TypeError("expecting a 'str' object") @@ -369,46 +392,42 @@ class _StringField(_Field, collections.abc.Sequence): return value @property - def value(self): + def _value(self): value = native_bt.ctf_field_string_get_value(self._ptr) - - if value is None: - # field is not set - return - return value - @value.setter - def value(self, value): + def _set_value(self, value): value = self._value_to_str(value) ret = native_bt.ctf_field_string_set_value(self._ptr, value) utils._handle_ret(ret, "cannot set string field object's value") - def __eq__(self, other): + value = property(fset=_set_value) + + def _spec_eq(self, other): try: other = self._value_to_str(other) except: return False - return self.value == other + return self._value == other def __le__(self, other): - return self.value <= self._value_to_str(other) + return self._value <= self._value_to_str(other) def __lt__(self, other): - return self.value < self._value_to_str(other) + return self._value < self._value_to_str(other) def __bool__(self): - return bool(self.value) + return bool(self._value) def __str__(self): - return self.value + return self._value def __getitem__(self, index): - return self.value[index] + return self._value[index] def __len__(self): - return len(self.value) + return len(self._value) def __iadd__(self, value): value = self._value_to_str(value) @@ -474,36 +493,39 @@ class _StructureField(_ContainerField, collections.abc.MutableMapping): # same name iterator return iter(self.field_type) - def __eq__(self, other): - if not isinstance(other, collections.abc.Mapping): - return False - - if len(self) != len(other): - return False - - for self_key, self_value in self.items(): - if self_key not in other: + def _spec_eq(self, other): + try: + if len(self) != len(other): return False - other_value = other[self_key] + for self_key, self_value in self.items(): + if self_key not in other: + return False - if self_value != other_value: - return False + other_value = other[self_key] - return True + if self_value != other_value: + return False + + return True + except: + return False @property - def value(self): - return {key: field.value for key, field in self.items()} + def _value(self): + return {key: value._value for key, value in self.items()} - @value.setter - def value(self, values): - if not hasattr(type(values), '__getitem__'): - raise TypeError('expecting a Mapping collection') + def _set_value(self, values): + original_values = self._value - for key, value in values.items(): - self[key].value = value + try: + for key, value in values.items(): + self[key].value = value + except: + self.value = original_values + raise + value = property(fset=_set_value) class _VariantField(_Field): _NAME = 'Variant' @@ -534,18 +556,21 @@ class _VariantField(_Field): return _create_from_ptr(field_ptr) - def __eq__(self, other): - if type(other) is not type(self): - return False - - if self.addr == other.addr: - return True - - return self.selected_field == other.selected_field + def _spec_eq(self, other): + return _get_leaf_field(self) == other def __bool__(self): return bool(self.selected_field) + @property + def _value(self): + if self.selected_field is not None: + return self.selected_field._value + + def _set_value(self, value): + self.selected_field.value = value + + value = property(fset=_set_value) class _ArraySequenceField(_ContainerField, collections.abc.MutableSequence): def __getitem__(self, index): @@ -579,33 +604,22 @@ class _ArraySequenceField(_ContainerField, collections.abc.MutableSequence): def insert(self, index, value): raise NotImplementedError - def __eq__(self, other): - if not isinstance(other, collections.abc.Sequence): - return False - - if len(self) != len(other): - return False - - for self_field, other_field in zip(self, other): - if self_field != other_field: + def _spec_eq(self, other): + try: + if len(self) != len(other): return False - return True - - @property - def value(self): - return [field.value for field in self] - - @value.setter - def value(self, values): - if not hasattr(type(values), '__iter__'): - raise TypeError('expecting an iterable container (Sequence)') + for self_field, other_field in zip(self, other): + if self_field != other_field: + return False - if len(self) != len(values): - raise ValueError('expected length of value and field to match') + return True + except: + return False - for index, value in enumerate(values): - self[index].value = value + @property + def _value(self): + return [field._value for field in self] class _ArrayField(_ArraySequenceField): @@ -617,6 +631,24 @@ class _ArrayField(_ArraySequenceField): def _get_field_ptr_at_index(self, index): return native_bt.ctf_field_array_get_field(self._ptr, index) + def _set_value(self, values): + if len(self) != len(values): + raise ValueError( + 'expected length of value and array field to match') + + original_values = self._value + try: + for index, value in enumerate(values): + if value is not None: + self[index].value = value + else: + self[index].reset() + except: + self.value = original_values + raise + + value = property(fset=_set_value) + class _SequenceField(_ArraySequenceField): _NAME = 'Sequence' @@ -639,6 +671,33 @@ class _SequenceField(_ArraySequenceField): def _get_field_ptr_at_index(self, index): return native_bt.ctf_field_sequence_get_field(self._ptr, index) + def _set_value(self, values): + original_length_field = self.length_field + if original_length_field is not None: + original_values = self._value + + if len(values) != self.length_field: + if self.length_field is not None: + length_ft = self.length_field.field_type + else: + length_ft = bt2.IntegerFieldType(size=64, is_signed=False) + self.length_field = length_ft(len(values)) + + try: + for index, value in enumerate(values): + if value is not None: + self[index].value = value + else: + self[index].reset() + except: + if original_length_field is not None: + self.length_field = original_length_field + self.value = original_values + else: + self.reset() + raise + + value = property(fset=_set_value) _TYPE_ID_TO_OBJ = { native_bt.CTF_FIELD_TYPE_ID_INTEGER: _IntegerField, diff --git a/tests/bindings/python/bt2/test_fields.py b/tests/bindings/python/bt2/test_fields.py index 5a776ea7..3b0b424e 100644 --- a/tests/bindings/python/bt2/test_fields.py +++ b/tests/bindings/python/bt2/test_fields.py @@ -34,7 +34,7 @@ class _TestNumericField(_TestCopySimple): comp_value = rhs if isinstance(rhs, (bt2.fields._IntegerField, bt2.fields._FloatingPointNumberField)): - comp_value = rhs.value + comp_value = copy.copy(rhs) try: r = op(self._def, rhs) @@ -102,9 +102,9 @@ class _TestNumericField(_TestCopySimple): self.assertEqual(self._def.addr, addr_before) def _test_unaryop_value_same(self, op): - value_before = self._def.value + value_before = copy.copy(self._def_value) self._unaryop(op) - self.assertEqual(self._def.value, value_before) + self.assertEqual(self._def, value_before) def _test_binop_type(self, op, rhs): r, rv = self._binop(op, rhs) @@ -132,9 +132,9 @@ class _TestNumericField(_TestCopySimple): self.assertEqual(self._def.addr, addr_before) def _test_binop_lhs_value_same(self, op, rhs): - value_before = self._def.value + value_before = copy.copy(self._def) r, rv = self._binop(op, rhs) - self.assertEqual(self._def.value, value_before) + self.assertEqual(self._def, value_before) def _test_binop_invalid_unknown(self, op): if op in _COMP_BINOPS: @@ -717,25 +717,21 @@ class _TestIntegerFieldCommon(_TestNumericField): raw = True self._def.value = raw self.assertEqual(self._def, raw) - self.assertEqual(self._def.value, raw) def test_assign_false(self): raw = False self._def.value = raw self.assertEqual(self._def, raw) - self.assertEqual(self._def.value, raw) def test_assign_pos_int(self): raw = 477 self._def.value = raw self.assertEqual(self._def, raw) - self.assertEqual(self._def.value, raw) def test_assign_neg_int(self): raw = -13 self._def.value = raw self.assertEqual(self._def, raw) - self.assertEqual(self._def.value, raw) def test_assign_int_field(self): raw = 999 @@ -743,13 +739,11 @@ class _TestIntegerFieldCommon(_TestNumericField): field.value = raw self._def.value = field self.assertEqual(self._def, raw) - self.assertEqual(self._def.value, raw) def test_assign_float(self): raw = 123.456 self._def.value = raw self.assertEqual(self._def, int(raw)) - self.assertEqual(self._def.value, int(raw)) def test_assign_invalid_type(self): with self.assertRaises(TypeError): @@ -761,7 +755,6 @@ class _TestIntegerFieldCommon(_TestNumericField): raw = 1777 field.value = 1777 self.assertEqual(field, raw) - self.assertEqual(field.value, raw) def test_assign_uint_invalid_neg(self): ft = bt2.IntegerFieldType(size=32, is_signed=False) @@ -847,24 +840,20 @@ class FloatingPointNumberFieldTestCase(_TestNumericField, unittest.TestCase): def test_assign_true(self): self._def.value = True self.assertTrue(self._def) - self.assertTrue(self._def.value) def test_assign_false(self): self._def.value = False self.assertFalse(self._def) - self.assertFalse(self._def.value) def test_assign_pos_int(self): raw = 477 self._def.value = raw self.assertEqual(self._def, float(raw)) - self.assertEqual(self._def.value, float(raw)) def test_assign_neg_int(self): raw = -13 self._def.value = raw self.assertEqual(self._def, float(raw)) - self.assertEqual(self._def.value, float(raw)) def test_assign_int_field(self): ft = bt2.IntegerFieldType(32) @@ -873,13 +862,11 @@ class FloatingPointNumberFieldTestCase(_TestNumericField, unittest.TestCase): field.value = raw self._def.value = field self.assertEqual(self._def, float(raw)) - self.assertEqual(self._def.value, float(raw)) def test_assign_float(self): raw = -19.23 self._def.value = raw self.assertEqual(self._def, raw) - self.assertEqual(self._def.value, raw) def test_assign_float_field(self): ft = bt2.FloatingPointNumberFieldType(32) @@ -888,7 +875,6 @@ class FloatingPointNumberFieldTestCase(_TestNumericField, unittest.TestCase): field.value = raw self._def.value = field self.assertEqual(self._def, raw) - self.assertEqual(self._def.value, raw) def test_assign_invalid_type(self): with self.assertRaises(TypeError):