bt2: refactor test_message_iterator
authorSimon Marchi <simon.marchi@efficios.com>
Sun, 8 Sep 2019 21:26:09 +0000 (17:26 -0400)
committerSimon Marchi <simon.marchi@efficios.com>
Tue, 17 Sep 2019 19:05:53 +0000 (15:05 -0400)
This patch tries to make test_message_iterator a bit easier to
understand and extend, since the following patches add various tests in
it.

- Split the test class in two, with "seek beginning" tests in their own
  class.
- Rename _setup_seek_beginning_test to _setup_seek_test, as we'll use it
  for seek ns from origin tests.
- Move _create_graph and _setup_seek_test to be top-level
  functions, they don't really benefit from being static methods.  Plus,
  we'll want to share them between the different test classes.
- Make _setup_seek_test take parameters for various optional methods to
  add to the message iterator class (for now, _user_seek_beginning and
  _user_can_seek_beginning).
- Make _setup_seek_test use _create_graph.
- Split test `test_can_seek_beginning` in three: one for when we have
  a _user_can_seek_beginning method, one for when we don't but have a
  _user_seek_beginning method, and the other one for when we have none
  of those.

Change-Id: Ib99965d60a406acde7c0b16f5fde30a268c84673
Signed-off-by: Simon Marchi <simon.marchi@efficios.com>
Reviewed-on: https://review.lttng.org/c/babeltrace/+/2014
Tested-by: jenkins <jenkins@lttng.org>
Reviewed-by: Francis Deslauriers <francis.deslauriers@efficios.com>
tests/bindings/python/bt2/test_message_iterator.py

index acd091f872bdca240753a8bd26f64d1e2df4b612..2d4ab43850774131ffa9982355639079aa3f518b 100644 (file)
@@ -23,41 +23,39 @@ from utils import TestOutputPortMessageIterator
 from bt2 import port as bt2_port
 
 
-class UserMessageIteratorTestCase(unittest.TestCase):
-    @staticmethod
-    def _create_graph(src_comp_cls, flt_comp_cls=None):
-        class MySink(bt2._UserSinkComponent):
-            def __init__(self, params, obj):
-                self._add_input_port('in')
+class SimpleSink(bt2._UserSinkComponent):
+    # Straightforward sink that creates one input port (`in`) and consumes from
+    # it.
 
-            def _user_consume(self):
-                next(self._msg_iter)
+    def __init__(self, params, obj):
+        self._add_input_port('in')
 
-            def _user_graph_is_configured(self):
-                self._msg_iter = self._create_input_port_message_iterator(
-                    self._input_ports['in']
-                )
+    def _user_consume(self):
+        next(self._msg_iter)
 
-        graph = bt2.Graph()
-        src_comp = graph.add_component(src_comp_cls, 'src')
+    def _user_graph_is_configured(self):
+        self._msg_iter = self._create_input_port_message_iterator(
+            self._input_ports['in']
+        )
 
-        if flt_comp_cls is not None:
-            flt_comp = graph.add_component(flt_comp_cls, 'flt')
 
-        sink_comp = graph.add_component(MySink, 'sink')
+def _create_graph(src_comp_cls, sink_comp_cls, flt_comp_cls=None):
+    graph = bt2.Graph()
 
-        if flt_comp_cls is not None:
-            assert flt_comp is not None
-            graph.connect_ports(
-                src_comp.output_ports['out'], flt_comp.input_ports['in']
-            )
-            out_port = flt_comp.output_ports['out']
-        else:
-            out_port = src_comp.output_ports['out']
+    src_comp = graph.add_component(src_comp_cls, 'src')
+    sink_comp = graph.add_component(sink_comp_cls, 'sink')
 
-        graph.connect_ports(out_port, sink_comp.input_ports['in'])
-        return graph
+    if flt_comp_cls is not None:
+        flt_comp = graph.add_component(flt_comp_cls, 'flt')
+        graph.connect_ports(src_comp.output_ports['out'], flt_comp.input_ports['in'])
+        graph.connect_ports(flt_comp.output_ports['out'], sink_comp.input_ports['in'])
+    else:
+        graph.connect_ports(src_comp.output_ports['out'], sink_comp.input_ports['in'])
 
