bt2: Adapt test_graph.py and make it pass
[babeltrace.git] / bindings / python / bt2 / bt2 / graph.py
index c630cf0ac57a95b623bf53ae956b713607f0d36a..998d310b3b268ea8d6640b5a6a06fe2180ee7a52 100644 (file)
@@ -28,62 +28,28 @@ import bt2.port
 import bt2
 
 
-class GraphListenerType:
-    PORT_ADDED = 0
-    PORT_REMOVED = 1
-    PORTS_CONNECTED = 2
-    PORTS_DISCONNECTED = 3
-
-
-def _graph_port_added_listener_from_native(user_listener, port_ptr):
-    try:
-        port = bt2.port._create_from_ptr(port_ptr)
-        port._get()
-        user_listener(port)
-    except:
-        pass
-
-
-def _graph_port_removed_listener_from_native(user_listener, port_ptr):
-    try:
-        port = bt2.port._create_from_ptr(port_ptr)
-        port._get()
-        user_listener(port)
-    except:
-        pass
+def _graph_port_added_listener_from_native(user_listener,
+                                           component_ptr, component_type,
+                                           port_ptr, port_type):
+    component = bt2.component._create_component_from_ptr_and_get_ref(component_ptr, component_type)
+    port = bt2.port._create_from_ptr_and_get_ref(port_ptr, port_type)
+    user_listener(component, port)
 
 
 def _graph_ports_connected_listener_from_native(user_listener,
+                                                upstream_component_ptr, upstream_component_type,
                                                 upstream_port_ptr,
+                                                downstream_component_ptr, downstream_component_type,
                                                 downstream_port_ptr):
-    try:
-        upstream_port = bt2.port._create_from_ptr(upstream_port_ptr)
-        upstream_port._get()
-        downstream_port = bt2.port._create_from_ptr(downstream_port_ptr)
-        downstream_port._get()
-        user_listener(upstream_port, downstream_port)
-    except:
-        pass
-
-
-def _graph_ports_disconnected_listener_from_native(user_listener,
-                                                   upstream_comp_ptr,
-                                                   downstream_comp_ptr,
-                                                   upstream_port_ptr,
-                                                   downstream_port_ptr):
-    try:
-        upstream_comp = bt2.component._create_generic_component_from_ptr(upstream_comp_ptr)
-        upstream_comp._get()
-        downstream_comp = bt2.component._create_generic_component_from_ptr(downstream_comp_ptr)
-        downstream_comp._get()
-        upstream_port = bt2.port._create_from_ptr(upstream_port_ptr)
-        upstream_port._get()
-        downstream_port = bt2.port._create_from_ptr(downstream_port_ptr)
-        downstream_port._get()
-        user_listener(upstream_comp, downstream_comp, upstream_port,
-                      downstream_port)
-    except:
-        pass
+    upstream_component = bt2.component._create_component_from_ptr_and_get_ref(
+        upstream_component_ptr, upstream_component_type)
+    upstream_port = bt2.port._create_from_ptr_and_get_ref(
+        upstream_port_ptr, native_bt.PORT_TYPE_OUTPUT)
+    downstream_component = bt2.component._create_component_from_ptr_and_get_ref(
+        downstream_component_ptr, downstream_component_type)
+    downstream_port = bt2.port._create_from_ptr_and_get_ref(
+        downstream_port_ptr, native_bt.PORT_TYPE_INPUT)
+    user_listener(upstream_component, upstream_port, downstream_component, downstream_port)
 
 
 class Graph(object._SharedObject):
@@ -111,18 +77,30 @@ class Graph(object._SharedObject):
             raise bt2.Error(gen_error_msg)
 
     def add_component(self, component_class, name, params=None):
-        if issubclass(component_class, bt2.component._UserSourceComponent):
-            cc_ptr = component_class._cc_ptr
+        if isinstance(component_class, bt2.component._GenericSourceComponentClass):
+            cc_ptr = component_class._ptr
             add_fn = native_bt.graph_add_source_component
             cc_type = native_bt.COMPONENT_CLASS_TYPE_SOURCE
-        elif issubclass(component_class, bt2.component._UserFilterComponent):
-            cc_ptr = component_class._cc_ptr
+        elif isinstance(component_class, bt2.component._GenericFilterComponentClass):
+            cc_ptr = component_class._ptr
             add_fn = native_bt.graph_add_filter_component
             cc_type = native_bt.COMPONENT_CLASS_TYPE_FILTER
