lib: pass config object to message iterator init method, add can seek forward property
[babeltrace.git] / tests / bindings / python / bt2 / test_message_iterator.py
index 39d808887988a6111ea97a78a6c4b957898fcc29..b331e4a40996188dd88cc6f0872d2451370d9549 100644 (file)
@@ -21,6 +21,7 @@ import bt2
 import sys
 from utils import TestOutputPortMessageIterator
 from bt2 import port as bt2_port
+from bt2 import message_iterator as bt2_message_iterator
 
 
 class SimpleSink(bt2._UserSinkComponent):
@@ -61,7 +62,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
         the_output_port_from_iter = None
 
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self, self_port_output):
+            def __init__(self, config, self_port_output):
                 nonlocal initialized
                 nonlocal the_output_port_from_iter
                 initialized = True
@@ -83,7 +84,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
     def test_create_from_message_iterator(self):
         class MySourceIter(bt2._UserMessageIterator):
-            def __init__(self, self_port_output):
+            def __init__(self, config, self_port_output):
                 nonlocal src_iter_initialized
                 src_iter_initialized = True
 
@@ -92,7 +93,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 self._add_output_port('out')
 
         class MyFilterIter(bt2._UserMessageIterator):
-            def __init__(self, self_port_output):
+            def __init__(self, config, self_port_output):
                 nonlocal flt_iter_initialized
                 flt_iter_initialized = True
                 self._up_iter = self._create_input_port_message_iterator(
@@ -120,7 +121,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
         # and _UserMessageIterator._create_input_port_message_iterator, as they
         # are both used in the graph.
         class MySourceIter(bt2._UserMessageIterator):
-            def __init__(self, self_port_output):
+            def __init__(self, config, self_port_output):
                 raise ValueError('Very bad error')
 
         class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
@@ -128,7 +129,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 self._add_output_port('out')
 
         class MyFilterIter(bt2._UserMessageIterator):
-            def __init__(self, self_port_output):
+            def __init__(self, config, self_port_output):
                 # This is expected to raise because of the error in
                 # MySourceIter.__init__.
                 self._create_input_port_message_iterator(
@@ -169,9 +170,74 @@ class UserMessageIteratorTestCase(unittest.TestCase):
         del graph
         self.assertTrue(finalized)
 
+    def test_config_parameter(self):
+        class MyIter(bt2._UserMessageIterator):
+            def __init__(self, config, port):
+                nonlocal config_type
+                config_type = type(config)
+
+        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+            def __init__(self, config, params, obj):
+                self._add_output_port('out')
+
+        config_type = None
+        graph = _create_graph(MySource, SimpleSink)
+        graph.run()
+        self.assertIs(config_type, bt2_message_iterator._MessageIteratorConfiguration)
+
+    def _test_config_can_seek_forward(self, set_can_seek_forward):
+        class MyIter(bt2._UserMessageIterator):
+            def __init__(self, config, port):
+                if set_can_seek_forward:
+                    config.can_seek_forward = True
+
+        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+            def __init__(self, config, params, obj):
+                self._add_output_port('out')
+
+        class MySink(bt2._UserSinkComponent):
+            def __init__(self, config, params, obj):
+                self._add_input_port('in')
+
+            def _user_graph_is_configured(self):
+                self._msg_iter = self._create_input_port_message_iterator(
+                    self._input_ports['in']
+                )
+
+            def _user_consume(self):
+                nonlocal can_seek_forward
+                can_seek_forward = self._msg_iter.can_seek_forward
+
+        can_seek_forward = None
+        graph = _create_graph(MySource, MySink)
+        graph.run_once()
+        self.assertIs(can_seek_forward, set_can_seek_forward)
+
+    def test_config_can_seek_forward_default(self):
+        self._test_config_can_seek_forward(False)
+
+    def test_config_can_seek_forward(self):
+        self._test_config_can_seek_forward(True)
+
+    def test_config_can_seek_forward_wrong_type(self):
+        class MyIter(bt2._UserMessageIterator):
+            def __init__(self, config, port):
+                config.can_seek_forward = 1
+
+        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+            def __init__(self, config, params, obj):
+                self._add_output_port('out')
+
+        graph = _create_graph(MySource, SimpleSink)
+        with self.assertRaises(bt2._Error) as ctx:
+            graph.run()
+
+        root_cause = ctx.exception[0]
+        self.assertIn("TypeError: 'int' is not a 'bool' object", root_cause.message)
+
     def test_component(self):
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self, self_port_output):
+            def __init__(self, config, self_port_output):
                 nonlocal salut
                 salut = self._component._salut
 
@@ -187,7 +253,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
     def test_port(self):
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self_iter, self_port_output):
+            def __init__(self_iter, config, self_port_output):
                 nonlocal called
                 called = True
                 port = self_iter._port
@@ -206,7 +272,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
     def test_addr(self):
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self, self_port_output):
+            def __init__(self, config, self_port_output):
                 nonlocal addr
                 addr = self.addr
 
@@ -224,7 +290,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
     # and can be re-used.
     def test_reuse_message(self):
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self, port):
+            def __init__(self, config, port):
                 tc, sc, ec = port.user_data
                 trace = tc()
                 stream = trace.create_stream(sc)
@@ -322,7 +388,7 @@ def _setup_seek_test(
     user_can_seek_ns_from_origin=None,
 ):
     class MySourceIter(bt2._UserMessageIterator):
-        def __init__(self, port):
+        def __init__(self, config, port):
             tc, sc, ec = port.user_data
             trace = tc()
             stream = trace.create_stream(sc)
@@ -367,7 +433,7 @@ def _setup_seek_test(
             self._add_output_port('out', (tc, sc, ec))
 
     class MyFilterIter(bt2._UserMessageIterator):
-        def __init__(self, port):
+        def __init__(self, config, port):
             self._upstream_iter = self._create_input_port_message_iterator(
                 self._component._input_ports['in']
             )
This page took 0.024808 seconds and 4 git commands to generate.