1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-15 19:08:07 +00:00

style(python): upgrade debuglink.py to 3.10 style type annotations

This commit is contained in:
matejcik 2024-11-06 14:28:07 +01:00 committed by matejcik
parent 3a8f92f64d
commit dfac2ae4dd

View File

@ -14,6 +14,8 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import json import json
import logging import logging
import re import re
@ -33,11 +35,8 @@ from typing import (
Generator, Generator,
Iterable, Iterable,
Iterator, Iterator,
List,
Optional,
Sequence, Sequence,
Tuple, Tuple,
Type,
Union, Union,
) )
@ -57,7 +56,7 @@ if TYPE_CHECKING:
from .transport import Transport from .transport import Transport
ExpectedMessage = Union[ ExpectedMessage = Union[
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter" protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
] ]
AnyDict = Dict[str, Any] AnyDict = Dict[str, Any]
@ -65,8 +64,8 @@ if TYPE_CHECKING:
class InputFunc(Protocol): class InputFunc(Protocol):
def __call__( def __call__(
self, self,
hold_ms: Optional[int] = None, hold_ms: int | None = None,
wait: Optional[bool] = None, wait: bool | None = None,
) -> "LayoutContent": ... ) -> "LayoutContent": ...
@ -101,14 +100,14 @@ class UnstructuredJSONReader:
self.json_str = json_str self.json_str = json_str
# We may not receive valid JSON, e.g. from an old model in upgrade tests # We may not receive valid JSON, e.g. from an old model in upgrade tests
try: try:
self.dict: "AnyDict" = json.loads(json_str) self.dict: AnyDict = json.loads(json_str)
except json.JSONDecodeError: except json.JSONDecodeError:
self.dict = {} self.dict = {}
def top_level_value(self, key: str) -> Any: def top_level_value(self, key: str) -> Any:
return self.dict.get(key) return self.dict.get(key)
def find_objects_with_key_and_value(self, key: str, value: Any) -> List["AnyDict"]: def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
def recursively_find(data: Any) -> Iterator[Any]: def recursively_find(data: Any) -> Iterator[Any]:
if isinstance(data, dict): if isinstance(data, dict):
if data.get(key) == value: if data.get(key) == value:
@ -123,16 +122,14 @@ class UnstructuredJSONReader:
def find_unique_object_with_key_and_value( def find_unique_object_with_key_and_value(
self, key: str, value: Any self, key: str, value: Any
) -> Optional["AnyDict"]: ) -> AnyDict | None:
objects = self.find_objects_with_key_and_value(key, value) objects = self.find_objects_with_key_and_value(key, value)
if not objects: if not objects:
return None return None
assert len(objects) == 1 assert len(objects) == 1
return objects[0] return objects[0]
def find_values_by_key( def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
self, key: str, only_type: Optional[type] = None
) -> List[Any]:
def recursively_find(data: Any) -> Iterator[Any]: def recursively_find(data: Any) -> Iterator[Any]:
if isinstance(data, dict): if isinstance(data, dict):
if key in data: if key in data:
@ -151,7 +148,7 @@ class UnstructuredJSONReader:
return values return values
def find_unique_value_by_key( def find_unique_value_by_key(
self, key: str, default: Any, only_type: Optional[type] = None self, key: str, default: Any, only_type: type | None = None
) -> Any: ) -> Any:
values = self.find_values_by_key(key, only_type=only_type) values = self.find_values_by_key(key, only_type=only_type)
if not values: if not values:
@ -171,7 +168,7 @@ class LayoutContent(UnstructuredJSONReader):
"""Getting the main component of the layout.""" """Getting the main component of the layout."""
return self.top_level_value("component") or "no main component" return self.top_level_value("component") or "no main component"
def all_components(self) -> List[str]: def all_components(self) -> list[str]:
"""Getting all components of the layout.""" """Getting all components of the layout."""
return self.find_values_by_key("component", only_type=str) return self.find_values_by_key("component", only_type=str)
@ -209,7 +206,7 @@ class LayoutContent(UnstructuredJSONReader):
def title(self) -> str: def title(self) -> str:
"""Getting text that is displayed as a title and potentially subtitle.""" """Getting text that is displayed as a title and potentially subtitle."""
# There could be possibly subtitle as well # There could be possibly subtitle as well
title_parts: List[str] = [] title_parts: list[str] = []
title = self._get_str_or_dict_text("title") title = self._get_str_or_dict_text("title")
if title: if title:
@ -244,7 +241,7 @@ class LayoutContent(UnstructuredJSONReader):
# Look for paragraphs first (will match most of the time for TT) # Look for paragraphs first (will match most of the time for TT)
paragraphs = self.raw_content_paragraphs() paragraphs = self.raw_content_paragraphs()
if paragraphs: if paragraphs:
main_text_blocks: List[str] = [] main_text_blocks: list[str] = []
for par in paragraphs: for par in paragraphs:
par_content = "" par_content = ""
for line_or_newline in par: for line_or_newline in par:
@ -294,13 +291,13 @@ class LayoutContent(UnstructuredJSONReader):
# Default when not finding anything # Default when not finding anything
return self.main_component() return self.main_component()
def raw_content_paragraphs(self) -> Optional[List[List[str]]]: def raw_content_paragraphs(self) -> list[list[str]] | None:
"""Getting raw paragraphs as sent from Rust.""" """Getting raw paragraphs as sent from Rust."""
return self.find_unique_value_by_key("paragraphs", default=None, only_type=list) return self.find_unique_value_by_key("paragraphs", default=None, only_type=list)
def tt_check_seed_button_contents(self) -> List[str]: def tt_check_seed_button_contents(self) -> list[str]:
"""Getting list of button contents.""" """Getting list of button contents."""
buttons: List[str] = [] buttons: list[str] = []
button_objects = self.find_objects_with_key_and_value("component", "Button") button_objects = self.find_objects_with_key_and_value("component", "Button")
for button in button_objects: for button in button_objects:
if button.get("icon"): if button.get("icon"):
@ -309,7 +306,7 @@ class LayoutContent(UnstructuredJSONReader):
buttons.append(button["text"]) buttons.append(button["text"])
return buttons return buttons
def button_contents(self) -> List[str]: def button_contents(self) -> list[str]:
"""Getting list of button contents.""" """Getting list of button contents."""
buttons = self.find_unique_value_by_key("buttons", default={}, only_type=dict) buttons = self.find_unique_value_by_key("buttons", default={}, only_type=dict)
@ -331,13 +328,13 @@ class LayoutContent(UnstructuredJSONReader):
button_keys = ("left_btn", "middle_btn", "right_btn") button_keys = ("left_btn", "middle_btn", "right_btn")
return [get_button_content(btn_key) for btn_key in button_keys] return [get_button_content(btn_key) for btn_key in button_keys]
def seed_words(self) -> List[str]: def seed_words(self) -> list[str]:
"""Get all the seed words on the screen in order. """Get all the seed words on the screen in order.
Example content: "1. ladybug\n2. acid\n3. academic\n4. afraid" Example content: "1. ladybug\n2. acid\n3. academic\n4. afraid"
-> ["ladybug", "acid", "academic", "afraid"] -> ["ladybug", "acid", "academic", "afraid"]
""" """
words: List[str] = [] words: list[str] = []
for line in self.screen_content().split("\n"): for line in self.screen_content().split("\n"):
# Dot after index is optional (present on TT, not on TR) # Dot after index is optional (present on TT, not on TR)
match = re.match(r"^\s*\d+\.? (\w+)$", line) match = re.match(r"^\s*\d+\.? (\w+)$", line)
@ -377,7 +374,7 @@ class LayoutContent(UnstructuredJSONReader):
"""What is the choice being selected right now.""" """What is the choice being selected right now."""
return self.choice_items()[1] return self.choice_items()[1]
def choice_items(self) -> Tuple[str, str, str]: def choice_items(self) -> tuple[str, str, str]:
"""Getting actions for all three possible buttons.""" """Getting actions for all three possible buttons."""
choice_obj = self.find_unique_value_by_key( choice_obj = self.find_unique_value_by_key(
"choice_page", default={}, only_type=dict "choice_page", default={}, only_type=dict
@ -396,15 +393,15 @@ class LayoutContent(UnstructuredJSONReader):
return footer.get("description", "") + " " + footer.get("instruction", "") return footer.get("description", "") + " " + footer.get("instruction", "")
def multipage_content(layouts: List[LayoutContent]) -> str: def multipage_content(layouts: list[LayoutContent]) -> str:
"""Get overall content from multiple-page layout.""" """Get overall content from multiple-page layout."""
return "".join(layout.text_content() for layout in layouts) return "".join(layout.text_content() for layout in layouts)
def _make_input_func( def _make_input_func(
button: Optional[messages.DebugButton] = None, button: messages.DebugButton | None = None,
physical_button: Optional[messages.DebugPhysicalButton] = None, physical_button: messages.DebugPhysicalButton | None = None,
swipe: Optional[messages.DebugSwipeDirection] = None, swipe: messages.DebugSwipeDirection | None = None,
) -> "InputFunc": ) -> "InputFunc":
decision = messages.DebugLinkDecision( decision = messages.DebugLinkDecision(
button=button, button=button,
@ -414,8 +411,8 @@ def _make_input_func(
def input_func( def input_func(
self: "DebugLink", self: "DebugLink",
hold_ms: Optional[int] = None, hold_ms: int | None = None,
wait: Optional[bool] = None, wait: bool | None = None,
) -> LayoutContent: ) -> LayoutContent:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
decision.hold_ms = hold_ms decision.hold_ms = hold_ms
@ -431,18 +428,18 @@ class DebugLink:
self.mapping = mapping.DEFAULT_MAPPING self.mapping = mapping.DEFAULT_MAPPING
# To be set by TrezorClientDebugLink (is not known during creation time) # To be set by TrezorClientDebugLink (is not known during creation time)
self.model: Optional[models.TrezorModel] = None self.model: models.TrezorModel | None = None
self.version: Tuple[int, int, int] = (0, 0, 0) self.version: tuple[int, int, int] = (0, 0, 0)
# Where screenshots are being saved # Where screenshots are being saved
self.screenshot_recording_dir: Optional[str] = None self.screenshot_recording_dir: str | None = None
# For T1 screenshotting functionality in DebugUI # For T1 screenshotting functionality in DebugUI
self.t1_screenshot_directory: Optional[Path] = None self.t1_screenshot_directory: Path | None = None
self.t1_screenshot_counter = 0 self.t1_screenshot_counter = 0
# Optional file for saving text representation of the screen # Optional file for saving text representation of the screen
self.screen_text_file: Optional[Path] = None self.screen_text_file: Path | None = None
self.last_screen_content = "" self.last_screen_content = ""
self.waiting_for_layout_change = False self.waiting_for_layout_change = False
@ -470,7 +467,7 @@ class DebugLink:
assert self.model is not None assert self.model is not None
return LayoutType.from_model(self.model) return LayoutType.from_model(self.model)
def set_screen_text_file(self, file_path: Optional[Path]) -> None: def set_screen_text_file(self, file_path: Path | None) -> None:
if file_path is not None: if file_path is not None:
file_path.write_bytes(b"") file_path.write_bytes(b"")
self.screen_text_file = file_path self.screen_text_file = file_path
@ -609,7 +606,7 @@ class DebugLink:
""" """
self._call(messages.DebugLinkWatchLayout(watch=watch)) self._call(messages.DebugLinkWatchLayout(watch=watch))
def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str: def encode_pin(self, pin: str, matrix: str | None = None) -> str:
"""Transform correct PIN according to the displayed matrix.""" """Transform correct PIN according to the displayed matrix."""
if matrix is None: if matrix is None:
matrix = self.state().matrix matrix = self.state().matrix
@ -619,7 +616,7 @@ class DebugLink:
return "".join([str(matrix.index(p) + 1) for p in pin]) return "".join([str(matrix.index(p) + 1) for p in pin])
def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]: def read_recovery_word(self) -> Tuple[str | None, int | None]:
state = self.state() state = self.state()
return (state.recovery_fake_word, state.recovery_word_pos) return (state.recovery_fake_word, state.recovery_word_pos)
@ -628,7 +625,7 @@ class DebugLink:
return state.reset_word return state.reset_word
def _decision( def _decision(
self, decision: messages.DebugLinkDecision, wait: Optional[bool] = None self, decision: messages.DebugLinkDecision, wait: bool | None = None
) -> LayoutContent: ) -> LayoutContent:
"""Send a debuglink decision and returns the resulting layout. """Send a debuglink decision and returns the resulting layout.
@ -691,15 +688,15 @@ class DebugLink:
) )
"""Press right button. See `_decision` for more details.""" """Press right button. See `_decision` for more details."""
def input(self, word: str, wait: Optional[bool] = None) -> LayoutContent: def input(self, word: str, wait: bool | None = None) -> LayoutContent:
"""Send text input to the device. See `_decision` for more details.""" """Send text input to the device. See `_decision` for more details."""
return self._decision(messages.DebugLinkDecision(input=word), wait) return self._decision(messages.DebugLinkDecision(input=word), wait)
def click( def click(
self, self,
click: Tuple[int, int], click: Tuple[int, int],
hold_ms: Optional[int] = None, hold_ms: int | None = None,
wait: Optional[bool] = None, wait: bool | None = None,
) -> LayoutContent: ) -> LayoutContent:
"""Send a click to the device. See `_decision` for more details.""" """Send a click to the device. See `_decision` for more details."""
x, y = click x, y = click
@ -750,9 +747,7 @@ class DebugLink:
def reseed(self, value: int) -> protobuf.MessageType: def reseed(self, value: int) -> protobuf.MessageType:
return self._call(messages.DebugLinkReseedRandom(value=value)) return self._call(messages.DebugLinkReseedRandom(value=value))
def start_recording( def start_recording(self, directory: str, refresh_index: int | None = None) -> None:
self, directory: str, refresh_index: Optional[int] = None
) -> None:
self.screenshot_recording_dir = directory self.screenshot_recording_dir = directory
# Different recording logic between core and legacy # Different recording logic between core and legacy
if self.model is not models.T1B1: if self.model is not models.T1B1:
@ -807,7 +802,7 @@ class DebugLink:
assert len(data) == 128 * 64 // 8 assert len(data) == 128 * 64 // 8
pixels: List[int] = [] pixels: list[int] = []
for byteline in range(64 // 8): for byteline in range(64 // 8):
offset = byteline * 128 offset = byteline * 128
row = data[offset : offset + 128] row = data[offset : offset + 128]
@ -840,7 +835,7 @@ class NullDebugLink(DebugLink):
def _call( def _call(
self, msg: protobuf.MessageType, nowait: bool = False self, msg: protobuf.MessageType, nowait: bool = False
) -> Optional[messages.DebugLinkState]: ) -> messages.DebugLinkState | None:
if not nowait: if not nowait:
if isinstance(msg, messages.DebugLinkGetState): if isinstance(msg, messages.DebugLinkGetState):
return messages.DebugLinkState() return messages.DebugLinkState()
@ -858,7 +853,7 @@ class DebugUI:
self.clear() self.clear()
def clear(self) -> None: def clear(self) -> None:
self.pins: Optional[Iterator[str]] = None self.pins: Iterator[str] | None = None
self.passphrase = "" self.passphrase = ""
self.input_flow: Union[ self.input_flow: Union[
Generator[None, messages.ButtonRequest, None], object, None Generator[None, messages.ButtonRequest, None], object, None
@ -897,7 +892,7 @@ class DebugUI:
except StopIteration: except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str: def get_pin(self, code: PinMatrixRequestType | None = None) -> str:
self.debuglink.snapshot_legacy() self.debuglink.snapshot_legacy()
if self.pins is None: if self.pins is None:
@ -914,7 +909,7 @@ class DebugUI:
class MessageFilter: class MessageFilter:
def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None: def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
self.message_type = message_type self.message_type = message_type
self.fields: Dict[str, Any] = {} self.fields: Dict[str, Any] = {}
self.update_fields(**fields) self.update_fields(**fields)
@ -967,7 +962,7 @@ class MessageFilter:
return True return True
def to_string(self, maxwidth: int = 80) -> str: def to_string(self, maxwidth: int = 80) -> str:
fields: List[Tuple[str, str]] = [] fields: list[Tuple[str, str]] = []
for field in self.message_type.FIELDS.values(): for field in self.message_type.FIELDS.values():
if field.name not in self.fields: if field.name not in self.fields:
continue continue
@ -988,7 +983,7 @@ class MessageFilter:
if len(oneline_str) < maxwidth: if len(oneline_str) < maxwidth:
return f"{self.message_type.__name__}({oneline_str})" return f"{self.message_type.__name__}({oneline_str})"
else: else:
item: List[str] = [] item: list[str] = []
item.append(f"{self.message_type.__name__}(") item.append(f"{self.message_type.__name__}(")
for pair in pairs: for pair in pairs:
item.append(f" {pair}") item.append(f" {pair}")
@ -1052,11 +1047,11 @@ class TrezorClientDebugLink(TrezorClient):
""" """
self.ui: DebugUI = DebugUI(self.debug) self.ui: DebugUI = DebugUI(self.debug)
self.in_with_statement = False self.in_with_statement = False
self.expected_responses: Optional[List[MessageFilter]] = None self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: Optional[List[protobuf.MessageType]] = None self.actual_responses: list[protobuf.MessageType] | None = None
self.filters: Dict[ self.filters: dict[
Type[protobuf.MessageType], type[protobuf.MessageType],
Optional[Callable[[protobuf.MessageType], protobuf.MessageType]], Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {} ] = {}
def ensure_open(self) -> None: def ensure_open(self) -> None:
@ -1076,8 +1071,8 @@ class TrezorClientDebugLink(TrezorClient):
def set_filter( def set_filter(
self, self,
message_type: Type[protobuf.MessageType], message_type: type[protobuf.MessageType],
callback: Optional[Callable[[protobuf.MessageType], protobuf.MessageType]], callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None: ) -> None:
"""Configure a filter function for a specified message type. """Configure a filter function for a specified message type.
@ -1102,7 +1097,7 @@ class TrezorClientDebugLink(TrezorClient):
return msg return msg
def set_input_flow( def set_input_flow(
self, input_flow: Generator[None, Optional[messages.ButtonRequest], None] self, input_flow: Generator[None, messages.ButtonRequest | None, None]
) -> None: ) -> None:
"""Configure a sequence of input events for the current with-block. """Configure a sequence of input events for the current with-block.
@ -1184,7 +1179,7 @@ class TrezorClientDebugLink(TrezorClient):
input_flow.throw(exc_type, value, traceback) input_flow.throw(exc_type, value, traceback)
def set_expected_responses( def set_expected_responses(
self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
) -> None: ) -> None:
"""Set a sequence of expected responses to client calls. """Set a sequence of expected responses to client calls.
@ -1251,10 +1246,10 @@ class TrezorClientDebugLink(TrezorClient):
return super()._raw_write(self._filter_message(msg)) return super()._raw_write(self._filter_message(msg))
@staticmethod @staticmethod
def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]: def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output: List[str] = [] output: list[str] = []
output.append("Expected responses:") output.append("Expected responses:")
if start_at > 0: if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)") output.append(f" (...{start_at} previous responses omitted)")
@ -1272,8 +1267,8 @@ class TrezorClientDebugLink(TrezorClient):
@classmethod @classmethod
def _verify_responses( def _verify_responses(
cls, cls,
expected: Optional[List[MessageFilter]], expected: list[MessageFilter] | None,
actual: Optional[List[protobuf.MessageType]], actual: list[protobuf.MessageType] | None,
) -> None: ) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612 __tracebackhide__ = True # for pytest # pylint: disable=W0612
@ -1350,9 +1345,9 @@ class TrezorClientDebugLink(TrezorClient):
def load_device( def load_device(
client: "TrezorClient", client: "TrezorClient",
mnemonic: Union[str, Iterable[str]], mnemonic: Union[str, Iterable[str]],
pin: Optional[str], pin: str | None,
passphrase_protection: bool, passphrase_protection: bool,
label: Optional[str], label: str | None,
skip_checksum: bool = False, skip_checksum: bool = False,
needs_backup: bool = False, needs_backup: bool = False,
no_backup: bool = False, no_backup: bool = False,