lib: run most of bt_self_component_port_input_message_iterator_try_finalize when...
[babeltrace.git] / tests / bindings / python / bt2 / test_message_iterator.py
index b331e4a40996188dd88cc6f0872d2451370d9549..bc8078a6922242798ac770ebc9b30b29dc36ef3a 100644 (file)
@@ -379,6 +379,63 @@ class UserMessageIteratorTestCase(unittest.TestCase):
             with self.assertRaises(bt2.TryAgain):
                 next(it)
 
+    def test_error_in_iterator_with_cycle_after_having_created_upstream_iterator(self):
+        # Test a failure that triggered an abort in libbabeltrace2, in this situation:
+        #
+        #   - The filter iterator creates an upstream iterator.
+        #   - The filter iterator creates a reference cycle, including itself.
+        #   - An exception is raised, causing the filter iterator's
+        #     initialization method to fail.
+        class MySourceIter(bt2._UserMessageIterator):
+            pass
+
+        class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
+            def __init__(self, config, params, obj):
+                self._add_output_port('out')
+
+        class MyFilterIter(bt2._UserMessageIterator):
+            def __init__(self, config, port):
+                # First, create an upstream iterator.
+                self._upstream_iter = self._create_input_port_message_iterator(
+                    self._component._input_ports['in']
+                )
+
+                # Then, voluntarily make a reference cycle that will keep this
+                # Python object alive, which will keep the upstream iterator
+                # Babeltrace object alive.
+                self._self = self
+
+                # Finally, raise an exception to make __init__ fail.
+                raise ValueError('woops')
+
+        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
+            def __init__(self, config, params, obj):
+                self._in = self._add_input_port('in')
+                self._out = self._add_output_port('out')
+
+        class MySink(bt2._UserSinkComponent):
+            def __init__(self, config, params, obj):
+                self._input_port = self._add_input_port('in')
+
+            def _user_graph_is_configured(self):
+                self._upstream_iter = self._create_input_port_message_iterator(
+                    self._input_port
+                )
+
+            def _user_consume(self):
+                # We should not reach this.
+                assert False
+
+        g = bt2.Graph()
+        src = g.add_component(MySource, 'src')
+        flt = g.add_component(MyFilter, 'flt')
+        snk = g.add_component(MySink, 'snk')
+        g.connect_ports(src.output_ports['out'], flt.input_ports['in'])
+        g.connect_ports(flt.output_ports['out'], snk.input_ports['in'])
+
+        with self.assertRaisesRegex(bt2._Error, 'ValueError: woops'):
+            g.run()
+
 
 def _setup_seek_test(
     sink_cls,
@@ -386,6 +443,7 @@ def _setup_seek_test(
     user_can_seek_beginning=None,
     user_seek_ns_from_origin=None,
     user_can_seek_ns_from_origin=None,
+    can_seek_forward=False,
 ):
     class MySourceIter(bt2._UserMessageIterator):
         def __init__(self, config, port):
@@ -403,6 +461,7 @@ def _setup_seek_test(
                 self._create_stream_end_message(stream),
             ]
             self._at = 0
+            config.can_seek_forward = can_seek_forward
 
         def __next__(self):
             if self._at < len(self._msgs):
@@ -437,6 +496,7 @@ def _setup_seek_test(
             self._upstream_iter = self._create_input_port_message_iterator(
                 self._component._input_ports['in']
             )
+            config.can_seek_forward = self._upstream_iter.can_seek_forward
 
         def __next__(self):
             return next(self._upstream_iter)
@@ -462,6 +522,13 @@ def _setup_seek_test(
 
 
 class UserMessageIteratorSeekBeginningTestCase(unittest.TestCase):
+    def test_can_seek_beginning_without_seek_beginning(self):
+        with self.assertRaisesRegex(
+            bt2._IncompleteUserClass,
+            "cannot create component class 'MySource': message iterator class implements _user_can_seek_beginning but not _user_seek_beginning",
+        ):
+            _setup_seek_test(SimpleSink, user_can_seek_beginning=lambda: None)
+
     def test_can_seek_beginning(self):
         class MySink(bt2._UserSinkComponent):
             def __init__(self, config, params, obj):
@@ -481,7 +548,9 @@ class UserMessageIteratorSeekBeginningTestCase(unittest.TestCase):
             return input_port_iter_can_seek_beginning
 
         graph = _setup_seek_test(
-            MySink, user_can_seek_beginning=_user_can_seek_beginning
+            MySink,
+            user_can_seek_beginning=_user_can_seek_beginning,
+            user_seek_beginning=lambda: None,
         )
 
         input_port_iter_can_seek_beginning = True
@@ -557,7 +626,9 @@ class UserMessageIteratorSeekBeginningTestCase(unittest.TestCase):
             raise ValueError('moustiquaire')
 
         graph = _setup_seek_test(
-            MySink, user_can_seek_beginning=_user_can_seek_beginning
+            MySink,
+            user_can_seek_beginning=_user_can_seek_beginning,
+            user_seek_beginning=lambda: None,
         )
 
         with self.assertRaises(bt2._Error) as ctx:
@@ -584,7 +655,9 @@ class UserMessageIteratorSeekBeginningTestCase(unittest.TestCase):
             return 'Amqui'
 
         graph = _setup_seek_test(
-            MySink, user_can_seek_beginning=_user_can_seek_beginning
+            MySink,
+            user_can_seek_beginning=_user_can_seek_beginning,
+            user_seek_beginning=lambda: None,
         )
 
         with self.assertRaises(bt2._Error) as ctx:
@@ -660,48 +733,227 @@ class UserMessageIteratorSeekBeginningTestCase(unittest.TestCase):
 
 
 class UserMessageIteratorSeekNsFromOriginTestCase(unittest.TestCase):
-    def test_can_seek_ns_from_origin(self):
-        class MySink(bt2._UserSinkComponent):
-            def __init__(self, config, params, obj):
-                self._add_input_port('in')
-
-            def _user_graph_is_configured(self):
-                self._msg_iter = self._create_input_port_message_iterator(
-                    self._input_ports['in']
+    def test_can_seek_ns_from_origin_without_seek_ns_from_origin(self):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Returns True (don't really care, as long
+        #     as it's provided)
+        #   - seek_ns_from_origin provided: No
+        #   - can the iterator seek beginning: Don't care
+        #   - can the iterator seek forward: Don't care
+        for can_seek_ns_from_origin in (False, True):
+            for iter_can_seek_beginning in (False, True):
+                for iter_can_seek_forward in (False, True):
+                    with self.assertRaisesRegex(
+                        bt2._IncompleteUserClass,
+                        "cannot create component class 'MySource': message iterator class implements _user_can_seek_ns_from_origin but not _user_seek_ns_from_origin",
+                    ):
+                        self._can_seek_ns_from_origin_test(
+                            None,
+                            user_can_seek_ns_from_origin_ret_val=True,
+                            user_seek_ns_from_origin_provided=False,
+                            iter_can_seek_beginning=iter_can_seek_beginning,
+                            iter_can_seek_forward=iter_can_seek_forward,
+                        )
+
+    def test_can_seek_ns_from_origin_returns_true(self):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns True
+        #   - seek_ns_from_origin provided: Yes
+        #   - can the iterator seek beginning: Don't care
+        #   - can the iterator seek forward: Don't care
+        #
+        # We expect iter.can_seek_ns_from_origin to return True.
+        for iter_can_seek_beginning in (False, True):
+            for iter_can_seek_forward in (False, True):
+                self._can_seek_ns_from_origin_test(
+                    expected_outcome=True,
+                    user_can_seek_ns_from_origin_ret_val=True,
+                    user_seek_ns_from_origin_provided=True,
+                    iter_can_seek_beginning=iter_can_seek_beginning,
+                    iter_can_seek_forward=iter_can_seek_forward,
                 )
 
-            def _user_consume(self):
-                nonlocal can_seek_ns_from_origin
-                nonlocal test_ns_from_origin
-                can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
-                    test_ns_from_origin
+    def test_can_seek_ns_from_origin_returns_false_can_seek_beginning_forward_seekable(
+        self,
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns False
+        #   - seek_ns_from_origin provided: Yes
+        #   - can the iterator seek beginning: Yes
+        #   - can the iterator seek forward: Yes
+        #
+        # We expect iter.can_seek_ns_from_origin to return True.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=True,
+            user_can_seek_ns_from_origin_ret_val=False,
+            user_seek_ns_from_origin_provided=True,
+            iter_can_seek_beginning=True,
+            iter_can_seek_forward=True,
+        )
+
+    def test_can_seek_ns_from_origin_returns_false_can_seek_beginning_not_forward_seekable(
+        self,
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns False
+        #   - seek_ns_from_origin provided: Yes
+        #   - can the iterator seek beginning: Yes
+        #   - can the iterator seek forward: No
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=False,
+            user_seek_ns_from_origin_provided=True,
+            iter_can_seek_beginning=True,
+            iter_can_seek_forward=False,
+        )
+
+    def test_can_seek_ns_from_origin_returns_false_cant_seek_beginning_forward_seekable(
+        self,
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns False
+        #   - seek_ns_from_origin provided: Yes
+        #   - can the iterator seek beginning: No
+        #   - can the iterator seek forward: Yes
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=False,
+            user_seek_ns_from_origin_provided=True,
+            iter_can_seek_beginning=False,
+            iter_can_seek_forward=True,
+        )
+
+    def test_can_seek_ns_from_origin_returns_false_cant_seek_beginning_not_forward_seekable(
+        self,
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns False
+        #   - seek_ns_from_origin provided: Yes
+        #   - can the iterator seek beginning: No
+        #   - can the iterator seek forward: No
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=False,
+            user_seek_ns_from_origin_provided=True,
+            iter_can_seek_beginning=False,
+            iter_can_seek_forward=False,
+        )
+
+    def test_no_can_seek_ns_from_origin_seek_ns_from_origin(self):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Yes
+        #   - can the iterator seek beginning: Don't care
+        #   - can the iterator seek forward: Don't care
+        #
+        # We expect iter.can_seek_ns_from_origin to return True.
+        for iter_can_seek_beginning in (False, True):
+            for iter_can_seek_forward in (False, True):
+                self._can_seek_ns_from_origin_test(
+                    expected_outcome=True,
+                    user_can_seek_ns_from_origin_ret_val=None,
+                    user_seek_ns_from_origin_provided=True,
+                    iter_can_seek_beginning=iter_can_seek_beginning,
+                    iter_can_seek_forward=iter_can_seek_forward,
                 )
 
-        def _user_can_seek_ns_from_origin(iter_self, ns_from_origin):
-            nonlocal input_port_iter_can_seek_ns_from_origin
-            nonlocal test_ns_from_origin
-            self.assertEqual(ns_from_origin, test_ns_from_origin)
-            return input_port_iter_can_seek_ns_from_origin
+    def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_can_seek_beginning_forward_seekable(
+        self,
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Not provided
+        #   - can the iterator seek beginning: Yes
+        #   - can the iterator seek forward: Yes
+        #
+        # We expect iter.can_seek_ns_from_origin to return True.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=True,
+            user_can_seek_ns_from_origin_ret_val=None,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=True,
+            iter_can_seek_forward=True,
+        )
 
-        graph = _setup_seek_test(
-            MySink, user_can_seek_ns_from_origin=_user_can_seek_ns_from_origin
+    def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_can_seek_beginning_not_forward_seekable(
+        self,
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Not provided
+        #   - can the iterator seek beginning: Yes
+        #   - can the iterator seek forward: No
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=None,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=True,
+            iter_can_seek_forward=False,
         )
 
-        input_port_iter_can_seek_ns_from_origin = True
-        can_seek_ns_from_origin = None
-        test_ns_from_origin = 1
-        graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, True)
+    def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_cant_seek_beginning_forward_seekable(
+        self,
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Not provided
+        #   - can the iterator seek beginning: No
+        #   - can the iterator seek forward: Yes
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=None,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=False,
+            iter_can_seek_forward=True,
+        )
 
-        input_port_iter_can_seek_ns_from_origin = False
-        can_seek_ns_from_origin = None
-        test_ns_from_origin = 2
-        graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, False)
+    def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_cant_seek_beginning_not_forward_seekable(
+        self,
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Not provided
+        #   - can the iterator seek beginning: No
+        #   - can the iterator seek forward: No
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=None,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=False,
+            iter_can_seek_forward=False,
+        )
 
-    def test_no_can_seek_ns_from_origin_with_seek_ns_from_origin(self):
-        # Test an iterator without a _user_can_seek_ns_from_origin method, but
-        # with a _user_seek_ns_from_origin method.
+    def _can_seek_ns_from_origin_test(
+        self,
+        expected_outcome,
+        user_can_seek_ns_from_origin_ret_val,
+        user_seek_ns_from_origin_provided,
+        iter_can_seek_beginning,
+        iter_can_seek_forward,
+    ):
         class MySink(bt2._UserSinkComponent):
             def __init__(self, config, params, obj):
                 self._add_input_port('in')
@@ -713,74 +965,52 @@ class UserMessageIteratorSeekNsFromOriginTestCase(unittest.TestCase):
 
             def _user_consume(self):
                 nonlocal can_seek_ns_from_origin
-                nonlocal test_ns_from_origin
                 can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
-                    test_ns_from_origin
+                    passed_ns_from_origin
                 )
 
-        def _user_seek_ns_from_origin(self):
-            pass
+        if user_can_seek_ns_from_origin_ret_val is not None:
 
-        graph = _setup_seek_test(
-            MySink, user_seek_ns_from_origin=_user_seek_ns_from_origin
-        )
-        can_seek_ns_from_origin = None
-        test_ns_from_origin = 2
-        graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, True)
+            def user_can_seek_ns_from_origin(self, ns_from_origin):
+                nonlocal received_ns_from_origin
+                received_ns_from_origin = ns_from_origin
+                return user_can_seek_ns_from_origin_ret_val
 
-    def test_no_can_seek_ns_from_origin_with_seek_beginning(self):
-        # Test an iterator without a _user_can_seek_ns_from_origin method, but
-        # with a _user_seek_beginning method.
-        class MySink(bt2._UserSinkComponent):
-            def __init__(self, config, params, obj):
-                self._add_input_port('in')
+        else:
+            user_can_seek_ns_from_origin = None
 
-            def _user_graph_is_configured(self):
-                self._msg_iter = self._create_input_port_message_iterator(
-                    self._input_ports['in']
-                )
+        if user_seek_ns_from_origin_provided:
 
-            def _user_consume(self):
-                nonlocal can_seek_ns_from_origin
-                nonlocal test_ns_from_origin
-                can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
-                    test_ns_from_origin
-                )
+            def user_seek_ns_from_origin(self, ns_from_origin):
+                pass
 
-        def _user_seek_beginning(self):
-            pass
+        else:
+            user_seek_ns_from_origin = None
 
-        graph = _setup_seek_test(MySink, user_seek_beginning=_user_seek_beginning)
-        can_seek_ns_from_origin = None
-        test_ns_from_origin = 2
-        graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, True)
+        if iter_can_seek_beginning:
 
