bt2: add user attributes property support
authorPhilippe Proulx <eeppeliteloop@gmail.com>
Wed, 14 Aug 2019 07:18:56 +0000 (03:18 -0400)
committerPhilippe Proulx <eeppeliteloop@gmail.com>
Thu, 15 Aug 2019 15:41:44 +0000 (11:41 -0400)
This patch adds user attributes property support to the appropriate
`bt2` objects to wrap the equivalent library functions.

The following objects now have a read-only `user_attributes` property
which returns a map value object:

* `_ClockClass`
* `_EventClass`
* `_FieldClass`
* `_Stream`
* `_StreamClass`
* `_StructureFieldClassMember`
* `_Trace`
* `_TraceClass`
* `_VariantFieldClassOption`

You can set the user attributes of those objects at creation time, when
calling the creation method, or when appending a structure field class
member or a variant field class option for `_StructureFieldClassMember`
and `_VariantFieldClassOption`.

To make the `_user_attributes` property of `_StructureFieldClassMember`
and `_VariantFieldClassOption` above possible, I made them wrap a real
BT pointer and keep a reference on the owning field class, making all
the properties make a call to the library instead of the object just
containing them.

In `test_field_class.py`, because there's now a common property to test
for each field class, the new `_TestFieldClass` mixin does that,
requiring a _create_default_field_class() method which accepts custom
positional and keyword arguments to create a specific default field
class to test.

Signed-off-by: Philippe Proulx <eeppeliteloop@gmail.com>
Change-Id: I1d58f1b386e4a367f038f1dda9c2e58738794a66
Reviewed-on: https://review.lttng.org/c/babeltrace/+/1923
Tested-by: jenkins <jenkins@lttng.org>
15 files changed:
src/bindings/python/bt2/bt2/clock_class.py
src/bindings/python/bt2/bt2/component.py
src/bindings/python/bt2/bt2/event_class.py
src/bindings/python/bt2/bt2/field_class.py
src/bindings/python/bt2/bt2/stream.py
src/bindings/python/bt2/bt2/stream_class.py
src/bindings/python/bt2/bt2/trace.py
src/bindings/python/bt2/bt2/trace_class.py
tests/bindings/python/bt2/test_clock_class.py
tests/bindings/python/bt2/test_event_class.py
tests/bindings/python/bt2/test_field_class.py
tests/bindings/python/bt2/test_stream.py
tests/bindings/python/bt2/test_stream_class.py
tests/bindings/python/bt2/test_trace.py
tests/bindings/python/bt2/test_trace_class.py

index a8247dfbf4244ea72b53fbed26d480b4549b67d3..953a2185b2d8c3926d1b57873c959237379d4b2c 100644 (file)
@@ -21,6 +21,7 @@
 # THE SOFTWARE.
 
 from bt2 import native_bt, object, utils
+from bt2 import value as bt2_value
 import uuid as uuidp
 
 
@@ -51,6 +52,19 @@ class _ClockClass(object._SharedObject):
     _get_ref = staticmethod(native_bt.clock_class_get_ref)
     _put_ref = staticmethod(native_bt.clock_class_put_ref)
 
+    @property
+    def user_attributes(self):
+        ptr = native_bt.clock_class_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
+
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.clock_class_set_user_attributes(self._ptr, value._ptr)
+
+    _user_attributes = property(fset=_user_attributes)
+
     @property
     def name(self):
         return native_bt.clock_class_get_name(self._ptr)
index a68da615f1b542f187cc0637c83f4b7f325d94f1..a0d29e272f43d5286cad50f47ff400c68eb33a2f 100644 (file)
@@ -706,7 +706,9 @@ class _UserComponent(metaclass=_UserComponentType):
         )
         self._user_port_connected(port, other_port)
 
-    def _create_trace_class(self, assigns_automatic_stream_class_id=True):
+    def _create_trace_class(
+        self, user_attributes=None, assigns_automatic_stream_class_id=True
+    ):
         ptr = self._bt_as_self_component_ptr(self._bt_ptr)
         tc_ptr = native_bt.trace_class_create(ptr)
 
@@ -716,12 +718,16 @@ class _UserComponent(metaclass=_UserComponentType):
         tc = bt2_trace_class._TraceClass._create_from_ptr(tc_ptr)
         tc._assigns_automatic_stream_class_id = assigns_automatic_stream_class_id
 
+        if user_attributes is not None:
+            tc._user_attributes = user_attributes
+
         return tc
 
     def _create_clock_class(
         self,
         frequency=None,
         name=None,
+        user_attributes=None,
         description=None,
         precision=None,
         offset=None,
@@ -742,6 +748,9 @@ class _UserComponent(metaclass=_UserComponentType):
         if name is not None:
             cc._name = name
 
+        if user_attributes is not None:
+            cc._user_attributes = user_attributes
+
         if description is not None:
             cc._description = description
 
index dfef62137f3315e10b195d73e87a1455256b1c3e..8d9d54dcc8fe7fae316a23a7d6b4343b23cfa6e5 100644 (file)
@@ -57,6 +57,19 @@ class _EventClass(object._SharedObject):
         if sc_ptr is not None:
             return bt2_stream_class._StreamClass._create_from_ptr_and_get_ref(sc_ptr)
 
+    @property
+    def user_attributes(self):
+        ptr = native_bt.event_class_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
+
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.event_class_set_user_attributes(self._ptr, value._ptr)
+
+    _user_attributes = property(fset=_user_attributes)
+
     @property
     def name(self):
         return native_bt.event_class_get_name(self._ptr)
index 75d7cc5a785bf3f75c3119e153d3fb2c73f7e370..d467bcc4129fa8a986fb3dc47054e3fd25e1f0ef 100644 (file)
@@ -24,6 +24,7 @@ from bt2 import native_bt, object, utils
 import collections.abc
 from bt2 import field_path as bt2_field_path
 from bt2 import integer_range_set as bt2_integer_range_set
+from bt2 import value as bt2_value
 import bt2
 
 
@@ -49,6 +50,19 @@ class _FieldClass(object._SharedObject):
                 'cannot create {} field class object'.format(self._NAME.lower())
             )
 
+    @property
+    def user_attributes(self):
+        ptr = native_bt.field_class_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
+
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.field_class_set_user_attributes(self._ptr, value._ptr)
+
+    _user_attributes = property(fset=_user_attributes)
+
 
 class _BoolFieldClass(_FieldClass):
     _NAME = 'Boolean'
@@ -283,29 +297,58 @@ class _StringFieldClass(_FieldClass):
 
 
 class _StructureFieldClassMember:
-    def __init__(self, name, field_class):
-        self._name = name
-        self._field_class = field_class
+    def __init__(self, owning_struct_fc, member_ptr):
+        # this field class owns the member; keeping it here maintains
+        # the member alive as members are not shared objects
+        self._owning_struct_fc = owning_struct_fc
+        self._ptr = member_ptr
 
     @property
     def name(self):
-        return self._name
+        name = native_bt.field_class_structure_member_get_name(self._ptr)
+        assert name is not None
+        return name
 
     @property
     def field_class(self):
-        return self._field_class
+        fc_ptr = native_bt.field_class_structure_member_borrow_field_class_const(
+            self._ptr
+        )
+        assert fc_ptr is not None
+        return _create_field_class_from_ptr_and_get_ref(fc_ptr)
+
+    @property
+    def user_attributes(self):
+        ptr = native_bt.field_class_structure_member_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
+
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.field_class_structure_member_set_user_attributes(
+            self._ptr, value._ptr
+        )
+
+    _user_attributes = property(fset=_user_attributes)
 
 
 class _StructureFieldClass(_FieldClass, collections.abc.Mapping):
     _NAME = 'Structure'
 
