bt2: TraceCollectionNotificationIterator: support custom filter CCs
[babeltrace.git] / bindings / python / bt2 / bt2 / trace_collection_notification_iterator.py
index d7f1ae5eb4a6d342ae25e81eda0388cfeb122b83..a23436736d687c6796fae3c08b95a205148487ed 100644 (file)
@@ -22,6 +22,7 @@
 
 from bt2 import utils
 import bt2
+import itertools
 import bt2.notification_iterator
 import datetime
 import collections.abc
@@ -29,12 +30,11 @@ from collections import namedtuple
 import numbers
 
 
-# a pair of source component and _SourceComponentSpec
-_SourceComponentAndSpec = namedtuple('_SourceComponentAndSpec',
-                                     ['comp', 'spec'])
+# a pair of component and ComponentSpec
+_ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec'])
 
 
-class SourceComponentSpec:
+class ComponentSpec:
     def __init__(self, plugin_name, component_class_name, params=None):
         utils._check_str(plugin_name)
         utils._check_str(component_class_name)
@@ -76,29 +76,46 @@ def _get_ns(obj):
     return int(s * 1e9)
 
 
+class _CompClsType:
+    SOURCE = 0
+    FILTER = 1
+
+
 class TraceCollectionNotificationIterator(bt2.notification_iterator._NotificationIterator):
-    def __init__(self, source_component_specs, notification_types=None,
-                 stream_intersection_mode=False, begin=None,
-                 end=None):
+    def __init__(self, source_component_specs, filter_component_specs=None,
+                 notification_types=None, stream_intersection_mode=False,
+                 begin=None, end=None):
         utils._check_bool(stream_intersection_mode)
         self._stream_intersection_mode = stream_intersection_mode
         self._begin_ns = _get_ns(begin)
         self._end_ns = _get_ns(end)
         self._notification_types = notification_types
+
+        if type(source_component_specs) is ComponentSpec:
+            source_component_specs = [source_component_specs]
+
+        if type(filter_component_specs) is ComponentSpec:
+            filter_component_specs = [filter_component_specs]
+        elif filter_component_specs is None:
+            filter_component_specs = []
+
         self._src_comp_specs = source_component_specs
+        self._flt_comp_specs = filter_component_specs
         self._next_suffix = 1
         self._connect_ports = False
 
-        # set of _SourceComponentAndSpec
+        # lists of _ComponentAndSpec
         self._src_comps_and_specs = []
+        self._flt_comps_and_specs = []
 
-        self._validate_source_component_specs()
+        self._validate_component_specs(source_component_specs)
+        self._validate_component_specs(filter_component_specs)
         self._build_graph()
 
-    def _validate_source_component_specs(self):
-        for source_comp_spec in self._src_comp_specs:
-            if type(source_comp_spec) is not SourceComponentSpec:
-                raise TypeError('"{}" object is not a SourceComponentSpec'.format(type(source_comp_spec)))
+    def _validate_component_specs(self, comp_specs):
+        for comp_spec in comp_specs:
+            if type(comp_spec) is not ComponentSpec:
+                raise TypeError('"{}" object is not a ComponentSpec'.format(type(comp_spec)))
 
     def __next__(self):
         return next(self._notif_iter)
@@ -176,28 +193,37 @@ class TraceCollectionNotificationIterator(bt2.notification_iterator._Notificatio
         comp_cls = plugin.filter_component_classes['trimmer']
         return self._graph.add_component(comp_cls, name, params)
 
-    def _get_unique_src_comp_name(self, comp_spec):
+    def _get_unique_comp_name(self, comp_spec):
         name = '{}-{}'.format(comp_spec.plugin_name,
                               comp_spec.component_class_name)
+        comps_and_specs = itertools.chain(self._src_comps_and_specs,
+                                          self._flt_comps_and_specs)
 
-        if name in [comp_and_spec.comp.name for comp_and_spec in self._src_comps_and_specs]:
+        if name in [comp_and_spec.comp.name for comp_and_spec in comps_and_specs]:
             name += '-{}'.format(self._next_suffix)
             self._next_suffix += 1
 
         return name
 
-    def _create_src_comp(self, comp_spec):
+    def _create_comp(self, comp_spec, comp_cls_type):
         plugin = bt2.find_plugin(comp_spec.plugin_name)
 
         if plugin is None:
             raise bt2.Error('no such plugin: {}'.format(comp_spec.plugin_name))
 
-        if comp_spec.component_class_name not in plugin.source_component_classes:
-            raise bt2.Error('no such source component class in "{}" plugin: {}'.format(comp_spec.plugin_name,
-                                                                                       comp_spec.component_class_name))
+        if comp_cls_type == _CompClsType.SOURCE:
+            comp_classes = plugin.source_component_classes
+        else:
+            comp_classes = plugin.filter_component_classes
+
+        if comp_spec.component_class_name not in comp_classes:
+            cc_type = 'source' if comp_cls_type == _CompClsType.SOURCE else 'filter'
+            raise bt2.Error('no such {} component class in "{}" plugin: {}'.format(cc_type,
+                                                                                   comp_spec.plugin_name,
+                                                                                   comp_spec.component_class_name))
 
-        comp_cls = plugin.source_component_classes[comp_spec.component_class_name]
-        name = self._get_unique_src_comp_name(comp_spec)
+        comp_cls = comp_classes[comp_spec.component_class_name]
+        name = self._get_unique_comp_name(comp_spec)
         comp = self._graph.add_component(comp_cls, name, comp_spec.params)
         return comp
 
@@ -252,6 +278,18 @@ class TraceCollectionNotificationIterator(bt2.notification_iterator._Notificatio
         else:
             notif_iter_port = self._muxer_comp.output_ports['out']
 
+        # create extra filter components (chained)
+        for comp_spec in self._flt_comp_specs:
+            comp = self._create_comp(comp_spec, _CompClsType.FILTER)
+            self._flt_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec))
+
+        # connect the extra filter chain
+        for comp_and_spec in self._flt_comps_and_specs:
+            in_port = list(comp_and_spec.comp.input_ports.values())[0]
+            out_port = list(comp_and_spec.comp.output_ports.values())[0]
+            self._graph.connect_ports(notif_iter_port, in_port)
+            notif_iter_port = out_port
+
         # Here we create the components, self._graph_port_added() is
         # called when they add ports, but the callback returns early
         # because self._connect_ports is False. This is because the
@@ -260,8 +298,8 @@ class TraceCollectionNotificationIterator(bt2.notification_iterator._Notificatio
         # it does not exist yet (it needs the created component to
         # exist).
         for comp_spec in self._src_comp_specs:
-            comp = self._create_src_comp(comp_spec)
-            self._src_comps_and_specs.append(_SourceComponentAndSpec(comp, comp_spec))
+            comp = self._create_comp(comp_spec, _CompClsType.SOURCE)
+            self._src_comps_and_specs.append(_ComponentAndSpec(comp, comp_spec))
 
         # Now we connect the ports which exist at this point. We allow
         # self._graph_port_added() to automatically connect _new_ ports.
This page took 0.029071 seconds and 4 git commands to generate.