bt2: Adapt test_message_iterator.py and make it pass
authorSimon Marchi <simon.marchi@efficios.com>
Mon, 3 Jun 2019 20:48:45 +0000 (16:48 -0400)
committerPhilippe Proulx <eeppeliteloop@gmail.com>
Wed, 5 Jun 2019 17:47:34 +0000 (13:47 -0400)
Update test_message_iterator.py to work with the current Babeltrace API.
Most changes are related to the fact that it is not possible anymore to
create most objects in isolation, everything is derived from a trace
class.  Message iterators are created in the _graph_is_configured
callback rather than _port_connected. Because of this, we now need to
call graph.run() in the tests to get _graph_is_configured called.

I noticed that the component method of _SelfPortInputMessageIterator was
a bit problematic, as we don't know the type of the component at that
point.  Returning the right type of component (the right Python class,
wrapping the right type of Swig pointer) would require a bit of dirty
work.  I am not sure this is even useful, so I've just removed it and
marked the tests as skipped / commented them out (they can be removed if
we rule this is indeed unnecessary).

An important change is the addition the self_output_port parameter in
the constructor of _UserMessageIterator (which the user is expected to
override).  This parameter is how the iterator is supposed to know what
it's supposed to iterate on (if the component has multiple ports).

There are no other changes in message_iterator.py, since everything had
to be done already for test_message.py to pass.

Change-Id: I574d812ed81d9628a8e4a2a2dd5c593c4d730c95
Signed-off-by: Simon Marchi <simon.marchi@efficios.com>
Signed-off-by: Francis Deslauriers <francis.deslauriers@efficios.com>
Reviewed-on: https://review.lttng.org/c/babeltrace/+/1326
Tested-by: jenkins
Reviewed-by: Philippe Proulx <eeppeliteloop@gmail.com>
bindings/python/bt2/bt2/message_iterator.py
bindings/python/bt2/bt2/native_bt_component_class.i
tests/bindings/python/bt2/test_clock_class.py
tests/bindings/python/bt2/test_event.py
tests/bindings/python/bt2/test_graph.py
tests/bindings/python/bt2/test_message.py
tests/bindings/python/bt2/test_message_iterator.py

index 4cf362ad490d9b4fe002214b329a1fbfcb3dc0c8..7ac98f3dfcf64c35792972ae675c9a7695c757c6 100644 (file)
@@ -60,22 +60,28 @@ class _GenericMessageIterator(object._SharedObject, _MessageIterator):
         return bt2.message._create_from_ptr(msg_ptr)
 
 
+# This is created when a component wants to iterate on one of its input ports.
 class _UserComponentInputPortMessageIterator(_GenericMessageIterator):
     _get_msg_range = staticmethod(native_bt.py3_self_component_port_input_get_msg_range)
-
-    @property
-    def component(self):
-        comp_ptr = native_bt.private_connection_message_iterator_get_component(self._ptr)
-        assert(comp_ptr)
-        return bt2.component._create_generic_component_from_ptr(comp_ptr)
+    _get_ref = staticmethod(native_bt.self_component_port_input_message_iterator_get_ref)
+    _put_ref = staticmethod(native_bt.self_component_port_input_message_iterator_put_ref)
 
 
+# This is created when the user wants to iterate on a component's output port,
+# from outside the graph.
 class _OutputPortMessageIterator(_GenericMessageIterator):
     _get_msg_range = staticmethod(native_bt.py3_port_output_get_msg_range)
     _get_ref = staticmethod(native_bt.port_output_message_iterator_get_ref)
     _put_ref = staticmethod(native_bt.port_output_message_iterator_put_ref)
 
 
+# This is extended by the user to implement component classes in Python.  It
+# is created for a given output port when an input port message iterator is
+# created on the input port on the other side of the connection.  It is also
+# created when an output port message iterator is created on this output port.
+#
+# Its purpose is to feed the messages that should go out through this output
+# port.
 class _UserMessageIterator(_MessageIterator):
     def __new__(cls, ptr):
         # User iterator objects are always created by the native side,
