c59fd8d57d27d41c8b3aa41672f909c21132d35e
[babeltrace.git] / tests / bindings / python / bt2 / utils.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # Copyright (C) 2019 EfficiOS Inc.
4 #
5
6 import bt2
7 import collections.abc
8
9
10 # Run callable `func` in the context of a component's __init__ method. The
11 # callable is passed the Component being instantiated.
12 #
13 # The value returned by the callable is returned by run_in_component_init.
14 def run_in_component_init(func):
15 class MySink(bt2._UserSinkComponent):
16 def __init__(self, config, params, obj):
17 nonlocal res_bound
18 res_bound = func(self)
19
20 def _user_consume(self):
21 pass
22
23 g = bt2.Graph()
24 res_bound = None
25 g.add_component(MySink, "comp")
26
27 # We deliberately use a different variable for returning the result than
28 # the variable bound to the MySink.__init__ context and delete res_bound.
29 # The MySink.__init__ context stays alive until the end of the program, so
30 # if res_bound were to still point to our result, it would contribute an
31 # unexpected reference to the refcount of the result, from the point of view
32 # of the user of this function. It would then affect destruction tests,
33 # for example, which want to test what happens when the refcount of a Python
34 # object reaches 0.
35
36 res = res_bound
37 del res_bound
38 return res
39
40
41 # Create an empty trace class with default values.
42 def get_default_trace_class():
43 def f(comp_self):
44 return comp_self._create_trace_class()
45
46 return run_in_component_init(f)
47
48
49 # Create a pair of list, one containing non-const messages and the other
50 # containing const messages
51 def _get_all_message_types(with_packet=True):
52 _msgs = None
53
54 class MyIter(bt2._UserMessageIterator):
55 def __init__(self, config, self_output_port):
56 nonlocal _msgs
57 self._at = 0
58 self._msgs = [
59 self._create_stream_beginning_message(
60 self_output_port.user_data["stream"]
61 )
62 ]
63
64 if with_packet:
65 assert self_output_port.user_data["packet"]
66 self._msgs.append(
67 self._create_packet_beginning_message(
68 self_output_port.user_data["packet"]
69 )
70 )
71
72 default_clock_snapshot = 789
73
74 if with_packet:
75 assert self_output_port.user_data["packet"]
76 ev_parent = self_output_port.user_data["packet"]
77 else:
78 assert self_output_port.user_data["stream"]
79 ev_parent = self_output_port.user_data["stream"]
80
81 msg = self._create_event_message(
82 self_output_port.user_data["event_class"],
83 ev_parent,
84 default_clock_snapshot,
85 )
86
87 msg.event.payload_field["giraffe"] = 1
88 msg.event.specific_context_field["ant"] = -1
89 msg.event.common_context_field["cpu_id"] = 1
90 self._msgs.append(msg)
91
92 if with_packet:
93 self._msgs.append(
94 self._create_packet_end_message(
95 self_output_port.user_data["packet"]
96 )
97 )
98
99 self._msgs.append(
100 self._create_stream_end_message(self_output_port.user_data["stream"])
101 )
102
103 _msgs = self._msgs
104
105 def __next__(self):
106 if self._at == len(self._msgs):
107 raise bt2.Stop
108
109 msg = self._msgs[self._at]
110 self._at += 1
111 return msg
112
113 class MySrc(bt2._UserSourceComponent, message_iterator_class=MyIter):
114 def __init__(self, config, params, obj):
115 tc = self._create_trace_class()
116 clock_class = self._create_clock_class(frequency=1000)
117
118 # event common context (stream-class-defined)
119 cc = tc.create_structure_field_class()
120 cc += [("cpu_id", tc.create_signed_integer_field_class(8))]
121
122 # packet context (stream-class-defined)
123 pc = None
124
125 if with_packet:
126 pc = tc.create_structure_field_class()
127 pc += [("something", tc.create_unsigned_integer_field_class(8))]
128
129 stream_class = tc.create_stream_class(
130 default_clock_class=clock_class,
131 event_common_context_field_class=cc,
132 packet_context_field_class=pc,
133 supports_packets=with_packet,
134 )
135
136 # specific context (event-class-defined)
137 sc = tc.create_structure_field_class()
138 sc += [("ant", tc.create_signed_integer_field_class(16))]
139
140 # event payload
141 ep = tc.create_structure_field_class()
142 ep += [("giraffe", tc.create_signed_integer_field_class(32))]
143
144 event_class = stream_class.create_event_class(
145 name="garou", specific_context_field_class=sc, payload_field_class=ep
146 )
147
148 trace = tc(environment={"patate": 12})
149 stream = trace.create_stream(stream_class, user_attributes={"salut": 23})
150
151 if with_packet:
152 packet = stream.create_packet()
153 packet.context_field["something"] = 154
154 else:
155 packet = None
156
157 self._add_output_port(
158 "out",
159 {
160 "tc": tc,
161 "stream": stream,
162 "event_class": event_class,
163 "trace": trace,
164 "packet": packet,
165 },
166 )
167
168 _graph = bt2.Graph()
169 _src_comp = _graph.add_component(MySrc, "my_source")
170 _msg_iter = TestOutputPortMessageIterator(_graph, _src_comp.output_ports["out"])
171
172 const_msgs = list(_msg_iter)
173
174 return _msgs, const_msgs
175
176
177 def get_stream_beginning_message():
178 msgs, _ = _get_all_message_types()
179 for m in msgs:
180 if type(m) is bt2._StreamBeginningMessage:
181 return m
182
183
184 def get_const_stream_beginning_message():
185 _, const_msgs = _get_all_message_types()
186 for m in const_msgs:
187 if type(m) is bt2._StreamBeginningMessageConst:
188 return m
189
190
191 def get_stream_end_message():
192 msgs, _ = _get_all_message_types()
193 for m in msgs:
194 if type(m) is bt2._StreamEndMessage:
195 return m
196
197
198 def get_packet_beginning_message():
199 msgs, _ = _get_all_message_types(with_packet=True)
200 for m in msgs:
201 if type(m) is bt2._PacketBeginningMessage:
202 return m
203
204
205 def get_const_packet_beginning_message():
206 _, const_msgs = _get_all_message_types(with_packet=True)
207 for m in const_msgs:
208 if type(m) is bt2._PacketBeginningMessageConst:
209 return m
210
211
212 def get_packet_end_message():
213 msgs, _ = _get_all_message_types(with_packet=True)
214 for m in msgs:
215 if type(m) is bt2._PacketEndMessage:
216 return m
217
218
219 def get_event_message():
220 msgs, _ = _get_all_message_types()
221 for m in msgs:
222 if type(m) is bt2._EventMessage:
223 return m
224
225
226 def get_const_event_message():
227 _, const_msgs = _get_all_message_types()
228 for m in const_msgs:
229 if type(m) is bt2._EventMessageConst:
230 return m
231
232
233 # Proxy sink component class.
234 #
235 # This sink accepts a list of a single item as its initialization
236 # object. This sink creates a single input port `in`. When it consumes
237 # from this port, it puts the returned message in the initialization
238 # list as the first item.
239 class TestProxySink(bt2._UserSinkComponent):
240 def __init__(self, config, params, msg_list):
241 assert msg_list is not None
242 self._msg_list = msg_list
243 self._add_input_port("in")
244
245 def _user_graph_is_configured(self):
246 self._msg_iter = self._create_message_iterator(self._input_ports["in"])
247
248 def _user_consume(self):
249 assert self._msg_list[0] is None
250 self._msg_list[0] = next(self._msg_iter)
251
252
253 # This is a helper message iterator for tests.
254 #
255 # The constructor accepts a graph and an output port.
256 #
257 # Internally, it adds a proxy sink to the graph and connects the
258 # received output port to the proxy sink's input port. Its __next__()
259 # method then uses the proxy sink to transfer the consumed message to
260 # the output port message iterator's user.
261 #
262 # This message iterator cannot seek.
263 class TestOutputPortMessageIterator(collections.abc.Iterator):
264 def __init__(self, graph, output_port):
265 self._graph = graph
266 self._msg_list = [None]
267 sink = graph.add_component(TestProxySink, "test-proxy-sink", obj=self._msg_list)
268 graph.connect_ports(output_port, sink.input_ports["in"])
269
270 def __next__(self):
271 assert self._msg_list[0] is None
272 self._graph.run_once()
273 msg = self._msg_list[0]
274 assert msg is not None
275 self._msg_list[0] = None
276 return msg
277
278
279 # Create a const field of the given field class.
280 #
281 # The field is part of a dummy stream, itself part of a dummy trace created
282 # from trace class `tc`.
283 def create_const_field(tc, field_class, field_value_setter_fn):
284 field_name = "const field"
285
286 class MyIter(bt2._UserMessageIterator):
287 def __init__(self, config, self_port_output):
288 nonlocal field_class
289 nonlocal field_value_setter_fn
290 trace = tc()
291 packet_context_fc = tc.create_structure_field_class()
292 packet_context_fc.append_member(field_name, field_class)
293 sc = tc.create_stream_class(
294 packet_context_field_class=packet_context_fc, supports_packets=True
295 )
296 stream = trace.create_stream(sc)
297 packet = stream.create_packet()
298
299 field_value_setter_fn(packet.context_field[field_name])
300
301 self._msgs = [
302 self._create_stream_beginning_message(stream),
303 self._create_packet_beginning_message(packet),
304 ]
305
306 def __next__(self):
307 if len(self._msgs) == 0:
308 raise StopIteration
309
310 return self._msgs.pop(0)
311
312 class MySrc(bt2._UserSourceComponent, message_iterator_class=MyIter):
313 def __init__(self, config, params, obj):
314 self._add_output_port("out", params)
315
316 graph = bt2.Graph()
317 src_comp = graph.add_component(MySrc, "my_source", None)
318 msg_iter = TestOutputPortMessageIterator(graph, src_comp.output_ports["out"])
319
320 # Ignore first message, stream beginning
321 _ = next(msg_iter)
322 packet_beg_msg = next(msg_iter)
323
324 return packet_beg_msg.packet.context_field[field_name]
325
326
327 # Run `msg_iter_next_func` in a bt2._UserMessageIterator.__next__ context.
328 #
329 # For convenience, a trace and a stream are created. To allow the caller to
330 # customize the created stream class, the `create_stream_class_func` callback
331 # is invoked during the component initialization. It gets passed a trace class
332 # and a clock class, and must return a stream class.
333 #
334 # The `msg_iter_next_func` callback receives two arguments, the message iterator
335 # and the created stream.
336 #
337 # The value returned by `msg_iter_next_func` is returned by this function.
338 def run_in_message_iterator_next(create_stream_class_func, msg_iter_next_func):
339 class MyIter(bt2._UserMessageIterator):
340 def __init__(self, config, port):
341 tc, sc = port.user_data
342 trace = tc()
343 self._stream = trace.create_stream(sc)
344
345 def __next__(self):
346 nonlocal res_bound
347 res_bound = msg_iter_next_func(self, self._stream)
348 raise bt2.Stop
349
350 class MySrc(bt2._UserSourceComponent, message_iterator_class=MyIter):
351 def __init__(self, config, params, obj):
352 tc = self._create_trace_class()
353 cc = self._create_clock_class()
354 sc = create_stream_class_func(tc, cc)
355
356 self._add_output_port("out", (tc, sc))
357
358 class MySink(bt2._UserSinkComponent):
359 def __init__(self, config, params, obj):
360 self._input_port = self._add_input_port("in")
361
362 def _user_graph_is_configured(self):
363 self._input_iter = self._create_message_iterator(self._input_port)
364
365 def _user_consume(self):
366 next(self._input_iter)
367
368 graph = bt2.Graph()
369 res_bound = None
370 src = graph.add_component(MySrc, "ze source")
371 snk = graph.add_component(MySink, "ze sink")
372 graph.connect_ports(src.output_ports["out"], snk.input_ports["in"])
373 graph.run()
374
375 # We deliberately use a different variable for returning the result than
376 # the variable bound to the MyIter.__next__ context. See the big comment
377 # about that in `run_in_component_init`.
378
379 res = res_bound
380 del res_bound
381 return res
This page took 0.03674 seconds and 3 git commands to generate.