Commit | Line | Data |
---|---|---|
d2d857a8 MJ |
1 | # |
2 | # Copyright (C) 2019 EfficiOS Inc. | |
3 | # | |
4 | # This program is free software; you can redistribute it and/or | |
5 | # modify it under the terms of the GNU General Public License | |
6 | # as published by the Free Software Foundation; only version 2 | |
7 | # of the License. | |
8 | # | |
9 | # This program is distributed in the hope that it will be useful, | |
10 | # but WITHOUT ANY WARRANTY; without even the implied warranty of | |
11 | # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
12 | # GNU General Public License for more details. | |
13 | # | |
14 | # You should have received a copy of the GNU General Public License | |
15 | # along with this program; if not, write to the Free Software | |
16 | # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. | |
17 | # | |
18 | ||
811644b8 | 19 | import unittest |
811644b8 | 20 | import bt2 |
8e97c333 | 21 | import sys |
6c373cc9 | 22 | from utils import TestOutputPortMessageIterator |
14503fb1 | 23 | from bt2 import port as bt2_port |
811644b8 PP |
24 | |
25 | ||
5602ef81 | 26 | class UserMessageIteratorTestCase(unittest.TestCase): |
811644b8 | 27 | @staticmethod |
ca02df0a | 28 | def _create_graph(src_comp_cls, flt_comp_cls=None): |
811644b8 | 29 | class MySink(bt2._UserSinkComponent): |
66964f3f | 30 | def __init__(self, params, obj): |
811644b8 PP |
31 | self._add_input_port('in') |
32 | ||
6a91742b | 33 | def _user_consume(self): |
5602ef81 | 34 | next(self._msg_iter) |
811644b8 | 35 | |
6a91742b | 36 | def _user_graph_is_configured(self): |
ca02df0a PP |
37 | self._msg_iter = self._create_input_port_message_iterator( |
38 | self._input_ports['in'] | |
39 | ) | |
811644b8 PP |
40 | |
41 | graph = bt2.Graph() | |
42 | src_comp = graph.add_component(src_comp_cls, 'src') | |
ca02df0a PP |
43 | |
44 | if flt_comp_cls is not None: | |
45 | flt_comp = graph.add_component(flt_comp_cls, 'flt') | |
46 | ||
811644b8 | 47 | sink_comp = graph.add_component(MySink, 'sink') |
ca02df0a PP |
48 | |
49 | if flt_comp_cls is not None: | |
50 | assert flt_comp is not None | |
51 | graph.connect_ports( | |
52 | src_comp.output_ports['out'], flt_comp.input_ports['in'] | |
53 | ) | |
54 | out_port = flt_comp.output_ports['out'] | |
55 | else: | |
56 | out_port = src_comp.output_ports['out'] | |
57 | ||
58 | graph.connect_ports(out_port, sink_comp.input_ports['in']) | |
811644b8 PP |
59 | return graph |
60 | ||
61 | def test_init(self): | |
c5f330cd SM |
62 | the_output_port_from_source = None |
63 | the_output_port_from_iter = None | |
64 | ||
5602ef81 | 65 | class MyIter(bt2._UserMessageIterator): |
c5f330cd | 66 | def __init__(self, self_port_output): |
811644b8 | 67 | nonlocal initialized |
c5f330cd | 68 | nonlocal the_output_port_from_iter |
811644b8 | 69 | initialized = True |
c5f330cd | 70 | the_output_port_from_iter = self_port_output |
811644b8 | 71 | |
cfbd7cf3 | 72 | class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): |
66964f3f | 73 | def __init__(self, params, obj): |
c5f330cd | 74 | nonlocal the_output_port_from_source |
2e00bc76 | 75 | the_output_port_from_source = self._add_output_port('out', 'user data') |
811644b8 PP |
76 | |
77 | initialized = False | |
78 | graph = self._create_graph(MySource) | |
c5f330cd | 79 | graph.run() |
811644b8 | 80 | self.assertTrue(initialized) |
cfbd7cf3 FD |
81 | self.assertEqual( |
82 | the_output_port_from_source.addr, the_output_port_from_iter.addr | |
83 | ) | |
2e00bc76 | 84 | self.assertEqual(the_output_port_from_iter.user_data, 'user data') |
811644b8 | 85 | |
ca02df0a PP |
86 | def test_create_from_message_iterator(self): |
87 | class MySourceIter(bt2._UserMessageIterator): | |
88 | def __init__(self, self_port_output): | |
89 | nonlocal src_iter_initialized | |
90 | src_iter_initialized = True | |
91 | ||
92 | class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter): | |
66964f3f | 93 | def __init__(self, params, obj): |
ca02df0a PP |
94 | self._add_output_port('out') |
95 | ||
96 | class MyFilterIter(bt2._UserMessageIterator): | |
97 | def __init__(self, self_port_output): | |
98 | nonlocal flt_iter_initialized | |
99 | flt_iter_initialized = True | |
100 | self._up_iter = self._create_input_port_message_iterator( | |
101 | self._component._input_ports['in'] | |
102 | ) | |
103 | ||
104 | def __next__(self): | |
105 | return next(self._up_iter) | |
106 | ||
107 | class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter): | |
66964f3f | 108 | def __init__(self, params, obj): |
ca02df0a PP |
109 | self._add_input_port('in') |
110 | self._add_output_port('out') | |
111 | ||
112 | src_iter_initialized = False | |
113 | flt_iter_initialized = False | |
114 | graph = self._create_graph(MySource, MyFilter) | |
115 | graph.run() | |
116 | self.assertTrue(src_iter_initialized) | |
117 | self.assertTrue(flt_iter_initialized) | |
118 | ||
e803df70 SM |
119 | def test_create_user_error(self): |
120 | # This tests both error handling by | |
121 | # _UserSinkComponent._create_input_port_message_iterator | |
122 | # and _UserMessageIterator._create_input_port_message_iterator, as they | |
123 | # are both used in the graph. | |
124 | class MySourceIter(bt2._UserMessageIterator): | |
125 | def __init__(self, self_port_output): | |
126 | raise ValueError('Very bad error') | |
127 | ||
128 | class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter): | |
129 | def __init__(self, params, obj): | |
130 | self._add_output_port('out') | |
131 | ||
132 | class MyFilterIter(bt2._UserMessageIterator): | |
133 | def __init__(self, self_port_output): | |
134 | # This is expected to raise because of the error in | |
135 | # MySourceIter.__init__. | |
136 | self._create_input_port_message_iterator( | |
137 | self._component._input_ports['in'] | |
138 | ) | |
139 | ||
140 | class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter): | |
141 | def __init__(self, params, obj): | |
142 | self._add_input_port('in') | |
143 | self._add_output_port('out') | |
144 | ||
145 | graph = self._create_graph(MySource, MyFilter) | |
146 | ||
147 | with self.assertRaises(bt2._Error) as ctx: | |
148 | graph.run() | |
149 | ||
150 | exc = ctx.exception | |
151 | cause = exc[0] | |
152 | ||
153 | self.assertIsInstance(cause, bt2._MessageIteratorErrorCause) | |
154 | self.assertEqual(cause.component_name, 'src') | |
155 | self.assertEqual(cause.component_output_port_name, 'out') | |
156 | self.assertIn('ValueError: Very bad error', cause.message) | |
157 | ||
811644b8 | 158 | def test_finalize(self): |
5602ef81 | 159 | class MyIter(bt2._UserMessageIterator): |
6a91742b | 160 | def _user_finalize(self): |
811644b8 PP |
161 | nonlocal finalized |
162 | finalized = True | |
163 | ||
cfbd7cf3 | 164 | class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): |
66964f3f | 165 | def __init__(self, params, obj): |
811644b8 PP |
166 | self._add_output_port('out') |
167 | ||
168 | finalized = False | |
169 | graph = self._create_graph(MySource) | |
c5f330cd | 170 | graph.run() |
811644b8 PP |
171 | del graph |
172 | self.assertTrue(finalized) | |
173 | ||
174 | def test_component(self): | |
5602ef81 | 175 | class MyIter(bt2._UserMessageIterator): |
c5f330cd | 176 | def __init__(self, self_port_output): |
811644b8 PP |
177 | nonlocal salut |
178 | salut = self._component._salut | |
179 | ||
cfbd7cf3 | 180 | class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): |
66964f3f | 181 | def __init__(self, params, obj): |
811644b8 PP |
182 | self._add_output_port('out') |
183 | self._salut = 23 | |
184 | ||
185 | salut = None | |
186 | graph = self._create_graph(MySource) | |
c5f330cd | 187 | graph.run() |
811644b8 PP |
188 | self.assertEqual(salut, 23) |
189 | ||
14503fb1 SM |
190 | def test_port(self): |
191 | class MyIter(bt2._UserMessageIterator): | |
192 | def __init__(self_iter, self_port_output): | |
193 | nonlocal called | |
194 | called = True | |
195 | port = self_iter._port | |
196 | self.assertIs(type(self_port_output), bt2_port._UserComponentOutputPort) | |
197 | self.assertIs(type(port), bt2_port._UserComponentOutputPort) | |
198 | self.assertEqual(self_port_output.addr, port.addr) | |
199 | ||
200 | class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): | |
201 | def __init__(self, params, obj): | |
202 | self._add_output_port('out') | |
203 | ||
204 | called = False | |
205 | graph = self._create_graph(MySource) | |
206 | graph.run() | |
207 | self.assertTrue(called) | |
208 | ||
811644b8 | 209 | def test_addr(self): |
5602ef81 | 210 | class MyIter(bt2._UserMessageIterator): |
c5f330cd | 211 | def __init__(self, self_port_output): |
811644b8 PP |
212 | nonlocal addr |
213 | addr = self.addr | |
214 | ||
cfbd7cf3 | 215 | class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): |
66964f3f | 216 | def __init__(self, params, obj): |
811644b8 PP |
217 | self._add_output_port('out') |
218 | ||
219 | addr = None | |
220 | graph = self._create_graph(MySource) | |
c5f330cd | 221 | graph.run() |
811644b8 PP |
222 | self.assertIsNotNone(addr) |
223 | self.assertNotEqual(addr, 0) | |
224 | ||
d79a8353 SM |
225 | # Test that messages returned by _UserMessageIterator.__next__ remain valid |
226 | # and can be re-used. | |
227 | def test_reuse_message(self): | |
228 | class MyIter(bt2._UserMessageIterator): | |
229 | def __init__(self, port): | |
230 | tc, sc, ec = port.user_data | |
231 | trace = tc() | |
232 | stream = trace.create_stream(sc) | |
233 | packet = stream.create_packet() | |
234 | ||
235 | # This message will be returned twice by __next__. | |
236 | event_message = self._create_event_message(ec, packet) | |
237 | ||
238 | self._msgs = [ | |
239 | self._create_stream_beginning_message(stream), | |
d79a8353 SM |
240 | self._create_packet_beginning_message(packet), |
241 | event_message, | |
242 | event_message, | |
243 | ] | |
244 | ||
245 | def __next__(self): | |
246 | return self._msgs.pop(0) | |
247 | ||
248 | class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): | |
66964f3f | 249 | def __init__(self, params, obj): |
d79a8353 | 250 | tc = self._create_trace_class() |
26fc5aed | 251 | sc = tc.create_stream_class(supports_packets=True) |
d79a8353 SM |
252 | ec = sc.create_event_class() |
253 | self._add_output_port('out', (tc, sc, ec)) | |
254 | ||
255 | graph = bt2.Graph() | |
256 | src = graph.add_component(MySource, 'src') | |
6c373cc9 | 257 | it = TestOutputPortMessageIterator(graph, src.output_ports['out']) |
d79a8353 SM |
258 | |
259 | # Skip beginning messages. | |
188edac1 | 260 | msg = next(it) |
f0a42b33 | 261 | self.assertIs(type(msg), bt2._StreamBeginningMessageConst) |
188edac1 | 262 | msg = next(it) |
f0a42b33 | 263 | self.assertIs(type(msg), bt2._PacketBeginningMessageConst) |
d79a8353 SM |
264 | |
265 | msg_ev1 = next(it) | |
266 | msg_ev2 = next(it) | |
267 | ||
f0a42b33 FD |
268 | self.assertIs(type(msg_ev1), bt2._EventMessageConst) |
269 | self.assertIs(type(msg_ev2), bt2._EventMessageConst) | |
d79a8353 SM |
270 | self.assertEqual(msg_ev1.addr, msg_ev2.addr) |
271 | ||
f00b8d40 | 272 | @staticmethod |
6c373cc9 | 273 | def _setup_seek_beginning_test(sink_cls): |
f00b8d40 SM |
274 | # Use a source, a filter and an output port iterator. This allows us |
275 | # to test calling `seek_beginning` on both a _OutputPortMessageIterator | |
276 | # and a _UserComponentInputPortMessageIterator, on top of checking that | |
277 | # _UserMessageIterator._seek_beginning is properly called. | |
278 | ||
279 | class MySourceIter(bt2._UserMessageIterator): | |
280 | def __init__(self, port): | |
281 | tc, sc, ec = port.user_data | |
282 | trace = tc() | |
283 | stream = trace.create_stream(sc) | |
284 | packet = stream.create_packet() | |
285 | ||
286 | self._msgs = [ | |
287 | self._create_stream_beginning_message(stream), | |
f00b8d40 SM |
288 | self._create_packet_beginning_message(packet), |
289 | self._create_event_message(ec, packet), | |
290 | self._create_event_message(ec, packet), | |
291 | self._create_packet_end_message(packet), | |
f00b8d40 SM |
292 | self._create_stream_end_message(stream), |
293 | ] | |
294 | self._at = 0 | |
295 | ||
6a91742b | 296 | def _user_seek_beginning(self): |
f00b8d40 SM |
297 | self._at = 0 |
298 | ||
299 | def __next__(self): | |
300 | if self._at < len(self._msgs): | |
301 | msg = self._msgs[self._at] | |
302 | self._at += 1 | |
303 | return msg | |
304 | else: | |
305 | raise StopIteration | |
306 | ||
cfbd7cf3 | 307 | class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter): |
66964f3f | 308 | def __init__(self, params, obj): |
f00b8d40 | 309 | tc = self._create_trace_class() |
26fc5aed | 310 | sc = tc.create_stream_class(supports_packets=True) |
f00b8d40 SM |
311 | ec = sc.create_event_class() |
312 | ||
313 | self._add_output_port('out', (tc, sc, ec)) | |
314 | ||
315 | class MyFilterIter(bt2._UserMessageIterator): | |
316 | def __init__(self, port): | |
317 | input_port = port.user_data | |
ca02df0a PP |
318 | self._upstream_iter = self._create_input_port_message_iterator( |
319 | input_port | |
320 | ) | |
f00b8d40 SM |
321 | |
322 | def __next__(self): | |
323 | return next(self._upstream_iter) | |
324 | ||
6a91742b | 325 | def _user_seek_beginning(self): |
f00b8d40 SM |
326 | self._upstream_iter.seek_beginning() |
327 | ||
328 | @property | |
6a91742b | 329 | def _user_can_seek_beginning(self): |
f00b8d40 SM |
330 | return self._upstream_iter.can_seek_beginning |
331 | ||
332 | class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter): | |
66964f3f | 333 | def __init__(self, params, obj): |
f00b8d40 SM |
334 | input_port = self._add_input_port('in') |
335 | self._add_output_port('out', input_port) | |
336 | ||
f00b8d40 SM |
337 | graph = bt2.Graph() |
338 | src = graph.add_component(MySource, 'src') | |
339 | flt = graph.add_component(MyFilter, 'flt') | |
6c373cc9 | 340 | sink = graph.add_component(sink_cls, 'sink') |
f00b8d40 | 341 | graph.connect_ports(src.output_ports['out'], flt.input_ports['in']) |
6c373cc9 PP |
342 | graph.connect_ports(flt.output_ports['out'], sink.input_ports['in']) |
343 | return MySourceIter, graph | |
f00b8d40 SM |
344 | |
345 | def test_can_seek_beginning(self): | |
6c373cc9 PP |
346 | class MySink(bt2._UserSinkComponent): |
347 | def __init__(self, params, obj): | |
348 | self._add_input_port('in') | |
349 | ||
350 | def _user_graph_is_configured(self): | |
351 | self._msg_iter = self._create_input_port_message_iterator( | |
352 | self._input_ports['in'] | |
353 | ) | |
354 | ||
355 | def _user_consume(self): | |
356 | nonlocal can_seek_beginning | |
357 | can_seek_beginning = self._msg_iter.can_seek_beginning | |
358 | ||
359 | MySourceIter, graph = self._setup_seek_beginning_test(MySink) | |
f00b8d40 | 360 | |
6a91742b | 361 | def _user_can_seek_beginning(self): |
6c373cc9 PP |
362 | nonlocal input_port_iter_can_seek_beginning |
363 | return input_port_iter_can_seek_beginning | |
f00b8d40 | 364 | |
6a91742b | 365 | MySourceIter._user_can_seek_beginning = property(_user_can_seek_beginning) |
f00b8d40 | 366 | |
6c373cc9 PP |
367 | input_port_iter_can_seek_beginning = True |
368 | can_seek_beginning = None | |
369 | graph.run_once() | |
370 | self.assertTrue(can_seek_beginning) | |
f00b8d40 | 371 | |
6c373cc9 PP |
372 | input_port_iter_can_seek_beginning = False |
373 | can_seek_beginning = None | |
374 | graph.run_once() | |
375 | self.assertFalse(can_seek_beginning) | |
f00b8d40 SM |
376 | |
377 | # Once can_seek_beginning returns an error, verify that it raises when | |
378 | # _can_seek_beginning has/returns the wrong type. | |
379 | ||
380 | # Remove the _can_seek_beginning method, we now rely on the presence of | |
381 | # a _seek_beginning method to know whether the iterator can seek to | |
382 | # beginning or not. | |
6a91742b | 383 | del MySourceIter._user_can_seek_beginning |
6c373cc9 PP |
384 | can_seek_beginning = None |
385 | graph.run_once() | |
386 | self.assertTrue(can_seek_beginning) | |
f00b8d40 | 387 | |
6a91742b | 388 | del MySourceIter._user_seek_beginning |
6c373cc9 PP |
389 | can_seek_beginning = None |
390 | graph.run_once() | |
391 | self.assertFalse(can_seek_beginning) | |
f00b8d40 SM |
392 | |
393 | def test_seek_beginning(self): | |
6c373cc9 PP |
394 | class MySink(bt2._UserSinkComponent): |
395 | def __init__(self, params, obj): | |
396 | self._add_input_port('in') | |
f00b8d40 | 397 | |
6c373cc9 PP |
398 | def _user_graph_is_configured(self): |
399 | self._msg_iter = self._create_input_port_message_iterator( | |
400 | self._input_ports['in'] | |
401 | ) | |
402 | ||
403 | def _user_consume(self): | |
404 | nonlocal do_seek_beginning | |
405 | nonlocal msg | |
406 | ||
407 | if do_seek_beginning: | |
408 | self._msg_iter.seek_beginning() | |
409 | return | |
410 | ||
411 | msg = next(self._msg_iter) | |
412 | ||
413 | do_seek_beginning = False | |
414 | msg = None | |
415 | MySourceIter, graph = self._setup_seek_beginning_test(MySink) | |
416 | graph.run_once() | |
f0a42b33 | 417 | self.assertIs(type(msg), bt2._StreamBeginningMessageConst) |
6c373cc9 | 418 | graph.run_once() |
f0a42b33 | 419 | self.assertIs(type(msg), bt2._PacketBeginningMessageConst) |
6c373cc9 PP |
420 | do_seek_beginning = True |
421 | graph.run_once() | |
422 | do_seek_beginning = False | |
423 | graph.run_once() | |
f0a42b33 | 424 | self.assertIs(type(msg), bt2._StreamBeginningMessageConst) |
f00b8d40 | 425 | |
6c373cc9 PP |
426 | def test_seek_beginning_user_error(self): |
427 | class MySink(bt2._UserSinkComponent): | |
428 | def __init__(self, params, obj): | |
429 | self._add_input_port('in') | |
f00b8d40 | 430 | |
6c373cc9 PP |
431 | def _user_graph_is_configured(self): |
432 | self._msg_iter = self._create_input_port_message_iterator( | |
433 | self._input_ports['in'] | |
434 | ) | |
f00b8d40 | 435 | |
6c373cc9 PP |
436 | def _user_consume(self): |
437 | self._msg_iter.seek_beginning() | |
f00b8d40 | 438 | |
6c373cc9 | 439 | MySourceIter, graph = self._setup_seek_beginning_test(MySink) |
f00b8d40 | 440 | |
6a91742b | 441 | def _user_seek_beginning_error(self): |
cfbd7cf3 | 442 | raise ValueError('ouch') |
f00b8d40 | 443 | |
6a91742b | 444 | MySourceIter._user_seek_beginning = _user_seek_beginning_error |
f00b8d40 | 445 | |
694c792b | 446 | with self.assertRaises(bt2._Error): |
6c373cc9 | 447 | graph.run_once() |
f00b8d40 | 448 | |
0361868a SM |
449 | # Try consuming many times from an iterator that always returns TryAgain. |
450 | # This verifies that we are not missing an incref of Py_None, making the | |
451 | # refcount of Py_None reach 0. | |
452 | def test_try_again_many_times(self): | |
453 | class MyIter(bt2._UserMessageIterator): | |
454 | def __next__(self): | |
455 | raise bt2.TryAgain | |
456 | ||
457 | class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter): | |
66964f3f | 458 | def __init__(self, params, obj): |
0361868a SM |
459 | self._add_output_port('out') |
460 | ||
461 | graph = bt2.Graph() | |
462 | src = graph.add_component(MySource, 'src') | |
6c373cc9 | 463 | it = TestOutputPortMessageIterator(graph, src.output_ports['out']) |
0361868a | 464 | |
8e97c333 PP |
465 | # Three times the initial ref count of `None` iterations should |
466 | # be enough to catch the bug even if there are small differences | |
0361868a | 467 | # between configurations. |
8e97c333 PP |
468 | none_ref_count = sys.getrefcount(None) * 3 |
469 | ||
470 | for i in range(none_ref_count): | |
0361868a SM |
471 | with self.assertRaises(bt2.TryAgain): |
472 | next(it) | |
473 | ||
f00b8d40 | 474 | |
f00b8d40 SM |
475 | if __name__ == '__main__': |
476 | unittest.main() |