bt2: update bindings to make test_plugins pass
[babeltrace.git] / bindings / python / bt2 / bt2 / plugin.py
index a7ecf9b35064cb20337cbe06aad18e13ebbc9e43..51f9a2ed85f1b7591a693ac02e0d48bf4a0fad10 100644 (file)
@@ -33,9 +33,9 @@ def find_plugins(path, recurse=True):
     plugin_set_ptr = None
 
     if os.path.isfile(path):
-        plugin_set_ptr = native_bt.plugin_create_all_from_file(path)
+        plugin_set_ptr = native_bt.plugin_find_all_from_file(path)
     elif os.path.isdir(path):
-        plugin_set_ptr = native_bt.plugin_create_all_from_dir(path, int(recurse))
+        plugin_set_ptr = native_bt.plugin_find_all_from_dir(path, int(recurse))
 
     if plugin_set_ptr is None:
         return
@@ -54,6 +54,9 @@ def find_plugin(name):
 
 
 class _PluginSet(object._SharedObject, collections.abc.Sequence):
+    _put_ref = native_bt.plugin_set_put_ref
+    _get_ref = native_bt.plugin_set_get_ref
+
     def __len__(self):
         count = native_bt.plugin_set_get_plugin_count(self._ptr)
         assert(count >= 0)
@@ -65,9 +68,9 @@ class _PluginSet(object._SharedObject, collections.abc.Sequence):
         if index >= len(self):
             raise IndexError
 
-        plugin_ptr = native_bt.plugin_set_get_plugin(self._ptr, index)
-        assert(plugin_ptr)
-        return _Plugin._create_from_ptr(plugin_ptr)
+        plugin_ptr = native_bt.plugin_set_borrow_plugin_by_index_const(self._ptr, index)
+        assert plugin_ptr is not None
+        return _Plugin._create_from_ptr_and_get_ref(plugin_ptr)
 
 
 class _PluginVersion:
@@ -109,68 +112,68 @@ class _PluginComponentClassesIterator(collections.abc.Iterator):
 
     def __next__(self):
         plugin_ptr = self._plugin_comp_cls._plugin._ptr
-        comp_cls_type = self._plugin_comp_cls._comp_cls_type
-        total = native_bt.plugin_get_component_class_count(plugin_ptr)
-
-        while True:
-            if self._at == total:
-                raise StopIteration
-
-            comp_cls_ptr = native_bt.plugin_get_component_class_by_index(plugin_ptr,
-                                                                         self._at)
-            assert(comp_cls_ptr)
-            cc_type = native_bt.component_class_get_type(comp_cls_ptr)
-            self._at += 1
+        total = self._plugin_comp_cls._component_class_count(plugin_ptr)
 
-            if cc_type == comp_cls_type:
-                break
+        if self._at == total:
+            raise StopIteration
 
-            native_bt.put(comp_cls_ptr)
+        comp_cls_ptr = self._plugin_comp_cls._borrow_component_class_by_index(plugin_ptr, self._at)
+        assert comp_cls_ptr is not None
+        self._at += 1
 
+        comp_cls_type = self._plugin_comp_cls._comp_cls_type
+        comp_cls_pycls = bt2.component._COMP_CLS_TYPE_TO_GENERIC_COMP_CLS_PYCLS[comp_cls_type]
+        comp_cls_ptr = comp_cls_pycls._as_component_class_ptr(comp_cls_ptr)
         name = native_bt.component_class_get_name(comp_cls_ptr)
-        native_bt.put(comp_cls_ptr)
-        assert(name is not None)
+        assert name is not None
         return name
 
 
 class _PluginComponentClasses(collections.abc.Mapping):
-    def __init__(self, plugin, comp_cls_type):
+    def __init__(self, plugin):
         self._plugin = plugin
-        self._comp_cls_type = comp_cls_type
 
     def __getitem__(self, key):
         utils._check_str(key)