-    def test_no_can_seek_ns_from_origin(self):
-        # Test an iterator without a _user_can_seek_ns_from_origin method
-        # and no other related method.
-        class MySink(bt2._UserSinkComponent):
-            def __init__(self, config, params, obj):
-                self._add_input_port('in')
+            def user_seek_beginning(self):
+                pass
 
-            def _user_graph_is_configured(self):
-                self._msg_iter = self._create_input_port_message_iterator(
-                    self._input_ports['in']
-                )
+        else:
+            user_seek_beginning = None
 
-            def _user_consume(self):
-                nonlocal can_seek_ns_from_origin
-                nonlocal test_ns_from_origin
-                can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
-                    test_ns_from_origin
-                )
+        graph = _setup_seek_test(
+            MySink,
+            user_can_seek_ns_from_origin=user_can_seek_ns_from_origin,
+            user_seek_ns_from_origin=user_seek_ns_from_origin,
+            user_seek_beginning=user_seek_beginning,
+            can_seek_forward=iter_can_seek_forward,
+        )
 
-        graph = _setup_seek_test(MySink)
+        passed_ns_from_origin = 77
+        received_ns_from_origin = None
         can_seek_ns_from_origin = None
-        test_ns_from_origin = 2
         graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, False)
+        self.assertIs(can_seek_ns_from_origin, expected_outcome)
+
+        if user_can_seek_ns_from_origin_ret_val is not None:
+            self.assertEqual(received_ns_from_origin, passed_ns_from_origin)
 
     def test_can_seek_ns_from_origin_user_error(self):
         class MySink(bt2._UserSinkComponent):
