bt2: Adapt test_event.py and make it pass
[babeltrace.git] / bindings / python / bt2 / bt2 / graph.py
index 56148395deafdea4d657896a7605f586febdf83a..c630cf0ac57a95b623bf53ae956b713607f0d36a 100644 (file)
@@ -86,7 +86,10 @@ def _graph_ports_disconnected_listener_from_native(user_listener,
         pass
 
 
-class Graph(object._Object):
+class Graph(object._SharedObject):
+    _get_ref = staticmethod(native_bt.graph_get_ref)
+    _put_ref = staticmethod(native_bt.graph_put_ref)
+
     def __init__(self):
         ptr = native_bt.graph_create()
 
@@ -104,32 +107,35 @@ class Graph(object._Object):
             raise bt2.Stop
         elif status == native_bt.GRAPH_STATUS_AGAIN:
             raise bt2.TryAgain
-        elif status == native_bt.GRAPH_STATUS_NO_SINK:
-            raise bt2.NoSinkComponent
         elif status < 0:
             raise bt2.Error(gen_error_msg)
 
     def add_component(self, component_class, name, params=None):
-        if isinstance(component_class, bt2.component._GenericComponentClass):
-            cc_ptr = component_class._ptr
-        elif issubclass(component_class, bt2.component._UserComponent):
+        if issubclass(component_class, bt2.component._UserSourceComponent):
+            cc_ptr = component_class._cc_ptr
+            add_fn = native_bt.graph_add_source_component
+            cc_type = native_bt.COMPONENT_CLASS_TYPE_SOURCE
+        elif issubclass(component_class, bt2.component._UserFilterComponent):
             cc_ptr = component_class._cc_ptr
+            add_fn = native_bt.graph_add_filter_component
+            cc_type = native_bt.COMPONENT_CLASS_TYPE_FILTER
+        elif issubclass(component_class, bt2.component._UserSinkComponent):
+            cc_ptr = component_class._cc_ptr
+            add_fn = native_bt.graph_add_sink_component
+            cc_type = native_bt.COMPONENT_CLASS_TYPE_SINK
         else:
-            raise TypeError("'{}' is not a component class".format(component_class.__class__.__name__))
+            raise TypeError("'{}' is not a component class".format(
+                component_class.__class__.__name__))
 
         utils._check_str(name)
         params = bt2.create_value(params)
 
-        if params is None:
-            params_ptr = None
-        else:
-            params_ptr = params._ptr
+        params_ptr = params._ptr if params is not None else None
 
-        status, comp_ptr = native_bt.graph_add_component(self._ptr, cc_ptr,
-                                                         name, params_ptr)
+        status, comp_ptr = add_fn(self._ptr, cc_ptr, name, params_ptr)
         self._handle_status(status, 'cannot add component to graph')
-        assert(comp_ptr)
-        return bt2.component._create_generic_component_from_ptr(comp_ptr)
+        assert comp_ptr
+        return bt2.component._create_component_from_ptr(comp_ptr, cc_type)
 
     def connect_ports(self, upstream_port, downstream_port):
         utils._check_type(upstream_port, bt2.port._OutputPort)
@@ -186,6 +192,15 @@ class Graph(object._Object):
         assert(is_canceled >= 0)
         return is_canceled > 0
 
+    def create_output_port_message_iterator(self, output_port):
+        utils._check_type(output_port, bt2.port._OutputPort)
+        msg_iter_ptr = native_bt.port_output_message_iterator_create(self._ptr, output_port._ptr)
+
+        if msg_iter_ptr is None:
+            raise bt2.CreationError('cannot create output port message iterator')
+
+        return bt2.message_iterator._OutputPortMessageIterator(msg_iter_ptr)
+
     def __eq__(self, other):
         if type(other) is not type(self):
             return False
This page took 0.024389 seconds and 4 git commands to generate.