X-Git-Url: http://git.efficios.com/?a=blobdiff_plain;f=bindings%2Fpython%2Fbt2%2Fbt2%2Ffields.py;h=8d315eb5355af9acf7113f4afe091a7f6c91b1ea;hb=872eeccbe6bb5201177a77f5bedc216750abd293;hp=13c6a68335393c200f6f473b4189ef326f486b35;hpb=7c54e2e7f763bb3c3dca8b37eb3d064550d19789;p=babeltrace.git diff --git a/bindings/python/bt2/bt2/fields.py b/bindings/python/bt2/bt2/fields.py index 13c6a683..8d315eb5 100644 --- a/bindings/python/bt2/bt2/fields.py +++ b/bindings/python/bt2/bt2/fields.py @@ -30,14 +30,21 @@ import abc import bt2 +def _get_leaf_field(obj): + if type(obj) is not _VariantField: + return obj + + return _get_leaf_field(obj.selected_field) + + def _create_from_ptr(ptr): # recreate the field type wrapper of this field's type (the identity # could be different, but the underlying address should be the # same) - field_type_ptr = native_bt.ctf_field_get_type(ptr) + field_type_ptr = native_bt.field_get_type(ptr) utils._handle_ptr(field_type_ptr, "cannot get field object's type") field_type = bt2.field_types._create_from_ptr(field_type_ptr) - typeid = native_bt.ctf_field_type_get_type_id(field_type._ptr) + typeid = native_bt.field_type_get_type_id(field_type._ptr) field = _TYPE_ID_TO_OBJ[typeid]._create_from_ptr(ptr) field._field_type = field_type return field @@ -45,7 +52,7 @@ def _create_from_ptr(ptr): class _Field(object._Object, metaclass=abc.ABCMeta): def __copy__(self): - ptr = native_bt.ctf_field_copy(self._ptr) + ptr = native_bt.field_copy(self._ptr) utils._handle_ptr(ptr, 'cannot copy {} field object'.format(self._NAME.lower())) return _create_from_ptr(ptr) @@ -54,10 +61,30 @@ class _Field(object._Object, metaclass=abc.ABCMeta): memo[id(self)] = cpy return cpy + def __eq__(self, other): + # special case: two unset fields with the same field type are equal + if isinstance(other, _Field): + if not self.is_set or not other.is_set: + if not self.is_set and not other.is_set and self.field_type == other.field_type: + return True + return False + + other = _get_leaf_field(other) + return self._spec_eq(other) + @property def field_type(self): return self._field_type + @property + def is_set(self): + is_set = native_bt.field_is_set(self._ptr) + return is_set > 0 + + def reset(self): + ret = native_bt.field_reset(self._ptr) + utils._handle_ret(ret, "cannot reset field object's value") + @functools.total_ordering class _NumericField(_Field): @@ -78,93 +105,93 @@ class _NumericField(_Field): raise TypeError("'{}' object is not a number object".format(other.__class__.__name__)) def __int__(self): - return int(self.value) + return int(self._value) def __float__(self): - return float(self.value) + return float(self._value) def __str__(self): - return str(self.value) + return str(self._value) def __lt__(self, other): if not isinstance(other, numbers.Number): raise TypeError('unorderable types: {}() < {}()'.format(self.__class__.__name__, other.__class__.__name__)) - return self.value < float(other) + return self._value < float(other) def __le__(self, other): if not isinstance(other, numbers.Number): raise TypeError('unorderable types: {}() <= {}()'.format(self.__class__.__name__, other.__class__.__name__)) - return self.value <= float(other) + return self._value <= float(other) - def __eq__(self, other): + def _spec_eq(self, other): if not isinstance(other, numbers.Number): return False - return self.value == complex(other) + return self._value == complex(other) def __rmod__(self, other): - return self._extract_value(other) % self.value + return self._extract_value(other) % self._value def __mod__(self, other): - return self.value % self._extract_value(other) + return self._value % self._extract_value(other) def __rfloordiv__(self, other): - return self._extract_value(other) // self.value + return self._extract_value(other) // self._value def __floordiv__(self, other): - return self.value // self._extract_value(other) + return self._value // self._extract_value(other) def __round__(self, ndigits=None): if ndigits is None: - return round(self.value) + return round(self._value) else: - return round(self.value, ndigits) + return round(self._value, ndigits) def __ceil__(self): - return math.ceil(self.value) + return math.ceil(self._value) def __floor__(self): - return math.floor(self.value) + return math.floor(self._value) def __trunc__(self): - return int(self.value) + return int(self._value) def __abs__(self): - return abs(self.value) + return abs(self._value) def __add__(self, other): - return self.value + self._extract_value(other) + return self._value + self._extract_value(other) def __radd__(self, other): return self.__add__(other) def __neg__(self): - return -self.value + return -self._value def __pos__(self): - return +self.value + return +self._value def __mul__(self, other): - return self.value * self._extract_value(other) + return self._value * self._extract_value(other) def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other): - return self.value / self._extract_value(other) + return self._value / self._extract_value(other) def __rtruediv__(self, other): - return self._extract_value(other) / self.value + return self._extract_value(other) / self._value def __pow__(self, exponent): - return self.value ** self._extract_value(exponent) + return self._value ** self._extract_value(exponent) def __rpow__(self, base): - return self._extract_value(base) ** self.value + return self._extract_value(base) ** self._value def __iadd__(self, other): self.value = self + other @@ -197,37 +224,37 @@ class _NumericField(_Field): class _IntegralField(_NumericField, numbers.Integral): def __lshift__(self, other): - return self.value << self._extract_value(other) + return self._value << self._extract_value(other) def __rlshift__(self, other): - return self._extract_value(other) << self.value + return self._extract_value(other) << self._value def __rshift__(self, other): - return self.value >> self._extract_value(other) + return self._value >> self._extract_value(other) def __rrshift__(self, other): - return self._extract_value(other) >> self.value + return self._extract_value(other) >> self._value def __and__(self, other): - return self.value & self._extract_value(other) + return self._value & self._extract_value(other) def __rand__(self, other): - return self._extract_value(other) & self.value + return self._extract_value(other) & self._value def __xor__(self, other): - return self.value ^ self._extract_value(other) + return self._value ^ self._extract_value(other) def __rxor__(self, other): - return self._extract_value(other) ^ self.value + return self._extract_value(other) ^ self._value def __or__(self, other): - return self.value | self._extract_value(other) + return self._value | self._extract_value(other) def __ror__(self, other): - return self._extract_value(other) | self.value + return self._extract_value(other) | self._value def __invert__(self): - return ~self.value + return ~self._value def __ilshift__(self, other): self.value = self << other @@ -271,29 +298,32 @@ class _IntegerField(_IntegralField): return value @property - def value(self): + def _value(self): if self.field_type.is_signed: - ret, value = native_bt.ctf_field_signed_integer_get_value(self._ptr) + ret, value = native_bt.field_signed_integer_get_value(self._ptr) else: - ret, value = native_bt.ctf_field_unsigned_integer_get_value(self._ptr) + ret, value = native_bt.field_unsigned_integer_get_value(self._ptr) if ret < 0: - # field is not set - return + if not self.is_set: + return + + utils._handle_ret(ret, "cannot get integer field's value") return value - @value.setter - def value(self, value): + def _set_value(self, value): value = self._value_to_int(value) if self.field_type.is_signed: - ret = native_bt.ctf_field_signed_integer_set_value(self._ptr, value) + ret = native_bt.field_signed_integer_set_value(self._ptr, value) else: - ret = native_bt.ctf_field_unsigned_integer_set_value(self._ptr, value) + ret = native_bt.field_unsigned_integer_set_value(self._ptr, value) utils._handle_ret(ret, "cannot set integer field object's value") + value = property(fset=_set_value) + class _FloatingPointNumberField(_RealField): _NAME = 'Floating point number' @@ -305,42 +335,50 @@ class _FloatingPointNumberField(_RealField): return float(value) @property - def value(self): - ret, value = native_bt.ctf_field_floating_point_get_value(self._ptr) + def _value(self): + ret, value = native_bt.field_floating_point_get_value(self._ptr) if ret < 0: - # field is not set - return + if not self.is_set: + return + + utils._handle_ret(ret, "cannot get floating point number field's value") return value - @value.setter - def value(self, value): + def _set_value(self, value): value = self._value_to_float(value) - ret = native_bt.ctf_field_floating_point_set_value(self._ptr, value) + ret = native_bt.field_floating_point_set_value(self._ptr, value) utils._handle_ret(ret, "cannot set floating point number field object's value") + value = property(fset=_set_value) + class _EnumerationField(_IntegerField): _NAME = 'Enumeration' @property def integer_field(self): - int_field_ptr = native_bt.ctf_field_enumeration_get_container(self._ptr) + int_field_ptr = native_bt.field_enumeration_get_container(self._ptr) assert(int_field_ptr) return _create_from_ptr(int_field_ptr) - @property - def value(self): - return self.integer_field.value - - @value.setter - def value(self, value): + def _set_value(self, value): self.integer_field.value = value + def __repr__(self): + labels = [repr(v.name) for v in self.mappings] + return '{} ({})'.format(self._value, ', '.join(labels)) + + value = property(fset=_set_value) + + @property + def _value(self): + return self.integer_field._value + @property def mappings(self): - iter_ptr = native_bt.ctf_field_enumeration_get_mappings(self._ptr) + iter_ptr = native_bt.field_enumeration_get_mappings(self._ptr) assert(iter_ptr) return bt2.field_types._EnumerationFieldTypeMappingIterator(iter_ptr, self.field_type.is_signed) @@ -352,7 +390,7 @@ class _StringField(_Field, collections.abc.Sequence): def _value_to_str(self, value): if isinstance(value, self.__class__): - value = value.value + value = value._value if not isinstance(value, str): raise TypeError("expecting a 'str' object") @@ -360,50 +398,46 @@ class _StringField(_Field, collections.abc.Sequence): return value @property - def value(self): - value = native_bt.ctf_field_string_get_value(self._ptr) - - if value is None: - # field is not set - return - + def _value(self): + value = native_bt.field_string_get_value(self._ptr) return value - @value.setter - def value(self, value): + def _set_value(self, value): value = self._value_to_str(value) - ret = native_bt.ctf_field_string_set_value(self._ptr, value) + ret = native_bt.field_string_set_value(self._ptr, value) utils._handle_ret(ret, "cannot set string field object's value") - def __eq__(self, other): + value = property(fset=_set_value) + + def _spec_eq(self, other): try: other = self._value_to_str(other) except: return False - return self.value == other + return self._value == other def __le__(self, other): - return self.value <= self._value_to_str(other) + return self._value <= self._value_to_str(other) def __lt__(self, other): - return self.value < self._value_to_str(other) + return self._value < self._value_to_str(other) def __bool__(self): - return bool(self.value) + return bool(self._value) def __str__(self): - return self.value + return self._value def __getitem__(self, index): - return self.value[index] + return self._value[index] def __len__(self): - return len(self.value) + return len(self._value) def __iadd__(self, value): value = self._value_to_str(value) - ret = native_bt.ctf_field_string_append(self._ptr, value) + ret = native_bt.field_string_append(self._ptr, value) utils._handle_ret(ret, "cannot append to string field object's value") return self @@ -429,7 +463,7 @@ class _StructureField(_ContainerField, collections.abc.MutableMapping): def __getitem__(self, key): utils._check_str(key) - ptr = native_bt.ctf_field_structure_get_field_by_name(self._ptr, key) + ptr = native_bt.field_structure_get_field_by_name(self._ptr, key) if ptr is None: raise KeyError(key) @@ -437,16 +471,9 @@ class _StructureField(_ContainerField, collections.abc.MutableMapping): return _create_from_ptr(ptr) def __setitem__(self, key, value): - # we can only set numbers and strings - if not isinstance(value, (numbers.Number, str)): - raise TypeError('expecting number object or string') - - # raises if index is somehow invalid + # raises if key is somehow invalid field = self[key] - if not isinstance(field, (_NumericField, _StringField)): - raise TypeError('can only set the value of a number or string field') - # the field's property does the appropriate conversion or raises # the appropriate exception field.value = value @@ -457,7 +484,7 @@ class _StructureField(_ContainerField, collections.abc.MutableMapping): if index >= len(self): raise IndexError - field_ptr = native_bt.ctf_field_structure_get_field_by_index(self._ptr, index) + field_ptr = native_bt.field_structure_get_field_by_index(self._ptr, index) assert(field_ptr) return _create_from_ptr(field_ptr) @@ -465,35 +492,43 @@ class _StructureField(_ContainerField, collections.abc.MutableMapping): # same name iterator return iter(self.field_type) - def __eq__(self, other): - if not isinstance(other, collections.abc.Mapping): - return False - - if len(self) != len(other): - return False - - for self_key, self_value in self.items(): - if self_key not in other: + def _spec_eq(self, other): + try: + if len(self) != len(other): return False - other_value = other[self_key] + for self_key, self_value in self.items(): + if self_key not in other: + return False - if self_value != other_value: - return False + other_value = other[self_key] - return True + if self_value != other_value: + return False + + return True + except: + return False @property - def value(self): - return {key: field.value for key, field in self.items()} + def _value(self): + return {key: value._value for key, value in self.items()} + + def _set_value(self, values): + original_values = self._value + + try: + for key, value in values.items(): + self[key].value = value + except: + self.value = original_values + raise - @value.setter - def value(self, values): - if not hasattr(type(values), '__getitem__'): - raise TypeError('expecting a Mapping collection') + value = property(fset=_set_value) - for key, value in values.items(): - self[key].value = value + def __repr__(self): + items = ['{}: {}'.format(repr(k), repr(v)) for k, v in self.items()] + return '{{{}}}'.format(', '.join(items)) class _VariantField(_Field): @@ -501,7 +536,7 @@ class _VariantField(_Field): @property def tag_field(self): - field_ptr = native_bt.ctf_field_variant_get_tag(self._ptr) + field_ptr = native_bt.field_variant_get_tag(self._ptr) if field_ptr is None: return @@ -514,29 +549,36 @@ class _VariantField(_Field): def field(self, tag_field=None): if tag_field is None: - field_ptr = native_bt.ctf_field_variant_get_current_field(self._ptr) + field_ptr = native_bt.field_variant_get_current_field(self._ptr) if field_ptr is None: return else: utils._check_type(tag_field, _EnumerationField) - field_ptr = native_bt.ctf_field_variant_get_field(self._ptr, tag_field._ptr) + field_ptr = native_bt.field_variant_get_field(self._ptr, tag_field._ptr) utils._handle_ptr(field_ptr, "cannot select variant field object's field") return _create_from_ptr(field_ptr) - def __eq__(self, other): - if type(other) is not type(self): - return False - - if self.addr == other.addr: - return True - - return self.selected_field == other.selected_field + def _spec_eq(self, other): + return _get_leaf_field(self) == other def __bool__(self): return bool(self.selected_field) + def __repr__(self): + return repr(self._value) + + @property + def _value(self): + if self.selected_field is not None: + return self.selected_field._value + + def _set_value(self, value): + self.selected_field.value = value + + value = property(fset=_set_value) + class _ArraySequenceField(_ContainerField, collections.abc.MutableSequence): def __getitem__(self, index): @@ -570,33 +612,22 @@ class _ArraySequenceField(_ContainerField, collections.abc.MutableSequence): def insert(self, index, value): raise NotImplementedError - def __eq__(self, other): - if not isinstance(other, collections.abc.Sequence): - return False - - if len(self) != len(other): - return False - - for self_field, other_field in zip(self, other): - if self_field != other_field: + def _spec_eq(self, other): + try: + if len(self) != len(other): return False - return True + for self_field, other_field in zip(self, other): + if self_field != other_field: + return False - @property - def value(self): - return [field.value for field in self] - - @value.setter - def value(self, values): - if not hasattr(type(values), '__iter__'): - raise TypeError('expecting an iterable container (Sequence)') - - if len(self) != len(values): - raise ValueError('expected length of value and field to match') + return True + except: + return False - for index, value in enumerate(values): - self[index].value = value + @property + def _value(self): + return [field._value for field in self] class _ArrayField(_ArraySequenceField): @@ -606,38 +637,85 @@ class _ArrayField(_ArraySequenceField): return self.field_type.length def _get_field_ptr_at_index(self, index): - return native_bt.ctf_field_array_get_field(self._ptr, index) + return native_bt.field_array_get_field(self._ptr, index) + + def _set_value(self, values): + if len(self) != len(values): + raise ValueError( + 'expected length of value and array field to match') + + original_values = self._value + try: + for index, value in enumerate(values): + if value is not None: + self[index].value = value + else: + self[index].reset() + except: + self.value = original_values + raise + + value = property(fset=_set_value) class _SequenceField(_ArraySequenceField): _NAME = 'Sequence' def _count(self): - return self.length_field.value + return int(self.length_field) @property def length_field(self): - field_ptr = native_bt.ctf_field_sequence_get_length(self._ptr) - utils._handle_ptr("cannot get sequence field object's length field") + field_ptr = native_bt.field_sequence_get_length(self._ptr) + if field_ptr is None: + return return _create_from_ptr(field_ptr) @length_field.setter def length_field(self, length_field): utils._check_type(length_field, _IntegerField) - ret = native_bt.ctf_field_sequence_set_length(self._ptr, length_field._ptr) + ret = native_bt.field_sequence_set_length(self._ptr, length_field._ptr) utils._handle_ret(ret, "cannot set sequence field object's length field") def _get_field_ptr_at_index(self, index): - return native_bt.ctf_field_sequence_get_field(self._ptr, index) + return native_bt.field_sequence_get_field(self._ptr, index) + + def _set_value(self, values): + original_length_field = self.length_field + if original_length_field is not None: + original_values = self._value + + if len(values) != self.length_field: + if self.length_field is not None: + length_ft = self.length_field.field_type + else: + length_ft = bt2.IntegerFieldType(size=64, is_signed=False) + self.length_field = length_ft(len(values)) + + try: + for index, value in enumerate(values): + if value is not None: + self[index].value = value + else: + self[index].reset() + except: + if original_length_field is not None: + self.length_field = original_length_field + self.value = original_values + else: + self.reset() + raise + + value = property(fset=_set_value) _TYPE_ID_TO_OBJ = { - native_bt.CTF_FIELD_TYPE_ID_INTEGER: _IntegerField, - native_bt.CTF_FIELD_TYPE_ID_FLOAT: _FloatingPointNumberField, - native_bt.CTF_FIELD_TYPE_ID_ENUM: _EnumerationField, - native_bt.CTF_FIELD_TYPE_ID_STRING: _StringField, - native_bt.CTF_FIELD_TYPE_ID_STRUCT: _StructureField, - native_bt.CTF_FIELD_TYPE_ID_ARRAY: _ArrayField, - native_bt.CTF_FIELD_TYPE_ID_SEQUENCE: _SequenceField, - native_bt.CTF_FIELD_TYPE_ID_VARIANT: _VariantField, + native_bt.FIELD_TYPE_ID_INTEGER: _IntegerField, + native_bt.FIELD_TYPE_ID_FLOAT: _FloatingPointNumberField, + native_bt.FIELD_TYPE_ID_ENUM: _EnumerationField, + native_bt.FIELD_TYPE_ID_STRING: _StringField, + native_bt.FIELD_TYPE_ID_STRUCT: _StructureField, + native_bt.FIELD_TYPE_ID_ARRAY: _ArrayField, + native_bt.FIELD_TYPE_ID_SEQUENCE: _SequenceField, + native_bt.FIELD_TYPE_ID_VARIANT: _VariantField, }