bt2: Adapt test_stream.py and make it pass
[babeltrace.git] / bindings / python / bt2 / bt2 / field_class.py
1 # The MIT License (MIT)
2 #
3 # Copyright (c) 2017 Philippe Proulx <pproulx@efficios.com>
4 #
5 # Permission is hereby granted, free of charge, to any person obtaining a copy
6 # of this software and associated documentation files (the "Software"), to deal
7 # in the Software without restriction, including without limitation the rights
8 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 # copies of the Software, and to permit persons to whom the Software is
10 # furnished to do so, subject to the following conditions:
11 #
12 # The above copyright notice and this permission notice shall be included in
13 # all copies or substantial portions of the Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 # THE SOFTWARE.
22
23 from bt2 import native_bt, object, utils
24 import collections.abc
25 import bt2.field
26 import abc
27 import bt2
28
29
30 def _create_field_class_from_ptr_and_get_ref(ptr):
31 typeid = native_bt.field_class_get_type(ptr)
32 return _FIELD_CLASS_TYPE_TO_OBJ[typeid]._create_from_ptr_and_get_ref(ptr)
33
34
35 class _FieldClass(object._SharedObject, metaclass=abc.ABCMeta):
36 _get_ref = staticmethod(native_bt.field_class_get_ref)
37 _put_ref = staticmethod(native_bt.field_class_put_ref)
38
39 def __init__(self, ptr):
40 super().__init__(ptr)
41
42 def __eq__(self, other):
43 if not isinstance(other, self.__class__):
44 # not comparing apples to apples
45 return False
46
47 if self.addr == other.addr:
48 return True
49
50 ret = native_bt.field_class_compare(self._ptr, other._ptr)
51 utils._handle_ret(ret, "cannot compare field classes")
52 return ret == 0
53
54 def _check_create_status(self, ptr):
55 if ptr is None:
56 raise bt2.CreationError('cannot create {} field class object'.format(self._NAME.lower()))
57
58 def __copy__(self):
59 ptr = native_bt.field_class_copy(self._ptr)
60 utils._handle_ptr(ptr, 'cannot copy {} field class object'.format(self._NAME.lower()))
61 return _create_from_ptr(ptr)
62
63 def __deepcopy__(self, memo):
64 cpy = self.__copy__()
65 memo[id(self)] = cpy
66 return cpy
67
68 def __call__(self, value=None):
69 field_ptr = native_bt.field_create(self._ptr)
70
71 if field_ptr is None:
72 raise bt2.CreationError('cannot create {} field object'.format(self._NAME.lower()))
73
74 field = bt2.field._create_from_ptr(field_ptr)
75
76 if value is not None:
77 if not isinstance(field, (bt2.field._IntegerField, bt2.field._FloatingPointNumberField, bt2.field._StringField)):
78 raise bt2.Error('cannot assign an initial value to a {} field object'.format(field._NAME))
79
80 field.value = value
81
82 return field
83
84
85 class _AlignmentProp:
86 @property
87 def alignment(self):
88 alignment = native_bt.field_class_get_alignment(self._ptr)
89 assert(alignment >= 0)
90 return alignment
91
92 @alignment.setter
93 def alignment(self, alignment):
94 utils._check_alignment(alignment)
95 ret = native_bt.field_class_set_alignment(self._ptr, alignment)
96 utils._handle_ret(ret, "cannot set field class object's alignment")
97
98
99 class _ByteOrderProp:
100 @property
101 def byte_order(self):
102 bo = native_bt.field_class_get_byte_order(self._ptr)
103 assert(bo >= 0)
104 return bo
105
106 @byte_order.setter
107 def byte_order(self, byte_order):
108 utils._check_int(byte_order)
109 ret = native_bt.field_class_set_byte_order(self._ptr, byte_order)
110 utils._handle_ret(ret, "cannot set field class object's byte order")
111
112
113 class _IntegerFieldClass(_FieldClass):
114
115 def __init__(self, size, alignment=None, byte_order=None, is_signed=None,
116 base=None, encoding=None, mapped_clock_class=None):
117 utils._check_uint64(size)
118
119 if size == 0:
120 raise ValueError('size is 0 bits')
121
122 ptr = native_bt.field_class_integer_create(size)
123 self._check_create_status(ptr)
124 super().__init__(ptr)
125
126 if alignment is not None:
127 self.alignment = alignment
128
129 if byte_order is not None:
130 self.byte_order = byte_order
131
132 if is_signed is not None:
133 self.is_signed = is_signed
134
135 if base is not None:
136 self.base = base
137
138 if encoding is not None:
139 self.encoding = encoding
140
141 if mapped_clock_class is not None:
142 self.mapped_clock_class = mapped_clock_class
143
144 @property
145 def size(self):
146 size = native_bt.field_class_integer_get_size(self._ptr)
147 assert(size >= 1)
148 return size
149
150 @property
151 def is_signed(self):
152 is_signed = native_bt.field_class_integer_is_signed(self._ptr)
153 assert(is_signed >= 0)
154 return is_signed > 0
155
156 @is_signed.setter
157 def is_signed(self, is_signed):
158 utils._check_bool(is_signed)
159 ret = native_bt.field_class_integer_set_is_signed(self._ptr, int(is_signed))
160 utils._handle_ret(ret, "cannot set integer field class object's signedness")
161
162 @property
163 def base(self):
164 base = native_bt.field_class_integer_get_base(self._ptr)
165 assert(base >= 0)
166 return base
167
168 @base.setter
169 def base(self, base):
170 utils._check_int(base)
171 ret = native_bt.field_class_integer_set_base(self._ptr, base)
172 utils._handle_ret(ret, "cannot set integer field class object's base")
173
174 @property
175 def encoding(self):
176 encoding = native_bt.field_class_integer_get_encoding(self._ptr)
177 assert(encoding >= 0)
178 return encoding
179
180 @encoding.setter
181 def encoding(self, encoding):
182 utils._check_int(encoding)
183 ret = native_bt.field_class_integer_set_encoding(self._ptr, encoding)
184 utils._handle_ret(ret, "cannot set integer field class object's encoding")
185
186 @property
187 def mapped_clock_class(self):
188 ptr = native_bt.field_class_integer_get_mapped_clock_class(self._ptr)
189
190 if ptr is None:
191 return
192
193 return bt2.ClockClass._create_from_ptr(ptr)
194
195 @mapped_clock_class.setter
196 def mapped_clock_class(self, clock_class):
197 utils._check_type(clock_class, bt2.ClockClass)
198 ret = native_bt.field_class_integer_set_mapped_clock_class(self._ptr, clock_class._ptr)
199 utils._handle_ret(ret, "cannot set integer field class object's mapped clock class")
200
201
202 class _SignedIntegerFieldClass(_IntegerFieldClass):
203 pass
204
205 class SignedIntegerFieldClass(_SignedIntegerFieldClass):
206 _NAME = 'SignedInteger'
207
208 class FloatingPointNumberFieldClass(_FieldClass, _AlignmentProp, _ByteOrderProp):
209 _NAME = 'Floating point number'
210
211 def __init__(self, alignment=None, byte_order=None, exponent_size=None,
212 mantissa_size=None):
213 ptr = native_bt.field_class_floating_point_create()
214 self._check_create_status(ptr)
215 super().__init__(ptr)
216
217 if alignment is not None:
218 self.alignment = alignment
219
220 if byte_order is not None:
221 self.byte_order = byte_order
222
223 if exponent_size is not None:
224 self.exponent_size = exponent_size
225
226 if mantissa_size is not None:
227 self.mantissa_size = mantissa_size
228
229 @property
230 def exponent_size(self):
231 exp_size = native_bt.field_class_floating_point_get_exponent_digits(self._ptr)
232 assert(exp_size >= 0)
233 return exp_size
234
235 @exponent_size.setter
236 def exponent_size(self, exponent_size):
237 utils._check_uint64(exponent_size)
238 ret = native_bt.field_class_floating_point_set_exponent_digits(self._ptr, exponent_size)
239 utils._handle_ret(ret, "cannot set floating point number field class object's exponent size")
240
241 @property
242 def mantissa_size(self):
243 mant_size = native_bt.field_class_floating_point_get_mantissa_digits(self._ptr)
244 assert(mant_size >= 0)
245 return mant_size
246
247 @mantissa_size.setter
248 def mantissa_size(self, mantissa_size):
249 utils._check_uint64(mantissa_size)
250 ret = native_bt.field_class_floating_point_set_mantissa_digits(self._ptr, mantissa_size)
251 utils._handle_ret(ret, "cannot set floating point number field class object's mantissa size")
252
253
254 class _EnumerationFieldClassMapping:
255 def __init__(self, name, lower, upper):
256 self._name = name
257 self._lower = lower
258 self._upper = upper
259
260 @property
261 def name(self):
262 return self._name
263
264 @property
265 def lower(self):
266 return self._lower
267
268 @property
269 def upper(self):
270 return self._upper
271
272 def __eq__(self, other):
273 if type(other) is not self.__class__:
274 return False
275
276 return (self.name, self.lower, self.upper) == (other.name, other.lower, other.upper)
277
278
279 class _EnumerationFieldClassMappingIterator(object._SharedObject,
280 collections.abc.Iterator):
281 def __init__(self, iter_ptr, is_signed):
282 super().__init__(iter_ptr)
283 self._is_signed = is_signed
284 self._done = (iter_ptr is None)
285
286 def __next__(self):
287 if self._done:
288 raise StopIteration
289
290 ret = native_bt.field_class_enumeration_mapping_iterator_next(self._ptr)
291 if ret < 0:
292 self._done = True
293 raise StopIteration
294
295 if self._is_signed:
296 ret, name, lower, upper = native_bt.field_class_enumeration_mapping_iterator_get_signed(self._ptr)
297 else:
298 ret, name, lower, upper = native_bt.field_class_enumeration_mapping_iterator_get_unsigned(self._ptr)
299
300 assert(ret == 0)
301 mapping = _EnumerationFieldClassMapping(name, lower, upper)
302
303 return mapping
304
305
306 class EnumerationFieldClass(_IntegerFieldClass, collections.abc.Sequence):
307 _NAME = 'Enumeration'
308
309 def __init__(self, int_field_class=None, size=None, alignment=None,
310 byte_order=None, is_signed=None, base=None, encoding=None,
311 mapped_clock_class=None):
312 if int_field_class is None:
313 int_field_class = IntegerFieldClass(size=size, alignment=alignment,
314 byte_order=byte_order,
315 is_signed=is_signed, base=base,
316 encoding=encoding,
317 mapped_clock_class=mapped_clock_class)
318
319 utils._check_type(int_field_class, IntegerFieldClass)
320 ptr = native_bt.field_class_enumeration_create(int_field_class._ptr)
321 self._check_create_status(ptr)
322 _FieldClass.__init__(self, ptr)
323
324 @property
325 def integer_field_class(self):
326 ptr = native_bt.field_class_enumeration_get_container_type(self._ptr)
327 assert(ptr)
328 return _create_from_ptr(ptr)
329
330 @property
331 def size(self):
332 return self.integer_field_class.size
333
334 @property
335 def alignment(self):
336 return self.integer_field_class.alignment
337
338 @alignment.setter
339 def alignment(self, alignment):
340 self.integer_field_class.alignment = alignment
341
342 @property
343 def byte_order(self):
344 return self.integer_field_class.byte_order
345
346 @byte_order.setter
347 def byte_order(self, byte_order):
348 self.integer_field_class.byte_order = byte_order
349
350 @property
351 def is_signed(self):
352 return self.integer_field_class.is_signed
353
354 @is_signed.setter
355 def is_signed(self, is_signed):
356 self.integer_field_class.is_signed = is_signed
357
358 @property
359 def base(self):
360 return self.integer_field_class.base
361
362 @base.setter
363 def base(self, base):
364 self.integer_field_class.base = base
365
366 @property
367 def encoding(self):
368 return self.integer_field_class.encoding
369
370 @encoding.setter
371 def encoding(self, encoding):
372 self.integer_field_class.encoding = encoding
373
374 @property
375 def mapped_clock_class(self):
376 return self.integer_field_class.mapped_clock_class
377
378 @mapped_clock_class.setter
379 def mapped_clock_class(self, mapped_clock_class):
380 self.integer_field_class.mapped_clock_class = mapped_clock_class
381
382 def __len__(self):
383 count = native_bt.field_class_enumeration_get_mapping_count(self._ptr)
384 assert(count >= 0)
385 return count
386
387 def __getitem__(self, index):
388 utils._check_uint64(index)
389
390 if index >= len(self):
391 raise IndexError
392
393 if self.is_signed:
394 get_fn = native_bt.field_class_enumeration_get_mapping_signed
395 else:
396 get_fn = native_bt.field_class_enumeration_get_mapping_unsigned
397
398 ret, name, lower, upper = get_fn(self._ptr, index)
399 assert(ret == 0)
400 return _EnumerationFieldClassMapping(name, lower, upper)
401
402 def _get_mapping_iter(self, iter_ptr):
403 return _EnumerationFieldClassMappingIterator(iter_ptr, self.is_signed)
404
405 def mappings_by_name(self, name):
406 utils._check_str(name)
407 iter_ptr = native_bt.field_class_enumeration_find_mappings_by_name(self._ptr, name)
408 print('iter_ptr', iter_ptr)
409 return self._get_mapping_iter(iter_ptr)
410
411 def mappings_by_value(self, value):
412 if self.is_signed:
413 utils._check_int64(value)
414 iter_ptr = native_bt.field_class_enumeration_find_mappings_by_signed_value(self._ptr, value)
415 else:
416 utils._check_uint64(value)
417 iter_ptr = native_bt.field_class_enumeration_find_mappings_by_unsigned_value(self._ptr, value)
418
419 return self._get_mapping_iter(iter_ptr)
420
421 def add_mapping(self, name, lower, upper=None):
422 utils._check_str(name)
423
424 if upper is None:
425 upper = lower
426
427 if self.is_signed:
428 add_fn = native_bt.field_class_enumeration_add_mapping_signed
429 utils._check_int64(lower)
430 utils._check_int64(upper)
431 else:
432 add_fn = native_bt.field_class_enumeration_add_mapping_unsigned
433 utils._check_uint64(lower)
434 utils._check_uint64(upper)
435
436 ret = add_fn(self._ptr, name, lower, upper)
437 utils._handle_ret(ret, "cannot add mapping to enumeration field class object")
438
439 def __iadd__(self, mappings):
440 for mapping in mappings:
441 self.add_mapping(mapping.name, mapping.lower, mapping.upper)
442
443 return self
444
445
446 class StringFieldClass(_FieldClass):
447 _NAME = 'String'
448
449 def __init__(self, encoding=None):
450 ptr = native_bt.field_class_string_create()
451 self._check_create_status(ptr)
452 super().__init__(ptr)
453
454 if encoding is not None:
455 self.encoding = encoding
456
457 @property
458 def encoding(self):
459 encoding = native_bt.field_class_string_get_encoding(self._ptr)
460 assert(encoding >= 0)
461 return encoding
462
463 @encoding.setter
464 def encoding(self, encoding):
465 utils._check_int(encoding)
466 ret = native_bt.field_class_string_set_encoding(self._ptr, encoding)
467 utils._handle_ret(ret, "cannot set string field class object's encoding")
468
469
470 class _FieldContainer(collections.abc.Mapping):
471 def __len__(self):
472 count = self._count()
473 assert(count >= 0)
474 return count
475
476 def __getitem__(self, key):
477 if not isinstance(key, str):
478 raise TypeError("'{}' is not a 'str' object".format(key.__class__.__name__))
479
480 ptr = self._get_field_by_name(key)
481
482 if ptr is None:
483 raise KeyError(key)
484
485 return _create_from_ptr(ptr)
486
487 def __iter__(self):
488 return self._ITER_CLS(self)
489
490 def append_field(self, name, field_class):
491 utils._check_str(name)
492 utils._check_type(field_class, _FieldClass)
493 ret = self._add_field(name, field_class._ptr)
494 utils._handle_ret(ret, "cannot add field to {} field class object".format(self._NAME.lower()))
495
496 def __iadd__(self, fields):
497 for name, field_class in fields.items():
498 self.append_field(name, field_class)
499
500 return self
501
502 def at_index(self, index):
503 utils._check_uint64(index)
504 return self._at(index)
505
506
507 class _StructureFieldClassFieldIterator(collections.abc.Iterator):
508 def __init__(self, struct_field_class):
509 self._struct_field_class = struct_field_class
510 self._at = 0
511
512 def __next__(self):
513 if self._at == len(self._struct_field_class):
514 raise StopIteration
515
516 get_fc_by_index = native_bt.field_class_structure_get_field_by_index
517 ret, name, field_class_ptr = get_fc_by_index(self._struct_field_class._ptr,
518 self._at)
519 assert(ret == 0)
520 native_bt.put(field_class_ptr)
521 self._at += 1
522 return name
523
524
525 class _StructureFieldClass(_FieldClass, _FieldContainer, _AlignmentProp):
526 _NAME = 'Structure'
527 _ITER_CLS = _StructureFieldClassFieldIterator
528
529 def __init__(self, min_alignment=None):
530 ptr = native_bt.field_class_structure_create()
531 self._check_create_status(ptr)
532 super().__init__(ptr)
533
534 if min_alignment is not None:
535 self.min_alignment = min_alignment
536
537 def _count(self):
538 return native_bt.field_class_structure_get_field_count(self._ptr)
539
540 def _get_field_by_name(self, key):
541 return native_bt.field_class_structure_get_field_class_by_name(self._ptr, key)
542
543 def _add_field(self, name, ptr):
544 return native_bt.field_class_structure_append_member(self._ptr, name, ptr)
545
546 def _at(self, index):
547 if index < 0 or index >= len(self):
548 raise IndexError
549
550 ret, name, field_class_ptr = native_bt.field_class_structure_get_field_by_index(self._ptr, index)
551 assert(ret == 0)
552 return _create_from_ptr(field_class_ptr)
553
554
555 _StructureFieldClass.min_alignment = property(fset=_StructureFieldClass.alignment.fset)
556 _StructureFieldClass.alignment = property(fget=_StructureFieldClass.alignment.fget)
557
558
559 class _VariantFieldClassFieldIterator(collections.abc.Iterator):
560 def __init__(self, variant_field_class):
561 self._variant_field_class = variant_field_class
562 self._at = 0
563
564 def __next__(self):
565 if self._at == len(self._variant_field_class):
566 raise StopIteration
567
568 ret, name, field_class_ptr = native_bt.field_class_variant_get_field_by_index(self._variant_field_class._ptr,
569 self._at)
570 assert(ret == 0)
571 native_bt.put(field_class_ptr)
572 self._at += 1
573 return name
574
575
576 class VariantFieldClass(_FieldClass, _FieldContainer, _AlignmentProp):
577 _NAME = 'Variant'
578 _ITER_CLS = _VariantFieldClassFieldIterator
579
580 def __init__(self, tag_name, tag_field_class=None):
581 utils._check_str(tag_name)
582
583 if tag_field_class is None:
584 tag_fc_ptr = None
585 else:
586 utils._check_type(tag_field_class, EnumerationFieldClass)
587 tag_fc_ptr = tag_field_class._ptr
588
589 ptr = native_bt.field_class_variant_create(tag_fc_ptr,
590 tag_name)
591 self._check_create_status(ptr)
592 super().__init__(ptr)
593
594 @property
595 def tag_name(self):
596 tag_name = native_bt.field_class_variant_get_tag_name(self._ptr)
597 assert(tag_name is not None)
598 return tag_name
599
600 @tag_name.setter
601 def tag_name(self, tag_name):
602 utils._check_str(tag_name)
603 ret = native_bt.field_class_variant_set_tag_name(self._ptr, tag_name)
604 utils._handle_ret(ret, "cannot set variant field class object's tag name")
605
606 @property
607 def tag_field_class(self):
608 fc_ptr = native_bt.field_class_variant_get_tag_type(self._ptr)
609
610 if fc_ptr is None:
611 return
612
613 return _create_from_ptr(fc_ptr)
614
615 def _count(self):
616 return native_bt.field_class_variant_get_field_count(self._ptr)
617
618 def _get_field_by_name(self, key):
619 return native_bt.field_class_variant_get_field_class_by_name(self._ptr, key)
620
621 def _add_field(self, ptr, name):
622 return native_bt.field_class_variant_add_field(self._ptr, ptr, name)
623
624 def _at(self, index):
625 if index < 0 or index >= len(self):
626 raise IndexError
627
628 ret, name, field_class_ptr = native_bt.field_class_variant_get_field_by_index(self._ptr, index)
629 assert(ret == 0)
630 return _create_from_ptr(field_class_ptr)
631
632
633 class ArrayFieldClass(_FieldClass):
634 _NAME = 'Array'
635
636 def __init__(self, element_field_class, length):
637 utils._check_type(element_field_class, _FieldClass)
638 utils._check_uint64(length)
639 ptr = native_bt.field_class_array_create(element_field_class._ptr, length)
640 self._check_create_status(ptr)
641 super().__init__(ptr)
642
643 @property
644 def length(self):
645 length = native_bt.field_class_array_get_length(self._ptr)
646 assert(length >= 0)
647 return length
648
649 @property
650 def element_field_class(self):
651 ptr = native_bt.field_class_array_get_element_type(self._ptr)
652 assert(ptr)
653 return _create_from_ptr(ptr)
654
655
656 class SequenceFieldClass(_FieldClass):
657 _NAME = 'Sequence'
658
659 def __init__(self, element_field_class, length_name):
660 utils._check_type(element_field_class, _FieldClass)
661 utils._check_str(length_name)
662 ptr = native_bt.field_class_sequence_create(element_field_class._ptr,
663 length_name)
664 self._check_create_status(ptr)
665 super().__init__(ptr)
666
667 @property
668 def length_name(self):
669 length_name = native_bt.field_class_sequence_get_length_field_name(self._ptr)
670 assert(length_name is not None)
671 return length_name
672
673 @property
674 def element_field_class(self):
675 ptr = native_bt.field_class_sequence_get_element_type(self._ptr)
676 assert(ptr)
677 return _create_from_ptr(ptr)
678
679
680 _FIELD_CLASS_TYPE_TO_OBJ = {
681 native_bt.FIELD_CLASS_TYPE_STRUCTURE: _StructureFieldClass,
682 }
This page took 0.04416 seconds and 4 git commands to generate.