1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-22 12:32:02 +00:00

debuglink and tests fixes

[no changelog]
This commit is contained in:
Martin Milata 2025-02-12 22:12:31 +01:00 committed by M1nd3r
parent 224e3825be
commit 12a00e53b0
9 changed files with 15 additions and 72 deletions

View File

@ -116,7 +116,7 @@ class TrezorClient:
if isinstance(self.protocol, ProtocolV1Channel): if isinstance(self.protocol, ProtocolV1Channel):
return SessionV1.new(self, passphrase, derive_cardano) return SessionV1.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO raise NotImplementedError
def resume_session(self, session: Session): def resume_session(self, session: Session):
""" """

View File

@ -798,7 +798,7 @@ class DebugUI:
def clear(self) -> None: def clear(self) -> None:
self.pins: t.Iterator[str] | None = None self.pins: t.Iterator[str] | None = None
self.passphrase = "" self.passphrase = None
self.input_flow: t.Union[ self.input_flow: t.Union[
t.Generator[None, messages.ButtonRequest, None], object, None t.Generator[None, messages.ButtonRequest, None], object, None
] = None ] = None
@ -848,7 +848,7 @@ class DebugUI:
except StopIteration: except StopIteration:
raise AssertionError("PIN sequence ended prematurely") raise AssertionError("PIN sequence ended prematurely")
def get_passphrase(self, available_on_device: bool) -> str: def get_passphrase(self, available_on_device: bool) -> str | None | object:
self.debuglink.snapshot_legacy() self.debuglink.snapshot_legacy()
return self.passphrase return self.passphrase
@ -1098,7 +1098,6 @@ class SessionDebugWrapper(Session):
self.button_callback = self.client.button_callback self.button_callback = self.client.button_callback
self.pin_callback = self.client.pin_callback self.pin_callback = self.client.pin_callback
self.passphrase_callback = self._session.passphrase_callback self.passphrase_callback = self._session.passphrase_callback
self.passphrase = self._session.passphrase
def __enter__(self) -> "SessionDebugWrapper": def __enter__(self) -> "SessionDebugWrapper":
# For usage in with/expected_responses # For usage in with/expected_responses
@ -1232,7 +1231,6 @@ class TrezorClientDebugLink(TrezorClient):
# and know the supported debug capabilities # and know the supported debug capabilities
self.debug.model = self.model self.debug.model = self.model
self.debug.version = self.version self.debug.version = self.version
self.passphrase: str | None = None
@property @property
def layout_type(self) -> LayoutType: def layout_type(self) -> LayoutType:
@ -1328,12 +1326,16 @@ class TrezorClientDebugLink(TrezorClient):
return send_passphrase(None, None) return send_passphrase(None, None)
try: try:
if session.passphrase is None and isinstance(session, SessionV1): if isinstance(session, SessionV1) or isinstance(
session, SessionDebugWrapper
):
passphrase = self.ui.get_passphrase( passphrase = self.ui.get_passphrase(
available_on_device=available_on_device available_on_device=available_on_device
) )
else: if passphrase is None:
passphrase = session.passphrase passphrase = session.passphrase
else:
raise NotImplementedError
except Cancelled: except Cancelled:
session.call_raw(messages.Cancel()) session.call_raw(messages.Cancel())
raise raise
@ -1387,33 +1389,6 @@ class TrezorClientDebugLink(TrezorClient):
passphrase = Mnemonic.normalize_string(passphrase) passphrase = Mnemonic.normalize_string(passphrase)
return super().get_session(passphrase, derive_cardano, session_id) return super().get_session(passphrase, derive_cardano, session_id)
def set_filter(
self,
message_type: t.Type[protobuf.MessageType],
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
) -> None:
"""Configure a filter function for a specified message type.
The `callback` must be a function that accepts a protobuf message, and returns
a (possibly modified) protobuf message of the same type. Whenever a message
is sent or received that matches `message_type`, `callback` is invoked on the
message and its result is substituted for the original.
Useful for test scenarios with an active malicious actor on the wire.
"""
if not self.in_with_statement:
raise RuntimeError("Must be called inside 'with' statement")
self.filters[message_type] = callback
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
message_type = msg.__class__
callback = self.filters.get(message_type)
if callable(callback):
return callback(deepcopy(msg))
else:
return msg
def set_input_flow( def set_input_flow(
self, input_flow: InputFlowType | t.Callable[[], InputFlowType] self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
) -> None: ) -> None:
@ -1545,27 +1520,11 @@ class TrezorClientDebugLink(TrezorClient):
""" """
self.ui.pins = iter(pins) self.ui.pins = iter(pins)
def use_passphrase(self, passphrase: str) -> None:
"""Respond to passphrase prompts from device with the provided passphrase."""
self.passphrase = passphrase
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
def use_mnemonic(self, mnemonic: str) -> None: def use_mnemonic(self, mnemonic: str) -> None:
"""Use the provided mnemonic to respond to device. """Use the provided mnemonic to respond to device.
Only applies to T1, where device prompts the host for mnemonic words.""" Only applies to T1, where device prompts the host for mnemonic words."""
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ") self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
def _raw_read(self) -> protobuf.MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
resp = self.get_seedless_session()._read()
resp = self._filter_message(resp)
if self.actual_responses is not None:
self.actual_responses.append(resp)
return resp
def _raw_write(self, msg: protobuf.MessageType) -> None:
return self.get_seedless_session()._write(self._filter_message(msg))
@staticmethod @staticmethod
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)

