diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 8a3fff6d10..2c017a58d3 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -116,7 +116,7 @@ class TrezorClient: if isinstance(self.protocol, ProtocolV1Channel): return SessionV1.new(self, passphrase, derive_cardano) - raise NotImplementedError # TODO + raise NotImplementedError def resume_session(self, session: Session): """ diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 0067d644c1..ccfa67f682 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -798,7 +798,7 @@ class DebugUI: def clear(self) -> None: self.pins: t.Iterator[str] | None = None - self.passphrase = "" + self.passphrase = None self.input_flow: t.Union[ t.Generator[None, messages.ButtonRequest, None], object, None ] = None @@ -848,7 +848,7 @@ class DebugUI: except StopIteration: raise AssertionError("PIN sequence ended prematurely") - def get_passphrase(self, available_on_device: bool) -> str: + def get_passphrase(self, available_on_device: bool) -> str | None | object: self.debuglink.snapshot_legacy() return self.passphrase @@ -1098,7 +1098,6 @@ class SessionDebugWrapper(Session): self.button_callback = self.client.button_callback self.pin_callback = self.client.pin_callback self.passphrase_callback = self._session.passphrase_callback - self.passphrase = self._session.passphrase def __enter__(self) -> "SessionDebugWrapper": # For usage in with/expected_responses @@ -1232,7 +1231,6 @@ class TrezorClientDebugLink(TrezorClient): # and know the supported debug capabilities self.debug.model = self.model self.debug.version = self.version - self.passphrase: str | None = None @property def layout_type(self) -> LayoutType: @@ -1328,12 +1326,16 @@ class TrezorClientDebugLink(TrezorClient): return send_passphrase(None, None) try: - if session.passphrase is None and isinstance(session, SessionV1): + if isinstance(session, SessionV1) or isinstance( + session, SessionDebugWrapper + ): passphrase = self.ui.get_passphrase( available_on_device=available_on_device ) + if passphrase is None: + passphrase = session.passphrase else: - passphrase = session.passphrase + raise NotImplementedError except Cancelled: session.call_raw(messages.Cancel()) raise @@ -1387,33 +1389,6 @@ class TrezorClientDebugLink(TrezorClient): passphrase = Mnemonic.normalize_string(passphrase) return super().get_session(passphrase, derive_cardano, session_id) - def set_filter( - self, - message_type: t.Type[protobuf.MessageType], - callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, - ) -> None: - """Configure a filter function for a specified message type. - - The `callback` must be a function that accepts a protobuf message, and returns - a (possibly modified) protobuf message of the same type. Whenever a message - is sent or received that matches `message_type`, `callback` is invoked on the - message and its result is substituted for the original. - - Useful for test scenarios with an active malicious actor on the wire. - """ - if not self.in_with_statement: - raise RuntimeError("Must be called inside 'with' statement") - - self.filters[message_type] = callback - - def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: - message_type = msg.__class__ - callback = self.filters.get(message_type) - if callable(callback): - return callback(deepcopy(msg)) - else: - return msg - def set_input_flow( self, input_flow: InputFlowType | t.Callable[[], InputFlowType] ) -> None: @@ -1545,27 +1520,11 @@ class TrezorClientDebugLink(TrezorClient): """ self.ui.pins = iter(pins) - def use_passphrase(self, passphrase: str) -> None: - """Respond to passphrase prompts from device with the provided passphrase.""" - self.passphrase = passphrase - self.ui.passphrase = Mnemonic.normalize_string(passphrase) - def use_mnemonic(self, mnemonic: str) -> None: """Use the provided mnemonic to respond to device. Only applies to T1, where device prompts the host for mnemonic words.""" self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - def _raw_read(self) -> protobuf.MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - resp = self.get_seedless_session()._read() - resp = self._filter_message(resp) - if self.actual_responses is not None: - self.actual_responses.append(resp) - return resp - - def _raw_write(self, msg: protobuf.MessageType) -> None: - return self.get_seedless_session()._write(self._filter_message(msg)) - @staticmethod def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) diff --git a/tests/click_tests/test_passphrase_delizia.py b/tests/click_tests/test_passphrase_delizia.py index fc7d79610e..e0b9fbb91c 100644 --- a/tests/click_tests/test_passphrase_delizia.py +++ b/tests/click_tests/test_passphrase_delizia.py @@ -98,7 +98,6 @@ def prepare_passphrase_dialogue( ) -> Generator["DebugLink", None, None]: debug = device_handler.debuglink() device_handler.run_with_session(get_test_address) # type: ignore - # TODO assert debug.read_layout().main_component() == "PassphraseKeyboard" # Resetting the category as it could have been changed by previous tests diff --git a/tests/conftest.py b/tests/conftest.py index 01d11e3b22..6cb5d1b6f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -364,9 +364,6 @@ def _client_unlocked( if request.node.get_closest_marker("experimental"): apply_settings(session, experimental_features=True) - if use_passphrase and isinstance(setup_params["passphrase"], str): - _raw_client.use_passphrase(setup_params["passphrase"]) - # TODO _raw_client.clear_session() yield _raw_client @@ -391,7 +388,10 @@ def session( session = _client_unlocked.get_seedless_session() else: derive_cardano = bool(request.node.get_closest_marker("cardano")) - passphrase = _client_unlocked.passphrase or "" + passphrase = "" + marker = request.node.get_closest_marker("setup_client") + if marker and isinstance(marker.kwargs.get("passphrase"), str): + passphrase = marker.kwargs["passphrase"] if _client_unlocked._setup_pin is not None: _client_unlocked.use_pin_sequence([_client_unlocked._setup_pin]) session = _client_unlocked.get_session( diff --git a/tests/device_handler.py b/tests/device_handler.py index 605b160c98..c060a405e9 100644 --- a/tests/device_handler.py +++ b/tests/device_handler.py @@ -64,8 +64,8 @@ class BackgroundDeviceHandler: raise RuntimeError("Wait for previous task first") # wait for the first UI change triggered by the task running in the background + session = self.client.get_session() with self.debuglink().wait_for_layout_change(): - session = self.client.get_session() self.task = self._pool.submit(function, session, *args, **kwargs) def run_with_provided_session( diff --git a/tests/device_tests/test_msg_loaddevice.py b/tests/device_tests/test_msg_loaddevice.py index 89e9ebe4df..57c8459e9e 100644 --- a/tests/device_tests/test_msg_loaddevice.py +++ b/tests/device_tests/test_msg_loaddevice.py @@ -124,7 +124,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session: Session = session.client.get_session(passphrase=passphrase_nfkd) - session.client.use_passphrase(passphrase_nfkd) # TODO is needed? address_nfkd = get_test_address(session) device.wipe(session) @@ -139,7 +138,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfc) - session.client.use_passphrase(passphrase_nfc) # TODO is needed? address_nfc = get_test_address(session) device.wipe(session) @@ -154,7 +152,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfkc) - session.client.use_passphrase(passphrase_nfkc) # TODO is needed? address_nfkc = get_test_address(session) device.wipe(session) @@ -169,7 +166,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfd) - session.client.use_passphrase(passphrase_nfd) # TODO is needed? address_nfd = get_test_address(session) assert address_nfkd == address_nfc assert address_nfkd == address_nfkc diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index 9c8dfb924f..e8df254129 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -146,7 +146,6 @@ def test_session_recycling(client: Client): messages.Address, ] ) - client.use_passphrase("TREZOR") _ = get_test_address(session) # address = get_test_address(session) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index c8a4b3c36c..93ce1eeb45 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -19,7 +19,6 @@ import random import pytest from trezorlib import device, exceptions, messages -from trezorlib.client import ProtocolVersion from trezorlib.debuglink import LayoutType from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import TrezorClientDebugLink as Client @@ -54,7 +53,6 @@ SESSIONS_STORED = 10 def _get_xpub( session: Session, expected_passphrase_req: bool = False, - passphrase_v1: str | None = None, ): """Get XPUB and check that the appropriate passphrase flow has happened.""" if expected_passphrase_req: @@ -66,11 +64,6 @@ def _get_xpub( ] else: expected_responses = [messages.PublicKey] - if ( - passphrase_v1 is not None - and session.protocol_version == ProtocolVersion.PROTOCOL_V1 - ): - session.passphrase = passphrase_v1 with session: session.set_expected_responses(expected_responses) @@ -228,7 +221,6 @@ def test_max_sessions_with_passphrases(client: Client): _get_xpub( resumed_session, expected_passphrase_req=True, - passphrase_v1="whatever", ) # passphrase is prompted @@ -435,7 +427,6 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - client.use_passphrase(passphrase) result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_hidden_passphrase = result.xpub @@ -471,7 +462,6 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - client.use_passphrase(passphrase) result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_shown_passphrase = result.xpub diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 49cd73d6c8..7a0bee8887 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -384,7 +384,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]): address = btc.get_address(session, "Bitcoin", PATH) if session.protocol_version == ProtocolVersion.PROTOCOL_V1: session.call(messages.Initialize(new_session=True)) - new_session = emu.client.get_session(passphrase="TREZOR") + new_session = Session(emu.client.get_session(passphrase="TREZOR")) address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) assert emu.client.features.backup_availability == BackupAvailability.Required