-    def append_member(self, name, field_class):
+    def append_member(self, name, field_class, user_attributes=None):
         utils._check_str(name)
         utils._check_type(field_class, _FieldClass)
 
         if name in self:
             raise ValueError("duplicate member name '{}'".format(name))
 
+        user_attributes_value = None
+
+        if user_attributes is not None:
+            # check now that user attributes are valid
+            user_attributes_value = bt2.create_value(user_attributes)
+
         status = native_bt.field_class_structure_append_member(
             self._ptr, name, field_class._ptr
         )
@@ -313,21 +356,16 @@ class _StructureFieldClass(_FieldClass, collections.abc.Mapping):
             status, 'cannot append member to structure field class object'
         )
 
+        if user_attributes is not None:
+            self[name]._user_attributes = user_attributes_value
+
     def __len__(self):
         count = native_bt.field_class_structure_get_member_count(self._ptr)
         assert count >= 0
         return count
 
-    @staticmethod
-    def _create_member_from_ptr(member_ptr):
-        name = native_bt.field_class_structure_member_get_name(member_ptr)
-        assert name is not None
-        fc_ptr = native_bt.field_class_structure_member_borrow_field_class_const(
-            member_ptr
-        )
-        assert fc_ptr is not None
-        fc = _create_field_class_from_ptr_and_get_ref(fc_ptr)
-        return _StructureFieldClassMember(name, fc)
+    def _create_member_from_ptr(self, member_ptr):
+        return _StructureFieldClassMember(self, member_ptr)
 
     def __getitem__(self, key):
         if not isinstance(key, str):
@@ -387,27 +425,38 @@ class _OptionFieldClass(_FieldClass):
 
 
 class _VariantFieldClassOption:
-    def __init__(self, name, field_class):
-        self._name = name
-        self._field_class = field_class
+    def __init__(self, owning_var_fc, option_ptr):
+        # this field class owns the option; keeping it here maintains
+        # the option alive as options are not shared objects
+        self._owning_var_fc = owning_var_fc
+        self._ptr = option_ptr
 
     @property
     def name(self):
-        return self._name
+        name = native_bt.field_class_variant_option_get_name(self._ptr)
+        assert name is not None
+        return name
 
     @property
     def field_class(self):
-        return self._field_class
+        fc_ptr = native_bt.field_class_variant_option_borrow_field_class_const(
+            self._ptr
+        )
+        assert fc_ptr is not None
+        return _create_field_class_from_ptr_and_get_ref(fc_ptr)
 
+    @property
+    def user_attributes(self):
+        ptr = native_bt.field_class_variant_option_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
 
-class _VariantFieldClassWithSelectorOption(_VariantFieldClassOption):
-    def __init__(self, name, field_class, ranges):
-        super().__init__(name, field_class)
-        self._ranges = ranges
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.field_class_variant_option_set_user_attributes(self._ptr, value._ptr)
 
-    @property
-    def ranges(self):
-        return self._ranges
+    _user_attributes = property(fset=_user_attributes)
 
 
 class _VariantFieldClass(_FieldClass, collections.abc.Mapping):
@@ -424,12 +473,7 @@ class _VariantFieldClass(_FieldClass, collections.abc.Mapping):
         return opt_ptr
 
     def _create_option_from_ptr(self, opt_ptr):
-        name = native_bt.field_class_variant_option_get_name(opt_ptr)
-        assert name is not None
-        fc_ptr = native_bt.field_class_variant_option_borrow_field_class_const(opt_ptr)
-        assert fc_ptr is not None
-        fc = _create_field_class_from_ptr_and_get_ref(fc_ptr)
-        return _VariantFieldClassOption(name, fc)
+        return _VariantFieldClassOption(self, opt_ptr)
 
     def __len__(self):
         count = native_bt.field_class_variant_get_option_count(self._ptr)
@@ -470,13 +514,19 @@ class _VariantFieldClass(_FieldClass, collections.abc.Mapping):
 class _VariantFieldClassWithoutSelector(_VariantFieldClass):
     _NAME = 'Variant (without selector)'
 
-    def append_option(self, name, field_class):
+    def append_option(self, name, field_class, user_attributes=None):
         utils._check_str(name)
         utils._check_type(field_class, _FieldClass)
 
         if name in self:
             raise ValueError("duplicate option name '{}'".format(name))
 
+        user_attributes_value = None
+
+        if user_attributes is not None:
+            # check now that user attributes are valid
+            user_attributes_value = bt2.create_value(user_attributes)
+
         status = native_bt.field_class_variant_without_selector_append_option(
             self._ptr, name, field_class._ptr
         )
@@ -484,6 +534,9 @@ class _VariantFieldClassWithoutSelector(_VariantFieldClass):
             status, 'cannot append option to variant field class object'
         )
 
+        if user_attributes is not None:
+            self[name]._user_attributes = user_attributes_value
+
     def __iadd__(self, options):
         for name, field_class in options:
             self.append_option(name, field_class)
@@ -491,22 +544,45 @@ class _VariantFieldClassWithoutSelector(_VariantFieldClass):
         return self
 
 
+class _VariantFieldClassWithSelectorOption(_VariantFieldClassOption):
+    def __init__(self, owning_var_fc, spec_opt_ptr):
+        self._spec_ptr = spec_opt_ptr
+        super().__init__(owning_var_fc, self._as_option_ptr(spec_opt_ptr))
+
+    @property
+    def ranges(self):
+        range_set_ptr = self._borrow_ranges_ptr(self._spec_ptr)
+        assert range_set_ptr is not None
+        return self._range_set_type._create_from_ptr_and_get_ref(range_set_ptr)
+
+
+class _VariantFieldClassWithSignedSelectorOption(_VariantFieldClassWithSelectorOption):
+    _as_option_ptr = staticmethod(
+        native_bt.field_class_variant_with_selector_signed_option_as_option_const
+    )
+    _borrow_ranges_ptr = staticmethod(
+        native_bt.field_class_variant_with_selector_signed_option_borrow_ranges_const
+    )
+    _range_set_type = bt2_integer_range_set.SignedIntegerRangeSet
+
+
+class _VariantFieldClassWithUnsignedSelectorOption(
+    _VariantFieldClassWithSelectorOption
+):
+    _as_option_ptr = staticmethod(
+        native_bt.field_class_variant_with_selector_unsigned_option_as_option_const
+    )
+    _borrow_ranges_ptr = staticmethod(
+        native_bt.field_class_variant_with_selector_unsigned_option_borrow_ranges_const
+    )
+    _range_set_type = bt2_integer_range_set.UnsignedIntegerRangeSet
+
+
 class _VariantFieldClassWithSelector(_VariantFieldClass):
     _NAME = 'Variant (with selector)'
 
     def _create_option_from_ptr(self, opt_ptr):
-        base_opt_ptr = self._as_option_ptr(opt_ptr)
-        name = native_bt.field_class_variant_option_get_name(base_opt_ptr)
-        assert name is not None
-        fc_ptr = native_bt.field_class_variant_option_borrow_field_class_const(
-            base_opt_ptr
-        )
-        assert fc_ptr is not None
-        fc = _create_field_class_from_ptr_and_get_ref(fc_ptr)
-        range_set_ptr = self._option_borrow_ranges_ptr(opt_ptr)
-        assert range_set_ptr is not None
-        range_set = self._range_set_type._create_from_ptr_and_get_ref(range_set_ptr)
-        return _VariantFieldClassWithSelectorOption(name, fc, range_set)
+        return self._option_type(self, opt_ptr)
 
     @property
     def selector_field_path(self):