View File

@ -98,7 +98,6 @@ def prepare_passphrase_dialogue(
) -> Generator["DebugLink", None, None]: ) -> Generator["DebugLink", None, None]:
debug = device_handler.debuglink() debug = device_handler.debuglink()
device_handler.run_with_session(get_test_address) # type: ignore device_handler.run_with_session(get_test_address) # type: ignore
# TODO
assert debug.read_layout().main_component() == "PassphraseKeyboard" assert debug.read_layout().main_component() == "PassphraseKeyboard"
# Resetting the category as it could have been changed by previous tests # Resetting the category as it could have been changed by previous tests

View File

@ -364,9 +364,6 @@ def _client_unlocked(
if request.node.get_closest_marker("experimental"): if request.node.get_closest_marker("experimental"):
apply_settings(session, experimental_features=True) apply_settings(session, experimental_features=True)
if use_passphrase and isinstance(setup_params["passphrase"], str):
_raw_client.use_passphrase(setup_params["passphrase"])
# TODO _raw_client.clear_session() # TODO _raw_client.clear_session()
yield _raw_client yield _raw_client
@ -391,7 +388,10 @@ def session(
session = _client_unlocked.get_seedless_session() session = _client_unlocked.get_seedless_session()
else: else:
derive_cardano = bool(request.node.get_closest_marker("cardano")) derive_cardano = bool(request.node.get_closest_marker("cardano"))
passphrase = _client_unlocked.passphrase or "" passphrase = ""
marker = request.node.get_closest_marker("setup_client")
if marker and isinstance(marker.kwargs.get("passphrase"), str):
passphrase = marker.kwargs["passphrase"]
if _client_unlocked._setup_pin is not None: if _client_unlocked._setup_pin is not None:
_client_unlocked.use_pin_sequence([_client_unlocked._setup_pin]) _client_unlocked.use_pin_sequence([_client_unlocked._setup_pin])
session = _client_unlocked.get_session( session = _client_unlocked.get_session(

View File

@ -64,8 +64,8 @@ class BackgroundDeviceHandler:
raise RuntimeError("Wait for previous task first") raise RuntimeError("Wait for previous task first")
# wait for the first UI change triggered by the task running in the background # wait for the first UI change triggered by the task running in the background
with self.debuglink().wait_for_layout_change():
session = self.client.get_session() session = self.client.get_session()
with self.debuglink().wait_for_layout_change():
self.task = self._pool.submit(function, session, *args, **kwargs) self.task = self._pool.submit(function, session, *args, **kwargs)
def run_with_provided_session( def run_with_provided_session(

View File

@ -124,7 +124,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True, skip_checksum=True,
) )
session: Session = session.client.get_session(passphrase=passphrase_nfkd) session: Session = session.client.get_session(passphrase=passphrase_nfkd)
session.client.use_passphrase(passphrase_nfkd) # TODO is needed?
address_nfkd = get_test_address(session) address_nfkd = get_test_address(session)
device.wipe(session) device.wipe(session)
@ -139,7 +138,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True, skip_checksum=True,
) )
session = client.get_session(passphrase=passphrase_nfc) session = client.get_session(passphrase=passphrase_nfc)
session.client.use_passphrase(passphrase_nfc) # TODO is needed?
address_nfc = get_test_address(session) address_nfc = get_test_address(session)
device.wipe(session) device.wipe(session)
@ -154,7 +152,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True, skip_checksum=True,
) )
session = client.get_session(passphrase=passphrase_nfkc) session = client.get_session(passphrase=passphrase_nfkc)
session.client.use_passphrase(passphrase_nfkc) # TODO is needed?
address_nfkc = get_test_address(session) address_nfkc = get_test_address(session)
device.wipe(session) device.wipe(session)
@ -169,7 +166,6 @@ def test_load_device_utf(client: Client):
skip_checksum=True, skip_checksum=True,
) )
session = client.get_session(passphrase=passphrase_nfd) session = client.get_session(passphrase=passphrase_nfd)
session.client.use_passphrase(passphrase_nfd) # TODO is needed?
address_nfd = get_test_address(session) address_nfd = get_test_address(session)
assert address_nfkd == address_nfc assert address_nfkd == address_nfc
assert address_nfkd == address_nfkc assert address_nfkd == address_nfkc

