X-Git-Url: http://git.efficios.com/?a=blobdiff_plain;f=tests%2Fbindings%2Fpython%2Fbt2%2Ftest_message_iterator.py;h=f780cf2c75a9a3052c00c92903253d717fabff37;hb=ca02df0ad8ae9a1a3640956d91ca31059d0b203a;hp=0f9182f26b8ea9fb95e4758c8395f64c6913fcb3;hpb=51f97fcc8c58d17aa224c150735f7e256e796f4e;p=babeltrace.git diff --git a/tests/bindings/python/bt2/test_message_iterator.py b/tests/bindings/python/bt2/test_message_iterator.py index 0f9182f2..f780cf2c 100644 --- a/tests/bindings/python/bt2/test_message_iterator.py +++ b/tests/bindings/python/bt2/test_message_iterator.py @@ -25,7 +25,7 @@ import bt2 class UserMessageIteratorTestCase(unittest.TestCase): @staticmethod - def _create_graph(src_comp_cls): + def _create_graph(src_comp_cls, flt_comp_cls=None): class MySink(bt2._UserSinkComponent): def __init__(self, params): self._add_input_port('in') @@ -34,12 +34,28 @@ class UserMessageIteratorTestCase(unittest.TestCase): next(self._msg_iter) def _graph_is_configured(self): - self._msg_iter = self._input_ports['in'].create_message_iterator() + self._msg_iter = self._create_input_port_message_iterator( + self._input_ports['in'] + ) graph = bt2.Graph() src_comp = graph.add_component(src_comp_cls, 'src') + + if flt_comp_cls is not None: + flt_comp = graph.add_component(flt_comp_cls, 'flt') + sink_comp = graph.add_component(MySink, 'sink') - graph.connect_ports(src_comp.output_ports['out'], sink_comp.input_ports['in']) + + if flt_comp_cls is not None: + assert flt_comp is not None + graph.connect_ports( + src_comp.output_ports['out'], flt_comp.input_ports['in'] + ) + out_port = flt_comp.output_ports['out'] + else: + out_port = src_comp.output_ports['out'] + + graph.connect_ports(out_port, sink_comp.input_ports['in']) return graph def test_init(self): @@ -67,6 +83,39 @@ class UserMessageIteratorTestCase(unittest.TestCase): ) self.assertEqual(the_output_port_from_iter.user_data, 'user data') + def test_create_from_message_iterator(self): + class MySourceIter(bt2._UserMessageIterator): + def __init__(self, self_port_output): + nonlocal src_iter_initialized + src_iter_initialized = True + + class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter): + def __init__(self, params): + self._add_output_port('out') + + class MyFilterIter(bt2._UserMessageIterator): + def __init__(self, self_port_output): + nonlocal flt_iter_initialized + flt_iter_initialized = True + self._up_iter = self._create_input_port_message_iterator( + self._component._input_ports['in'] + ) + + def __next__(self): + return next(self._up_iter) + + class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter): + def __init__(self, params): + self._add_input_port('in') + self._add_output_port('out') + + src_iter_initialized = False + flt_iter_initialized = False + graph = self._create_graph(MySource, MyFilter) + graph.run() + self.assertTrue(src_iter_initialized) + self.assertTrue(flt_iter_initialized) + def test_finalize(self): class MyIter(bt2._UserMessageIterator): def _finalize(self): @@ -208,7 +257,9 @@ class UserMessageIteratorTestCase(unittest.TestCase): class MyFilterIter(bt2._UserMessageIterator): def __init__(self, port): input_port = port.user_data - self._upstream_iter = input_port.create_message_iterator() + self._upstream_iter = self._create_input_port_message_iterator( + input_port + ) def __next__(self): return next(self._upstream_iter)