X-Git-Url: http://git.efficios.com/?a=blobdiff_plain;f=src%2Fbindings%2Fpython%2Fbt2%2Fbt2%2Ftrace_collection_message_iterator.py;h=821b5eae84f2e002e279e05b8ef35c4dacade036;hb=5f2a1585bf407f3f3aa7e63d9041b75390cf8563;hp=1b68d20523dba054a1f6f70177e622e1023196dc;hpb=ce4923b0c7a2de36eba95725334d251e9aa08aad;p=babeltrace.git diff --git a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py index 1b68d205..821b5eae 100644 --- a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py +++ b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py @@ -20,10 +20,14 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from bt2 import utils +from bt2 import utils, native_bt import bt2 import itertools -import bt2.message_iterator +from bt2 import message_iterator as bt2_message_iterator +from bt2 import port as bt2_port +from bt2 import component as bt2_component +from bt2 import value as bt2_value +from bt2 import plugin as bt2_plugin import datetime from collections import namedtuple import numbers @@ -33,41 +37,205 @@ import numbers _ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec']) -class ComponentSpec: +class _BaseComponentSpec: + # Base for any component spec that can be passed to + # TraceCollectionMessageIterator. + def __init__(self, params, obj, logging_level): + if logging_level is not None: + utils._check_log_level(logging_level) + + self._params = bt2.create_value(params) + self._obj = obj + self._logging_level = logging_level + + @property + def params(self): + return self._params + + @property + def obj(self): + return self._obj + + @property + def logging_level(self): + return self._logging_level + + +class ComponentSpec(_BaseComponentSpec): + # A component spec with a specific component class. def __init__( self, + component_class, + params=None, + obj=None, + logging_level=bt2.LoggingLevel.NONE, + ): + if type(params) is str: + params = {'inputs': [params]} + + super().__init__(params, obj, logging_level) + + is_cc_object = isinstance( + component_class, (bt2._SourceComponentClass, bt2._FilterComponentClass) + ) + is_user_cc_type = isinstance( + component_class, bt2_component._UserComponentType + ) and issubclass( + component_class, (bt2._UserSourceComponent, bt2._UserFilterComponent) + ) + + if not is_cc_object and not is_user_cc_type: + raise TypeError( + "'{}' is not a source or filter component class".format( + component_class.__class__.__name__ + ) + ) + + self._component_class = component_class + + @property + def component_class(self): + return self._component_class + + @classmethod + def from_named_plugin_and_component_class( + cls, plugin_name, - class_name, + component_class_name, params=None, - logging_level=bt2.logging.LoggingLevel.NONE, + obj=None, + logging_level=bt2.LoggingLevel.NONE, ): - utils._check_str(plugin_name) - utils._check_str(class_name) - utils._check_log_level(logging_level) - self._plugin_name = plugin_name - self._class_name = class_name - self._logging_level = logging_level + plugin = bt2.find_plugin(plugin_name) - if type(params) is str: - self._params = bt2.create_value({'inputs': [params]}) + if plugin is None: + raise ValueError('no such plugin: {}'.format(plugin_name)) + + if component_class_name in plugin.source_component_classes: + comp_class = plugin.source_component_classes[component_class_name] + elif component_class_name in plugin.filter_component_classes: + comp_class = plugin.filter_component_classes[component_class_name] else: - self._params = bt2.create_value(params) + raise KeyError( + 'source or filter component class `{}` not found in plugin `{}`'.format( + component_class_name, plugin_name + ) + ) - @property - def plugin_name(self): - return self._plugin_name + return cls(comp_class, params, obj, logging_level) - @property - def class_name(self): - return self._class_name - @property - def logging_level(self): - return self._logging_level +class AutoSourceComponentSpec(_BaseComponentSpec): + # A component spec that does automatic source discovery. + _no_obj = object() + + def __init__(self, input, params=None, obj=_no_obj, logging_level=None): + super().__init__(params, obj, logging_level) + self._input = input @property - def params(self): - return self._params + def input(self): + return self._input + + +def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): + # Transform a list of `AutoSourceComponentSpec` in a list of `ComponentSpec` + # using the automatic source discovery mechanism. + inputs = bt2.ArrayValue([spec.input for spec in auto_source_comp_specs]) + + if plugin_set is None: + plugin_set = bt2.find_plugins() + else: + utils._check_type(plugin_set, bt2_plugin._PluginSet) + + res_ptr = native_bt.bt2_auto_discover_source_components( + inputs._ptr, plugin_set._ptr + ) + + if res_ptr is None: + raise bt2._MemoryError('cannot auto discover source components') + + res = bt2_value._create_from_ptr(res_ptr) + + assert type(res) == bt2.MapValue + assert 'status' in res + + status = res['status'] + utils._handle_func_status(status, 'cannot auto-discover source components') + + comp_specs = [] + comp_specs_raw = res['results'] + assert type(comp_specs_raw) == bt2.ArrayValue + + used_input_indices = set() + + for comp_spec_raw in comp_specs_raw: + assert type(comp_spec_raw) == bt2.ArrayValue + assert len(comp_spec_raw) == 4 + + plugin_name = comp_spec_raw[0] + assert type(plugin_name) == bt2.StringValue + plugin_name = str(plugin_name) + + class_name = comp_spec_raw[1] + assert type(class_name) == bt2.StringValue + class_name = str(class_name) + + comp_inputs = comp_spec_raw[2] + assert type(comp_inputs) == bt2.ArrayValue + + comp_orig_indices = comp_spec_raw[3] + assert type(comp_orig_indices) + + params = bt2.MapValue() + logging_level = bt2.LoggingLevel.NONE + obj = None + + # Compute `params` for this component by piling up params given to all + # AutoSourceComponentSpec objects that contributed in the instantiation + # of this component. + # + # The effective log level for a component is the last one specified + # across the AutoSourceComponentSpec that contributed in its + # instantiation. + for idx in comp_orig_indices: + orig_spec = auto_source_comp_specs[idx] + + if orig_spec.params is not None: + params.update(orig_spec.params) + + if orig_spec.logging_level is not None: + logging_level = orig_spec.logging_level + + if orig_spec.obj is not AutoSourceComponentSpec._no_obj: + obj = orig_spec.obj + + used_input_indices.add(int(idx)) + + params['inputs'] = comp_inputs + + comp_specs.append( + ComponentSpec.from_named_plugin_and_component_class( + plugin_name, + class_name, + params=params, + obj=obj, + logging_level=logging_level, + ) + ) + + if len(used_input_indices) != len(inputs): + unused_input_indices = set(range(len(inputs))) - used_input_indices + unused_input_indices = sorted(unused_input_indices) + unused_inputs = [str(inputs[x]) for x in unused_input_indices] + + msg = ( + 'Some auto source component specs did not produce any component: ' + + ', '.join(unused_inputs) + ) + raise RuntimeError(msg) + + return comp_specs # datetime.datetime or integral to nanoseconds @@ -89,12 +257,23 @@ def _get_ns(obj): return int(s * 1e9) -class _CompClsType: - SOURCE = 0 - FILTER = 1 +class _TraceCollectionMessageIteratorProxySink(bt2_component._UserSinkComponent): + def __init__(self, params, msg_list): + assert type(msg_list) is list + self._msg_list = msg_list + self._add_input_port('in') + + def _user_graph_is_configured(self): + self._msg_iter = self._create_input_port_message_iterator( + self._input_ports['in'] + ) + + def _user_consume(self): + assert self._msg_list[0] is None + self._msg_list[0] = next(self._msg_iter) -class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): +class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): def __init__( self, source_component_specs, @@ -102,21 +281,54 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): stream_intersection_mode=False, begin=None, end=None, + plugin_set=None, ): utils._check_bool(stream_intersection_mode) self._stream_intersection_mode = stream_intersection_mode self._begin_ns = _get_ns(begin) self._end_ns = _get_ns(end) - - if type(source_component_specs) is ComponentSpec: + self._msg_list = [None] + + # If a single item is provided, convert to a list. + if type(source_component_specs) in ( + ComponentSpec, + AutoSourceComponentSpec, + str, + ): source_component_specs = [source_component_specs] + # Convert any string to an AutoSourceComponentSpec. + def str_to_auto(item): + if type(item) is str: + item = AutoSourceComponentSpec(item) + + return item + + source_component_specs = [str_to_auto(s) for s in source_component_specs] + if type(filter_component_specs) is ComponentSpec: filter_component_specs = [filter_component_specs] elif filter_component_specs is None: filter_component_specs = [] - self._src_comp_specs = source_component_specs + self._validate_source_component_specs(source_component_specs) + self._validate_filter_component_specs(filter_component_specs) + + # Pass any `ComponentSpec` instance as-is. + self._src_comp_specs = [ + spec for spec in source_component_specs if type(spec) is ComponentSpec + ] + + # Convert any `AutoSourceComponentSpec` in concrete `ComponentSpec` instances. + auto_src_comp_specs = [ + spec + for spec in source_component_specs + if type(spec) is AutoSourceComponentSpec + ] + self._src_comp_specs += _auto_discover_source_component_specs( + auto_src_comp_specs, plugin_set + ) + self._flt_comp_specs = filter_component_specs self._next_suffix = 1 self._connect_ports = False @@ -125,11 +337,58 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): self._src_comps_and_specs = [] self._flt_comps_and_specs = [] - self._validate_component_specs(source_component_specs) - self._validate_component_specs(filter_component_specs) self._build_graph() - def _validate_component_specs(self, comp_specs): + def _compute_stream_intersections(self): + # Pre-compute the trimmer range to use for each port in the graph, when + # stream intersection mode is enabled. + self._stream_inter_port_to_range = {} + + for src_comp_and_spec in self._src_comps_and_specs: + # Query the port's component for the `babeltrace.trace-infos` + # object which contains the range for each stream, from which we can + # compute the intersection of the streams in each trace. + query_exec = bt2.QueryExecutor( + src_comp_and_spec.spec.component_class, + 'babeltrace.trace-infos', + src_comp_and_spec.spec.params, + ) + trace_infos = query_exec.query() + + for trace_info in trace_infos: + begin = max( + [ + stream['range-ns']['begin'] + for stream in trace_info['stream-infos'] + ] + ) + end = min( + [stream['range-ns']['end'] for stream in trace_info['stream-infos']] + ) + + # Each port associated to this trace will have this computed + # range. + for stream in trace_info['stream-infos']: + # A port name is unique within a component, but not + # necessarily across all components. Use a component + # and port name pair to make it unique across the graph. + port_name = str(stream['port-name']) + key = (src_comp_and_spec.comp.addr, port_name) + self._stream_inter_port_to_range[key] = (begin, end) + + def _validate_source_component_specs(self, comp_specs): + for comp_spec in comp_specs: + if ( + type(comp_spec) is not ComponentSpec + and type(comp_spec) is not AutoSourceComponentSpec + ): + raise TypeError( + '"{}" object is not a ComponentSpec or AutoSourceComponentSpec'.format( + type(comp_spec) + ) + ) + + def _validate_filter_component_specs(self, comp_specs): for comp_spec in comp_specs: if type(comp_spec) is not ComponentSpec: raise TypeError( @@ -137,52 +396,17 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): ) def __next__(self): - return next(self._msg_iter) + assert self._msg_list[0] is None + self._graph.run_once() + msg = self._msg_list[0] + assert msg is not None + self._msg_list[0] = None + return msg def _create_stream_intersection_trimmer(self, component, port): - # find the original parameters specified by the user to create - # this port's component to get the `inputs` parameter - for src_comp_and_spec in self._src_comps_and_specs: - if component == src_comp_and_spec.comp: - break - - try: - inputs = src_comp_and_spec.spec.params['inputs'] - except Exception as e: - raise ValueError( - 'all source components must be created with an "inputs" parameter in stream intersection mode' - ) from e - - params = {'inputs': inputs} - - # query the port's component for the `trace-info` object which - # contains the stream intersection range for each exposed - # trace - query_exec = bt2.QueryExecutor() - trace_info_res = query_exec.query( - src_comp_and_spec.comp.cls, 'trace-info', params - ) - begin = None - end = None - - # find the trace info for this port's trace - try: - for trace_info in trace_info_res: - for stream in trace_info['streams']: - if stream['port-name'] == port.name: - range_ns = trace_info['intersection-range-ns'] - begin = range_ns['begin'] - end = range_ns['end'] - break - except Exception: - pass - - if begin is None or end is None: - raise RuntimeError( - 'cannot find stream intersection range for port "{}"'.format(port.name) - ) - - name = 'trimmer-{}-{}'.format(src_comp_and_spec.comp.name, port.name) + key = (component.addr, port.name) + begin, end = self._stream_inter_port_to_range[key] + name = 'trimmer-{}-{}'.format(component.name, port.name) return self._create_trimmer(begin, end, name) def _create_muxer(self): @@ -226,8 +450,8 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): comp_cls = plugin.filter_component_classes['trimmer'] return self._graph.add_component(comp_cls, name, params) - def _get_unique_comp_name(self, comp_spec): - name = '{}-{}'.format(comp_spec.plugin_name, comp_spec.class_name) + def _get_unique_comp_name(self, comp_cls): + name = comp_cls.name comps_and_specs = itertools.chain( self._src_comps_and_specs, self._flt_comps_and_specs ) @@ -238,29 +462,11 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): return name - def _create_comp(self, comp_spec, comp_cls_type): - plugin = bt2.find_plugin(comp_spec.plugin_name) - - if plugin is None: - raise ValueError('no such plugin: {}'.format(comp_spec.plugin_name)) - - if comp_cls_type == _CompClsType.SOURCE: - comp_classes = plugin.source_component_classes - else: - comp_classes = plugin.filter_component_classes - - if comp_spec.class_name not in comp_classes: - cc_type = 'source' if comp_cls_type == _CompClsType.SOURCE else 'filter' - raise ValueError( - 'no such {} component class in "{}" plugin: {}'.format( - cc_type, comp_spec.plugin_name, comp_spec.class_name - ) - ) - - comp_cls = comp_classes[comp_spec.class_name] - name = self._get_unique_comp_name(comp_spec) + def _create_comp(self, comp_spec): + comp_cls = comp_spec.component_class + name = self._get_unique_comp_name(comp_cls) comp = self._graph.add_component( - comp_cls, name, comp_spec.params, comp_spec.logging_level + comp_cls, name, comp_spec.params, comp_spec.obj, comp_spec.logging_level ) return comp @@ -291,7 +497,7 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): if not self._connect_ports: return - if type(port) is bt2.port._InputPort: + if type(port) is bt2_port._InputPort: return if component not in [comp.comp for comp in self._src_comps_and_specs]: @@ -300,8 +506,36 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): self._connect_src_comp_port(component, port) + def _get_greatest_operative_mip_version(self): + def append_comp_specs_descriptors(descriptors, comp_specs): + for comp_spec in comp_specs: + descriptors.append( + bt2.ComponentDescriptor( + comp_spec.component_class, comp_spec.params, comp_spec.obj + ) + ) + + descriptors = [] + append_comp_specs_descriptors(descriptors, self._src_comp_specs) + append_comp_specs_descriptors(descriptors, self._flt_comp_specs) + + if self._stream_intersection_mode: + # we also need at least one `flt.utils.trimmer` component + comp_spec = ComponentSpec.from_named_plugin_and_component_class( + 'utils', 'trimmer' + ) + append_comp_specs_descriptors(descriptors, [comp_spec]) + + mip_version = bt2.get_greatest_operative_mip_version(descriptors) + + if mip_version is None: + msg = 'failed to find an operative message interchange protocol version (components are not interoperable)' + raise RuntimeError(msg) + + return mip_version + def _build_graph(self): - self._graph = bt2.Graph() + self._graph = bt2.Graph(self._get_greatest_operative_mip_version()) self._graph.add_port_added_listener(self._graph_port_added) self._muxer_comp = self._create_muxer() @@ -310,21 +544,21 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): self._graph.connect_ports( self._muxer_comp.output_ports['out'], trimmer_comp.input_ports['in'] ) - msg_iter_port = trimmer_comp.output_ports['out'] + last_flt_out_port = trimmer_comp.output_ports['out'] else: - msg_iter_port = self._muxer_comp.output_ports['out'] + last_flt_out_port = self._muxer_comp.output_ports['out'] # create extra filter components (chained) for comp_spec in self._flt_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.FILTER) + comp = self._create_comp(comp_spec) self._flt_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) # connect the extra filter chain for comp_and_spec in self._flt_comps_and_specs: in_port = list(comp_and_spec.comp.input_ports.values())[0] out_port = list(comp_and_spec.comp.output_ports.values())[0] - self._graph.connect_ports(msg_iter_port, in_port) - msg_iter_port = out_port + self._graph.connect_ports(last_flt_out_port, in_port) + last_flt_out_port = out_port # Here we create the components, self._graph_port_added() is # called when they add ports, but the callback returns early @@ -334,9 +568,12 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): # it does not exist yet (it needs the created component to # exist). for comp_spec in self._src_comp_specs: - comp = self._create_comp(comp_spec, _CompClsType.SOURCE) + comp = self._create_comp(comp_spec) self._src_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec)) + if self._stream_intersection_mode: + self._compute_stream_intersections() + # Now we connect the ports which exist at this point. We allow # self._graph_port_added() to automatically connect _new_ ports. self._connect_ports = True @@ -353,5 +590,12 @@ class TraceCollectionMessageIterator(bt2.message_iterator._MessageIterator): self._connect_src_comp_port(comp_and_spec.comp, out_port) - # create this trace collection iterator's message iterator - self._msg_iter = self._graph.create_output_port_message_iterator(msg_iter_port) + # Add the proxy sink, passing our message list to share consumed + # messages with this trace collection message iterator. + sink = self._graph.add_component( + _TraceCollectionMessageIteratorProxySink, 'proxy-sink', obj=self._msg_list + ) + sink_in_port = sink.input_ports['in'] + + # connect last filter to proxy sink + self._graph.connect_ports(last_flt_out_port, sink_in_port)