View File

@ -146,7 +146,6 @@ def test_session_recycling(client: Client):
messages.Address, messages.Address,
] ]
) )
client.use_passphrase("TREZOR")
_ = get_test_address(session) _ = get_test_address(session)
# address = get_test_address(session) # address = get_test_address(session)

View File

@ -19,7 +19,6 @@ import random
import pytest import pytest
from trezorlib import device, exceptions, messages from trezorlib import device, exceptions, messages
from trezorlib.client import ProtocolVersion
from trezorlib.debuglink import LayoutType from trezorlib.debuglink import LayoutType
from trezorlib.debuglink import SessionDebugWrapper as Session from trezorlib.debuglink import SessionDebugWrapper as Session
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
@ -54,7 +53,6 @@ SESSIONS_STORED = 10
def _get_xpub( def _get_xpub(
session: Session, session: Session,
expected_passphrase_req: bool = False, expected_passphrase_req: bool = False,
passphrase_v1: str | None = None,
): ):
"""Get XPUB and check that the appropriate passphrase flow has happened.""" """Get XPUB and check that the appropriate passphrase flow has happened."""
if expected_passphrase_req: if expected_passphrase_req:
@ -66,11 +64,6 @@ def _get_xpub(
] ]
else: else:
expected_responses = [messages.PublicKey] expected_responses = [messages.PublicKey]
if (
passphrase_v1 is not None
and session.protocol_version == ProtocolVersion.PROTOCOL_V1
):
session.passphrase = passphrase_v1
with session: with session:
session.set_expected_responses(expected_responses) session.set_expected_responses(expected_responses)
@ -228,7 +221,6 @@ def test_max_sessions_with_passphrases(client: Client):
_get_xpub( _get_xpub(
resumed_session, resumed_session,
expected_passphrase_req=True, expected_passphrase_req=True,
passphrase_v1="whatever",
) # passphrase is prompted ) # passphrase is prompted
@ -435,7 +427,6 @@ def test_hide_passphrase_from_host(client: Client):
messages.PublicKey, messages.PublicKey,
] ]
) )
client.use_passphrase(passphrase)
result = session.call(XPUB_REQUEST) result = session.call(XPUB_REQUEST)
assert isinstance(result, messages.PublicKey) assert isinstance(result, messages.PublicKey)
xpub_hidden_passphrase = result.xpub xpub_hidden_passphrase = result.xpub
@ -471,7 +462,6 @@ def test_hide_passphrase_from_host(client: Client):
messages.PublicKey, messages.PublicKey,
] ]
) )
client.use_passphrase(passphrase)
result = session.call(XPUB_REQUEST) result = session.call(XPUB_REQUEST)
assert isinstance(result, messages.PublicKey) assert isinstance(result, messages.PublicKey)
xpub_shown_passphrase = result.xpub xpub_shown_passphrase = result.xpub

View File

@ -384,7 +384,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
address = btc.get_address(session, "Bitcoin", PATH) address = btc.get_address(session, "Bitcoin", PATH)
if session.protocol_version == ProtocolVersion.PROTOCOL_V1: if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
session.call(messages.Initialize(new_session=True)) session.call(messages.Initialize(new_session=True))
new_session = emu.client.get_session(passphrase="TREZOR") new_session = Session(emu.client.get_session(passphrase="TREZOR"))
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH) address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)
assert emu.client.features.backup_availability == BackupAvailability.Required assert emu.client.features.backup_availability == BackupAvailability.Required