From dfac2ae4dd6902b9870aee50007cfa7f4aa18239 Mon Sep 17 00:00:00 2001 From: matejcik Date: Wed, 6 Nov 2024 14:28:07 +0100 Subject: [PATCH] style(python): upgrade debuglink.py to 3.10 style type annotations --- python/src/trezorlib/debuglink.py | 127 ++++++++++++++---------------- 1 file changed, 61 insertions(+), 66 deletions(-) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index a6f77b6fdc..aae023c8eb 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -14,6 +14,8 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import json import logging import re @@ -33,11 +35,8 @@ from typing import ( Generator, Iterable, Iterator, - List, - Optional, Sequence, Tuple, - Type, Union, ) @@ -57,7 +56,7 @@ if TYPE_CHECKING: from .transport import Transport ExpectedMessage = Union[ - protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter" + protobuf.MessageType, type[protobuf.MessageType], "MessageFilter" ] AnyDict = Dict[str, Any] @@ -65,8 +64,8 @@ if TYPE_CHECKING: class InputFunc(Protocol): def __call__( self, - hold_ms: Optional[int] = None, - wait: Optional[bool] = None, + hold_ms: int | None = None, + wait: bool | None = None, ) -> "LayoutContent": ... @@ -101,14 +100,14 @@ class UnstructuredJSONReader: self.json_str = json_str # We may not receive valid JSON, e.g. from an old model in upgrade tests try: - self.dict: "AnyDict" = json.loads(json_str) + self.dict: AnyDict = json.loads(json_str) except json.JSONDecodeError: self.dict = {} def top_level_value(self, key: str) -> Any: return self.dict.get(key) - def find_objects_with_key_and_value(self, key: str, value: Any) -> List["AnyDict"]: + def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]: def recursively_find(data: Any) -> Iterator[Any]: if isinstance(data, dict): if data.get(key) == value: @@ -123,16 +122,14 @@ class UnstructuredJSONReader: def find_unique_object_with_key_and_value( self, key: str, value: Any - ) -> Optional["AnyDict"]: + ) -> AnyDict | None: objects = self.find_objects_with_key_and_value(key, value) if not objects: return None assert len(objects) == 1 return objects[0] - def find_values_by_key( - self, key: str, only_type: Optional[type] = None - ) -> List[Any]: + def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]: def recursively_find(data: Any) -> Iterator[Any]: if isinstance(data, dict): if key in data: @@ -151,7 +148,7 @@ class UnstructuredJSONReader: return values def find_unique_value_by_key( - self, key: str, default: Any, only_type: Optional[type] = None + self, key: str, default: Any, only_type: type | None = None ) -> Any: values = self.find_values_by_key(key, only_type=only_type) if not values: @@ -171,7 +168,7 @@ class LayoutContent(UnstructuredJSONReader): """Getting the main component of the layout.""" return self.top_level_value("component") or "no main component" - def all_components(self) -> List[str]: + def all_components(self) -> list[str]: """Getting all components of the layout.""" return self.find_values_by_key("component", only_type=str) @@ -209,7 +206,7 @@ class LayoutContent(UnstructuredJSONReader): def title(self) -> str: """Getting text that is displayed as a title and potentially subtitle.""" # There could be possibly subtitle as well - title_parts: List[str] = [] + title_parts: list[str] = [] title = self._get_str_or_dict_text("title") if title: @@ -244,7 +241,7 @@ class LayoutContent(UnstructuredJSONReader): # Look for paragraphs first (will match most of the time for TT) paragraphs = self.raw_content_paragraphs() if paragraphs: - main_text_blocks: List[str] = [] + main_text_blocks: list[str] = [] for par in paragraphs: par_content = "" for line_or_newline in par: @@ -294,13 +291,13 @@ class LayoutContent(UnstructuredJSONReader): # Default when not finding anything return self.main_component() - def raw_content_paragraphs(self) -> Optional[List[List[str]]]: + def raw_content_paragraphs(self) -> list[list[str]] | None: """Getting raw paragraphs as sent from Rust.""" return self.find_unique_value_by_key("paragraphs", default=None, only_type=list) - def tt_check_seed_button_contents(self) -> List[str]: + def tt_check_seed_button_contents(self) -> list[str]: """Getting list of button contents.""" - buttons: List[str] = [] + buttons: list[str] = [] button_objects = self.find_objects_with_key_and_value("component", "Button") for button in button_objects: if button.get("icon"): @@ -309,7 +306,7 @@ class LayoutContent(UnstructuredJSONReader): buttons.append(button["text"]) return buttons - def button_contents(self) -> List[str]: + def button_contents(self) -> list[str]: """Getting list of button contents.""" buttons = self.find_unique_value_by_key("buttons", default={}, only_type=dict) @@ -331,13 +328,13 @@ class LayoutContent(UnstructuredJSONReader): button_keys = ("left_btn", "middle_btn", "right_btn") return [get_button_content(btn_key) for btn_key in button_keys] - def seed_words(self) -> List[str]: + def seed_words(self) -> list[str]: """Get all the seed words on the screen in order. Example content: "1. ladybug\n2. acid\n3. academic\n4. afraid" -> ["ladybug", "acid", "academic", "afraid"] """ - words: List[str] = [] + words: list[str] = [] for line in self.screen_content().split("\n"): # Dot after index is optional (present on TT, not on TR) match = re.match(r"^\s*\d+\.? (\w+)$", line) @@ -377,7 +374,7 @@ class LayoutContent(UnstructuredJSONReader): """What is the choice being selected right now.""" return self.choice_items()[1] - def choice_items(self) -> Tuple[str, str, str]: + def choice_items(self) -> tuple[str, str, str]: """Getting actions for all three possible buttons.""" choice_obj = self.find_unique_value_by_key( "choice_page", default={}, only_type=dict @@ -396,15 +393,15 @@ class LayoutContent(UnstructuredJSONReader): return footer.get("description", "") + " " + footer.get("instruction", "") -def multipage_content(layouts: List[LayoutContent]) -> str: +def multipage_content(layouts: list[LayoutContent]) -> str: """Get overall content from multiple-page layout.""" return "".join(layout.text_content() for layout in layouts) def _make_input_func( - button: Optional[messages.DebugButton] = None, - physical_button: Optional[messages.DebugPhysicalButton] = None, - swipe: Optional[messages.DebugSwipeDirection] = None, + button: messages.DebugButton | None = None, + physical_button: messages.DebugPhysicalButton | None = None, + swipe: messages.DebugSwipeDirection | None = None, ) -> "InputFunc": decision = messages.DebugLinkDecision( button=button, @@ -414,8 +411,8 @@ def _make_input_func( def input_func( self: "DebugLink", - hold_ms: Optional[int] = None, - wait: Optional[bool] = None, + hold_ms: int | None = None, + wait: bool | None = None, ) -> LayoutContent: __tracebackhide__ = True # for pytest # pylint: disable=W0612 decision.hold_ms = hold_ms @@ -431,18 +428,18 @@ class DebugLink: self.mapping = mapping.DEFAULT_MAPPING # To be set by TrezorClientDebugLink (is not known during creation time) - self.model: Optional[models.TrezorModel] = None - self.version: Tuple[int, int, int] = (0, 0, 0) + self.model: models.TrezorModel | None = None + self.version: tuple[int, int, int] = (0, 0, 0) # Where screenshots are being saved - self.screenshot_recording_dir: Optional[str] = None + self.screenshot_recording_dir: str | None = None # For T1 screenshotting functionality in DebugUI - self.t1_screenshot_directory: Optional[Path] = None + self.t1_screenshot_directory: Path | None = None self.t1_screenshot_counter = 0 # Optional file for saving text representation of the screen - self.screen_text_file: Optional[Path] = None + self.screen_text_file: Path | None = None self.last_screen_content = "" self.waiting_for_layout_change = False @@ -470,7 +467,7 @@ class DebugLink: assert self.model is not None return LayoutType.from_model(self.model) - def set_screen_text_file(self, file_path: Optional[Path]) -> None: + def set_screen_text_file(self, file_path: Path | None) -> None: if file_path is not None: file_path.write_bytes(b"") self.screen_text_file = file_path @@ -609,7 +606,7 @@ class DebugLink: """ self._call(messages.DebugLinkWatchLayout(watch=watch)) - def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str: + def encode_pin(self, pin: str, matrix: str | None = None) -> str: """Transform correct PIN according to the displayed matrix.""" if matrix is None: matrix = self.state().matrix @@ -619,7 +616,7 @@ class DebugLink: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]: + def read_recovery_word(self) -> Tuple[str | None, int | None]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) @@ -628,7 +625,7 @@ class DebugLink: return state.reset_word def _decision( - self, decision: messages.DebugLinkDecision, wait: Optional[bool] = None + self, decision: messages.DebugLinkDecision, wait: bool | None = None ) -> LayoutContent: """Send a debuglink decision and returns the resulting layout. @@ -691,15 +688,15 @@ class DebugLink: ) """Press right button. See `_decision` for more details.""" - def input(self, word: str, wait: Optional[bool] = None) -> LayoutContent: + def input(self, word: str, wait: bool | None = None) -> LayoutContent: """Send text input to the device. See `_decision` for more details.""" return self._decision(messages.DebugLinkDecision(input=word), wait) def click( self, click: Tuple[int, int], - hold_ms: Optional[int] = None, - wait: Optional[bool] = None, + hold_ms: int | None = None, + wait: bool | None = None, ) -> LayoutContent: """Send a click to the device. See `_decision` for more details.""" x, y = click @@ -750,9 +747,7 @@ class DebugLink: def reseed(self, value: int) -> protobuf.MessageType: return self._call(messages.DebugLinkReseedRandom(value=value)) - def start_recording( - self, directory: str, refresh_index: Optional[int] = None - ) -> None: + def start_recording(self, directory: str, refresh_index: int | None = None) -> None: self.screenshot_recording_dir = directory # Different recording logic between core and legacy if self.model is not models.T1B1: @@ -807,7 +802,7 @@ class DebugLink: assert len(data) == 128 * 64 // 8 - pixels: List[int] = [] + pixels: list[int] = [] for byteline in range(64 // 8): offset = byteline * 128 row = data[offset : offset + 128] @@ -840,7 +835,7 @@ class NullDebugLink(DebugLink): def _call( self, msg: protobuf.MessageType, nowait: bool = False - ) -> Optional[messages.DebugLinkState]: + ) -> messages.DebugLinkState | None: if not nowait: if isinstance(msg, messages.DebugLinkGetState): return messages.DebugLinkState() @@ -858,7 +853,7 @@ class DebugUI: self.clear() def clear(self) -> None: - self.pins: Optional[Iterator[str]] = None + self.pins: Iterator[str] | None = None self.passphrase = "" self.input_flow: Union[ Generator[None, messages.ButtonRequest, None], object, None @@ -897,7 +892,7 @@ class DebugUI: except StopIteration: self.input_flow = self.INPUT_FLOW_DONE - def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str: + def get_pin(self, code: PinMatrixRequestType | None = None) -> str: self.debuglink.snapshot_legacy() if self.pins is None: @@ -914,7 +909,7 @@ class DebugUI: class MessageFilter: - def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None: + def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None: self.message_type = message_type self.fields: Dict[str, Any] = {} self.update_fields(**fields) @@ -967,7 +962,7 @@ class MessageFilter: return True def to_string(self, maxwidth: int = 80) -> str: - fields: List[Tuple[str, str]] = [] + fields: list[Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -988,7 +983,7 @@ class MessageFilter: if len(oneline_str) < maxwidth: return f"{self.message_type.__name__}({oneline_str})" else: - item: List[str] = [] + item: list[str] = [] item.append(f"{self.message_type.__name__}(") for pair in pairs: item.append(f" {pair}") @@ -1052,11 +1047,11 @@ class TrezorClientDebugLink(TrezorClient): """ self.ui: DebugUI = DebugUI(self.debug) self.in_with_statement = False - self.expected_responses: Optional[List[MessageFilter]] = None - self.actual_responses: Optional[List[protobuf.MessageType]] = None - self.filters: Dict[ - Type[protobuf.MessageType], - Optional[Callable[[protobuf.MessageType], protobuf.MessageType]], + self.expected_responses: list[MessageFilter] | None = None + self.actual_responses: list[protobuf.MessageType] | None = None + self.filters: dict[ + type[protobuf.MessageType], + Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} def ensure_open(self) -> None: @@ -1076,8 +1071,8 @@ class TrezorClientDebugLink(TrezorClient): def set_filter( self, - message_type: Type[protobuf.MessageType], - callback: Optional[Callable[[protobuf.MessageType], protobuf.MessageType]], + message_type: type[protobuf.MessageType], + callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None, ) -> None: """Configure a filter function for a specified message type. @@ -1102,7 +1097,7 @@ class TrezorClientDebugLink(TrezorClient): return msg def set_input_flow( - self, input_flow: Generator[None, Optional[messages.ButtonRequest], None] + self, input_flow: Generator[None, messages.ButtonRequest | None, None] ) -> None: """Configure a sequence of input events for the current with-block. @@ -1184,7 +1179,7 @@ class TrezorClientDebugLink(TrezorClient): input_flow.throw(exc_type, value, traceback) def set_expected_responses( - self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] + self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] ) -> None: """Set a sequence of expected responses to client calls. @@ -1251,10 +1246,10 @@ class TrezorClientDebugLink(TrezorClient): return super()._raw_write(self._filter_message(msg)) @staticmethod - def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]: + def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) - output: List[str] = [] + output: list[str] = [] output.append("Expected responses:") if start_at > 0: output.append(f" (...{start_at} previous responses omitted)") @@ -1272,8 +1267,8 @@ class TrezorClientDebugLink(TrezorClient): @classmethod def _verify_responses( cls, - expected: Optional[List[MessageFilter]], - actual: Optional[List[protobuf.MessageType]], + expected: list[MessageFilter] | None, + actual: list[protobuf.MessageType] | None, ) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 @@ -1350,9 +1345,9 @@ class TrezorClientDebugLink(TrezorClient): def load_device( client: "TrezorClient", mnemonic: Union[str, Iterable[str]], - pin: Optional[str], + pin: str | None, passphrase_protection: bool, - label: Optional[str], + label: str | None, skip_checksum: bool = False, needs_backup: bool = False, no_backup: bool = False,