@@ -92,7 +98,12 @@ class _UserMessageIterator(_MessageIterator):
         self._ptr = ptr
         return self
 
-    def __init__(self):
+    def _init_from_native(self, self_output_port_ptr):
+        self_output_port = bt2.port._create_self_from_ptr_and_get_ref(
+            self_output_port_ptr, native_bt.PORT_TYPE_OUTPUT)
+        self.__init__(self_output_port)
+
+    def __init__(self, output_port):
         pass
 
     @property
index deca5a38682820183e3c457be26c359a74a3c777..1dfa0110886fa05a225e1d018155d11b67829142 100644 (file)
@@ -1340,6 +1340,7 @@ bt_py3_component_class_message_iterator_init(
        PyObject *py_comp_cls = NULL;
        PyObject *py_iter_cls = NULL;
        PyObject *py_iter_ptr = NULL;
+       PyObject *py_component_port_output_ptr = NULL;
        PyObject *py_init_method_result = NULL;
        PyObject *py_iter = NULL;
        PyObject *py_comp;
@@ -1384,14 +1385,23 @@ bt_py3_component_class_message_iterator_init(
        /*
         * Initialize object:
         *
-        *     py_iter.__init__()
+        *     py_iter.__init__(self_output_port)
+        *
+         * through the _init_for_native helper static method.
         *
         * At this point, py_iter._ptr is set, so this initialization
         * function has access to self._component (which gives it the
         * user Python component object from which the iterator was
         * created).
         */
-       py_init_method_result = PyObject_CallMethod(py_iter, "__init__", NULL);
+        py_component_port_output_ptr = SWIG_NewPointerObj(SWIG_as_voidptr(self_component_port_output),
+               SWIGTYPE_p_bt_self_component_port_output, 0);
+       if (!py_component_port_output_ptr) {
+               BT_LOGE_STR("Failed to create a SWIG pointer object.");
+               goto error;
+       }
+
+       py_init_method_result = PyObject_CallMethod(py_iter, "_init_from_native", "O", py_component_port_output_ptr);
        if (!py_init_method_result) {
                BT_LOGE_STR("User's __init__() method failed.");
                bt2_py_loge_exception();
@@ -1441,6 +1451,7 @@ end:
        Py_XDECREF(py_comp_cls);
        Py_XDECREF(py_iter_cls);
        Py_XDECREF(py_iter_ptr);
+       Py_XDECREF(py_component_port_output_ptr);
        Py_XDECREF(py_init_method_result);
        Py_XDECREF(py_iter);
        return status;
index de0c001c6d6b11235812b8518f2021ff51464617..8be327bd5c6d1031eb50916caf51eff1aae3e0f8 100644 (file)
@@ -201,7 +201,7 @@ class ClockSnapshotTestCase(unittest.TestCase):
         self._cc = _cc
 
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self):
+            def __init__(self, self_port_output):
                 self._at = 0
 
             def __next__(self):
index 6f818cec407d03581ab7c74352f94b3def39470b..4c21ee7fe144dd204a963ff9da93f057fd8a4d34 100644 (file)
@@ -11,7 +11,7 @@ class EventTestCase(unittest.TestCase):
                                    with_ep=False):
 
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self):
+            def __init__(self, self_output_port):
                 self._at = 0
 
             def __next__(self):
index 8ed585a626aa0ff4d0753949d6ca99c03a78cc61..48e02873bb812b76d8c2bd414c4f1fb84b44bd8d 100644 (file)
@@ -6,7 +6,7 @@ import bt2
 
 
 class _MyIter(bt2._UserMessageIterator):
-    def __init__(self):
+    def __init__(self, self_output_port):
         self._build_meta()
         self._at = 0
 
@@ -496,7 +496,8 @@ class GraphTestCase(unittest.TestCase):
             def _consume(self):
                 raise bt2.Stop
 
-        def ports_connected_listener(upstream_port, downstream_port):
+        def ports_connected_listener(upstream_component, upstream_port,
+                                     downstream_component, downstream_port):
             raise ValueError('oh noes!')
 
         graph = bt2.Graph()
