lib: make can_seek_ns_from_origin logic use `can_seek_forward` property of iterator
[babeltrace.git] / tests / bindings / python / bt2 / test_message_iterator.py
index 21f95704b33d072e814f641f1e8c6766d2907699..37f6dc87798dcda7b32b4ef72b8e36b40b56b0f5 100644 (file)
@@ -386,6 +386,7 @@ def _setup_seek_test(
     user_can_seek_beginning=None,
     user_seek_ns_from_origin=None,
     user_can_seek_ns_from_origin=None,
+    can_seek_forward=False,
 ):
     class MySourceIter(bt2._UserMessageIterator):
         def __init__(self, config, port):
@@ -403,6 +404,7 @@ def _setup_seek_test(
                 self._create_stream_end_message(stream),
             ]
             self._at = 0
+            config.can_seek_forward = can_seek_forward
 
         def __next__(self):
             if self._at < len(self._msgs):
@@ -437,6 +439,7 @@ def _setup_seek_test(
             self._upstream_iter = self._create_input_port_message_iterator(
                 self._component._input_ports['in']
             )
+            config.can_seek_forward = self._upstream_iter.can_seek_forward
 
         def __next__(self):
             return next(self._upstream_iter)
@@ -660,48 +663,209 @@ class UserMessageIteratorSeekBeginningTestCase(unittest.TestCase):
 
 
 class UserMessageIteratorSeekNsFromOriginTestCase(unittest.TestCase):
-    def test_can_seek_ns_from_origin(self):
-        class MySink(bt2._UserSinkComponent):
-            def __init__(self, config, params, obj):
-                self._add_input_port('in')
+    def test_can_seek_ns_from_origin_returns_true(self):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns True
+        #   - seek_ns_from_origin provided: Don't care
+        #   - can the iterator seek beginning: Don't care
+        #   - can the iterator seek forward: Don't care
+        #
+        # We expect iter.can_seek_ns_from_origin to return True.
+        for user_seek_ns_from_origin_provided in (False, True):
+            for iter_can_seek_beginning in (False, True):
+                for iter_can_seek_forward in (False, True):
+                    self._can_seek_ns_from_origin_test(
+                        expected_outcome=True,
+                        user_can_seek_ns_from_origin_ret_val=True,
+                        user_seek_ns_from_origin_provided=user_seek_ns_from_origin_provided,
+                        iter_can_seek_beginning=iter_can_seek_beginning,
+                        iter_can_seek_forward=iter_can_seek_forward,
+                    )
+
+    def test_can_seek_ns_from_origin_returns_false_can_seek_beginning_forward_seekable(
+        self
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns False
+        #   - seek_ns_from_origin provided: Don't care
+        #   - can the iterator seek beginning: Yes
+        #   - can the iterator seek forward: Yes
+        #
+        # We expect iter.can_seek_ns_from_origin to return True.
+        for user_seek_ns_from_origin_provided in (False, True):
+            self._can_seek_ns_from_origin_test(
+                expected_outcome=True,
+                user_can_seek_ns_from_origin_ret_val=False,
+                user_seek_ns_from_origin_provided=user_seek_ns_from_origin_provided,
+                iter_can_seek_beginning=True,
+                iter_can_seek_forward=True,
+            )
 
-            def _user_graph_is_configured(self):
-                self._msg_iter = self._create_input_port_message_iterator(
-                    self._input_ports['in']
-                )
+    def test_can_seek_ns_from_origin_returns_false_can_seek_beginning_not_forward_seekable(
+        self
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns False
+        #   - seek_ns_from_origin provided: Don't care
+        #   - can the iterator seek beginning: Yes
+        #   - can the iterator seek forward: No
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        for user_seek_ns_from_origin_provided in (False, True):
+            self._can_seek_ns_from_origin_test(
+                expected_outcome=False,
+                user_can_seek_ns_from_origin_ret_val=False,
+                user_seek_ns_from_origin_provided=user_seek_ns_from_origin_provided,
+                iter_can_seek_beginning=True,
+                iter_can_seek_forward=False,
+            )
 
-            def _user_consume(self):
-                nonlocal can_seek_ns_from_origin
-                nonlocal test_ns_from_origin
-                can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
-                    test_ns_from_origin
+    def test_can_seek_ns_from_origin_returns_false_cant_seek_beginning_forward_seekable(
+        self
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns False
+        #   - seek_ns_from_origin provided: Don't care
+        #   - can the iterator seek beginning: No
+        #   - can the iterator seek forward: Yes
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        # for user_seek_ns_from_origin_provided in (False, True):
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=False,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=False,
+            iter_can_seek_forward=True,
+        )
+
+    def test_can_seek_ns_from_origin_returns_false_cant_seek_beginning_not_forward_seekable(
+        self
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: returns False
+        #   - seek_ns_from_origin provided: Don't care
+        #   - can the iterator seek beginning: No
+        #   - can the iterator seek forward: No
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        for user_seek_ns_from_origin_provided in (False, True):
+            self._can_seek_ns_from_origin_test(
+                expected_outcome=False,
+                user_can_seek_ns_from_origin_ret_val=False,
+                user_seek_ns_from_origin_provided=user_seek_ns_from_origin_provided,
+                iter_can_seek_beginning=False,
+                iter_can_seek_forward=False,
+            )
+
+    def test_no_can_seek_ns_from_origin_seek_ns_from_origin(self):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Yes
+        #   - can the iterator seek beginning: Don't care
+        #   - can the iterator seek forward: Don't care
+        #
+        # We expect iter.can_seek_ns_from_origin to return True.
+        for iter_can_seek_beginning in (False, True):
+            for iter_can_seek_forward in (False, True):
+                self._can_seek_ns_from_origin_test(
+                    expected_outcome=True,
+                    user_can_seek_ns_from_origin_ret_val=None,
+                    user_seek_ns_from_origin_provided=True,
+                    iter_can_seek_beginning=iter_can_seek_beginning,
+                    iter_can_seek_forward=iter_can_seek_forward,
                 )
 
-        def _user_can_seek_ns_from_origin(iter_self, ns_from_origin):
-            nonlocal input_port_iter_can_seek_ns_from_origin
-            nonlocal test_ns_from_origin
-            self.assertEqual(ns_from_origin, test_ns_from_origin)
-            return input_port_iter_can_seek_ns_from_origin
+    def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_can_seek_beginning_forward_seekable(
+        self
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Not provided
+        #   - can the iterator seek beginning: Yes
+        #   - can the iterator seek forward: Yes
+        #
+        # We expect iter.can_seek_ns_from_origin to return True.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=True,
+            user_can_seek_ns_from_origin_ret_val=None,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=True,
+            iter_can_seek_forward=True,
+        )
 
-        graph = _setup_seek_test(
-            MySink, user_can_seek_ns_from_origin=_user_can_seek_ns_from_origin
+    def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_can_seek_beginning_not_forward_seekable(
+        self
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Not provided
+        #   - can the iterator seek beginning: Yes
+        #   - can the iterator seek forward: No
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=None,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=True,
+            iter_can_seek_forward=False,
         )
 
-        input_port_iter_can_seek_ns_from_origin = True
-        can_seek_ns_from_origin = None
-        test_ns_from_origin = 1
-        graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, True)
+    def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_cant_seek_beginning_forward_seekable(
+        self
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Not provided
+        #   - can the iterator seek beginning: No
+        #   - can the iterator seek forward: Yes
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=None,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=False,
+            iter_can_seek_forward=True,
+        )
 
-        input_port_iter_can_seek_ns_from_origin = False
-        can_seek_ns_from_origin = None
-        test_ns_from_origin = 2
-        graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, False)
+    def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_cant_seek_beginning_not_forward_seekable(
+        self
+    ):
+        # Test the case where:
+        #
+        #   - can_seek_ns_from_origin: Not provided
+        #   - seek_ns_from_origin provided: Not provided
+        #   - can the iterator seek beginning: No
+        #   - can the iterator seek forward: No
+        #
+        # We expect iter.can_seek_ns_from_origin to return False.
+        self._can_seek_ns_from_origin_test(
+            expected_outcome=False,
+            user_can_seek_ns_from_origin_ret_val=None,
+            user_seek_ns_from_origin_provided=False,
+            iter_can_seek_beginning=False,
+            iter_can_seek_forward=False,
+        )
 
-    def test_no_can_seek_ns_from_origin_with_seek_ns_from_origin(self):
-        # Test an iterator without a _user_can_seek_ns_from_origin method, but
-        # with a _user_seek_ns_from_origin method.
+    def _can_seek_ns_from_origin_test(
+        self,
+        expected_outcome,
+        user_can_seek_ns_from_origin_ret_val,
+        user_seek_ns_from_origin_provided,
+        iter_can_seek_beginning,
+        iter_can_seek_forward,
+    ):
         class MySink(bt2._UserSinkComponent):
             def __init__(self, config, params, obj):
                 self._add_input_port('in')
@@ -713,74 +877,52 @@ class UserMessageIteratorSeekNsFromOriginTestCase(unittest.TestCase):
 
             def _user_consume(self):
                 nonlocal can_seek_ns_from_origin
-                nonlocal test_ns_from_origin
                 can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
-                    test_ns_from_origin
+                    passed_ns_from_origin
                 )
 
-        def _user_seek_ns_from_origin(self):
-            pass
+        if user_can_seek_ns_from_origin_ret_val is not None:
 
-        graph = _setup_seek_test(
-            MySink, user_seek_ns_from_origin=_user_seek_ns_from_origin
-        )
-        can_seek_ns_from_origin = None
-        test_ns_from_origin = 2
-        graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, True)
+            def user_can_seek_ns_from_origin(self, ns_from_origin):
+                nonlocal received_ns_from_origin
+                received_ns_from_origin = ns_from_origin
+                return user_can_seek_ns_from_origin_ret_val
 
-    def test_no_can_seek_ns_from_origin_with_seek_beginning(self):
-        # Test an iterator without a _user_can_seek_ns_from_origin method, but
-        # with a _user_seek_beginning method.
-        class MySink(bt2._UserSinkComponent):
-            def __init__(self, config, params, obj):
-                self._add_input_port('in')
+        else:
+            user_can_seek_ns_from_origin = None
 
-            def _user_graph_is_configured(self):
-                self._msg_iter = self._create_input_port_message_iterator(
-                    self._input_ports['in']
-                )
+        if user_seek_ns_from_origin_provided:
 
-            def _user_consume(self):
-                nonlocal can_seek_ns_from_origin
-                nonlocal test_ns_from_origin
-                can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
-                    test_ns_from_origin
-                )
+            def user_seek_ns_from_origin(self, ns_from_origin):
+                pass
 
