From b48c36c2bfba53d3624ca031eae6810ba21e6630 Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 19 Jul 2024 10:51:07 +0200 Subject: [PATCH] fix(python): improve robustness of TrezorClientDebugLink setup * improve sync_responses to work on uninitialized instance * sync responses at construction time --- python/src/trezorlib/debuglink.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 1bdaddcfc3..cd445049a6 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -48,7 +48,7 @@ from .client import TrezorClient from .exceptions import TrezorFailure from .log import DUMP_BYTES from .messages import DebugWaitType -from .tools import expect +from .tools import expect, session if TYPE_CHECKING: from typing_extensions import Protocol @@ -1014,8 +1014,11 @@ class TrezorClientDebugLink(TrezorClient): else: raise - self.reset_debug_features() + # set transport explicitly so that sync_responses can work + self.transport = transport + self.reset_debug_features() + self.sync_responses() super().__init__(transport, ui=self.ui) # So that we can choose right screenshotting logic (T1 vs TT) @@ -1300,14 +1303,23 @@ class TrezorClientDebugLink(TrezorClient): # Start by canceling whatever is on screen. This will work to cancel T1 PIN # prompt, which is in TINY mode and does not respond to `Ping`. - # go to super() to avoid message filtering - super()._raw_write(messages.Cancel()) + cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) + self.transport.begin_session() + try: + self.transport.write(*cancel_msg) - message = "SYNC" + secrets.token_hex(8) - super()._raw_write(messages.Ping(message=message)) - resp = None - while resp != messages.Success(message=message): - resp = super()._raw_read() + message = "SYNC" + secrets.token_hex(8) + ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) + self.transport.write(*ping_msg) + resp = None + while resp != messages.Success(message=message): + msg_id, msg_bytes = self.transport.read() + try: + resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) + except Exception: + pass + finally: + self.transport.end_session() def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word()