lib, bt2: add precondition check for port name unicity
[babeltrace.git] / tests / bindings / python / bt2 / test_port.py
index c404d8fac8cd1cf9286383f7c717edce3fb1d648..e6db06ea53c0736b3cf4facaa86e869179789483 100644 (file)
@@ -1,20 +1,7 @@
+# SPDX-License-Identifier: GPL-2.0-only
 #
 # Copyright (C) 2019 EfficiOS Inc.
 #
-# This program is free software; you can redistribute it and/or
-# modify it under the terms of the GNU General Public License
-# as published by the Free Software Foundation; only version 2
-# of the License.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
-# GNU General Public License for more details.
-#
-# You should have received a copy of the GNU General Public License
-# along with this program; if not, write to the Free Software
-# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
-#
 
 import unittest
 import bt2
@@ -32,11 +19,9 @@ class PortTestCase(unittest.TestCase):
         return graph.add_component(comp_cls, name)
 
     def test_src_add_output_port(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port = comp_self._add_output_port('out')
                 self.assertEqual(port.name, 'out')
@@ -45,12 +30,31 @@ class PortTestCase(unittest.TestCase):
         self.assertEqual(len(comp.output_ports), 1)
         self.assertIs(type(comp.output_ports['out']), bt2_port._OutputPortConst)
 
-    def test_flt_add_output_port(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
+    # Test adding output port with duplicate name to source.
+    def test_src_add_output_port_dup_name_raises(self):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
+            def __init__(comp_self, config, params, obj):
+                comp_self._add_output_port('out')
+
+                with self.assertRaisesRegex(
+                    ValueError,
+                    "source component `comp` already contains an output port named `out`",
+                ):
+                    comp_self._add_output_port('out')
+
+                nonlocal seen
+                seen = True
+
+        seen = False
+        self._create_comp(MySource)
+        self.assertTrue(seen)
 
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+    def test_flt_add_output_port(self):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port = comp_self._add_output_port('out')
                 self.assertEqual(port.name, 'out')
@@ -58,12 +62,31 @@ class PortTestCase(unittest.TestCase):
         comp = self._create_comp(MyFilter)
         self.assertEqual(len(comp.output_ports), 1)
 
-    def test_flt_add_input_port(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
+    # Test adding output port with duplicate name to filter.
+    def test_flt_add_output_port_dup_name_raises(self):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
+            def __init__(comp_self, config, params, obj):
+                comp_self._add_output_port('out')
+
+                with self.assertRaisesRegex(
+                    ValueError,
+                    "filter component `comp` already contains an output port named `out`",
+                ):
+                    comp_self._add_output_port('out')
 
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+                nonlocal seen
+                seen = True
+
+        seen = False
+        self._create_comp(MyFilter)
+        self.assertTrue(seen)
+
+    def test_flt_add_input_port(self):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port = comp_self._add_input_port('in')
                 self.assertEqual(port.name, 'in')
@@ -72,6 +95,27 @@ class PortTestCase(unittest.TestCase):
         self.assertEqual(len(comp.input_ports), 1)
         self.assertIs(type(comp.input_ports['in']), bt2_port._InputPortConst)
 
+    # Test adding input port with duplicate name to filter.
+    def test_flt_add_input_port_dup_name_raises(self):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
+            def __init__(comp_self, config, params, obj):
+                comp_self._add_input_port('in')
+
+                with self.assertRaisesRegex(
+                    ValueError,
+                    "filter component `comp` already contains an input port named `in`",
+                ):
+                    comp_self._add_input_port('in')
+
+                nonlocal seen
+                seen = True
+
+        seen = False
+        self._create_comp(MyFilter)
+        self.assertTrue(seen)
+
     def test_sink_add_input_port(self):
         class MySink(bt2._UserSinkComponent):
             def __init__(comp_self, config, params, obj):
@@ -84,12 +128,32 @@ class PortTestCase(unittest.TestCase):
         comp = self._create_comp(MySink)
         self.assertEqual(len(comp.input_ports), 1)
 
-    def test_user_src_output_ports_getitem(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
+    # Test adding input port with duplicate name to sink.
+    def test_sink_add_input_port_dup_name_raises(self):
+        class MySink(bt2._UserSinkComponent):
+            def __init__(comp_self, config, params, obj):
+                comp_self._add_input_port('in')
 
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+                with self.assertRaisesRegex(
+                    ValueError,
+                    "sink component `comp` already contains an input port named `in`",
+                ):
+                    comp_self._add_input_port('in')
+
+                nonlocal seen
+                seen = True
+
+            def _user_consume(self):
+                pass
+
+        seen = False
+        self._create_comp(MySink)
+        self.assertTrue(seen)
+
+    def test_user_src_output_ports_getitem(self):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port1 = comp_self._add_output_port('clear')
                 port2 = comp_self._add_output_port('print')
@@ -101,11 +165,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySource)
 
     def test_user_flt_output_ports_getitem(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port1 = comp_self._add_output_port('clear')
                 port2 = comp_self._add_output_port('print')
@@ -117,11 +179,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MyFilter)
 
     def test_user_flt_input_ports_getitem(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port1 = comp_self._add_input_port('clear')
                 port2 = comp_self._add_input_port('print')
@@ -148,11 +208,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySink)
 
     def test_user_src_output_ports_getitem_invalid_key(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_output_port('clear')
                 comp_self._add_output_port('print')
@@ -164,11 +222,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySource)
 
     def test_user_flt_output_ports_getitem_invalid_key(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_output_port('clear')
                 comp_self._add_output_port('print')
@@ -180,11 +236,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MyFilter)
 
     def test_user_flt_input_ports_getitem_invalid_key(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_input_port('clear')
                 comp_self._add_input_port('print')
@@ -211,11 +265,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySink)
 
     def test_user_src_output_ports_len(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_output_port('clear')
                 comp_self._add_output_port('print')
@@ -225,11 +277,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySource)
 
     def test_user_flt_output_ports_len(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_output_port('clear')
                 comp_self._add_output_port('print')
@@ -239,11 +289,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MyFilter)
 
     def test_user_flt_input_ports_len(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_input_port('clear')
                 comp_self._add_input_port('print')
@@ -266,11 +314,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySink)
 
     def test_user_src_output_ports_iter(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port1 = comp_self._add_output_port('clear')
                 port2 = comp_self._add_output_port('print')
@@ -290,11 +336,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySource)
 
     def test_user_flt_output_ports_iter(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port1 = comp_self._add_output_port('clear')
                 port2 = comp_self._add_output_port('print')
@@ -314,11 +358,9 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MyFilter)
 
     def test_user_flt_input_ports_iter(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 port1 = comp_self._add_input_port('clear')
                 port2 = comp_self._add_input_port('print')
@@ -361,15 +403,13 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySink)
 
     def test_gen_src_output_ports_getitem(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
         port1 = None
         port2 = None
         port3 = None
 
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 nonlocal port1, port2, port3
                 port1 = comp_self._add_output_port('clear')
@@ -385,15 +425,13 @@ class PortTestCase(unittest.TestCase):
         del port3
 
     def test_gen_flt_output_ports_getitem(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
         port1 = None
         port2 = None
         port3 = None
 
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 nonlocal port1, port2, port3
                 port1 = comp_self._add_output_port('clear')
@@ -409,15 +447,13 @@ class PortTestCase(unittest.TestCase):
         del port3
 
     def test_gen_flt_input_ports_getitem(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
         port1 = None
         port2 = None
         port3 = None
 
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 nonlocal port1, port2, port3
                 port1 = comp_self._add_input_port('clear')
@@ -456,11 +492,9 @@ class PortTestCase(unittest.TestCase):
         del port3
 
     def test_gen_src_output_ports_getitem_invalid_key(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_output_port('clear')
                 comp_self._add_output_port('print')
@@ -472,11 +506,9 @@ class PortTestCase(unittest.TestCase):
             comp.output_ports['hello']
 
     def test_gen_flt_output_ports_getitem_invalid_key(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_output_port('clear')
                 comp_self._add_output_port('print')
@@ -488,11 +520,9 @@ class PortTestCase(unittest.TestCase):
             comp.output_ports['hello']
 
     def test_gen_flt_input_ports_getitem_invalid_key(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_input_port('clear')
                 comp_self._add_input_port('print')
@@ -522,11 +552,9 @@ class PortTestCase(unittest.TestCase):
             comp.input_ports['hello']
 
     def test_gen_src_output_ports_len(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_output_port('clear')
                 comp_self._add_output_port('print')
@@ -536,11 +564,9 @@ class PortTestCase(unittest.TestCase):
         self.assertEqual(len(comp.output_ports), 3)
 
     def test_gen_flt_output_ports_len(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_output_port('clear')
                 comp_self._add_output_port('print')
@@ -550,11 +576,9 @@ class PortTestCase(unittest.TestCase):
         self.assertEqual(len(comp.output_ports), 3)
 
     def test_gen_flt_input_ports_len(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 comp_self._add_input_port('clear')
                 comp_self._add_input_port('print')
@@ -577,15 +601,13 @@ class PortTestCase(unittest.TestCase):
         self.assertEqual(len(comp.input_ports), 3)
 
     def test_gen_src_output_ports_iter(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
         port1 = None
         port2 = None
         port3 = None
 
-        class MySource(bt2._UserSourceComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserSourceComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 nonlocal port1, port2, port3
                 port1 = comp_self._add_output_port('clear')
@@ -609,15 +631,13 @@ class PortTestCase(unittest.TestCase):
         del port3
 
     def test_gen_flt_output_ports_iter(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
         port1 = None
         port2 = None
         port3 = None
 
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 nonlocal port1, port2, port3
                 port1 = comp_self._add_output_port('clear')
@@ -641,15 +661,13 @@ class PortTestCase(unittest.TestCase):
         del port3
 
     def test_gen_flt_input_ports_iter(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
-
         port1 = None
         port2 = None
         port3 = None
 
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 nonlocal port1, port2, port3
                 port1 = comp_self._add_input_port('clear')
@@ -770,11 +788,14 @@ class PortTestCase(unittest.TestCase):
         self._create_comp(MySink)
 
     def test_source_self_port_user_data(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
+        class MyUserData:
+            def __del__(self):
+                nonlocal objects_deleted
+                objects_deleted += 1
 
-        class MySource(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MySource(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 nonlocal user_datas
 
@@ -782,18 +803,33 @@ class PortTestCase(unittest.TestCase):
                 user_datas.append(p.user_data)
                 p = comp_self._add_output_port('port2', 2)
                 user_datas.append(p.user_data)
+                p = comp_self._add_output_port('port3', MyUserData())
+                user_datas.append(p.user_data)
 
         user_datas = []
+        objects_deleted = 0
 
-        self._create_comp(MySource)
-        self.assertEqual(user_datas, [None, 2])
+        comp = self._create_comp(MySource)
+        self.assertEqual(len(user_datas), 3)
+        self.assertIs(user_datas[0], None)
+        self.assertEqual(user_datas[1], 2)
+        self.assertIs(type(user_datas[2]), MyUserData)
+
+        # Verify that the user data gets freed.
+        self.assertEqual(objects_deleted, 0)
+        del user_datas
+        del comp
+        self.assertEqual(objects_deleted, 1)
 
     def test_filter_self_port_user_data(self):
-        class MyIter(bt2._UserMessageIterator):
-            def __next__(self):
-                raise bt2.Stop
+        class MyUserData:
+            def __del__(self):
+                nonlocal objects_deleted
+                objects_deleted += 1
 
-        class MyFilter(bt2._UserFilterComponent, message_iterator_class=MyIter):
+        class MyFilter(
+            bt2._UserFilterComponent, message_iterator_class=bt2._UserMessageIterator
+        ):
             def __init__(comp_self, config, params, obj):
                 nonlocal user_datas
 
@@ -801,36 +837,65 @@ class PortTestCase(unittest.TestCase):
                 user_datas.append(p.user_data)
                 p = comp_self._add_output_port('port2', 'user data string')
                 user_datas.append(p.user_data)
+                p = comp_self._add_output_port('port3', MyUserData())
+                user_datas.append(p.user_data)
 
-                p = comp_self._add_input_port('port3')
+                p = comp_self._add_input_port('port4')
                 user_datas.append(p.user_data)
-                p = comp_self._add_input_port('port4', user_data={'user data': 'dict'})
+                p = comp_self._add_input_port('port5', user_data={'user data': 'dict'})
+                user_datas.append(p.user_data)
+                p = comp_self._add_input_port('port6', MyUserData())
                 user_datas.append(p.user_data)
 
         user_datas = []
+        objects_deleted = 0
 
-        self._create_comp(MyFilter)
-        self.assertEqual(
-            user_datas, [None, 'user data string', None, {'user data': 'dict'}]
-        )
+        comp = self._create_comp(MyFilter)
+        self.assertEqual(len(user_datas), 6)
+        self.assertIs(user_datas[0], None)
+        self.assertEqual(user_datas[1], 'user data string')
+        self.assertIs(type(user_datas[2]), MyUserData)
+        self.assertIs(user_datas[3], None)
+        self.assertEqual(user_datas[4], {'user data': 'dict'})
+        self.assertIs(type(user_datas[5]), MyUserData)
+
+        # Verify that the user data gets freed.
+        self.assertEqual(objects_deleted, 0)
+        del user_datas
+        del comp
+        self.assertEqual(objects_deleted, 2)
 
     def test_sink_self_port_user_data(self):
+        class MyUserData:
+            def __del__(self):
+                nonlocal objects_deleted
+                objects_deleted += 1
+
         class MySink(bt2._UserSinkComponent):
             def __init__(comp_self, config, params, obj):
                 nonlocal user_datas
 
                 p = comp_self._add_input_port('port1')
                 user_datas.append(p.user_data)
-                p = comp_self._add_input_port('port2', set())
+                p = comp_self._add_input_port('port2', MyUserData())
                 user_datas.append(p.user_data)
 
             def _user_consume(self):
                 pass
 
         user_datas = []
+        objects_deleted = 0
 
-        self._create_comp(MySink)
-        self.assertEqual(user_datas, [None, set()])
+        comp = self._create_comp(MySink)
+        self.assertEqual(len(user_datas), 2)
+        self.assertIs(user_datas[0], None)
+        self.assertIs(type(user_datas[1]), MyUserData)
+
+        # Verify that the user data gets freed.
+        self.assertEqual(objects_deleted, 0)
+        del user_datas
+        del comp
+        self.assertEqual(objects_deleted, 1)
 
 
 if __name__ == '__main__':
This page took 0.031613 seconds and 4 git commands to generate.