From aca7de769c495565a693605d1f372f1b477cdb70 Mon Sep 17 00:00:00 2001 From: Simon Marchi Date: Mon, 18 Sep 2023 15:07:30 -0400 Subject: [PATCH] tests: add typing annotations to lttng_live_server.py Add typing annotations throughout lttng_live_server.py, such that pyright passes cleanly in strict mode. Specify in a comment in lttng_live_server.py to run in strict mode, but disable these two warnings: - reportTypeCommentUsage: we need to use the comment form to support Python 3.4 - reportMissingTypeStubs: it would complain about not having a stub file for the utils module (although it can inspect utils.py just fine) Note that the sessions_filename variable near the bottom is marked as a string, even though it's currently an optional argument (so it should be `str | None`). But the rest of the code treats it as non-optional. This is corrected in a subsequent patch that makes the sessions filename mandatory. Change the way lttng_live_server.py is launched to use run_python, to have access to Python modules in tests/utils/python, to have access to jsonw.py and possible the typing module shim for Python 3.4. Change-Id: I842f4c696eea7c99932876af26c195bf99d5cfff Signed-off-by: Simon Marchi Reviewed-on: https://review.lttng.org/c/babeltrace/+/10869 Tested-by: jenkins Reviewed-by: Philippe Proulx --- .../src.ctf.lttng-live/lttng_live_server.py | 591 +++++++++++------- tests/plugins/src.ctf.lttng-live/test_live | 2 +- 2 files changed, 372 insertions(+), 221 deletions(-) diff --git a/tests/data/plugins/src.ctf.lttng-live/lttng_live_server.py b/tests/data/plugins/src.ctf.lttng-live/lttng_live_server.py index 96879294..45246b9b 100644 --- a/tests/data/plugins/src.ctf.lttng-live/lttng_live_server.py +++ b/tests/data/plugins/src.ctf.lttng-live/lttng_live_server.py @@ -3,18 +3,25 @@ # Copyright (C) 2019 Philippe Proulx # +# pyright: strict, reportTypeCommentUsage=false, reportMissingTypeStubs=false + import os import re import sys -import json import socket import struct import logging import os.path import argparse import tempfile -import collections.abc -from collections import namedtuple +from typing import Dict, Union, Iterable, Optional, Sequence, overload + +import tjson + +# isort: off +from typing import Any, Callable # noqa: F401 + +# isort: on class UnexpectedInput(RuntimeError): @@ -25,13 +32,13 @@ class UnexpectedInput(RuntimeError): class _LttngDataStreamIndexEntry: def __init__( self, - offset_bytes, - total_size_bits, - content_size_bits, - timestamp_begin, - timestamp_end, - events_discarded, - stream_class_id, + offset_bytes: int, + total_size_bits: int, + content_size_bits: int, + timestamp_begin: int, + timestamp_end: int, + events_discarded: int, + stream_class_id: int, ): self._offset_bytes = offset_bytes self._total_size_bits = total_size_bits @@ -81,8 +88,8 @@ class _LttngDataStreamIndexEntry: # An entry within the index of an LTTng data stream. While a stream beacon entry # is conceptually unrelated to an index, it is sent as a reply to a # LttngLiveViewerGetNextDataStreamIndexEntryCommand -class _LttngDataStreamBeaconEntry: - def __init__(self, stream_class_id, timestamp): +class _LttngDataStreamBeaconIndexEntry: + def __init__(self, stream_class_id: int, timestamp: int): self._stream_class_id = stream_class_id self._timestamp = timestamp @@ -95,8 +102,11 @@ class _LttngDataStreamBeaconEntry: return self._stream_class_id +_LttngIndexEntryT = Union[_LttngDataStreamIndexEntry, _LttngDataStreamBeaconIndexEntry] + + class _LttngLiveViewerCommand: - def __init__(self, version): + def __init__(self, version: int): self._version = version @property @@ -105,7 +115,7 @@ class _LttngLiveViewerCommand: class _LttngLiveViewerConnectCommand(_LttngLiveViewerCommand): - def __init__(self, version, viewer_session_id, major, minor): + def __init__(self, version: int, viewer_session_id: int, major: int, minor: int): super().__init__(version) self._viewer_session_id = viewer_session_id self._major = major @@ -124,8 +134,12 @@ class _LttngLiveViewerConnectCommand(_LttngLiveViewerCommand): return self._minor -class _LttngLiveViewerConnectReply: - def __init__(self, viewer_session_id, major, minor): +class _LttngLiveViewerReply: + pass + + +class _LttngLiveViewerConnectReply(_LttngLiveViewerReply): + def __init__(self, viewer_session_id: int, major: int, minor: int): self._viewer_session_id = viewer_session_id self._major = major self._minor = minor @@ -150,12 +164,12 @@ class _LttngLiveViewerGetTracingSessionInfosCommand(_LttngLiveViewerCommand): class _LttngLiveViewerTracingSessionInfo: def __init__( self, - tracing_session_id, - live_timer_freq, - client_count, - stream_count, - hostname, - name, + tracing_session_id: int, + live_timer_freq: int, + client_count: int, + stream_count: int, + hostname: str, + name: str, ): self._tracing_session_id = tracing_session_id self._live_timer_freq = live_timer_freq @@ -189,8 +203,10 @@ class _LttngLiveViewerTracingSessionInfo: return self._name -class _LttngLiveViewerGetTracingSessionInfosReply: - def __init__(self, tracing_session_infos): +class _LttngLiveViewerGetTracingSessionInfosReply(_LttngLiveViewerReply): + def __init__( + self, tracing_session_infos: Sequence[_LttngLiveViewerTracingSessionInfo] + ): self._tracing_session_infos = tracing_session_infos @property @@ -203,7 +219,9 @@ class _LttngLiveViewerAttachToTracingSessionCommand(_LttngLiveViewerCommand): BEGINNING = 1 LAST = 2 - def __init__(self, version, tracing_session_id, offset, seek_type): + def __init__( + self, version: int, tracing_session_id: int, offset: int, seek_type: int + ): super().__init__(version) self._tracing_session_id = tracing_session_id self._offset = offset @@ -223,7 +241,9 @@ class _LttngLiveViewerAttachToTracingSessionCommand(_LttngLiveViewerCommand): class _LttngLiveViewerStreamInfo: - def __init__(self, id, trace_id, is_metadata, path, channel_name): + def __init__( + self, id: int, trace_id: int, is_metadata: bool, path: str, channel_name: str + ): self._id = id self._trace_id = trace_id self._is_metadata = is_metadata @@ -251,7 +271,7 @@ class _LttngLiveViewerStreamInfo: return self._channel_name -class _LttngLiveViewerAttachToTracingSessionReply: +class _LttngLiveViewerAttachToTracingSessionReply(_LttngLiveViewerReply): class Status: OK = 1 ALREADY = 2 @@ -260,7 +280,7 @@ class _LttngLiveViewerAttachToTracingSessionReply: SEEK_ERROR = 5 NO_SESSION = 6 - def __init__(self, status, stream_infos): + def __init__(self, status: int, stream_infos: Sequence[_LttngLiveViewerStreamInfo]): self._status = status self._stream_infos = stream_infos @@ -274,7 +294,7 @@ class _LttngLiveViewerAttachToTracingSessionReply: class _LttngLiveViewerGetNextDataStreamIndexEntryCommand(_LttngLiveViewerCommand): - def __init__(self, version, stream_id): + def __init__(self, version: int, stream_id: int): super().__init__(version) self._stream_id = stream_id @@ -283,7 +303,7 @@ class _LttngLiveViewerGetNextDataStreamIndexEntryCommand(_LttngLiveViewerCommand return self._stream_id -class _LttngLiveViewerGetNextDataStreamIndexEntryReply: +class _LttngLiveViewerGetNextDataStreamIndexEntryReply(_LttngLiveViewerReply): class Status: OK = 1 RETRY = 2 @@ -292,7 +312,13 @@ class _LttngLiveViewerGetNextDataStreamIndexEntryReply: INACTIVE = 5 EOF = 6 - def __init__(self, status, index_entry, has_new_metadata, has_new_data_stream): + def __init__( + self, + status: int, + index_entry: _LttngIndexEntryT, + has_new_metadata: bool, + has_new_data_stream: bool, + ): self._status = status self._index_entry = index_entry self._has_new_metadata = has_new_metadata @@ -316,7 +342,7 @@ class _LttngLiveViewerGetNextDataStreamIndexEntryReply: class _LttngLiveViewerGetDataStreamPacketDataCommand(_LttngLiveViewerCommand): - def __init__(self, version, stream_id, offset, req_length): + def __init__(self, version: int, stream_id: int, offset: int, req_length: int): super().__init__(version) self._stream_id = stream_id self._offset = offset @@ -335,14 +361,20 @@ class _LttngLiveViewerGetDataStreamPacketDataCommand(_LttngLiveViewerCommand): return self._req_length -class _LttngLiveViewerGetDataStreamPacketDataReply: +class _LttngLiveViewerGetDataStreamPacketDataReply(_LttngLiveViewerReply): class Status: OK = 1 RETRY = 2 ERROR = 3 EOF = 4 - def __init__(self, status, data, has_new_metadata, has_new_data_stream): + def __init__( + self, + status: int, + data: bytes, + has_new_metadata: bool, + has_new_data_stream: bool, + ): self._status = status self._data = data self._has_new_metadata = has_new_metadata @@ -366,7 +398,7 @@ class _LttngLiveViewerGetDataStreamPacketDataReply: class _LttngLiveViewerGetMetadataStreamDataCommand(_LttngLiveViewerCommand): - def __init__(self, version, stream_id): + def __init__(self, version: int, stream_id: int): super().__init__(version) self._stream_id = stream_id @@ -375,13 +407,13 @@ class _LttngLiveViewerGetMetadataStreamDataCommand(_LttngLiveViewerCommand): return self._stream_id -class _LttngLiveViewerGetMetadataStreamDataContentReply: +class _LttngLiveViewerGetMetadataStreamDataContentReply(_LttngLiveViewerReply): class Status: OK = 1 NO_NEW = 2 ERROR = 3 - def __init__(self, status, data): + def __init__(self, status: int, data: bytes): self._status = status self._data = data @@ -395,7 +427,7 @@ class _LttngLiveViewerGetMetadataStreamDataContentReply: class _LttngLiveViewerGetNewStreamInfosCommand(_LttngLiveViewerCommand): - def __init__(self, version, tracing_session_id): + def __init__(self, version: int, tracing_session_id: int): super().__init__(version) self._tracing_session_id = tracing_session_id @@ -404,14 +436,14 @@ class _LttngLiveViewerGetNewStreamInfosCommand(_LttngLiveViewerCommand): return self._tracing_session_id -class _LttngLiveViewerGetNewStreamInfosReply: +class _LttngLiveViewerGetNewStreamInfosReply(_LttngLiveViewerReply): class Status: OK = 1 NO_NEW = 2 ERROR = 3 HUP = 4 - def __init__(self, status, stream_infos): + def __init__(self, status: int, stream_infos: Sequence[_LttngLiveViewerStreamInfo]): self._status = status self._stream_infos = stream_infos @@ -428,12 +460,12 @@ class _LttngLiveViewerCreateViewerSessionCommand(_LttngLiveViewerCommand): pass -class _LttngLiveViewerCreateViewerSessionReply: +class _LttngLiveViewerCreateViewerSessionReply(_LttngLiveViewerReply): class Status: OK = 1 ERROR = 2 - def __init__(self, status): + def __init__(self, status: int): self._status = status @property @@ -442,7 +474,7 @@ class _LttngLiveViewerCreateViewerSessionReply: class _LttngLiveViewerDetachFromTracingSessionCommand(_LttngLiveViewerCommand): - def __init__(self, version, tracing_session_id): + def __init__(self, version: int, tracing_session_id: int): super().__init__(version) self._tracing_session_id = tracing_session_id @@ -451,13 +483,13 @@ class _LttngLiveViewerDetachFromTracingSessionCommand(_LttngLiveViewerCommand): return self._tracing_session_id -class _LttngLiveViewerDetachFromTracingSessionReply: +class _LttngLiveViewerDetachFromTracingSessionReply(_LttngLiveViewerReply): class Status: OK = 1 UNKNOWN = 2 ERROR = 3 - def __init__(self, status): + def __init__(self, status: int): self._status = status @property @@ -474,16 +506,16 @@ class _LttngLiveViewerProtocolCodec: def __init__(self): pass - def _unpack(self, fmt, data, offset=0): + def _unpack(self, fmt: str, data: bytes, offset: int = 0): fmt = "!" + fmt return struct.unpack_from(fmt, data, offset) - def _unpack_payload(self, fmt, data): + def _unpack_payload(self, fmt: str, data: bytes): return self._unpack( fmt, data, _LttngLiveViewerProtocolCodec._COMMAND_HEADER_SIZE_BYTES ) - def decode(self, data): + def decode(self, data: bytes): if len(data) < self._COMMAND_HEADER_SIZE_BYTES: # Not enough data to read the command header return @@ -502,9 +534,7 @@ class _LttngLiveViewerProtocolCodec: return if cmd_type == 1: - viewer_session_id, major, minor, conn_type = self._unpack_payload( - "QIII", data - ) + viewer_session_id, major, minor, _ = self._unpack_payload("QIII", data) return _LttngLiveViewerConnectCommand( version, viewer_session_id, major, minor ) @@ -541,21 +571,23 @@ class _LttngLiveViewerProtocolCodec: else: raise UnexpectedInput("Unknown command type {}".format(cmd_type)) - def _pack(self, fmt, *args): + def _pack(self, fmt: str, *args: Any): # Force network byte order return struct.pack("!" + fmt, *args) - def _encode_zero_padded_str(self, string, length): + def _encode_zero_padded_str(self, string: str, length: int): data = string.encode() return data.ljust(length, b"\x00") - def _encode_stream_info(self, info): + def _encode_stream_info(self, info: _LttngLiveViewerStreamInfo): data = self._pack("QQI", info.id, info.trace_id, int(info.is_metadata)) data += self._encode_zero_padded_str(info.path, 4096) data += self._encode_zero_padded_str(info.channel_name, 255) return data - def _get_has_new_stuff_flags(self, has_new_metadata, has_new_data_streams): + def _get_has_new_stuff_flags( + self, has_new_metadata: bool, has_new_data_streams: bool + ): flags = 0 if has_new_metadata: @@ -566,7 +598,10 @@ class _LttngLiveViewerProtocolCodec: return flags - def encode(self, reply): + def encode( + self, + reply: _LttngLiveViewerReply, + ) -> bytes: if type(reply) is _LttngLiveViewerConnectReply: data = self._pack( "QIII", reply.viewer_session_id, reply.major, reply.minor, 2 @@ -596,7 +631,7 @@ class _LttngLiveViewerProtocolCodec: reply.has_new_metadata, reply.has_new_data_stream ) - if type(entry) is _LttngDataStreamIndexEntry: + if isinstance(entry, _LttngDataStreamIndexEntry): data = self._pack( index_format, entry.offset_bytes, @@ -610,7 +645,6 @@ class _LttngLiveViewerProtocolCodec: flags, ) else: - assert type(entry) is _LttngDataStreamBeaconEntry data = self._pack( index_format, 0, @@ -649,26 +683,31 @@ class _LttngLiveViewerProtocolCodec: return data -def _get_entry_timestamp_begin(entry): - if type(entry) is _LttngDataStreamBeaconEntry: +def _get_entry_timestamp_begin( + entry: _LttngIndexEntryT, +): + if isinstance(entry, _LttngDataStreamBeaconIndexEntry): return entry.timestamp else: - assert type(entry) is _LttngDataStreamIndexEntry return entry.timestamp_begin # The index of an LTTng data stream, a sequence of index entries. -class _LttngDataStreamIndex(collections.abc.Sequence): - def __init__(self, path, beacons): +class _LttngDataStreamIndex(Sequence[_LttngIndexEntryT]): + def __init__(self, path: str, beacons: Optional[tjson.ArrayVal]): self._path = path self._build() if beacons: stream_class_id = self._entries[0].stream_class_id - beacons = [ - _LttngDataStreamBeaconEntry(stream_class_id, ts) for ts in beacons - ] - self._add_beacons(beacons) + + beacons_list = [] # type: list[_LttngDataStreamBeaconIndexEntry] + for ts in beacons.iter(tjson.IntVal): + beacons_list.append( + _LttngDataStreamBeaconIndexEntry(stream_class_id, ts.val) + ) + + self._add_beacons(beacons_list) logging.info( 'Built data stream index entries: path="{}", count={}'.format( @@ -677,7 +716,7 @@ class _LttngDataStreamIndex(collections.abc.Sequence): ) def _build(self): - self._entries = [] + self._entries = [] # type: list[_LttngIndexEntryT] assert os.path.isfile(self._path) with open(self._path, "rb") as f: @@ -686,9 +725,7 @@ class _LttngDataStreamIndex(collections.abc.Sequence): size = struct.calcsize(fmt) data = f.read(size) assert len(data) == size - magic, index_major, index_minor, index_entry_length = struct.unpack( - fmt, data - ) + magic, _, _, index_entry_length = struct.unpack(fmt, data) assert magic == 0xC1F1DCC1 # Read index entries @@ -733,10 +770,12 @@ class _LttngDataStreamIndex(collections.abc.Sequence): # Skip anything else before the next entry f.seek(index_entry_length - size, os.SEEK_CUR) - def _add_beacons(self, beacons): + def _add_beacons(self, beacons: Iterable[_LttngDataStreamBeaconIndexEntry]): # Assumes entries[n + 1].timestamp_end >= entries[n].timestamp_begin - def sort_key(entry): - if type(entry) is _LttngDataStreamBeaconEntry: + def sort_key( + entry: Union[_LttngDataStreamIndexEntry, _LttngDataStreamBeaconIndexEntry], + ) -> int: + if isinstance(entry, _LttngDataStreamBeaconIndexEntry): return entry.timestamp else: return entry.timestamp_end @@ -744,7 +783,17 @@ class _LttngDataStreamIndex(collections.abc.Sequence): self._entries += beacons self._entries.sort(key=sort_key) - def __getitem__(self, index): + @overload + def __getitem__(self, index: int) -> _LttngIndexEntryT: + ... + + @overload + def __getitem__(self, index: slice) -> Sequence[_LttngIndexEntryT]: # noqa: F811 + ... + + def __getitem__( # noqa: F811 + self, index: Union[int, slice] + ) -> Union[_LttngIndexEntryT, Sequence[_LttngIndexEntryT],]: return self._entries[index] def __len__(self): @@ -757,7 +806,7 @@ class _LttngDataStreamIndex(collections.abc.Sequence): # An LTTng data stream. class _LttngDataStream: - def __init__(self, path, beacons): + def __init__(self, path: str, beacons_json: Optional[tjson.ArrayVal]): self._path = path filename = os.path.basename(path) match = re.match(r"(.*)_\d+", filename) @@ -769,7 +818,7 @@ class _LttngDataStream: self._channel_name = match.group(1) trace_dir = os.path.dirname(path) index_path = os.path.join(trace_dir, "index", filename + ".idx") - self._index = _LttngDataStreamIndex(index_path, beacons) + self._index = _LttngDataStreamIndex(index_path, beacons_json) assert os.path.isfile(path) self._file = open(path, "rb") logging.info( @@ -790,13 +839,13 @@ class _LttngDataStream: def index(self): return self._index - def get_data(self, offset_bytes, len_bytes): + def get_data(self, offset_bytes: int, len_bytes: int): self._file.seek(offset_bytes) return self._file.read(len_bytes) class _LttngMetadataStreamSection: - def __init__(self, timestamp, data): + def __init__(self, timestamp: int, data: Optional[bytes]): self._timestamp = timestamp if data is None: self._data = bytes() @@ -819,7 +868,11 @@ class _LttngMetadataStreamSection: # An LTTng metadata stream. class _LttngMetadataStream: - def __init__(self, metadata_file_path, config_sections): + def __init__( + self, + metadata_file_path: str, + config_sections: Sequence[_LttngMetadataStreamSection], + ): self._path = metadata_file_path self._sections = config_sections logging.info( @@ -837,61 +890,73 @@ class _LttngMetadataStream: return self._sections -LttngMetadataConfigSection = namedtuple( - "LttngMetadataConfigSection", ["line", "timestamp", "is_empty"] -) +class LttngMetadataConfigSection: + def __init__(self, line: int, timestamp: int, is_empty: bool): + self._line = line + self._timestamp = timestamp + self._is_empty = is_empty + + @property + def line(self): + return self._line + @property + def timestamp(self): + return self._timestamp -def _parse_metadata_sections_config(config_sections): - assert config_sections is not None - config_metadata_sections = [] + @property + def is_empty(self): + return self._is_empty + + +def _parse_metadata_sections_config(metadata_sections_json: tjson.ArrayVal): + metadata_sections = [] # type: list[LttngMetadataConfigSection] append_empty_section = False last_timestamp = 0 last_line = 0 - for config_section in config_sections: - if config_section == "empty": - # Found a empty section marker. Actually append the section at the - # timestamp of the next concrete section. - append_empty_section = True - else: - assert type(config_section) is dict - line = config_section.get("line") - ts = config_section.get("timestamp") - - if type(line) is not int: - raise RuntimeError("`line` is not an integer") - - if type(ts) is not int: - raise RuntimeError("`timestamp` is not an integer") + for section in metadata_sections_json: + if isinstance(section, tjson.StrVal): + if section.val == "empty": + # Found a empty section marker. Actually append the section at the + # timestamp of the next concrete section. + append_empty_section = True + else: + raise ValueError("Invalid string value at {}.".format(section.path)) + elif isinstance(section, tjson.ObjVal): + line = section.at("line", tjson.IntVal).val + ts = section.at("timestamp", tjson.IntVal).val # Sections' timestamps and lines must both be increasing. assert ts > last_timestamp last_timestamp = ts + assert line > last_line last_line = line if append_empty_section: - config_metadata_sections.append( - LttngMetadataConfigSection(line, ts, True) - ) + metadata_sections.append(LttngMetadataConfigSection(line, ts, True)) append_empty_section = False - config_metadata_sections.append(LttngMetadataConfigSection(line, ts, False)) - - return config_metadata_sections + metadata_sections.append(LttngMetadataConfigSection(line, ts, False)) + else: + raise TypeError( + "`{}`: expecting a string or object value".format(section.path) + ) + return metadata_sections -def _split_metadata_sections(metadata_file_path, raw_config_sections): - assert isinstance(raw_config_sections, collections.abc.Sequence) - parsed_sections = _parse_metadata_sections_config(raw_config_sections) +def _split_metadata_sections( + metadata_file_path: str, metadata_sections_json: tjson.ArrayVal +): + metadata_sections = _parse_metadata_sections_config(metadata_sections_json) - sections = [] + sections = [] # type: list[_LttngMetadataStreamSection] with open(metadata_file_path, "r") as metadata_file: metadata_lines = [line for line in metadata_file] - config_metadata_sections_idx = 0 + metadata_section_idx = 0 curr_metadata_section = bytearray() for idx, line_content in enumerate(metadata_lines): @@ -901,13 +966,11 @@ def _split_metadata_sections(metadata_file_path, raw_config_sections): curr_line_number = idx + 1 # If there are no more sections, simply append the line. - if config_metadata_sections_idx + 1 >= len(parsed_sections): + if metadata_section_idx + 1 >= len(metadata_sections): curr_metadata_section += bytearray(line_content, "utf8") continue - next_section_line_number = parsed_sections[ - config_metadata_sections_idx + 1 - ].line + next_section_line_number = metadata_sections[metadata_section_idx + 1].line # If the next section begins at the current line, create a # section with the metadata we gathered so far. @@ -915,26 +978,26 @@ def _split_metadata_sections(metadata_file_path, raw_config_sections): # Flushing the metadata of the current section. sections.append( _LttngMetadataStreamSection( - parsed_sections[config_metadata_sections_idx].timestamp, + metadata_sections[metadata_section_idx].timestamp, bytes(curr_metadata_section), ) ) # Move to the next section. - config_metadata_sections_idx += 1 + metadata_section_idx += 1 # Clear old content and append current line for the next section. curr_metadata_section.clear() curr_metadata_section += bytearray(line_content, "utf8") # Append any empty sections. - while parsed_sections[config_metadata_sections_idx].is_empty: + while metadata_sections[metadata_section_idx].is_empty: sections.append( _LttngMetadataStreamSection( - parsed_sections[config_metadata_sections_idx].timestamp, None + metadata_sections[metadata_section_idx].timestamp, None ) ) - config_metadata_sections_idx += 1 + metadata_section_idx += 1 else: # Append line_content to the current metadata section. curr_metadata_section += bytearray(line_content, "utf8") @@ -942,7 +1005,7 @@ def _split_metadata_sections(metadata_file_path, raw_config_sections): # We iterated over all the lines of the metadata file. Close the current section. sections.append( _LttngMetadataStreamSection( - parsed_sections[config_metadata_sections_idx].timestamp, + metadata_sections[metadata_section_idx].timestamp, bytes(curr_metadata_section), ) ) @@ -950,17 +1013,27 @@ def _split_metadata_sections(metadata_file_path, raw_config_sections): return sections +_StreamBeaconsT = Dict[str, Iterable[int]] + + # An LTTng trace, a sequence of LTTng data streams. -class LttngTrace(collections.abc.Sequence): - def __init__(self, trace_dir, metadata_sections, beacons): +class LttngTrace(Sequence[_LttngDataStream]): + def __init__( + self, + trace_dir: str, + metadata_sections_json: Optional[tjson.ArrayVal], + beacons_json: Optional[tjson.ObjVal], + ): assert os.path.isdir(trace_dir) self._path = trace_dir - self._create_metadata_stream(trace_dir, metadata_sections) - self._create_data_streams(trace_dir, beacons) + self._create_metadata_stream(trace_dir, metadata_sections_json) + self._create_data_streams(trace_dir, beacons_json) logging.info('Built trace: path="{}"'.format(trace_dir)) - def _create_data_streams(self, trace_dir, beacons): - data_stream_paths = [] + def _create_data_streams( + self, trace_dir: str, beacons_json: Optional[tjson.ObjVal] + ): + data_stream_paths = [] # type: list[str] for filename in os.listdir(trace_dir): path = os.path.join(trace_dir, filename) @@ -977,31 +1050,32 @@ class LttngTrace(collections.abc.Sequence): data_stream_paths.append(path) data_stream_paths.sort() - self._data_streams = [] + self._data_streams = [] # type: list[_LttngDataStream] for data_stream_path in data_stream_paths: stream_name = os.path.basename(data_stream_path) - this_stream_beacons = None - - if beacons is not None and stream_name in beacons: - this_stream_beacons = beacons[stream_name] + this_beacons_json = None + if beacons_json is not None and stream_name in beacons_json: + this_beacons_json = beacons_json.at(stream_name, tjson.ArrayVal) self._data_streams.append( - _LttngDataStream(data_stream_path, this_stream_beacons) + _LttngDataStream(data_stream_path, this_beacons_json) ) - def _create_metadata_stream(self, trace_dir, config_metadata_sections): + def _create_metadata_stream( + self, trace_dir: str, metadata_sections_json: Optional[tjson.ArrayVal] + ): metadata_path = os.path.join(trace_dir, "metadata") - metadata_sections = [] + metadata_sections = [] # type: list[_LttngMetadataStreamSection] - if config_metadata_sections is None: + if metadata_sections_json is None: with open(metadata_path, "rb") as metadata_file: metadata_sections.append( _LttngMetadataStreamSection(0, metadata_file.read()) ) else: metadata_sections = _split_metadata_sections( - metadata_path, config_metadata_sections + metadata_path, metadata_sections_json ) self._metadata_stream = _LttngMetadataStream(metadata_path, metadata_sections) @@ -1014,7 +1088,17 @@ class LttngTrace(collections.abc.Sequence): def metadata_stream(self): return self._metadata_stream - def __getitem__(self, index): + @overload + def __getitem__(self, index: int) -> _LttngDataStream: + ... + + @overload + def __getitem__(self, index: slice) -> Sequence[_LttngDataStream]: # noqa: F811 + ... + + def __getitem__( # noqa: F811 + self, index: Union[int, slice] + ) -> Union[_LttngDataStream, Sequence[_LttngDataStream]]: return self._data_streams[index] def __len__(self): @@ -1023,7 +1107,13 @@ class LttngTrace(collections.abc.Sequence): # The state of a single data stream. class _LttngLiveViewerSessionDataStreamState: - def __init__(self, ts_state, info, data_stream, metadata_stream_id): + def __init__( + self, + ts_state: "_LttngLiveViewerSessionTracingSessionState", + info: _LttngLiveViewerStreamInfo, + data_stream: _LttngDataStream, + metadata_stream_id: int, + ): self._ts_state = ts_state self._info = info self._data_stream = data_stream @@ -1058,13 +1148,22 @@ class _LttngLiveViewerSessionDataStreamState: return self._data_stream.index[self._cur_index_entry_index] + @property + def metadata_stream_id(self): + return self._metadata_stream_id + def goto_next_index_entry(self): self._cur_index_entry_index += 1 # The state of a single metadata stream. class _LttngLiveViewerSessionMetadataStreamState: - def __init__(self, ts_state, info, metadata_stream): + def __init__( + self, + ts_state: "_LttngLiveViewerSessionTracingSessionState", + info: _LttngLiveViewerStreamInfo, + metadata_stream: _LttngMetadataStream, + ): self._ts_state = ts_state self._info = info self._metadata_stream = metadata_stream @@ -1100,7 +1199,7 @@ class _LttngLiveViewerSessionMetadataStreamState: return self._is_sent @is_sent.setter - def is_sent(self, value): + def is_sent(self, value: bool): self._is_sent = value @property @@ -1132,7 +1231,13 @@ class _LttngLiveViewerSessionMetadataStreamState: # objects). class LttngTracingSessionDescriptor: def __init__( - self, name, tracing_session_id, hostname, live_timer_freq, client_count, traces + self, + name: str, + tracing_session_id: int, + hostname: str, + live_timer_freq: int, + client_count: int, + traces: Iterable[LttngTrace], ): for trace in traces: if name not in trace.path: @@ -1161,11 +1266,13 @@ class LttngTracingSessionDescriptor: # The state of a tracing session. class _LttngLiveViewerSessionTracingSessionState: - def __init__(self, tc_descr, base_stream_id): + def __init__(self, tc_descr: LttngTracingSessionDescriptor, base_stream_id: int): self._tc_descr = tc_descr - self._stream_infos = [] - self._ds_states = {} - self._ms_states = {} + self._stream_infos = [] # type: list[_LttngLiveViewerStreamInfo] + self._ds_states = {} # type: dict[int, _LttngLiveViewerSessionDataStreamState] + self._ms_states = ( + {} + ) # type: dict[int, _LttngLiveViewerSessionMetadataStreamState] stream_id = base_stream_id for trace in tc_descr.traces: @@ -1226,11 +1333,14 @@ class _LttngLiveViewerSessionTracingSessionState: return self._is_attached @is_attached.setter - def is_attached(self, value): + def is_attached(self, value: bool): self._is_attached = value -def needs_new_metadata_section(metadata_stream_state, latest_timestamp): +def needs_new_metadata_section( + metadata_stream_state: _LttngLiveViewerSessionMetadataStreamState, + latest_timestamp: int, +): if metadata_stream_state.next_section_timestamp is None: return False @@ -1245,13 +1355,17 @@ def needs_new_metadata_section(metadata_stream_state, latest_timestamp): class _LttngLiveViewerSession: def __init__( self, - viewer_session_id, - tracing_session_descriptors, - max_query_data_response_size, + viewer_session_id: int, + tracing_session_descriptors: Iterable[LttngTracingSessionDescriptor], + max_query_data_response_size: Optional[int], ): self._viewer_session_id = viewer_session_id - self._ts_states = {} - self._stream_states = {} + self._ts_states = ( + {} + ) # type: dict[int, _LttngLiveViewerSessionTracingSessionState] + self._stream_states = ( + {} + ) # type: dict[int, _LttngLiveViewerSessionDataStreamState | _LttngLiveViewerSessionMetadataStreamState] self._max_query_data_response_size = max_query_data_response_size total_stream_infos = 0 @@ -1276,13 +1390,13 @@ class _LttngLiveViewerSession: _LttngLiveViewerGetNewStreamInfosCommand: self._handle_get_new_stream_infos_command, _LttngLiveViewerGetNextDataStreamIndexEntryCommand: self._handle_get_next_data_stream_index_entry_command, _LttngLiveViewerGetTracingSessionInfosCommand: self._handle_get_tracing_session_infos_command, - } + } # type: dict[type[_LttngLiveViewerCommand], Callable[[Any], _LttngLiveViewerReply]] @property def viewer_session_id(self): return self._viewer_session_id - def _get_tracing_session_state(self, tracing_session_id): + def _get_tracing_session_state(self, tracing_session_id: int): if tracing_session_id not in self._ts_states: raise UnexpectedInput( "Unknown tracing session ID {}".format(tracing_session_id) @@ -1290,13 +1404,27 @@ class _LttngLiveViewerSession: return self._ts_states[tracing_session_id] - def _get_stream_state(self, stream_id): + def _get_data_stream_state(self, stream_id: int): if stream_id not in self._stream_states: - UnexpectedInput("Unknown stream ID {}".format(stream_id)) + RuntimeError("Unknown stream ID {}".format(stream_id)) - return self._stream_states[stream_id] + stream = self._stream_states[stream_id] + if type(stream) is not _LttngLiveViewerSessionDataStreamState: + raise RuntimeError("Stream is not a data stream") - def handle_command(self, cmd): + return stream + + def _get_metadata_stream_state(self, stream_id: int): + if stream_id not in self._stream_states: + RuntimeError("Unknown stream ID {}".format(stream_id)) + + stream = self._stream_states[stream_id] + if type(stream) is not _LttngLiveViewerSessionMetadataStreamState: + raise RuntimeError("Stream is not a metadata stream") + + return stream + + def handle_command(self, cmd: _LttngLiveViewerCommand): logging.info( "Handling command in viewer session: cmd-cls-name={}".format( cmd.__class__.__name__ @@ -1311,7 +1439,9 @@ class _LttngLiveViewerSession: return self._command_handlers[cmd_type](cmd) - def _handle_attach_to_tracing_session_command(self, cmd): + def _handle_attach_to_tracing_session_command( + self, cmd: _LttngLiveViewerAttachToTracingSessionCommand + ): fmt = 'Handling "attach to tracing session" command: ts-id={}, offset={}, seek-type={}' logging.info(fmt.format(cmd.tracing_session_id, cmd.offset, cmd.seek_type)) ts_state = self._get_tracing_session_state(cmd.tracing_session_id) @@ -1330,7 +1460,9 @@ class _LttngLiveViewerSession: status, ts_state.stream_infos ) - def _handle_detach_from_tracing_session_command(self, cmd): + def _handle_detach_from_tracing_session_command( + self, cmd: _LttngLiveViewerDetachFromTracingSessionCommand + ): fmt = 'Handling "detach from tracing session" command: ts-id={}' logging.info(fmt.format(cmd.tracing_session_id)) ts_state = self._get_tracing_session_state(cmd.tracing_session_id) @@ -1347,16 +1479,15 @@ class _LttngLiveViewerSession: status = _LttngLiveViewerDetachFromTracingSessionReply.Status.OK return _LttngLiveViewerDetachFromTracingSessionReply(status) - def _handle_get_next_data_stream_index_entry_command(self, cmd): + def _handle_get_next_data_stream_index_entry_command( + self, cmd: _LttngLiveViewerGetNextDataStreamIndexEntryCommand + ): fmt = 'Handling "get next data stream index entry" command: stream-id={}' logging.info(fmt.format(cmd.stream_id)) - stream_state = self._get_stream_state(cmd.stream_id) - metadata_stream_state = self._get_stream_state(stream_state._metadata_stream_id) - - if type(stream_state) is not _LttngLiveViewerSessionDataStreamState: - raise UnexpectedInput( - "Stream with ID {} is not a data stream".format(cmd.stream_id) - ) + stream_state = self._get_data_stream_state(cmd.stream_id) + metadata_stream_state = self._get_metadata_stream_state( + stream_state.metadata_stream_id + ) if stream_state.cur_index_entry is None: # The viewer is done reading this stream @@ -1379,10 +1510,9 @@ class _LttngLiveViewerSession: # The viewer only checks the `has_new_metadata` flag if the # reply's status is `OK`, so we need to provide an index here has_new_metadata = stream_state.tracing_session_state.has_new_metadata - if type(stream_state.cur_index_entry) is _LttngDataStreamIndexEntry: + if isinstance(stream_state.cur_index_entry, _LttngDataStreamIndexEntry): status = _LttngLiveViewerGetNextDataStreamIndexEntryReply.Status.OK else: - assert type(stream_state.cur_index_entry) is _LttngDataStreamBeaconEntry status = _LttngLiveViewerGetNextDataStreamIndexEntryReply.Status.INACTIVE reply = _LttngLiveViewerGetNextDataStreamIndexEntryReply( @@ -1391,17 +1521,14 @@ class _LttngLiveViewerSession: stream_state.goto_next_index_entry() return reply - def _handle_get_data_stream_packet_data_command(self, cmd): + def _handle_get_data_stream_packet_data_command( + self, cmd: _LttngLiveViewerGetDataStreamPacketDataCommand + ): fmt = 'Handling "get data stream packet data" command: stream-id={}, offset={}, req-length={}' logging.info(fmt.format(cmd.stream_id, cmd.offset, cmd.req_length)) - stream_state = self._get_stream_state(cmd.stream_id) + stream_state = self._get_data_stream_state(cmd.stream_id) data_response_length = cmd.req_length - if type(stream_state) is not _LttngLiveViewerSessionDataStreamState: - raise UnexpectedInput( - "Stream with ID {} is not a data stream".format(cmd.stream_id) - ) - if stream_state.tracing_session_state.has_new_metadata: status = _LttngLiveViewerGetDataStreamPacketDataReply.Status.ERROR return _LttngLiveViewerGetDataStreamPacketDataReply( @@ -1422,18 +1549,12 @@ class _LttngLiveViewerSession: status = _LttngLiveViewerGetDataStreamPacketDataReply.Status.OK return _LttngLiveViewerGetDataStreamPacketDataReply(status, data, False, False) - def _handle_get_metadata_stream_data_command(self, cmd): + def _handle_get_metadata_stream_data_command( + self, cmd: _LttngLiveViewerGetMetadataStreamDataCommand + ): fmt = 'Handling "get metadata stream data" command: stream-id={}' logging.info(fmt.format(cmd.stream_id)) - metadata_stream_state = self._get_stream_state(cmd.stream_id) - - if ( - type(metadata_stream_state) - is not _LttngLiveViewerSessionMetadataStreamState - ): - raise UnexpectedInput( - "Stream with ID {} is not a metadata stream".format(cmd.stream_id) - ) + metadata_stream_state = self._get_metadata_stream_state(cmd.stream_id) if metadata_stream_state.is_sent: status = _LttngLiveViewerGetMetadataStreamDataContentReply.Status.NO_NEW @@ -1455,7 +1576,9 @@ class _LttngLiveViewerSession: status, metadata_section.data ) - def _handle_get_new_stream_infos_command(self, cmd): + def _handle_get_new_stream_infos_command( + self, cmd: _LttngLiveViewerGetNewStreamInfosCommand + ): fmt = 'Handling "get new stream infos" command: ts-id={}' logging.info(fmt.format(cmd.tracing_session_id)) @@ -1467,7 +1590,9 @@ class _LttngLiveViewerSession: status = _LttngLiveViewerGetNewStreamInfosReply.Status.HUP return _LttngLiveViewerGetNewStreamInfosReply(status, []) - def _handle_get_tracing_session_infos_command(self, cmd): + def _handle_get_tracing_session_infos_command( + self, cmd: _LttngLiveViewerGetTracingSessionInfosCommand + ): logging.info('Handling "get tracing session infos" command.') infos = [ tss.tracing_session_descriptor.info for tss in self._ts_states.values() @@ -1475,7 +1600,9 @@ class _LttngLiveViewerSession: infos.sort(key=lambda info: info.name) return _LttngLiveViewerGetTracingSessionInfosReply(infos) - def _handle_create_viewer_session_command(self, cmd): + def _handle_create_viewer_session_command( + self, cmd: _LttngLiveViewerCreateViewerSessionCommand + ): logging.info('Handling "create viewer session" command.') status = _LttngLiveViewerCreateViewerSessionReply.Status.OK @@ -1501,10 +1628,10 @@ class _LttngLiveViewerSession: class LttngLiveServer: def __init__( self, - port, - port_filename, - tracing_session_descriptors, - max_query_data_response_size, + port: Optional[int], + port_filename: str, + tracing_session_descriptors: Iterable[LttngTracingSessionDescriptor], + max_query_data_response_size: Optional[int], ): logging.info("Server configuration:") @@ -1590,7 +1717,7 @@ class LttngLiveServer: ) return cmd - def _send_reply(self, reply): + def _send_reply(self, reply: _LttngLiveViewerReply): data = self._codec.encode(reply) logging.info( "Sending reply to viewer: reply-cls-name={}, length={}".format( @@ -1649,7 +1776,7 @@ class LttngLiveServer: finally: self._conn.close() - def _write_port_to_file(self, port_filename): + def _write_port_to_file(self, port_filename: str): # Write the port number to a temporary file. with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp_port_file: print(self._server_port, end="", file=tmp_port_file) @@ -1663,7 +1790,9 @@ class LttngLiveServer: ) -def _session_descriptors_from_path(sessions_filename, trace_path_prefix): +def _session_descriptors_from_path( + sessions_filename: str, trace_path_prefix: Optional[str] +): # File format is: # # [ @@ -1693,27 +1822,43 @@ def _session_descriptors_from_path(sessions_filename, trace_path_prefix): # } # ] with open(sessions_filename, "r") as sessions_file: - params = json.load(sessions_file) + sessions_json = tjson.load(sessions_file, tjson.ArrayVal) + + sessions = [] # type: list[LttngTracingSessionDescriptor] - sessions = [] + for session_json in sessions_json.iter(tjson.ObjVal): + name = session_json.at("name", tjson.StrVal).val + tracing_session_id = session_json.at("id", tjson.IntVal).val + hostname = session_json.at("hostname", tjson.StrVal).val + live_timer_freq = session_json.at("live-timer-freq", tjson.IntVal).val + client_count = session_json.at("client-count", tjson.IntVal).val + traces_json = session_json.at("traces", tjson.ArrayVal) - for session in params: - name = session["name"] - tracing_session_id = session["id"] - hostname = session["hostname"] - live_timer_freq = session["live-timer-freq"] - client_count = session["client-count"] - traces = [] + traces = [] # type: list[LttngTrace] - for trace in session["traces"]: - metadata_sections = trace.get("metadata-sections") - beacons = trace.get("beacons") - path = trace["path"] + for trace_json in traces_json.iter(tjson.ObjVal): + metadata_sections = ( + trace_json.at("metadata-sections", tjson.ArrayVal) + if "metadata-sections" in trace_json + else None + ) + beacons = ( + trace_json.at("beacons", tjson.ObjVal) + if "beacons" in trace_json + else None + ) + path = trace_json.at("path", tjson.StrVal).val - if not os.path.isabs(path): + if not os.path.isabs(path) and trace_path_prefix: path = os.path.join(trace_path_prefix, path) - traces.append(LttngTrace(path, metadata_sections, beacons)) + traces.append( + LttngTrace( + path, + metadata_sections, + beacons, + ) + ) sessions.append( LttngTracingSessionDescriptor( @@ -1729,7 +1874,7 @@ def _session_descriptors_from_path(sessions_filename, trace_path_prefix): return sessions -def _loglevel_parser(string): +def _loglevel_parser(string: str): loglevels = {"info": logging.INFO, "warning": logging.WARNING} if string not in loglevels: msg = "{} is not a valid loglevel".format(string) @@ -1787,13 +1932,19 @@ if __name__ == "__main__": args = parser.parse_args(args=remaining_args) try: + sessions_filename = args.sessions_filename # type: str + trace_path_prefix = args.trace_path_prefix # type: str | None sessions = _session_descriptors_from_path( - args.sessions_filename, - args.trace_path_prefix, - ) - LttngLiveServer( - args.port, args.port_filename, sessions, args.max_query_data_response_size + sessions_filename, + trace_path_prefix, ) + + port = args.port # type: int | None + port_filename = args.port_filename # type: str + max_query_data_response_size = ( + args.max_query_data_response_size + ) # type: int | None + LttngLiveServer(port, port_filename, sessions, max_query_data_response_size) except UnexpectedInput as exc: logging.error(str(exc)) print(exc, file=sys.stderr) diff --git a/tests/plugins/src.ctf.lttng-live/test_live b/tests/plugins/src.ctf.lttng-live/test_live index 9bbba432..f694d799 100755 --- a/tests/plugins/src.ctf.lttng-live/test_live +++ b/tests/plugins/src.ctf.lttng-live/test_live @@ -56,7 +56,7 @@ lttng_live_server() { # start server diag "$BT_TESTS_PYTHON_BIN $server_script --port-file $port_file --trace-path-prefix $trace_dir_native $server_args" - echo "$server_args" | xargs "$BT_TESTS_PYTHON_BIN" "$server_script" \ + echo "$server_args" | run_python xargs "$BT_TESTS_PYTHON_BIN" "$server_script" \ --port-file "$port_file" \ --trace-path-prefix "$trace_dir_native" & -- 2.34.1