bt2: add auto source discovery support to TraceCollectionMessageIterator
[babeltrace.git] / src / bindings / python / bt2 / bt2 / trace_collection_message_iterator.py
index 9c931fe2d4be568a5362cf1c020be87f8e76e00b..56b48da9ff8061fe4a96e70870495ec711c313a2 100644 (file)
 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 # THE SOFTWARE.
 
-from bt2 import utils
+from bt2 import utils, native_bt
 import bt2
 import itertools
 from bt2 import message_iterator as bt2_message_iterator
 from bt2 import logging as bt2_logging
 from bt2 import port as bt2_port
 from bt2 import component as bt2_component
+from bt2 import value as bt2_value
+from bt2 import plugin as bt2_plugin
 import datetime
 from collections import namedtuple
 import numbers
@@ -36,7 +38,29 @@ import numbers
 _ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec'])
 
 
-class ComponentSpec:
+class _BaseComponentSpec:
+    def __init__(self, params, obj, logging_level):
+        if logging_level is not None:
+            utils._check_log_level(logging_level)
+
+        self._params = bt2.create_value(params)
+        self._obj = obj
+        self._logging_level = logging_level
+
+    @property
+    def params(self):
+        return self._params
+
+    @property
+    def obj(self):
+        return self._obj
+
+    @property
+    def logging_level(self):
+        return self._logging_level
+
+
+class ComponentSpec(_BaseComponentSpec):
     def __init__(
         self,
         plugin_name,
@@ -45,18 +69,16 @@ class ComponentSpec:
         obj=None,
         logging_level=bt2_logging.LoggingLevel.NONE,
     ):
+        if type(params) is str:
+            params = {'inputs': [params]}
+
+        super().__init__(params, obj, logging_level)
+
         utils._check_str(plugin_name)
         utils._check_str(class_name)
-        utils._check_log_level(logging_level)
+
         self._plugin_name = plugin_name
         self._class_name = class_name
-        self._logging_level = logging_level
-        self._obj = obj
-
-        if type(params) is str:
-            self._params = bt2.create_value({'inputs': [params]})
-        else:
-            self._params = bt2.create_value(params)
 
     @property
     def plugin_name(self):
@@ -66,17 +88,102 @@ class ComponentSpec:
     def class_name(self):
         return self._class_name
 
-    @property
-    def logging_level(self):
-        return self._logging_level
 
-    @property
-    def params(self):
-        return self._params
+class AutoSourceComponentSpec(_BaseComponentSpec):
+    _no_obj = object()
+
+    def __init__(self, input, params=None, obj=_no_obj, logging_level=None):
+        super().__init__(params, obj, logging_level)
+        self._input = input
 
     @property
-    def obj(self):
-        return self._obj
+    def input(self):
+        return self._input
+
+
+def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set):
+    # Transform a list of `AutoSourceComponentSpec` in a list of `ComponentSpec`
+    # using the automatic source discovery mechanism.
+    inputs = bt2.ArrayValue([spec.input for spec in auto_source_comp_specs])
+
+    if plugin_set is None:
+        plugin_set = bt2.find_plugins()
+    else:
+        utils._check_type(plugin_set, bt2_plugin._PluginSet)
+
+    res_ptr = native_bt.bt2_auto_discover_source_components(
+        inputs._ptr, plugin_set._ptr
+    )
+
+    if res_ptr is None:
+        raise bt2._MemoryError('cannot auto discover source components')
+
+    res = bt2_value._create_from_ptr(res_ptr)
+
+    assert type(res) == bt2.MapValue
+    assert 'status' in res
+
+    status = res['status']
+    utils._handle_func_status(status, 'cannot auto-discover source components')
+
+    comp_specs = []
+    comp_specs_raw = res['results']
+    assert type(comp_specs_raw) == bt2.ArrayValue
+
+    for comp_spec_raw in comp_specs_raw:
+        assert type(comp_spec_raw) == bt2.ArrayValue
+        assert len(comp_spec_raw) == 4
+
+        plugin_name = comp_spec_raw[0]
+        assert type(plugin_name) == bt2.StringValue
+        plugin_name = str(plugin_name)
+
+        class_name = comp_spec_raw[1]
+        assert type(class_name) == bt2.StringValue
+        class_name = str(class_name)
+
+        comp_inputs = comp_spec_raw[2]
+        assert type(comp_inputs) == bt2.ArrayValue
+
+        comp_orig_indices = comp_spec_raw[3]
+        assert type(comp_orig_indices)
+
+        params = bt2.MapValue()
+        logging_level = bt2.LoggingLevel.NONE
+        obj = None
+
+        # Compute `params` for this component by piling up params given to all
+        # AutoSourceComponentSpec objects that contributed in the instantiation
+        # of this component.
+        #
+        # The effective log level for a component is the last one specified
+        # across the AutoSourceComponentSpec that contributed in its
+        # instantiation.
+        for idx in comp_orig_indices:
+            orig_spec = auto_source_comp_specs[idx]
+
+            if orig_spec.params is not None:
+                params.update(orig_spec.params)
+
+            if orig_spec.logging_level is not None:
+                logging_level = orig_spec.logging_level
+
+            if orig_spec.obj is not AutoSourceComponentSpec._no_obj:
+                obj = orig_spec.obj
+
+        params['inputs'] = comp_inputs
+
+        comp_specs.append(
+            ComponentSpec(
+                plugin_name,
+                class_name,
+                params=params,
+                obj=obj,
+                logging_level=logging_level,
+            )
+        )
+
+    return comp_specs
 
 
 # datetime.datetime or integral to nanoseconds
