+ flt = graph.add_component(MyFilter, 'flt')
+ sink = graph.add_component(sink_cls, 'sink')
+ graph.connect_ports(src.output_ports['out'], flt.input_ports['in'])
+ graph.connect_ports(flt.output_ports['out'], sink.input_ports['in'])
+ return MySourceIter, graph
+
+ def test_can_seek_beginning(self):
+ class MySink(bt2._UserSinkComponent):
+ def __init__(self, params, obj):
+ self._add_input_port('in')
+
+ def _user_graph_is_configured(self):
+ self._msg_iter = self._create_input_port_message_iterator(
+ self._input_ports['in']
+ )
+
+ def _user_consume(self):
+ nonlocal can_seek_beginning
+ can_seek_beginning = self._msg_iter.can_seek_beginning
+
+ MySourceIter, graph = self._setup_seek_beginning_test(MySink)
+
+ def _user_can_seek_beginning(self):
+ nonlocal input_port_iter_can_seek_beginning
+ return input_port_iter_can_seek_beginning
+
+ MySourceIter._user_can_seek_beginning = property(_user_can_seek_beginning)
+
+ input_port_iter_can_seek_beginning = True
+ can_seek_beginning = None
+ graph.run_once()
+ self.assertTrue(can_seek_beginning)
+
+ input_port_iter_can_seek_beginning = False
+ can_seek_beginning = None
+ graph.run_once()
+ self.assertFalse(can_seek_beginning)
+
+ # Once can_seek_beginning returns an error, verify that it raises when
+ # _can_seek_beginning has/returns the wrong type.
+
+ # Remove the _can_seek_beginning method, we now rely on the presence of
+ # a _seek_beginning method to know whether the iterator can seek to
+ # beginning or not.
+ del MySourceIter._user_can_seek_beginning
+ can_seek_beginning = None
+ graph.run_once()
+ self.assertTrue(can_seek_beginning)
+
+ del MySourceIter._user_seek_beginning
+ can_seek_beginning = None
+ graph.run_once()
+ self.assertFalse(can_seek_beginning)
+
+ def test_seek_beginning(self):
+ class MySink(bt2._UserSinkComponent):
+ def __init__(self, params, obj):
+ self._add_input_port('in')
+
+ def _user_graph_is_configured(self):
+ self._msg_iter = self._create_input_port_message_iterator(
+ self._input_ports['in']
+ )
+
+ def _user_consume(self):
+ nonlocal do_seek_beginning
+ nonlocal msg
+
+ if do_seek_beginning:
+ self._msg_iter.seek_beginning()
+ return
+
+ msg = next(self._msg_iter)
+
+ do_seek_beginning = False
+ msg = None
+ MySourceIter, graph = self._setup_seek_beginning_test(MySink)
+ graph.run_once()
+ self.assertIs(type(msg), bt2._StreamBeginningMessageConst)
+ graph.run_once()
+ self.assertIs(type(msg), bt2._PacketBeginningMessageConst)
+ do_seek_beginning = True
+ graph.run_once()
+ do_seek_beginning = False
+ graph.run_once()
+ self.assertIs(type(msg), bt2._StreamBeginningMessageConst)
+
+ def test_seek_beginning_user_error(self):
+ class MySink(bt2._UserSinkComponent):
+ def __init__(self, params, obj):
+ self._add_input_port('in')
+
+ def _user_graph_is_configured(self):
+ self._msg_iter = self._create_input_port_message_iterator(
+ self._input_ports['in']
+ )
+
+ def _user_consume(self):
+ self._msg_iter.seek_beginning()
+
+ MySourceIter, graph = self._setup_seek_beginning_test(MySink)
+
+ def _user_seek_beginning_error(self):
+ raise ValueError('ouch')
+
+ MySourceIter._user_seek_beginning = _user_seek_beginning_error
+
+ with self.assertRaises(bt2._Error):
+ graph.run_once()
+
+ # Try consuming many times from an iterator that always returns TryAgain.
+ # This verifies that we are not missing an incref of Py_None, making the
+ # refcount of Py_None reach 0.
+ def test_try_again_many_times(self):
+ class MyIter(bt2._UserMessageIterator):
+ def __next__(self):
+ raise bt2.TryAgain
+
+ class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+ def __init__(self, params, obj):
+ self._add_output_port('out')
+
+ graph = bt2.Graph()
+ src = graph.add_component(MySource, 'src')
+ it = TestOutputPortMessageIterator(graph, src.output_ports['out'])
+
+ # Three times the initial ref count of `None` iterations should
+ # be enough to catch the bug even if there are small differences
+ # between configurations.
+ none_ref_count = sys.getrefcount(None) * 3
+
+ for i in range(none_ref_count):
+ with self.assertRaises(bt2.TryAgain):
+ next(it)
+
+
+if __name__ == '__main__':
+ unittest.main()