@@ -519,10 +595,10 @@ class _VariantFieldClassWithSelector(_VariantFieldClass):
 
         return bt2_field_path._FieldPath._create_from_ptr_and_get_ref(ptr)
 
-    def append_option(self, name, field_class, ranges):
+    def append_option(self, name, field_class, ranges, user_attributes=None):
         utils._check_str(name)
         utils._check_type(field_class, _FieldClass)
-        utils._check_type(ranges, self._range_set_type)
+        utils._check_type(ranges, self._option_type._range_set_type)
 
         if name in self:
             raise ValueError("duplicate option name '{}'".format(name))
@@ -530,6 +606,12 @@ class _VariantFieldClassWithSelector(_VariantFieldClass):
         if len(ranges) == 0:
             raise ValueError('range set is empty')
 
+        user_attributes_value = None
+
+        if user_attributes is not None:
+            # check now that user attributes are valid
+            user_attributes_value = bt2.create_value(user_attributes)
+
         # TODO: check overlaps (precondition of self._append_option())
 
         status = self._append_option(self._ptr, name, field_class._ptr, ranges._ptr)
@@ -537,6 +619,9 @@ class _VariantFieldClassWithSelector(_VariantFieldClass):
             status, 'cannot append option to variant field class object'
         )
 
+        if user_attributes is not None:
+            self[name]._user_attributes = user_attributes_value
+
     def __iadd__(self, options):
         for name, field_class, ranges in options:
             self.append_option(name, field_class, ranges)
@@ -552,16 +637,11 @@ class _VariantFieldClassWithUnsignedSelector(_VariantFieldClassWithSelector):
     _borrow_member_by_index_ptr = staticmethod(
         native_bt.field_class_variant_with_selector_unsigned_borrow_option_by_index_const
     )
-    _as_option_ptr = staticmethod(
-        native_bt.field_class_variant_with_selector_unsigned_option_as_option_const
-    )
     _append_option = staticmethod(
         native_bt.field_class_variant_with_selector_unsigned_append_option
     )
-    _option_borrow_ranges_ptr = staticmethod(
-        native_bt.field_class_variant_with_selector_unsigned_option_borrow_ranges_const
-    )
-    _range_set_type = bt2_integer_range_set.UnsignedIntegerRangeSet
+    _option_type = _VariantFieldClassWithUnsignedSelectorOption
+    _as_option_ptr = staticmethod(_option_type._as_option_ptr)
 
 
 class _VariantFieldClassWithSignedSelector(_VariantFieldClassWithSelector):
@@ -572,16 +652,11 @@ class _VariantFieldClassWithSignedSelector(_VariantFieldClassWithSelector):
     _borrow_member_by_index_ptr = staticmethod(
         native_bt.field_class_variant_with_selector_signed_borrow_option_by_index_const
     )
-    _as_option_ptr = staticmethod(
-        native_bt.field_class_variant_with_selector_signed_option_as_option_const
-    )
     _append_option = staticmethod(
         native_bt.field_class_variant_with_selector_signed_append_option
     )
-    _option_borrow_ranges_ptr = staticmethod(
-        native_bt.field_class_variant_with_selector_signed_option_borrow_ranges_const
-    )
-    _range_set_type = bt2_integer_range_set.SignedIntegerRangeSet
+    _option_type = _VariantFieldClassWithSignedSelectorOption
+    _as_option_ptr = staticmethod(_option_type._as_option_ptr)
 
 
 class _ArrayFieldClass(_FieldClass):
index 4710aeaaa0cb51e215279694d25588f3d3491c82..264d5477f45fc948253db6abb702f3d071003c18 100644 (file)
@@ -25,6 +25,7 @@ from bt2 import object as bt2_object
 from bt2 import packet as bt2_packet
 from bt2 import trace as bt2_trace
 from bt2 import stream_class as bt2_stream_class
+from bt2 import value as bt2_value
 import bt2
 
 
@@ -50,6 +51,19 @@ class _Stream(bt2_object._SharedObject):
 
     _name = property(fset=_name)
 
+    @property
+    def user_attributes(self):
+        ptr = native_bt.stream_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
+
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.stream_set_user_attributes(self._ptr, value._ptr)
+
+    _user_attributes = property(fset=_user_attributes)
+
     @property
     def id(self):
         id = native_bt.stream_get_id(self._ptr)
index af8d4e53afaa8b13a7e6569ba1952bb033b02729..2c33dae8ddfa8d8a11ded5eedaadce48bd646139 100644 (file)
@@ -25,6 +25,7 @@ from bt2 import field_class as bt2_field_class
 from bt2 import event_class as bt2_event_class
 from bt2 import trace_class as bt2_trace_class
 from bt2 import clock_class as bt2_clock_class
+from bt2 import value as bt2_value
 import collections.abc
 
 
@@ -62,6 +63,7 @@ class _StreamClass(object._SharedObject, collections.abc.Mapping):
         self,
         id=None,
         name=None,
+        user_attributes=None,
         log_level=None,
         emf_uri=None,
         specific_context_field_class=None,
@@ -88,6 +90,9 @@ class _StreamClass(object._SharedObject, collections.abc.Mapping):
         if name is not None:
             event_class._name = name
 
+        if user_attributes is not None:
+            event_class._user_attributes = user_attributes
+
         if log_level is not None:
             event_class._log_level = log_level
 
@@ -109,6 +114,19 @@ class _StreamClass(object._SharedObject, collections.abc.Mapping):
         if tc_ptr is not None:
             return bt2_trace_class._TraceClass._create_from_ptr_and_get_ref(tc_ptr)
 
+    @property
+    def user_attributes(self):
+        ptr = native_bt.stream_class_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
+
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.stream_class_set_user_attributes(self._ptr, value._ptr)
+
+    _user_attributes = property(fset=_user_attributes)
+
     @property
     def name(self):
         return native_bt.stream_class_get_name(self._ptr)
index e893b2d28bfabb24906b8735857b822b0bc728a0..7f98ac8698f39a7f9aeb615e33c901a28b941152 100644 (file)
@@ -115,6 +115,19 @@ class _Trace(object._SharedObject, collections.abc.Mapping):
         assert trace_class_ptr is not None
         return bt2_trace_class._TraceClass._create_from_ptr_and_get_ref(trace_class_ptr)
 
+    @property
+    def user_attributes(self):
+        ptr = native_bt.trace_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
+
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.trace_set_user_attributes(self._ptr, value._ptr)
+
+    _user_attributes = property(fset=_user_attributes)
+
     @property
     def name(self):
         return native_bt.trace_get_name(self._ptr)
@@ -144,7 +157,7 @@ class _Trace(object._SharedObject, collections.abc.Mapping):
     def env(self):
         return _TraceEnv(self)
 
-    def create_stream(self, stream_class, id=None, name=None):
+    def create_stream(self, stream_class, id=None, name=None, user_attributes=None):
         utils._check_type(stream_class, bt2_stream_class._StreamClass)
 
         if stream_class.assigns_automatic_stream_id:
@@ -173,6 +186,9 @@ class _Trace(object._SharedObject, collections.abc.Mapping):
         if name is not None:
             stream._name = name
 
+        if user_attributes is not None:
+            stream._user_attributes = user_attributes
+
         return stream
 
     def add_destruction_listener(self, listener):
