diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 56409f722e..5a7f04c44a 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -40,6 +40,7 @@ from .tools import parse_path from .transport import Timeout from .transport.session import ProtocolV2Channel, Session from .transport.thp.protocol_v1 import ProtocolV1Channel +from trezorlib import exceptions if t.TYPE_CHECKING: from typing_extensions import Protocol @@ -1031,6 +1032,7 @@ class TrezorClientDebugLink(TrezorClient): # by the device. protocol: ProtocolV1Channel | ProtocolV2Channel + actual_responses: list[protobuf.MessageType] | None = None def __init__( self, @@ -1056,8 +1058,6 @@ class TrezorClientDebugLink(TrezorClient): transport.open() # set transport explicitly so that sync_responses can work - super().__init__(transport) - self.transport = transport self.ui: DebugUI = DebugUI(self.debug) @@ -1071,6 +1071,13 @@ class TrezorClientDebugLink(TrezorClient): self.pin_callback = get_pin self.button_callback = self.ui.button_request + try: + super().__init__(transport) + except exceptions.DeviceLockedException: + self.use_pin_sequence(["1234"]) + self.debug.input(self.debug.encode_pin("1234")) + super().__init__(transport) + self.sync_responses() # So that we can choose right screenshotting logic (T1 vs TT) @@ -1303,6 +1310,7 @@ class TrezorClientDebugLink(TrezorClient): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. input_flow.throw(exc_type, value, traceback) + self.actual_responses = None @classmethod def _verify_responses( @@ -1406,6 +1414,14 @@ class TrezorClientDebugLink(TrezorClient): self.actual_responses.append(resp) return resp + def notify_read(self, msg: protobuf.MessageType) -> None: + pass + try: + if self.actual_responses is not None: + self.actual_responses.append(msg) + except Exception as e: + print(e) + def load_device( session: "Session", diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 7a0af18e20..5e8f0a0d19 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging import typing as t + from .. import exceptions, messages, models from ..client import MAX_PIN_LENGTH from ..protobuf import MessageType @@ -274,8 +275,12 @@ class SessionV2(Session): self.channel.write(self.sid, msg) def _read(self) -> t.Any: + from trezorlib.debuglink import TrezorClientDebugLink + msg = self.channel.read(self.sid) LOG.debug("reading message %s", type(msg)) + if isinstance(self.client, TrezorClientDebugLink): + self.client.notify_read(msg) return msg def update_id_and_sid(self, id: bytes) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 2a701b0d70..c6f6a0a282 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -345,7 +345,11 @@ def _client_unlocked( LOG = logging.getLogger(__name__) LOG.error(f"Failed to re-create a client: {e}") sleep(LOCK_TIME) - _raw_client = _get_raw_client(request) + try: + _raw_client = _raw_client.get_new_client() + except Exception as e: + sleep(1.5) + _raw_client = _get_raw_client(request) session = _raw_client.get_seedless_session() wipe_device(session) diff --git a/tests/device_tests/test_cancel.py b/tests/device_tests/test_cancel.py index 5fa14ceb84..8c56e64e90 100644 --- a/tests/device_tests/test_cancel.py +++ b/tests/device_tests/test_cancel.py @@ -58,6 +58,7 @@ def test_cancel_message_via_cancel(session: Session, message): ), ], ) +@pytest.mark.protocol("protocol_v1") def test_cancel_message_via_initialize(session: Session, message): resp = session.call_raw(message) assert isinstance(resp, m.ButtonRequest) diff --git a/tests/device_tests/test_msg_applysettings.py b/tests/device_tests/test_msg_applysettings.py index 9a000ea295..5da2ed0e9e 100644 --- a/tests/device_tests/test_msg_applysettings.py +++ b/tests/device_tests/test_msg_applysettings.py @@ -19,6 +19,7 @@ from pathlib import Path import pytest from trezorlib import btc, device, exceptions, messages, misc, models +from trezorlib.client import ProtocolVersion from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.tools import parse_path @@ -211,7 +212,8 @@ def test_apply_homescreen_toif(session: Session): def test_apply_homescreen_jpeg(session: Session): with open(HERE / "test_bg.jpg", "rb") as f: img = f.read() - # raise Exception("FAILS FOR SOME REASON ") + if session.protocol_version is ProtocolVersion.V2: + raise Exception("Message too large for THP") with session.client as client: _set_expected_responses(client) device.apply_settings(session, homescreen=img) @@ -338,7 +340,9 @@ def test_safety_checks(session: Session): assert session.features.safety_checks == messages.SafetyCheckLevel.Strict - with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client: + with pytest.raises( + exceptions.TrezorFailure, match="Forbidden key path" + ), session.client as client: client.set_expected_responses([messages.Failure]) get_bad_address() @@ -351,11 +355,11 @@ def test_safety_checks(session: Session): assert session.features.safety_checks == messages.SafetyCheckLevel.PromptAlways - with client: + with session.client as client: client.set_expected_responses( [messages.ButtonRequest, messages.ButtonRequest, messages.Address] ) - IF = InputFlowConfirmAllWarnings(session.client) + IF = InputFlowConfirmAllWarnings(client) client.set_input_flow(IF.get()) get_bad_address() @@ -365,11 +369,13 @@ def test_safety_checks(session: Session): assert session.features.safety_checks == messages.SafetyCheckLevel.Strict - with pytest.raises(exceptions.TrezorFailure, match="Forbidden key path"), client: + with pytest.raises( + exceptions.TrezorFailure, match="Forbidden key path" + ), session.client as client: client.set_expected_responses([messages.Failure]) get_bad_address() - with client: + with session.client as client: client.set_expected_responses(EXPECTED_RESPONSES_NOPIN) device.apply_settings( session, safety_checks=messages.SafetyCheckLevel.PromptTemporarily @@ -377,11 +383,11 @@ def test_safety_checks(session: Session): assert session.features.safety_checks == messages.SafetyCheckLevel.PromptTemporarily - with client: + with session.client as client: client.set_expected_responses( [messages.ButtonRequest, messages.ButtonRequest, messages.Address] ) - if session.model is not models.T1B1: + if client.model is not models.T1B1: IF = InputFlowConfirmAllWarnings(session.client) client.set_input_flow(IF.get()) get_bad_address() @@ -403,7 +409,9 @@ def test_experimental_features(session: Session): assert not session.features.experimental_features - with pytest.raises(exceptions.TrezorFailure, match="DataError"), client: + with pytest.raises( + exceptions.TrezorFailure, match="DataError" + ), session.client as client: client.set_expected_responses([messages.Failure]) experimental_call() @@ -413,7 +421,7 @@ def test_experimental_features(session: Session): assert session.features.experimental_features - with client: + with session.client as client: client.set_expected_responses([messages.Nonce]) experimental_call()