X-Git-Url: http://git.efficios.com/?a=blobdiff_plain;f=tests%2Fbindings%2Fpython%2Fbt2%2Ftest_message_iterator.py;h=9635672fe792b409e53bf81cb8bc68036be81439;hb=57081273d1191fc79edc101af619fab96b72460d;hp=c8d2cf16806d22566149995ee8c97c9c15906f34;hpb=3fb99a226ccb40c79de6b55b5a249d93b9c5262e;p=babeltrace.git diff --git a/tests/bindings/python/bt2/test_message_iterator.py b/tests/bindings/python/bt2/test_message_iterator.py index c8d2cf16..9635672f 100644 --- a/tests/bindings/python/bt2/test_message_iterator.py +++ b/tests/bindings/python/bt2/test_message_iterator.py @@ -16,24 +16,23 @@ # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. # -from bt2 import value -import collections import unittest -import copy import bt2 +import sys +from utils import TestOutputPortMessageIterator class UserMessageIteratorTestCase(unittest.TestCase): @staticmethod def _create_graph(src_comp_cls, flt_comp_cls=None): class MySink(bt2._UserSinkComponent): - def __init__(self, params): + def __init__(self, params, obj): self._add_input_port('in') - def _consume(self): + def _user_consume(self): next(self._msg_iter) - def _graph_is_configured(self): + def _user_graph_is_configured(self): self._msg_iter = self._create_input_port_message_iterator( self._input_ports['in'] ) @@ -70,7 +69,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): the_output_port_from_iter = self_port_output class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): - def __init__(self, params): + def __init__(self, params, obj): nonlocal the_output_port_from_source the_output_port_from_source = self._add_output_port('out', 'user data') @@ -90,7 +89,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): src_iter_initialized = True class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter): - def __init__(self, params): + def __init__(self, params, obj): self._add_output_port('out') class MyFilterIter(bt2._UserMessageIterator): @@ -105,7 +104,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): return next(self._up_iter) class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter): - def __init__(self, params): + def __init__(self, params, obj): self._add_input_port('in') self._add_output_port('out') @@ -118,12 +117,12 @@ class UserMessageIteratorTestCase(unittest.TestCase): def test_finalize(self): class MyIter(bt2._UserMessageIterator): - def _finalize(self): + def _user_finalize(self): nonlocal finalized finalized = True class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): - def __init__(self, params): + def __init__(self, params, obj): self._add_output_port('out') finalized = False @@ -139,7 +138,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): salut = self._component._salut class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): - def __init__(self, params): + def __init__(self, params, obj): self._add_output_port('out') self._salut = 23 @@ -155,7 +154,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): addr = self.addr class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): - def __init__(self, params): + def __init__(self, params, obj): self._add_output_port('out') addr = None @@ -188,7 +187,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): return self._msgs.pop(0) class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): - def __init__(self, params): + def __init__(self, params, obj): tc = self._create_trace_class() sc = tc.create_stream_class(supports_packets=True) ec = sc.create_event_class() @@ -196,7 +195,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): graph = bt2.Graph() src = graph.add_component(MySource, 'src') - it = graph.create_output_port_message_iterator(src.output_ports['out']) + it = TestOutputPortMessageIterator(graph, src.output_ports['out']) # Skip beginning messages. msg = next(it) @@ -212,7 +211,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): self.assertEqual(msg_ev1.addr, msg_ev2.addr) @staticmethod - def _setup_seek_beginning_test(): + def _setup_seek_beginning_test(sink_cls): # Use a source, a filter and an output port iterator. This allows us # to test calling `seek_beginning` on both a _OutputPortMessageIterator # and a _UserComponentInputPortMessageIterator, on top of checking that @@ -235,7 +234,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): ] self._at = 0 - def _seek_beginning(self): + def _user_seek_beginning(self): self._at = 0 def __next__(self): @@ -247,7 +246,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): raise StopIteration class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter): - def __init__(self, params): + def __init__(self, params, obj): tc = self._create_trace_class() sc = tc.create_stream_class(supports_packets=True) ec = sc.create_event_class() @@ -264,40 +263,57 @@ class UserMessageIteratorTestCase(unittest.TestCase): def __next__(self): return next(self._upstream_iter) - def _seek_beginning(self): + def _user_seek_beginning(self): self._upstream_iter.seek_beginning() @property - def _can_seek_beginning(self): + def _user_can_seek_beginning(self): return self._upstream_iter.can_seek_beginning class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter): - def __init__(self, params): + def __init__(self, params, obj): input_port = self._add_input_port('in') self._add_output_port('out', input_port) graph = bt2.Graph() src = graph.add_component(MySource, 'src') flt = graph.add_component(MyFilter, 'flt') + sink = graph.add_component(sink_cls, 'sink') graph.connect_ports(src.output_ports['out'], flt.input_ports['in']) - it = graph.create_output_port_message_iterator(flt.output_ports['out']) - - return it, MySourceIter + graph.connect_ports(flt.output_ports['out'], sink.input_ports['in']) + return MySourceIter, graph def test_can_seek_beginning(self): - it, MySourceIter = self._setup_seek_beginning_test() + class MySink(bt2._UserSinkComponent): + def __init__(self, params, obj): + self._add_input_port('in') + + def _user_graph_is_configured(self): + self._msg_iter = self._create_input_port_message_iterator( + self._input_ports['in'] + ) - def _can_seek_beginning(self): - nonlocal can_seek_beginning - return can_seek_beginning + def _user_consume(self): + nonlocal can_seek_beginning + can_seek_beginning = self._msg_iter.can_seek_beginning - MySourceIter._can_seek_beginning = property(_can_seek_beginning) + MySourceIter, graph = self._setup_seek_beginning_test(MySink) - can_seek_beginning = True - self.assertTrue(it.can_seek_beginning) + def _user_can_seek_beginning(self): + nonlocal input_port_iter_can_seek_beginning + return input_port_iter_can_seek_beginning - can_seek_beginning = False - self.assertFalse(it.can_seek_beginning) + MySourceIter._user_can_seek_beginning = property(_user_can_seek_beginning) + + input_port_iter_can_seek_beginning = True + can_seek_beginning = None + graph.run_once() + self.assertTrue(can_seek_beginning) + + input_port_iter_can_seek_beginning = False + can_seek_beginning = None + graph.run_once() + self.assertFalse(can_seek_beginning) # Once can_seek_beginning returns an error, verify that it raises when # _can_seek_beginning has/returns the wrong type. @@ -305,51 +321,71 @@ class UserMessageIteratorTestCase(unittest.TestCase): # Remove the _can_seek_beginning method, we now rely on the presence of # a _seek_beginning method to know whether the iterator can seek to # beginning or not. - del MySourceIter._can_seek_beginning - self.assertTrue(it.can_seek_beginning) + del MySourceIter._user_can_seek_beginning + can_seek_beginning = None + graph.run_once() + self.assertTrue(can_seek_beginning) - del MySourceIter._seek_beginning - self.assertFalse(it.can_seek_beginning) + del MySourceIter._user_seek_beginning + can_seek_beginning = None + graph.run_once() + self.assertFalse(can_seek_beginning) def test_seek_beginning(self): - it, MySourceIter = self._setup_seek_beginning_test() + class MySink(bt2._UserSinkComponent): + def __init__(self, params, obj): + self._add_input_port('in') - msg = next(it) + def _user_graph_is_configured(self): + self._msg_iter = self._create_input_port_message_iterator( + self._input_ports['in'] + ) + + def _user_consume(self): + nonlocal do_seek_beginning + nonlocal msg + + if do_seek_beginning: + self._msg_iter.seek_beginning() + return + + msg = next(self._msg_iter) + + do_seek_beginning = False + msg = None + MySourceIter, graph = self._setup_seek_beginning_test(MySink) + graph.run_once() self.assertIsInstance(msg, bt2._StreamBeginningMessage) - msg = next(it) + graph.run_once() self.assertIsInstance(msg, bt2._PacketBeginningMessage) + do_seek_beginning = True + graph.run_once() + do_seek_beginning = False + graph.run_once() + self.assertIsInstance(msg, bt2._StreamBeginningMessage) - it.seek_beginning() + def test_seek_beginning_user_error(self): + class MySink(bt2._UserSinkComponent): + def __init__(self, params, obj): + self._add_input_port('in') - msg = next(it) - self.assertIsInstance(msg, bt2._StreamBeginningMessage) + def _user_graph_is_configured(self): + self._msg_iter = self._create_input_port_message_iterator( + self._input_ports['in'] + ) - # Verify that we can seek beginning after having reached the end. - # - # It currently does not work to seek an output port message iterator - # once it's ended, but we should eventually make it work and uncomment - # the following snippet. - # - # try: - # while True: - # next(it) - # except bt2.Stop: - # pass - # - # it.seek_beginning() - # msg = next(it) - # self.assertIsInstance(msg, bt2._StreamBeginningMessage) + def _user_consume(self): + self._msg_iter.seek_beginning() - def test_seek_beginning_user_error(self): - it, MySourceIter = self._setup_seek_beginning_test() + MySourceIter, graph = self._setup_seek_beginning_test(MySink) - def _seek_beginning_error(self): + def _user_seek_beginning_error(self): raise ValueError('ouch') - MySourceIter._seek_beginning = _seek_beginning_error + MySourceIter._user_seek_beginning = _user_seek_beginning_error with self.assertRaises(bt2._Error): - it.seek_beginning() + graph.run_once() # Try consuming many times from an iterator that always returns TryAgain. # This verifies that we are not missing an incref of Py_None, making the @@ -360,92 +396,22 @@ class UserMessageIteratorTestCase(unittest.TestCase): raise bt2.TryAgain class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): - def __init__(self, params): + def __init__(self, params, obj): self._add_output_port('out') graph = bt2.Graph() src = graph.add_component(MySource, 'src') - it = graph.create_output_port_message_iterator(src.output_ports['out']) + it = TestOutputPortMessageIterator(graph, src.output_ports['out']) - # The initial refcount of Py_None was in the 7000, so 100000 iterations - # should be enough to catch the bug even if there are small differences + # Three times the initial ref count of `None` iterations should + # be enough to catch the bug even if there are small differences # between configurations. - for i in range(100000): + none_ref_count = sys.getrefcount(None) * 3 + + for i in range(none_ref_count): with self.assertRaises(bt2.TryAgain): next(it) -class OutputPortMessageIteratorTestCase(unittest.TestCase): - def test_component(self): - class MyIter(bt2._UserMessageIterator): - def __init__(self, self_port_output): - self._at = 0 - - def __next__(self): - if self._at == 7: - raise bt2.Stop - - if self._at == 0: - msg = self._create_stream_beginning_message(test_obj._stream) - elif self._at == 1: - msg = self._create_packet_beginning_message(test_obj._packet) - elif self._at == 5: - msg = self._create_packet_end_message(test_obj._packet) - elif self._at == 6: - msg = self._create_stream_end_message(test_obj._stream) - else: - msg = self._create_event_message( - test_obj._event_class, test_obj._packet - ) - msg.event.payload_field['my_int'] = self._at * 3 - - self._at += 1 - return msg - - class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): - def __init__(self, params): - self._add_output_port('out') - - trace_class = self._create_trace_class() - stream_class = trace_class.create_stream_class(supports_packets=True) - - # Create payload field class - my_int_ft = trace_class.create_signed_integer_field_class(32) - payload_ft = trace_class.create_structure_field_class() - payload_ft += [('my_int', my_int_ft)] - - event_class = stream_class.create_event_class( - name='salut', payload_field_class=payload_ft - ) - - trace = trace_class() - stream = trace.create_stream(stream_class) - packet = stream.create_packet() - - test_obj._event_class = event_class - test_obj._stream = stream - test_obj._packet = packet - - test_obj = self - graph = bt2.Graph() - src = graph.add_component(MySource, 'src') - msg_iter = graph.create_output_port_message_iterator(src.output_ports['out']) - - for at, msg in enumerate(msg_iter): - if at == 0: - self.assertIsInstance(msg, bt2._StreamBeginningMessage) - elif at == 1: - self.assertIsInstance(msg, bt2._PacketBeginningMessage) - elif at == 5: - self.assertIsInstance(msg, bt2._PacketEndMessage) - elif at == 6: - self.assertIsInstance(msg, bt2._StreamEndMessage) - else: - self.assertIsInstance(msg, bt2._EventMessage) - self.assertEqual(msg.event.cls.name, 'salut') - field = msg.event.payload_field['my_int'] - self.assertEqual(field, at * 3) - - if __name__ == '__main__': unittest.main()