-        cc_ptr = native_bt.plugin_get_component_class_by_name_and_type(self._plugin._ptr,
-                                                                       key,
-                                                                       self._comp_cls_type)
+        cc_ptr = self._borrow_component_class_by_name(self._plugin._ptr, key)
 
         if cc_ptr is None:
             raise KeyError(key)
 
-        return bt2.component._create_generic_component_class_from_ptr(cc_ptr)
+        return bt2.component._create_component_class_from_ptr_and_get_ref(cc_ptr, self._comp_cls_type)
 
     def __len__(self):
-        count = 0
-        total = native_bt.plugin_get_component_class_count(self._plugin._ptr)
+        return self._component_class_count(self._plugin._ptr)
+
+    def __iter__(self):
+        return _PluginComponentClassesIterator(self)
 
-        for at in range(total):
-            comp_cls_ptr = native_bt.plugin_get_component_class_by_index(self._plugin._ptr,
-                                                                         at)
-            assert(comp_cls_ptr)
-            cc_type = native_bt.component_class_get_type(comp_cls_ptr)
 
-            if cc_type == self._comp_cls_type:
-                count += 1
+class _PluginSourceComponentClasses(_PluginComponentClasses):
+    _component_class_count = native_bt.plugin_get_source_component_class_count
+    _borrow_component_class_by_name = native_bt.plugin_borrow_source_component_class_by_name_const
+    _borrow_component_class_by_index = native_bt.plugin_borrow_source_component_class_by_index_const
+    _comp_cls_type = native_bt.COMPONENT_CLASS_TYPE_SOURCE
 
-            native_bt.put(comp_cls_ptr)
 
-        return count
+class _PluginFilterComponentClasses(_PluginComponentClasses):
+    _component_class_count = native_bt.plugin_get_filter_component_class_count
+    _borrow_component_class_by_name = native_bt.plugin_borrow_filter_component_class_by_name_const
+    _borrow_component_class_by_index = native_bt.plugin_borrow_filter_component_class_by_index_const
+    _comp_cls_type = native_bt.COMPONENT_CLASS_TYPE_FILTER
 
-    def __iter__(self):
-        return _PluginComponentClassesIterator(self)
+
+class _PluginSinkComponentClasses(_PluginComponentClasses):
+    _component_class_count = native_bt.plugin_get_sink_component_class_count
+    _borrow_component_class_by_name = native_bt.plugin_borrow_sink_component_class_by_name_const
+    _borrow_component_class_by_index = native_bt.plugin_borrow_sink_component_class_by_index_const
+    _comp_cls_type = native_bt.COMPONENT_CLASS_TYPE_SINK
 
 
 class _Plugin(object._SharedObject):
+    _put_ref = native_bt.plugin_put_ref
+    _get_ref = native_bt.plugin_get_ref
+
     @property
     def name(self):
         name = native_bt.plugin_get_name(self._ptr)
@@ -197,19 +200,19 @@ class _Plugin(object._SharedObject):
     def version(self):
         status, major, minor, patch, extra = native_bt.plugin_get_version(self._ptr)
 
-        if status < 0:
+        if status == native_bt.PROPERTY_AVAILABILITY_NOT_AVAILABLE:
             return
 
         return _PluginVersion(major, minor, patch, extra)
 
     @property
     def source_component_classes(self):
-        return _PluginComponentClasses(self, native_bt.COMPONENT_CLASS_TYPE_SOURCE)
+        return _PluginSourceComponentClasses(self)
 
     @property
     def filter_component_classes(self):
-        return _PluginComponentClasses(self, native_bt.COMPONENT_CLASS_TYPE_FILTER)
+        return _PluginFilterComponentClasses(self)
 
     @property
     def sink_component_classes(self):
-        return _PluginComponentClasses(self, native_bt.COMPONENT_CLASS_TYPE_SINK)
+        return _PluginSinkComponentClasses(self)
This page took 0.025469 seconds and 4 git commands to generate.