Apply black code formatter on all Python code
[babeltrace.git] / tests / bindings / python / bt2 / test_message_iterator.py
1 #
2 # Copyright (C) 2019 EfficiOS Inc.
3 #
4 # This program is free software; you can redistribute it and/or
5 # modify it under the terms of the GNU General Public License
6 # as published by the Free Software Foundation; only version 2
7 # of the License.
8 #
9 # This program is distributed in the hope that it will be useful,
10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 # GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License
15 # along with this program; if not, write to the Free Software
16 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
17 #
18
19 from bt2 import value
20 import collections
21 import unittest
22 import copy
23 import bt2
24
25
26 class UserMessageIteratorTestCase(unittest.TestCase):
27 @staticmethod
28 def _create_graph(src_comp_cls):
29 class MySink(bt2._UserSinkComponent):
30 def __init__(self, params):
31 self._add_input_port('in')
32
33 def _consume(self):
34 next(self._msg_iter)
35
36 def _graph_is_configured(self):
37 self._msg_iter = self._input_ports['in'].create_message_iterator()
38
39 graph = bt2.Graph()
40 src_comp = graph.add_component(src_comp_cls, 'src')
41 sink_comp = graph.add_component(MySink, 'sink')
42 graph.connect_ports(src_comp.output_ports['out'], sink_comp.input_ports['in'])
43 return graph
44
45 def test_init(self):
46 the_output_port_from_source = None
47 the_output_port_from_iter = None
48
49 class MyIter(bt2._UserMessageIterator):
50 def __init__(self, self_port_output):
51 nonlocal initialized
52 nonlocal the_output_port_from_iter
53 initialized = True
54 the_output_port_from_iter = self_port_output
55
56 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
57 def __init__(self, params):
58 nonlocal the_output_port_from_source
59 the_output_port_from_source = self._add_output_port('out', 'user data')
60
61 initialized = False
62 graph = self._create_graph(MySource)
63 graph.run()
64 self.assertTrue(initialized)
65 self.assertEqual(
66 the_output_port_from_source.addr, the_output_port_from_iter.addr
67 )
68 self.assertEqual(the_output_port_from_iter.user_data, 'user data')
69
70 def test_finalize(self):
71 class MyIter(bt2._UserMessageIterator):
72 def _finalize(self):
73 nonlocal finalized
74 finalized = True
75
76 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
77 def __init__(self, params):
78 self._add_output_port('out')
79
80 finalized = False
81 graph = self._create_graph(MySource)
82 graph.run()
83 del graph
84 self.assertTrue(finalized)
85
86 def test_component(self):
87 class MyIter(bt2._UserMessageIterator):
88 def __init__(self, self_port_output):
89 nonlocal salut
90 salut = self._component._salut
91
92 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
93 def __init__(self, params):
94 self._add_output_port('out')
95 self._salut = 23
96
97 salut = None
98 graph = self._create_graph(MySource)
99 graph.run()
100 self.assertEqual(salut, 23)
101
102 def test_addr(self):
103 class MyIter(bt2._UserMessageIterator):
104 def __init__(self, self_port_output):
105 nonlocal addr
106 addr = self.addr
107
108 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
109 def __init__(self, params):
110 self._add_output_port('out')
111
112 addr = None
113 graph = self._create_graph(MySource)
114 graph.run()
115 self.assertIsNotNone(addr)
116 self.assertNotEqual(addr, 0)
117
118 # Test that messages returned by _UserMessageIterator.__next__ remain valid
119 # and can be re-used.
120 def test_reuse_message(self):
121 class MyIter(bt2._UserMessageIterator):
122 def __init__(self, port):
123 tc, sc, ec = port.user_data
124 trace = tc()
125 stream = trace.create_stream(sc)
126 packet = stream.create_packet()
127
128 # This message will be returned twice by __next__.
129 event_message = self._create_event_message(ec, packet)
130
131 self._msgs = [
132 self._create_stream_beginning_message(stream),
133 self._create_packet_beginning_message(packet),
134 event_message,
135 event_message,
136 ]
137
138 def __next__(self):
139 return self._msgs.pop(0)
140
141 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
142 def __init__(self, params):
143 tc = self._create_trace_class()
144 sc = tc.create_stream_class(supports_packets=True)
145 ec = sc.create_event_class()
146 self._add_output_port('out', (tc, sc, ec))
147
148 graph = bt2.Graph()
149 src = graph.add_component(MySource, 'src')
150 it = graph.create_output_port_message_iterator(src.output_ports['out'])
151
152 # Skip beginning messages.
153 msg = next(it)
154 self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
155 msg = next(it)
156 self.assertIsInstance(msg, bt2.message._PacketBeginningMessage)
157
158 msg_ev1 = next(it)
159 msg_ev2 = next(it)
160
161 self.assertIsInstance(msg_ev1, bt2.message._EventMessage)
162 self.assertIsInstance(msg_ev2, bt2.message._EventMessage)
163 self.assertEqual(msg_ev1.addr, msg_ev2.addr)
164
165 @staticmethod
166 def _setup_seek_beginning_test():
167 # Use a source, a filter and an output port iterator. This allows us
168 # to test calling `seek_beginning` on both a _OutputPortMessageIterator
169 # and a _UserComponentInputPortMessageIterator, on top of checking that
170 # _UserMessageIterator._seek_beginning is properly called.
171
172 class MySourceIter(bt2._UserMessageIterator):
173 def __init__(self, port):
174 tc, sc, ec = port.user_data
175 trace = tc()
176 stream = trace.create_stream(sc)
177 packet = stream.create_packet()
178
179 self._msgs = [
180 self._create_stream_beginning_message(stream),
181 self._create_packet_beginning_message(packet),
182 self._create_event_message(ec, packet),
183 self._create_event_message(ec, packet),
184 self._create_packet_end_message(packet),
185 self._create_stream_end_message(stream),
186 ]
187 self._at = 0
188
189 def _seek_beginning(self):
190 self._at = 0
191
192 def __next__(self):
193 if self._at < len(self._msgs):
194 msg = self._msgs[self._at]
195 self._at += 1
196 return msg
197 else:
198 raise StopIteration
199
200 class MySource(bt2._UserSourceComponent, message_iterator_class=MySourceIter):
201 def __init__(self, params):
202 tc = self._create_trace_class()
203 sc = tc.create_stream_class(supports_packets=True)
204 ec = sc.create_event_class()
205
206 self._add_output_port('out', (tc, sc, ec))
207
208 class MyFilterIter(bt2._UserMessageIterator):
209 def __init__(self, port):
210 input_port = port.user_data
211 self._upstream_iter = input_port.create_message_iterator()
212
213 def __next__(self):
214 return next(self._upstream_iter)
215
216 def _seek_beginning(self):
217 self._upstream_iter.seek_beginning()
218
219 @property
220 def _can_seek_beginning(self):
221 return self._upstream_iter.can_seek_beginning
222
223 class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyFilterIter):
224 def __init__(self, params):
225 input_port = self._add_input_port('in')
226 self._add_output_port('out', input_port)
227
228 graph = bt2.Graph()
229 src = graph.add_component(MySource, 'src')
230 flt = graph.add_component(MyFilter, 'flt')
231 graph.connect_ports(src.output_ports['out'], flt.input_ports['in'])
232 it = graph.create_output_port_message_iterator(flt.output_ports['out'])
233
234 return it, MySourceIter
235
236 def test_can_seek_beginning(self):
237 it, MySourceIter = self._setup_seek_beginning_test()
238
239 def _can_seek_beginning(self):
240 nonlocal can_seek_beginning
241 return can_seek_beginning
242
243 MySourceIter._can_seek_beginning = property(_can_seek_beginning)
244
245 can_seek_beginning = True
246 self.assertTrue(it.can_seek_beginning)
247
248 can_seek_beginning = False
249 self.assertFalse(it.can_seek_beginning)
250
251 # Once can_seek_beginning returns an error, verify that it raises when
252 # _can_seek_beginning has/returns the wrong type.
253
254 # Remove the _can_seek_beginning method, we now rely on the presence of
255 # a _seek_beginning method to know whether the iterator can seek to
256 # beginning or not.
257 del MySourceIter._can_seek_beginning
258 self.assertTrue(it.can_seek_beginning)
259
260 del MySourceIter._seek_beginning
261 self.assertFalse(it.can_seek_beginning)
262
263 def test_seek_beginning(self):
264 it, MySourceIter = self._setup_seek_beginning_test()
265
266 msg = next(it)
267 self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
268 msg = next(it)
269 self.assertIsInstance(msg, bt2.message._PacketBeginningMessage)
270
271 it.seek_beginning()
272
273 msg = next(it)
274 self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
275
276 # Verify that we can seek beginning after having reached the end.
277 #
278 # It currently does not work to seek an output port message iterator
279 # once it's ended, but we should eventually make it work and uncomment
280 # the following snippet.
281 #
282 # try:
283 # while True:
284 # next(it)
285 # except bt2.Stop:
286 # pass
287 #
288 # it.seek_beginning()
289 # msg = next(it)
290 # self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
291
292 def test_seek_beginning_user_error(self):
293 it, MySourceIter = self._setup_seek_beginning_test()
294
295 def _seek_beginning_error(self):
296 raise ValueError('ouch')
297
298 MySourceIter._seek_beginning = _seek_beginning_error
299
300 with self.assertRaises(bt2.Error):
301 it.seek_beginning()
302
303
304 class OutputPortMessageIteratorTestCase(unittest.TestCase):
305 def test_component(self):
306 class MyIter(bt2._UserMessageIterator):
307 def __init__(self, self_port_output):
308 self._at = 0
309
310 def __next__(self):
311 if self._at == 7:
312 raise bt2.Stop
313
314 if self._at == 0:
315 msg = self._create_stream_beginning_message(test_obj._stream)
316 elif self._at == 1:
317 msg = self._create_packet_beginning_message(test_obj._packet)
318 elif self._at == 5:
319 msg = self._create_packet_end_message(test_obj._packet)
320 elif self._at == 6:
321 msg = self._create_stream_end_message(test_obj._stream)
322 else:
323 msg = self._create_event_message(
324 test_obj._event_class, test_obj._packet
325 )
326 msg.event.payload_field['my_int'] = self._at * 3
327
328 self._at += 1
329 return msg
330
331 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
332 def __init__(self, params):
333 self._add_output_port('out')
334
335 trace_class = self._create_trace_class()
336 stream_class = trace_class.create_stream_class(supports_packets=True)
337
338 # Create payload field class
339 my_int_ft = trace_class.create_signed_integer_field_class(32)
340 payload_ft = trace_class.create_structure_field_class()
341 payload_ft += [('my_int', my_int_ft)]
342
343 event_class = stream_class.create_event_class(
344 name='salut', payload_field_class=payload_ft
345 )
346
347 trace = trace_class()
348 stream = trace.create_stream(stream_class)
349 packet = stream.create_packet()
350
351 test_obj._event_class = event_class
352 test_obj._stream = stream
353 test_obj._packet = packet
354
355 test_obj = self
356 graph = bt2.Graph()
357 src = graph.add_component(MySource, 'src')
358 msg_iter = graph.create_output_port_message_iterator(src.output_ports['out'])
359
360 for at, msg in enumerate(msg_iter):
361 if at == 0:
362 self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
363 elif at == 1:
364 self.assertIsInstance(msg, bt2.message._PacketBeginningMessage)
365 elif at == 5:
366 self.assertIsInstance(msg, bt2.message._PacketEndMessage)
367 elif at == 6:
368 self.assertIsInstance(msg, bt2.message._StreamEndMessage)
369 else:
370 self.assertIsInstance(msg, bt2.message._EventMessage)
371 self.assertEqual(msg.event.cls.name, 'salut')
372 field = msg.event.payload_field['my_int']
373 self.assertEqual(field, at * 3)
374
375
376 if __name__ == '__main__':
377 unittest.main()
This page took 0.037528 seconds and 4 git commands to generate.