1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-15 17:12:04 +00:00
This commit is contained in:
Martin Milata 2025-02-13 01:34:50 +01:00
parent d9d9c10fe7
commit eee595c87a
5 changed files with 14 additions and 55 deletions

View File

@ -800,7 +800,7 @@ class DebugUI:
def clear(self) -> None:
self.pins: t.Iterator[str] | None = None
self.passphrase = ""
self.passphrase = None
self.input_flow: t.Union[
t.Generator[None, messages.ButtonRequest, None], object, None
] = None
@ -850,7 +850,7 @@ class DebugUI:
except StopIteration:
raise AssertionError("PIN sequence ended prematurely")
def get_passphrase(self, available_on_device: bool) -> str:
def get_passphrase(self, available_on_device: bool) -> str | None | object:
self.debuglink.snapshot_legacy()
return self.passphrase
@ -970,6 +970,10 @@ class SessionDebugWrapper(Session):
def id(self) -> bytes:
return self._session.id
@property
def passphrase(self) -> str | None | object:
return self._session.passphrase
def _write(self, msg: t.Any) -> None:
print("writing message:", msg.__class__.__name__)
self._session._write(self._filter_message(msg))
@ -1092,7 +1096,6 @@ class SessionDebugWrapper(Session):
self.button_callback = self.client.button_callback
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self._session.passphrase_callback
self.passphrase = self._session.passphrase
def __enter__(self) -> "SessionDebugWrapper":
# For usage in with/expected_responses
@ -1226,7 +1229,6 @@ class TrezorClientDebugLink(TrezorClient):
# and know the supported debug capabilities
self.debug.model = self.model
self.debug.version = self.version
self.passphrase: str | None = None
@property
def layout_type(self) -> LayoutType:
@ -1319,12 +1321,14 @@ class TrezorClientDebugLink(TrezorClient):
return send_passphrase(None, None)
try:
if isinstance(session, SessionV1):
if isinstance(session, SessionV1) or isinstance(session, SessionDebugWrapper):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
if passphrase is None:
passphrase = session.passphrase
else:
passphrase = session.passphrase
raise NotImplementedError
except Cancelled:
session.call_raw(messages.Cancel())
raise
@ -1378,33 +1382,6 @@ class TrezorClientDebugLink(TrezorClient):
passphrase = Mnemonic.normalize_string(passphrase)
return super().get_session(passphrase, derive_cardano, session_id)
def set_filter(
self,
message_type: t.Type[protobuf.MessageType],
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None:
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
a (possibly modified) protobuf message of the same type. Whenever a message
is sent or received that matches `message_type`, `callback` is invoked on the
message and its result is substituted for the original.
Useful for test scenarios with an active malicious actor on the wire.
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.filters[message_type] = callback
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg
def set_input_flow(
self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
) -> None:
@ -1538,7 +1515,6 @@ class TrezorClientDebugLink(TrezorClient):
def use_passphrase(self, passphrase: str) -> None:
"""Respond to passphrase prompts from device with the provided passphrase."""
self.passphrase = passphrase
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def use_mnemonic(self, mnemonic: str) -> None:
@ -1546,17 +1522,6 @@ class TrezorClientDebugLink(TrezorClient):
Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def _raw_read(self) -> protobuf.MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
resp = self.get_seedless_session()._read()
resp = self._filter_message(resp)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def _raw_write(self, msg: protobuf.MessageType) -> None:
return self.get_seedless_session()._write(self._filter_message(msg))
@staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)

View File

@ -372,8 +372,6 @@ def _client_unlocked(
if request.node.get_closest_marker("experimental"):
apply_settings(session, experimental_features=True)
if use_passphrase and isinstance(setup_params["passphrase"], str):
_raw_client.use_passphrase(setup_params["passphrase"])
# TODO _raw_client.clear_session()
@ -399,7 +397,10 @@ def session(
session = _client_unlocked.get_seedless_session()
else:
derive_cardano = bool(request.node.get_closest_marker("cardano"))
passphrase = _client_unlocked.passphrase or ""
passphrase = ""
marker = request.node.get_closest_marker("setup_client")
if marker and isinstance(marker.kwargs.get("passphrase"), str):
passphrase = marker.kwargs["passphrase"]
if _client_unlocked._setup_pin is not None:
_client_unlocked.use_pin_sequence([_client_unlocked._setup_pin])
session = _client_unlocked.get_session(

View File

@ -124,7 +124,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True,
)
session: Session = session.client.get_session(passphrase=passphrase_nfkd)
session.client.use_passphrase(passphrase_nfkd) # TODO is needed?
address_nfkd = get_test_address(session)
device.wipe(session)
@ -139,7 +138,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True,
)
session = client.get_session(passphrase=passphrase_nfc)
session.client.use_passphrase(passphrase_nfc) # TODO is needed?
address_nfc = get_test_address(session)
device.wipe(session)
@ -154,7 +152,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True,
)
session = client.get_session(passphrase=passphrase_nfkc)
session.client.use_passphrase(passphrase_nfkc) # TODO is needed?
address_nfkc = get_test_address(session)
device.wipe(session)
@ -169,7 +166,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True,
)
session = client.get_session(passphrase=passphrase_nfd)
session.client.use_passphrase(passphrase_nfd) # TODO is needed?
address_nfd = get_test_address(session)
assert address_nfkd == address_nfc
assert address_nfkd == address_nfkc

View File

@ -146,7 +146,6 @@ def test_session_recycling(client: Client):
messages.Address,
]
)
client.use_passphrase("TREZOR")
_ = get_test_address(session)
# address = get_test_address(session)

View File

@ -435,7 +435,6 @@ def test_hide_passphrase_from_host(client: Client):
messages.PublicKey,
]
)
client.use_passphrase(passphrase)
result = session.call(XPUB_REQUEST)
assert isinstance(result, messages.PublicKey)
xpub_hidden_passphrase = result.xpub
@ -471,7 +470,6 @@ def test_hide_passphrase_from_host(client: Client):
messages.PublicKey,
]
)
client.use_passphrase(passphrase)
result = session.call(XPUB_REQUEST)
assert isinstance(result, messages.PublicKey)
xpub_shown_passphrase = result.xpub