From a373e036294d14b0abb8394d2636b04216e8ac89 Mon Sep 17 00:00:00 2001 From: Simon Marchi Date: Mon, 28 Oct 2019 17:00:11 -0400 Subject: [PATCH] bt2: validate parameters to _TraceClass.create_stream_class before creating the native object When creating a stream class (in _TraceClass.create_stream_class), we start by creating the native object, then assign each property passed as arguments to create_stream_class. If the value of a parameter is invalid (e.g. wrong type), we will raise an exception. The problem is that at this point, the stream class object has already been created and added to its parent, the trace class. This leaves the user in an uncomfortable position, where a stream class has been created, with the parameters only partially applied, while an exception was raised. Instead, the stream class native object should be created only once we know that all the parameters are valid. This patch makes it so we validate all parameters before calling the native creation function. Given that the setters in _StreamClass are internal and only used in _TraceClass.create_stream_class, we don't need to validate the values of the parameters again in them. I have added tests for two conditions that weren't tested, when passing a wrong parameter type to assigns_automatic_event_class_id and assigns_automatic_stream_id. Otherwise, I added assertions to the test to make sure that when the stream class Python object couldn't be created, no stream class object was added as a child of the trace class. Change-Id: I18cbb2e8128cf49e6a6411a225352f279aec5d02 Signed-off-by: Simon Marchi Reviewed-on: https://review.lttng.org/c/babeltrace/+/2279 Tested-by: jenkins --- src/bindings/python/bt2/bt2/stream_class.py | 178 +++++++++++------- src/bindings/python/bt2/bt2/trace_class.py | 17 ++ .../bindings/python/bt2/test_stream_class.py | 62 +++++- 3 files changed, 184 insertions(+), 73 deletions(-) diff --git a/src/bindings/python/bt2/bt2/stream_class.py b/src/bindings/python/bt2/bt2/stream_class.py index d4baa89a..ecab1df8 100644 --- a/src/bindings/python/bt2/bt2/stream_class.py +++ b/src/bindings/python/bt2/bt2/stream_class.py @@ -259,67 +259,32 @@ class _StreamClass(_StreamClassConst): def _user_attributes(self, user_attributes): value = bt2_value.create_value(user_attributes) - utils._check_type(value, bt2_value.MapValue) native_bt.stream_class_set_user_attributes(self._ptr, value._ptr) _user_attributes = property(fset=_user_attributes) def _name(self, name): - utils._check_str(name) status = native_bt.stream_class_set_name(self._ptr, name) utils._handle_func_status(status, "cannot set stream class object's name") _name = property(fset=_name) def _assigns_automatic_event_class_id(self, auto_id): - utils._check_bool(auto_id) - return native_bt.stream_class_set_assigns_automatic_event_class_id( - self._ptr, auto_id - ) + native_bt.stream_class_set_assigns_automatic_event_class_id(self._ptr, auto_id) _assigns_automatic_event_class_id = property(fset=_assigns_automatic_event_class_id) def _assigns_automatic_stream_id(self, auto_id): - utils._check_bool(auto_id) - return native_bt.stream_class_set_assigns_automatic_stream_id( - self._ptr, auto_id - ) + native_bt.stream_class_set_assigns_automatic_stream_id(self._ptr, auto_id) _assigns_automatic_stream_id = property(fset=_assigns_automatic_stream_id) def _set_supports_packets(self, supports, with_begin_cs=False, with_end_cs=False): - utils._check_bool(supports) - utils._check_bool(with_begin_cs) - utils._check_bool(with_end_cs) - - if not supports: - if with_begin_cs: - raise ValueError( - 'cannot not support packets, but have packet beginning default clock snapshot' - ) - if with_end_cs: - raise ValueError( - 'cannot not support packets, but have packet end default clock snapshots' - ) - - if not supports and self.packet_context_field_class is not None: - raise ValueError( - 'cannot disable support for packets, stream class already has a packet context field class' - ) - native_bt.stream_class_set_supports_packets( self._ptr, supports, with_begin_cs, with_end_cs ) def _set_supports_discarded_events(self, supports, with_cs=False): - utils._check_bool(supports) - utils._check_bool(with_cs) - - if not supports and with_cs: - raise ValueError( - 'cannot not support discarded events, but have default clock snapshots for discarded event messages' - ) - native_bt.stream_class_set_supports_discarded_events( self._ptr, supports, with_cs ) @@ -327,19 +292,6 @@ class _StreamClass(_StreamClassConst): _supports_discarded_events = property(fset=_set_supports_discarded_events) def _set_supports_discarded_packets(self, supports, with_cs): - utils._check_bool(supports) - utils._check_bool(with_cs) - - if supports and not self.supports_packets: - raise ValueError( - 'cannot support discarded packets, but not support packets' - ) - - if not supports and with_cs: - raise ValueError( - 'cannot not support discarded packets, but have default clock snapshots for discarded packet messages' - ) - native_bt.stream_class_set_supports_discarded_packets( self._ptr, supports, with_cs ) @@ -347,41 +299,123 @@ class _StreamClass(_StreamClassConst): _supports_discarded_packets = property(fset=_set_supports_discarded_packets) def _packet_context_field_class(self, packet_context_field_class): - if packet_context_field_class is not None: - utils._check_type( - packet_context_field_class, bt2_field_class._StructureFieldClass - ) + status = native_bt.stream_class_set_packet_context_field_class( + self._ptr, packet_context_field_class._ptr + ) + utils._handle_func_status( + status, "cannot set stream class' packet context field class" + ) + + _packet_context_field_class = property(fset=_packet_context_field_class) + + def _event_common_context_field_class(self, event_common_context_field_class): + set_context_fn = native_bt.stream_class_set_event_common_context_field_class + status = set_context_fn(self._ptr, event_common_context_field_class._ptr) + utils._handle_func_status( + status, "cannot set stream class' event context field type" + ) - if not self.supports_packets: + _event_common_context_field_class = property(fset=_event_common_context_field_class) + + def _default_clock_class(self, clock_class): + native_bt.stream_class_set_default_clock_class(self._ptr, clock_class._ptr) + + _default_clock_class = property(fset=_default_clock_class) + + @classmethod + def _validate_create_params( + cls, + name, + user_attributes, + packet_context_field_class, + event_common_context_field_class, + default_clock_class, + assigns_automatic_event_class_id, + assigns_automatic_stream_id, + supports_packets, + packets_have_beginning_default_clock_snapshot, + packets_have_end_default_clock_snapshot, + supports_discarded_events, + discarded_events_have_default_clock_snapshots, + supports_discarded_packets, + discarded_packets_have_default_clock_snapshots, + ): + # Name + if name is not None: + utils._check_str(name) + + # User attributes + if user_attributes is not None: + value = bt2_value.create_value(user_attributes) + utils._check_type(value, bt2_value.MapValue) + + # Packet context field class + if packet_context_field_class is not None: + if not supports_packets: raise ValueError( 'cannot have a packet context field class without supporting packets' ) - status = native_bt.stream_class_set_packet_context_field_class( - self._ptr, packet_context_field_class._ptr - ) - utils._handle_func_status( - status, "cannot set stream class' packet context field class" + utils._check_type( + packet_context_field_class, bt2_field_class._StructureFieldClass ) - _packet_context_field_class = property(fset=_packet_context_field_class) - - def _event_common_context_field_class(self, event_common_context_field_class): + # Event common context field class if event_common_context_field_class is not None: utils._check_type( event_common_context_field_class, bt2_field_class._StructureFieldClass ) - set_context_fn = native_bt.stream_class_set_event_common_context_field_class - status = set_context_fn(self._ptr, event_common_context_field_class._ptr) - utils._handle_func_status( - status, "cannot set stream class' event context field type" + # Default clock class + if default_clock_class is not None: + utils._check_type(default_clock_class, bt2_clock_class._ClockClass) + + # Assigns automatic event class id + utils._check_bool(assigns_automatic_event_class_id) + + # Assigns automatic stream id + utils._check_bool(assigns_automatic_stream_id) + + # Packets + utils._check_bool(supports_packets) + utils._check_bool(packets_have_beginning_default_clock_snapshot) + utils._check_bool(packets_have_end_default_clock_snapshot) + + if not supports_packets: + if packets_have_beginning_default_clock_snapshot: + raise ValueError( + 'cannot not support packets, but have packet beginning default clock snapshot' + ) + if packets_have_end_default_clock_snapshot: + raise ValueError( + 'cannot not support packets, but have packet end default clock snapshots' + ) + + # Discarded events + utils._check_bool(supports_discarded_events) + utils._check_bool(discarded_events_have_default_clock_snapshots) + + if ( + not supports_discarded_events + and discarded_events_have_default_clock_snapshots + ): + raise ValueError( + 'cannot not support discarded events, but have default clock snapshots for discarded event messages' ) - _event_common_context_field_class = property(fset=_event_common_context_field_class) + # Discarded packets + utils._check_bool(supports_discarded_packets) + utils._check_bool(discarded_packets_have_default_clock_snapshots) - def _default_clock_class(self, clock_class): - utils._check_type(clock_class, bt2_clock_class._ClockClass) - native_bt.stream_class_set_default_clock_class(self._ptr, clock_class._ptr) + if supports_discarded_packets and not supports_packets: + raise ValueError( + 'cannot support discarded packets, but not support packets' + ) - _default_clock_class = property(fset=_default_clock_class) + if ( + not supports_discarded_packets + and discarded_packets_have_default_clock_snapshots + ): + raise ValueError( + 'cannot not support discarded packets, but have default clock snapshots for discarded packet messages' + ) diff --git a/src/bindings/python/bt2/bt2/trace_class.py b/src/bindings/python/bt2/bt2/trace_class.py index 1d05dba9..08d3e91f 100644 --- a/src/bindings/python/bt2/bt2/trace_class.py +++ b/src/bindings/python/bt2/bt2/trace_class.py @@ -190,6 +190,23 @@ class _TraceClass(_TraceClassConst): supports_discarded_packets=False, discarded_packets_have_default_clock_snapshots=False, ): + # Validate parameters before we create the object. + bt2_stream_class._StreamClass._validate_create_params( + name, + user_attributes, + packet_context_field_class, + event_common_context_field_class, + default_clock_class, + assigns_automatic_event_class_id, + assigns_automatic_stream_id, + supports_packets, + packets_have_beginning_default_clock_snapshot, + packets_have_end_default_clock_snapshot, + supports_discarded_events, + discarded_events_have_default_clock_snapshots, + supports_discarded_packets, + discarded_packets_have_default_clock_snapshots, + ) if self.assigns_automatic_stream_class_id: if id is not None: diff --git a/tests/bindings/python/bt2/test_stream_class.py b/tests/bindings/python/bt2/test_stream_class.py index be73d46e..09537ee8 100644 --- a/tests/bindings/python/bt2/test_stream_class.py +++ b/tests/bindings/python/bt2/test_stream_class.py @@ -62,6 +62,8 @@ class StreamClassTestCase(unittest.TestCase): with self.assertRaisesRegex(TypeError, "'int' is not a 'str' object"): self._tc.create_stream_class(name=17) + self.assertEqual(len(self._tc), 0) + def test_create_packet_context_field_class(self): fc = self._tc.create_structure_field_class() sc = self._tc.create_stream_class( @@ -77,7 +79,11 @@ class StreamClassTestCase(unittest.TestCase): TypeError, "'int' is not a '' object", ): - self._tc.create_stream_class(packet_context_field_class=22) + self._tc.create_stream_class( + packet_context_field_class=22, supports_packets=True + ) + + self.assertEqual(len(self._tc), 0) def test_create_invalid_packet_context_field_class_no_packets(self): fc = self._tc.create_structure_field_class() @@ -88,6 +94,8 @@ class StreamClassTestCase(unittest.TestCase): ): self._tc.create_stream_class(packet_context_field_class=fc) + self.assertEqual(len(self._tc), 0) + def test_create_event_common_context_field_class(self): fc = self._tc.create_structure_field_class() sc = self._tc.create_stream_class(event_common_context_field_class=fc) @@ -104,6 +112,8 @@ class StreamClassTestCase(unittest.TestCase): ): self._tc.create_stream_class(event_common_context_field_class=22) + self.assertEqual(len(self._tc), 0) + def test_create_default_clock_class(self): sc = self._tc.create_stream_class(default_clock_class=self._cc) self.assertEqual(sc.default_clock_class.addr, self._cc.addr) @@ -115,6 +125,8 @@ class StreamClassTestCase(unittest.TestCase): ): self._tc.create_stream_class(default_clock_class=12) + self.assertEqual(len(self._tc), 0) + def test_create_user_attributes(self): sc = self._tc.create_stream_class(user_attributes={'salut': 23}) self.assertEqual(sc.user_attributes, {'salut': 23}) @@ -125,6 +137,8 @@ class StreamClassTestCase(unittest.TestCase): ): self._tc.create_stream_class(user_attributes=object()) + self.assertEqual(len(self._tc), 0) + def test_create_invalid_user_attributes_value_type(self): with self.assertRaisesRegex( TypeError, @@ -132,6 +146,8 @@ class StreamClassTestCase(unittest.TestCase): ): self._tc.create_stream_class(user_attributes=23) + self.assertEqual(len(self._tc), 0) + def test_automatic_stream_ids(self): sc = self._tc.create_stream_class(assigns_automatic_stream_id=True) self.assertTrue(sc.assigns_automatic_stream_id) @@ -148,6 +164,14 @@ class StreamClassTestCase(unittest.TestCase): ): self._trace.create_stream(sc, id=123) + self.assertEqual(len(self._trace), 0) + + def test_automatic_stream_ids_wrong_type(self): + with self.assertRaisesRegex(TypeError, "str' is not a 'bool' object"): + self._tc.create_stream_class(assigns_automatic_stream_id='True') + + self.assertEqual(len(self._tc), 0) + def test_no_automatic_stream_ids(self): sc = self._tc.create_stream_class(assigns_automatic_stream_id=False) self.assertFalse(sc.assigns_automatic_stream_id) @@ -165,6 +189,8 @@ class StreamClassTestCase(unittest.TestCase): ): self._trace.create_stream(sc) + self.assertEqual(len(self._trace), 0) + def test_automatic_event_class_ids(self): sc = self._tc.create_stream_class(assigns_automatic_event_class_id=True) self.assertTrue(sc.assigns_automatic_event_class_id) @@ -182,6 +208,14 @@ class StreamClassTestCase(unittest.TestCase): ): sc.create_event_class(id=123) + self.assertEqual(len(sc), 0) + + def test_automatic_event_class_ids_wrong_type(self): + with self.assertRaisesRegex(TypeError, "'str' is not a 'bool' object"): + self._tc.create_stream_class(assigns_automatic_event_class_id='True') + + self.assertEqual(len(self._tc), 0) + def test_no_automatic_event_class_ids(self): sc = self._tc.create_stream_class(assigns_automatic_event_class_id=False) self.assertFalse(sc.assigns_automatic_event_class_id) @@ -199,6 +233,8 @@ class StreamClassTestCase(unittest.TestCase): ): sc.create_event_class() + self.assertEqual(len(sc), 0) + def test_supports_packets_without_cs(self): sc = self._tc.create_stream_class( default_clock_class=self._cc, supports_packets=True @@ -233,6 +269,8 @@ class StreamClassTestCase(unittest.TestCase): default_clock_class=self._cc, supports_packets=23 ) + self.assertEqual(len(self._tc), 0) + def test_packets_have_begin_default_cs_raises_type_error(self): with self.assertRaisesRegex(TypeError, "'int' is not a 'bool' object"): self._tc.create_stream_class( @@ -240,12 +278,16 @@ class StreamClassTestCase(unittest.TestCase): packets_have_beginning_default_clock_snapshot=23, ) + self.assertEqual(len(self._tc), 0) + def test_packets_have_end_default_cs_raises_type_error(self): with self.assertRaisesRegex(TypeError, "'int' is not a 'bool' object"): self._tc.create_stream_class( default_clock_class=self._cc, packets_have_end_default_clock_snapshot=23 ) + self.assertEqual(len(self._tc), 0) + def test_does_not_support_packets_raises_with_begin_cs(self): with self.assertRaisesRegex( ValueError, @@ -256,6 +298,8 @@ class StreamClassTestCase(unittest.TestCase): packets_have_beginning_default_clock_snapshot=True, ) + self.assertEqual(len(self._tc), 0) + def test_does_not_support_packets_raises_with_end_cs(self): with self.assertRaisesRegex( ValueError, @@ -266,6 +310,8 @@ class StreamClassTestCase(unittest.TestCase): packets_have_end_default_clock_snapshot=True, ) + self.assertEqual(len(self._tc), 0) + def test_supports_discarded_events_without_cs(self): sc = self._tc.create_stream_class( default_clock_class=self._cc, supports_discarded_events=True @@ -288,6 +334,8 @@ class StreamClassTestCase(unittest.TestCase): default_clock_class=self._cc, supports_discarded_events=23 ) + self.assertEqual(len(self._tc), 0) + def test_discarded_events_have_default_cs_raises_type_error(self): with self.assertRaisesRegex(TypeError, "'int' is not a 'bool' object"): self._tc.create_stream_class( @@ -295,6 +343,8 @@ class StreamClassTestCase(unittest.TestCase): discarded_events_have_default_clock_snapshots=23, ) + self.assertEqual(len(self._tc), 0) + def test_does_not_support_discarded_events_raises_with_cs(self): with self.assertRaisesRegex( ValueError, @@ -305,6 +355,8 @@ class StreamClassTestCase(unittest.TestCase): discarded_events_have_default_clock_snapshots=True, ) + self.assertEqual(len(self._tc), 0) + def test_supports_discarded_packets_without_cs(self): sc = self._tc.create_stream_class( default_clock_class=self._cc, @@ -332,6 +384,8 @@ class StreamClassTestCase(unittest.TestCase): default_clock_class=self._cc, supports_discarded_packets=True ) + self.assertEqual(len(self._tc), 0) + def test_supports_discarded_packets_raises_type_error(self): with self.assertRaisesRegex(TypeError, "'int' is not a 'bool' object"): self._tc.create_stream_class( @@ -340,6 +394,8 @@ class StreamClassTestCase(unittest.TestCase): supports_packets=True, ) + self.assertEqual(len(self._tc), 0) + def test_discarded_packets_have_default_cs_raises_type_error(self): with self.assertRaisesRegex(TypeError, "'int' is not a 'bool' object"): self._tc.create_stream_class( @@ -348,6 +404,8 @@ class StreamClassTestCase(unittest.TestCase): supports_packets=True, ) + self.assertEqual(len(self._tc), 0) + def test_does_not_support_discarded_packets_raises_with_cs(self): with self.assertRaisesRegex( ValueError, @@ -359,6 +417,8 @@ class StreamClassTestCase(unittest.TestCase): supports_packets=True, ) + self.assertEqual(len(self._tc), 0) + def test_trace_class(self): sc = self._tc.create_stream_class() self.assertEqual(sc.trace_class.addr, self._tc.addr) -- 2.34.1