X-Git-Url: http://git.efficios.com/?a=blobdiff_plain;f=src%2Fbindings%2Fpython%2Fbt2%2Fbt2%2Ftrace_collection_message_iterator.py;h=4e7347a92a01128a6265f4b2328739540d3044c5;hb=30947af01a064c13fd12c8faf8f13e3d9fd8087f;hp=56b48da9ff8061fe4a96e70870495ec711c313a2;hpb=f3c9a159782f70dbd0e5dedb37e4a1ef8a6d304e;p=babeltrace.git diff --git a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py index 56b48da9..4e7347a9 100644 --- a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py +++ b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py @@ -24,7 +24,6 @@ from bt2 import utils, native_bt import bt2 import itertools from bt2 import message_iterator as bt2_message_iterator -from bt2 import logging as bt2_logging from bt2 import port as bt2_port from bt2 import component as bt2_component from bt2 import value as bt2_value @@ -39,6 +38,8 @@ _ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec']) class _BaseComponentSpec: + # Base for any component spec that can be passed to + # TraceCollectionMessageIterator. def __init__(self, params, obj, logging_level): if logging_level is not None: utils._check_log_level(logging_level) @@ -61,35 +62,71 @@ class _BaseComponentSpec: class ComponentSpec(_BaseComponentSpec): + # A component spec with a specific component class. def __init__( self, - plugin_name, - class_name, + component_class, params=None, obj=None, - logging_level=bt2_logging.LoggingLevel.NONE, + logging_level=bt2.LoggingLevel.NONE, ): if type(params) is str: params = {'inputs': [params]} super().__init__(params, obj, logging_level) - utils._check_str(plugin_name) - utils._check_str(class_name) + is_cc_object = isinstance( + component_class, (bt2._SourceComponentClass, bt2._FilterComponentClass) + ) + is_user_cc_type = isinstance( + component_class, bt2_component._UserComponentType + ) and issubclass( + component_class, (bt2._UserSourceComponent, bt2._UserFilterComponent) + ) + + if not is_cc_object and not is_user_cc_type: + raise TypeError( + "'{}' is not a source or filter component class".format( + component_class.__class__.__name__ + ) + ) - self._plugin_name = plugin_name - self._class_name = class_name + self._component_class = component_class @property - def plugin_name(self): - return self._plugin_name + def component_class(self): + return self._component_class - @property - def class_name(self): - return self._class_name + @classmethod + def from_named_plugin_and_component_class( + cls, + plugin_name, + component_class_name, + params=None, + obj=None, + logging_level=bt2.LoggingLevel.NONE, + ): + plugin = bt2.find_plugin(plugin_name) + + if plugin is None: + raise ValueError('no such plugin: {}'.format(plugin_name)) + + if component_class_name in plugin.source_component_classes: + comp_class = plugin.source_component_classes[component_class_name] + elif component_class_name in plugin.filter_component_classes: + comp_class = plugin.filter_component_classes[component_class_name] + else: + raise KeyError( + 'source or filter component class `{}` not found in plugin `{}`'.format( + component_class_name, plugin_name + ) + ) + + return cls(comp_class, params, obj, logging_level) class AutoSourceComponentSpec(_BaseComponentSpec): + # A component spec that does automatic source discovery. _no_obj = object() def __init__(self, input, params=None, obj=_no_obj, logging_level=None): @@ -130,6 +167,8 @@ def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): comp_specs_raw = res['results'] assert type(comp_specs_raw) == bt2.ArrayValue + used_input_indices = set() + for comp_spec_raw in comp_specs_raw: assert type(comp_spec_raw) == bt2.ArrayValue assert len(comp_spec_raw) == 4 @@ -171,10 +210,12 @@ def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): if orig_spec.obj is not AutoSourceComponentSpec._no_obj: obj = orig_spec.obj + used_input_indices.add(int(idx)) + params['inputs'] = comp_inputs comp_specs.append( - ComponentSpec( + ComponentSpec.from_named_plugin_and_component_class( plugin_name, class_name, params=params, @@ -183,6 +224,17 @@ def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): ) ) + if len(used_input_indices) != len(inputs): + unused_input_indices = set(range(len(inputs))) - used_input_indices + unused_input_indices = sorted(unused_input_indices) + unused_inputs = [str(inputs[x]) for x in unused_input_indices] + + msg = ( + 'Some auto source component specs did not produce any component: ' + + ', '.join(unused_inputs) + ) + raise RuntimeError(msg) + return comp_specs @@ -205,11 +257,6 @@ def _get_ns(obj): return int(s * 1e9) -class _CompClsType: - SOURCE = 0 - FILTER = 1 - - class _TraceCollectionMessageIteratorProxySink(bt2_component._UserSinkComponent): def __init__(self, params, msg_list): assert type(msg_list) is list @@ -292,6 +339,47 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): self._build_graph() + def _compute_stream_intersections(self): + # Pre-compute the trimmer range to use for each port in the graph, when + # stream intersection mode is enabled. + self._stream_inter_port_to_range = {} + + for src_comp_and_spec in self._src_comps_and_specs: + try: + inputs = src_comp_and_spec.spec.params['inputs'] + except KeyError as e: + raise ValueError( + 'all source components must be created with an "inputs" parameter in stream intersection mode' + ) from e + + params = {'inputs': inputs} + + # query the port's component for the `babeltrace.trace-info` + # object which contains the range for each stream, from which we can + # compute the intersection of the streams in each trace. + query_exec = bt2.QueryExecutor( + src_comp_and_spec.spec.component_class, 'babeltrace.trace-info', params + ) + trace_infos = query_exec.query() + + for trace_info in trace_infos: + begin = max( + [stream['range-ns']['begin'] for stream in trace_info['streams']] + ) + end = min( + [stream['range-ns']['end'] for stream in trace_info['streams']] + ) + + # Each port associated to this trace will have this computed + # range. + for stream in trace_info['streams']: + # A port name is unique within a component, but not + # necessarily across all components. Use a component + # and port name pair to make it unique across the graph. + port_name = str(stream['port-name']) + key = (src_comp_and_spec.comp.addr, port_name) + self._stream_inter_port_to_range[key] = (begin, end) + def _validate_source_component_specs(self, comp_specs): for comp_spec in comp_specs: if ( @@ -320,49 +408,9 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): return msg def _create_stream_intersection_trimmer(self, component, port): - # find the original parameters specified by the user to create - # this port's component to get the `inputs` parameter - for src_comp_and_spec in self._src_comps_and_specs: - if component == src_comp_and_spec.comp: - break - - try: - inputs = src_comp_and_spec.spec.params['inputs'] - except Exception as e: - raise ValueError( - 'all source components must be created with an "inputs" parameter in stream intersection mode' - ) from e - - params = {'inputs': inputs} - - # query the port's component for the `babeltrace.trace-info` - # object which contains the stream intersection range for each - # exposed trace - query_exec = bt2.QueryExecutor( - src_comp_and_spec.comp.cls, 'babeltrace.trace-info', params - ) - trace_info_res = query_exec.query() - begin = None - end = None - - # find the trace info for this port's trace - try: - for trace_info in trace_info_res: - for stream in trace_info['streams']: - if stream['port-name'] == port.name: - range_ns = trace_info['intersection-range-ns'] - begin = range_ns['begin'] - end = range_ns['end'] - break - except Exception: - pass - - if begin is None or end is None: - raise RuntimeError( - 'cannot find stream intersection range for port "{}"'.format(port.name) - ) - - name = 'trimmer-{}-{}'.format(src_comp_and_spec.comp.name, port.name) + key = (component.addr, port.name) + begin, end = self._stream_inter_port_to_range[key] + name = 'trimmer-{}-{}'.format(component.name, port.name) return self._create_trimmer(begin, end, name) def _create_muxer(self): @@ -406,8 +454,8 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): comp_cls = plugin.filter_component_classes['trimmer'] return self._graph.add_component(comp_cls, name, params) - def _get_unique_comp_name(self, comp_spec): - name = '{}-{}'.format(comp_spec.plugin_name, comp_spec.class_name) + def _get_unique_comp_name(self, comp_cls): + name = comp_cls.name comps_and_specs = itertools.chain( self._src_comps_and_specs, self._flt_comps_and_specs ) @@ -418,27 +466,9 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): return name - def _create_comp(self, comp_spec, comp_cls_type): - plugin = bt2.find_plugin(comp_spec.plugin_name) - - if plugin is None: - raise ValueError('no such plugin: {}'.format(comp_spec.plugin_name)) - - if comp_cls_type == _CompClsType.SOURCE: - comp_classes = plugin.source_component_classes - else: - comp_classes = plugin.filter_component_classes - - if comp_spec.class_name not in comp_classes: - cc_type = 'source' if comp_cls_type == _CompClsType.SOURCE else 'filter' - raise ValueError( - 'no such {} component class in "{}" plugin: {}'.format( - cc_type, comp_spec.plugin_name, comp_spec.class_name - ) - ) - - comp_cls = comp_classes[comp_spec.class_name] - name = self._get_unique_comp_name(comp_spec) + def _create_comp(self, comp_spec): + comp_cls = comp_spec.component_class + name = self._get_unique_comp_name(comp_cls) comp = self._graph.add_component( comp_cls, name, comp_spec.params, comp_spec.obj, comp_spec.logging_level ) @@ -480,8 +510,36 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): self._connect_src_comp_port(component, port) + def _get_greatest_operative_mip_version(self): + def append_comp_specs_descriptors(descriptors, comp_specs): + for comp_spec in comp_specs: + descriptors.append( + bt2.ComponentDescriptor( + comp_spec.component_class, comp_spec.params, comp_spec.obj + ) + ) + + descriptors = [] + append_comp_specs_descriptors(descriptors, self._src_comp_specs) + append_comp_specs_descriptors(descriptors, self._flt_comp_specs) + + if self._stream_intersection_mode: + # we also need at least one `flt.utils.trimmer` component + comp_spec = ComponentSpec.from_named_plugin_and_component_class( + 'utils', 'trimmer' + ) + append_comp_specs_descriptors(descriptors, [comp_spec]) + + mip_version = bt2.get_greatest_operative_mip_version(descriptors) + + if mip_version is None: + msg = 'failed to find an operative message interchange protocol version (components are not interoperable)' + raise RuntimeError(msg) + + return mip_version + def _build_graph(self): - self._graph = bt2.Graph() + self._graph = bt2.Graph(self._get_greatest_operative_mip_version()) self._graph.add_port_added_listener(self._graph_port_added) self._muxer_comp = self._create_muxer() @@ -496,7 +554,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): # create extra filter components (chained) for comp_spec in self._flt_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.FILTER) + comp = self._create_comp(comp_spec) self._flt_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) # connect the extra filter chain @@ -514,9 +572,12 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): # it does not exist yet (it needs the created component to # exist). for comp_spec in self._src_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.SOURCE) + comp = self._create_comp(comp_spec) self._src_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) + if self._stream_intersection_mode: + self._compute_stream_intersections() + # Now we connect the ports which exist at this point. We allow # self._graph_port_added() to automatically connect _new_ ports. self._connect_ports = True