mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-15 09:50:57 +00:00
style(python): upgrade debuglink.py to 3.10 style type annotations
This commit is contained in:
parent
3a8f92f64d
commit
dfac2ae4dd
@ -14,6 +14,8 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
@ -33,11 +35,8 @@ from typing import (
|
||||
Generator,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -57,7 +56,7 @@ if TYPE_CHECKING:
|
||||
from .transport import Transport
|
||||
|
||||
ExpectedMessage = Union[
|
||||
protobuf.MessageType, Type[protobuf.MessageType], "MessageFilter"
|
||||
protobuf.MessageType, type[protobuf.MessageType], "MessageFilter"
|
||||
]
|
||||
|
||||
AnyDict = Dict[str, Any]
|
||||
@ -65,8 +64,8 @@ if TYPE_CHECKING:
|
||||
class InputFunc(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
hold_ms: Optional[int] = None,
|
||||
wait: Optional[bool] = None,
|
||||
hold_ms: int | None = None,
|
||||
wait: bool | None = None,
|
||||
) -> "LayoutContent": ...
|
||||
|
||||
|
||||
@ -101,14 +100,14 @@ class UnstructuredJSONReader:
|
||||
self.json_str = json_str
|
||||
# We may not receive valid JSON, e.g. from an old model in upgrade tests
|
||||
try:
|
||||
self.dict: "AnyDict" = json.loads(json_str)
|
||||
self.dict: AnyDict = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
self.dict = {}
|
||||
|
||||
def top_level_value(self, key: str) -> Any:
|
||||
return self.dict.get(key)
|
||||
|
||||
def find_objects_with_key_and_value(self, key: str, value: Any) -> List["AnyDict"]:
|
||||
def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]:
|
||||
def recursively_find(data: Any) -> Iterator[Any]:
|
||||
if isinstance(data, dict):
|
||||
if data.get(key) == value:
|
||||
@ -123,16 +122,14 @@ class UnstructuredJSONReader:
|
||||
|
||||
def find_unique_object_with_key_and_value(
|
||||
self, key: str, value: Any
|
||||
) -> Optional["AnyDict"]:
|
||||
) -> AnyDict | None:
|
||||
objects = self.find_objects_with_key_and_value(key, value)
|
||||
if not objects:
|
||||
return None
|
||||
assert len(objects) == 1
|
||||
return objects[0]
|
||||
|
||||
def find_values_by_key(
|
||||
self, key: str, only_type: Optional[type] = None
|
||||
) -> List[Any]:
|
||||
def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]:
|
||||
def recursively_find(data: Any) -> Iterator[Any]:
|
||||
if isinstance(data, dict):
|
||||
if key in data:
|
||||
@ -151,7 +148,7 @@ class UnstructuredJSONReader:
|
||||
return values
|
||||
|
||||
def find_unique_value_by_key(
|
||||
self, key: str, default: Any, only_type: Optional[type] = None
|
||||
self, key: str, default: Any, only_type: type | None = None
|
||||
) -> Any:
|
||||
values = self.find_values_by_key(key, only_type=only_type)
|
||||
if not values:
|
||||
@ -171,7 +168,7 @@ class LayoutContent(UnstructuredJSONReader):
|
||||
"""Getting the main component of the layout."""
|
||||
return self.top_level_value("component") or "no main component"
|
||||
|
||||
def all_components(self) -> List[str]:
|
||||
def all_components(self) -> list[str]:
|
||||
"""Getting all components of the layout."""
|
||||
return self.find_values_by_key("component", only_type=str)
|
||||
|
||||
@ -209,7 +206,7 @@ class LayoutContent(UnstructuredJSONReader):
|
||||
def title(self) -> str:
|
||||
"""Getting text that is displayed as a title and potentially subtitle."""
|
||||
# There could be possibly subtitle as well
|
||||
title_parts: List[str] = []
|
||||
title_parts: list[str] = []
|
||||
|
||||
title = self._get_str_or_dict_text("title")
|
||||
if title:
|
||||
@ -244,7 +241,7 @@ class LayoutContent(UnstructuredJSONReader):
|
||||
# Look for paragraphs first (will match most of the time for TT)
|
||||
paragraphs = self.raw_content_paragraphs()
|
||||
if paragraphs:
|
||||
main_text_blocks: List[str] = []
|
||||
main_text_blocks: list[str] = []
|
||||
for par in paragraphs:
|
||||
par_content = ""
|
||||
for line_or_newline in par:
|
||||
@ -294,13 +291,13 @@ class LayoutContent(UnstructuredJSONReader):
|
||||
# Default when not finding anything
|
||||
return self.main_component()
|
||||
|
||||
def raw_content_paragraphs(self) -> Optional[List[List[str]]]:
|
||||
def raw_content_paragraphs(self) -> list[list[str]] | None:
|
||||
"""Getting raw paragraphs as sent from Rust."""
|
||||
return self.find_unique_value_by_key("paragraphs", default=None, only_type=list)
|
||||
|
||||
def tt_check_seed_button_contents(self) -> List[str]:
|
||||
def tt_check_seed_button_contents(self) -> list[str]:
|
||||
"""Getting list of button contents."""
|
||||
buttons: List[str] = []
|
||||
buttons: list[str] = []
|
||||
button_objects = self.find_objects_with_key_and_value("component", "Button")
|
||||
for button in button_objects:
|
||||
if button.get("icon"):
|
||||
@ -309,7 +306,7 @@ class LayoutContent(UnstructuredJSONReader):
|
||||
buttons.append(button["text"])
|
||||
return buttons
|
||||
|
||||
def button_contents(self) -> List[str]:
|
||||
def button_contents(self) -> list[str]:
|
||||
"""Getting list of button contents."""
|
||||
buttons = self.find_unique_value_by_key("buttons", default={}, only_type=dict)
|
||||
|
||||
@ -331,13 +328,13 @@ class LayoutContent(UnstructuredJSONReader):
|
||||
button_keys = ("left_btn", "middle_btn", "right_btn")
|
||||
return [get_button_content(btn_key) for btn_key in button_keys]
|
||||
|
||||
def seed_words(self) -> List[str]:
|
||||
def seed_words(self) -> list[str]:
|
||||
"""Get all the seed words on the screen in order.
|
||||
|
||||
Example content: "1. ladybug\n2. acid\n3. academic\n4. afraid"
|
||||
-> ["ladybug", "acid", "academic", "afraid"]
|
||||
"""
|
||||
words: List[str] = []
|
||||
words: list[str] = []
|
||||
for line in self.screen_content().split("\n"):
|
||||
# Dot after index is optional (present on TT, not on TR)
|
||||
match = re.match(r"^\s*\d+\.? (\w+)$", line)
|
||||
@ -377,7 +374,7 @@ class LayoutContent(UnstructuredJSONReader):
|
||||
"""What is the choice being selected right now."""
|
||||
return self.choice_items()[1]
|
||||
|
||||
def choice_items(self) -> Tuple[str, str, str]:
|
||||
def choice_items(self) -> tuple[str, str, str]:
|
||||
"""Getting actions for all three possible buttons."""
|
||||
choice_obj = self.find_unique_value_by_key(
|
||||
"choice_page", default={}, only_type=dict
|
||||
@ -396,15 +393,15 @@ class LayoutContent(UnstructuredJSONReader):
|
||||
return footer.get("description", "") + " " + footer.get("instruction", "")
|
||||
|
||||
|
||||
def multipage_content(layouts: List[LayoutContent]) -> str:
|
||||
def multipage_content(layouts: list[LayoutContent]) -> str:
|
||||
"""Get overall content from multiple-page layout."""
|
||||
return "".join(layout.text_content() for layout in layouts)
|
||||
|
||||
|
||||
def _make_input_func(
|
||||
button: Optional[messages.DebugButton] = None,
|
||||
physical_button: Optional[messages.DebugPhysicalButton] = None,
|
||||
swipe: Optional[messages.DebugSwipeDirection] = None,
|
||||
button: messages.DebugButton | None = None,
|
||||
physical_button: messages.DebugPhysicalButton | None = None,
|
||||
swipe: messages.DebugSwipeDirection | None = None,
|
||||
) -> "InputFunc":
|
||||
decision = messages.DebugLinkDecision(
|
||||
button=button,
|
||||
@ -414,8 +411,8 @@ def _make_input_func(
|
||||
|
||||
def input_func(
|
||||
self: "DebugLink",
|
||||
hold_ms: Optional[int] = None,
|
||||
wait: Optional[bool] = None,
|
||||
hold_ms: int | None = None,
|
||||
wait: bool | None = None,
|
||||
) -> LayoutContent:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
decision.hold_ms = hold_ms
|
||||
@ -431,18 +428,18 @@ class DebugLink:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
|
||||
# To be set by TrezorClientDebugLink (is not known during creation time)
|
||||
self.model: Optional[models.TrezorModel] = None
|
||||
self.version: Tuple[int, int, int] = (0, 0, 0)
|
||||
self.model: models.TrezorModel | None = None
|
||||
self.version: tuple[int, int, int] = (0, 0, 0)
|
||||
|
||||
# Where screenshots are being saved
|
||||
self.screenshot_recording_dir: Optional[str] = None
|
||||
self.screenshot_recording_dir: str | None = None
|
||||
|
||||
# For T1 screenshotting functionality in DebugUI
|
||||
self.t1_screenshot_directory: Optional[Path] = None
|
||||
self.t1_screenshot_directory: Path | None = None
|
||||
self.t1_screenshot_counter = 0
|
||||
|
||||
# Optional file for saving text representation of the screen
|
||||
self.screen_text_file: Optional[Path] = None
|
||||
self.screen_text_file: Path | None = None
|
||||
self.last_screen_content = ""
|
||||
|
||||
self.waiting_for_layout_change = False
|
||||
@ -470,7 +467,7 @@ class DebugLink:
|
||||
assert self.model is not None
|
||||
return LayoutType.from_model(self.model)
|
||||
|
||||
def set_screen_text_file(self, file_path: Optional[Path]) -> None:
|
||||
def set_screen_text_file(self, file_path: Path | None) -> None:
|
||||
if file_path is not None:
|
||||
file_path.write_bytes(b"")
|
||||
self.screen_text_file = file_path
|
||||
@ -609,7 +606,7 @@ class DebugLink:
|
||||
"""
|
||||
self._call(messages.DebugLinkWatchLayout(watch=watch))
|
||||
|
||||
def encode_pin(self, pin: str, matrix: Optional[str] = None) -> str:
|
||||
def encode_pin(self, pin: str, matrix: str | None = None) -> str:
|
||||
"""Transform correct PIN according to the displayed matrix."""
|
||||
if matrix is None:
|
||||
matrix = self.state().matrix
|
||||
@ -619,7 +616,7 @@ class DebugLink:
|
||||
|
||||
return "".join([str(matrix.index(p) + 1) for p in pin])
|
||||
|
||||
def read_recovery_word(self) -> Tuple[Optional[str], Optional[int]]:
|
||||
def read_recovery_word(self) -> Tuple[str | None, int | None]:
|
||||
state = self.state()
|
||||
return (state.recovery_fake_word, state.recovery_word_pos)
|
||||
|
||||
@ -628,7 +625,7 @@ class DebugLink:
|
||||
return state.reset_word
|
||||
|
||||
def _decision(
|
||||
self, decision: messages.DebugLinkDecision, wait: Optional[bool] = None
|
||||
self, decision: messages.DebugLinkDecision, wait: bool | None = None
|
||||
) -> LayoutContent:
|
||||
"""Send a debuglink decision and returns the resulting layout.
|
||||
|
||||
@ -691,15 +688,15 @@ class DebugLink:
|
||||
)
|
||||
"""Press right button. See `_decision` for more details."""
|
||||
|
||||
def input(self, word: str, wait: Optional[bool] = None) -> LayoutContent:
|
||||
def input(self, word: str, wait: bool | None = None) -> LayoutContent:
|
||||
"""Send text input to the device. See `_decision` for more details."""
|
||||
return self._decision(messages.DebugLinkDecision(input=word), wait)
|
||||
|
||||
def click(
|
||||
self,
|
||||
click: Tuple[int, int],
|
||||
hold_ms: Optional[int] = None,
|
||||
wait: Optional[bool] = None,
|
||||
hold_ms: int | None = None,
|
||||
wait: bool | None = None,
|
||||
) -> LayoutContent:
|
||||
"""Send a click to the device. See `_decision` for more details."""
|
||||
x, y = click
|
||||
@ -750,9 +747,7 @@ class DebugLink:
|
||||
def reseed(self, value: int) -> protobuf.MessageType:
|
||||
return self._call(messages.DebugLinkReseedRandom(value=value))
|
||||
|
||||
def start_recording(
|
||||
self, directory: str, refresh_index: Optional[int] = None
|
||||
) -> None:
|
||||
def start_recording(self, directory: str, refresh_index: int | None = None) -> None:
|
||||
self.screenshot_recording_dir = directory
|
||||
# Different recording logic between core and legacy
|
||||
if self.model is not models.T1B1:
|
||||
@ -807,7 +802,7 @@ class DebugLink:
|
||||
|
||||
assert len(data) == 128 * 64 // 8
|
||||
|
||||
pixels: List[int] = []
|
||||
pixels: list[int] = []
|
||||
for byteline in range(64 // 8):
|
||||
offset = byteline * 128
|
||||
row = data[offset : offset + 128]
|
||||
@ -840,7 +835,7 @@ class NullDebugLink(DebugLink):
|
||||
|
||||
def _call(
|
||||
self, msg: protobuf.MessageType, nowait: bool = False
|
||||
) -> Optional[messages.DebugLinkState]:
|
||||
) -> messages.DebugLinkState | None:
|
||||
if not nowait:
|
||||
if isinstance(msg, messages.DebugLinkGetState):
|
||||
return messages.DebugLinkState()
|
||||
@ -858,7 +853,7 @@ class DebugUI:
|
||||
self.clear()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.pins: Optional[Iterator[str]] = None
|
||||
self.pins: Iterator[str] | None = None
|
||||
self.passphrase = ""
|
||||
self.input_flow: Union[
|
||||
Generator[None, messages.ButtonRequest, None], object, None
|
||||
@ -897,7 +892,7 @@ class DebugUI:
|
||||
except StopIteration:
|
||||
self.input_flow = self.INPUT_FLOW_DONE
|
||||
|
||||
def get_pin(self, code: Optional["PinMatrixRequestType"] = None) -> str:
|
||||
def get_pin(self, code: PinMatrixRequestType | None = None) -> str:
|
||||
self.debuglink.snapshot_legacy()
|
||||
|
||||
if self.pins is None:
|
||||
@ -914,7 +909,7 @@ class DebugUI:
|
||||
|
||||
|
||||
class MessageFilter:
|
||||
def __init__(self, message_type: Type[protobuf.MessageType], **fields: Any) -> None:
|
||||
def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None:
|
||||
self.message_type = message_type
|
||||
self.fields: Dict[str, Any] = {}
|
||||
self.update_fields(**fields)
|
||||
@ -967,7 +962,7 @@ class MessageFilter:
|
||||
return True
|
||||
|
||||
def to_string(self, maxwidth: int = 80) -> str:
|
||||
fields: List[Tuple[str, str]] = []
|
||||
fields: list[Tuple[str, str]] = []
|
||||
for field in self.message_type.FIELDS.values():
|
||||
if field.name not in self.fields:
|
||||
continue
|
||||
@ -988,7 +983,7 @@ class MessageFilter:
|
||||
if len(oneline_str) < maxwidth:
|
||||
return f"{self.message_type.__name__}({oneline_str})"
|
||||
else:
|
||||
item: List[str] = []
|
||||
item: list[str] = []
|
||||
item.append(f"{self.message_type.__name__}(")
|
||||
for pair in pairs:
|
||||
item.append(f" {pair}")
|
||||
@ -1052,11 +1047,11 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
"""
|
||||
self.ui: DebugUI = DebugUI(self.debug)
|
||||
self.in_with_statement = False
|
||||
self.expected_responses: Optional[List[MessageFilter]] = None
|
||||
self.actual_responses: Optional[List[protobuf.MessageType]] = None
|
||||
self.filters: Dict[
|
||||
Type[protobuf.MessageType],
|
||||
Optional[Callable[[protobuf.MessageType], protobuf.MessageType]],
|
||||
self.expected_responses: list[MessageFilter] | None = None
|
||||
self.actual_responses: list[protobuf.MessageType] | None = None
|
||||
self.filters: dict[
|
||||
type[protobuf.MessageType],
|
||||
Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
] = {}
|
||||
|
||||
def ensure_open(self) -> None:
|
||||
@ -1076,8 +1071,8 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
def set_filter(
|
||||
self,
|
||||
message_type: Type[protobuf.MessageType],
|
||||
callback: Optional[Callable[[protobuf.MessageType], protobuf.MessageType]],
|
||||
message_type: type[protobuf.MessageType],
|
||||
callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||
) -> None:
|
||||
"""Configure a filter function for a specified message type.
|
||||
|
||||
@ -1102,7 +1097,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
return msg
|
||||
|
||||
def set_input_flow(
|
||||
self, input_flow: Generator[None, Optional[messages.ButtonRequest], None]
|
||||
self, input_flow: Generator[None, messages.ButtonRequest | None, None]
|
||||
) -> None:
|
||||
"""Configure a sequence of input events for the current with-block.
|
||||
|
||||
@ -1184,7 +1179,7 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
input_flow.throw(exc_type, value, traceback)
|
||||
|
||||
def set_expected_responses(
|
||||
self, expected: List[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
|
||||
self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]]
|
||||
) -> None:
|
||||
"""Set a sequence of expected responses to client calls.
|
||||
|
||||
@ -1251,10 +1246,10 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
return super()._raw_write(self._filter_message(msg))
|
||||
|
||||
@staticmethod
|
||||
def _expectation_lines(expected: List[MessageFilter], current: int) -> List[str]:
|
||||
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||
stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected))
|
||||
output: List[str] = []
|
||||
output: list[str] = []
|
||||
output.append("Expected responses:")
|
||||
if start_at > 0:
|
||||
output.append(f" (...{start_at} previous responses omitted)")
|
||||
@ -1272,8 +1267,8 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
@classmethod
|
||||
def _verify_responses(
|
||||
cls,
|
||||
expected: Optional[List[MessageFilter]],
|
||||
actual: Optional[List[protobuf.MessageType]],
|
||||
expected: list[MessageFilter] | None,
|
||||
actual: list[protobuf.MessageType] | None,
|
||||
) -> None:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
|
||||
@ -1350,9 +1345,9 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
def load_device(
|
||||
client: "TrezorClient",
|
||||
mnemonic: Union[str, Iterable[str]],
|
||||
pin: Optional[str],
|
||||
pin: str | None,
|
||||
passphrase_protection: bool,
|
||||
label: Optional[str],
|
||||
label: str | None,
|
||||
skip_checksum: bool = False,
|
||||
needs_backup: bool = False,
|
||||
no_backup: bool = False,
|
||||
|
Loading…
Reference in New Issue
Block a user