cgen.py: add type hints
authorPhilippe Proulx <eeppeliteloop@gmail.com>
Thu, 3 Sep 2020 17:16:42 +0000 (13:16 -0400)
committerPhilippe Proulx <eeppeliteloop@gmail.com>
Thu, 3 Sep 2020 17:23:14 +0000 (13:23 -0400)
Signed-off-by: Philippe Proulx <eeppeliteloop@gmail.com>
barectf/cgen.py
barectf/template.py

index ebc85e131843ef5cf1bd8a1d30e404de6e1b4dc6..6c9cf4a05369c59cfa78398e895e8d4bcb358a5d 100644 (file)
@@ -25,6 +25,9 @@ import barectf.template as barectf_template
 import barectf.config as barectf_config
 import collections
 import copy
+from typing import List, Optional, Mapping, Callable, Any, Set, Tuple
+import typing
+from barectf.typing import Count, Alignment
 
 
 # A tuple containing serialization and size computation function
@@ -46,7 +49,8 @@ _OpTemplates = collections.namedtuple('_OpTemplates', ['serialize', 'size'])
 # * Serialization and size computation templates to generate the
 #   operation's source code for those functions.
 class _Op:
-    def __init__(self, offset_in_byte, ft, names, templates):
+    def __init__(self, offset_in_byte: Count, ft: barectf_config._FieldType, names: List[str],
+                 templates: _OpTemplates):
         assert(offset_in_byte >= 0 and offset_in_byte < 8)
         self._offset_in_byte = offset_in_byte
         self._ft = ft
@@ -54,40 +58,41 @@ class _Op:
         self._templates = templates
 
     @property
-    def offset_in_byte(self):
+    def offset_in_byte(self) -> Count:
         return self._offset_in_byte
 
     @property
-    def ft(self):
+    def ft(self) -> barectf_config._FieldType:
         return self._ft
 
     @property
-    def names(self):
+    def names(self) -> List[str]:
         return self._names
 
     @property
-    def top_name(self):
+    def top_name(self) -> str:
         return self._names[-1]
 
-    def _render_template(self, templ, **kwargs):
+    def _render_template(self, templ: barectf_template._Template, **kwargs) -> str:
         return templ.render(op=self, root_ft_prefixes=_RootFtPrefixes,
                             root_ft_prefix_names=_ROOT_FT_PREFIX_NAMES, **kwargs)
 
-    def serialize_str(self, **kwargs):
+    def serialize_str(self, **kwargs) -> str:
         return self._render_template(self._templates.serialize, **kwargs)
 
-    def size_str(self, **kwargs):
+    def size_str(self, **kwargs) -> str:
         return self._render_template(self._templates.size, **kwargs)
 
 
 # An "align" operation.
 class _AlignOp(_Op):
-    def __init__(self, offset_in_byte, ft, names, templates, value):
+    def __init__(self, offset_in_byte: Count, ft: barectf_config._FieldType, names: List[str],
+                 templates: _OpTemplates, value: Alignment):
         super().__init__(offset_in_byte, ft, names, templates)
         self._value = value
 
     @property
-    def value(self):
+    def value(self) -> Alignment:
         return self._value
 
 
@@ -96,6 +101,10 @@ class _WriteOp(_Op):
     pass
 
 
+_SpecSerializeWriteTemplates = Mapping[str, barectf_template._Template]
+_Ops = List[_Op]
+
+
 # A builder of a chain of operations.
 #
 # Such a builder is closely connected to a `_CodeGen` object using it to
@@ -106,16 +115,16 @@ class _WriteOp(_Op):
 #
 # Get an operation builder's operations with its `ops` property.
 class _OpsBuilder:
-    def __init__(self, cg):
-        self._last_alignment = None
-        self._last_bit_array_size = None
-        self._ops = []
-        self._names = []
-        self._offset_in_byte = 0
+    def __init__(self, cg: _CodeGen):
+        self._last_alignment: Optional[Alignment] = None
+        self._last_bit_array_size: Optional[Count] = None
+        self._ops: _Ops = []
+        self._names: List[str] = []
+        self._offset_in_byte = Count(0)
         self._cg = cg
 
     @property
-    def ops(self):
+    def ops(self) -> _Ops:
         return self._ops
 
     # Creates and appends the operations for the members, recursively,
@@ -123,7 +132,8 @@ class _OpsBuilder:
     #
     # `spec_serialize_write_templates` is a mapping of first level
     # member names to specialized serialization "write" templates.
-    def append_root_ft(self, ft, name, spec_serialize_write_templates=None):
+    def append_root_ft(self, ft: barectf_config._FieldType, name: str,
+                       spec_serialize_write_templates: Optional[_SpecSerializeWriteTemplates] = None):
         if ft is None:
             return
 