@@ -127,6 +234,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator):
         stream_intersection_mode=False,
         begin=None,
         end=None,
+        plugin_set=None,
     ):
         utils._check_bool(stream_intersection_mode)
         self._stream_intersection_mode = stream_intersection_mode
@@ -134,15 +242,46 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator):
         self._end_ns = _get_ns(end)
         self._msg_list = [None]
 
-        if type(source_component_specs) is ComponentSpec:
+        # If a single item is provided, convert to a list.
+        if type(source_component_specs) in (
+            ComponentSpec,
+            AutoSourceComponentSpec,
+            str,
+        ):
             source_component_specs = [source_component_specs]
 
+        # Convert any string to an AutoSourceComponentSpec.
+        def str_to_auto(item):
+            if type(item) is str:
+                item = AutoSourceComponentSpec(item)
+
+            return item
+
+        source_component_specs = [str_to_auto(s) for s in source_component_specs]
+
         if type(filter_component_specs) is ComponentSpec:
             filter_component_specs = [filter_component_specs]
         elif filter_component_specs is None:
             filter_component_specs = []
 
-        self._src_comp_specs = source_component_specs
+        self._validate_source_component_specs(source_component_specs)
+        self._validate_filter_component_specs(filter_component_specs)
+
+        # Pass any `ComponentSpec` instance as-is.
+        self._src_comp_specs = [
+            spec for spec in source_component_specs if type(spec) is ComponentSpec
+        ]
+
+        # Convert any `AutoSourceComponentSpec` in concrete `ComponentSpec` instances.
+        auto_src_comp_specs = [
+            spec
+            for spec in source_component_specs
+            if type(spec) is AutoSourceComponentSpec
+        ]
+        self._src_comp_specs += _auto_discover_source_component_specs(
+            auto_src_comp_specs, plugin_set
+        )
+
         self._flt_comp_specs = filter_component_specs
         self._next_suffix = 1
         self._connect_ports = False
@@ -151,11 +290,21 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator):
         self._src_comps_and_specs = []
         self._flt_comps_and_specs = []
 
-        self._validate_component_specs(source_component_specs)
-        self._validate_component_specs(filter_component_specs)
         self._build_graph()
 
-    def _validate_component_specs(self, comp_specs):
+    def _validate_source_component_specs(self, comp_specs):
+        for comp_spec in comp_specs:
+            if (
+                type(comp_spec) is not ComponentSpec
+                and type(comp_spec) is not AutoSourceComponentSpec
+            ):
+                raise TypeError(
+                    '"{}" object is not a ComponentSpec or AutoSourceComponentSpec'.format(
+                        type(comp_spec)
+                    )
+                )
+
+    def _validate_filter_component_specs(self, comp_specs):
         for comp_spec in comp_specs:
             if type(comp_spec) is not ComponentSpec:
                 raise TypeError(
This page took 0.026485 seconds and 4 git commands to generate.