e121c973c4325e46f2398c1a38eed5a9140aa1b8
[babeltrace.git] / tests / bindings / python / bt2 / test_message_iterator.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # Copyright (C) 2019 EfficiOS Inc.
4 #
5
6 import sys
7 import unittest
8
9 import bt2
10 from bt2 import port as bt2_port
11 from bt2 import message_iterator as bt2_message_iterator
12 from utils import TestOutputPortMessageIterator
13
14
15 class SimpleSink(bt2._UserSinkComponent):
16 # Straightforward sink that creates one input port (`in`) and consumes from
17 # it.
18
19 def __init__(self, config, params, obj):
20 self._add_input_port("in")
21
22 def _user_consume(self):
23 next(self._msg_iter)
24
25 def _user_graph_is_configured(self):
26 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
27
28
29 def _create_graph(src_comp_cls, sink_comp_cls, flt_comp_cls=None):
30 graph = bt2.Graph()
31
32 src_comp = graph.add_component(src_comp_cls, "src")
33 sink_comp = graph.add_component(sink_comp_cls, "sink")
34
35 if flt_comp_cls is not None:
36 flt_comp = graph.add_component(flt_comp_cls, "flt")
37 graph.connect_ports(src_comp.output_ports["out"], flt_comp.input_ports["in"])
38 graph.connect_ports(flt_comp.output_ports["out"], sink_comp.input_ports["in"])
39 else:
40 graph.connect_ports(src_comp.output_ports["out"], sink_comp.input_ports["in"])
41
42 return graph
43
44
45 class UserMessageIteratorTestCase(unittest.TestCase):
46 def test_init(self):
47 the_output_port_from_source = None
48 the_output_port_from_iter = None
49
50 class MyIter(bt2._UserMessageIterator):
51 def __init__(self, config, self_port_output):
52 nonlocal initialized
53 nonlocal the_output_port_from_iter
54 initialized = True
55 the_output_port_from_iter = self_port_output
56
57 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
58 def __init__(self, config, params, obj):
59 nonlocal the_output_port_from_source
60 the_output_port_from_source = self._add_output_port("out", "user data")
61
62 initialized = False
63 graph = _create_graph(MySource, SimpleSink)
64 graph.run()
65 self.assertTrue(initialized)
66 self.assertEqual(
67 the_output_port_from_source.addr, the_output_port_from_iter.addr
68 )
69 self.assertEqual(the_output_port_from_iter.user_data, "user data")
70
71 def test_create_from_message_iterator(self):
72 class MySourceIter(bt2._UserMessageIterator):
73 def __init__(self, config, self_port_output):
74 nonlocal src_iter_initialized
75 src_iter_initialized = True
76
77 class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
78 def __init__(self, config, params, obj):
79 self._add_output_port("out")
80
81 class MyFilterIter(bt2._UserMessageIterator):
82 def __init__(self, config, self_port_output):
83 nonlocal flt_iter_initialized
84 flt_iter_initialized = True
85 self._up_iter = self._create_message_iterator(
86 self._component._input_ports["in"]
87 )
88
89 def __next__(self):
90 return next(self._up_iter)
91
92 class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
93 def __init__(self, config, params, obj):
94 self._add_input_port("in")
95 self._add_output_port("out")
96
97 src_iter_initialized = False
98 flt_iter_initialized = False
99 graph = _create_graph(MySource, SimpleSink, MyFilter)
100 graph.run()
101 self.assertTrue(src_iter_initialized)
102 self.assertTrue(flt_iter_initialized)
103
104 # Test that creating a message iterator from a sink component on a
105 # non-connected inport port raises.
106 def test_create_from_sink_component_unconnected_port_raises(self):
107 class MySink(bt2._UserSinkComponent):
108 def __init__(comp_self, config, params, obj):
109 comp_self._input_port = comp_self._add_input_port("in")
110
111 def _user_graph_is_configured(comp_self):
112 with self.assertRaisesRegex(ValueError, "input port is not connected"):
113 comp_self._create_message_iterator(comp_self._input_port)
114
115 nonlocal seen
116 seen = True
117
118 def _user_consume(self):
119 raise bt2.Stop
120
121 seen = False
122 graph = bt2.Graph()
123 graph.add_component(MySink, "snk")
124 graph.run()
125 self.assertTrue(seen)
126
127 # Test that creating a message iterator from a message iteartor on a
128 # non-connected inport port raises.
129 def test_create_from_message_iterator_unconnected_port_raises(self):
130 class MyFilterIter(bt2._UserMessageIterator):
131 def __init__(iter_self, config, port):
132 input_port = iter_self._component._input_ports["in"]
133
134 with self.assertRaisesRegex(ValueError, "input port is not connected"):
135 iter_self._create_message_iterator(input_port)
136
137 nonlocal seen
138 seen = True
139
140 class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
141 def __init__(comp_self, config, params, obj):
142 comp_self._add_input_port("in")
143 comp_self._add_output_port("out")
144
145 class MySink(bt2._UserSinkComponent):
146 def __init__(comp_self, config, params, obj):
147 comp_self._input_port = comp_self._add_input_port("in")
148
149 def _user_graph_is_configured(comp_self):
150 comp_self._input_iter = comp_self._create_message_iterator(
151 comp_self._input_port
152 )
153
154 def _user_consume(self):
155 raise bt2.Stop
156
157 seen = False
158 graph = bt2.Graph()
159 flt = graph.add_component(MyFilter, "flt")
160 snk = graph.add_component(MySink, "snk")
161 graph.connect_ports(flt.output_ports["out"], snk.input_ports["in"])
162 graph.run()
163 self.assertTrue(seen)
164
165 def test_create_user_error(self):
166 # This tests both error handling by
167 # _UserSinkComponent._create_message_iterator
168 # and _UserMessageIterator._create_message_iterator, as they
169 # are both used in the graph.
170 class MySourceIter(bt2._UserMessageIterator):
171 def __init__(self, config, self_port_output):
172 raise ValueError("Very bad error")
173
174 class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
175 def __init__(self, config, params, obj):
176 self._add_output_port("out")
177
178 class MyFilterIter(bt2._UserMessageIterator):
179 def __init__(self, config, self_port_output):
180 # This is expected to raise because of the error in
181 # MySourceIter.__init__.
182 self._create_message_iterator(self._component._input_ports["in"])
183
184 class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
185 def __init__(self, config, params, obj):
186 self._add_input_port("in")
187 self._add_output_port("out")
188
189 graph = _create_graph(MySource, SimpleSink, MyFilter)
190
191 with self.assertRaises(bt2._Error) as ctx:
192 graph.run()
193
194 exc = ctx.exception
195 cause = exc[0]
196
197 self.assertIsInstance(cause, bt2._MessageIteratorErrorCause)
198 self.assertEqual(cause.component_name, "src")
199 self.assertEqual(cause.component_output_port_name, "out")
200 self.assertIn("ValueError: Very bad error", cause.message)
201
202 def test_finalize(self):
203 class MyIter(bt2._UserMessageIterator):
204 def _user_finalize(self):
205 nonlocal finalized
206 finalized = True
207
208 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
209 def __init__(self, config, params, obj):
210 self._add_output_port("out")
211
212 finalized = False
213 graph = _create_graph(MySource, SimpleSink)
214 graph.run()
215 del graph
216 self.assertTrue(finalized)
217
218 def test_config_parameter(self):
219 class MyIter(bt2._UserMessageIterator):
220 def __init__(self, config, port):
221 nonlocal config_type
222 config_type = type(config)
223
224 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
225 def __init__(self, config, params, obj):
226 self._add_output_port("out")
227
228 config_type = None
229 graph = _create_graph(MySource, SimpleSink)
230 graph.run()
231 self.assertIs(config_type, bt2_message_iterator._MessageIteratorConfiguration)
232
233 def _test_config_can_seek_forward(self, set_can_seek_forward):
234 class MyIter(bt2._UserMessageIterator):
235 def __init__(self, config, port):
236 if set_can_seek_forward:
237 config.can_seek_forward = True
238
239 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
240 def __init__(self, config, params, obj):
241 self._add_output_port("out")
242
243 class MySink(bt2._UserSinkComponent):
244 def __init__(self, config, params, obj):
245 self._add_input_port("in")
246
247 def _user_graph_is_configured(self):
248 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
249
250 def _user_consume(self):
251 nonlocal can_seek_forward
252 can_seek_forward = self._msg_iter.can_seek_forward
253
254 can_seek_forward = None
255 graph = _create_graph(MySource, MySink)
256 graph.run_once()
257 self.assertIs(can_seek_forward, set_can_seek_forward)
258
259 def test_config_can_seek_forward_default(self):
260 self._test_config_can_seek_forward(False)
261
262 def test_config_can_seek_forward(self):
263 self._test_config_can_seek_forward(True)
264
265 def test_config_can_seek_forward_wrong_type(self):
266 class MyIter(bt2._UserMessageIterator):
267 def __init__(self, config, port):
268 config.can_seek_forward = 1
269
270 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
271 def __init__(self, config, params, obj):
272 self._add_output_port("out")
273
274 graph = _create_graph(MySource, SimpleSink)
275 with self.assertRaises(bt2._Error) as ctx:
276 graph.run()
277
278 root_cause = ctx.exception[0]
279 self.assertIn("TypeError: 'int' is not a 'bool' object", root_cause.message)
280
281 def test_component(self):
282 class MyIter(bt2._UserMessageIterator):
283 def __init__(self, config, self_port_output):
284 nonlocal salut
285 salut = self._component._salut
286
287 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
288 def __init__(self, config, params, obj):
289 self._add_output_port("out")
290 self._salut = 23
291
292 salut = None
293 graph = _create_graph(MySource, SimpleSink)
294 graph.run()
295 self.assertEqual(salut, 23)
296
297 def test_port(self):
298 class MyIter(bt2._UserMessageIterator):
299 def __init__(self_iter, config, self_port_output):
300 nonlocal called
301 called = True
302 port = self_iter._port
303 self.assertIs(type(self_port_output), bt2_port._UserComponentOutputPort)
304 self.assertIs(type(port), bt2_port._UserComponentOutputPort)
305 self.assertEqual(self_port_output.addr, port.addr)
306
307 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
308 def __init__(self, config, params, obj):
309 self._add_output_port("out")
310
311 called = False
312 graph = _create_graph(MySource, SimpleSink)
313 graph.run()
314 self.assertTrue(called)
315
316 def test_addr(self):
317 class MyIter(bt2._UserMessageIterator):
318 def __init__(self, config, self_port_output):
319 nonlocal addr
320 addr = self.addr
321
322 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
323 def __init__(self, config, params, obj):
324 self._add_output_port("out")
325
326 addr = None
327 graph = _create_graph(MySource, SimpleSink)
328 graph.run()
329 self.assertIsNotNone(addr)
330 self.assertNotEqual(addr, 0)
331
332 # Test that messages returned by _UserMessageIterator.__next__ remain valid
333 # and can be re-used.
334 def test_reuse_message(self):
335 class MyIter(bt2._UserMessageIterator):
336 def __init__(self, config, port):
337 tc, sc, ec = port.user_data
338 trace = tc()
339 stream = trace.create_stream(sc)
340 packet = stream.create_packet()
341
342 # This message will be returned twice by __next__.
343 event_message = self._create_event_message(ec, packet)
344
345 self._msgs = [
346 self._create_stream_beginning_message(stream),
347 self._create_packet_beginning_message(packet),
348 event_message,
349 event_message,
350 ]
351
352 def __next__(self):
353 return self._msgs.pop(0)
354
355 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
356 def __init__(self, config, params, obj):
357 tc = self._create_trace_class()
358 sc = tc.create_stream_class(supports_packets=True)
359 ec = sc.create_event_class()
360 self._add_output_port("out", (tc, sc, ec))
361
362 graph = bt2.Graph()
363 src = graph.add_component(MySource, "src")
364 it = TestOutputPortMessageIterator(graph, src.output_ports["out"])
365
366 # Skip beginning messages.
367 msg = next(it)
368 self.assertIs(type(msg), bt2._StreamBeginningMessageConst)
369 msg = next(it)
370 self.assertIs(type(msg), bt2._PacketBeginningMessageConst)
371
372 msg_ev1 = next(it)
373 msg_ev2 = next(it)
374
375 self.assertIs(type(msg_ev1), bt2._EventMessageConst)
376 self.assertIs(type(msg_ev2), bt2._EventMessageConst)
377 self.assertEqual(msg_ev1.addr, msg_ev2.addr)
378
379 # Try consuming many times from an iterator that always returns TryAgain.
380 # This verifies that we are not missing an incref of Py_None, making the
381 # refcount of Py_None reach 0.
382 def test_try_again_many_times(self):
383 # Starting with Python 3.12, `None` is immortal: its reference
384 # count operations are no-op. Skip this test in that case.
385 before = sys.getrefcount(None)
386 dummy = None # noqa: F841
387
388 if before == sys.getrefcount(None):
389 raise unittest.SkipTest("`None` is immortal")
390
391 class MyIter(bt2._UserMessageIterator):
392 def __next__(self):
393 raise bt2.TryAgain
394
395 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
396 def __init__(self, config, params, obj):
397 self._add_output_port("out")
398
399 graph = bt2.Graph()
400 src = graph.add_component(MySource, "src")
401 it = TestOutputPortMessageIterator(graph, src.output_ports["out"])
402
403 # Three times the initial ref count of `None` iterations should
404 # be enough to catch the bug even if there are small differences
405 # between configurations.
406 none_ref_count = sys.getrefcount(None) * 3
407
408 for i in range(none_ref_count):
409 with self.assertRaises(bt2.TryAgain):
410 next(it)
411
412 def test_error_in_iterator_with_cycle_after_having_created_upstream_iterator(self):
413 # Test a failure that triggered an abort in libbabeltrace2, in this situation:
414 #
415 # - The filter iterator creates an upstream iterator.
416 # - The filter iterator creates a reference cycle, including itself.
417 # - An exception is raised, causing the filter iterator's
418 # initialization method to fail.
419 class MySourceIter(bt2._UserMessageIterator):
420 pass
421
422 class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
423 def __init__(self, config, params, obj):
424 self._add_output_port("out")
425
426 class MyFilterIter(bt2._UserMessageIterator):
427 def __init__(self, config, port):
428 # First, create an upstream iterator.
429 self._upstream_iter = self._create_message_iterator(
430 self._component._input_ports["in"]
431 )
432
433 # Then, voluntarily make a reference cycle that will keep this
434 # Python object alive, which will keep the upstream iterator
435 # Babeltrace object alive.
436 self._self = self
437
438 # Finally, raise an exception to make __init__ fail.
439 raise ValueError("woops")
440
441 class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
442 def __init__(self, config, params, obj):
443 self._in = self._add_input_port("in")
444 self._out = self._add_output_port("out")
445
446 class MySink(bt2._UserSinkComponent):
447 def __init__(self, config, params, obj):
448 self._input_port = self._add_input_port("in")
449
450 def _user_graph_is_configured(self):
451 self._upstream_iter = self._create_message_iterator(self._input_port)
452
453 def _user_consume(self):
454 # We should not reach this.
455 assert False
456
457 g = bt2.Graph()
458 src = g.add_component(MySource, "src")
459 flt = g.add_component(MyFilter, "flt")
460 snk = g.add_component(MySink, "snk")
461 g.connect_ports(src.output_ports["out"], flt.input_ports["in"])
462 g.connect_ports(flt.output_ports["out"], snk.input_ports["in"])
463
464 with self.assertRaisesRegex(bt2._Error, "ValueError: woops"):
465 g.run()
466
467
468 def _setup_seek_test(
469 sink_cls,
470 user_seek_beginning=None,
471 user_can_seek_beginning=None,
472 user_seek_ns_from_origin=None,
473 user_can_seek_ns_from_origin=None,
474 can_seek_forward=False,
475 ):
476 class MySourceIter(bt2._UserMessageIterator):
477 def __init__(self, config, port):
478 tc, sc, ec = port.user_data
479 trace = tc()
480 stream = trace.create_stream(sc)
481 packet = stream.create_packet()
482
483 self._msgs = [
484 self._create_stream_beginning_message(stream),
485 self._create_packet_beginning_message(packet),
486 self._create_event_message(ec, packet),
487 self._create_event_message(ec, packet),
488 self._create_packet_end_message(packet),
489 self._create_stream_end_message(stream),
490 ]
491 self._at = 0
492 config.can_seek_forward = can_seek_forward
493
494 def __next__(self):
495 if self._at < len(self._msgs):
496 msg = self._msgs[self._at]
497 self._at += 1
498 return msg
499 else:
500 raise StopIteration
501
502 if user_seek_beginning is not None:
503 MySourceIter._user_seek_beginning = user_seek_beginning
504
505 if user_can_seek_beginning is not None:
506 MySourceIter._user_can_seek_beginning = user_can_seek_beginning
507
508 if user_seek_ns_from_origin is not None:
509 MySourceIter._user_seek_ns_from_origin = user_seek_ns_from_origin
510
511 if user_can_seek_ns_from_origin is not None:
512 MySourceIter._user_can_seek_ns_from_origin = user_can_seek_ns_from_origin
513
514 class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
515 def __init__(self, config, params, obj):
516 tc = self._create_trace_class()
517 sc = tc.create_stream_class(supports_packets=True)
518 ec = sc.create_event_class()
519
520 self._add_output_port("out", (tc, sc, ec))
521
522 class MyFilterIter(bt2._UserMessageIterator):
523 def __init__(self, config, port):
524 self._upstream_iter = self._create_message_iterator(
525 self._component._input_ports["in"]
526 )
527 config.can_seek_forward = self._upstream_iter.can_seek_forward
528
529 def __next__(self):
530 return next(self._upstream_iter)
531
532 def _user_can_seek_beginning(self):
533 return self._upstream_iter.can_seek_beginning()
534
535 def _user_seek_beginning(self):
536 self._upstream_iter.seek_beginning()
537
538 def _user_can_seek_ns_from_origin(self, ns_from_origin):
539 return self._upstream_iter.can_seek_ns_from_origin(ns_from_origin)
540
541 def _user_seek_ns_from_origin(self, ns_from_origin):
542 self._upstream_iter.seek_ns_from_origin(ns_from_origin)
543
544 class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
545 def __init__(self, config, params, obj):
546 self._add_input_port("in")
547 self._add_output_port("out")
548
549 return _create_graph(MySource, sink_cls, flt_comp_cls=MyFilter)
550
551
552 class UserMessageIteratorSeekBeginningTestCase(unittest.TestCase):
553 def test_can_seek_beginning_without_seek_beginning(self):
554 with self.assertRaisesRegex(
555 bt2._IncompleteUserClass,
556 "cannot create component class 'MySource': message iterator class implements _user_can_seek_beginning but not _user_seek_beginning",
557 ):
558 _setup_seek_test(SimpleSink, user_can_seek_beginning=lambda: None)
559
560 def test_can_seek_beginning(self):
561 class MySink(bt2._UserSinkComponent):
562 def __init__(self, config, params, obj):
563 self._add_input_port("in")
564
565 def _user_graph_is_configured(self):
566 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
567
568 def _user_consume(self):
569 nonlocal can_seek_beginning
570 can_seek_beginning = self._msg_iter.can_seek_beginning()
571
572 def _user_can_seek_beginning(self):
573 nonlocal input_port_iter_can_seek_beginning
574 return input_port_iter_can_seek_beginning
575
576 graph = _setup_seek_test(
577 MySink,
578 user_can_seek_beginning=_user_can_seek_beginning,
579 user_seek_beginning=lambda: None,
580 )
581
582 input_port_iter_can_seek_beginning = True
583 can_seek_beginning = None
584 graph.run_once()
585 self.assertIs(can_seek_beginning, True)
586
587 input_port_iter_can_seek_beginning = False
588 can_seek_beginning = None
589 graph.run_once()
590 self.assertIs(can_seek_beginning, False)
591
592 def test_no_can_seek_beginning_with_seek_beginning(self):
593 # Test an iterator without a _user_can_seek_beginning method, but with
594 # a _user_seek_beginning method.
595 class MySink(bt2._UserSinkComponent):
596 def __init__(self, config, params, obj):
597 self._add_input_port("in")
598
599 def _user_graph_is_configured(self):
600 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
601
602 def _user_consume(self):
603 nonlocal can_seek_beginning
604 can_seek_beginning = self._msg_iter.can_seek_beginning()
605
606 def _user_seek_beginning(self):
607 pass
608
609 graph = _setup_seek_test(MySink, user_seek_beginning=_user_seek_beginning)
610 can_seek_beginning = None
611 graph.run_once()
612 self.assertIs(can_seek_beginning, True)
613
614 def test_no_can_seek_beginning(self):
615 # Test an iterator without a _user_can_seek_beginning method, without
616 # a _user_seek_beginning method.
617 class MySink(bt2._UserSinkComponent):
618 def __init__(self, config, params, obj):
619 self._add_input_port("in")
620
621 def _user_graph_is_configured(self):
622 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
623
624 def _user_consume(self):
625 nonlocal can_seek_beginning
626 can_seek_beginning = self._msg_iter.can_seek_beginning()
627
628 graph = _setup_seek_test(MySink)
629 can_seek_beginning = None
630 graph.run_once()
631 self.assertIs(can_seek_beginning, False)
632
633 def test_can_seek_beginning_user_error(self):
634 class MySink(bt2._UserSinkComponent):
635 def __init__(self, config, params, obj):
636 self._add_input_port("in")
637
638 def _user_graph_is_configured(self):
639 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
640
641 def _user_consume(self):
642 # This is expected to raise.
643 self._msg_iter.can_seek_beginning()
644
645 def _user_can_seek_beginning(self):
646 raise ValueError("moustiquaire")
647
648 graph = _setup_seek_test(
649 MySink,
650 user_can_seek_beginning=_user_can_seek_beginning,
651 user_seek_beginning=lambda: None,
652 )
653
654 with self.assertRaises(bt2._Error) as ctx:
655 graph.run_once()
656
657 cause = ctx.exception[0]
658 self.assertIn("ValueError: moustiquaire", cause.message)
659
660 def test_can_seek_beginning_wrong_return_value(self):
661 class MySink(bt2._UserSinkComponent):
662 def __init__(self, config, params, obj):
663 self._add_input_port("in")
664
665 def _user_graph_is_configured(self):
666 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
667
668 def _user_consume(self):
669 # This is expected to raise.
670 self._msg_iter.can_seek_beginning()
671
672 def _user_can_seek_beginning(self):
673 return "Amqui"
674
675 graph = _setup_seek_test(
676 MySink,
677 user_can_seek_beginning=_user_can_seek_beginning,
678 user_seek_beginning=lambda: None,
679 )
680
681 with self.assertRaises(bt2._Error) as ctx:
682 graph.run_once()
683
684 cause = ctx.exception[0]
685 self.assertIn("TypeError: 'str' is not a 'bool' object", cause.message)
686
687 def test_seek_beginning(self):
688 class MySink(bt2._UserSinkComponent):
689 def __init__(self, config, params, obj):
690 self._add_input_port("in")
691
692 def _user_graph_is_configured(self):
693 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
694
695 def _user_consume(self):
696 nonlocal do_seek_beginning
697 nonlocal msg
698
699 if do_seek_beginning:
700 self._msg_iter.seek_beginning()
701 return
702
703 msg = next(self._msg_iter)
704
705 def _user_seek_beginning(self):
706 self._at = 0
707
708 msg = None
709 graph = _setup_seek_test(MySink, user_seek_beginning=_user_seek_beginning)
710
711 # Consume message.
712 do_seek_beginning = False
713 graph.run_once()
714 self.assertIs(type(msg), bt2._StreamBeginningMessageConst)
715
716 # Consume message.
717 graph.run_once()
718 self.assertIs(type(msg), bt2._PacketBeginningMessageConst)
719
720 # Seek beginning.
721 do_seek_beginning = True
722 graph.run_once()
723
724 # Consume message.
725 do_seek_beginning = False
726 graph.run_once()
727 self.assertIs(type(msg), bt2._StreamBeginningMessageConst)
728
729 def test_seek_beginning_user_error(self):
730 class MySink(bt2._UserSinkComponent):
731 def __init__(self, config, params, obj):
732 self._add_input_port("in")
733
734 def _user_graph_is_configured(self):
735 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
736
737 def _user_consume(self):
738 self._msg_iter.seek_beginning()
739
740 def _user_seek_beginning(self):
741 raise ValueError("ouch")
742
743 graph = _setup_seek_test(MySink, user_seek_beginning=_user_seek_beginning)
744
745 with self.assertRaises(bt2._Error):
746 graph.run_once()
747
748
749 class UserMessageIteratorSeekNsFromOriginTestCase(unittest.TestCase):
750 def test_can_seek_ns_from_origin_without_seek_ns_from_origin(self):
751 # Test the case where:
752 #
753 # - can_seek_ns_from_origin: Returns True (don't really care, as long
754 # as it's provided)
755 # - seek_ns_from_origin provided: No
756 # - can the iterator seek beginning: Don't care
757 # - can the iterator seek forward: Don't care
758 for can_seek_ns_from_origin in (False, True):
759 for iter_can_seek_beginning in (False, True):
760 for iter_can_seek_forward in (False, True):
761 with self.assertRaisesRegex(
762 bt2._IncompleteUserClass,
763 "cannot create component class 'MySource': message iterator class implements _user_can_seek_ns_from_origin but not _user_seek_ns_from_origin",
764 ):
765 self._can_seek_ns_from_origin_test(
766 None,
767 user_can_seek_ns_from_origin_ret_val=True,
768 user_seek_ns_from_origin_provided=False,
769 iter_can_seek_beginning=iter_can_seek_beginning,
770 iter_can_seek_forward=iter_can_seek_forward,
771 )
772
773 def test_can_seek_ns_from_origin_returns_true(self):
774 # Test the case where:
775 #
776 # - can_seek_ns_from_origin: returns True
777 # - seek_ns_from_origin provided: Yes
778 # - can the iterator seek beginning: Don't care
779 # - can the iterator seek forward: Don't care
780 #
781 # We expect iter.can_seek_ns_from_origin to return True.
782 for iter_can_seek_beginning in (False, True):
783 for iter_can_seek_forward in (False, True):
784 self._can_seek_ns_from_origin_test(
785 expected_outcome=True,
786 user_can_seek_ns_from_origin_ret_val=True,
787 user_seek_ns_from_origin_provided=True,
788 iter_can_seek_beginning=iter_can_seek_beginning,
789 iter_can_seek_forward=iter_can_seek_forward,
790 )
791
792 def test_can_seek_ns_from_origin_returns_false_can_seek_beginning_forward_seekable(
793 self,
794 ):
795 # Test the case where:
796 #
797 # - can_seek_ns_from_origin: returns False
798 # - seek_ns_from_origin provided: Yes
799 # - can the iterator seek beginning: Yes
800 # - can the iterator seek forward: Yes
801 #
802 # We expect iter.can_seek_ns_from_origin to return True.
803 self._can_seek_ns_from_origin_test(
804 expected_outcome=True,
805 user_can_seek_ns_from_origin_ret_val=False,
806 user_seek_ns_from_origin_provided=True,
807 iter_can_seek_beginning=True,
808 iter_can_seek_forward=True,
809 )
810
811 def test_can_seek_ns_from_origin_returns_false_can_seek_beginning_not_forward_seekable(
812 self,
813 ):
814 # Test the case where:
815 #
816 # - can_seek_ns_from_origin: returns False
817 # - seek_ns_from_origin provided: Yes
818 # - can the iterator seek beginning: Yes
819 # - can the iterator seek forward: No
820 #
821 # We expect iter.can_seek_ns_from_origin to return False.
822 self._can_seek_ns_from_origin_test(
823 expected_outcome=False,
824 user_can_seek_ns_from_origin_ret_val=False,
825 user_seek_ns_from_origin_provided=True,
826 iter_can_seek_beginning=True,
827 iter_can_seek_forward=False,
828 )
829
830 def test_can_seek_ns_from_origin_returns_false_cant_seek_beginning_forward_seekable(
831 self,
832 ):
833 # Test the case where:
834 #
835 # - can_seek_ns_from_origin: returns False
836 # - seek_ns_from_origin provided: Yes
837 # - can the iterator seek beginning: No
838 # - can the iterator seek forward: Yes
839 #
840 # We expect iter.can_seek_ns_from_origin to return False.
841 self._can_seek_ns_from_origin_test(
842 expected_outcome=False,
843 user_can_seek_ns_from_origin_ret_val=False,
844 user_seek_ns_from_origin_provided=True,
845 iter_can_seek_beginning=False,
846 iter_can_seek_forward=True,
847 )
848
849 def test_can_seek_ns_from_origin_returns_false_cant_seek_beginning_not_forward_seekable(
850 self,
851 ):
852 # Test the case where:
853 #
854 # - can_seek_ns_from_origin: returns False
855 # - seek_ns_from_origin provided: Yes
856 # - can the iterator seek beginning: No
857 # - can the iterator seek forward: No
858 #
859 # We expect iter.can_seek_ns_from_origin to return False.
860 self._can_seek_ns_from_origin_test(
861 expected_outcome=False,
862 user_can_seek_ns_from_origin_ret_val=False,
863 user_seek_ns_from_origin_provided=True,
864 iter_can_seek_beginning=False,
865 iter_can_seek_forward=False,
866 )
867
868 def test_no_can_seek_ns_from_origin_seek_ns_from_origin(self):
869 # Test the case where:
870 #
871 # - can_seek_ns_from_origin: Not provided
872 # - seek_ns_from_origin provided: Yes
873 # - can the iterator seek beginning: Don't care
874 # - can the iterator seek forward: Don't care
875 #
876 # We expect iter.can_seek_ns_from_origin to return True.
877 for iter_can_seek_beginning in (False, True):
878 for iter_can_seek_forward in (False, True):
879 self._can_seek_ns_from_origin_test(
880 expected_outcome=True,
881 user_can_seek_ns_from_origin_ret_val=None,
882 user_seek_ns_from_origin_provided=True,
883 iter_can_seek_beginning=iter_can_seek_beginning,
884 iter_can_seek_forward=iter_can_seek_forward,
885 )
886
887 def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_can_seek_beginning_forward_seekable(
888 self,
889 ):
890 # Test the case where:
891 #
892 # - can_seek_ns_from_origin: Not provided
893 # - seek_ns_from_origin provided: Not provided
894 # - can the iterator seek beginning: Yes
895 # - can the iterator seek forward: Yes
896 #
897 # We expect iter.can_seek_ns_from_origin to return True.
898 self._can_seek_ns_from_origin_test(
899 expected_outcome=True,
900 user_can_seek_ns_from_origin_ret_val=None,
901 user_seek_ns_from_origin_provided=False,
902 iter_can_seek_beginning=True,
903 iter_can_seek_forward=True,
904 )
905
906 def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_can_seek_beginning_not_forward_seekable(
907 self,
908 ):
909 # Test the case where:
910 #
911 # - can_seek_ns_from_origin: Not provided
912 # - seek_ns_from_origin provided: Not provided
913 # - can the iterator seek beginning: Yes
914 # - can the iterator seek forward: No
915 #
916 # We expect iter.can_seek_ns_from_origin to return False.
917 self._can_seek_ns_from_origin_test(
918 expected_outcome=False,
919 user_can_seek_ns_from_origin_ret_val=None,
920 user_seek_ns_from_origin_provided=False,
921 iter_can_seek_beginning=True,
922 iter_can_seek_forward=False,
923 )
924
925 def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_cant_seek_beginning_forward_seekable(
926 self,
927 ):
928 # Test the case where:
929 #
930 # - can_seek_ns_from_origin: Not provided
931 # - seek_ns_from_origin provided: Not provided
932 # - can the iterator seek beginning: No
933 # - can the iterator seek forward: Yes
934 #
935 # We expect iter.can_seek_ns_from_origin to return False.
936 self._can_seek_ns_from_origin_test(
937 expected_outcome=False,
938 user_can_seek_ns_from_origin_ret_val=None,
939 user_seek_ns_from_origin_provided=False,
940 iter_can_seek_beginning=False,
941 iter_can_seek_forward=True,
942 )
943
944 def test_no_can_seek_ns_from_origin_no_seek_ns_from_origin_cant_seek_beginning_not_forward_seekable(
945 self,
946 ):
947 # Test the case where:
948 #
949 # - can_seek_ns_from_origin: Not provided
950 # - seek_ns_from_origin provided: Not provided
951 # - can the iterator seek beginning: No
952 # - can the iterator seek forward: No
953 #
954 # We expect iter.can_seek_ns_from_origin to return False.
955 self._can_seek_ns_from_origin_test(
956 expected_outcome=False,
957 user_can_seek_ns_from_origin_ret_val=None,
958 user_seek_ns_from_origin_provided=False,
959 iter_can_seek_beginning=False,
960 iter_can_seek_forward=False,
961 )
962
963 def _can_seek_ns_from_origin_test(
964 self,
965 expected_outcome,
966 user_can_seek_ns_from_origin_ret_val,
967 user_seek_ns_from_origin_provided,
968 iter_can_seek_beginning,
969 iter_can_seek_forward,
970 ):
971 class MySink(bt2._UserSinkComponent):
972 def __init__(self, config, params, obj):
973 self._add_input_port("in")
974
975 def _user_graph_is_configured(self):
976 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
977
978 def _user_consume(self):
979 nonlocal can_seek_ns_from_origin
980 can_seek_ns_from_origin = self._msg_iter.can_seek_ns_from_origin(
981 passed_ns_from_origin
982 )
983
984 if user_can_seek_ns_from_origin_ret_val is not None:
985
986 def user_can_seek_ns_from_origin(self, ns_from_origin):
987 nonlocal received_ns_from_origin
988 received_ns_from_origin = ns_from_origin
989 return user_can_seek_ns_from_origin_ret_val
990
991 else:
992 user_can_seek_ns_from_origin = None
993
994 if user_seek_ns_from_origin_provided:
995
996 def user_seek_ns_from_origin(self, ns_from_origin):
997 pass
998
999 else:
1000 user_seek_ns_from_origin = None
1001
1002 if iter_can_seek_beginning:
1003
1004 def user_seek_beginning(self):
1005 pass
1006
1007 else:
1008 user_seek_beginning = None
1009
1010 graph = _setup_seek_test(
1011 MySink,
1012 user_can_seek_ns_from_origin=user_can_seek_ns_from_origin,
1013 user_seek_ns_from_origin=user_seek_ns_from_origin,
1014 user_seek_beginning=user_seek_beginning,
1015 can_seek_forward=iter_can_seek_forward,
1016 )
1017
1018 passed_ns_from_origin = 77
1019 received_ns_from_origin = None
1020 can_seek_ns_from_origin = None
1021 graph.run_once()
1022 self.assertIs(can_seek_ns_from_origin, expected_outcome)
1023
1024 if user_can_seek_ns_from_origin_ret_val is not None:
1025 self.assertEqual(received_ns_from_origin, passed_ns_from_origin)
1026
1027 def test_can_seek_ns_from_origin_user_error(self):
1028 class MySink(bt2._UserSinkComponent):
1029 def __init__(self, config, params, obj):
1030 self._add_input_port("in")
1031
1032 def _user_graph_is_configured(self):
1033 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
1034
1035 def _user_consume(self):
1036 # This is expected to raise.
1037 self._msg_iter.can_seek_ns_from_origin(2)
1038
1039 def _user_can_seek_ns_from_origin(self, ns_from_origin):
1040 raise ValueError("Joutel")
1041
1042 graph = _setup_seek_test(
1043 MySink,
1044 user_can_seek_ns_from_origin=_user_can_seek_ns_from_origin,
1045 user_seek_ns_from_origin=lambda: None,
1046 )
1047
1048 with self.assertRaises(bt2._Error) as ctx:
1049 graph.run_once()
1050
1051 cause = ctx.exception[0]
1052 self.assertIn("ValueError: Joutel", cause.message)
1053
1054 def test_can_seek_ns_from_origin_wrong_return_value(self):
1055 class MySink(bt2._UserSinkComponent):
1056 def __init__(self, config, params, obj):
1057 self._add_input_port("in")
1058
1059 def _user_graph_is_configured(self):
1060 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
1061
1062 def _user_consume(self):
1063 # This is expected to raise.
1064 self._msg_iter.can_seek_ns_from_origin(2)
1065
1066 def _user_can_seek_ns_from_origin(self, ns_from_origin):
1067 return "Nitchequon"
1068
1069 graph = _setup_seek_test(
1070 MySink,
1071 user_can_seek_ns_from_origin=_user_can_seek_ns_from_origin,
1072 user_seek_ns_from_origin=lambda: None,
1073 )
1074
1075 with self.assertRaises(bt2._Error) as ctx:
1076 graph.run_once()
1077
1078 cause = ctx.exception[0]
1079 self.assertIn("TypeError: 'str' is not a 'bool' object", cause.message)
1080
1081 def test_seek_ns_from_origin(self):
1082 class MySink(bt2._UserSinkComponent):
1083 def __init__(self, config, params, obj):
1084 self._add_input_port("in")
1085
1086 def _user_graph_is_configured(self):
1087 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
1088
1089 def _user_consume(self):
1090 self._msg_iter.seek_ns_from_origin(17)
1091
1092 def _user_seek_ns_from_origin(self, ns_from_origin):
1093 nonlocal actual_ns_from_origin
1094 actual_ns_from_origin = ns_from_origin
1095
1096 graph = _setup_seek_test(
1097 MySink, user_seek_ns_from_origin=_user_seek_ns_from_origin
1098 )
1099
1100 actual_ns_from_origin = None
1101 graph.run_once()
1102 self.assertEqual(actual_ns_from_origin, 17)
1103
1104
1105 if __name__ == "__main__":
1106 unittest.main()
This page took 0.052286 seconds and 3 git commands to generate.