-        def _user_seek_beginning(self):
-            pass
+        else:
+            user_seek_ns_from_origin = None
 
-        graph = _setup_seek_test(MySink, user_seek_beginning=_user_seek_beginning)
-        can_seek_ns_from_origin = None
-        test_ns_from_origin = 2
-        graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, True)
+        if iter_can_seek_beginning:
 
-    def test_no_can_seek_ns_from_origin(self):
-        # Test an iterator without a _user_can_seek_ns_from_origin method
-        # and no other related method.
-        class MySink(bt2._UserSinkComponent):
-            def __init__(self, config, params, obj):
-                self._add_input_port('in')
+            def user_seek_beginning(self):
+                pass
 
-            def _user_graph_is_configured(self):
-                self._msg_iter = self._create_input_port_message_iterator(
-                    self._input_ports['in']
-                )
+        else:
+            user_seek_beginning = None
 
-            def _user_consume(self):
-                nonlocal can_seek_ns_from_origin
-                nonlocal test_ns_from_origin
-                can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
-                    test_ns_from_origin
-                )
+        graph = _setup_seek_test(
+            MySink,
+            user_can_seek_ns_from_origin=user_can_seek_ns_from_origin,
+            user_seek_ns_from_origin=user_seek_ns_from_origin,
+            user_seek_beginning=user_seek_beginning,
+            can_seek_forward=iter_can_seek_forward,
+        )
 
-        graph = _setup_seek_test(MySink)
+        passed_ns_from_origin = 77
+        received_ns_from_origin = None
         can_seek_ns_from_origin = None
-        test_ns_from_origin = 2
         graph.run_once()
-        self.assertIs(can_seek_ns_from_origin, False)
+        self.assertIs(can_seek_ns_from_origin, expected_outcome)
+
+        if user_can_seek_ns_from_origin_ret_val is not None:
+            self.assertEqual(received_ns_from_origin, passed_ns_from_origin)
 
     def test_can_seek_ns_from_origin_user_error(self):
         class MySink(bt2._UserSinkComponent):
This page took 0.027477 seconds and 4 git commands to generate.