X-Git-Url: http://git.efficios.com/?a=blobdiff_plain;ds=sidebyside;f=src%2Fbindings%2Fpython%2Fbt2%2Fbt2%2Ftrace_collection_message_iterator.py;h=56b48da9ff8061fe4a96e70870495ec711c313a2;hb=f3c9a159782f70dbd0e5dedb37e4a1ef8a6d304e;hp=9c931fe2d4be568a5362cf1c020be87f8e76e00b;hpb=a83410cd12c9acfe79443782e7a5311a6dd6ca59;p=babeltrace.git diff --git a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py index 9c931fe2..56b48da9 100644 --- a/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py +++ b/src/bindings/python/bt2/bt2/trace_collection_message_iterator.py @@ -20,13 +20,15 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from bt2 import utils +from bt2 import utils, native_bt import bt2 import itertools from bt2 import message_iterator as bt2_message_iterator from bt2 import logging as bt2_logging from bt2 import port as bt2_port from bt2 import component as bt2_component +from bt2 import value as bt2_value +from bt2 import plugin as bt2_plugin import datetime from collections import namedtuple import numbers @@ -36,7 +38,29 @@ import numbers _ComponentAndSpec = namedtuple('_ComponentAndSpec', ['comp', 'spec']) -class ComponentSpec: +class _BaseComponentSpec: + def __init__(self, params, obj, logging_level): + if logging_level is not None: + utils._check_log_level(logging_level) + + self._params = bt2.create_value(params) + self._obj = obj + self._logging_level = logging_level + + @property + def params(self): + return self._params + + @property + def obj(self): + return self._obj + + @property + def logging_level(self): + return self._logging_level + + +class ComponentSpec(_BaseComponentSpec): def __init__( self, plugin_name, @@ -45,18 +69,16 @@ class ComponentSpec: obj=None, logging_level=bt2_logging.LoggingLevel.NONE, ): + if type(params) is str: + params = {'inputs': [params]} + + super().__init__(params, obj, logging_level) + utils._check_str(plugin_name) utils._check_str(class_name) - utils._check_log_level(logging_level) + self._plugin_name = plugin_name self._class_name = class_name - self._logging_level = logging_level - self._obj = obj - - if type(params) is str: - self._params = bt2.create_value({'inputs': [params]}) - else: - self._params = bt2.create_value(params) @property def plugin_name(self): @@ -66,17 +88,102 @@ class ComponentSpec: def class_name(self): return self._class_name - @property - def logging_level(self): - return self._logging_level - @property - def params(self): - return self._params +class AutoSourceComponentSpec(_BaseComponentSpec): + _no_obj = object() + + def __init__(self, input, params=None, obj=_no_obj, logging_level=None): + super().__init__(params, obj, logging_level) + self._input = input @property - def obj(self): - return self._obj + def input(self): + return self._input + + +def _auto_discover_source_component_specs(auto_source_comp_specs, plugin_set): + # Transform a list of `AutoSourceComponentSpec` in a list of `ComponentSpec` + # using the automatic source discovery mechanism. + inputs = bt2.ArrayValue([spec.input for spec in auto_source_comp_specs]) + + if plugin_set is None: + plugin_set = bt2.find_plugins() + else: + utils._check_type(plugin_set, bt2_plugin._PluginSet) + + res_ptr = native_bt.bt2_auto_discover_source_components( + inputs._ptr, plugin_set._ptr + ) + + if res_ptr is None: + raise bt2._MemoryError('cannot auto discover source components') + + res = bt2_value._create_from_ptr(res_ptr) + + assert type(res) == bt2.MapValue + assert 'status' in res + + status = res['status'] + utils._handle_func_status(status, 'cannot auto-discover source components') + + comp_specs = [] + comp_specs_raw = res['results'] + assert type(comp_specs_raw) == bt2.ArrayValue + + for comp_spec_raw in comp_specs_raw: + assert type(comp_spec_raw) == bt2.ArrayValue + assert len(comp_spec_raw) == 4 + + plugin_name = comp_spec_raw[0] + assert type(plugin_name) == bt2.StringValue + plugin_name = str(plugin_name) + + class_name = comp_spec_raw[1] + assert type(class_name) == bt2.StringValue + class_name = str(class_name) + + comp_inputs = comp_spec_raw[2] + assert type(comp_inputs) == bt2.ArrayValue + + comp_orig_indices = comp_spec_raw[3] + assert type(comp_orig_indices) + + params = bt2.MapValue() + logging_level = bt2.LoggingLevel.NONE + obj = None + + # Compute `params` for this component by piling up params given to all + # AutoSourceComponentSpec objects that contributed in the instantiation + # of this component. + # + # The effective log level for a component is the last one specified + # across the AutoSourceComponentSpec that contributed in its + # instantiation. + for idx in comp_orig_indices: + orig_spec = auto_source_comp_specs[idx] + + if orig_spec.params is not None: + params.update(orig_spec.params) + + if orig_spec.logging_level is not None: + logging_level = orig_spec.logging_level + + if orig_spec.obj is not AutoSourceComponentSpec._no_obj: + obj = orig_spec.obj + + params['inputs'] = comp_inputs + + comp_specs.append( + ComponentSpec( + plugin_name, + class_name, + params=params, + obj=obj, + logging_level=logging_level, + ) + ) + + return comp_specs # datetime.datetime or integral to nanoseconds @@ -127,6 +234,7 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): stream_intersection_mode=False, begin=None, end=None, + plugin_set=None, ): utils._check_bool(stream_intersection_mode) self._stream_intersection_mode = stream_intersection_mode @@ -134,15 +242,46 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): self._end_ns = _get_ns(end) self._msg_list = [None] - if type(source_component_specs) is ComponentSpec: + # If a single item is provided, convert to a list. + if type(source_component_specs) in ( + ComponentSpec, + AutoSourceComponentSpec, + str, + ): source_component_specs = [source_component_specs] + # Convert any string to an AutoSourceComponentSpec. + def str_to_auto(item): + if type(item) is str: + item = AutoSourceComponentSpec(item) + + return item + + source_component_specs = [str_to_auto(s) for s in source_component_specs] + if type(filter_component_specs) is ComponentSpec: filter_component_specs = [filter_component_specs] elif filter_component_specs is None: filter_component_specs = [] - self._src_comp_specs = source_component_specs + self._validate_source_component_specs(source_component_specs) + self._validate_filter_component_specs(filter_component_specs) + + # Pass any `ComponentSpec` instance as-is. + self._src_comp_specs = [ + spec for spec in source_component_specs if type(spec) is ComponentSpec + ] + + # Convert any `AutoSourceComponentSpec` in concrete `ComponentSpec` instances. + auto_src_comp_specs = [ + spec + for spec in source_component_specs + if type(spec) is AutoSourceComponentSpec + ] + self._src_comp_specs += _auto_discover_source_component_specs( + auto_src_comp_specs, plugin_set + ) + self._flt_comp_specs = filter_component_specs self._next_suffix = 1 self._connect_ports = False @@ -151,11 +290,21 @@ class TraceCollectionMessageIterator(bt2_message_iterator._MessageIterator): self._src_comps_and_specs = [] self._flt_comps_and_specs = [] - self._validate_component_specs(source_component_specs) - self._validate_component_specs(filter_component_specs) self._build_graph() - def _validate_component_specs(self, comp_specs): + def _validate_source_component_specs(self, comp_specs): + for comp_spec in comp_specs: + if ( + type(comp_spec) is not ComponentSpec + and type(comp_spec) is not AutoSourceComponentSpec + ): + raise TypeError( + '"{}" object is not a ComponentSpec or AutoSourceComponentSpec'.format( + type(comp_spec) + ) + ) + + def _validate_filter_component_specs(self, comp_specs): for comp_spec in comp_specs: if type(comp_spec) is not ComponentSpec: raise TypeError(