bt2: let Python message iterators implement seek beginning
authorSimon Marchi <simon.marchi@efficios.com>
Thu, 27 Jun 2019 03:33:58 +0000 (23:33 -0400)
committerPhilippe Proulx <eeppeliteloop@gmail.com>
Tue, 2 Jul 2019 15:57:01 +0000 (11:57 -0400)
This patch adds the possibility for message iterators implemented in
Python to implement the "can seek beginning" and "seek beginning"
operations.

A message iterator can support "seek beginning" by defining a
_seek_beginning method:

  class MyIter(bt2._UserMessageIterator):
    def _seek_beginning(self):
      # Do the needful.

It can support the "can seek beginning" operation by defining a
_can_seek_beginning property, which must evaluate to a bool:

  class MyIter(bt2._UserMessageIterator):
    @property
    def _can_seek_beginning(self):
      # Do the needful, including returning a bool.

The behavior of the "can seek beginning" operation is made to mimic
the C API:

- If the iterator has a _can_seek_beginning attribute, it is used to
  determine whether the iterator can seek beginning.
- Otherwise, the presence or absence of a _seek_beginning method
  indicates whether it can.

Change-Id: I68c48b5cd30090bff833529cda9fa918c7e72b0b
Signed-off-by: Simon Marchi <simon.marchi@efficios.com>
Reviewed-on: https://review.lttng.org/c/babeltrace/+/1555
Reviewed-by: Philippe Proulx <eeppeliteloop@gmail.com>
src/bindings/python/bt2/bt2/message_iterator.py
src/bindings/python/bt2/bt2/native_bt_component_class.i
tests/bindings/python/bt2/test_message_iterator.py

index b07e2450afc97b46f7b5e46c1198d6f07aeb2853..b56b17a26c325a8d41a2a6e30f8d73ebff17fcdb 100644 (file)
@@ -28,14 +28,6 @@ import bt2
 
 
 class _MessageIterator(collections.abc.Iterator):
-    def _handle_status(self, status, gen_error_msg):
-        if status == native_bt.MESSAGE_ITERATOR_STATUS_AGAIN:
-            raise bt2.TryAgain
-        elif status == native_bt.MESSAGE_ITERATOR_STATUS_END:
-            raise bt2.Stop
-        elif status < 0:
-            raise bt2.Error(gen_error_msg)
-
     def __next__(self):
         raise NotImplementedError
 
@@ -46,6 +38,14 @@ class _GenericMessageIterator(object._SharedObject, _MessageIterator):
             self._at = 0
             super().__init__(ptr)
 
+    def _handle_status(self, status, gen_error_msg):
+        if status == native_bt.MESSAGE_ITERATOR_STATUS_AGAIN:
+            raise bt2.TryAgain
+        elif status == native_bt.MESSAGE_ITERATOR_STATUS_END:
+            raise bt2.Stop
+        elif status < 0:
+            raise bt2.Error(gen_error_msg)
+
     def __next__(self):
         if len(self._current_msgs) == self._at:
             status, msgs = self._get_msg_range(self._ptr)
@@ -59,12 +59,27 @@ class _GenericMessageIterator(object._SharedObject, _MessageIterator):
 
         return bt2.message._create_from_ptr(msg_ptr)
 
+    @property
+    def can_seek_beginning(self):
+        res = self._can_seek_beginning(self._ptr)
+        return res != 0
+
+    def seek_beginning(self):
+        # Forget about buffered messages, they won't be valid after seeking..
+        self._current_msgs.clear()
+        self._at = 0
+
+        status = self._seek_beginning(self._ptr)
+        self._handle_status(status, 'cannot seek message iterator beginning')
+
 
 # This is created when a component wants to iterate on one of its input ports.
 class _UserComponentInputPortMessageIterator(_GenericMessageIterator):
     _get_msg_range = staticmethod(native_bt.py3_self_component_port_input_get_msg_range)
     _get_ref = staticmethod(native_bt.self_component_port_input_message_iterator_get_ref)
     _put_ref = staticmethod(native_bt.self_component_port_input_message_iterator_put_ref)
