diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index a238e1ea53..4f6375a4db 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -800,7 +800,7 @@ class DebugUI: def clear(self) -> None: self.pins: t.Iterator[str] | None = None - self.passphrase = "" + self.passphrase = None self.input_flow: t.Union[ t.Generator[None, messages.ButtonRequest, None], object, None ] = None @@ -850,7 +850,7 @@ class DebugUI: except StopIteration: raise AssertionError("PIN sequence ended prematurely") - def get_passphrase(self, available_on_device: bool) -> str: + def get_passphrase(self, available_on_device: bool) -> str | None | object: self.debuglink.snapshot_legacy() return self.passphrase @@ -970,6 +970,10 @@ class SessionDebugWrapper(Session): def id(self) -> bytes: return self._session.id + @property + def passphrase(self) -> str | None | object: + return self._session.passphrase + def _write(self, msg: t.Any) -> None: print("writing message:", msg.__class__.__name__) self._session._write(self._filter_message(msg)) @@ -1092,7 +1096,6 @@ class SessionDebugWrapper(Session): self.button_callback = self.client.button_callback self.pin_callback = self.client.pin_callback self.passphrase_callback = self._session.passphrase_callback - self.passphrase = self._session.passphrase def __enter__(self) -> "SessionDebugWrapper": # For usage in with/expected_responses @@ -1226,7 +1229,6 @@ class TrezorClientDebugLink(TrezorClient): # and know the supported debug capabilities self.debug.model = self.model self.debug.version = self.version - self.passphrase: str | None = None @property def layout_type(self) -> LayoutType: @@ -1319,12 +1321,14 @@ class TrezorClientDebugLink(TrezorClient): return send_passphrase(None, None) try: - if isinstance(session, SessionV1): + if isinstance(session, SessionV1) or isinstance(session, SessionDebugWrapper): passphrase = self.ui.get_passphrase( available_on_device=available_on_device ) + if passphrase is None: + passphrase = session.passphrase else: - passphrase = session.passphrase + raise NotImplementedError except Cancelled: session.call_raw(messages.Cancel()) raise @@ -1378,33 +1382,6 @@ class TrezorClientDebugLink(TrezorClient): passphrase = Mnemonic.normalize_string(passphrase) return super().get_session(passphrase, derive_cardano, session_id) - def set_filter( - self, - message_type: t.Type[protobuf.MessageType], - callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, - ) -> None: - """Configure a filter function for a specified message type. - - The `callback` must be a function that accepts a protobuf message, and returns - a (possibly modified) protobuf message of the same type. Whenever a message - is sent or received that matches `message_type`, `callback` is invoked on the - message and its result is substituted for the original. - - Useful for test scenarios with an active malicious actor on the wire. - """ - if not self.in_with_statement: - raise RuntimeError("Must be called inside 'with' statement") - - self.filters[message_type] = callback - - def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: - message_type = msg.__class__ - callback = self.filters.get(message_type) - if callable(callback): - return callback(deepcopy(msg)) - else: - return msg - def set_input_flow( self, input_flow: InputFlowType | t.Callable[[], InputFlowType] ) -> None: @@ -1538,7 +1515,6 @@ class TrezorClientDebugLink(TrezorClient): def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" - self.passphrase = passphrase self.ui.passphrase = Mnemonic.normalize_string(passphrase) def use_mnemonic(self, mnemonic: str) -> None: @@ -1546,17 +1522,6 @@ class TrezorClientDebugLink(TrezorClient): Only applies to T1, where device prompts the host for mnemonic words.""" self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") - def _raw_read(self) -> protobuf.MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - resp = self.get_seedless_session()._read() - resp = self._filter_message(resp) - if self.actual_responses is not None: - self.actual_responses.append(resp) - return resp - - def _raw_write(self, msg: protobuf.MessageType) -> None: - return self.get_seedless_session()._write(self._filter_message(msg)) - @staticmethod def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) diff --git a/tests/conftest.py b/tests/conftest.py index 6b659916c9..17c298c3aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -372,8 +372,6 @@ def _client_unlocked( if request.node.get_closest_marker("experimental"): apply_settings(session, experimental_features=True) - if use_passphrase and isinstance(setup_params["passphrase"], str): - _raw_client.use_passphrase(setup_params["passphrase"]) # TODO _raw_client.clear_session() @@ -399,7 +397,10 @@ def session( session = _client_unlocked.get_seedless_session() else: derive_cardano = bool(request.node.get_closest_marker("cardano")) - passphrase = _client_unlocked.passphrase or "" + passphrase = "" + marker = request.node.get_closest_marker("setup_client") + if marker and isinstance(marker.kwargs.get("passphrase"), str): + passphrase = marker.kwargs["passphrase"] if _client_unlocked._setup_pin is not None: _client_unlocked.use_pin_sequence([_client_unlocked._setup_pin]) session = _client_unlocked.get_session( diff --git a/tests/device_tests/test_msg_loaddevice.py b/tests/device_tests/test_msg_loaddevice.py index 89e9ebe4df..57c8459e9e 100644 --- a/tests/device_tests/test_msg_loaddevice.py +++ b/tests/device_tests/test_msg_loaddevice.py @@ -124,7 +124,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session: Session = session.client.get_session(passphrase=passphrase_nfkd) - session.client.use_passphrase(passphrase_nfkd) # TODO is needed? address_nfkd = get_test_address(session) device.wipe(session) @@ -139,7 +138,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfc) - session.client.use_passphrase(passphrase_nfc) # TODO is needed? address_nfc = get_test_address(session) device.wipe(session) @@ -154,7 +152,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfkc) - session.client.use_passphrase(passphrase_nfkc) # TODO is needed? address_nfkc = get_test_address(session) device.wipe(session) @@ -169,7 +166,6 @@ def test_load_device_utf(client: Client): skip_checksum=True, ) session = client.get_session(passphrase=passphrase_nfd) - session.client.use_passphrase(passphrase_nfd) # TODO is needed? address_nfd = get_test_address(session) assert address_nfkd == address_nfc assert address_nfkd == address_nfkc diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index 9c8dfb924f..e8df254129 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -146,7 +146,6 @@ def test_session_recycling(client: Client): messages.Address, ] ) - client.use_passphrase("TREZOR") _ = get_test_address(session) # address = get_test_address(session) diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index c8a4b3c36c..f4b4097f5b 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -435,7 +435,6 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - client.use_passphrase(passphrase) result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_hidden_passphrase = result.xpub @@ -471,7 +470,6 @@ def test_hide_passphrase_from_host(client: Client): messages.PublicKey, ] ) - client.use_passphrase(passphrase) result = session.call(XPUB_REQUEST) assert isinstance(result, messages.PublicKey) xpub_shown_passphrase = result.xpub