Update Python bindings and tests to match the latest API
[babeltrace.git] / bindings / python / bt2 / field_types.py
index 5268836c48236ad8e499b54f1a59c1e861b34bcc..3a4e820f8ec1a0c05f3d792529859f2e2307ab19 100644 (file)
@@ -1,6 +1,6 @@
 # The MIT License (MIT)
 #
-# Copyright (c) 2016 Philippe Proulx <pproulx@efficios.com>
+# Copyright (c) 2017 Philippe Proulx <pproulx@efficios.com>
 #
 # Permission is hereby granted, free of charge, to any person obtaining a copy
 # of this software and associated documentation files (the "Software"), to deal
@@ -103,7 +103,7 @@ class _AlignmentProp:
     @property
     def alignment(self):
         alignment = native_bt.ctf_field_type_get_alignment(self._ptr)
-        utils._handle_ret(alignment, "cannot get field type object's alignment")
+        assert(alignment >= 0)
         return alignment
 
     @alignment.setter
@@ -117,7 +117,7 @@ class _ByteOrderProp:
     @property
     def byte_order(self):
         bo = native_bt.ctf_field_type_get_byte_order(self._ptr)
-        utils._handle_ret(bo, "cannot get field type object's byte order")
+        assert(bo >= 0)
         return bo
 
     @byte_order.setter
@@ -162,26 +162,25 @@ class IntegerFieldType(_FieldType, _AlignmentProp, _ByteOrderProp):
     @property
     def size(self):
         size = native_bt.ctf_field_type_integer_get_size(self._ptr)
-        utils._handle_ret(size, "cannot get integer field type object's size")
+        assert(size >= 1)
         return size
 
     @property
     def is_signed(self):
-        is_signed = native_bt.ctf_field_type_integer_get_signed(self._ptr)
-        utils._handle_ret(is_signed,
-                    "cannot get integer field type object's signedness")
+        is_signed = native_bt.ctf_field_type_integer_is_signed(self._ptr)
+        assert(is_signed >= 0)
         return is_signed > 0
 
     @is_signed.setter
     def is_signed(self, is_signed):
         utils._check_bool(is_signed)
-        ret = native_bt.ctf_field_type_integer_set_signed(self._ptr, int(is_signed))
+        ret = native_bt.ctf_field_type_integer_set_is_signed(self._ptr, int(is_signed))
         utils._handle_ret(ret, "cannot set integer field type object's signedness")
 
     @property
     def base(self):
         base = native_bt.ctf_field_type_integer_get_base(self._ptr)
-        utils._handle_ret(base, "cannot get integer field type object's base")
+        assert(base >= 0)
         return base
 
     @base.setter
@@ -193,7 +192,7 @@ class IntegerFieldType(_FieldType, _AlignmentProp, _ByteOrderProp):
     @property
     def encoding(self):
         encoding = native_bt.ctf_field_type_integer_get_encoding(self._ptr)
-        utils._handle_ret(encoding, "cannot get integer field type object's encoding")
+        assert(encoding >= 0)
         return encoding
 
     @encoding.setter
@@ -205,7 +204,10 @@ class IntegerFieldType(_FieldType, _AlignmentProp, _ByteOrderProp):
     @property
     def mapped_clock_class(self):
         ptr = native_bt.ctf_field_type_integer_get_mapped_clock_class(self._ptr)
-        utils._handle_ptr(ptr, "cannot get integer field type object's mapped clock class")
+
+        if ptr is None:
+            return
+
         return bt2.ClockClass._create_from_ptr(ptr)
 
     @mapped_clock_class.setter
@@ -239,7 +241,7 @@ class FloatingPointNumberFieldType(_FieldType, _AlignmentProp, _ByteOrderProp):
     @property
     def exponent_size(self):
         exp_size = native_bt.ctf_field_type_floating_point_get_exponent_digits(self._ptr)
-        utils._handle_ret(exp_size, "cannot get floating point number field type object's exponent size")
+        assert(exp_size >= 0)
         return exp_size
 
     @exponent_size.setter
@@ -250,9 +252,9 @@ class FloatingPointNumberFieldType(_FieldType, _AlignmentProp, _ByteOrderProp):
 
     @property
     def mantissa_size(self):
-        exp_size = native_bt.ctf_field_type_floating_point_get_mantissa_digits(self._ptr)
-        utils._handle_ret(exp_size, "cannot get floating point number field type object's mantissa size")
-        return exp_size
+        mant_size = native_bt.ctf_field_type_floating_point_get_mantissa_digits(self._ptr)
+        assert(mant_size >= 0)
+        return mant_size
 
     @mantissa_size.setter
     def mantissa_size(self, mantissa_size):
