1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-15 09:50:57 +00:00

style(python): upgrade debuglink.py to 3.10 style type annotations

This commit is contained in:
matejcik 2024-11-06 14:28:07 +01:00 committed by matejcik
parent 3a8f92f64d
commit dfac2ae4dd

View File

@ -14,6 +14,8 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import json
import logging
import re
@ -33,11 +35,8 @@ from typing import (
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
)
@ -57,7 +56,7 @@ if TYPE_CHECKING:
from .transport import Transport
ExpectedMessage = Union[
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
]
AnyDict = Dict[str, Any]
@ -65,8 +64,8 @@ if TYPE_CHECKING:
class InputFunc(Protocol):
def __call__(
self,
hold_ms: Optional[int] = None,
wait: Optional[bool] = None,
hold_ms: int | None = None,
wait: bool | None = None,
) -> "LayoutContent": ...
@ -101,14 +100,14 @@ class UnstructuredJSONReader:
self.json_str = json_str
# We may not receive valid JSON, e.g. from an old model in upgrade tests
try:
self.dict: "AnyDict" = json.loads(json_str)
self.dict: AnyDict = json.loads(json_str)
except json.JSONDecodeError:
self.dict = {}
def top_level_value(self, key: str) -> Any:
return self.dict.get(key)
def find_objects_with_key_and_value(self, key: str, value: Any) -> List["AnyDict"]:
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
def recursively_find(data: Any) -> Iterator[Any]:
if isinstance(data, dict):
if data.get(key) == value:
@ -123,16 +122,14 @@ class UnstructuredJSONReader:
def find_unique_object_with_key_and_value(
self, key: str, value: Any
) -> Optional["AnyDict"]:
) -> AnyDict | None:
objects = self.find_objects_with_key_and_value(key, value)
if not objects:
return None
assert len(objects) == 1
return objects[0]
def find_values_by_key(
self, key: str, only_type: Optional[type] = None
) -> List[Any]:
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
def recursively_find(data: Any) -> Iterator[Any]:
if isinstance(data, dict):
if key in data:
@ -151,7 +148,7 @@ class UnstructuredJSONReader:
return values
def find_unique_value_by_key(
self, key: str, default: Any, only_type: Optional[type] = None
self, key: str, default: Any, only_type: type | None = None
) -> Any:
values = self.find_values_by_key(key, only_type=only_type)
if not values:
@ -171,7 +168,7 @@ class LayoutContent(UnstructuredJSONReader):
"""Getting the main component of the layout."""
return self.top_level_value("component") or "no main component"
def all_components(self) -> List[str]:
def all_components(self) -> list[str]:
"""Getting all components of the layout."""
return self.find_values_by_key("component", only_type=str)
@ -209,7 +206,7 @@ class LayoutContent(UnstructuredJSONReader):
def title(self) -> str:
"""Getting text that is displayed as a title and potentially subtitle."""
# There could be possibly subtitle as well
title_parts: List[str] = []
title_parts: list[str] = []
title = self._get_str_or_dict_text("title")
if title:
@ -244,7 +241,7 @@ class LayoutContent(UnstructuredJSONReader):
# Look for paragraphs first (will match most of the time for TT)
paragraphs = self.raw_content_paragraphs()
if paragraphs:
main_text_blocks: List[str] = []
main_text_blocks: list[str] = []
for par in paragraphs:
par_content = ""
for line_or_newline in par:
@ -294,13 +291,13 @@ class LayoutContent(UnstructuredJSONReader):
# Default when not finding anything
return self.main_component()
def raw_content_paragraphs(self) -> Optional[List[List[str]]]:
def raw_content_paragraphs(self) -> list[list[str]] | None:
"""Getting raw paragraphs as sent from Rust."""
return self.find_unique_value_by_key("paragraphs", default=None, only_type=list)
def tt_check_seed_button_contents(self) -> List[str]:
def tt_check_seed_button_contents(self) -> list[str]:
"""Getting list of button contents."""
buttons: List[str] = []
buttons: list[str] = []
button_objects = self.find_objects_with_key_and_value("component", "Button")
for button in button_objects:
if button.get("icon"):
@ -309,7 +306,7 @@ class LayoutContent(UnstructuredJSONReader):
buttons.append(button["text"])
return buttons
def button_contents(self) -> List[str]:
def button_contents(self) -> list[str]:
"""Getting list of button contents."""
buttons = self.find_unique_value_by_key("buttons", default={}, only_type=dict)
@ -331,13 +328,13 @@ class LayoutContent(UnstructuredJSONReader):
button_keys = ("left_btn", "middle_btn", "right_btn")
return [get_button_content(btn_key) for btn_key in button_keys]
def seed_words(self) -> List[str]:
def seed_words(self) -> list[str]:
"""Get all the seed words on the screen in order.
Example content: "1. ladybug\n2. acid\n3. academic\n4. afraid"
-> ["ladybug", "acid", "academic", "afraid"]
"""
words: List[str] = []
words: list[str] = []
for line in self.screen_content().split("\n"):
# Dot after index is optional (present on TT, not on TR)
match = re.match(r"^\s*\d+\.? (\w+)$", line)
@ -377,7 +374,7 @@ class LayoutContent(UnstructuredJSONReader):
"""What is the choice being selected right now."""
return self.choice_items()[1]
def choice_items(self) -> Tuple[str, str, str]:
def choice_items(self) -> tuple[str, str, str]:
"""Getting actions for all three possible buttons."""
choice_obj = self.find_unique_value_by_key(
"choice_page", default={}, only_type=dict
@ -396,15 +393,15 @@ class LayoutContent(UnstructuredJSONReader):
return footer.get("description", "") + " " + footer.get("instruction", "")
def multipage_content(layouts: List[LayoutContent]) -> str:
def multipage_content(layouts: list[LayoutContent]) -> str:
"""Get overall content from multiple-page layout."""
return "".join(layout.text_content() for layout in layouts)
def _make_input_func(
button: Optional[messages.DebugButton] = None,
physical_button: Optional[messages.DebugPhysicalButton] = None,
swipe: Optional[messages.DebugSwipeDirection] = None,
button: messages.DebugButton | None = None,
physical_button: messages.DebugPhysicalButton | None = None,
swipe: messages.DebugSwipeDirection | None = None,
) -> "InputFunc":
decision = messages.DebugLinkDecision(
button=button,
@ -414,8 +411,8 @@ def _make_input_func(
def input_func(
self: "DebugLink",
hold_ms: Optional[int] = None,
wait: Optional[bool] = None,
hold_ms: int | None = None,
wait: bool | None = None,
) -> LayoutContent:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
decision.hold_ms = hold_ms
@ -431,18 +428,18 @@ class DebugLink:
self.mapping = mapping.DEFAULT_MAPPING
# To be set by TrezorClientDebugLink (is not known during creation time)
self.model: Optional[models.TrezorModel] = None
self.version: Tuple[int, int, int] = (0, 0, 0)
self.model: models.TrezorModel | None = None
self.version: tuple[int, int, int] = (0, 0, 0)
# Where screenshots are being saved
self.screenshot_recording_dir: Optional[str] = None
self.screenshot_recording_dir: str | None = None
# For T1 screenshotting functionality in DebugUI
self.t1_screenshot_directory: Optional[Path] = None
self.t1_screenshot_directory: Path | None = None
self.t1_screenshot_counter = 0
# Optional file for saving text representation of the screen
self.screen_text_file: Optional[Path] = None
self.screen_text_file: Path | None = None
self.last_screen_content = ""
self.waiting_for_layout_change = False
@ -470,7 +467,7 @@ class DebugLink:
assert self.model is not None
return LayoutType.from_model(self.model)
def set_screen_text_file(self, file_path: Optional[Path]) -> None:
def set_screen_text_file(self, file_path: Path | None) -> None:
if file_path is not None:
file_path.write_bytes(b"")
self.screen_text_file = file_path
@ -609,7 +606,7 @@ class DebugLink:
"""
self._call(messages.DebugLinkWatchLayout(watch=watch))
def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str:
def encode_pin(self, pin: str, matrix: str | None = None) -> str:
"""Transform correct PIN according to the displayed matrix."""
if matrix is None:
matrix = self.state().matrix
@ -619,7 +616,7 @@ class DebugLink:
return "".join([str(matrix.index(p) + 1) for p in pin])
def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]:
def read_recovery_word(self) -> Tuple[str | None, int | None]:
state = self.state()
return (state.recovery_fake_word, state.recovery_word_pos)
@ -628,7 +625,7 @@ class DebugLink:
return state.reset_word
def _decision(
self, decision: messages.DebugLinkDecision, wait: Optional[bool] = None
self, decision: messages.DebugLinkDecision, wait: bool | None = None
) -> LayoutContent:
"""Send a debuglink decision and returns the resulting layout.
@ -691,15 +688,15 @@ class DebugLink:
)
"""Press right button. See `_decision` for more details."""
def input(self, word: str, wait: Optional[bool] = None) -> LayoutContent:
def input(self, word: str, wait: bool | None = None) -> LayoutContent:
"""Send text input to the device. See `_decision` for more details."""
return self._decision(messages.DebugLinkDecision(input=word), wait)
def click(
self,
click: Tuple[int, int],
hold_ms: Optional[int] = None,
wait: Optional[bool] = None,
hold_ms: int | None = None,
wait: bool | None = None,
) -> LayoutContent:
"""Send a click to the device. See `_decision` for more details."""
x, y = click
@ -750,9 +747,7 @@ class DebugLink:
def reseed(self, value: int) -> protobuf.MessageType:
return self._call(messages.DebugLinkReseedRandom(value=value))
def start_recording(
self, directory: str, refresh_index: Optional[int] = None
) -> None:
def start_recording(self, directory: str, refresh_index: int | None = None) -> None:
self.screenshot_recording_dir = directory
# Different recording logic between core and legacy
if self.model is not models.T1B1:
@ -807,7 +802,7 @@ class DebugLink:
assert len(data) == 128 * 64 // 8
pixels: List[int] = []
pixels: list[int] = []
for byteline in range(64 // 8):
offset = byteline * 128
row = data[offset : offset + 128]
@ -840,7 +835,7 @@ class NullDebugLink(DebugLink):
def _call(
self, msg: protobuf.MessageType, nowait: bool = False
) -> Optional[messages.DebugLinkState]:
) -> messages.DebugLinkState | None:
if not nowait:
if isinstance(msg, messages.DebugLinkGetState):
return messages.DebugLinkState()
@ -858,7 +853,7 @@ class DebugUI:
self.clear()
def clear(self) -> None:
self.pins: Optional[Iterator[str]] = None
self.pins: Iterator[str] | None = None
self.passphrase = ""
self.input_flow: Union[
Generator[None, messages.ButtonRequest, None], object, None
@ -897,7 +892,7 @@ class DebugUI:
except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str:
def get_pin(self, code: PinMatrixRequestType | None = None) -> str:
self.debuglink.snapshot_legacy()
if self.pins is None:
@ -914,7 +909,7 @@ class DebugUI:
class MessageFilter:
def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None:
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
self.message_type = message_type
self.fields: Dict[str, Any] = {}
self.update_fields(**fields)
@ -967,7 +962,7 @@ class MessageFilter:
return True
def to_string(self, maxwidth: int = 80) -> str:
fields: List[Tuple[str, str]] = []
fields: list[Tuple[str, str]] = []
for field in self.message_type.FIELDS.values():
if field.name not in self.fields:
continue
@ -988,7 +983,7 @@ class MessageFilter:
if len(oneline_str) < maxwidth:
return f"{self.message_type.__name__}({oneline_str})"
else:
item: List[str] = []
item: list[str] = []
item.append(f"{self.message_type.__name__}(")
for pair in pairs:
item.append(f" {pair}")
@ -1052,11 +1047,11 @@ class TrezorClientDebugLink(TrezorClient):
"""
self.ui: DebugUI = DebugUI(self.debug)
self.in_with_statement = False
self.expected_responses: Optional[List[MessageFilter]] = None
self.actual_responses: Optional[List[protobuf.MessageType]] = None
self.filters: Dict[
Type[protobuf.MessageType],
Optional[Callable[[protobuf.MessageType], protobuf.MessageType]],
self.expected_responses: list[MessageFilter] | None = None
self.actual_responses: list[protobuf.MessageType] | None = None
self.filters: dict[
type[protobuf.MessageType],
Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {}
def ensure_open(self) -> None:
@ -1076,8 +1071,8 @@ class TrezorClientDebugLink(TrezorClient):
def set_filter(
self,
message_type: Type[protobuf.MessageType],
callback: Optional[Callable[[protobuf.MessageType], protobuf.MessageType]],
message_type: type[protobuf.MessageType],
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None:
"""Configure a filter function for a specified message type.
@ -1102,7 +1097,7 @@ class TrezorClientDebugLink(TrezorClient):
return msg
def set_input_flow(
self, input_flow: Generator[None, Optional[messages.ButtonRequest], None]
self, input_flow: Generator[None, messages.ButtonRequest | None, None]
) -> None:
"""Configure a sequence of input events for the current with-block.
@ -1184,7 +1179,7 @@ class TrezorClientDebugLink(TrezorClient):
input_flow.throw(exc_type, value, traceback)
def set_expected_responses(
self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
) -> None:
"""Set a sequence of expected responses to client calls.
@ -1251,10 +1246,10 @@ class TrezorClientDebugLink(TrezorClient):
return super()._raw_write(self._filter_message(msg))
@staticmethod
def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]:
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
output: List[str] = []
output: list[str] = []
output.append("Expected responses:")
if start_at > 0:
output.append(f" (...{start_at} previous responses omitted)")
@ -1272,8 +1267,8 @@ class TrezorClientDebugLink(TrezorClient):
@classmethod
def _verify_responses(
cls,
expected: Optional[List[MessageFilter]],
actual: Optional[List[protobuf.MessageType]],
expected: list[MessageFilter] | None,
actual: list[protobuf.MessageType] | None,
) -> None:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
@ -1350,9 +1345,9 @@ class TrezorClientDebugLink(TrezorClient):
def load_device(
client: "TrezorClient",
mnemonic: Union[str, Iterable[str]],
pin: Optional[str],
pin: str | None,
passphrase_protection: bool,
label: Optional[str],
label: str | None,
skip_checksum: bool = False,
needs_backup: bool = False,
no_backup: bool = False,