Add support for variant fields in the Python bindings
[babeltrace.git] / bindings / python / babeltrace.i.in
index 4b3439f2a726a3df835746214f80a0d84011fd4c..c39b53ad08866b06fbd9811aac74e3c0e7b6e750 100644 (file)
@@ -568,6 +568,8 @@ struct bt_ctf_event *bt_ctf_iter_read_event(struct bt_ctf_iter *iter);
 %rename("_bt_ctf_get_int64") bt_ctf_get_int64(const struct bt_definition *field);
 %rename("_bt_ctf_get_char_array") bt_ctf_get_char_array(const struct bt_definition *field);
 %rename("_bt_ctf_get_string") bt_ctf_get_string(const struct bt_definition *field);
+%rename("_bt_ctf_get_float") bt_ctf_get_float(const struct bt_definition *field);
+%rename("_bt_ctf_get_variant") bt_ctf_get_variant(const struct bt_definition *field);
 %rename("_bt_ctf_field_get_error") bt_ctf_field_get_error(void);
 %rename("_bt_ctf_get_decl_event_name") bt_ctf_get_decl_event_name(const struct
                bt_ctf_event_decl *event);
@@ -605,6 +607,8 @@ uint64_t bt_ctf_get_uint64(const struct bt_definition *field);
 int64_t bt_ctf_get_int64(const struct bt_definition *field);
 char *bt_ctf_get_char_array(const struct bt_definition *field);
 char *bt_ctf_get_string(const struct bt_definition *field);
+double bt_ctf_get_float(const struct bt_definition *field);
+const struct bt_definition *bt_ctf_get_variant(const struct bt_definition *field);
 int bt_ctf_field_get_error(void);
 const char *bt_ctf_get_decl_event_name(const struct bt_ctf_event_decl *event);
 const char *bt_ctf_get_decl_field_name(const struct bt_ctf_field_decl *field);
@@ -632,6 +636,15 @@ class ctf:
                SEQUENCE = 9
                NR_CTF_TYPES = 10
 
+               def get_type_id_name(id):
+                       name = "UNKNOWN"
+                       constants = [attr for attr in dir(ctf.type_id) if not callable(getattr(ctf.type_id, attr)) and not attr.startswith("__")]
+                       for attr in constants:
+                               if getattr(ctf.type_id, attr) == id:
+                                       name = attr
+                                       break
+                       return name
+
        class scope:
                TRACE_PACKET_HEADER = 0
                STREAM_PACKET_CONTEXT = 1
@@ -884,6 +897,12 @@ class ctf:
                        else:
                                return ctx
 
+       class FieldError(Exception):
+               def __init__(self, value):
+                       self.value = value
+
+               def __str__(self):
+                       return repr(self.value)
 
        class Definition(object):
                """Definition class.  Do not instantiate."""
@@ -1021,35 +1040,61 @@ class ctf:
                        """
                        return _bt_ctf_get_string(self._d)
 
+               def get_float(self):
+                       """
+                       Return the value associated with the field.
+                       If the field does not exist or is not of the type requested,
+                       the value returned is undefined. To check if an error occured,
+                       use the ctf.field_error() function after accessing a field.
+                       """
+                       return _bt_ctf_get_float(self._d)
+
+               def get_variant(self):
+                       """
+                       Return the variant's selected field.
+                       If the field does not exist or is not of the type requested,
+                       the value returned is undefined. To check if an error occured,
+                       use the ctf.field_error() function after accessing a field.
+                       """
+                       return _bt_ctf_get_variant(self._d)
+
                def get_value(self):
                        """
                        Return the value associated with the field according to its type.
                        Return None on error.
                        """
                        id = self.field_type()
+                       value = None
                        if id == ctf.type_id.STRING:
-                               return self.get_str()
-                       if id == ctf.type_id.ARRAY:
-                               array = []
+                               value = self.get_str()
+                       elif id == ctf.type_id.ARRAY:
+                               value = []
                                for i in range(self.get_array_len()):
                                        element = self.get_array_element_at(i)
-                                       array.append(element.get_value())
-                               return array
-                       if id == ctf.type_id.INTEGER:
+                                       value.append(element.get_value())
+                       elif id == ctf.type_id.INTEGER:
                                if self.get_int_signedness() == 0:
-                                       return self.get_uint64()
+                                       value = self.get_uint64()
                                else:
-                                       return self.get_int64()
-                       if id == ctf.type_id.ENUM:
-                               return self.get_enum_str()
-                       if id == ctf.type_id.SEQUENCE:
+                                       value = self.get_int64()
+                       elif id == ctf.type_id.ENUM:
+                               value = self.get_enum_str()
+                       elif id == ctf.type_id.SEQUENCE:
                                seq_len = self.get_sequence_len()
-                               values = []
+                               value = []
                                for i in range(seq_len):
                                        evDef = self.get_sequence_element_at(i)
-                                       values.append(evDef.get_value())
-                               return values
-                       return None
+                                       value.append(evDef.get_value())
+                       elif id == ctf.type_id.FLOAT:
+                               value = self.get_float()
+                       elif id == ctf.type_id.VARIANT:
+                               variant = ctf.Definition.__new__(ctf.Definition)
+                               variant._d = self.get_variant();
+                               value = variant.get_value()
+
+                       if ctf.field_error():
+                               raise ctf.FieldError("Error occured while accessing field {} of type {}".format(self.field_name(), ctf.type_id.get_type_id_name(self.field_type())))
+                       return value
 
                def get_scope(self):
                        """Return the scope of a field or None on error."""
This page took 0.025403 seconds and 4 git commands to generate.