From 1624d186a967ccd3285a5fcb90ab700e8ffc784c Mon Sep 17 00:00:00 2001 From: Philippe Proulx Date: Thu, 3 Sep 2020 13:16:42 -0400 Subject: [PATCH] cgen.py: add type hints Signed-off-by: Philippe Proulx --- barectf/cgen.py | 163 ++++++++++++++++++++++++++------------------ barectf/template.py | 4 +- 2 files changed, 99 insertions(+), 68 deletions(-) diff --git a/barectf/cgen.py b/barectf/cgen.py index ebc85e1..6c9cf4a 100644 --- a/barectf/cgen.py +++ b/barectf/cgen.py @@ -25,6 +25,9 @@ import barectf.template as barectf_template import barectf.config as barectf_config import collections import copy +from typing import List, Optional, Mapping, Callable, Any, Set, Tuple +import typing +from barectf.typing import Count, Alignment # A tuple containing serialization and size computation function @@ -46,7 +49,8 @@ _OpTemplates = collections.namedtuple('_OpTemplates', ['serialize', 'size']) # * Serialization and size computation templates to generate the # operation's source code for those functions. class _Op: - def __init__(self, offset_in_byte, ft, names, templates): + def __init__(self, offset_in_byte: Count, ft: barectf_config._FieldType, names: List[str], + templates: _OpTemplates): assert(offset_in_byte >= 0 and offset_in_byte < 8) self._offset_in_byte = offset_in_byte self._ft = ft @@ -54,40 +58,41 @@ class _Op: self._templates = templates @property - def offset_in_byte(self): + def offset_in_byte(self) -> Count: return self._offset_in_byte @property - def ft(self): + def ft(self) -> barectf_config._FieldType: return self._ft @property - def names(self): + def names(self) -> List[str]: return self._names @property - def top_name(self): + def top_name(self) -> str: return self._names[-1] - def _render_template(self, templ, **kwargs): + def _render_template(self, templ: barectf_template._Template, **kwargs) -> str: return templ.render(op=self, root_ft_prefixes=_RootFtPrefixes, root_ft_prefix_names=_ROOT_FT_PREFIX_NAMES, **kwargs) - def serialize_str(self, **kwargs): + def serialize_str(self, **kwargs) -> str: return self._render_template(self._templates.serialize, **kwargs) - def size_str(self, **kwargs): + def size_str(self, **kwargs) -> str: return self._render_template(self._templates.size, **kwargs) # An "align" operation. class _AlignOp(_Op): - def __init__(self, offset_in_byte, ft, names, templates, value): + def __init__(self, offset_in_byte: Count, ft: barectf_config._FieldType, names: List[str], + templates: _OpTemplates, value: Alignment): super().__init__(offset_in_byte, ft, names, templates) self._value = value @property - def value(self): + def value(self) -> Alignment: return self._value @@ -96,6 +101,10 @@ class _WriteOp(_Op): pass +_SpecSerializeWriteTemplates = Mapping[str, barectf_template._Template] +_Ops = List[_Op] + + # A builder of a chain of operations. # # Such a builder is closely connected to a `_CodeGen` object using it to @@ -106,16 +115,16 @@ class _WriteOp(_Op): # # Get an operation builder's operations with its `ops` property. class _OpsBuilder: - def __init__(self, cg): - self._last_alignment = None - self._last_bit_array_size = None - self._ops = [] - self._names = [] - self._offset_in_byte = 0 + def __init__(self, cg: _CodeGen): + self._last_alignment: Optional[Alignment] = None + self._last_bit_array_size: Optional[Count] = None + self._ops: _Ops = [] + self._names: List[str] = [] + self._offset_in_byte = Count(0) self._cg = cg @property - def ops(self): + def ops(self) -> _Ops: return self._ops # Creates and appends the operations for the members, recursively, @@ -123,7 +132,8 @@ class _OpsBuilder: # # `spec_serialize_write_templates` is a mapping of first level # member names to specialized serialization "write" templates. - def append_root_ft(self, ft, name, spec_serialize_write_templates=None): + def append_root_ft(self, ft: barectf_config._FieldType, name: str, + spec_serialize_write_templates: Optional[_SpecSerializeWriteTemplates] = None): if ft is None: return @@ -138,23 +148,23 @@ class _OpsBuilder: # named `name`. # # See append_root_ft() for `spec_serialize_write_templates`. - def _append_ft(self, ft, name, spec_serialize_write_templates): - def top_name(): + def _append_ft(self, ft: barectf_config._FieldType, name: str, + spec_serialize_write_templates: _SpecSerializeWriteTemplates): + def top_name() -> str: return self._names[-1] # Appends a "write" operation for the field type `ft`. # # This function considers `spec_serialize_write_templates` to # override generic templates. - def append_write_op(ft): + def append_write_op(ft: barectf_config._FieldType): assert type(ft) is not barectf_config.StructureFieldType offset_in_byte = self._offset_in_byte if isinstance(ft, barectf_config._BitArrayFieldType): - self._offset_in_byte += ft.size - self._offset_in_byte %= 8 + self._offset_in_byte = Count((self._offset_in_byte + ft.size) % 8) - serialize_write_templ = None + serialize_write_templ: Optional[barectf_template._Template] = None if len(self._names) == 2: serialize_write_templ = spec_serialize_write_templates.get(top_name()) @@ -182,12 +192,13 @@ class _OpsBuilder: # `ft` if needed. # # This function updates the builder's state. - def try_append_align_op(alignment, do_align, ft): - def align(v, alignment): - return (v + (alignment - 1)) & -alignment + def try_append_align_op(alignment: Alignment, do_align: bool, + ft: barectf_config._FieldType): + def align(v: Count, alignment: Alignment) -> Count: + return Count((v + (alignment - 1)) & -alignment) offset_in_byte = self._offset_in_byte - self._offset_in_byte = align(self._offset_in_byte, alignment) % 8 + self._offset_in_byte = Count(align(self._offset_in_byte, alignment) % 8) if do_align and alignment > 1: self._ops.append(_AlignOp(offset_in_byte, ft, self._names, @@ -198,8 +209,8 @@ class _OpsBuilder: # Returns whether or not, considering the alignment requirement # `align_req` and the builder's current state, we must create # and append an "align" operation. - def must_align(align_req): - return self._last_alignment != align_req or self._last_bit_array_size % align_req != 0 + def must_align(align_req: Alignment) -> bool: + return self._last_alignment != align_req or typing.cast(Count, self._last_bit_array_size) % align_req != 0 # push field type's name to the builder's name stack initially self._names.append(name) @@ -208,23 +219,26 @@ class _OpsBuilder: assert type(ft) is barectf_config.StringFieldType or top_name() == 'uuid' # strings and arrays are always byte-aligned - do_align = must_align(8) - self._last_alignment = 8 - self._last_bit_array_size = 8 - try_append_align_op(8, do_align, ft) + do_align = must_align(Alignment(8)) + self._last_alignment = Alignment(8) + self._last_bit_array_size = Count(8) + try_append_align_op(Alignment(8), do_align, ft) append_write_op(ft) else: do_align = must_align(ft.alignment) self._last_alignment = ft.alignment if type(ft) is barectf_config.StructureFieldType: - self._last_bit_array_size = ft.alignment + self._last_bit_array_size = typing.cast(Count, ft.alignment) else: + assert isinstance(ft, barectf_config._BitArrayFieldType) + ft = typing.cast(barectf_config._BitArrayFieldType, ft) self._last_bit_array_size = ft.size try_append_align_op(ft.alignment, do_align, ft) if type(ft) is barectf_config.StructureFieldType: + ft = typing.cast(barectf_config.StructureFieldType, ft) for member_name, member in ft.members.items(): self._append_ft(member.field_type, member_name, spec_serialize_write_templates) else: @@ -240,20 +254,23 @@ class _OpsBuilder: # # * Specific context operations. # * Payload operations. -class _EventOps: - def __init__(self, spec_ctx_ops, payload_ops): +class _EvOps: + def __init__(self, spec_ctx_ops: _Ops, payload_ops: _Ops): self._spec_ctx_ops = copy.copy(spec_ctx_ops) self._payload_ops = copy.copy(payload_ops) @property - def spec_ctx_ops(self): + def spec_ctx_ops(self) -> _Ops: return self._spec_ctx_ops @property - def payload_ops(self): + def payload_ops(self) -> _Ops: return self._payload_ops +_EvOpsMap = Mapping[barectf_config.EventType, _EvOps] + + # The operations for a stream. # # The available operations are: @@ -262,10 +279,10 @@ class _EventOps: # * Packet context operations. # * Event header operations. # * Event common context operations. -# * Event operations (`_EventOps`). +# * Event operations (`_EvOps`). class _StreamOps: - def __init__(self, pkt_header_ops, pkt_ctx_ops, ev_header_ops, - ev_common_ctx_ops, ev_ops): + def __init__(self, pkt_header_ops: _Ops, pkt_ctx_ops: _Ops, ev_header_ops: _Ops, + ev_common_ctx_ops: _Ops, ev_ops: _EvOpsMap): self._pkt_header_ops = copy.copy(pkt_header_ops) self._pkt_ctx_ops = copy.copy(pkt_ctx_ops) self._ev_header_ops = copy.copy(ev_header_ops) @@ -273,23 +290,23 @@ class _StreamOps: self._ev_ops = copy.copy(ev_ops) @property - def pkt_header_ops(self): + def pkt_header_ops(self) -> _Ops: return self._pkt_header_ops @property - def pkt_ctx_ops(self): + def pkt_ctx_ops(self) -> _Ops: return self._pkt_ctx_ops @property - def ev_header_ops(self): + def ev_header_ops(self) -> _Ops: return self._ev_header_ops @property - def ev_common_ctx_ops(self): + def ev_common_ctx_ops(self) -> _Ops: return self._ev_common_ctx_ops @property - def ev_ops(self): + def ev_ops(self) -> _EvOpsMap: return self._ev_ops @@ -326,11 +343,10 @@ _FtParam = collections.namedtuple('_FtParam', ['ft', 'name']) # * The public header (gen_header()). # * The source code (gen_src()). class _CodeGen: - def __init__(self, cfg): + def __init__(self, cfg: barectf_config.Configuration): self._cfg = cfg self._iden_prefix = cfg.options.code_generation_options.identifier_prefix - self._saved_serialization_ops = {} - self._templ_filters = { + self._templ_filters: Mapping[str, Callable[..., Any]] = { 'ft_c_type': self._ft_c_type, 'open_func_params_str': self._open_func_params_str, 'trace_func_params_str': self._trace_func_params_str, @@ -360,7 +376,8 @@ class _CodeGen: # # Such a template has the filters custom filters # `self._templ_filters`. - def _create_template_base(self, name: str, is_file_template: bool): + def _create_template_base(self, name: str, + is_file_template: bool) -> barectf_template._Template: return barectf_template._Template(f'c/{name}', is_file_template, self._cfg, self._templ_filters) @@ -378,15 +395,16 @@ class _CodeGen: # Trace type of this code generator's barectf configuration. @property - def _trace_type(self): + def _trace_type(self) -> barectf_config.TraceType: return self._cfg.trace.type # Returns the C type for the field type `ft`, returning a `const` C # type if `is_const` is `True`. - def _ft_c_type(self, ft, is_const=False): + def _ft_c_type(self, ft: barectf_config._FieldType, is_const: bool = False) -> str: const_beg_str = 'const ' if isinstance(ft, barectf_config._IntegerFieldType): + ft = typing.cast(barectf_config._IntegerFieldType, ft) sign_prefix = 'u' if isinstance(ft, barectf_config.UnsignedIntegerFieldType) else '' if ft.size <= 8: @@ -401,6 +419,8 @@ class _CodeGen: return f'{const_beg_str if is_const else ""}{sign_prefix}int{sz}_t' elif type(ft) is barectf_config.RealFieldType: + ft = typing.cast(barectf_config.RealFieldType, ft) + if ft.size == 32 and ft.alignment == 32: c_type = 'float' elif ft.size == 64 and ft.alignment == 64: @@ -419,10 +439,11 @@ class _CodeGen: # Each parameter has the prefix `name_prefix` followed with `_`. # # Members of which the name is in `exclude_set` are excluded. - def _proto_params_str(self, root_ft, name_prefix, const_params, exclude_set=None, - only_dyn=False): + def _proto_params_str(self, root_ft: Optional[barectf_config.StructureFieldType], + name_prefix: str, const_params: bool, + exclude_set: Optional[Set[str]] = None, only_dyn: bool = False) -> str: if root_ft is None: - return + return '' if exclude_set is None: exclude_set = set() @@ -443,7 +464,8 @@ class _CodeGen: # Returns the packet opening function prototype parameters for the # stream type `stream_type`. - def _open_func_params_str(self, stream_type, const_params): + def _open_func_params_str(self, stream_type: barectf_config.StreamType, + const_params: bool) -> str: parts = [] parts.append(self._proto_params_str(self._trace_type._pkt_header_ft, _RootFtPrefixes.PH, const_params, {'magic', 'stream_id', 'uuid'})) @@ -461,7 +483,9 @@ class _CodeGen: # Returns the tracing function prototype parameters for the stream # and event types `stream_ev_types`. - def _trace_func_params_str(self, stream_ev_types, const_params, only_dyn=False): + def _trace_func_params_str(self, stream_ev_types: Tuple[barectf_config.StreamType, + barectf_config.EventType], + const_params: bool, only_dyn: bool = False): stream_type = stream_ev_types[0] ev_type = stream_ev_types[1] parts = [] @@ -489,23 +513,24 @@ class _CodeGen: # Returns the event header serialization function prototype # parameters for the stream type `stream_type`. - def _serialize_ev_common_ctx_func_params_str(self, stream_type, const_params): + def _serialize_ev_common_ctx_func_params_str(self, stream_type: barectf_config.StreamType, + const_params: bool) -> str: return self._proto_params_str(stream_type.event_common_context_field_type, _RootFtPrefixes.ECC, const_params); # Generates the bitfield header file contents. - def gen_bitfield_header(self): + def gen_bitfield_header(self) -> str: return self._create_file_template('bitfield.h.j2').render() # Generates the public header file contents. - def gen_header(self): + def gen_header(self) -> str: return self._create_file_template('barectf.h.j2').render(root_ft_prefixes=_RootFtPrefixes) # Generates the source code file contents. - def gen_src(self, header_file_name, bitfield_header_file_name): + def gen_src(self, header_file_name: str, bitfield_header_file_name: str) -> str: # Creates and returns the operations for all the stream and for # all their events. - def create_stream_ops(): + def create_stream_ops() -> Mapping[barectf_config.StreamType, _StreamOps]: stream_ser_ops = {} for stream_type in self._trace_type.stream_types: @@ -582,7 +607,7 @@ class _CodeGen: ev_builder.append_root_ft(ev_type.payload_field_type, _RootFtPrefixes.P) payload_ser_ops = copy.copy(ev_builder.ops[first_op_index:]) - ev_ser_ops[ev_type] = _EventOps(spec_ctx_ser_ops, payload_ser_ops) + ev_ser_ops[ev_type] = _EvOps(spec_ctx_ser_ops, payload_ser_ops) stream_ser_ops[stream_type] = _StreamOps(pkt_header_ser_ops, pkt_ctx_ser_ops, ev_header_ser_ops, ev_common_ctx_ser_ops, @@ -592,10 +617,16 @@ class _CodeGen: # Returns the "write" operation for the packet context member # named `member_name` within the stream type `stream_type`. - def stream_op_pkt_ctx_op(stream_type, member_name): + def stream_op_pkt_ctx_op(stream_type: barectf_config.StreamType, member_name: str) -> _Op: + ret_op = None + for op in stream_ops[stream_type].pkt_ctx_ops: if op.top_name == member_name and type(op) is _WriteOp: - return op + ret_op = op + break + + assert ret_op is not None + return typing.cast(_Op, ret_op) stream_ops = create_stream_ops() return self._create_file_template('barectf.c.j2').render(header_file_name=header_file_name, diff --git a/barectf/template.py b/barectf/template.py index 27ade4f..0dd4788 100644 --- a/barectf/template.py +++ b/barectf/template.py @@ -49,9 +49,9 @@ def _filt_escape_dq(text: str) -> str: return text.replace('\\', '\\\\').replace('"', '\\"') -_Filter = Callable[[Any], Any] +_Filter = Callable[..., Any] _Filters = Mapping[str, _Filter] -_Test = Callable[[Any], bool] +_Test = Callable[..., bool] _Tests = Mapping[str, _Test] -- 2.34.1