index 521d4e4daac1b5e9cb8f6a946a05e09e3ea8cc98..3cc55042ad9e6f73a7130236a4bd93d679ad0b66 100644 (file)
@@ -7,7 +7,7 @@ class AllMessagesTestCase(unittest.TestCase):
     def setUp(self):
 
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self):
+            def __init__(self, self_port_output):
                 self._at = 0
 
             def __next__(self):
index 455fca55d16f2f270b8233c344c18be7e737c622..0ad56329da087136891185765fa6f18d5b83f8df 100644 (file)
@@ -5,7 +5,6 @@ import copy
 import bt2
 
 
-@unittest.skip("this is broken")
 class UserMessageIteratorTestCase(unittest.TestCase):
     @staticmethod
     def _create_graph(src_comp_cls):
@@ -16,8 +15,8 @@ class UserMessageIteratorTestCase(unittest.TestCase):
             def _consume(self):
                 next(self._msg_iter)
 
-            def _port_connected(self, port, other_port):
-                self._msg_iter = port.connection.create_message_iterator()
+            def _graph_is_configured(self):
+                self._msg_iter = self._input_ports['in'].create_message_iterator()
 
         graph = bt2.Graph()
         src_comp = graph.add_component(src_comp_cls, 'src')
@@ -27,19 +26,27 @@ class UserMessageIteratorTestCase(unittest.TestCase):
         return graph
 
     def test_init(self):
+        the_output_port_from_source = None
+        the_output_port_from_iter = None
+
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self):
+            def __init__(self, self_port_output):
                 nonlocal initialized
+                nonlocal the_output_port_from_iter
                 initialized = True
+                the_output_port_from_iter = self_port_output
 
         class MySource(bt2._UserSourceComponent,
                        message_iterator_class=MyIter):
             def __init__(self, params):
-                self._add_output_port('out')
+                nonlocal the_output_port_from_source
+                the_output_port_from_source = self._add_output_port('out')
 
         initialized = False
         graph = self._create_graph(MySource)
+        graph.run()
         self.assertTrue(initialized)
+        self.assertEqual(the_output_port_from_source.addr, the_output_port_from_iter.addr)
 
     def test_finalize(self):
         class MyIter(bt2._UserMessageIterator):
@@ -54,12 +61,13 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
         finalized = False
         graph = self._create_graph(MySource)
+        graph.run()
         del graph
         self.assertTrue(finalized)
 
     def test_component(self):
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self):
+            def __init__(self, self_port_output):
                 nonlocal salut
                 salut = self._component._salut
 
@@ -71,11 +79,12 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
         salut = None
         graph = self._create_graph(MySource)
+        graph.run()
         self.assertEqual(salut, 23)
 
     def test_addr(self):
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self):
+            def __init__(self, self_port_output):
                 nonlocal addr
                 addr = self.addr
 
@@ -86,76 +95,33 @@ class UserMessageIteratorTestCase(unittest.TestCase):
 
         addr = None
         graph = self._create_graph(MySource)
+        graph.run()
         self.assertIsNotNone(addr)
         self.assertNotEqual(addr, 0)
 
 
-@unittest.skip("this is broken")
-class PrivateConnectionMessageIteratorTestCase(unittest.TestCase):
-    def test_component(self):
-        class MyIter(bt2._UserMessageIterator):
-            pass
-
-        class MySource(bt2._UserSourceComponent,
-                       message_iterator_class=MyIter):
-            def __init__(self, params):
-                self._add_output_port('out')
-
-        class MySink(bt2._UserSinkComponent):
-            def __init__(self, params):
-                self._add_input_port('in')
-
-            def _consume(self):
-                next(self._msg_iter)
-
-            def _port_connected(self, port, other_port):
-                nonlocal upstream_comp
-                self._msg_iter = port.connection.create_message_iterator()
-                upstream_comp = self._msg_iter.component
-
-        upstream_comp = None
-        graph = bt2.Graph()
-        src_comp = graph.add_component(MySource, 'src')
-        sink_comp = graph.add_component(MySink, 'sink')
-        graph.connect_ports(src_comp.output_ports['out'],
-                            sink_comp.input_ports['in'])
-        self.assertEqual(src_comp, upstream_comp)
-        del upstream_comp
-
-
-@unittest.skip("this is broken")
 class OutputPortMessageIteratorTestCase(unittest.TestCase):
     def test_component(self):
         class MyIter(bt2._UserMessageIterator):