+    _can_seek_beginning = staticmethod(native_bt.self_component_port_input_message_iterator_can_seek_beginning)
+    _seek_beginning = staticmethod(native_bt.self_component_port_input_message_iterator_seek_beginning)
 
 
 # This is created when the user wants to iterate on a component's output port,
@@ -73,6 +88,8 @@ class _OutputPortMessageIterator(_GenericMessageIterator):
     _get_msg_range = staticmethod(native_bt.py3_port_output_get_msg_range)
     _get_ref = staticmethod(native_bt.port_output_message_iterator_get_ref)
     _put_ref = staticmethod(native_bt.port_output_message_iterator_put_ref)
+    _can_seek_beginning = staticmethod(native_bt.port_output_message_iterator_can_seek_beginning)
+    _seek_beginning = staticmethod(native_bt.port_output_message_iterator_seek_beginning)
 
 
 # This is extended by the user to implement component classes in Python.  It
@@ -137,6 +154,24 @@ class _UserMessageIterator(_MessageIterator):
         msg._get_ref(msg._ptr)
         return int(msg._ptr)
 
+    @property
+    def _can_seek_beginning_from_native(self):
+        # Here, we mimic the behavior of the C API:
+        #
+        # - If the iterator has a _can_seek_beginning attribute, read it and use
+        #   that result.
+        # - Otherwise, the presence or absence of a `_seek_beginning`
+        #   method indicates whether the iterator can seek beginning.
+        if hasattr(self, '_can_seek_beginning'):
+            can_seek_beginning = self._can_seek_beginning
+            utils._check_bool(can_seek_beginning)
+            return can_seek_beginning
+        else:
+            return hasattr(self, '_seek_beginning')
+
+    def _seek_beginning_from_native(self):
+        self._seek_beginning()
+
     def _create_event_message(self, event_class, packet, default_clock_snapshot=None):
         utils._check_type(event_class, bt2.event_class._EventClass)
         utils._check_type(packet, bt2.packet._Packet)
index 18078eead697748619da42c0756ffb0cd4d1d95c..8219040dbc054009a64f99edd287975a9d9874c6 100644 (file)
@@ -452,6 +452,58 @@ void bt_py3_component_class_sink_finalize(bt_self_component_sink *self_component
        bt_py3_component_class_finalize(self_component);
 }
 