@@ -302,7 +304,7 @@ class _EnumerationFieldTypeMappingIterator(object._Object,
         else:
             ret, name, lower, upper = native_bt.ctf_field_type_enumeration_mapping_iterator_get_unsigned(self._ptr)
 
-        utils._handle_ret(ret, "cannot get enumeration field type mapping iterator object's current mapping")
+        assert(ret == 0)
         mapping = _EnumerationFieldTypeMapping(name, lower, upper)
         ret = native_bt.ctf_field_type_enumeration_mapping_iterator_next(self._ptr)
 
@@ -333,7 +335,7 @@ class EnumerationFieldType(IntegerFieldType, collections.abc.Sequence):
     @property
     def integer_field_type(self):
         ptr = native_bt.ctf_field_type_enumeration_get_container_type(self._ptr)
-        utils._handle_ptr(ptr, "cannot get enumeration field type object's integer field type")
+        assert(ptr)
         return _create_from_ptr(ptr)
 
     @property
@@ -390,7 +392,7 @@ class EnumerationFieldType(IntegerFieldType, collections.abc.Sequence):
 
     def __len__(self):
         count = native_bt.ctf_field_type_enumeration_get_mapping_count(self._ptr)
-        utils._handle_ret(count, "cannot get enumeration field type object's mapping count")
+        assert(count >= 0)
         return count
 
     def __getitem__(self, index):
@@ -405,7 +407,7 @@ class EnumerationFieldType(IntegerFieldType, collections.abc.Sequence):
             get_fn = native_bt.ctf_field_type_enumeration_get_mapping_unsigned
 
         ret, name, lower, upper = get_fn(self._ptr, index)
-        utils._handle_ret(ret, "cannot get enumeration field type object's mapping")
+        assert(ret == 0)
         return _EnumerationFieldTypeMapping(name, lower, upper)
 
     def _get_mapping_iter(self, iter_ptr):
@@ -433,7 +435,7 @@ class EnumerationFieldType(IntegerFieldType, collections.abc.Sequence):
             upper = lower
 
         if self.is_signed:
-            add_fn = native_bt.ctf_field_type_enumeration_add_mapping
+            add_fn = native_bt.ctf_field_type_enumeration_add_mapping_signed
             utils._check_int64(lower)
             utils._check_int64(upper)
         else:
@@ -465,7 +467,7 @@ class StringFieldType(_FieldType):
     @property
     def encoding(self):
         encoding = native_bt.ctf_field_type_string_get_encoding(self._ptr)
-        utils._handle_ret(encoding, "cannot get string field type object's encoding")
+        assert(encoding >= 0)
         return encoding
 
     @encoding.setter
@@ -478,7 +480,7 @@ class StringFieldType(_FieldType):
 class _FieldContainer(collections.abc.Mapping):
     def __len__(self):
         count = self._count()
-        utils._handle_ret(count, "cannot get {} field type object's field count".format(self._NAME.lower()))
+        assert(count >= 0)
         return count
 
     def __getitem__(self, key):
@@ -521,8 +523,10 @@ class _StructureFieldTypeFieldIterator(collections.abc.Iterator):
         if self._at == len(self._struct_field_type):
             raise StopIteration
 
-        ret, name, field_type_ptr = native_bt.ctf_field_type_structure_get_field(self._struct_field_type._ptr, self._at)
-        utils._handle_ret(ret, "cannot get structure field type object's field")
+        get_ft_by_index = native_bt.ctf_field_type_structure_get_field_by_index
+        ret, name, field_type_ptr = get_ft_by_index(self._struct_field_type._ptr,
+                                                    self._at)
+        assert(ret == 0)
         native_bt.put(field_type_ptr)
         self._at += 1
         return name
@@ -551,8 +555,11 @@ class StructureFieldType(_FieldType, _FieldContainer, _AlignmentProp):
                                                             name)
 
     def _at(self, index):
-        ret, name, field_type_ptr = native_bt.ctf_field_type_structure_get_field(self._ptr, index)
-        utils._handle_ret(ret, "cannot get structure field type object's field")
+        if index < 0 or index >= len(self):
+            raise IndexError
+
+        ret, name, field_type_ptr = native_bt.ctf_field_type_structure_get_field_by_index(self._ptr, index)
+        assert(ret == 0)
         return _create_from_ptr(field_type_ptr)
 
 
@@ -569,8 +576,9 @@ class _VariantFieldTypeFieldIterator(collections.abc.Iterator):
         if self._at == len(self._variant_field_type):
             raise StopIteration
 
-        ret, name, field_type_ptr = native_bt.ctf_field_type_variant_get_field(self._variant_field_type._ptr, self._at)
-        utils._handle_ret(ret, "cannot get variant field type object's field")
+        ret, name, field_type_ptr = native_bt.ctf_field_type_variant_get_field_by_index(self._variant_field_type._ptr,
+                                                                                        self._at)
+        assert(ret == 0)
         native_bt.put(field_type_ptr)
         self._at += 1
         return name
@@ -580,16 +588,24 @@ class VariantFieldType(_FieldType, _FieldContainer, _AlignmentProp):
     _NAME = 'Variant'
     _ITER_CLS = _VariantFieldTypeFieldIterator
 
-    def __init__(self, tag_name):
+    def __init__(self, tag_name, tag_field_type=None):
         utils._check_str(tag_name)
