From c87f23fa2be424d0e97c4ca677c738a847813acf Mon Sep 17 00:00:00 2001 From: Simon Marchi Date: Mon, 12 Aug 2019 18:31:15 -0400 Subject: [PATCH] bt2: make ComponentSpec take a component class object ... rather than a plugin and component class name. It can be a user component class: class MySource(bt2._UserSourceComponent, ...): ... spec = bt2.ComponentSpec(MySource) or a component class from a plugin: cc = bt2.find_plugin('my_plugin').source_component_classes['cerfeuil'] spec = bt2.ComponentSpec(cc) This is a bit more general compared to taking a plugin name and a component class name, since it allows passing an component class that does not come from a plugin. It also allows the user to use component classes from plugins in non-standard locations, by looking up plugins manually with bt2.find_plugins_in_path. For convenience, there is a new static method bt2.ComponentSpec.from_named_plugin_and_component_class that creates a ComponentSpec from a plugin and component class names. Change-Id: I39e3fec6bbc7e7a9ba375065c9318b61b1791e35 Signed-off-by: Simon Marchi Reviewed-on: https://review.lttng.org/c/babeltrace/+/1890 CI-Build: Francis Deslauriers Tested-by: jenkins Reviewed-by: Francis Deslauriers --- .../bt2/trace_collection_message_iterator.py | 125 +++++----- .../test_trace_collection_message_iterator.py | 233 +++++++++++++++--- 2 files changed, 264 insertions(+), 94 deletions(-) diff --git a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py index bd00de12..d14efe23 100644 --- a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py +++ b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py @@ -39,6 +39,8 @@ _ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec']) class _BaseComponentSpec: + # Base for any component spec that can be passed to + # TraceCollectionMessageIterator. def __init__(self, params, obj, logging_level): if logging_level is not None: utils._check_log_level(logging_level) @@ -61,35 +63,71 @@ class _BaseComponentSpec: class ComponentSpec(_BaseComponentSpec): + # A component spec with a specific component class. def __init__( self, - plugin_name, - class_name, + component_class, params=None, obj=None, - logging_level=bt2_logging.LoggingLevel.NONE, + logging_level=bt2.LoggingLevel.NONE, ): if type(params) is str: params = {'inputs': [params]} super().__init__(params, obj, logging_level) - utils._check_str(plugin_name) - utils._check_str(class_name) + is_cc_object = isinstance( + component_class, (bt2._SourceComponentClass, bt2._FilterComponentClass) + ) + is_user_cc_type = isinstance( + component_class, bt2_component._UserComponentType + ) and issubclass( + component_class, (bt2._UserSourceComponent, bt2._UserFilterComponent) + ) - self._plugin_name = plugin_name - self._class_name = class_name + if not is_cc_object and not is_user_cc_type: + raise TypeError( + "'{}' is not a source or filter component class".format( + component_class.__class__.__name__ + ) + ) - @property - def plugin_name(self): - return self._plugin_name + self._component_class = component_class @property - def class_name(self): - return self._class_name + def component_class(self): + return self._component_class + + @classmethod + def from_named_plugin_and_component_class( + cls, + plugin_name, + component_class_name, + params=None, + obj=None, + logging_level=bt2.LoggingLevel.NONE, + ): + plugin = bt2.find_plugin(plugin_name) + + if plugin is None: + raise ValueError('no such plugin: {}'.format(plugin_name)) + + if component_class_name in plugin.source_component_classes: + comp_class = plugin.source_component_classes[component_class_name] + elif component_class_name in plugin.filter_component_classes: + comp_class = plugin.filter_component_classes[component_class_name] + else: + raise KeyError( + 'source or filter component class `{}` not found in plugin `{}`'.format( + component_class_name, plugin_name + ) + ) + + return cls(comp_class, params, obj, logging_level) class AutoSourceComponentSpec(_BaseComponentSpec): + # A component spec that does automatic source discovery. _no_obj = object() def __init__(self, input, params=None, obj=_no_obj, logging_level=None): @@ -178,7 +216,7 @@ def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): params['inputs'] = comp_inputs comp_specs.append( - ComponentSpec( + ComponentSpec.from_named_plugin_and_component_class( plugin_name, class_name, params=params, @@ -220,11 +258,6 @@ def _get_ns(obj): return int(s * 1e9) -class _CompClsType: - SOURCE = 0 - FILTER = 1 - - class _TraceCollectionMessageIteratorProxySink(bt2_component._UserSinkComponent): def __init__(self, params, msg_list): assert type(msg_list) is list @@ -421,8 +454,8 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): comp_cls = plugin.filter_component_classes['trimmer'] return self._graph.add_component(comp_cls, name, params) - def _get_unique_comp_name(self, comp_spec): - name = '{}-{}'.format(comp_spec.plugin_name, comp_spec.class_name) + def _get_unique_comp_name(self, comp_cls): + name = comp_cls.name comps_and_specs = itertools.chain( self._src_comps_and_specs, self._flt_comps_and_specs ) @@ -433,30 +466,9 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): return name - def _component_spec_class(self, comp_spec, comp_cls_type): - plugin = bt2.find_plugin(comp_spec.plugin_name) - - if plugin is None: - raise ValueError('no such plugin: {}'.format(comp_spec.plugin_name)) - - if comp_cls_type == _CompClsType.SOURCE: - comp_classes = plugin.source_component_classes - else: - comp_classes = plugin.filter_component_classes - - if comp_spec.class_name not in comp_classes: - cc_type = 'source' if comp_cls_type == _CompClsType.SOURCE else 'filter' - raise ValueError( - 'no such {} component class in "{}" plugin: {}'.format( - cc_type, comp_spec.plugin_name, comp_spec.class_name - ) - ) - - return comp_classes[comp_spec.class_name] - - def _create_comp(self, comp_spec, comp_cls_type): - comp_cls = self._component_spec_class(comp_spec, comp_cls_type) - name = self._get_unique_comp_name(comp_spec) + def _create_comp(self, comp_spec): + comp_cls = comp_spec.component_class + name = self._get_unique_comp_name(comp_cls) comp = self._graph.add_component( comp_cls, name, comp_spec.params, comp_spec.obj, comp_spec.logging_level ) @@ -499,25 +511,24 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): self._connect_src_comp_port(component, port) def _get_greatest_operative_mip_version(self): - def append_comp_specs_descriptors(descriptors, comp_specs, comp_cls_type): + def append_comp_specs_descriptors(descriptors, comp_specs): for comp_spec in comp_specs: - comp_cls = self._component_spec_class(comp_spec, comp_cls_type) descriptors.append( - bt2.ComponentDescriptor(comp_cls, comp_spec.params, comp_spec.obj) + bt2.ComponentDescriptor( + comp_spec.component_class, comp_spec.params, comp_spec.obj + ) ) descriptors = [] - append_comp_specs_descriptors( - descriptors, self._src_comp_specs, _CompClsType.SOURCE - ) - append_comp_specs_descriptors( - descriptors, self._flt_comp_specs, _CompClsType.FILTER - ) + append_comp_specs_descriptors(descriptors, self._src_comp_specs) + append_comp_specs_descriptors(descriptors, self._flt_comp_specs) if self._stream_intersection_mode: # we also need at least one `flt.utils.trimmer` component - comp_spec = ComponentSpec('utils', 'trimmer') - append_comp_specs_descriptors(descriptors, [comp_spec], _CompClsType.FILTER) + comp_spec = ComponentSpec.from_named_plugin_and_component_class( + 'utils', 'trimmer' + ) + append_comp_specs_descriptors(descriptors, [comp_spec]) mip_version = bt2.get_greatest_operative_mip_version(descriptors) @@ -543,7 +554,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): # create extra filter components (chained) for comp_spec in self._flt_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.FILTER) + comp = self._create_comp(comp_spec) self._flt_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) # connect the extra filter chain @@ -561,7 +572,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): # it does not exist yet (it needs the created component to # exist). for comp_spec in self._src_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.SOURCE) + comp = self._create_comp(comp_spec) self._src_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) # Now we connect the ports which exist at this point. We allow diff --git a/tests/bindings/python/bt2/test_trace_collection_message_iterator.py b/tests/bindings/python/bt2/test_trace_collection_message_iterator.py index 16261708..3abbe358 100644 --- a/tests/bindings/python/bt2/test_trace_collection_message_iterator.py +++ b/tests/bindings/python/bt2/test_trace_collection_message_iterator.py @@ -40,28 +40,135 @@ _AUTO_SOURCE_DISCOVERY_PARAMS_LOG_LEVEL_PATH = os.path.join( ) +class _SomeSource( + bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator +): + pass + + +class _SomeFilter( + bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator +): + pass + + +class _SomeSink(bt2._UserSinkComponent): + def _user_consume(self): + pass + + class ComponentSpecTestCase(unittest.TestCase): - def test_create_good_no_params(self): - bt2.ComponentSpec('plugin', 'compcls') + def setUp(self): + # A source CC from a plugin. + self._dmesg_cc = bt2.find_plugin('text').source_component_classes['dmesg'] + assert self._dmesg_cc is not None + + # A filter CC from a plugin. + self._muxer_cc = bt2.find_plugin('utils').filter_component_classes['muxer'] + assert self._muxer_cc is not None + + # A sink CC from a plugin. + self._pretty_cc = bt2.find_plugin('text').sink_component_classes['pretty'] + assert self._pretty_cc is not None + + def test_create_source_from_name(self): + spec = bt2.ComponentSpec.from_named_plugin_and_component_class('text', 'dmesg') + self.assertEqual(spec.component_class.name, 'dmesg') - def test_create_good_with_params(self): - bt2.ComponentSpec('plugin', 'compcls', {'salut': 23}) + def test_create_source_from_plugin(self): + spec = bt2.ComponentSpec(self._dmesg_cc) + self.assertEqual(spec.component_class.name, 'dmesg') - def test_create_good_with_path_params(self): - spec = bt2.ComponentSpec('plugin', 'compcls', 'a path') + def test_create_source_from_user(self): + spec = bt2.ComponentSpec(_SomeSource) + self.assertEqual(spec.component_class.name, '_SomeSource') + + def test_create_filter_from_name(self): + spec = bt2.ComponentSpec.from_named_plugin_and_component_class('utils', 'muxer') + self.assertEqual(spec.component_class.name, 'muxer') + + def test_create_filter_from_object(self): + spec = bt2.ComponentSpec(self._muxer_cc) + self.assertEqual(spec.component_class.name, 'muxer') + + def test_create_sink_from_name(self): + with self.assertRaisesRegex( + KeyError, + 'source or filter component class `pretty` not found in plugin `text`', + ): + bt2.ComponentSpec.from_named_plugin_and_component_class('text', 'pretty') + + def test_create_sink_from_object(self): + with self.assertRaisesRegex( + TypeError, "'_SinkComponentClass' is not a source or filter component class" + ): + bt2.ComponentSpec(self._pretty_cc) + + def test_create_from_object_with_params(self): + spec = bt2.ComponentSpec(self._dmesg_cc, {'salut': 23}) + self.assertEqual(spec.params['salut'], 23) + + def test_create_from_name_with_params(self): + spec = bt2.ComponentSpec.from_named_plugin_and_component_class( + 'text', 'dmesg', {'salut': 23} + ) + self.assertEqual(spec.params['salut'], 23) + + def test_create_from_object_with_path_params(self): + spec = spec = bt2.ComponentSpec(self._dmesg_cc, 'a path') self.assertEqual(spec.params['inputs'], ['a path']) - def test_create_wrong_plugin_name_type(self): - with self.assertRaises(TypeError): - bt2.ComponentSpec(23, 'compcls') + def test_create_from_name_with_path_params(self): + spec = spec = bt2.ComponentSpec.from_named_plugin_and_component_class( + 'text', 'dmesg', 'a path' + ) + self.assertEqual(spec.params['inputs'], ['a path']) - def test_create_wrong_component_class_name_type(self): - with self.assertRaises(TypeError): - bt2.ComponentSpec('plugin', 190) + def test_create_wrong_comp_class_type(self): + with self.assertRaisesRegex( + TypeError, "'int' is not a source or filter component class" + ): + bt2.ComponentSpec(18) + + def test_create_from_name_wrong_plugin_name_type(self): + with self.assertRaisesRegex(TypeError, "'int' is not a 'str' object"): + bt2.ComponentSpec.from_named_plugin_and_component_class(23, 'compcls') + + def test_create_from_name_non_existent_plugin(self): + with self.assertRaisesRegex( + ValueError, "no such plugin: this_plugin_does_not_exist" + ): + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'this_plugin_does_not_exist', 'compcls' + ) + + def test_create_from_name_wrong_component_class_name_type(self): + with self.assertRaisesRegex(TypeError, "'int' is not a 'str' object"): + bt2.ComponentSpec.from_named_plugin_and_component_class('utils', 190) def test_create_wrong_params_type(self): - with self.assertRaises(TypeError): - bt2.ComponentSpec('dwdw', 'compcls', datetime.datetime.now()) + with self.assertRaisesRegex( + TypeError, "cannot create value object from 'datetime' object" + ): + bt2.ComponentSpec(self._dmesg_cc, params=datetime.datetime.now()) + + def test_create_from_name_wrong_params_type(self): + with self.assertRaisesRegex( + TypeError, "cannot create value object from 'datetime' object" + ): + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'text', 'dmesg', datetime.datetime.now() + ) + + def test_create_wrong_log_level_type(self): + with self.assertRaisesRegex(TypeError, "'str' is not an 'int' object"): + bt2.ComponentSpec(self._dmesg_cc, logging_level='banane') + + def test_create_from_name_wrong_log_level_type(self): + with self.assertRaisesRegex(TypeError, "'str' is not an 'int' object"): + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'text', 'dmesg', logging_level='banane' + ) # Return a map, msg type -> number of messages of this type. @@ -80,47 +187,73 @@ def _count_msgs_by_type(msgs): class TraceCollectionMessageIteratorTestCase(unittest.TestCase): def test_create_wrong_stream_intersection_mode_type(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] with self.assertRaises(TypeError): bt2.TraceCollectionMessageIterator(specs, stream_intersection_mode=23) def test_create_wrong_begin_type(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] with self.assertRaises(TypeError): bt2.TraceCollectionMessageIterator(specs, begin='hi') def test_create_wrong_end_type(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] with self.assertRaises(TypeError): bt2.TraceCollectionMessageIterator(specs, begin='lel') - def test_create_no_such_plugin(self): - specs = [bt2.ComponentSpec('77', '101', _3EVENTS_INTERSECT_TRACE_PATH)] - - with self.assertRaises(ValueError): - bt2.TraceCollectionMessageIterator(specs) - def test_create_begin_s(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] bt2.TraceCollectionMessageIterator(specs, begin=19457.918232) def test_create_end_s(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] bt2.TraceCollectionMessageIterator(specs, end=123.12312) def test_create_begin_datetime(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] bt2.TraceCollectionMessageIterator(specs, begin=datetime.datetime.now()) def test_create_end_datetime(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] bt2.TraceCollectionMessageIterator(specs, end=datetime.datetime.now()) def test_iter_no_intersection(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] msg_iter = bt2.TraceCollectionMessageIterator(specs) msgs = list(msg_iter) self.assertEqual(len(msgs), 28) @@ -129,7 +262,9 @@ class TraceCollectionMessageIteratorTestCase(unittest.TestCase): # Same as the above, but we pass a single spec instead of a spec list. def test_iter_specs_not_list(self): - spec = bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH) + spec = bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) msg_iter = bt2.TraceCollectionMessageIterator(spec) msgs = list(msg_iter) self.assertEqual(len(msgs), 28) @@ -137,14 +272,22 @@ class TraceCollectionMessageIteratorTestCase(unittest.TestCase): self.assertEqual(hist[bt2._EventMessage], 8) def test_iter_custom_filter(self): - src_spec = bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH) - flt_spec = bt2.ComponentSpec('utils', 'trimmer', {'end': '13515309.000000075'}) + src_spec = bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + flt_spec = bt2.ComponentSpec.from_named_plugin_and_component_class( + 'utils', 'trimmer', {'end': '13515309.000000075'} + ) msg_iter = bt2.TraceCollectionMessageIterator(src_spec, flt_spec) hist = _count_msgs_by_type(msg_iter) self.assertEqual(hist[bt2._EventMessage], 5) def test_iter_intersection(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] msg_iter = bt2.TraceCollectionMessageIterator( specs, stream_intersection_mode=True ) @@ -154,13 +297,19 @@ class TraceCollectionMessageIteratorTestCase(unittest.TestCase): self.assertEqual(hist[bt2._EventMessage], 3) def test_iter_intersection_no_inputs_param(self): - specs = [bt2.ComponentSpec('text', 'dmesg', {'read-from-stdin': True})] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'text', 'dmesg', {'read-from-stdin': True} + ) + ] with self.assertRaises(ValueError): bt2.TraceCollectionMessageIterator(specs, stream_intersection_mode=True) def test_iter_no_intersection_two_traces(self): - spec = bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH) + spec = bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) specs = [spec, spec] msg_iter = bt2.TraceCollectionMessageIterator(specs) msgs = list(msg_iter) @@ -169,13 +318,21 @@ class TraceCollectionMessageIteratorTestCase(unittest.TestCase): self.assertEqual(hist[bt2._EventMessage], 16) def test_iter_no_intersection_begin(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] msg_iter = bt2.TraceCollectionMessageIterator(specs, begin=13515309.000000023) hist = _count_msgs_by_type(msg_iter) self.assertEqual(hist[bt2._EventMessage], 6) def test_iter_no_intersection_end(self): - specs = [bt2.ComponentSpec('ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH)] + specs = [ + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _3EVENTS_INTERSECT_TRACE_PATH + ) + ] msg_iter = bt2.TraceCollectionMessageIterator(specs, end=13515309.000000075) hist = _count_msgs_by_type(msg_iter) self.assertEqual(hist[bt2._EventMessage], 5) @@ -207,7 +364,9 @@ class TraceCollectionMessageIteratorTestCase(unittest.TestCase): [ _3EVENTS_INTERSECT_TRACE_PATH, bt2.AutoSourceComponentSpec(_SEQUENCE_TRACE_PATH), - bt2.ComponentSpec('ctf', 'fs', _NOINTERSECT_TRACE_PATH), + bt2.ComponentSpec.from_named_plugin_and_component_class( + 'ctf', 'fs', _NOINTERSECT_TRACE_PATH + ), ] ) msgs = list(msg_iter) -- 2.34.1