+static
+bt_bool bt_py3_component_class_can_seek_beginning(
+               bt_self_message_iterator *self_message_iterator)
+{
+       PyObject *py_iter;
+       PyObject *py_result = NULL;
+       bt_bool can_seek_beginning = false;
+
+       py_iter = bt_self_message_iterator_get_data(self_message_iterator);
+       BT_ASSERT(py_iter);
+
+       py_result = PyObject_GetAttrString(py_iter, "_can_seek_beginning_from_native");
+
+       BT_ASSERT(!py_result || PyBool_Check(py_result));
+
+       if (py_result) {
+               can_seek_beginning = PyObject_IsTrue(py_result);
+       } else {
+               /*
+                * Once can_seek_beginning can report errors, convert the
+                * exception to a status.  For now, log and return false;
+                */
+               bt2_py_loge_exception();
+               PyErr_Clear();
+       }
+
+       Py_XDECREF(py_result);
+
+       return can_seek_beginning;
+}
+
+static
+bt_self_message_iterator_status bt_py3_component_class_seek_beginning(
+               bt_self_message_iterator *self_message_iterator)
+{
+       PyObject *py_iter;
+       PyObject *py_result;
+       bt_self_message_iterator_status status;
+
+       py_iter = bt_self_message_iterator_get_data(self_message_iterator);
+       BT_ASSERT(py_iter);
+
+       py_result = PyObject_CallMethod(py_iter, "_seek_beginning_from_native", NULL);
+
+       BT_ASSERT(!py_result || py_result == Py_None);
+        status = bt_py3_exc_to_self_message_iterator_status();
+
+       Py_XDECREF(py_result);
+
+       return status;
+}
+
 static
 bt_self_component_status bt_py3_component_class_port_connected(
                bt_self_component *self_component,
@@ -1001,7 +1053,13 @@ bt_component_class_source *bt_py3_component_class_source_create(
 
        ret = bt_component_class_source_set_init_method(component_class_source, bt_py3_component_class_source_init);
        BT_ASSERT(ret == 0);
-       ret = bt_component_class_source_set_finalize_method (component_class_source, bt_py3_component_class_source_finalize);
+       ret = bt_component_class_source_set_finalize_method(component_class_source, bt_py3_component_class_source_finalize);
+       BT_ASSERT(ret == 0);
+       ret = bt_component_class_source_set_message_iterator_can_seek_beginning_method(component_class_source,
+               bt_py3_component_class_can_seek_beginning);
+       BT_ASSERT(ret == 0);
+       ret = bt_component_class_source_set_message_iterator_seek_beginning_method(component_class_source,
+               bt_py3_component_class_seek_beginning);
        BT_ASSERT(ret == 0);
        ret = bt_component_class_source_set_output_port_connected_method(component_class_source,
                bt_py3_component_class_source_output_port_connected);
@@ -1049,6 +1107,12 @@ bt_component_class_filter *bt_py3_component_class_filter_create(
        BT_ASSERT(ret == 0);
        ret = bt_component_class_filter_set_finalize_method (component_class_filter, bt_py3_component_class_filter_finalize);
        BT_ASSERT(ret == 0);
+       ret = bt_component_class_filter_set_message_iterator_can_seek_beginning_method(component_class_filter,
+               bt_py3_component_class_can_seek_beginning);
+       BT_ASSERT(ret == 0);
+       ret = bt_component_class_filter_set_message_iterator_seek_beginning_method(component_class_filter,
+               bt_py3_component_class_seek_beginning);
+       BT_ASSERT(ret == 0);
        ret = bt_component_class_filter_set_input_port_connected_method(component_class_filter,
                bt_py3_component_class_filter_input_port_connected);
        BT_ASSERT(ret == 0);
index 98ec64aba7c5796bff0052a793e1bca60c2be013..a9762976bcef8f9adc2a3b354c13dc3ce7b1a80a 100644 (file)
@@ -166,6 +166,149 @@ class UserMessageIteratorTestCase(unittest.TestCase):
         self.assertIsInstance(msg_ev2, bt2.message._EventMessage)
         self.assertEqual(msg_ev1.addr, msg_ev2.addr)
 
+    @staticmethod
+    def _setup_seek_beginning_test():
+        # Use a source, a filter and an output port iterator.  This allows us
+        # to test calling `seek_beginning` on both a _OutputPortMessageIterator
+        # and a _UserComponentInputPortMessageIterator, on top of checking that
+        # _UserMessageIterator._seek_beginning is properly called.
+
+        class MySourceIter(bt2._UserMessageIterator):
+            def __init__(self, port):
+                tc, sc, ec = port.user_data
+                trace = tc()
+                stream = trace.create_stream(sc)
+                packet = stream.create_packet()
+
+                self._msgs = [
+                    self._create_stream_beginning_message(stream),
+                    self._create_stream_activity_beginning_message(stream),
+                    self._create_packet_beginning_message(packet),
+                    self._create_event_message(ec, packet),
+                    self._create_event_message(ec, packet),
+                    self._create_packet_end_message(packet),
+                    self._create_stream_activity_end_message(stream),
+                    self._create_stream_end_message(stream),
+                ]
+                self._at = 0
+
+            def _seek_beginning(self):
+                self._at = 0
+
+            def __next__(self):
+                if self._at < len(self._msgs):
+                    msg = self._msgs[self._at]
+                    self._at += 1
+                    return msg
+                else:
+                    raise StopIteration
+
+        class MySource(bt2._UserSourceComponent,
+                       message_iterator_class=MySourceIter):
+            def __init__(self, params):
+                tc = self._create_trace_class()
+                sc = tc.create_stream_class()
+                ec = sc.create_event_class()
+
+                self._add_output_port('out', (tc, sc, ec))
+
+        class MyFilterIter(bt2._UserMessageIterator):
+            def __init__(self, port):
+                input_port = port.user_data
+                self._upstream_iter = input_port.create_message_iterator()
+
+            def __next__(self):
+                return next(self._upstream_iter)
+
+            def _seek_beginning(self):
+                self._upstream_iter.seek_beginning()
+
+            @property
+            def _can_seek_beginning(self):
+                return self._upstream_iter.can_seek_beginning
+
+        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
+            def __init__(self, params):
+                input_port = self._add_input_port('in')
+                self._add_output_port('out', input_port)
+
+
+        graph = bt2.Graph()
+        src = graph.add_component(MySource, 'src')
+        flt = graph.add_component(MyFilter, 'flt')
+        graph.connect_ports(src.output_ports['out'], flt.input_ports['in'])
+        it = graph.create_output_port_message_iterator(flt.output_ports['out'])
+
+        return it, MySourceIter
+
+    def test_can_seek_beginning(self):
+        it, MySourceIter = self._setup_seek_beginning_test()
+
+        def _can_seek_beginning(self):
+            nonlocal can_seek_beginning
+            return can_seek_beginning
+
+        MySourceIter._can_seek_beginning = property(_can_seek_beginning)
+
+        can_seek_beginning = True
+        self.assertTrue(it.can_seek_beginning)
+
+        can_seek_beginning = False
+        self.assertFalse(it.can_seek_beginning)
+
+        # Once can_seek_beginning returns an error, verify that it raises when
+        # _can_seek_beginning has/returns the wrong type.
+
+        # Remove the _can_seek_beginning method, we now rely on the presence of
+        # a _seek_beginning method to know whether the iterator can seek to
+        # beginning or not.
+        del MySourceIter._can_seek_beginning
+        self.assertTrue(it.can_seek_beginning)
+
+        del MySourceIter._seek_beginning
+        self.assertFalse(it.can_seek_beginning)
+
+    def test_seek_beginning(self):
+        it, MySourceIter = self._setup_seek_beginning_test()
+
+        msg = next(it)
+        self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
+        msg = next(it)
+        self.assertIsInstance(msg, bt2.message._StreamActivityBeginningMessage)
+
+        it.seek_beginning()
+
+        msg = next(it)
+        self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
+
+        # Verify that we can seek beginning after having reached the end.
+        #
+        # It currently does not work to seek an output port message iterator
+        # once it's ended, but we should eventually make it work and uncomment
+        # the following snippet.
+        #
+        # try:
+        #    while True:
+        #        next(it)
+        # except bt2.Stop:
+        #    pass
+        #
+        # it.seek_beginning()
+        # msg = next(it)
+        # self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
+
+    def test_seek_beginning_user_error(self):
+        it, MySourceIter = self._setup_seek_beginning_test()
+
+        def _seek_beginning_error(self):
+           raise ValueError('ouch')
+
+        MySourceIter._seek_beginning = _seek_beginning_error
+
+        with self.assertRaises(bt2.Error):
+            it.seek_beginning()
+
+
 
 class OutputPortMessageIteratorTestCase(unittest.TestCase):
     def test_component(self):
@@ -236,3 +379,6 @@ class OutputPortMessageIteratorTestCase(unittest.TestCase):
                 self.assertEqual(msg.event.cls.name, 'salut')
                 field = msg.event.payload_field['my_int']
                 self.assertEqual(field, at * 3)
+
+if __name__ == '__main__':
+    unittest.main()
This page took 0.029913 seconds and 4 git commands to generate.