From f00b8d402364ed0b77c45a61467c1d27690222f7 Mon Sep 17 00:00:00 2001 From: Simon Marchi Date: Wed, 26 Jun 2019 23:33:58 -0400 Subject: [PATCH] bt2: let Python message iterators implement seek beginning This patch adds the possibility for message iterators implemented in Python to implement the "can seek beginning" and "seek beginning" operations. A message iterator can support "seek beginning" by defining a _seek_beginning method: class MyIter(bt2._UserMessageIterator): def _seek_beginning(self): # Do the needful. It can support the "can seek beginning" operation by defining a _can_seek_beginning property, which must evaluate to a bool: class MyIter(bt2._UserMessageIterator): @property def _can_seek_beginning(self): # Do the needful, including returning a bool. The behavior of the "can seek beginning" operation is made to mimic the C API: - If the iterator has a _can_seek_beginning attribute, it is used to determine whether the iterator can seek beginning. - Otherwise, the presence or absence of a _seek_beginning method indicates whether it can. Change-Id: I68c48b5cd30090bff833529cda9fa918c7e72b0b Signed-off-by: Simon Marchi Reviewed-on: https://review.lttng.org/c/babeltrace/+/1555 Reviewed-by: Philippe Proulx --- .../python/bt2/bt2/message_iterator.py | 51 +++++- .../bt2/bt2/native_bt_component_class.i | 66 +++++++- .../python/bt2/test_message_iterator.py | 146 ++++++++++++++++++ 3 files changed, 254 insertions(+), 9 deletions(-) diff --git a/src/bindings/python/bt2/bt2/message_iterator.py b/src/bindings/python/bt2/bt2/message_iterator.py index b07e2450..b56b17a2 100644 --- a/src/bindings/python/bt2/bt2/message_iterator.py +++ b/src/bindings/python/bt2/bt2/message_iterator.py @@ -28,14 +28,6 @@ import bt2 class _MessageIterator(collections.abc.Iterator): - def _handle_status(self, status, gen_error_msg): - if status == native_bt.MESSAGE_ITERATOR_STATUS_AGAIN: - raise bt2.TryAgain - elif status == native_bt.MESSAGE_ITERATOR_STATUS_END: - raise bt2.Stop - elif status < 0: - raise bt2.Error(gen_error_msg) - def __next__(self): raise NotImplementedError @@ -46,6 +38,14 @@ class _GenericMessageIterator(object._SharedObject, _MessageIterator): self._at = 0 super().__init__(ptr) + def _handle_status(self, status, gen_error_msg): + if status == native_bt.MESSAGE_ITERATOR_STATUS_AGAIN: + raise bt2.TryAgain + elif status == native_bt.MESSAGE_ITERATOR_STATUS_END: + raise bt2.Stop + elif status < 0: + raise bt2.Error(gen_error_msg) + def __next__(self): if len(self._current_msgs) == self._at: status, msgs = self._get_msg_range(self._ptr) @@ -59,12 +59,27 @@ class _GenericMessageIterator(object._SharedObject, _MessageIterator): return bt2.message._create_from_ptr(msg_ptr) + @property + def can_seek_beginning(self): + res = self._can_seek_beginning(self._ptr) + return res != 0 + + def seek_beginning(self): + # Forget about buffered messages, they won't be valid after seeking.. + self._current_msgs.clear() + self._at = 0 + + status = self._seek_beginning(self._ptr) + self._handle_status(status, 'cannot seek message iterator beginning') + # This is created when a component wants to iterate on one of its input ports. class _UserComponentInputPortMessageIterator(_GenericMessageIterator): _get_msg_range = staticmethod(native_bt.py3_self_component_port_input_get_msg_range) _get_ref = staticmethod(native_bt.self_component_port_input_message_iterator_get_ref) _put_ref = staticmethod(native_bt.self_component_port_input_message_iterator_put_ref) + _can_seek_beginning = staticmethod(native_bt.self_component_port_input_message_iterator_can_seek_beginning) + _seek_beginning = staticmethod(native_bt.self_component_port_input_message_iterator_seek_beginning) # This is created when the user wants to iterate on a component's output port, @@ -73,6 +88,8 @@ class _OutputPortMessageIterator(_GenericMessageIterator): _get_msg_range = staticmethod(native_bt.py3_port_output_get_msg_range) _get_ref = staticmethod(native_bt.port_output_message_iterator_get_ref) _put_ref = staticmethod(native_bt.port_output_message_iterator_put_ref) + _can_seek_beginning = staticmethod(native_bt.port_output_message_iterator_can_seek_beginning) + _seek_beginning = staticmethod(native_bt.port_output_message_iterator_seek_beginning) # This is extended by the user to implement component classes in Python. It @@ -137,6 +154,24 @@ class _UserMessageIterator(_MessageIterator): msg._get_ref(msg._ptr) return int(msg._ptr) + @property + def _can_seek_beginning_from_native(self): + # Here, we mimic the behavior of the C API: + # + # - If the iterator has a _can_seek_beginning attribute, read it and use + # that result. + # - Otherwise, the presence or absence of a `_seek_beginning` + # method indicates whether the iterator can seek beginning. + if hasattr(self, '_can_seek_beginning'): + can_seek_beginning = self._can_seek_beginning + utils._check_bool(can_seek_beginning) + return can_seek_beginning + else: + return hasattr(self, '_seek_beginning') + + def _seek_beginning_from_native(self): + self._seek_beginning() + def _create_event_message(self, event_class, packet, default_clock_snapshot=None): utils._check_type(event_class, bt2.event_class._EventClass) utils._check_type(packet, bt2.packet._Packet) diff --git a/src/bindings/python/bt2/bt2/native_bt_component_class.i b/src/bindings/python/bt2/bt2/native_bt_component_class.i index 18078eea..8219040d 100644 --- a/src/bindings/python/bt2/bt2/native_bt_component_class.i +++ b/src/bindings/python/bt2/bt2/native_bt_component_class.i @@ -452,6 +452,58 @@ void bt_py3_component_class_sink_finalize(bt_self_component_sink *self_component bt_py3_component_class_finalize(self_component); } +static +bt_bool bt_py3_component_class_can_seek_beginning( + bt_self_message_iterator *self_message_iterator) +{ + PyObject *py_iter; + PyObject *py_result = NULL; + bt_bool can_seek_beginning = false; + + py_iter = bt_self_message_iterator_get_data(self_message_iterator); + BT_ASSERT(py_iter); + + py_result = PyObject_GetAttrString(py_iter, "_can_seek_beginning_from_native"); + + BT_ASSERT(!py_result || PyBool_Check(py_result)); + + if (py_result) { + can_seek_beginning = PyObject_IsTrue(py_result); + } else { + /* + * Once can_seek_beginning can report errors, convert the + * exception to a status. For now, log and return false; + */ + bt2_py_loge_exception(); + PyErr_Clear(); + } + + Py_XDECREF(py_result); + + return can_seek_beginning; +} + +static +bt_self_message_iterator_status bt_py3_component_class_seek_beginning( + bt_self_message_iterator *self_message_iterator) +{ + PyObject *py_iter; + PyObject *py_result; + bt_self_message_iterator_status status; + + py_iter = bt_self_message_iterator_get_data(self_message_iterator); + BT_ASSERT(py_iter); + + py_result = PyObject_CallMethod(py_iter, "_seek_beginning_from_native", NULL); + + BT_ASSERT(!py_result || py_result == Py_None); + status = bt_py3_exc_to_self_message_iterator_status(); + + Py_XDECREF(py_result); + + return status; +} + static bt_self_component_status bt_py3_component_class_port_connected( bt_self_component *self_component, @@ -1001,7 +1053,13 @@ bt_component_class_source *bt_py3_component_class_source_create( ret = bt_component_class_source_set_init_method(component_class_source, bt_py3_component_class_source_init); BT_ASSERT(ret == 0); - ret = bt_component_class_source_set_finalize_method (component_class_source, bt_py3_component_class_source_finalize); + ret = bt_component_class_source_set_finalize_method(component_class_source, bt_py3_component_class_source_finalize); + BT_ASSERT(ret == 0); + ret = bt_component_class_source_set_message_iterator_can_seek_beginning_method(component_class_source, + bt_py3_component_class_can_seek_beginning); + BT_ASSERT(ret == 0); + ret = bt_component_class_source_set_message_iterator_seek_beginning_method(component_class_source, + bt_py3_component_class_seek_beginning); BT_ASSERT(ret == 0); ret = bt_component_class_source_set_output_port_connected_method(component_class_source, bt_py3_component_class_source_output_port_connected); @@ -1049,6 +1107,12 @@ bt_component_class_filter *bt_py3_component_class_filter_create( BT_ASSERT(ret == 0); ret = bt_component_class_filter_set_finalize_method (component_class_filter, bt_py3_component_class_filter_finalize); BT_ASSERT(ret == 0); + ret = bt_component_class_filter_set_message_iterator_can_seek_beginning_method(component_class_filter, + bt_py3_component_class_can_seek_beginning); + BT_ASSERT(ret == 0); + ret = bt_component_class_filter_set_message_iterator_seek_beginning_method(component_class_filter, + bt_py3_component_class_seek_beginning); + BT_ASSERT(ret == 0); ret = bt_component_class_filter_set_input_port_connected_method(component_class_filter, bt_py3_component_class_filter_input_port_connected); BT_ASSERT(ret == 0); diff --git a/tests/bindings/python/bt2/test_message_iterator.py b/tests/bindings/python/bt2/test_message_iterator.py index 98ec64ab..a9762976 100644 --- a/tests/bindings/python/bt2/test_message_iterator.py +++ b/tests/bindings/python/bt2/test_message_iterator.py @@ -166,6 +166,149 @@ class UserMessageIteratorTestCase(unittest.TestCase): self.assertIsInstance(msg_ev2, bt2.message._EventMessage) self.assertEqual(msg_ev1.addr, msg_ev2.addr) + @staticmethod + def _setup_seek_beginning_test(): + # Use a source, a filter and an output port iterator. This allows us + # to test calling `seek_beginning` on both a _OutputPortMessageIterator + # and a _UserComponentInputPortMessageIterator, on top of checking that + # _UserMessageIterator._seek_beginning is properly called. + + class MySourceIter(bt2._UserMessageIterator): + def __init__(self, port): + tc, sc, ec = port.user_data + trace = tc() + stream = trace.create_stream(sc) + packet = stream.create_packet() + + self._msgs = [ + self._create_stream_beginning_message(stream), + self._create_stream_activity_beginning_message(stream), + self._create_packet_beginning_message(packet), + self._create_event_message(ec, packet), + self._create_event_message(ec, packet), + self._create_packet_end_message(packet), + self._create_stream_activity_end_message(stream), + self._create_stream_end_message(stream), + ] + self._at = 0 + + def _seek_beginning(self): + self._at = 0 + + def __next__(self): + if self._at < len(self._msgs): + msg = self._msgs[self._at] + self._at += 1 + return msg + else: + raise StopIteration + + class MySource(bt2._UserSourceComponent, + message_iterator_class=MySourceIter): + def __init__(self, params): + tc = self._create_trace_class() + sc = tc.create_stream_class() + ec = sc.create_event_class() + + self._add_output_port('out', (tc, sc, ec)) + + class MyFilterIter(bt2._UserMessageIterator): + def __init__(self, port): + input_port = port.user_data + self._upstream_iter = input_port.create_message_iterator() + + def __next__(self): + return next(self._upstream_iter) + + def _seek_beginning(self): + self._upstream_iter.seek_beginning() + + @property + def _can_seek_beginning(self): + return self._upstream_iter.can_seek_beginning + + class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter): + def __init__(self, params): + input_port = self._add_input_port('in') + self._add_output_port('out', input_port) + + + graph = bt2.Graph() + src = graph.add_component(MySource, 'src') + flt = graph.add_component(MyFilter, 'flt') + graph.connect_ports(src.output_ports['out'], flt.input_ports['in']) + it = graph.create_output_port_message_iterator(flt.output_ports['out']) + + return it, MySourceIter + + def test_can_seek_beginning(self): + it, MySourceIter = self._setup_seek_beginning_test() + + def _can_seek_beginning(self): + nonlocal can_seek_beginning + return can_seek_beginning + + MySourceIter._can_seek_beginning = property(_can_seek_beginning) + + can_seek_beginning = True + self.assertTrue(it.can_seek_beginning) + + can_seek_beginning = False + self.assertFalse(it.can_seek_beginning) + + # Once can_seek_beginning returns an error, verify that it raises when + # _can_seek_beginning has/returns the wrong type. + + # Remove the _can_seek_beginning method, we now rely on the presence of + # a _seek_beginning method to know whether the iterator can seek to + # beginning or not. + del MySourceIter._can_seek_beginning + self.assertTrue(it.can_seek_beginning) + + del MySourceIter._seek_beginning + self.assertFalse(it.can_seek_beginning) + + def test_seek_beginning(self): + it, MySourceIter = self._setup_seek_beginning_test() + + msg = next(it) + self.assertIsInstance(msg, bt2.message._StreamBeginningMessage) + msg = next(it) + self.assertIsInstance(msg, bt2.message._StreamActivityBeginningMessage) + + it.seek_beginning() + + msg = next(it) + self.assertIsInstance(msg, bt2.message._StreamBeginningMessage) + + # Verify that we can seek beginning after having reached the end. + # + # It currently does not work to seek an output port message iterator + # once it's ended, but we should eventually make it work and uncomment + # the following snippet. + # + # try: + # while True: + # next(it) + # except bt2.Stop: + # pass + # + # it.seek_beginning() + # msg = next(it) + # self.assertIsInstance(msg, bt2.message._StreamBeginningMessage) + + def test_seek_beginning_user_error(self): + it, MySourceIter = self._setup_seek_beginning_test() + + def _seek_beginning_error(self): + raise ValueError('ouch') + + MySourceIter._seek_beginning = _seek_beginning_error + + with self.assertRaises(bt2.Error): + it.seek_beginning() + + class OutputPortMessageIteratorTestCase(unittest.TestCase): def test_component(self): @@ -236,3 +379,6 @@ class OutputPortMessageIteratorTestCase(unittest.TestCase): self.assertEqual(msg.event.cls.name, 'salut') field = msg.event.payload_field['my_int'] self.assertEqual(field, at * 3) + +if __name__ == '__main__': + unittest.main() -- 2.34.1