Update Python bindings and tests to match the latest API
[babeltrace.git] / bindings / python / bt2 / plugin.py
index 920eb63642d71502868c3cc3becbe69e3c41e3af..a37e7631340d206f0b6988af8777f0f058730e14 100644 (file)
 from bt2 import native_bt, object, utils
 import collections.abc
 import bt2.component
+import os.path
 import bt2
 
 
-def _plugin_ptrs_to_plugins(plugin_ptrs):
-    plugins = []
-
-    for plugin_ptr in plugin_ptrs:
-        plugin = _Plugin._create_from_ptr(plugin_ptr)
-        plugins.append(plugin)
+def find_plugins(path, recurse=True):
+    utils._check_str(path)
+    utils._check_bool(recurse)
+    plugin_set_ptr = None
 
-    return plugins
+    if os.path.isfile(path):
+        plugin_set_ptr = native_bt.plugin_create_all_from_file(path)
+    elif os.path.isdir(path):
+        plugin_set_ptr = native_bt.plugin_create_all_from_dir(path, int(recurse))
 
+    if plugin_set_ptr is None:
+        return
 
-def create_plugins_from_file(path):
-    utils._check_str(path)
-    plugin_ptrs = native_bt.py3_plugin_create_all_from_file(path)
+    return _PluginSet._create_from_ptr(plugin_set_ptr)
 
-    if plugin_ptrs is None:
-        raise bt2.Error('cannot get plugin objects from file')
 
-    return _plugin_ptrs_to_plugins(plugin_ptrs)
+def find_plugin(name):
+    utils._check_str(name)
+    ptr = native_bt.plugin_find(name)
 
+    if ptr is None:
+        return
 
-def create_plugins_from_dir(path, recurse=True):
-    utils._check_str(path)
-    plugin_ptrs = native_bt.py3_plugin_create_all_from_dir(path, recurse)
+    return _Plugin._create_from_ptr(ptr)
 
-    if plugin_ptrs is None:
-        raise bt2.Error('cannot get plugin objects from directory')
-
-    return _plugin_ptrs_to_plugins(plugin_ptrs)
 
+class _PluginSet(object._Object, collections.abc.Sequence):
+    def __len__(self):
+        count = native_bt.plugin_set_get_plugin_count(self._ptr)
+        assert(count >= 0)
+        return count
 
-def create_plugin_from_name(name):
-    utils._check_str(name)
-    plugin_ptr = native_bt.plugin_create_from_name(name)
+    def __getitem__(self, index):
+        utils._check_uint64(index)
 
-    if plugin_ptr is None:
-        raise bt2.NoSuchPluginError(name)
+        if index >= len(self):
+            raise IndexError
 
-    return _Plugin._create_from_ptr(plugin_ptr)
+        plugin_ptr = native_bt.plugin_set_get_plugin(self._ptr, index)
+        assert(plugin_ptr)
+        return _Plugin._create_from_ptr(plugin_ptr)
 
 
 class _PluginVersion:
@@ -98,11 +102,79 @@ class _PluginVersion:
         return '{}.{}.{}{}'.format(self._major, self._minor, self._patch, extra)
 
 
