From 4ee6ffa81d584039b8430e0b4ffe70d85673418c Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Wed, 12 Feb 2025 22:12:31 +0100 Subject: [PATCH] debuglink and tests fixes --- python/src/trezorlib/client.py | 4 +- python/src/trezorlib/debuglink.py | 57 ++++--------------- tests/click_tests/test_passphrase_delizia.py | 1 - tests/conftest.py | 7 ++- tests/device_handler.py | 2 +- tests/device_tests/test_msg_loaddevice.py | 4 -- tests/device_tests/test_session.py | 1 - .../test_session_id_and_passphrase.py | 9 --- tests/upgrade_tests/test_firmware_upgrades.py | 2 +- 9 files changed, 18 insertions(+), 69 deletions(-) diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 76702fdda6..cc090fc17d 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -115,10 +115,8 @@ class TrezorClient: from .transport.session import SessionV1 if isinstance(self.protocol, ProtocolV1): - if passphrase is None: - passphrase = "" return SessionV1.new(self, passphrase, derive_cardano) - raise NotImplementedError # TODO + raise NotImplementedError def resume_session(self, session: Session): """ diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index f81593c43f..3a9adf2be2 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -798,7 +798,7 @@ class DebugUI: def clear(self) -> None: self.pins: t.Iterator[str] | None = None - self.passphrase = "" + self.passphrase = None self.input_flow: t.Union[ t.Generator[None, messages.ButtonRequest, None], object, None ] = None @@ -848,7 +848,7 @@ class DebugUI: except StopIteration: raise AssertionError("PIN sequence ended prematurely") - def get_passphrase(self, available_on_device: bool) -> str: + def get_passphrase(self, available_on_device: bool) -> str | None | object: self.debuglink.snapshot_legacy() return self.passphrase @@ -968,6 +968,10 @@ class SessionDebugWrapper(Session): def id(self) -> bytes: return self._session.id + @property + def passphrase(self) -> str | None | object: + return self._session.passphrase + def _write(self, msg: t.Any) -> None: print("writing message:", msg.__class__.__name__) self._session._write(self._filter_message(msg)) @@ -1090,7 +1094,6 @@ class SessionDebugWrapper(Session): self.button_callback = self.client.button_callback self.pin_callback = self.client.pin_callback self.passphrase_callback = self._session.passphrase_callback - self.passphrase = self._session.passphrase def __enter__(self) -> "SessionDebugWrapper": # For usage in with/expected_responses @@ -1224,7 +1227,6 @@ class TrezorClientDebugLink(TrezorClient): # and know the supported debug capabilities self.debug.model = self.model self.debug.version = self.version - self.passphrase: str | None = None @property def layout_type(self) -> LayoutType: @@ -1308,7 +1310,7 @@ class TrezorClientDebugLink(TrezorClient): msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) resp = session.call_raw(msg) if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - # session.session_id = resp.state + session._session.id = resp.state resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) return resp @@ -1317,12 +1319,14 @@ class TrezorClientDebugLink(TrezorClient): return send_passphrase(None, None) try: - if session.passphrase is None and isinstance(session, SessionV1): + if isinstance(session, SessionV1) or isinstance(session, SessionDebugWrapper): passphrase = self.ui.get_passphrase( available_on_device=available_on_device ) + if passphrase is None: + passphrase = session.passphrase else: - passphrase = session.passphrase + raise NotImplementedError except Cancelled: session.call_raw(messages.Cancel()) raise @@ -1376,33 +1380,6 @@ class TrezorClientDebugLink(TrezorClient): passphrase = Mnemonic.normalize_string(passphrase) return super().get_session(passphrase, derive_cardano, session_id) - def set_filter( - self, - message_type: t.Type[protobuf.MessageType], - callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, - ) -> None: - """Configure a filter function for a specified message type. - - The `callback` must be a function that accepts a protobuf message, and returns - a (possibly modified) protobuf message of the same type. Whenever a message - is sent or received that matches `message_type`, `callback` is invoked on the - message and its result is substituted for the original. - - Useful for test scenarios with an active malicious actor on the wire. - """ - if not self.in_with_statement: - raise RuntimeError("Must be called inside 'with' statement") - - self.filters[message_type] = callback - - def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: - message_type = msg.__class__ - callback = self.filters.get(message_type) - if callable(callback): - return callback(deepcopy(msg)) - else: - return msg - def set_input_flow( self, input_flow: InputFlowType | t.Callable[[], InputFlowType] ) -> None: @@ -1536,7 +1513,6 @@ class TrezorClientDebugLink(TrezorClient): def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" - self.passphrase = passphrase self.ui.passphrase = Mnemonic.normalize_string(passphrase) def use_mnemonic(self, mnemonic: str) -> None: @@ -1544,17 +1520,6 @@ class TrezorClientDebugLink(TrezorClient): Only applies to T1, where device prompts the host for mnemonic words.""" self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - def _raw_read(self) -> protobuf.MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - resp = self.get_seedless_session()._read() - resp = self._filter_message(resp) - if self.actual_responses is not None: - self.actual_responses.append(resp) - return resp - - def _raw_write(self, msg: protobuf.MessageType) -> None: - return self.get_seedless_session()._write(self._filter_message(msg)) - @staticmethod def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) diff --git a/tests/click_tests/test_passphrase_delizia.py b/tests/click_tests/test_passphrase_delizia.py index fc7d79610e..e0b9fbb91c 100644 --- a/tests/click_tests/test_passphrase_delizia.py +++ b/tests/click_tests/test_passphrase_delizia.py @@ -98,7 +98,6 @@ def prepare_passphrase_dialogue( ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() device_handler.run_with_session(get_test_address) # type: ignore - # TODO assert debug.read_layout().main_component() == "PassphraseKeyboard" # Resetting the category as it could have been changed by previous tests diff --git a/tests/conftest.py b/tests/conftest.py index 6b659916c9..17c298c3aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -372,8 +372,6 @@ def _client_unlocked( if request.node.get_closest_marker("experimental"): apply_settings(session, experimental_features=True) - if use_passphrase and isinstance(setup_params["passphrase"], str): - _raw_client.use_passphrase(setup_params["passphrase"]) # TODO _raw_client.clear_session() @@ -399,7 +397,10 @@ def session( session = _client_unlocked.get_seedless_session() else: derive_cardano = bool(request.node.get_closest_marker("cardano")) - passphrase = _client_unlocked.passphrase or "" + passphrase = "" + marker = request.node.get_closest_marker("setup_client") + if marker and isinstance(marker.kwargs.get("passphrase"), str): + passphrase = marker.kwargs["passphrase"] if _client_unlocked._setup_pin is not None: _client_unlocked.use_pin_sequence([_client_unlocked._setup_pin]) session = _client_unlocked.get_session( diff --git a/tests/device_handler.py b/tests/device_handler.py index ec1bc1828b..331881bcd0 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -65,8 +65,8 @@ class BackgroundDeviceHandler: raise RuntimeError("Wait for previous task first") # wait for the first UI change triggered by the task running in the background + session = self.client.get_session() with self.debuglink().wait_for_layout_change(): - session = self.client.get_session() self.task = self._pool.submit(function, session, *args, **kwargs) def run_with_provided_session( diff --git a/tests/device_tests/test_msg_loaddevice.py b/tests/device_tests/test_msg_loaddevice.py index 89e9ebe4df..57c8459e9e 100644 --- a/tests/device_tests/test_msg_loaddevice.py +++ b/tests/device_tests/test_msg_loaddevice.py @@ -124,7 +124,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session: Session = session.client.get_session(passphrase=passphrase_nfkd) - session.client.use_passphrase(passphrase_nfkd) # TODO is needed? address_nfkd = get_test_address(session) device.wipe(session) @@ -139,7 +138,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfc) - session.client.use_passphrase(passphrase_nfc) # TODO is needed? address_nfc = get_test_address(session) device.wipe(session) @@ -154,7 +152,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfkc) - session.client.use_passphrase(passphrase_nfkc) # TODO is needed? address_nfkc = get_test_address(session) device.wipe(session) @@ -169,7 +166,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfd) - session.client.use_passphrase(passphrase_nfd) # TODO is needed? address_nfd = get_test_address(session) assert address_nfkd == address_nfc assert address_nfkd == address_nfkc diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index 9c8dfb924f..e8df254129 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -146,7 +146,6 @@ def test_session_recycling(client: Client): messages.Address, ] ) - client.use_passphrase("TREZOR") _ = get_test_address(session) # address = get_test_address(session) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index c8a4b3c36c..da668b914c 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -54,7 +54,6 @@ SESSIONS_STORED = 10 def _get_xpub( session: Session, expected_passphrase_req: bool = False, - passphrase_v1: str | None = None, ): """Get XPUB and check that the appropriate passphrase flow has happened.""" if expected_passphrase_req: @@ -66,11 +65,6 @@ def _get_xpub( ] else: expected_responses = [messages.PublicKey] - if ( - passphrase_v1 is not None - and session.protocol_version == ProtocolVersion.PROTOCOL_V1 - ): - session.passphrase = passphrase_v1 with session: session.set_expected_responses(expected_responses) @@ -228,7 +222,6 @@ def test_max_sessions_with_passphrases(client: Client): _get_xpub( resumed_session, expected_passphrase_req=True, - passphrase_v1="whatever", ) # passphrase is prompted @@ -435,7 +428,6 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - client.use_passphrase(passphrase) result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_hidden_passphrase = result.xpub @@ -471,7 +463,6 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - client.use_passphrase(passphrase) result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_shown_passphrase = result.xpub diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 49cd73d6c8..7a0bee8887 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -384,7 +384,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): address = btc.get_address(session, "Bitcoin", PATH) if session.protocol_version == ProtocolVersion.PROTOCOL_V1: session.call(messages.Initialize(new_session=True)) - new_session = emu.client.get_session(passphrase="TREZOR") + new_session = Session(emu.client.get_session(passphrase="TREZOR")) address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) assert emu.client.features.backup_availability == BackupAvailability.Required