-            def __init__(self):
-                self._build_meta()
+            def __init__(self, self_port_output):
                 self._at = 0
 
-            def _build_meta(self):
-                self._trace = bt2.Trace()
-                self._sc = bt2.StreamClass()
-                self._ec = bt2.EventClass('salut')
-                self._my_int_fc = bt2.IntegerFieldClass(32)
-                self._ec.payload_field_class = bt2.StructureFieldClass()
-                self._ec.payload_field_class += collections.OrderedDict([
-                    ('my_int', self._my_int_fc),
-                ])
-                self._sc.add_event_class(self._ec)
-                self._trace.add_stream_class(self._sc)
-                self._stream = self._sc()
-                self._packet = self._stream.create_packet()
-
-            def _create_event(self, value):
-                ev = self._ec()
-                ev.payload_field['my_int'] = value
-                ev.packet = self._packet
-                return ev
-
             def __next__(self):
-                if self._at == 5:
+                if self._at == 7:
                     raise bt2.Stop
 
-                msg = bt2.EventMessage(self._create_event(self._at * 3))
+                if self._at == 0:
+                    msg = self._create_stream_beginning_message(test_obj._stream)
+                elif self._at == 1:
+                    msg = self._create_packet_beginning_message(test_obj._packet)
+                elif self._at == 5:
+                    msg = self._create_packet_end_message(test_obj._packet)
+                elif self._at == 6:
+                    msg = self._create_stream_end_message(test_obj._stream)
+                else:
+                    msg = self._create_event_message(test_obj._event_class, test_obj._packet)
+                    msg.event.payload_field['my_int'] = self._at * 3
+
                 self._at += 1
                 return msg
 
@@ -164,13 +130,42 @@ class OutputPortMessageIteratorTestCase(unittest.TestCase):
             def __init__(self, params):
                 self._add_output_port('out')
 
+                trace_class = self._create_trace_class()
+                stream_class = trace_class.create_stream_class()
+
+                # Create payload field class
+                my_int_ft = trace_class.create_signed_integer_field_class(32)
+                payload_ft = trace_class.create_structure_field_class()
+                payload_ft += collections.OrderedDict([
+                    ('my_int', my_int_ft),
+                ])
+
+                event_class = stream_class.create_event_class(name='salut', payload_field_class=payload_ft)
+
+                trace = trace_class()
+                stream = trace.create_stream(stream_class)
+                packet = stream.create_packet()
+
+                test_obj._event_class = event_class
+                test_obj._stream = stream
+                test_obj._packet = packet
+
+        test_obj = self
         graph = bt2.Graph()
         src = graph.add_component(MySource, 'src')
-        types = [bt2.EventMessage]
-        msg_iter = src.output_ports['out'].create_message_iterator(types)
+        msg_iter = graph.create_output_port_message_iterator(src.output_ports['out'])
 
         for at, msg in enumerate(msg_iter):
-            self.assertIsInstance(msg, bt2.EventMessage)
-            self.assertEqual(msg.event.event_class.name, 'salut')
-            field = msg.event.payload_field['my_int']
-            self.assertEqual(field, at * 3)
+            if at == 0:
+                self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
+            elif at == 1:
+                self.assertIsInstance(msg, bt2.message._PacketBeginningMessage)
+            elif at == 5:
+                self.assertIsInstance(msg, bt2.message._PacketEndMessage)
+            elif at == 6:
+                self.assertIsInstance(msg, bt2.message._StreamEndMessage)
+            else:
+                self.assertIsInstance(msg, bt2.message._EventMessage)
+                self.assertEqual(msg.event.event_class.name, 'salut')
+                field = msg.event.payload_field['my_int']
+                self.assertEqual(field, at * 3)
This page took 0.031736 seconds and 4 git commands to generate.