Move to kernel style SPDX license identifiers
[babeltrace.git] / tests / bindings / python / bt2 / utils.py
1 # SPDX-License-Identifier: GPL-2.0-only
2 #
3 # Copyright (C) 2019 EfficiOS Inc.
4 #
5
6 import bt2
7 import collections.abc
8
9
10 # Run callable `func` in the context of a component's __init__ method. The
11 # callable is passed the Component being instantiated.
12 #
13 # The value returned by the callable is returned by run_in_component_init.
14 def run_in_component_init(func):
15 class MySink(bt2._UserSinkComponent):
16 def __init__(self, config, params, obj):
17 nonlocal res_bound
18 res_bound = func(self)
19
20 def _user_consume(self):
21 pass
22
23 g = bt2.Graph()
24 res_bound = None
25 g.add_component(MySink, 'comp')
26
27 # We deliberately use a different variable for returning the result than
28 # the variable bound to the MySink.__init__ context and delete res_bound.
29 # The MySink.__init__ context stays alive until the end of the program, so
30 # if res_bound were to still point to our result, it would contribute an
31 # unexpected reference to the refcount of the result, from the point of view
32 # of the user of this function. It would then affect destruction tests,
33 # for example, which want to test what happens when the refcount of a Python
34 # object reaches 0.
35
36 res = res_bound
37 del res_bound
38 return res
39
40
41 # Create an empty trace class with default values.
42 def get_default_trace_class():
43 def f(comp_self):
44 return comp_self._create_trace_class()
45
46 return run_in_component_init(f)
47
48
49 # Create a pair of list, one containing non-const messages and the other
50 # containing const messages
51 def _get_all_message_types(with_packet=True):
52 _msgs = None
53
54 class MyIter(bt2._UserMessageIterator):
55 def __init__(self, config, self_output_port):
56
57 nonlocal _msgs
58 self._at = 0
59 self._msgs = [
60 self._create_stream_beginning_message(
61 self_output_port.user_data['stream']
62 )
63 ]
64
65 if with_packet:
66 assert self_output_port.user_data['packet']
67 self._msgs.append(
68 self._create_packet_beginning_message(
69 self_output_port.user_data['packet']
70 )
71 )
72
73 default_clock_snapshot = 789
74
75 if with_packet:
76 assert self_output_port.user_data['packet']
77 ev_parent = self_output_port.user_data['packet']
78 else:
79 assert self_output_port.user_data['stream']
80 ev_parent = self_output_port.user_data['stream']
81
82 msg = self._create_event_message(
83 self_output_port.user_data['event_class'],
84 ev_parent,
85 default_clock_snapshot,
86 )
87
88 msg.event.payload_field['giraffe'] = 1
89 msg.event.specific_context_field['ant'] = -1
90 msg.event.common_context_field['cpu_id'] = 1
91 self._msgs.append(msg)
92
93 if with_packet:
94 self._msgs.append(
95 self._create_packet_end_message(
96 self_output_port.user_data['packet']
97 )
98 )
99
100 self._msgs.append(
101 self._create_stream_end_message(self_output_port.user_data['stream'])
102 )
103
104 _msgs = self._msgs
105
106 def __next__(self):
107 if self._at == len(self._msgs):
108 raise bt2.Stop
109
110 msg = self._msgs[self._at]
111 self._at += 1
112 return msg
113
114 class MySrc(bt2._UserSourceComponent, message_iterator_class=MyIter):
115 def __init__(self, config, params, obj):
116 tc = self._create_trace_class()
117 clock_class = self._create_clock_class(frequency=1000)
118
119 # event common context (stream-class-defined)
120 cc = tc.create_structure_field_class()
121 cc += [('cpu_id', tc.create_signed_integer_field_class(8))]
122
123 # packet context (stream-class-defined)
124 pc = None
125
126 if with_packet:
127 pc = tc.create_structure_field_class()
128 pc += [('something', tc.create_unsigned_integer_field_class(8))]
129
130 stream_class = tc.create_stream_class(
131 default_clock_class=clock_class,
132 event_common_context_field_class=cc,
133 packet_context_field_class=pc,
134 supports_packets=with_packet,
135 )
136
137 # specific context (event-class-defined)
138 sc = tc.create_structure_field_class()
139 sc += [('ant', tc.create_signed_integer_field_class(16))]
140
141 # event payload
142 ep = tc.create_structure_field_class()
143 ep += [('giraffe', tc.create_signed_integer_field_class(32))]
144
145 event_class = stream_class.create_event_class(
146 name='garou', specific_context_field_class=sc, payload_field_class=ep
147 )
148
149 trace = tc(environment={'patate': 12})
150 stream = trace.create_stream(stream_class, user_attributes={'salut': 23})
151
152 if with_packet:
153 packet = stream.create_packet()
154 packet.context_field['something'] = 154
155 else:
156 packet = None
157
158 self._add_output_port(
159 'out',
160 {
161 'tc': tc,
162 'stream': stream,
163 'event_class': event_class,
164 'trace': trace,
165 'packet': packet,
166 },
167 )
168
169 _graph = bt2.Graph()
170 _src_comp = _graph.add_component(MySrc, 'my_source')
171 _msg_iter = TestOutputPortMessageIterator(_graph, _src_comp.output_ports['out'])
172
173 const_msgs = list(_msg_iter)
174
175 return _msgs, const_msgs
176
177
178 def get_stream_beginning_message():
179 msgs, _ = _get_all_message_types()
180 for m in msgs:
181 if type(m) is bt2._StreamBeginningMessage:
182 return m
183
184
185 def get_const_stream_beginning_message():
186 _, const_msgs = _get_all_message_types()
187 for m in const_msgs:
188 if type(m) is bt2._StreamBeginningMessageConst:
189 return m
190
191
192 def get_stream_end_message():
193 msgs, _ = _get_all_message_types()
194 for m in msgs:
195 if type(m) is bt2._StreamEndMessage:
196 return m
197
198
199 def get_packet_beginning_message():
200 msgs, _ = _get_all_message_types(with_packet=True)
201 for m in msgs:
202 if type(m) is bt2._PacketBeginningMessage:
203 return m
204
205
206 def get_const_packet_beginning_message():
207 _, const_msgs = _get_all_message_types(with_packet=True)
208 for m in const_msgs:
209 if type(m) is bt2._PacketBeginningMessageConst:
210 return m
211
212
213 def get_packet_end_message():
214 msgs, _ = _get_all_message_types(with_packet=True)
215 for m in msgs:
216 if type(m) is bt2._PacketEndMessage:
217 return m
218
219
220 def get_event_message():
221 msgs, _ = _get_all_message_types()
222 for m in msgs:
223 if type(m) is bt2._EventMessage:
224 return m
225
226
227 def get_const_event_message():
228 _, const_msgs = _get_all_message_types()
229 for m in const_msgs:
230 if type(m) is bt2._EventMessageConst:
231 return m
232
233
234 # Proxy sink component class.
235 #
236 # This sink accepts a list of a single item as its initialization
237 # object. This sink creates a single input port `in`. When it consumes
238 # from this port, it puts the returned message in the initialization
239 # list as the first item.
240 class TestProxySink(bt2._UserSinkComponent):
241 def __init__(self, config, params, msg_list):
242 assert msg_list is not None
243 self._msg_list = msg_list
244 self._add_input_port('in')
245
246 def _user_graph_is_configured(self):
247 self._msg_iter = self._create_message_iterator(self._input_ports['in'])
248
249 def _user_consume(self):
250 assert self._msg_list[0] is None
251 self._msg_list[0] = next(self._msg_iter)
252
253
254 # This is a helper message iterator for tests.
255 #
256 # The constructor accepts a graph and an output port.
257 #
258 # Internally, it adds a proxy sink to the graph and connects the
259 # received output port to the proxy sink's input port. Its __next__()
260 # method then uses the proxy sink to transfer the consumed message to
261 # the output port message iterator's user.
262 #
263 # This message iterator cannot seek.
264 class TestOutputPortMessageIterator(collections.abc.Iterator):
265 def __init__(self, graph, output_port):
266 self._graph = graph
267 self._msg_list = [None]
268 sink = graph.add_component(TestProxySink, 'test-proxy-sink', obj=self._msg_list)
269 graph.connect_ports(output_port, sink.input_ports['in'])
270
271 def __next__(self):
272 assert self._msg_list[0] is None
273 self._graph.run_once()
274 msg = self._msg_list[0]
275 assert msg is not None
276 self._msg_list[0] = None
277 return msg
278
279
280 # Create a const field of the given field class.
281 #
282 # The field is part of a dummy stream, itself part of a dummy trace created
283 # from trace class `tc`.
284 def create_const_field(tc, field_class, field_value_setter_fn):
285 field_name = 'const field'
286
287 class MyIter(bt2._UserMessageIterator):
288 def __init__(self, config, self_port_output):
289 nonlocal field_class
290 nonlocal field_value_setter_fn
291 trace = tc()
292 packet_context_fc = tc.create_structure_field_class()
293 packet_context_fc.append_member(field_name, field_class)
294 sc = tc.create_stream_class(
295 packet_context_field_class=packet_context_fc, supports_packets=True
296 )
297 stream = trace.create_stream(sc)
298 packet = stream.create_packet()
299
300 field_value_setter_fn(packet.context_field[field_name])
301
302 self._msgs = [
303 self._create_stream_beginning_message(stream),
304 self._create_packet_beginning_message(packet),
305 ]
306
307 def __next__(self):
308 if len(self._msgs) == 0:
309 raise StopIteration
310
311 return self._msgs.pop(0)
312
313 class MySrc(bt2._UserSourceComponent, message_iterator_class=MyIter):
314 def __init__(self, config, params, obj):
315 self._add_output_port('out', params)
316
317 graph = bt2.Graph()
318 src_comp = graph.add_component(MySrc, 'my_source', None)
319 msg_iter = TestOutputPortMessageIterator(graph, src_comp.output_ports['out'])
320
321 # Ignore first message, stream beginning
322 _ = next(msg_iter)
323 packet_beg_msg = next(msg_iter)
324
325 return packet_beg_msg.packet.context_field[field_name]
This page took 0.036248 seconds and 4 git commands to generate.