index b9ac172fc0a19de3bd6268e485e4d8313cc73038..1ec2a5055e120f9fafac87029667a8038615f245 100644 (file)
@@ -30,6 +30,7 @@ from bt2 import stream_class as bt2_stream_class
 from bt2 import field_class as bt2_field_class
 from bt2 import trace as bt2_trace
 from bt2 import trace_class as bt2_trace_class
+from bt2 import value as bt2_value
 import collections.abc
 import functools
 
@@ -47,7 +48,7 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
 
     # Instantiate a trace of this class.
 
-    def __call__(self, name=None, uuid=None, env=None):
+    def __call__(self, name=None, user_attributes=None, uuid=None, env=None):
         trace_ptr = native_bt.trace_create(self._ptr)
 
         if trace_ptr is None:
@@ -58,6 +59,9 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
         if name is not None:
             trace._name = name
 
+        if user_attributes is not None:
+            trace._user_attributes = user_attributes
+
         if uuid is not None:
             trace._uuid = uuid
 
@@ -101,6 +105,7 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
         self,
         id=None,
         name=None,
+        user_attributes=None,
         packet_context_field_class=None,
         event_common_context_field_class=None,
         default_clock_class=None,
@@ -136,6 +141,9 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
         if name is not None:
             sc._name = name
 
+        if user_attributes is not None:
+            sc._user_attributes = user_attributes
+
         if event_common_context_field_class is not None:
             sc._event_common_context_field_class = event_common_context_field_class
 
@@ -168,6 +176,19 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
         )
         return sc
 
+    @property
+    def user_attributes(self):
+        ptr = native_bt.trace_class_borrow_user_attributes(self._ptr)
+        assert ptr is not None
+        return bt2_value._create_from_ptr_and_get_ref(ptr)
+
+    def _user_attributes(self, user_attributes):
+        value = bt2_value.create_value(user_attributes)
+        utils._check_type(value, bt2_value.MapValue)
+        native_bt.trace_class_set_user_attributes(self._ptr, value._ptr)
+
+    _user_attributes = property(fset=_user_attributes)
+
     @property
     def assigns_automatic_stream_class_id(self):
         return native_bt.trace_class_assigns_automatic_stream_class_id(self._ptr)
@@ -188,13 +209,19 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
         if ptr is None:
             raise bt2._MemoryError('cannot create {} field class'.format(type_name))
 
-    def create_bool_field_class(self):
+    @staticmethod
+    def _set_field_class_user_attrs(fc, user_attributes):
+        if user_attributes is not None:
+            fc._user_attributes = user_attributes
+
+    def create_bool_field_class(self, user_attributes=None):
         field_class_ptr = native_bt.field_class_bool_create(self._ptr)
         self._check_field_class_create_status(field_class_ptr, 'boolean')
+        fc = bt2_field_class._BoolFieldClass._create_from_ptr(field_class_ptr)
+        self._set_field_class_user_attrs(fc, user_attributes)
+        return fc
 
-        return bt2_field_class._BoolFieldClass._create_from_ptr(field_class_ptr)
-
-    def create_bit_array_field_class(self, length):
+    def create_bit_array_field_class(self, length, user_attributes=None):
         utils._check_uint64(length)
 
         if length < 1 or length > 64:
@@ -206,11 +233,18 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
 
         field_class_ptr = native_bt.field_class_bit_array_create(self._ptr, length)
         self._check_field_class_create_status(field_class_ptr, 'bit array')
-
-        return bt2_field_class._BitArrayFieldClass._create_from_ptr(field_class_ptr)
+        fc = bt2_field_class._BitArrayFieldClass._create_from_ptr(field_class_ptr)
+        self._set_field_class_user_attrs(fc, user_attributes)
+        return fc
 
     def _create_integer_field_class(
-        self, create_func, py_cls, type_name, field_value_range, preferred_display_base
+        self,
+        create_func,
+        py_cls,
+        type_name,
+        field_value_range,
+        preferred_display_base,
+        user_attributes,
     ):
         field_class_ptr = create_func(self._ptr)
         self._check_field_class_create_status(field_class_ptr, type_name)
@@ -223,10 +257,11 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
         if preferred_display_base is not None:
             field_class._preferred_display_base = preferred_display_base
 
+        self._set_field_class_user_attrs(field_class, user_attributes)
         return field_class
 
     def create_signed_integer_field_class(
-        self, field_value_range=None, preferred_display_base=None
+        self, field_value_range=None, preferred_display_base=None, user_attributes=None
     ):
         return self._create_integer_field_class(
             native_bt.field_class_integer_signed_create,
@@ -234,10 +269,11 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
             'signed integer',
             field_value_range,
             preferred_display_base,
+            user_attributes,
         )
 
     def create_unsigned_integer_field_class(
-        self, field_value_range=None, preferred_display_base=None
+        self, field_value_range=None, preferred_display_base=None, user_attributes=None
     ):
         return self._create_integer_field_class(
             native_bt.field_class_integer_unsigned_create,
@@ -245,10 +281,11 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
             'unsigned integer',
             field_value_range,
             preferred_display_base,
+            user_attributes,
         )
 
     def create_signed_enumeration_field_class(
-        self, field_value_range=None, preferred_display_base=None
+        self, field_value_range=None, preferred_display_base=None, user_attributes=None
     ):
         return self._create_integer_field_class(
             native_bt.field_class_enumeration_signed_create,
@@ -256,10 +293,11 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
             'signed enumeration',
             field_value_range,
             preferred_display_base,
+            user_attributes,
         )
 
     def create_unsigned_enumeration_field_class(
-        self, field_value_range=None, preferred_display_base=None
+        self, field_value_range=None, preferred_display_base=None, user_attributes=None
     ):
         return self._create_integer_field_class(
             native_bt.field_class_enumeration_unsigned_create,
@@ -267,39 +305,46 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
             'unsigned enumeration',
             field_value_range,
             preferred_display_base,
+            user_attributes,
         )
 
-    def create_real_field_class(self, is_single_precision=False):
+    def create_real_field_class(self, is_single_precision=False, user_attributes=None):
         field_class_ptr = native_bt.field_class_real_create(self._ptr)
         self._check_field_class_create_status(field_class_ptr, 'real')
 
         field_class = bt2_field_class._RealFieldClass._create_from_ptr(field_class_ptr)
 
         field_class._is_single_precision = is_single_precision
+        self._set_field_class_user_attrs(field_class, user_attributes)
 
         return field_class
 
-    def create_structure_field_class(self):
+    def create_structure_field_class(self, user_attributes=None):
         field_class_ptr = native_bt.field_class_structure_create(self._ptr)
         self._check_field_class_create_status(field_class_ptr, 'structure')
+        fc = bt2_field_class._StructureFieldClass._create_from_ptr(field_class_ptr)
+        self._set_field_class_user_attrs(fc, user_attributes)
+        return fc
 
-        return bt2_field_class._StructureFieldClass._create_from_ptr(field_class_ptr)
-
-    def create_string_field_class(self):
+    def create_string_field_class(self, user_attributes=None):
         field_class_ptr = native_bt.field_class_string_create(self._ptr)
         self._check_field_class_create_status(field_class_ptr, 'string')
+        fc = bt2_field_class._StringFieldClass._create_from_ptr(field_class_ptr)
+        self._set_field_class_user_attrs(fc, user_attributes)
+        return fc
 
