X-Git-Url: http://git.efficios.com/?a=blobdiff_plain;f=tests%2Fbindings%2Fpython%2Fbt2%2Ftest_message_iterator.py;h=0f9182f26b8ea9fb95e4758c8395f64c6913fcb3;hb=694c792bc8f078c02acde68a3390acafbb36b2f4;hp=0ad56329da087136891185765fa6f18d5b83f8df;hpb=c5f330cd909f5dfbdb519546e875b4427434ba4f;p=babeltrace.git diff --git a/tests/bindings/python/bt2/test_message_iterator.py b/tests/bindings/python/bt2/test_message_iterator.py index 0ad56329..0f9182f2 100644 --- a/tests/bindings/python/bt2/test_message_iterator.py +++ b/tests/bindings/python/bt2/test_message_iterator.py @@ -1,3 +1,21 @@ +# +# Copyright (C) 2019 EfficiOS Inc. +# +# This program is free software; you can redistribute it and/or +# modify it under the terms of the GNU General Public License +# as published by the Free Software Foundation; only version 2 +# of the License. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. +# + from bt2 import value import collections import unittest @@ -21,8 +39,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): graph = bt2.Graph() src_comp = graph.add_component(src_comp_cls, 'src') sink_comp = graph.add_component(MySink, 'sink') - graph.connect_ports(src_comp.output_ports['out'], - sink_comp.input_ports['in']) + graph.connect_ports(src_comp.output_ports['out'], sink_comp.input_ports['in']) return graph def test_init(self): @@ -36,17 +53,19 @@ class UserMessageIteratorTestCase(unittest.TestCase): initialized = True the_output_port_from_iter = self_port_output - class MySource(bt2._UserSourceComponent, - message_iterator_class=MyIter): + class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): def __init__(self, params): nonlocal the_output_port_from_source - the_output_port_from_source = self._add_output_port('out') + the_output_port_from_source = self._add_output_port('out', 'user data') initialized = False graph = self._create_graph(MySource) graph.run() self.assertTrue(initialized) - self.assertEqual(the_output_port_from_source.addr, the_output_port_from_iter.addr) + self.assertEqual( + the_output_port_from_source.addr, the_output_port_from_iter.addr + ) + self.assertEqual(the_output_port_from_iter.user_data, 'user data') def test_finalize(self): class MyIter(bt2._UserMessageIterator): @@ -54,8 +73,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): nonlocal finalized finalized = True - class MySource(bt2._UserSourceComponent, - message_iterator_class=MyIter): + class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): def __init__(self, params): self._add_output_port('out') @@ -71,8 +89,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): nonlocal salut salut = self._component._salut - class MySource(bt2._UserSourceComponent, - message_iterator_class=MyIter): + class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): def __init__(self, params): self._add_output_port('out') self._salut = 23 @@ -88,8 +105,7 @@ class UserMessageIteratorTestCase(unittest.TestCase): nonlocal addr addr = self.addr - class MySource(bt2._UserSourceComponent, - message_iterator_class=MyIter): + class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): def __init__(self, params): self._add_output_port('out') @@ -99,6 +115,214 @@ class UserMessageIteratorTestCase(unittest.TestCase): self.assertIsNotNone(addr) self.assertNotEqual(addr, 0) + # Test that messages returned by _UserMessageIterator.__next__ remain valid + # and can be re-used. + def test_reuse_message(self): + class MyIter(bt2._UserMessageIterator): + def __init__(self, port): + tc, sc, ec = port.user_data + trace = tc() + stream = trace.create_stream(sc) + packet = stream.create_packet() + + # This message will be returned twice by __next__. + event_message = self._create_event_message(ec, packet) + + self._msgs = [ + self._create_stream_beginning_message(stream), + self._create_packet_beginning_message(packet), + event_message, + event_message, + ] + + def __next__(self): + return self._msgs.pop(0) + + class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): + def __init__(self, params): + tc = self._create_trace_class() + sc = tc.create_stream_class(supports_packets=True) + ec = sc.create_event_class() + self._add_output_port('out', (tc, sc, ec)) + + graph = bt2.Graph() + src = graph.add_component(MySource, 'src') + it = graph.create_output_port_message_iterator(src.output_ports['out']) + + # Skip beginning messages. + msg = next(it) + self.assertIsInstance(msg, bt2.message._StreamBeginningMessage) + msg = next(it) + self.assertIsInstance(msg, bt2.message._PacketBeginningMessage) + + msg_ev1 = next(it) + msg_ev2 = next(it) + + self.assertIsInstance(msg_ev1, bt2.message._EventMessage) + self.assertIsInstance(msg_ev2, bt2.message._EventMessage) + self.assertEqual(msg_ev1.addr, msg_ev2.addr) + + @staticmethod + def _setup_seek_beginning_test(): + # Use a source, a filter and an output port iterator. This allows us + # to test calling `seek_beginning` on both a _OutputPortMessageIterator + # and a _UserComponentInputPortMessageIterator, on top of checking that + # _UserMessageIterator._seek_beginning is properly called. + + class MySourceIter(bt2._UserMessageIterator): + def __init__(self, port): + tc, sc, ec = port.user_data + trace = tc() + stream = trace.create_stream(sc) + packet = stream.create_packet() + + self._msgs = [ + self._create_stream_beginning_message(stream), + self._create_packet_beginning_message(packet), + self._create_event_message(ec, packet), + self._create_event_message(ec, packet), + self._create_packet_end_message(packet), + self._create_stream_end_message(stream), + ] + self._at = 0 + + def _seek_beginning(self): + self._at = 0 + + def __next__(self): + if self._at < len(self._msgs): + msg = self._msgs[self._at] + self._at += 1 + return msg + else: + raise StopIteration + + class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter): + def __init__(self, params): + tc = self._create_trace_class() + sc = tc.create_stream_class(supports_packets=True) + ec = sc.create_event_class() + + self._add_output_port('out', (tc, sc, ec)) + + class MyFilterIter(bt2._UserMessageIterator): + def __init__(self, port): + input_port = port.user_data + self._upstream_iter = input_port.create_message_iterator() + + def __next__(self): + return next(self._upstream_iter) + + def _seek_beginning(self): + self._upstream_iter.seek_beginning() + + @property + def _can_seek_beginning(self): + return self._upstream_iter.can_seek_beginning + + class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter): + def __init__(self, params): + input_port = self._add_input_port('in') + self._add_output_port('out', input_port) + + graph = bt2.Graph() + src = graph.add_component(MySource, 'src') + flt = graph.add_component(MyFilter, 'flt') + graph.connect_ports(src.output_ports['out'], flt.input_ports['in']) + it = graph.create_output_port_message_iterator(flt.output_ports['out']) + + return it, MySourceIter + + def test_can_seek_beginning(self): + it, MySourceIter = self._setup_seek_beginning_test() + + def _can_seek_beginning(self): + nonlocal can_seek_beginning + return can_seek_beginning + + MySourceIter._can_seek_beginning = property(_can_seek_beginning) + + can_seek_beginning = True + self.assertTrue(it.can_seek_beginning) + + can_seek_beginning = False + self.assertFalse(it.can_seek_beginning) + + # Once can_seek_beginning returns an error, verify that it raises when + # _can_seek_beginning has/returns the wrong type. + + # Remove the _can_seek_beginning method, we now rely on the presence of + # a _seek_beginning method to know whether the iterator can seek to + # beginning or not. + del MySourceIter._can_seek_beginning + self.assertTrue(it.can_seek_beginning) + + del MySourceIter._seek_beginning + self.assertFalse(it.can_seek_beginning) + + def test_seek_beginning(self): + it, MySourceIter = self._setup_seek_beginning_test() + + msg = next(it) + self.assertIsInstance(msg, bt2.message._StreamBeginningMessage) + msg = next(it) + self.assertIsInstance(msg, bt2.message._PacketBeginningMessage) + + it.seek_beginning() + + msg = next(it) + self.assertIsInstance(msg, bt2.message._StreamBeginningMessage) + + # Verify that we can seek beginning after having reached the end. + # + # It currently does not work to seek an output port message iterator + # once it's ended, but we should eventually make it work and uncomment + # the following snippet. + # + # try: + # while True: + # next(it) + # except bt2.Stop: + # pass + # + # it.seek_beginning() + # msg = next(it) + # self.assertIsInstance(msg, bt2.message._StreamBeginningMessage) + + def test_seek_beginning_user_error(self): + it, MySourceIter = self._setup_seek_beginning_test() + + def _seek_beginning_error(self): + raise ValueError('ouch') + + MySourceIter._seek_beginning = _seek_beginning_error + + with self.assertRaises(bt2._Error): + it.seek_beginning() + + # Try consuming many times from an iterator that always returns TryAgain. + # This verifies that we are not missing an incref of Py_None, making the + # refcount of Py_None reach 0. + def test_try_again_many_times(self): + class MyIter(bt2._UserMessageIterator): + def __next__(self): + raise bt2.TryAgain + + class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): + def __init__(self, params): + self._add_output_port('out') + + graph = bt2.Graph() + src = graph.add_component(MySource, 'src') + it = graph.create_output_port_message_iterator(src.output_ports['out']) + + # The initial refcount of Py_None was in the 7000, so 100000 iterations + # should be enough to catch the bug even if there are small differences + # between configurations. + for i in range(100000): + with self.assertRaises(bt2.TryAgain): + next(it) + class OutputPortMessageIteratorTestCase(unittest.TestCase): def test_component(self): @@ -119,28 +343,29 @@ class OutputPortMessageIteratorTestCase(unittest.TestCase): elif self._at == 6: msg = self._create_stream_end_message(test_obj._stream) else: - msg = self._create_event_message(test_obj._event_class, test_obj._packet) + msg = self._create_event_message( + test_obj._event_class, test_obj._packet + ) msg.event.payload_field['my_int'] = self._at * 3 self._at += 1 return msg - class MySource(bt2._UserSourceComponent, - message_iterator_class=MyIter): + class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): def __init__(self, params): self._add_output_port('out') trace_class = self._create_trace_class() - stream_class = trace_class.create_stream_class() + stream_class = trace_class.create_stream_class(supports_packets=True) # Create payload field class my_int_ft = trace_class.create_signed_integer_field_class(32) payload_ft = trace_class.create_structure_field_class() - payload_ft += collections.OrderedDict([ - ('my_int', my_int_ft), - ]) + payload_ft += [('my_int', my_int_ft)] - event_class = stream_class.create_event_class(name='salut', payload_field_class=payload_ft) + event_class = stream_class.create_event_class( + name='salut', payload_field_class=payload_ft + ) trace = trace_class() stream = trace.create_stream(stream_class) @@ -166,6 +391,10 @@ class OutputPortMessageIteratorTestCase(unittest.TestCase): self.assertIsInstance(msg, bt2.message._StreamEndMessage) else: self.assertIsInstance(msg, bt2.message._EventMessage) - self.assertEqual(msg.event.event_class.name, 'salut') + self.assertEqual(msg.event.cls.name, 'salut') field = msg.event.payload_field['my_int'] self.assertEqual(field, at * 3) + + +if __name__ == '__main__': + unittest.main()