From c5f330cd909f5dfbdb519546e875b4427434ba4f Mon Sep 17 00:00:00 2001 From: Simon Marchi Date: Mon, 3 Jun 2019 16:48:45 -0400 Subject: [PATCH] bt2: Adapt test_message_iterator.py and make it pass Update test_message_iterator.py to work with the current Babeltrace API. Most changes are related to the fact that it is not possible anymore to create most objects in isolation, everything is derived from a trace class. Message iterators are created in the _graph_is_configured callback rather than _port_connected. Because of this, we now need to call graph.run() in the tests to get _graph_is_configured called. I noticed that the component method of _SelfPortInputMessageIterator was a bit problematic, as we don't know the type of the component at that point. Returning the right type of component (the right Python class, wrapping the right type of Swig pointer) would require a bit of dirty work. I am not sure this is even useful, so I've just removed it and marked the tests as skipped / commented them out (they can be removed if we rule this is indeed unnecessary). An important change is the addition the self_output_port parameter in the constructor of _UserMessageIterator (which the user is expected to override). This parameter is how the iterator is supposed to know what it's supposed to iterate on (if the component has multiple ports). There are no other changes in message_iterator.py, since everything had to be done already for test_message.py to pass. Change-Id: I574d812ed81d9628a8e4a2a2dd5c593c4d730c95 Signed-off-by: Simon Marchi Signed-off-by: Francis Deslauriers Reviewed-on: https://review.lttng.org/c/babeltrace/+/1326 Tested-by: jenkins Reviewed-by: Philippe Proulx --- bindings/python/bt2/bt2/message_iterator.py | 25 +++- .../bt2/bt2/native_bt_component_class.i | 15 +- tests/bindings/python/bt2/test_clock_class.py | 2 +- tests/bindings/python/bt2/test_event.py | 2 +- tests/bindings/python/bt2/test_graph.py | 5 +- tests/bindings/python/bt2/test_message.py | 2 +- .../python/bt2/test_message_iterator.py | 137 +++++++++--------- 7 files changed, 103 insertions(+), 85 deletions(-) diff --git a/bindings/python/bt2/bt2/message_iterator.py b/bindings/python/bt2/bt2/message_iterator.py index 4cf362ad..7ac98f3d 100644 --- a/bindings/python/bt2/bt2/message_iterator.py +++ b/bindings/python/bt2/bt2/message_iterator.py @@ -60,22 +60,28 @@ class _GenericMessageIterator(object._SharedObject, _MessageIterator): return bt2.message._create_from_ptr(msg_ptr) +# This is created when a component wants to iterate on one of its input ports. class _UserComponentInputPortMessageIterator(_GenericMessageIterator): _get_msg_range = staticmethod(native_bt.py3_self_component_port_input_get_msg_range) - - @property - def component(self): - comp_ptr = native_bt.private_connection_message_iterator_get_component(self._ptr) - assert(comp_ptr) - return bt2.component._create_generic_component_from_ptr(comp_ptr) + _get_ref = staticmethod(native_bt.self_component_port_input_message_iterator_get_ref) + _put_ref = staticmethod(native_bt.self_component_port_input_message_iterator_put_ref) +# This is created when the user wants to iterate on a component's output port, +# from outside the graph. class _OutputPortMessageIterator(_GenericMessageIterator): _get_msg_range = staticmethod(native_bt.py3_port_output_get_msg_range) _get_ref = staticmethod(native_bt.port_output_message_iterator_get_ref) _put_ref = staticmethod(native_bt.port_output_message_iterator_put_ref) +# This is extended by the user to implement component classes in Python. It +# is created for a given output port when an input port message iterator is +# created on the input port on the other side of the connection. It is also +# created when an output port message iterator is created on this output port. +# +# Its purpose is to feed the messages that should go out through this output +# port. class _UserMessageIterator(_MessageIterator): def __new__(cls, ptr): # User iterator objects are always created by the native side, @@ -92,7 +98,12 @@ class _UserMessageIterator(_MessageIterator): self._ptr = ptr return self - def __init__(self): + def _init_from_native(self, self_output_port_ptr): + self_output_port = bt2.port._create_self_from_ptr_and_get_ref( + self_output_port_ptr, native_bt.PORT_TYPE_OUTPUT) + self.__init__(self_output_port) + + def __init__(self, output_port): pass @property diff --git a/bindings/python/bt2/bt2/native_bt_component_class.i b/bindings/python/bt2/bt2/native_bt_component_class.i index deca5a38..1dfa0110 100644 --- a/bindings/python/bt2/bt2/native_bt_component_class.i +++ b/bindings/python/bt2/bt2/native_bt_component_class.i @@ -1340,6 +1340,7 @@ bt_py3_component_class_message_iterator_init( PyObject *py_comp_cls = NULL; PyObject *py_iter_cls = NULL; PyObject *py_iter_ptr = NULL; + PyObject *py_component_port_output_ptr = NULL; PyObject *py_init_method_result = NULL; PyObject *py_iter = NULL; PyObject *py_comp; @@ -1384,14 +1385,23 @@ bt_py3_component_class_message_iterator_init( /* * Initialize object: * - * py_iter.__init__() + * py_iter.__init__(self_output_port) + * + * through the _init_for_native helper static method. * * At this point, py_iter._ptr is set, so this initialization * function has access to self._component (which gives it the * user Python component object from which the iterator was * created). */ - py_init_method_result = PyObject_CallMethod(py_iter, "__init__", NULL); + py_component_port_output_ptr = SWIG_NewPointerObj(SWIG_as_voidptr(self_component_port_output), + SWIGTYPE_p_bt_self_component_port_output, 0); + if (!py_component_port_output_ptr) { + BT_LOGE_STR("Failed to create a SWIG pointer object."); + goto error; + } + + py_init_method_result = PyObject_CallMethod(py_iter, "_init_from_native", "O", py_component_port_output_ptr); if (!py_init_method_result) { BT_LOGE_STR("User's __init__() method failed."); bt2_py_loge_exception(); @@ -1441,6 +1451,7 @@ end: Py_XDECREF(py_comp_cls); Py_XDECREF(py_iter_cls); Py_XDECREF(py_iter_ptr); + Py_XDECREF(py_component_port_output_ptr); Py_XDECREF(py_init_method_result); Py_XDECREF(py_iter); return status; diff --git a/tests/bindings/python/bt2/test_clock_class.py b/tests/bindings/python/bt2/test_clock_class.py index de0c001c..8be327bd 100644 --- a/tests/bindings/python/bt2/test_clock_class.py +++ b/tests/bindings/python/bt2/test_clock_class.py @@ -201,7 +201,7 @@ class ClockSnapshotTestCase(unittest.TestCase): self._cc = _cc class MyIter(bt2._UserMessageIterator): - def __init__(self): + def __init__(self, self_port_output): self._at = 0 def __next__(self): diff --git a/tests/bindings/python/bt2/test_event.py b/tests/bindings/python/bt2/test_event.py index 6f818cec..4c21ee7f 100644 --- a/tests/bindings/python/bt2/test_event.py +++ b/tests/bindings/python/bt2/test_event.py @@ -11,7 +11,7 @@ class EventTestCase(unittest.TestCase): with_ep=False): class MyIter(bt2._UserMessageIterator): - def __init__(self): + def __init__(self, self_output_port): self._at = 0 def __next__(self): diff --git a/tests/bindings/python/bt2/test_graph.py b/tests/bindings/python/bt2/test_graph.py index 8ed585a6..48e02873 100644 --- a/tests/bindings/python/bt2/test_graph.py +++ b/tests/bindings/python/bt2/test_graph.py @@ -6,7 +6,7 @@ import bt2 class _MyIter(bt2._UserMessageIterator): - def __init__(self): + def __init__(self, self_output_port): self._build_meta() self._at = 0 @@ -496,7 +496,8 @@ class GraphTestCase(unittest.TestCase): def _consume(self): raise bt2.Stop - def ports_connected_listener(upstream_port, downstream_port): + def ports_connected_listener(upstream_component, upstream_port, + downstream_component, downstream_port): raise ValueError('oh noes!') graph = bt2.Graph() diff --git a/tests/bindings/python/bt2/test_message.py b/tests/bindings/python/bt2/test_message.py index 521d4e4d..3cc55042 100644 --- a/tests/bindings/python/bt2/test_message.py +++ b/tests/bindings/python/bt2/test_message.py @@ -7,7 +7,7 @@ class AllMessagesTestCase(unittest.TestCase): def setUp(self): class MyIter(bt2._UserMessageIterator): - def __init__(self): + def __init__(self, self_port_output): self._at = 0 def __next__(self): diff --git a/tests/bindings/python/bt2/test_message_iterator.py b/tests/bindings/python/bt2/test_message_iterator.py index 455fca55..0ad56329 100644 --- a/tests/bindings/python/bt2/test_message_iterator.py +++ b/tests/bindings/python/bt2/test_message_iterator.py @@ -5,7 +5,6 @@ import copy import bt2 -@unittest.skip("this is broken") class UserMessageIteratorTestCase(unittest.TestCase): @staticmethod def _create_graph(src_comp_cls): @@ -16,8 +15,8 @@ class UserMessageIteratorTestCase(unittest.TestCase): def _consume(self): next(self._msg_iter) - def _port_connected(self, port, other_port): - self._msg_iter = port.connection.create_message_iterator() + def _graph_is_configured(self): + self._msg_iter = self._input_ports['in'].create_message_iterator() graph = bt2.Graph() src_comp = graph.add_component(src_comp_cls, 'src') @@ -27,19 +26,27 @@ class UserMessageIteratorTestCase(unittest.TestCase): return graph def test_init(self): + the_output_port_from_source = None + the_output_port_from_iter = None + class MyIter(bt2._UserMessageIterator): - def __init__(self): + def __init__(self, self_port_output): nonlocal initialized + nonlocal the_output_port_from_iter initialized = True + the_output_port_from_iter = self_port_output class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): def __init__(self, params): - self._add_output_port('out') + nonlocal the_output_port_from_source + the_output_port_from_source = self._add_output_port('out') initialized = False graph = self._create_graph(MySource) + graph.run() self.assertTrue(initialized) + self.assertEqual(the_output_port_from_source.addr, the_output_port_from_iter.addr) def test_finalize(self): class MyIter(bt2._UserMessageIterator): @@ -54,12 +61,13 @@ class UserMessageIteratorTestCase(unittest.TestCase): finalized = False graph = self._create_graph(MySource) + graph.run() del graph self.assertTrue(finalized) def test_component(self): class MyIter(bt2._UserMessageIterator): - def __init__(self): + def __init__(self, self_port_output): nonlocal salut salut = self._component._salut @@ -71,11 +79,12 @@ class UserMessageIteratorTestCase(unittest.TestCase): salut = None graph = self._create_graph(MySource) + graph.run() self.assertEqual(salut, 23) def test_addr(self): class MyIter(bt2._UserMessageIterator): - def __init__(self): + def __init__(self, self_port_output): nonlocal addr addr = self.addr @@ -86,76 +95,33 @@ class UserMessageIteratorTestCase(unittest.TestCase): addr = None graph = self._create_graph(MySource) + graph.run() self.assertIsNotNone(addr) self.assertNotEqual(addr, 0) -@unittest.skip("this is broken") -class PrivateConnectionMessageIteratorTestCase(unittest.TestCase): - def test_component(self): - class MyIter(bt2._UserMessageIterator): - pass - - class MySource(bt2._UserSourceComponent, - message_iterator_class=MyIter): - def __init__(self, params): - self._add_output_port('out') - - class MySink(bt2._UserSinkComponent): - def __init__(self, params): - self._add_input_port('in') - - def _consume(self): - next(self._msg_iter) - - def _port_connected(self, port, other_port): - nonlocal upstream_comp - self._msg_iter = port.connection.create_message_iterator() - upstream_comp = self._msg_iter.component - - upstream_comp = None - graph = bt2.Graph() - src_comp = graph.add_component(MySource, 'src') - sink_comp = graph.add_component(MySink, 'sink') - graph.connect_ports(src_comp.output_ports['out'], - sink_comp.input_ports['in']) - self.assertEqual(src_comp, upstream_comp) - del upstream_comp - - -@unittest.skip("this is broken") class OutputPortMessageIteratorTestCase(unittest.TestCase): def test_component(self): class MyIter(bt2._UserMessageIterator): - def __init__(self): - self._build_meta() + def __init__(self, self_port_output): self._at = 0 - def _build_meta(self): - self._trace = bt2.Trace() - self._sc = bt2.StreamClass() - self._ec = bt2.EventClass('salut') - self._my_int_fc = bt2.IntegerFieldClass(32) - self._ec.payload_field_class = bt2.StructureFieldClass() - self._ec.payload_field_class += collections.OrderedDict([ - ('my_int', self._my_int_fc), - ]) - self._sc.add_event_class(self._ec) - self._trace.add_stream_class(self._sc) - self._stream = self._sc() - self._packet = self._stream.create_packet() - - def _create_event(self, value): - ev = self._ec() - ev.payload_field['my_int'] = value - ev.packet = self._packet - return ev - def __next__(self): - if self._at == 5: + if self._at == 7: raise bt2.Stop - msg = bt2.EventMessage(self._create_event(self._at * 3)) + if self._at == 0: + msg = self._create_stream_beginning_message(test_obj._stream) + elif self._at == 1: + msg = self._create_packet_beginning_message(test_obj._packet) + elif self._at == 5: + msg = self._create_packet_end_message(test_obj._packet) + elif self._at == 6: + msg = self._create_stream_end_message(test_obj._stream) + else: + msg = self._create_event_message(test_obj._event_class, test_obj._packet) + msg.event.payload_field['my_int'] = self._at * 3 + self._at += 1 return msg @@ -164,13 +130,42 @@ class OutputPortMessageIteratorTestCase(unittest.TestCase): def __init__(self, params): self._add_output_port('out') + trace_class = self._create_trace_class() + stream_class = trace_class.create_stream_class() + + # Create payload field class + my_int_ft = trace_class.create_signed_integer_field_class(32) + payload_ft = trace_class.create_structure_field_class() + payload_ft += collections.OrderedDict([ + ('my_int', my_int_ft), + ]) + + event_class = stream_class.create_event_class(name='salut', payload_field_class=payload_ft) + + trace = trace_class() + stream = trace.create_stream(stream_class) + packet = stream.create_packet() + + test_obj._event_class = event_class + test_obj._stream = stream + test_obj._packet = packet + + test_obj = self graph = bt2.Graph() src = graph.add_component(MySource, 'src') - types = [bt2.EventMessage] - msg_iter = src.output_ports['out'].create_message_iterator(types) + msg_iter = graph.create_output_port_message_iterator(src.output_ports['out']) for at, msg in enumerate(msg_iter): - self.assertIsInstance(msg, bt2.EventMessage) - self.assertEqual(msg.event.event_class.name, 'salut') - field = msg.event.payload_field['my_int'] - self.assertEqual(field, at * 3) + if at == 0: + self.assertIsInstance(msg, bt2.message._StreamBeginningMessage) + elif at == 1: + self.assertIsInstance(msg, bt2.message._PacketBeginningMessage) + elif at == 5: + self.assertIsInstance(msg, bt2.message._PacketEndMessage) + elif at == 6: + self.assertIsInstance(msg, bt2.message._StreamEndMessage) + else: + self.assertIsInstance(msg, bt2.message._EventMessage) + self.assertEqual(msg.event.event_class.name, 'salut') + field = msg.event.payload_field['my_int'] + self.assertEqual(field, at * 3) -- 2.34.1