-        return bt2_field_class._StringFieldClass._create_from_ptr(field_class_ptr)
-
-    def create_static_array_field_class(self, elem_fc, length):
+    def create_static_array_field_class(self, elem_fc, length, user_attributes=None):
         utils._check_type(elem_fc, bt2_field_class._FieldClass)
         utils._check_uint64(length)
         ptr = native_bt.field_class_array_static_create(self._ptr, elem_fc._ptr, length)
         self._check_field_class_create_status(ptr, 'static array')
+        fc = bt2_field_class._StaticArrayFieldClass._create_from_ptr_and_get_ref(ptr)
+        self._set_field_class_user_attrs(fc, user_attributes)
+        return fc
 
-        return bt2_field_class._StaticArrayFieldClass._create_from_ptr_and_get_ref(ptr)
-
-    def create_dynamic_array_field_class(self, elem_fc, length_fc=None):
+    def create_dynamic_array_field_class(
+        self, elem_fc, length_fc=None, user_attributes=None
+    ):
         utils._check_type(elem_fc, bt2_field_class._FieldClass)
         length_fc_ptr = None
 
@@ -311,9 +356,13 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
             self._ptr, elem_fc._ptr, length_fc_ptr
         )
         self._check_field_class_create_status(ptr, 'dynamic array')
-        return bt2_field_class._DynamicArrayFieldClass._create_from_ptr(ptr)
+        fc = bt2_field_class._DynamicArrayFieldClass._create_from_ptr(ptr)
+        self._set_field_class_user_attrs(fc, user_attributes)
+        return fc
 
-    def create_option_field_class(self, content_fc, selector_fc=None):
+    def create_option_field_class(
+        self, content_fc, selector_fc=None, user_attributes=None
+    ):
         utils._check_type(content_fc, bt2_field_class._FieldClass)
 
         selector_fc_ptr = None
@@ -326,9 +375,11 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
             self._ptr, content_fc._ptr, selector_fc_ptr
         )
         self._check_field_class_create_status(ptr, 'option')
-        return bt2_field_class._create_field_class_from_ptr_and_get_ref(ptr)
+        fc = bt2_field_class._create_field_class_from_ptr_and_get_ref(ptr)
+        self._set_field_class_user_attrs(fc, user_attributes)
+        return fc
 
-    def create_variant_field_class(self, selector_fc=None):
+    def create_variant_field_class(self, selector_fc=None, user_attributes=None):
         selector_fc_ptr = None
 
         if selector_fc is not None:
@@ -337,7 +388,9 @@ class _TraceClass(object._SharedObject, collections.abc.Mapping):
 
         ptr = native_bt.field_class_variant_create(self._ptr, selector_fc_ptr)
         self._check_field_class_create_status(ptr, 'variant')
-        return bt2_field_class._create_field_class_from_ptr_and_get_ref(ptr)
+        fc = bt2_field_class._create_field_class_from_ptr_and_get_ref(ptr)
+        self._set_field_class_user_attrs(fc, user_attributes)
+        return fc
 
     # Add a listener to be called when the trace class is destroyed.
 
index 9bc6bae7497593be6b15e7da9ccd941268b19f30..5788e80cb1cc77a46fb159d22563843b2f551093 100644 (file)
@@ -87,6 +87,7 @@ class ClockClassTestCase(unittest.TestCase):
         self.assertEqual(cc.offset, bt2.ClockClassOffset())
         self.assertTrue(cc.origin_is_unix_epoch)
         self.assertIsNone(cc.uuid)
+        self.assertEqual(len(cc.user_attributes), 0)
 
     def test_create_name(self):
         def f(comp_self):
@@ -198,6 +199,25 @@ class ClockClassTestCase(unittest.TestCase):
 
         self.assertRaisesInComponentInit(TypeError, f)
 
+    def test_create_user_attributes(self):
+        def f(comp_self):
+            return comp_self._create_clock_class(user_attributes={'salut': 23})
+
+        cc = run_in_component_init(f)
+        self.assertEqual(cc.user_attributes, {'salut': 23})
+
+    def test_create_invalid_user_attributes(self):
+        def f(comp_self):
+            return comp_self._create_clock_class(user_attributes=object())
+
+        self.assertRaisesInComponentInit(TypeError, f)
+
+    def test_create_invalid_user_attributes_value_type(self):
+        def f(comp_self):
+            return comp_self._create_clock_class(user_attributes=23)
+
+        self.assertRaisesInComponentInit(TypeError, f)
+
 
 class ClockSnapshotTestCase(unittest.TestCase):
     def setUp(self):
index 10d58a7fd03852005e77493c488ac41e3f614662..ffc10bc4bdac703635f5c24f442844f78906ec51 100644 (file)
@@ -47,6 +47,7 @@ class EventClassTestCase(unittest.TestCase):
         self.assertIsNone(ec.payload_field_class)
         self.assertIsNone(ec.emf_uri)
         self.assertIsNone(ec.log_level)
+        self.assertEqual(len(ec.user_attributes), 0)
 
     def test_create_invalid_id(self):
         sc = self._tc.create_stream_class(assigns_automatic_event_class_id=False)
@@ -97,6 +98,18 @@ class EventClassTestCase(unittest.TestCase):
         with self.assertRaises(ValueError):
             self._stream_class.create_event_class(log_level='zoom')
 
+    def test_create_user_attributes(self):
+        ec = self._stream_class.create_event_class(user_attributes={'salut': 23})
+        self.assertEqual(ec.user_attributes, {'salut': 23})
+
+    def test_create_invalid_user_attributes(self):
+        with self.assertRaises(TypeError):
+            self._stream_class.create_event_class(user_attributes=object())
+
+    def test_create_invalid_user_attributes_value_type(self):
+        with self.assertRaises(TypeError):
+            self._stream_class.create_event_class(user_attributes=23)
+
     def test_stream_class(self):
         ec = self._stream_class.create_event_class()
         self.assertEqual(ec.stream_class.addr, self._stream_class.addr)
index e564cac5f2a9c8a6533b1360c6a057186fa40de5..45255e6710c845d63726ade8c247d92f41fbda40 100644 (file)
@@ -21,34 +21,59 @@ import bt2
 from utils import get_default_trace_class
 
 
-class BoolFieldClassTestCase(unittest.TestCase):
-    def setUp(self):
+class _TestFieldClass:
+    def test_create_user_attributes(self):
+        fc = self._create_default_field_class(user_attributes={'salut': 23})
+        self.assertEqual(fc.user_attributes, {'salut': 23})
+
+    def test_create_invalid_user_attributes(self):
+        with self.assertRaises(TypeError):
+            self._create_default_field_class(user_attributes=object())
+
+    def test_create_invalid_user_attributes_value_type(self):
+        with self.assertRaises(TypeError):
+            self._create_default_field_class(user_attributes=23)
+
+
+class BoolFieldClassTestCase(_TestFieldClass, unittest.TestCase):
+    def _create_default_field_class(self, **kwargs):
         tc = get_default_trace_class()
-        self._fc = tc.create_bool_field_class()
+        return tc.create_bool_field_class(**kwargs)
+
+    def setUp(self):
+        self._fc = self._create_default_field_class()
 
     def test_create_default(self):
         self.assertIsNotNone(self._fc)
+        self.assertEqual(len(self._fc.user_attributes), 0)
 
 
-class BitArrayFieldClassTestCase(unittest.TestCase):
+class BitArrayFieldClassTestCase(_TestFieldClass, unittest.TestCase):
+    def _create_field_class(self, *args, **kwargs):
+        tc = get_default_trace_class()
+        return tc.create_bit_array_field_class(*args, **kwargs)
+
+    def _create_default_field_class(self, **kwargs):
+        return self._create_field_class(17, **kwargs)
+
     def setUp(self):
