Fix: bt2: fix reference counting of messages returned by Python components
[babeltrace.git] / tests / bindings / python / bt2 / test_message_iterator.py
1
2 #
3 # Copyright (C) 2019 EfficiOS Inc.
4 #
5 # This program is free software; you can redistribute it and/or
6 # modify it under the terms of the GNU General Public License
7 # as published by the Free Software Foundation; only version 2
8 # of the License.
9 #
10 # This program is distributed in the hope that it will be useful,
11 # but WITHOUT ANY WARRANTY; without even the implied warranty of
12 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 # GNU General Public License for more details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with this program; if not, write to the Free Software
17 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
18 #
19
20 from bt2 import value
21 import collections
22 import unittest
23 import copy
24 import bt2
25
26
27 class UserMessageIteratorTestCase(unittest.TestCase):
28 @staticmethod
29 def _create_graph(src_comp_cls):
30 class MySink(bt2._UserSinkComponent):
31 def __init__(self, params):
32 self._add_input_port('in')
33
34 def _consume(self):
35 next(self._msg_iter)
36
37 def _graph_is_configured(self):
38 self._msg_iter = self._input_ports['in'].create_message_iterator()
39
40 graph = bt2.Graph()
41 src_comp = graph.add_component(src_comp_cls, 'src')
42 sink_comp = graph.add_component(MySink, 'sink')
43 graph.connect_ports(src_comp.output_ports['out'],
44 sink_comp.input_ports['in'])
45 return graph
46
47 def test_init(self):
48 the_output_port_from_source = None
49 the_output_port_from_iter = None
50
51 class MyIter(bt2._UserMessageIterator):
52 def __init__(self, self_port_output):
53 nonlocal initialized
54 nonlocal the_output_port_from_iter
55 initialized = True
56 the_output_port_from_iter = self_port_output
57
58 class MySource(bt2._UserSourceComponent,
59 message_iterator_class=MyIter):
60 def __init__(self, params):
61 nonlocal the_output_port_from_source
62 the_output_port_from_source = self._add_output_port('out', 'user data')
63
64 initialized = False
65 graph = self._create_graph(MySource)
66 graph.run()
67 self.assertTrue(initialized)
68 self.assertEqual(the_output_port_from_source.addr, the_output_port_from_iter.addr)
69 self.assertEqual(the_output_port_from_iter.user_data, 'user data')
70
71 def test_finalize(self):
72 class MyIter(bt2._UserMessageIterator):
73 def _finalize(self):
74 nonlocal finalized
75 finalized = True
76
77 class MySource(bt2._UserSourceComponent,
78 message_iterator_class=MyIter):
79 def __init__(self, params):
80 self._add_output_port('out')
81
82 finalized = False
83 graph = self._create_graph(MySource)
84 graph.run()
85 del graph
86 self.assertTrue(finalized)
87
88 def test_component(self):
89 class MyIter(bt2._UserMessageIterator):
90 def __init__(self, self_port_output):
91 nonlocal salut
92 salut = self._component._salut
93
94 class MySource(bt2._UserSourceComponent,
95 message_iterator_class=MyIter):
96 def __init__(self, params):
97 self._add_output_port('out')
98 self._salut = 23
99
100 salut = None
101 graph = self._create_graph(MySource)
102 graph.run()
103 self.assertEqual(salut, 23)
104
105 def test_addr(self):
106 class MyIter(bt2._UserMessageIterator):
107 def __init__(self, self_port_output):
108 nonlocal addr
109 addr = self.addr
110
111 class MySource(bt2._UserSourceComponent,
112 message_iterator_class=MyIter):
113 def __init__(self, params):
114 self._add_output_port('out')
115
116 addr = None
117 graph = self._create_graph(MySource)
118 graph.run()
119 self.assertIsNotNone(addr)
120 self.assertNotEqual(addr, 0)
121
122 # Test that messages returned by _UserMessageIterator.__next__ remain valid
123 # and can be re-used.
124 def test_reuse_message(self):
125 class MyIter(bt2._UserMessageIterator):
126 def __init__(self, port):
127 tc, sc, ec = port.user_data
128 trace = tc()
129 stream = trace.create_stream(sc)
130 packet = stream.create_packet()
131
132 # This message will be returned twice by __next__.
133 event_message = self._create_event_message(ec, packet)
134
135 self._msgs = [
136 self._create_stream_beginning_message(stream),
137 self._create_stream_activity_beginning_message(stream),
138 self._create_packet_beginning_message(packet),
139 event_message,
140 event_message,
141 ]
142
143 def __next__(self):
144 return self._msgs.pop(0)
145
146 class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
147 def __init__(self, params):
148 tc = self._create_trace_class()
149 sc = tc.create_stream_class()
150 ec = sc.create_event_class()
151 self._add_output_port('out', (tc, sc, ec))
152
153 graph = bt2.Graph()
154 src = graph.add_component(MySource, 'src')
155 it = graph.create_output_port_message_iterator(src.output_ports['out'])
156
157 # Skip beginning messages.
158 next(it)
159 next(it)
160 next(it)
161
162 msg_ev1 = next(it)
163 msg_ev2 = next(it)
164
165 self.assertIsInstance(msg_ev1, bt2.message._EventMessage)
166 self.assertIsInstance(msg_ev2, bt2.message._EventMessage)
167 self.assertEqual(msg_ev1.addr, msg_ev2.addr)
168
169
170 class OutputPortMessageIteratorTestCase(unittest.TestCase):
171 def test_component(self):
172 class MyIter(bt2._UserMessageIterator):
173 def __init__(self, self_port_output):
174 self._at = 0
175
176 def __next__(self):
177 if self._at == 7:
178 raise bt2.Stop
179
180 if self._at == 0:
181 msg = self._create_stream_beginning_message(test_obj._stream)
182 elif self._at == 1:
183 msg = self._create_packet_beginning_message(test_obj._packet)
184 elif self._at == 5:
185 msg = self._create_packet_end_message(test_obj._packet)
186 elif self._at == 6:
187 msg = self._create_stream_end_message(test_obj._stream)
188 else:
189 msg = self._create_event_message(test_obj._event_class, test_obj._packet)
190 msg.event.payload_field['my_int'] = self._at * 3
191
192 self._at += 1
193 return msg
194
195 class MySource(bt2._UserSourceComponent,
196 message_iterator_class=MyIter):
197 def __init__(self, params):
198 self._add_output_port('out')
199
200 trace_class = self._create_trace_class()
201 stream_class = trace_class.create_stream_class()
202
203 # Create payload field class
204 my_int_ft = trace_class.create_signed_integer_field_class(32)
205 payload_ft = trace_class.create_structure_field_class()
206 payload_ft += collections.OrderedDict([
207 ('my_int', my_int_ft),
208 ])
209
210 event_class = stream_class.create_event_class(name='salut', payload_field_class=payload_ft)
211
212 trace = trace_class()
213 stream = trace.create_stream(stream_class)
214 packet = stream.create_packet()
215
216 test_obj._event_class = event_class
217 test_obj._stream = stream
218 test_obj._packet = packet
219
220 test_obj = self
221 graph = bt2.Graph()
222 src = graph.add_component(MySource, 'src')
223 msg_iter = graph.create_output_port_message_iterator(src.output_ports['out'])
224
225 for at, msg in enumerate(msg_iter):
226 if at == 0:
227 self.assertIsInstance(msg, bt2.message._StreamBeginningMessage)
228 elif at == 1:
229 self.assertIsInstance(msg, bt2.message._PacketBeginningMessage)
230 elif at == 5:
231 self.assertIsInstance(msg, bt2.message._PacketEndMessage)
232 elif at == 6:
233 self.assertIsInstance(msg, bt2.message._StreamEndMessage)
234 else:
235 self.assertIsInstance(msg, bt2.message._EventMessage)
236 self.assertEqual(msg.event.cls.name, 'salut')
237 field = msg.event.payload_field['my_int']
238 self.assertEqual(field, at * 3)
This page took 0.034824 seconds and 5 git commands to generate.