-class _Plugin(object._Object, collections.abc.Sequence):
+class _PluginComponentClassesIterator(collections.abc.Iterator):
+    def __init__(self, plugin_comp_cls):
+        self._plugin_comp_cls = plugin_comp_cls
+        self._at = 0
+
+    def __next__(self):
+        plugin_ptr = self._plugin_comp_cls._plugin._ptr
+        comp_cls_type = self._plugin_comp_cls._comp_cls_type
+        total = native_bt.plugin_get_component_class_count(plugin_ptr)
+
+        while True:
+            if self._at == total:
+                raise StopIteration
+
+            comp_cls_ptr = native_bt.plugin_get_component_class_by_index(plugin_ptr,
+                                                                         self._at)
+            assert(comp_cls_ptr)
+            cc_type = native_bt.component_class_get_type(comp_cls_ptr)
+            self._at += 1
+
+            if cc_type == comp_cls_type:
+                break
+
+            native_bt.put(comp_cls_ptr)
+
+        name = native_bt.component_class_get_name(comp_cls_ptr)
+        native_bt.put(comp_cls_ptr)
+        assert(name is not None)
+        return name
+
+
+class _PluginComponentClasses(collections.abc.Mapping):
+    def __init__(self, plugin, comp_cls_type):
+        self._plugin = plugin
+        self._comp_cls_type = comp_cls_type
+
+    def __getitem__(self, key):
+        utils._check_str(key)
+        cc_ptr = native_bt.plugin_get_component_class_by_name_and_type(self._plugin._ptr,
+                                                                       key,
+                                                                       self._comp_cls_type)
+
+        if cc_ptr is None:
+            raise KeyError(key)
+
+        return bt2.component._create_generic_component_class_from_ptr(cc_ptr)
+
+    def __len__(self):
+        count = 0
+        total = native_bt.plugin_get_component_class_count(self._plugin._ptr)
+
+        for at in range(total):
+            comp_cls_ptr = native_bt.plugin_get_component_class_by_index(self._plugin._ptr,
+                                                                         at)
+            assert(comp_cls_ptr)
+            cc_type = native_bt.component_class_get_type(comp_cls_ptr)
+
+            if cc_type == self._comp_cls_type:
+                count += 1
+
+            native_bt.put(comp_cls_ptr)
+
+        return count
+
+    def __iter__(self):
+        return _PluginComponentClassesIterator(self)
+
+
+class _Plugin(object._Object):
     @property
     def name(self):
         name = native_bt.plugin_get_name(self._ptr)
-        utils._handle_ptr(name, "cannot get plugin object's name")
+        assert(name is not None)
         return name
 
     @property
@@ -130,50 +202,14 @@ class _Plugin(object._Object, collections.abc.Sequence):
 
         return _PluginVersion(major, minor, patch, extra)
 
-    def source_component_class(self, name):
-        utils._check_str(name)
-        cc_ptr = native_bt.plugin_get_component_class_by_name_and_type(self._ptr,
-                                                                       name,
-                                                                       native_bt.COMPONENT_CLASS_TYPE_SOURCE)
-
-        if cc_ptr is None:
-            return
-
-        return bt2.component._create_generic_component_class_from_ptr(cc_ptr)
-
-    def filter_component_class(self, name):
-        utils._check_str(name)
-        cc_ptr = native_bt.plugin_get_component_class_by_name_and_type(self._ptr,
-                                                                       name,
-                                                                       native_bt.COMPONENT_CLASS_TYPE_FILTER)
-
-        if cc_ptr is None:
-            return
-
-        return bt2.component._create_generic_component_class_from_ptr(cc_ptr)
-
-    def sink_component_class(self, name):
-        utils._check_str(name)
-        cc_ptr = native_bt.plugin_get_component_class_by_name_and_type(self._ptr,
-                                                                       name,
-                                                                       native_bt.COMPONENT_CLASS_TYPE_SINK)
-
-        if cc_ptr is None:
-            return
-
-        return bt2.component._create_generic_component_class_from_ptr(cc_ptr)
-
-    def __len__(self):
-        count = native_bt.plugin_get_component_class_count(self._ptr)
-        utils._handle_ret(count, "cannot get plugin object's component class count")
-        return count
-
-    def __getitem__(self, index):
-        utils._check_uint64(index)
+    @property
+    def source_component_classes(self):
+        return _PluginComponentClasses(self, native_bt.COMPONENT_CLASS_TYPE_SOURCE)
 
-        if index >= len(self):
-            raise IndexError
+    @property
+    def filter_component_classes(self):
+        return _PluginComponentClasses(self, native_bt.COMPONENT_CLASS_TYPE_FILTER)
 
-        cc_ptr = native_bt.plugin_get_component_class(self._ptr, index)
-        utils._handle_ptr(cc_ptr, "cannot get plugin object's component class object")
-        return bt2.component._create_generic_component_class_from_ptr(cc_ptr)
+    @property
+    def sink_component_classes(self):
+        return _PluginComponentClasses(self, native_bt.COMPONENT_CLASS_TYPE_SINK)
This page took 0.026949 seconds and 4 git commands to generate.