-        self._tc = get_default_trace_class()
-        self._fc = self._tc.create_bit_array_field_class(17)
+        self._fc = self._create_default_field_class()
 
     def test_create_default(self):
         self.assertIsNotNone(self._fc)
+        self.assertEqual(len(self._fc.user_attributes), 0)
 
     def test_create_length_out_of_range(self):
         with self.assertRaises(ValueError):
-            self._tc.create_bit_array_field_class(65)
+            self._create_field_class(65)
 
     def test_create_length_zero(self):
         with self.assertRaises(ValueError):
-            self._tc.create_bit_array_field_class(0)
+            self._create_field_class(0)
 
     def test_create_length_invalid_type(self):
         with self.assertRaises(TypeError):
-            self._tc.create_bit_array_field_class('lel')
+            self._create_field_class('lel')
 
     def test_length_prop(self):
         self.assertEqual(self._fc.length, 17)
@@ -56,71 +81,92 @@ class BitArrayFieldClassTestCase(unittest.TestCase):
 
 class _TestIntegerFieldClassProps:
     def test_create_default(self):
-        fc = self._create_func()
+        fc = self._create_default_field_class()
         self.assertEqual(fc.field_value_range, 64)
         self.assertEqual(fc.preferred_display_base, bt2.IntegerDisplayBase.DECIMAL)
+        self.assertEqual(len(fc.user_attributes), 0)
 
     def test_create_range(self):
-        fc = self._create_func(field_value_range=35)
+        fc = self._create_field_class(field_value_range=35)
         self.assertEqual(fc.field_value_range, 35)
 
-        fc = self._create_func(36)
+        fc = self._create_field_class(36)
         self.assertEqual(fc.field_value_range, 36)
 
     def test_create_invalid_range(self):
         with self.assertRaises(TypeError):
-            self._create_func('yes')
+            self._create_field_class('yes')
 
         with self.assertRaises(TypeError):
-            self._create_func(field_value_range='yes')
+            self._create_field_class(field_value_range='yes')
 
         with self.assertRaises(ValueError):
-            self._create_func(field_value_range=-2)
+            self._create_field_class(field_value_range=-2)
 
         with self.assertRaises(ValueError):
-            self._create_func(field_value_range=0)
+            self._create_field_class(field_value_range=0)
 
     def test_create_base(self):