@@ -138,23 +148,23 @@ class _OpsBuilder:
     # named `name`.
     #
     # See append_root_ft() for `spec_serialize_write_templates`.
-    def _append_ft(self, ft, name, spec_serialize_write_templates):
-        def top_name():
+    def _append_ft(self, ft: barectf_config._FieldType, name: str,
+                   spec_serialize_write_templates: _SpecSerializeWriteTemplates):
+        def top_name() -> str:
             return self._names[-1]
 
         # Appends a "write" operation for the field type `ft`.
         #
         # This function considers `spec_serialize_write_templates` to
         # override generic templates.
-        def append_write_op(ft):
+        def append_write_op(ft: barectf_config._FieldType):
             assert type(ft) is not barectf_config.StructureFieldType
             offset_in_byte = self._offset_in_byte
 
             if isinstance(ft, barectf_config._BitArrayFieldType):
-                self._offset_in_byte += ft.size
-                self._offset_in_byte %= 8
+                self._offset_in_byte = Count((self._offset_in_byte + ft.size) % 8)
 
-            serialize_write_templ = None
+            serialize_write_templ: Optional[barectf_template._Template] = None
 
             if len(self._names) == 2:
                 serialize_write_templ = spec_serialize_write_templates.get(top_name())
@@ -182,12 +192,13 @@ class _OpsBuilder:
         # `ft` if needed.
         #
         # This function updates the builder's state.
-        def try_append_align_op(alignment, do_align, ft):
-            def align(v, alignment):
-                return (v + (alignment - 1)) & -alignment
+        def try_append_align_op(alignment: Alignment, do_align: bool,
+                                ft: barectf_config._FieldType):
+            def align(v: Count, alignment: Alignment) -> Count:
+                return Count((v + (alignment - 1)) & -alignment)
 
             offset_in_byte = self._offset_in_byte