@@ -800,7 +1030,9 @@ class UserMessageIteratorSeekNsFromOriginTestCase(unittest.TestCase):
             raise ValueError('Joutel')
 
         graph = _setup_seek_test(
-            MySink, user_can_seek_ns_from_origin=_user_can_seek_ns_from_origin
+            MySink,
+            user_can_seek_ns_from_origin=_user_can_seek_ns_from_origin,
+            user_seek_ns_from_origin=lambda: None,
         )
 
         with self.assertRaises(bt2._Error) as ctx:
@@ -827,7 +1059,9 @@ class UserMessageIteratorSeekNsFromOriginTestCase(unittest.TestCase):
             return 'Nitchequon'
 
         graph = _setup_seek_test(
-            MySink, user_can_seek_ns_from_origin=_user_can_seek_ns_from_origin
+            MySink,
+            user_can_seek_ns_from_origin=_user_can_seek_ns_from_origin,
+            user_seek_ns_from_origin=lambda: None,
         )
 
         with self.assertRaises(bt2._Error) as ctx:
@@ -853,7 +1087,6 @@ class UserMessageIteratorSeekNsFromOriginTestCase(unittest.TestCase):
             nonlocal actual_ns_from_origin
             actual_ns_from_origin = ns_from_origin
 
-        msg = None
         graph = _setup_seek_test(
             MySink, user_seek_ns_from_origin=_user_seek_ns_from_origin
         )
This page took 0.028859 seconds and 4 git commands to generate.