X-Git-Url: http://git.efficios.com/?a=blobdiff_plain;f=src%2Fbindings%2Fpython%2Fbt2%2Fbt2%2Ftrace_collection_message_iterator.py;h=74417776d07d15170258b931f075cb1516779efc;hb=ed47ebcd99fd2fb58770dc733d9318b57858de68;hp=56b48da9ff8061fe4a96e70870495ec711c313a2;hpb=f3c9a159782f70dbd0e5dedb37e4a1ef8a6d304e;p=babeltrace.git diff --git a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py index 56b48da9..74417776 100644 --- a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py +++ b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py @@ -24,7 +24,6 @@ from bt2 import utils, native_bt import bt2 import itertools from bt2 import message_iterator as bt2_message_iterator -from bt2 import logging as bt2_logging from bt2 import port as bt2_port from bt2 import component as bt2_component from bt2 import value as bt2_value @@ -39,6 +38,8 @@ _ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec']) class _BaseComponentSpec: + # Base for any component spec that can be passed to + # TraceCollectionMessageIterator. def __init__(self, params, obj, logging_level): if logging_level is not None: utils._check_log_level(logging_level) @@ -61,35 +62,71 @@ class _BaseComponentSpec: class ComponentSpec(_BaseComponentSpec): + # A component spec with a specific component class. def __init__( self, - plugin_name, - class_name, + component_class, params=None, obj=None, - logging_level=bt2_logging.LoggingLevel.NONE, + logging_level=bt2.LoggingLevel.NONE, ): if type(params) is str: params = {'inputs': [params]} super().__init__(params, obj, logging_level) - utils._check_str(plugin_name) - utils._check_str(class_name) + is_cc_object = isinstance( + component_class, (bt2._SourceComponentClass, bt2._FilterComponentClass) + ) + is_user_cc_type = isinstance( + component_class, bt2_component._UserComponentType + ) and issubclass( + component_class, (bt2._UserSourceComponent, bt2._UserFilterComponent) + ) - self._plugin_name = plugin_name - self._class_name = class_name + if not is_cc_object and not is_user_cc_type: + raise TypeError( + "'{}' is not a source or filter component class".format( + component_class.__class__.__name__ + ) + ) - @property - def plugin_name(self): - return self._plugin_name + self._component_class = component_class @property - def class_name(self): - return self._class_name + def component_class(self): + return self._component_class + + @classmethod + def from_named_plugin_and_component_class( + cls, + plugin_name, + component_class_name, + params=None, + obj=None, + logging_level=bt2.LoggingLevel.NONE, + ): + plugin = bt2.find_plugin(plugin_name) + + if plugin is None: + raise ValueError('no such plugin: {}'.format(plugin_name)) + + if component_class_name in plugin.source_component_classes: + comp_class = plugin.source_component_classes[component_class_name] + elif component_class_name in plugin.filter_component_classes: + comp_class = plugin.filter_component_classes[component_class_name] + else: + raise KeyError( + 'source or filter component class `{}` not found in plugin `{}`'.format( + component_class_name, plugin_name + ) + ) + + return cls(comp_class, params, obj, logging_level) class AutoSourceComponentSpec(_BaseComponentSpec): + # A component spec that does automatic source discovery. _no_obj = object() def __init__(self, input, params=None, obj=_no_obj, logging_level=None): @@ -130,6 +167,8 @@ def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): comp_specs_raw = res['results'] assert type(comp_specs_raw) == bt2.ArrayValue + used_input_indices = set() + for comp_spec_raw in comp_specs_raw: assert type(comp_spec_raw) == bt2.ArrayValue assert len(comp_spec_raw) == 4 @@ -171,10 +210,12 @@ def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): if orig_spec.obj is not AutoSourceComponentSpec._no_obj: obj = orig_spec.obj + used_input_indices.add(int(idx)) + params['inputs'] = comp_inputs comp_specs.append( - ComponentSpec( + ComponentSpec.from_named_plugin_and_component_class( plugin_name, class_name, params=params, @@ -183,6 +224,17 @@ def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): ) ) + if len(used_input_indices) != len(inputs): + unused_input_indices = set(range(len(inputs))) - used_input_indices + unused_input_indices = sorted(unused_input_indices) + unused_inputs = [str(inputs[x]) for x in unused_input_indices] + + msg = ( + 'Some auto source component specs did not produce any component: ' + + ', '.join(unused_inputs) + ) + raise RuntimeError(msg) + return comp_specs @@ -205,11 +257,6 @@ def _get_ns(obj): return int(s * 1e9) -class _CompClsType: - SOURCE = 0 - FILTER = 1 - - class _TraceCollectionMessageIteratorProxySink(bt2_component._UserSinkComponent): def __init__(self, params, msg_list): assert type(msg_list) is list @@ -406,8 +453,8 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): comp_cls = plugin.filter_component_classes['trimmer'] return self._graph.add_component(comp_cls, name, params) - def _get_unique_comp_name(self, comp_spec): - name = '{}-{}'.format(comp_spec.plugin_name, comp_spec.class_name) + def _get_unique_comp_name(self, comp_cls): + name = comp_cls.name comps_and_specs = itertools.chain( self._src_comps_and_specs, self._flt_comps_and_specs ) @@ -418,27 +465,9 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): return name - def _create_comp(self, comp_spec, comp_cls_type): - plugin = bt2.find_plugin(comp_spec.plugin_name) - - if plugin is None: - raise ValueError('no such plugin: {}'.format(comp_spec.plugin_name)) - - if comp_cls_type == _CompClsType.SOURCE: - comp_classes = plugin.source_component_classes - else: - comp_classes = plugin.filter_component_classes - - if comp_spec.class_name not in comp_classes: - cc_type = 'source' if comp_cls_type == _CompClsType.SOURCE else 'filter' - raise ValueError( - 'no such {} component class in "{}" plugin: {}'.format( - cc_type, comp_spec.plugin_name, comp_spec.class_name - ) - ) - - comp_cls = comp_classes[comp_spec.class_name] - name = self._get_unique_comp_name(comp_spec) + def _create_comp(self, comp_spec): + comp_cls = comp_spec.component_class + name = self._get_unique_comp_name(comp_cls) comp = self._graph.add_component( comp_cls, name, comp_spec.params, comp_spec.obj, comp_spec.logging_level ) @@ -480,8 +509,36 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): self._connect_src_comp_port(component, port) + def _get_greatest_operative_mip_version(self): + def append_comp_specs_descriptors(descriptors, comp_specs): + for comp_spec in comp_specs: + descriptors.append( + bt2.ComponentDescriptor( + comp_spec.component_class, comp_spec.params, comp_spec.obj + ) + ) + + descriptors = [] + append_comp_specs_descriptors(descriptors, self._src_comp_specs) + append_comp_specs_descriptors(descriptors, self._flt_comp_specs) + + if self._stream_intersection_mode: + # we also need at least one `flt.utils.trimmer` component + comp_spec = ComponentSpec.from_named_plugin_and_component_class( + 'utils', 'trimmer' + ) + append_comp_specs_descriptors(descriptors, [comp_spec]) + + mip_version = bt2.get_greatest_operative_mip_version(descriptors) + + if mip_version is None: + msg = 'failed to find an operative message interchange protocol version (components are not interoperable)' + raise RuntimeError(msg) + + return mip_version + def _build_graph(self): - self._graph = bt2.Graph() + self._graph = bt2.Graph(self._get_greatest_operative_mip_version()) self._graph.add_port_added_listener(self._graph_port_added) self._muxer_comp = self._create_muxer() @@ -496,7 +553,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): # create extra filter components (chained) for comp_spec in self._flt_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.FILTER) + comp = self._create_comp(comp_spec) self._flt_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) # connect the extra filter chain @@ -514,7 +571,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): # it does not exist yet (it needs the created component to # exist). for comp_spec in self._src_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.SOURCE) + comp = self._create_comp(comp_spec) self._src_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) # Now we connect the ports which exist at this point. We allow