-            self._offset_in_byte = align(self._offset_in_byte, alignment) % 8
+            self._offset_in_byte = Count(align(self._offset_in_byte, alignment) % 8)
 
             if do_align and alignment > 1:
                 self._ops.append(_AlignOp(offset_in_byte, ft, self._names,
@@ -198,8 +209,8 @@ class _OpsBuilder:
         # Returns whether or not, considering the alignment requirement
         # `align_req` and the builder's current state, we must create
         # and append an "align" operation.
-        def must_align(align_req):
-            return self._last_alignment != align_req or self._last_bit_array_size % align_req != 0
+        def must_align(align_req: Alignment) -> bool:
+            return self._last_alignment != align_req or typing.cast(Count, self._last_bit_array_size) % align_req != 0
 
         # push field type's name to the builder's name stack initially
         self._names.append(name)
@@ -208,23 +219,26 @@ class _OpsBuilder:
             assert type(ft) is barectf_config.StringFieldType or top_name() == 'uuid'
 
             # strings and arrays are always byte-aligned
-            do_align = must_align(8)
-            self._last_alignment = 8
-            self._last_bit_array_size = 8
-            try_append_align_op(8, do_align, ft)
+            do_align = must_align(Alignment(8))
+            self._last_alignment = Alignment(8)
+            self._last_bit_array_size = Count(8)
+            try_append_align_op(Alignment(8), do_align, ft)
             append_write_op(ft)
         else:
             do_align = must_align(ft.alignment)
             self._last_alignment = ft.alignment
 
             if type(ft) is barectf_config.StructureFieldType:
-                self._last_bit_array_size = ft.alignment
+                self._last_bit_array_size = typing.cast(Count, ft.alignment)
             else:
+                assert isinstance(ft, barectf_config._BitArrayFieldType)
+                ft = typing.cast(barectf_config._BitArrayFieldType, ft)
                 self._last_bit_array_size = ft.size
 
             try_append_align_op(ft.alignment, do_align, ft)
 
             if type(ft) is barectf_config.StructureFieldType:
+                ft = typing.cast(barectf_config.StructureFieldType, ft)
                 for member_name, member in ft.members.items():
                     self._append_ft(member.field_type, member_name, spec_serialize_write_templates)
             else:
@@ -240,20 +254,23 @@ class _OpsBuilder:
 #
 # * Specific context operations.
 # * Payload operations.
-class _EventOps:
-    def __init__(self, spec_ctx_ops, payload_ops):
+class _EvOps:
+    def __init__(self, spec_ctx_ops: _Ops, payload_ops: _Ops):
         self._spec_ctx_ops = copy.copy(spec_ctx_ops)
         self._payload_ops = copy.copy(payload_ops)
 
     @property
-    def spec_ctx_ops(self):
+    def spec_ctx_ops(self) -> _Ops:
         return self._spec_ctx_ops
 
     @property
-    def payload_ops(self):
+    def payload_ops(self) -> _Ops:
         return self._payload_ops
 
 
+_EvOpsMap = Mapping[barectf_config.EventType, _EvOps]
+
+
 # The operations for a stream.
 #
 # The available operations are:
@@ -262,10 +279,10 @@ class _EventOps:
 # * Packet context operations.
 # * Event header operations.
 # * Event common context operations.
-# * Event operations (`_EventOps`).
+# * Event operations (`_EvOps`).
 class _StreamOps:
-    def __init__(self, pkt_header_ops, pkt_ctx_ops, ev_header_ops,
-                 ev_common_ctx_ops, ev_ops):
+    def __init__(self, pkt_header_ops: _Ops, pkt_ctx_ops: _Ops, ev_header_ops: _Ops,
+                 ev_common_ctx_ops: _Ops, ev_ops: _EvOpsMap):
         self._pkt_header_ops = copy.copy(pkt_header_ops)
         self._pkt_ctx_ops = copy.copy(pkt_ctx_ops)
         self._ev_header_ops = copy.copy(ev_header_ops)
@@ -273,23 +290,23 @@ class _StreamOps:
         self._ev_ops = copy.copy(ev_ops)
 
     @property
-    def pkt_header_ops(self):
+    def pkt_header_ops(self) -> _Ops:
         return self._pkt_header_ops
 
     @property
-    def pkt_ctx_ops(self):
+    def pkt_ctx_ops(self) -> _Ops:
         return self._pkt_ctx_ops
 
     @property
-    def ev_header_ops(self):
+    def ev_header_ops(self) -> _Ops:
         return self._ev_header_ops
 
     @property
-    def ev_common_ctx_ops(self):
+    def ev_common_ctx_ops(self) -> _Ops:
         return self._ev_common_ctx_ops
 
     @property
-    def ev_ops(self):
+    def ev_ops(self) -> _EvOpsMap:
         return self._ev_ops
 
 
@@ -326,11 +343,10 @@ _FtParam = collections.namedtuple('_FtParam', ['ft', 'name'])
 # * The public header (gen_header()).
 # * The source code (gen_src()).
 class _CodeGen:
-    def __init__(self, cfg):
+    def __init__(self, cfg: barectf_config.Configuration):
         self._cfg = cfg
         self._iden_prefix = cfg.options.code_generation_options.identifier_prefix
-        self._saved_serialization_ops = {}
-        self._templ_filters = {
+        self._templ_filters: Mapping[str, Callable[..., Any]] = {
             'ft_c_type': self._ft_c_type,
             'open_func_params_str': self._open_func_params_str,
             'trace_func_params_str': self._trace_func_params_str,
@@ -360,7 +376,8 @@ class _CodeGen:
     #
     # Such a template has the filters custom filters
     # `self._templ_filters`.
-    def _create_template_base(self, name: str, is_file_template: bool):
+    def _create_template_base(self, name: str,
+                              is_file_template: bool) -> barectf_template._Template:
         return barectf_template._Template(f'c/{name}', is_file_template, self._cfg,
                                           self._templ_filters)
 
@@ -378,15 +395,16 @@ class _CodeGen:
 
     # Trace type of this code generator's barectf configuration.
     @property
-    def _trace_type(self):
+    def _trace_type(self) -> barectf_config.TraceType:
         return self._cfg.trace.type
 
     # Returns the C type for the field type `ft`, returning a `const` C
     # type if `is_const` is `True`.
-    def _ft_c_type(self, ft, is_const=False):
+    def _ft_c_type(self, ft: barectf_config._FieldType, is_const: bool = False) -> str:
         const_beg_str = 'const '
 
         if isinstance(ft, barectf_config._IntegerFieldType):
+            ft = typing.cast(barectf_config._IntegerFieldType, ft)
             sign_prefix = 'u' if isinstance(ft, barectf_config.UnsignedIntegerFieldType) else ''
 
             if ft.size <= 8:
@@ -401,6 +419,8 @@ class _CodeGen:
 
             return f'{const_beg_str if is_const else ""}{sign_prefix}int{sz}_t'
         elif type(ft) is barectf_config.RealFieldType:
+            ft = typing.cast(barectf_config.RealFieldType, ft)
+
             if ft.size == 32 and ft.alignment == 32:
                 c_type = 'float'
             elif ft.size == 64 and ft.alignment == 64:
@@ -419,10 +439,11 @@ class _CodeGen:
     # Each parameter has the prefix `name_prefix` followed with `_`.
     #
     # Members of which the name is in `exclude_set` are excluded.
-    def _proto_params_str(self, root_ft, name_prefix, const_params, exclude_set=None,
-                          only_dyn=False):
+    def _proto_params_str(self, root_ft: Optional[barectf_config.StructureFieldType],
+                          name_prefix: str, const_params: bool,
+                          exclude_set: Optional[Set[str]] = None, only_dyn: bool = False) -> str:
         if root_ft is None:
-            return
+            return ''
 
         if exclude_set is None:
             exclude_set = set()
@@ -443,7 +464,8 @@ class _CodeGen:
 
     # Returns the packet opening function prototype parameters for the
     # stream type `stream_type`.
-    def _open_func_params_str(self, stream_type, const_params):
+    def _open_func_params_str(self, stream_type: barectf_config.StreamType,
+                              const_params: bool) -> str:
         parts = []
         parts.append(self._proto_params_str(self._trace_type._pkt_header_ft, _RootFtPrefixes.PH,
                                             const_params, {'magic', 'stream_id', 'uuid'}))
@@ -461,7 +483,9 @@ class _CodeGen:
 
     # Returns the tracing function prototype parameters for the stream
     # and event types `stream_ev_types`.
-    def _trace_func_params_str(self, stream_ev_types, const_params, only_dyn=False):
+    def _trace_func_params_str(self, stream_ev_types: Tuple[barectf_config.StreamType,
+                                                            barectf_config.EventType],
+                               const_params: bool, only_dyn: bool = False):
         stream_type = stream_ev_types[0]
         ev_type = stream_ev_types[1]
         parts = []
@@ -489,23 +513,24 @@ class _CodeGen:
 
     # Returns the event header serialization function prototype
     # parameters for the stream type `stream_type`.
-    def _serialize_ev_common_ctx_func_params_str(self, stream_type, const_params):
+    def _serialize_ev_common_ctx_func_params_str(self, stream_type: barectf_config.StreamType,
+                                                 const_params: bool) -> str:
         return self._proto_params_str(stream_type.event_common_context_field_type,
                                       _RootFtPrefixes.ECC, const_params);
 
     # Generates the bitfield header file contents.
-    def gen_bitfield_header(self):
+    def gen_bitfield_header(self) -> str:
         return self._create_file_template('bitfield.h.j2').render()
 
     # Generates the public header file contents.
-    def gen_header(self):
+    def gen_header(self) -> str:
         return self._create_file_template('barectf.h.j2').render(root_ft_prefixes=_RootFtPrefixes)
 
     # Generates the source code file contents.
-    def gen_src(self, header_file_name, bitfield_header_file_name):
+    def gen_src(self, header_file_name: str, bitfield_header_file_name: str) -> str:
         # Creates and returns the operations for all the stream and for
         # all their events.
-        def create_stream_ops():
+        def create_stream_ops() -> Mapping[barectf_config.StreamType, _StreamOps]:
             stream_ser_ops = {}
 
             for stream_type in self._trace_type.stream_types:
@@ -582,7 +607,7 @@ class _CodeGen:
                         ev_builder.append_root_ft(ev_type.payload_field_type, _RootFtPrefixes.P)
                         payload_ser_ops = copy.copy(ev_builder.ops[first_op_index:])
 
-                    ev_ser_ops[ev_type] = _EventOps(spec_ctx_ser_ops, payload_ser_ops)
+                    ev_ser_ops[ev_type] = _EvOps(spec_ctx_ser_ops, payload_ser_ops)
 
                 stream_ser_ops[stream_type] = _StreamOps(pkt_header_ser_ops, pkt_ctx_ser_ops,
                                                          ev_header_ser_ops, ev_common_ctx_ser_ops,
@@ -592,10 +617,16 @@ class _CodeGen:
 
         # Returns the "write" operation for the packet context member
         # named `member_name` within the stream type `stream_type`.
-        def stream_op_pkt_ctx_op(stream_type, member_name):
+        def stream_op_pkt_ctx_op(stream_type: barectf_config.StreamType, member_name: str) -> _Op:
+            ret_op = None
+
             for op in stream_ops[stream_type].pkt_ctx_ops:
                 if op.top_name == member_name and type(op) is _WriteOp:
-                    return op
+                    ret_op = op
+                    break
+
+            assert ret_op is not None
+            return typing.cast(_Op, ret_op)
 
         stream_ops = create_stream_ops()
         return self._create_file_template('barectf.c.j2').render(header_file_name=header_file_name,
index 27ade4f244a281f95df71540b56c2c1648bc5822..0dd4788333761f8df401a6bb2179056e5bd2bd26 100644 (file)
@@ -49,9 +49,9 @@ def _filt_escape_dq(text: str) -> str:
     return text.replace('\\', '\\\\').replace('"', '\\"')
 
 
-_Filter = Callable[[Any], Any]
+_Filter = Callable[..., Any]
 _Filters = Mapping[str, _Filter]
-_Test = Callable[[Any], bool]
+_Test = Callable[..., bool]
 _Tests = Mapping[str, _Test]
 
 
This page took 0.030678 seconds and 4 git commands to generate.