+        elif isinstance(component_class, bt2.component._GenericSinkComponentClass):
+            cc_ptr = component_class._ptr
+            add_fn = native_bt.graph_add_sink_component
+            cc_type = native_bt.COMPONENT_CLASS_TYPE_SINK
+        elif issubclass(component_class, bt2.component._UserSourceComponent):
+            cc_ptr = component_class._cc_ptr
+            add_fn = native_bt.graph_add_source_component
+            cc_type = native_bt.COMPONENT_CLASS_TYPE_SOURCE
         elif issubclass(component_class, bt2.component._UserSinkComponent):
             cc_ptr = component_class._cc_ptr
             add_fn = native_bt.graph_add_sink_component
             cc_type = native_bt.COMPONENT_CLASS_TYPE_SINK
+        elif issubclass(component_class, bt2.component._UserFilterComponent):
+            cc_ptr = component_class._cc_ptr
+            add_fn = native_bt.graph_add_filter_component
+            cc_type = native_bt.COMPONENT_CLASS_TYPE_FILTER
         else:
             raise TypeError("'{}' is not a component class".format(
                 component_class.__class__.__name__))
@@ -147,32 +125,31 @@ class Graph(object._SharedObject):
         assert(conn_ptr)
         return bt2.connection._Connection._create_from_ptr(conn_ptr)
 
-    def add_listener(self, listener_type, listener):
-        if not hasattr(listener, '__call__'):
+    def add_port_added_listener(self, listener):
+        if not callable(listener):
             raise TypeError("'listener' parameter is not callable")
 
-        if listener_type == GraphListenerType.PORT_ADDED:
-            fn = native_bt.py3_graph_add_port_added_listener
-            listener_from_native = functools.partial(_graph_port_added_listener_from_native,
-                                                     listener)
-        elif listener_type == GraphListenerType.PORT_REMOVED:
-            fn = native_bt.py3_graph_add_port_removed_listener
-            listener_from_native = functools.partial(_graph_port_removed_listener_from_native,
-                                                     listener)
-        elif listener_type == GraphListenerType.PORTS_CONNECTED:
-            fn = native_bt.py3_graph_add_ports_connected_listener
-            listener_from_native = functools.partial(_graph_ports_connected_listener_from_native,
-                                                     listener)
-        elif listener_type == GraphListenerType.PORTS_DISCONNECTED:
-            fn = native_bt.py3_graph_add_ports_disconnected_listener
-            listener_from_native = functools.partial(_graph_ports_disconnected_listener_from_native,
-                                                     listener)
-        else:
-            raise TypeError
+        fn = native_bt.py3_graph_add_port_added_listener
+        listener_from_native = functools.partial(_graph_port_added_listener_from_native,
+                                                 listener)
 
-        listener_id = fn(self._ptr, listener_from_native)
-        utils._handle_ret(listener_id, 'cannot add listener to graph object')
-        return bt2._ListenerHandle(listener_id, self)
+        listener_ids = fn(self._ptr, listener_from_native)
+        if listener_ids is None:
+            utils._raise_bt2_error('cannot add listener to graph object')
+        return bt2._ListenerHandle(listener_ids, self)
+
+    def add_ports_connected_listener(self, listener):
+        if not callable(listener):
+            raise TypeError("'listener' parameter is not callable")
+
+        fn = native_bt.py3_graph_add_ports_connected_listener
+        listener_from_native = functools.partial(_graph_ports_connected_listener_from_native,
+                                                 listener)
+
+        listener_ids = fn(self._ptr, listener_from_native)
+        if listener_ids is None:
+            utils._raise_bt2_error('cannot add listener to graph object')
+        return bt2._ListenerHandle(listener_ids, self)
 
     def run(self):
         status = native_bt.graph_run(self._ptr)
@@ -200,9 +177,3 @@ class Graph(object._SharedObject):
             raise bt2.CreationError('cannot create output port message iterator')
 
         return bt2.message_iterator._OutputPortMessageIterator(msg_iter_ptr)
-
-    def __eq__(self, other):
-        if type(other) is not type(self):
-            return False
-
-        return self.addr == other.addr
This page took 0.024647 seconds and 4 git commands to generate.