-        fc = self._create_func(
+        fc = self._create_field_class(
             preferred_display_base=bt2.IntegerDisplayBase.HEXADECIMAL
         )
         self.assertEqual(fc.preferred_display_base, bt2.IntegerDisplayBase.HEXADECIMAL)
 
     def test_create_invalid_base_type(self):
         with self.assertRaises(TypeError):
-            self._create_func(preferred_display_base='yes')
+            self._create_field_class(preferred_display_base='yes')
 
     def test_create_invalid_base_value(self):
         with self.assertRaises(ValueError):
-            self._create_func(preferred_display_base=444)
+            self._create_field_class(preferred_display_base=444)
 
     def test_create_full(self):
-        fc = self._create_func(24, preferred_display_base=bt2.IntegerDisplayBase.OCTAL)
+        fc = self._create_field_class(
+            24, preferred_display_base=bt2.IntegerDisplayBase.OCTAL
+        )
         self.assertEqual(fc.field_value_range, 24)
         self.assertEqual(fc.preferred_display_base, bt2.IntegerDisplayBase.OCTAL)
 
 
-class IntegerFieldClassTestCase(_TestIntegerFieldClassProps, unittest.TestCase):
-    def setUp(self):
-        self._tc = get_default_trace_class()
-        self._create_func = self._tc.create_signed_integer_field_class
+class SignedIntegerFieldClassTestCase(
+    _TestIntegerFieldClassProps, _TestFieldClass, unittest.TestCase
+):
+    def _create_field_class(self, *args, **kwargs):
+        tc = get_default_trace_class()
+        return tc.create_signed_integer_field_class(*args, **kwargs)
 
+    _create_default_field_class = _create_field_class
+
+
+class UnsignedIntegerFieldClassTestCase(
+    _TestIntegerFieldClassProps, _TestFieldClass, unittest.TestCase
+):
+    def _create_field_class(self, *args, **kwargs):
+        tc = get_default_trace_class()
+        return tc.create_unsigned_integer_field_class(*args, **kwargs)
+
+    _create_default_field_class = _create_field_class
 
-class RealFieldClassTestCase(unittest.TestCase):
-    def setUp(self):
-        self._tc = get_default_trace_class()
+
+class RealFieldClassTestCase(_TestFieldClass, unittest.TestCase):
+    def _create_field_class(self, *args, **kwargs):
+        tc = get_default_trace_class()
+        return tc.create_real_field_class(*args, **kwargs)
+
+    _create_default_field_class = _create_field_class
 
     def test_create_default(self):
-        fc = self._tc.create_real_field_class()
+        fc = self._create_field_class()
         self.assertFalse(fc.is_single_precision)
+        self.assertEqual(len(fc.user_attributes), 0)
 
     def test_create_is_single_precision(self):
-        fc = self._tc.create_real_field_class(is_single_precision=True)
+        fc = self._create_field_class(is_single_precision=True)
         self.assertTrue(fc.is_single_precision)
 
     def test_create_invalid_is_single_precision(self):
         with self.assertRaises(TypeError):
-            self._tc.create_real_field_class(is_single_precision='hohoho')
+            self._create_field_class(is_single_precision='hohoho')
 
 
 # Converts an _EnumerationFieldClassMapping to a list of ranges:
@@ -134,13 +180,12 @@ def enum_mapping_to_set(mapping):
 
 class _EnumerationFieldClassTestCase(_TestIntegerFieldClassProps):
     def setUp(self):
-        self._tc = get_default_trace_class()
         self._spec_set_up()
-        self._fc = self._create_func()
+        self._fc = self._create_default_field_class()
 
     def test_create_from_invalid_type(self):
         with self.assertRaises(TypeError):
-            self._create_func('coucou')
+            self._create_field_class('coucou')
 
     def test_add_mapping_simple(self):
         self._fc.add_mapping('hello', self._ranges1)
@@ -228,7 +273,7 @@ class _EnumerationFieldClassTestCase(_TestIntegerFieldClassProps):
 
 
 class UnsignedEnumerationFieldClassTestCase(
-    _EnumerationFieldClassTestCase, unittest.TestCase
+    _EnumerationFieldClassTestCase, _TestFieldClass, unittest.TestCase
 ):
     def _spec_set_up(self):
         self._ranges1 = bt2.UnsignedIntegerRangeSet([(1, 4), (18, 47)])
@@ -236,11 +281,16 @@ class UnsignedEnumerationFieldClassTestCase(
         self._ranges3 = bt2.UnsignedIntegerRangeSet([(8, 22), (48, 99)])
         self._inval_ranges = bt2.SignedIntegerRangeSet([(-8, -5), (48, 1928)])
         self._value_in_range_1_and_3 = 20
-        self._create_func = self._tc.create_unsigned_enumeration_field_class
+
+    def _create_field_class(self, *args, **kwargs):
+        tc = get_default_trace_class()
+        return tc.create_unsigned_enumeration_field_class(*args, **kwargs)
+
+    _create_default_field_class = _create_field_class
 
 
 class SignedEnumerationFieldClassTestCase(
-    _EnumerationFieldClassTestCase, unittest.TestCase
+    _EnumerationFieldClassTestCase, _TestFieldClass, unittest.TestCase
 ):
     def _spec_set_up(self):
         self._ranges1 = bt2.SignedIntegerRangeSet([(-10, -4), (18, 47)])
@@ -248,25 +298,37 @@ class SignedEnumerationFieldClassTestCase(
         self._ranges3 = bt2.SignedIntegerRangeSet([(-100, -1), (8, 16), (48, 99)])
         self._inval_ranges = bt2.UnsignedIntegerRangeSet([(8, 16), (48, 99)])
         self._value_in_range_1_and_3 = -7
-        self._create_func = self._tc.create_signed_enumeration_field_class
 
+    def _create_field_class(self, *args, **kwargs):
+        tc = get_default_trace_class()
+        return tc.create_signed_enumeration_field_class(*args, **kwargs)
+
+    _create_default_field_class = _create_field_class
 
-class StringFieldClassTestCase(unittest.TestCase):
-    def setUp(self):
+
+class StringFieldClassTestCase(_TestFieldClass, unittest.TestCase):
+    def _create_field_class(self, *args, **kwargs):
         tc = get_default_trace_class()
-        self._fc = tc.create_string_field_class()
+        return tc.create_string_field_class(*args, **kwargs)
+
+    _create_default_field_class = _create_field_class
+
+    def setUp(self):
+        self._fc = self._create_default_field_class()
 
     def test_create_default(self):
         self.assertIsNotNone(self._fc)
+        self.assertEqual(len(self._fc.user_attributes), 0)
 
 
 class _TestElementContainer:
     def setUp(self):
         self._tc = get_default_trace_class()
-        self._fc = self._create_default_fc()
+        self._fc = self._create_default_field_class()
 
     def test_create_default(self):
         self.assertIsNotNone(self._fc)
+        self.assertEqual(len(self._fc.user_attributes), 0)
 
     def test_append_element(self):
         int_field_class = self._tc.create_signed_integer_field_class(32)
@@ -299,7 +361,7 @@ class _TestElementContainer:
             self._append_element_method(self._fc, 'yes', sub_fc2)
 
     def test_iadd(self):
-        other_fc = self._create_default_fc()
+        other_fc = self._create_default_field_class()
         a_field_class = self._tc.create_real_field_class()
         b_field_class = self._tc.create_signed_integer_field_class(17)
         self._append_element_method(self._fc, 'a_float', a_field_class)
@@ -372,6 +434,7 @@ class _TestElementContainer:
             self.assertEqual(element.name, test_elem[0])
             self.assertEqual(name, element.name)
             self.assertEqual(element.field_class.addr, test_elem[1].addr)
+            self.assertEqual(len(element.user_attributes), 0)
 
     def test_at_index(self):
         a_fc = self._tc.create_signed_integer_field_class(32)
@@ -400,28 +463,61 @@ class _TestElementContainer:
         with self.assertRaises(IndexError):
             self._at_index_method(self._fc, len(self._fc))
 
+    def test_user_attributes(self):
+        self._append_element_method(
+            self._fc,
+            'c',
+            self._tc.create_string_field_class(),
+            user_attributes={'salut': 23},
+        )
+        self.assertEqual(self._fc['c'].user_attributes, {'salut': 23})
 
-class StructureFieldClassTestCase(_TestElementContainer, unittest.TestCase):
+    def test_invalid_user_attributes(self):
+        with self.assertRaises(TypeError):
+            self._append_element_method(
+                self._fc,
+                'c',
+                self._tc.create_string_field_class(),
+                user_attributes=object(),
+            )
+
+    def test_invalid_user_attributes_value_type(self):
+        with self.assertRaises(TypeError):
+            self._append_element_method(
+                self._fc, 'c', self._tc.create_string_field_class(), user_attributes=23
+            )
+
+
+class StructureFieldClassTestCase(
+    _TestFieldClass, _TestElementContainer, unittest.TestCase
+):
     _append_element_method = staticmethod(bt2._StructureFieldClass.append_member)
     _at_index_method = staticmethod(bt2._StructureFieldClass.member_at_index)
 
-    def _create_default_fc(self):
-        return self._tc.create_structure_field_class()
+    def _create_field_class(self, *args, **kwargs):
+        tc = get_default_trace_class()
+        return tc.create_structure_field_class(*args, **kwargs)
 
+    _create_default_field_class = _create_field_class
+
+
+class OptionFieldClassTestCase(_TestFieldClass, unittest.TestCase):
+    def _create_default_field_class(self, *args, **kwargs):
+        return self._tc.create_option_field_class(self._content_fc, **kwargs)
 
-class OptionFieldClassTestCase(unittest.TestCase):
     def setUp(self):
         self._tc = get_default_trace_class()
         self._content_fc = self._tc.create_signed_integer_field_class(23)
         self._tag_fc = self._tc.create_bool_field_class()
 
     def test_create_default(self):
-        fc = self._tc.create_option_field_class(self._content_fc)
+        fc = self._create_default_field_class()
         self.assertEqual(fc.field_class.addr, self._content_fc.addr)
         self.assertIsNone(fc.selector_field_path, None)
+        self.assertEqual(len(fc.user_attributes), 0)
 
     def _create_field_class_for_field_path_test(self):
-        fc = self._tc.create_option_field_class(self._content_fc, self._tag_fc)
+        fc = self._create_default_field_class(selector_fc=self._tag_fc)
 
         foo_fc = self._tc.create_real_field_class()
         bar_fc = self._tc.create_string_field_class()
@@ -482,7 +578,7 @@ class OptionFieldClassTestCase(unittest.TestCase):
 
 
 class VariantFieldClassWithoutSelectorTestCase(
-    _TestElementContainer, unittest.TestCase
+    _TestFieldClass, _TestElementContainer, unittest.TestCase
 ):
     _append_element_method = staticmethod(
         bt2._VariantFieldClassWithoutSelector.append_option
@@ -491,21 +587,27 @@ class VariantFieldClassWithoutSelectorTestCase(
         bt2._VariantFieldClassWithoutSelector.option_at_index
     )
 
-    def _create_default_fc(self):
-        return self._tc.create_variant_field_class()
+    def _create_field_class(self, *args, **kwargs):
+        tc = get_default_trace_class()
+        return tc.create_variant_field_class(*args, **kwargs)
+
+    _create_default_field_class = _create_field_class
 
 
 class _VariantFieldClassWithSelectorTestCase:
+    def _create_default_field_class(self, *args, **kwargs):
+        return self._tc.create_variant_field_class(
+            *args, selector_fc=self._selector_fc, **kwargs
+        )
+
     def setUp(self):
         self._tc = get_default_trace_class()
         self._spec_set_up()
-        self._fc = self._create_default_fc()
-
-    def _create_default_fc(self):
-        return self._tc.create_variant_field_class(self._selector_fc)
+        self._fc = self._create_default_field_class()
 
     def test_create_default(self):
         self.assertIsNotNone(self._fc)
+        self.assertEqual(len(self._fc.user_attributes), 0)
 
     def test_append_element(self):
         str_field_class = self._tc.create_string_field_class()
@@ -555,8 +657,35 @@ class _VariantFieldClassWithSelectorTestCase:
         with self.assertRaises(TypeError):
             self._fc.append_option(self._fc, sub_fc, self._inval_ranges)
 
+    def test_user_attributes(self):
+        self._fc.append_option(
+            'c',
+            self._tc.create_string_field_class(),
+            self._ranges1,
+            user_attributes={'salut': 23},
+        )
+        self.assertEqual(self._fc['c'].user_attributes, {'salut': 23})
+
+    def test_invalid_user_attributes(self):
+        with self.assertRaises(TypeError):
+            self._fc.append_option(
+                'c',
+                self._tc.create_string_field_class(),
+                self._ranges1,
+                user_attributes=object(),
+            )
+
+    def test_invalid_user_attributes_value_type(self):
+        with self.assertRaises(TypeError):
+            self._fc.append_option(
+                'c',
+                self._tc.create_string_field_class(),
+                self._ranges1,
+                user_attributes=23,
+            )
+
     def test_iadd(self):
-        other_fc = self._create_default_fc()
+        other_fc = self._create_default_field_class()
         a_field_class = self._tc.create_real_field_class()
         self._fc.append_option('a_float', a_field_class, self._ranges1)
         c_field_class = self._tc.create_string_field_class()
@@ -762,6 +891,7 @@ class StaticArrayFieldClassTestCase(unittest.TestCase):
         fc = self._tc.create_static_array_field_class(self._elem_fc, 45)
         self.assertEqual(fc.element_field_class.addr, self._elem_fc.addr)
         self.assertEqual(fc.length, 45)
+        self.assertEqual(len(fc.user_attributes), 0)
 
     def test_create_invalid_elem_field_class(self):
         with self.assertRaises(TypeError):
@@ -790,6 +920,7 @@ class DynamicArrayFieldClassTestCase(unittest.TestCase):
         fc = self._tc.create_dynamic_array_field_class(self._elem_fc)
         self.assertEqual(fc.element_field_class.addr, self._elem_fc.addr)
         self.assertIsNone(fc.length_field_path, None)
+        self.assertEqual(len(fc.user_attributes), 0)
 
     def _create_field_class_for_field_path_test(self):
         # Create something a field class that is equivalent to:
index 7f80f8dfb67a97afe5df82ab227c0494746960b0..79e9b0238169e7d65c1190cb427dedab964f0868 100644 (file)
@@ -32,6 +32,7 @@ class StreamTestCase(unittest.TestCase):
     def test_create_default(self):
         stream = self._tr.create_stream(self._sc)
         self.assertIsNone(stream.name)
+        self.assertEqual(len(stream.user_attributes), 0)
 
     def test_name(self):
         stream = self._tr.create_stream(self._sc, name='équidistant')
@@ -41,6 +42,18 @@ class StreamTestCase(unittest.TestCase):
         with self.assertRaises(TypeError):
             self._tr.create_stream(self._sc, name=22)
 
+    def test_create_user_attributes(self):
+        stream = self._tr.create_stream(self._sc, user_attributes={'salut': 23})
+        self.assertEqual(stream.user_attributes, {'salut': 23})
+
+    def test_create_invalid_user_attributes(self):
+        with self.assertRaises(TypeError):
+            self._tr.create_stream(self._sc, user_attributes=object())
+
+    def test_create_invalid_user_attributes_value_type(self):
+        with self.assertRaises(TypeError):
+            self._tr.create_stream(self._sc, user_attributes=23)
+
     def test_stream_class(self):
         stream = self._tr.create_stream(self._sc)
         self.assertEqual(stream.cls, self._sc)
index 473efca7923aa1468195cfa4ffdcc3983a0cb2d6..60b86883439319d6fb30ce78607ddb839d05410e 100644 (file)
@@ -46,6 +46,7 @@ class StreamClassTestCase(unittest.TestCase):
         self.assertFalse(sc.discarded_events_have_default_clock_snapshots)
         self.assertFalse(sc.supports_discarded_packets)
         self.assertFalse(sc.discarded_packets_have_default_clock_snapshots)
+        self.assertEqual(len(sc.user_attributes), 0)
 
     def test_create_name(self):
         sc = self._tc.create_stream_class(name='bozo')
@@ -89,6 +90,18 @@ class StreamClassTestCase(unittest.TestCase):
         with self.assertRaises(TypeError):
             self._tc.create_stream_class(default_clock_class=12)
 
+    def test_create_user_attributes(self):
+        sc = self._tc.create_stream_class(user_attributes={'salut': 23})
+        self.assertEqual(sc.user_attributes, {'salut': 23})
+
+    def test_create_invalid_user_attributes(self):
+        with self.assertRaises(TypeError):
+            self._tc.create_stream_class(user_attributes=object())
+
+    def test_create_invalid_user_attributes_value_type(self):
+        with self.assertRaises(TypeError):
+            self._tc.create_stream_class(user_attributes=23)
+
     def test_automatic_stream_ids(self):
         sc = self._tc.create_stream_class(assigns_automatic_stream_id=True)
         self.assertTrue(sc.assigns_automatic_stream_id)
index 85615b4b515ba7d001e15fc80661af24c848daa3..5a9275628cf7b0ac6018955aa146eea6448f9704 100644 (file)
@@ -30,11 +30,24 @@ class TraceTestCase(unittest.TestCase):
         self.assertIsNone(trace.name)
         self.assertIsNone(trace.uuid)
         self.assertEqual(len(trace.env), 0)
+        self.assertEqual(len(trace.user_attributes), 0)
 
     def test_create_invalid_name(self):
         with self.assertRaises(TypeError):
             self._tc(name=17)
 
+    def test_create_user_attributes(self):
+        trace = self._tc(user_attributes={'salut': 23})
+        self.assertEqual(trace.user_attributes, {'salut': 23})
+
+    def test_create_invalid_user_attributes(self):
+        with self.assertRaises(TypeError):
+            self._tc(user_attributes=object())
+
+    def test_create_invalid_user_attributes_value_type(self):
+        with self.assertRaises(TypeError):
+            self._tc(user_attributes=23)
+
     def test_attr_trace_class(self):
         trace = self._tc()
         self.assertEqual(trace.cls.addr, self._tc.addr)
index 97ede58424d8e1499b29dda3206a81b126d34579..d8f60143fb01c18e3f65ebc98d833c8f4e59b2cb 100644 (file)
@@ -21,6 +21,17 @@ from utils import run_in_component_init, get_default_trace_class
 
 
 class TraceClassTestCase(unittest.TestCase):
+    def assertRaisesInComponentInit(self, expected_exc_type, user_code):
+        def f(comp_self):
+            try:
+                user_code(comp_self)
+            except Exception as exc:
+                return type(exc)
+
+        exc_type = run_in_component_init(f)
+        self.assertIsNotNone(exc_type)
+        self.assertEqual(exc_type, expected_exc_type)
+
     def test_create_default(self):
         def f(comp_self):
             return comp_self._create_trace_class()
@@ -29,6 +40,26 @@ class TraceClassTestCase(unittest.TestCase):
 
         self.assertEqual(len(tc), 0)
         self.assertTrue(tc.assigns_automatic_stream_class_id)
+        self.assertEqual(len(tc.user_attributes), 0)
+
+    def test_create_user_attributes(self):
+        def f(comp_self):
+            return comp_self._create_trace_class(user_attributes={'salut': 23})
+
+        tc = run_in_component_init(f)
+        self.assertEqual(tc.user_attributes, {'salut': 23})
+
+    def test_create_invalid_user_attributes(self):
+        def f(comp_self):
+            return comp_self._create_trace_class(user_attributes=object())
+
+        self.assertRaisesInComponentInit(TypeError, f)
+
+    def test_create_invalid_user_attributes_value_type(self):
+        def f(comp_self):
+            return comp_self._create_trace_class(user_attributes=23)
+
+        self.assertRaisesInComponentInit(TypeError, f)
 
     def test_automatic_stream_class_id(self):
         def f(comp_self):
This page took 0.044794 seconds and 4 git commands to generate.