+    return graph
+
+
+class UserMessageIteratorTestCase(unittest.TestCase):
     def test_init(self):
         the_output_port_from_source = None
         the_output_port_from_iter = None
@@ -75,7 +73,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 the_output_port_from_source = self._add_output_port('out', 'user data')
 
         initialized = False
-        graph = self._create_graph(MySource)
+        graph = _create_graph(MySource, SimpleSink)
         graph.run()
         self.assertTrue(initialized)
         self.assertEqual(
@@ -111,7 +109,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
         src_iter_initialized = False
         flt_iter_initialized = False
-        graph = self._create_graph(MySource, MyFilter)
+        graph = _create_graph(MySource, SimpleSink, MyFilter)
         graph.run()
         self.assertTrue(src_iter_initialized)
         self.assertTrue(flt_iter_initialized)
@@ -142,7 +140,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 self._add_input_port('in')
                 self._add_output_port('out')
 
-        graph = self._create_graph(MySource, MyFilter)
+        graph = _create_graph(MySource, SimpleSink, MyFilter)
 
         with self.assertRaises(bt2._Error) as ctx:
             graph.run()
@@ -166,7 +164,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 self._add_output_port('out')
 
         finalized = False
-        graph = self._create_graph(MySource)
+        graph = _create_graph(MySource, SimpleSink)
         graph.run()
         del graph
         self.assertTrue(finalized)
@@ -183,7 +181,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 self._salut = 23
 
         salut = None
-        graph = self._create_graph(MySource)
+        graph = _create_graph(MySource, SimpleSink)
         graph.run()
         self.assertEqual(salut, 23)
 
@@ -202,7 +200,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 self._add_output_port('out')
 
         called = False
-        graph = self._create_graph(MySource)
+        graph = _create_graph(MySource, SimpleSink)
         graph.run()
         self.assertTrue(called)
 
@@ -217,7 +215,7 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 self._add_output_port('out')
 
         addr = None
-        graph = self._create_graph(MySource)
+        graph = _create_graph(MySource, SimpleSink)
         graph.run()
         self.assertIsNotNone(addr)
         self.assertNotEqual(addr, 0)
@@ -269,48 +267,17 @@ class UserMessageIteratorTestCase(unittest.TestCase):
         self.assertIs(type(msg_ev2), bt2._EventMessageConst)
         self.assertEqual(msg_ev1.addr, msg_ev2.addr)
 
-    @staticmethod
-    def _setup_seek_beginning_test(sink_cls):
-        # Use a source, a filter and an output port iterator.  This allows us
-        # to test calling `seek_beginning` on both a _OutputPortMessageIterator
-        # and a _UserComponentInputPortMessageIterator, on top of checking that
-        # _UserMessageIterator._seek_beginning is properly called.
-
-        class MySourceIter(bt2._UserMessageIterator):
-            def __init__(self, port):
-                tc, sc, ec = port.user_data
-                trace = tc()
-                stream = trace.create_stream(sc)
-                packet = stream.create_packet()
-
-                self._msgs = [
-                    self._create_stream_beginning_message(stream),
-                    self._create_packet_beginning_message(packet),
-                    self._create_event_message(ec, packet),
-                    self._create_event_message(ec, packet),
-                    self._create_packet_end_message(packet),
-                    self._create_stream_end_message(stream),
-                ]
-                self._at = 0
-
-            def _user_seek_beginning(self):
-                self._at = 0
-
+    # Try consuming many times from an iterator that always returns TryAgain.
+    # This verifies that we are not missing an incref of Py_None, making the
+    # refcount of Py_None reach 0.
+    def test_try_again_many_times(self):
+        class MyIter(bt2._UserMessageIterator):
             def __next__(self):
-                if self._at < len(self._msgs):
-                    msg = self._msgs[self._at]
-                    self._at += 1
-                    return msg
-                else:
-                    raise StopIteration
+                raise bt2.TryAgain
 
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
+        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
             def __init__(self, params, obj):
-                tc = self._create_trace_class()
-                sc = tc.create_stream_class(supports_packets=True)
-                ec = sc.create_event_class()
-
-                self._add_output_port('out', (tc, sc, ec))
+                self._add_output_port('out')
 
         class MyFilterIter(bt2._UserMessageIterator):
             def __init__(self, port):
@@ -336,12 +303,83 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
         graph = bt2.Graph()
         src = graph.add_component(MySource, 'src')
-        flt = graph.add_component(MyFilter, 'flt')
-        sink = graph.add_component(sink_cls, 'sink')
-        graph.connect_ports(src.output_ports['out'], flt.input_ports['in'])
-        graph.connect_ports(flt.output_ports['out'], sink.input_ports['in'])
-        return MySourceIter, graph
+        it = TestOutputPortMessageIterator(graph, src.output_ports['out'])
+
+        # Three times the initial ref count of `None` iterations should
+        # be enough to catch the bug even if there are small differences
+        # between configurations.
+        none_ref_count = sys.getrefcount(None) * 3
+
+        for i in range(none_ref_count):
+            with self.assertRaises(bt2.TryAgain):
+                next(it)
+
+
+def _setup_seek_test(sink_cls, user_seek_beginning=None, user_can_seek_beginning=None):
+    class MySourceIter(bt2._UserMessageIterator):
+        def __init__(self, port):
+            tc, sc, ec = port.user_data
+            trace = tc()
+            stream = trace.create_stream(sc)
+            packet = stream.create_packet()
+
+            self._msgs = [
+                self._create_stream_beginning_message(stream),
+                self._create_packet_beginning_message(packet),
+                self._create_event_message(ec, packet),
+                self._create_event_message(ec, packet),
+                self._create_packet_end_message(packet),
+                self._create_stream_end_message(stream),
+            ]
+            self._at = 0
+
+        def __next__(self):
+            if self._at < len(self._msgs):
+                msg = self._msgs[self._at]
+                self._at += 1
+                return msg
+            else:
+                raise StopIteration
+
+    if user_seek_beginning is not None:
+        MySourceIter._user_seek_beginning = user_seek_beginning
+
+    if user_can_seek_beginning is not None:
+        MySourceIter._user_can_seek_beginning = property(user_can_seek_beginning)
+
+    class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
+        def __init__(self, params, obj):
+            tc = self._create_trace_class()
+            sc = tc.create_stream_class(supports_packets=True)
+            ec = sc.create_event_class()
+
+            self._add_output_port('out', (tc, sc, ec))
+
+    class MyFilterIter(bt2._UserMessageIterator):
+        def __init__(self, port):
+            self._upstream_iter = self._create_input_port_message_iterator(
+                self._component._input_ports['in']
+            )
+
+        def __next__(self):
+            return next(self._upstream_iter)
+
+        @property
+        def _user_can_seek_beginning(self):
+            return self._upstream_iter.can_seek_beginning
+
+        def _user_seek_beginning(self):
+            self._upstream_iter.seek_beginning()
+
+    class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
+        def __init__(self, params, obj):
+            self._add_input_port('in')
+            self._add_output_port('out')
+
+    return _create_graph(MySource, sink_cls, flt_comp_cls=MyFilter)
 
+
+class UserMessageIteratorSeekBeginningTestCase(unittest.TestCase):
     def test_can_seek_beginning(self):
         class MySink(bt2._UserSinkComponent):
             def __init__(self, params, obj):
@@ -356,13 +394,13 @@ class UserMessageIteratorTestCase(unittest.TestCase):
                 nonlocal can_seek_beginning
                 can_seek_beginning = self._msg_iter.can_seek_beginning
 
-        MySourceIter, graph = self._setup_seek_beginning_test(MySink)
-
         def _user_can_seek_beginning(self):
             nonlocal input_port_iter_can_seek_beginning
             return input_port_iter_can_seek_beginning
 
-        MySourceIter._user_can_seek_beginning = property(_user_can_seek_beginning)
+        graph = _setup_seek_test(
+            MySink, user_can_seek_beginning=_user_can_seek_beginning
+        )
 
         input_port_iter_can_seek_beginning = True
         can_seek_beginning = None
@@ -374,18 +412,47 @@ class UserMessageIteratorTestCase(unittest.TestCase):
         graph.run_once()
         self.assertFalse(can_seek_beginning)
 
-        # Once can_seek_beginning returns an error, verify that it raises when
-        # _can_seek_beginning has/returns the wrong type.
+    def test_no_can_seek_beginning_with_seek_beginning(self):
+        # Test an iterator without a _user_can_seek_beginning method, but with
+        # a _user_seek_beginning method.
+        class MySink(bt2._UserSinkComponent):
+            def __init__(self, params, obj):
+                self._add_input_port('in')
+
+            def _user_graph_is_configured(self):
+                self._msg_iter = self._create_input_port_message_iterator(
+                    self._input_ports['in']
+                )
+
+            def _user_consume(self):
+                nonlocal can_seek_beginning
+                can_seek_beginning = self._msg_iter.can_seek_beginning
+
+        def _user_seek_beginning(self):
+            pass
 
-        # Remove the _can_seek_beginning method, we now rely on the presence of
-        # a _seek_beginning method to know whether the iterator can seek to
-        # beginning or not.
-        del MySourceIter._user_can_seek_beginning
+        graph = _setup_seek_test(MySink, user_seek_beginning=_user_seek_beginning)
         can_seek_beginning = None
         graph.run_once()
         self.assertTrue(can_seek_beginning)
 
-        del MySourceIter._user_seek_beginning
+    def test_no_can_seek_beginning(self):
+        # Test an iterator without a _user_can_seek_beginning method, without
+        # a _user_seek_beginning method.
+        class MySink(bt2._UserSinkComponent):
+            def __init__(self, params, obj):
+                self._add_input_port('in')
+
+            def _user_graph_is_configured(self):
+                self._msg_iter = self._create_input_port_message_iterator(
+                    self._input_ports['in']
+                )
+
+            def _user_consume(self):
+                nonlocal can_seek_beginning
+                can_seek_beginning = self._msg_iter.can_seek_beginning
+
+        graph = _setup_seek_test(MySink)
         can_seek_beginning = None
         graph.run_once()
         self.assertFalse(can_seek_beginning)
@@ -410,15 +477,26 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
                 msg = next(self._msg_iter)
 
-        do_seek_beginning = False
+        def _user_seek_beginning(self):
+            self._at = 0
+
         msg = None
-        MySourceIter, graph = self._setup_seek_beginning_test(MySink)
+        graph = _setup_seek_test(MySink, user_seek_beginning=_user_seek_beginning)
+
+        # Consume message.
+        do_seek_beginning = False
         graph.run_once()
         self.assertIs(type(msg), bt2._StreamBeginningMessageConst)
+
+        # Consume message.
         graph.run_once()
         self.assertIs(type(msg), bt2._PacketBeginningMessageConst)
+
+        # Seek beginning.
         do_seek_beginning = True
         graph.run_once()
+
+        # Consume message.
         do_seek_beginning = False
         graph.run_once()
         self.assertIs(type(msg), bt2._StreamBeginningMessageConst)
@@ -436,41 +514,14 @@ class UserMessageIteratorTestCase(unittest.TestCase):
             def _user_consume(self):
                 self._msg_iter.seek_beginning()
 
-        MySourceIter, graph = self._setup_seek_beginning_test(MySink)
-
-        def _user_seek_beginning_error(self):
+        def _user_seek_beginning(self):
             raise ValueError('ouch')
 
-        MySourceIter._user_seek_beginning = _user_seek_beginning_error
+        graph = _setup_seek_test(MySink, user_seek_beginning=_user_seek_beginning)
 
         with self.assertRaises(bt2._Error):
             graph.run_once()
 
-    # Try consuming many times from an iterator that always returns TryAgain.
-    # This verifies that we are not missing an incref of Py_None, making the
-    # refcount of Py_None reach 0.
-    def test_try_again_many_times(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.TryAgain
-
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
-            def __init__(self, params, obj):
-                self._add_output_port('out')
-
-        graph = bt2.Graph()
-        src = graph.add_component(MySource, 'src')
-        it = TestOutputPortMessageIterator(graph, src.output_ports['out'])
-
-        # Three times the initial ref count of `None` iterations should
-        # be enough to catch the bug even if there are small differences
-        # between configurations.
-        none_ref_count = sys.getrefcount(None) * 3
-
-        for i in range(none_ref_count):
-            with self.assertRaises(bt2.TryAgain):
-                next(it)
-
 
 if __name__ == '__main__':
     unittest.main()
This page took 0.031167 seconds and 4 git commands to generate.