-        ptr = native_bt.ctf_field_type_variant_create(None, tag_name)
+
+        if tag_field_type is None:
+            tag_ft_ptr = None
+        else:
+            utils._check_type(tag_field_type, EnumerationFieldType)
+            tag_ft_ptr = tag_field_type._ptr
+
+        ptr = native_bt.ctf_field_type_variant_create(tag_ft_ptr,
+                                                      tag_name)
         self._check_create_status(ptr)
         super().__init__(ptr)
 
     @property
     def tag_name(self):
         tag_name = native_bt.ctf_field_type_variant_get_tag_name(self._ptr)
-        utils._handle_ptr(tag_name, "cannot get variant field type object's tag name")
+        assert(tag_name is not None)
         return tag_name
 
     @tag_name.setter
@@ -598,6 +614,15 @@ class VariantFieldType(_FieldType, _FieldContainer, _AlignmentProp):
         ret = native_bt.ctf_field_type_variant_set_tag_name(self._ptr, tag_name)
         utils._handle_ret(ret, "cannot set variant field type object's tag name")
 
+    @property
+    def tag_field_type(self):
+        ft_ptr = native_bt.ctf_field_type_variant_get_tag_type(self._ptr)
+
+        if ft_ptr is None:
+            return
+
+        return _create_from_ptr(ft_ptr)
+
     def _count(self):
         return native_bt.ctf_field_type_variant_get_field_count(self._ptr)
 
@@ -608,8 +633,11 @@ class VariantFieldType(_FieldType, _FieldContainer, _AlignmentProp):
         return native_bt.ctf_field_type_variant_add_field(self._ptr, ptr, name)
 
     def _at(self, index):
-        ret, name, field_type_ptr = native_bt.ctf_field_type_variant_get_field(self._ptr, index)
-        utils._handle_ret(ret, "cannot get variant field type object's field")
+        if index < 0 or index >= len(self):
+            raise IndexError
+
+        ret, name, field_type_ptr = native_bt.ctf_field_type_variant_get_field_by_index(self._ptr, index)
+        assert(ret == 0)
         return _create_from_ptr(field_type_ptr)
 
 
@@ -626,13 +654,13 @@ class ArrayFieldType(_FieldType):
     @property
     def length(self):
         length = native_bt.ctf_field_type_array_get_length(self._ptr)
-        utils._handle_ret(length, "cannot get array field type object's length")
+        assert(length >= 0)
         return length
 
     @property
     def element_field_type(self):
         ptr = native_bt.ctf_field_type_array_get_element_type(self._ptr)
-        utils._handle_ptr(ptr, "cannot get array field type object's element field type")
+        assert(ptr)
         return _create_from_ptr(ptr)
 
 
@@ -650,24 +678,23 @@ class SequenceFieldType(_FieldType):
     @property
     def length_name(self):
         length_name = native_bt.ctf_field_type_sequence_get_length_field_name(self._ptr)
-        utils._handle_ptr(length_name, "cannot get sequence field type object's length name")
+        assert(length_name is not None)
         return length_name
 
     @property
     def element_field_type(self):
         ptr = native_bt.ctf_field_type_sequence_get_element_type(self._ptr)
-        utils._handle_ptr(ptr, "cannot get sequence field type object's element field type")
+        assert(ptr)
         return _create_from_ptr(ptr)
 
 
 _TYPE_ID_TO_OBJ = {
-    native_bt.CTF_TYPE_ID_INTEGER: IntegerFieldType,
-    native_bt.CTF_TYPE_ID_FLOAT: FloatingPointNumberFieldType,
-    native_bt.CTF_TYPE_ID_ENUM: EnumerationFieldType,
-    native_bt.CTF_TYPE_ID_STRING: StringFieldType,
-    native_bt.CTF_TYPE_ID_STRUCT: StructureFieldType,
-    native_bt.CTF_TYPE_ID_ARRAY: ArrayFieldType,
-    native_bt.CTF_TYPE_ID_SEQUENCE: SequenceFieldType,
-    native_bt.CTF_TYPE_ID_VARIANT: VariantFieldType,
-    native_bt.CTF_TYPE_ID_UNTAGGED_VARIANT: VariantFieldType,
+    native_bt.CTF_FIELD_TYPE_ID_INTEGER: IntegerFieldType,
+    native_bt.CTF_FIELD_TYPE_ID_FLOAT: FloatingPointNumberFieldType,
+    native_bt.CTF_FIELD_TYPE_ID_ENUM: EnumerationFieldType,
+    native_bt.CTF_FIELD_TYPE_ID_STRING: StringFieldType,
+    native_bt.CTF_FIELD_TYPE_ID_STRUCT: StructureFieldType,
+    native_bt.CTF_FIELD_TYPE_ID_ARRAY: ArrayFieldType,
+    native_bt.CTF_FIELD_TYPE_ID_SEQUENCE: SequenceFieldType,
+    native_bt.CTF_FIELD_TYPE_ID_VARIANT: VariantFieldType,
 }
This page took 0.026991 seconds and 4 git commands to generate.