mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-19 11:02:02 +00:00
debuglink and tests fixes
This commit is contained in:
parent
c33f922650
commit
4ee6ffa81d
@ -115,10 +115,8 @@ class TrezorClient:
|
|||||||
from .transport.session import SessionV1
|
from .transport.session import SessionV1
|
||||||
|
|
||||||
if isinstance(self.protocol, ProtocolV1):
|
if isinstance(self.protocol, ProtocolV1):
|
||||||
if passphrase is None:
|
|
||||||
passphrase = ""
|
|
||||||
return SessionV1.new(self, passphrase, derive_cardano)
|
return SessionV1.new(self, passphrase, derive_cardano)
|
||||||
raise NotImplementedError # TODO
|
raise NotImplementedError
|
||||||
|
|
||||||
def resume_session(self, session: Session):
|
def resume_session(self, session: Session):
|
||||||
"""
|
"""
|
||||||
|
@ -798,7 +798,7 @@ class DebugUI:
|
|||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
self.pins: t.Iterator[str] | None = None
|
self.pins: t.Iterator[str] | None = None
|
||||||
self.passphrase = ""
|
self.passphrase = None
|
||||||
self.input_flow: t.Union[
|
self.input_flow: t.Union[
|
||||||
t.Generator[None, messages.ButtonRequest, None], object, None
|
t.Generator[None, messages.ButtonRequest, None], object, None
|
||||||
] = None
|
] = None
|
||||||
@ -848,7 +848,7 @@ class DebugUI:
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise AssertionError("PIN sequence ended prematurely")
|
raise AssertionError("PIN sequence ended prematurely")
|
||||||
|
|
||||||
def get_passphrase(self, available_on_device: bool) -> str:
|
def get_passphrase(self, available_on_device: bool) -> str | None | object:
|
||||||
self.debuglink.snapshot_legacy()
|
self.debuglink.snapshot_legacy()
|
||||||
return self.passphrase
|
return self.passphrase
|
||||||
|
|
||||||
@ -968,6 +968,10 @@ class SessionDebugWrapper(Session):
|
|||||||
def id(self) -> bytes:
|
def id(self) -> bytes:
|
||||||
return self._session.id
|
return self._session.id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def passphrase(self) -> str | None | object:
|
||||||
|
return self._session.passphrase
|
||||||
|
|
||||||
def _write(self, msg: t.Any) -> None:
|
def _write(self, msg: t.Any) -> None:
|
||||||
print("writing message:", msg.__class__.__name__)
|
print("writing message:", msg.__class__.__name__)
|
||||||
self._session._write(self._filter_message(msg))
|
self._session._write(self._filter_message(msg))
|
||||||
@ -1090,7 +1094,6 @@ class SessionDebugWrapper(Session):
|
|||||||
self.button_callback = self.client.button_callback
|
self.button_callback = self.client.button_callback
|
||||||
self.pin_callback = self.client.pin_callback
|
self.pin_callback = self.client.pin_callback
|
||||||
self.passphrase_callback = self._session.passphrase_callback
|
self.passphrase_callback = self._session.passphrase_callback
|
||||||
self.passphrase = self._session.passphrase
|
|
||||||
|
|
||||||
def __enter__(self) -> "SessionDebugWrapper":
|
def __enter__(self) -> "SessionDebugWrapper":
|
||||||
# For usage in with/expected_responses
|
# For usage in with/expected_responses
|
||||||
@ -1224,7 +1227,6 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
# and know the supported debug capabilities
|
# and know the supported debug capabilities
|
||||||
self.debug.model = self.model
|
self.debug.model = self.model
|
||||||
self.debug.version = self.version
|
self.debug.version = self.version
|
||||||
self.passphrase: str | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def layout_type(self) -> LayoutType:
|
def layout_type(self) -> LayoutType:
|
||||||
@ -1308,7 +1310,7 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||||
resp = session.call_raw(msg)
|
resp = session.call_raw(msg)
|
||||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||||
# session.session_id = resp.state
|
session._session.id = resp.state
|
||||||
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
@ -1317,12 +1319,14 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
return send_passphrase(None, None)
|
return send_passphrase(None, None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if session.passphrase is None and isinstance(session, SessionV1):
|
if isinstance(session, SessionV1) or isinstance(session, SessionDebugWrapper):
|
||||||
passphrase = self.ui.get_passphrase(
|
passphrase = self.ui.get_passphrase(
|
||||||
available_on_device=available_on_device
|
available_on_device=available_on_device
|
||||||
)
|
)
|
||||||
|
if passphrase is None:
|
||||||
|
passphrase = session.passphrase
|
||||||
else:
|
else:
|
||||||
passphrase = session.passphrase
|
raise NotImplementedError
|
||||||
except Cancelled:
|
except Cancelled:
|
||||||
session.call_raw(messages.Cancel())
|
session.call_raw(messages.Cancel())
|
||||||
raise
|
raise
|
||||||
@ -1376,33 +1380,6 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
passphrase = Mnemonic.normalize_string(passphrase)
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
return super().get_session(passphrase, derive_cardano, session_id)
|
return super().get_session(passphrase, derive_cardano, session_id)
|
||||||
|
|
||||||
def set_filter(
|
|
||||||
self,
|
|
||||||
message_type: t.Type[protobuf.MessageType],
|
|
||||||
callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
|
||||||
) -> None:
|
|
||||||
"""Configure a filter function for a specified message type.
|
|
||||||
|
|
||||||
The `callback` must be a function that accepts a protobuf message, and returns
|
|
||||||
a (possibly modified) protobuf message of the same type. Whenever a message
|
|
||||||
is sent or received that matches `message_type`, `callback` is invoked on the
|
|
||||||
message and its result is substituted for the original.
|
|
||||||
|
|
||||||
Useful for test scenarios with an active malicious actor on the wire.
|
|
||||||
"""
|
|
||||||
if not self.in_with_statement:
|
|
||||||
raise RuntimeError("Must be called inside 'with' statement")
|
|
||||||
|
|
||||||
self.filters[message_type] = callback
|
|
||||||
|
|
||||||
def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType:
|
|
||||||
message_type = msg.__class__
|
|
||||||
callback = self.filters.get(message_type)
|
|
||||||
if callable(callback):
|
|
||||||
return callback(deepcopy(msg))
|
|
||||||
else:
|
|
||||||
return msg
|
|
||||||
|
|
||||||
def set_input_flow(
|
def set_input_flow(
|
||||||
self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
|
self, input_flow: InputFlowType | t.Callable[[], InputFlowType]
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -1536,7 +1513,6 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
def use_passphrase(self, passphrase: str) -> None:
|
def use_passphrase(self, passphrase: str) -> None:
|
||||||
"""Respond to passphrase prompts from device with the provided passphrase."""
|
"""Respond to passphrase prompts from device with the provided passphrase."""
|
||||||
self.passphrase = passphrase
|
|
||||||
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
self.ui.passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
|
|
||||||
def use_mnemonic(self, mnemonic: str) -> None:
|
def use_mnemonic(self, mnemonic: str) -> None:
|
||||||
@ -1544,17 +1520,6 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
Only applies to T1, where device prompts the host for mnemonic words."""
|
Only applies to T1, where device prompts the host for mnemonic words."""
|
||||||
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
self.mnemonic = Mnemonic.normalize_string(mnemonic).split(" ")
|
||||||
|
|
||||||
def _raw_read(self) -> protobuf.MessageType:
|
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
|
||||||
resp = self.get_seedless_session()._read()
|
|
||||||
resp = self._filter_message(resp)
|
|
||||||
if self.actual_responses is not None:
|
|
||||||
self.actual_responses.append(resp)
|
|
||||||
return resp
|
|
||||||
|
|
||||||
def _raw_write(self, msg: protobuf.MessageType) -> None:
|
|
||||||
return self.get_seedless_session()._write(self._filter_message(msg))
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]:
|
||||||
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0)
|
||||||
|
@ -98,7 +98,6 @@ def prepare_passphrase_dialogue(
|
|||||||
) -> Generator["DebugLink", None, None]:
|
) -> Generator["DebugLink", None, None]:
|
||||||
debug = device_handler.debuglink()
|
debug = device_handler.debuglink()
|
||||||
device_handler.run_with_session(get_test_address) # type: ignore
|
device_handler.run_with_session(get_test_address) # type: ignore
|
||||||
# TODO
|
|
||||||
assert debug.read_layout().main_component() == "PassphraseKeyboard"
|
assert debug.read_layout().main_component() == "PassphraseKeyboard"
|
||||||
|
|
||||||
# Resetting the category as it could have been changed by previous tests
|
# Resetting the category as it could have been changed by previous tests
|
||||||
|
@ -372,8 +372,6 @@ def _client_unlocked(
|
|||||||
if request.node.get_closest_marker("experimental"):
|
if request.node.get_closest_marker("experimental"):
|
||||||
apply_settings(session, experimental_features=True)
|
apply_settings(session, experimental_features=True)
|
||||||
|
|
||||||
if use_passphrase and isinstance(setup_params["passphrase"], str):
|
|
||||||
_raw_client.use_passphrase(setup_params["passphrase"])
|
|
||||||
|
|
||||||
# TODO _raw_client.clear_session()
|
# TODO _raw_client.clear_session()
|
||||||
|
|
||||||
@ -399,7 +397,10 @@ def session(
|
|||||||
session = _client_unlocked.get_seedless_session()
|
session = _client_unlocked.get_seedless_session()
|
||||||
else:
|
else:
|
||||||
derive_cardano = bool(request.node.get_closest_marker("cardano"))
|
derive_cardano = bool(request.node.get_closest_marker("cardano"))
|
||||||
passphrase = _client_unlocked.passphrase or ""
|
passphrase = ""
|
||||||
|
marker = request.node.get_closest_marker("setup_client")
|
||||||
|
if marker and isinstance(marker.kwargs.get("passphrase"), str):
|
||||||
|
passphrase = marker.kwargs["passphrase"]
|
||||||
if _client_unlocked._setup_pin is not None:
|
if _client_unlocked._setup_pin is not None:
|
||||||
_client_unlocked.use_pin_sequence([_client_unlocked._setup_pin])
|
_client_unlocked.use_pin_sequence([_client_unlocked._setup_pin])
|
||||||
session = _client_unlocked.get_session(
|
session = _client_unlocked.get_session(
|
||||||
|
@ -65,8 +65,8 @@ class BackgroundDeviceHandler:
|
|||||||
raise RuntimeError("Wait for previous task first")
|
raise RuntimeError("Wait for previous task first")
|
||||||
|
|
||||||
# wait for the first UI change triggered by the task running in the background
|
# wait for the first UI change triggered by the task running in the background
|
||||||
|
session = self.client.get_session()
|
||||||
with self.debuglink().wait_for_layout_change():
|
with self.debuglink().wait_for_layout_change():
|
||||||
session = self.client.get_session()
|
|
||||||
self.task = self._pool.submit(function, session, *args, **kwargs)
|
self.task = self._pool.submit(function, session, *args, **kwargs)
|
||||||
|
|
||||||
def run_with_provided_session(
|
def run_with_provided_session(
|
||||||
|
@ -124,7 +124,6 @@ def test_load_device_utf(client: Client):
|
|||||||
skip_checksum=True,
|
skip_checksum=True,
|
||||||
)
|
)
|
||||||
session: Session = session.client.get_session(passphrase=passphrase_nfkd)
|
session: Session = session.client.get_session(passphrase=passphrase_nfkd)
|
||||||
session.client.use_passphrase(passphrase_nfkd) # TODO is needed?
|
|
||||||
address_nfkd = get_test_address(session)
|
address_nfkd = get_test_address(session)
|
||||||
|
|
||||||
device.wipe(session)
|
device.wipe(session)
|
||||||
@ -139,7 +138,6 @@ def test_load_device_utf(client: Client):
|
|||||||
skip_checksum=True,
|
skip_checksum=True,
|
||||||
)
|
)
|
||||||
session = client.get_session(passphrase=passphrase_nfc)
|
session = client.get_session(passphrase=passphrase_nfc)
|
||||||
session.client.use_passphrase(passphrase_nfc) # TODO is needed?
|
|
||||||
address_nfc = get_test_address(session)
|
address_nfc = get_test_address(session)
|
||||||
|
|
||||||
device.wipe(session)
|
device.wipe(session)
|
||||||
@ -154,7 +152,6 @@ def test_load_device_utf(client: Client):
|
|||||||
skip_checksum=True,
|
skip_checksum=True,
|
||||||
)
|
)
|
||||||
session = client.get_session(passphrase=passphrase_nfkc)
|
session = client.get_session(passphrase=passphrase_nfkc)
|
||||||
session.client.use_passphrase(passphrase_nfkc) # TODO is needed?
|
|
||||||
address_nfkc = get_test_address(session)
|
address_nfkc = get_test_address(session)
|
||||||
|
|
||||||
device.wipe(session)
|
device.wipe(session)
|
||||||
@ -169,7 +166,6 @@ def test_load_device_utf(client: Client):
|
|||||||
skip_checksum=True,
|
skip_checksum=True,
|
||||||
)
|
)
|
||||||
session = client.get_session(passphrase=passphrase_nfd)
|
session = client.get_session(passphrase=passphrase_nfd)
|
||||||
session.client.use_passphrase(passphrase_nfd) # TODO is needed?
|
|
||||||
address_nfd = get_test_address(session)
|
address_nfd = get_test_address(session)
|
||||||
assert address_nfkd == address_nfc
|
assert address_nfkd == address_nfc
|
||||||
assert address_nfkd == address_nfkc
|
assert address_nfkd == address_nfkc
|
||||||
|
@ -146,7 +146,6 @@ def test_session_recycling(client: Client):
|
|||||||
messages.Address,
|
messages.Address,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
client.use_passphrase("TREZOR")
|
|
||||||
_ = get_test_address(session)
|
_ = get_test_address(session)
|
||||||
# address = get_test_address(session)
|
# address = get_test_address(session)
|
||||||
|
|
||||||
|
@ -54,7 +54,6 @@ SESSIONS_STORED = 10
|
|||||||
def _get_xpub(
|
def _get_xpub(
|
||||||
session: Session,
|
session: Session,
|
||||||
expected_passphrase_req: bool = False,
|
expected_passphrase_req: bool = False,
|
||||||
passphrase_v1: str | None = None,
|
|
||||||
):
|
):
|
||||||
"""Get XPUB and check that the appropriate passphrase flow has happened."""
|
"""Get XPUB and check that the appropriate passphrase flow has happened."""
|
||||||
if expected_passphrase_req:
|
if expected_passphrase_req:
|
||||||
@ -66,11 +65,6 @@ def _get_xpub(
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
expected_responses = [messages.PublicKey]
|
expected_responses = [messages.PublicKey]
|
||||||
if (
|
|
||||||
passphrase_v1 is not None
|
|
||||||
and session.protocol_version == ProtocolVersion.PROTOCOL_V1
|
|
||||||
):
|
|
||||||
session.passphrase = passphrase_v1
|
|
||||||
|
|
||||||
with session:
|
with session:
|
||||||
session.set_expected_responses(expected_responses)
|
session.set_expected_responses(expected_responses)
|
||||||
@ -228,7 +222,6 @@ def test_max_sessions_with_passphrases(client: Client):
|
|||||||
_get_xpub(
|
_get_xpub(
|
||||||
resumed_session,
|
resumed_session,
|
||||||
expected_passphrase_req=True,
|
expected_passphrase_req=True,
|
||||||
passphrase_v1="whatever",
|
|
||||||
) # passphrase is prompted
|
) # passphrase is prompted
|
||||||
|
|
||||||
|
|
||||||
@ -435,7 +428,6 @@ def test_hide_passphrase_from_host(client: Client):
|
|||||||
messages.PublicKey,
|
messages.PublicKey,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
client.use_passphrase(passphrase)
|
|
||||||
result = session.call(XPUB_REQUEST)
|
result = session.call(XPUB_REQUEST)
|
||||||
assert isinstance(result, messages.PublicKey)
|
assert isinstance(result, messages.PublicKey)
|
||||||
xpub_hidden_passphrase = result.xpub
|
xpub_hidden_passphrase = result.xpub
|
||||||
@ -471,7 +463,6 @@ def test_hide_passphrase_from_host(client: Client):
|
|||||||
messages.PublicKey,
|
messages.PublicKey,
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
client.use_passphrase(passphrase)
|
|
||||||
result = session.call(XPUB_REQUEST)
|
result = session.call(XPUB_REQUEST)
|
||||||
assert isinstance(result, messages.PublicKey)
|
assert isinstance(result, messages.PublicKey)
|
||||||
xpub_shown_passphrase = result.xpub
|
xpub_shown_passphrase = result.xpub
|
||||||
|
@ -384,7 +384,7 @@ def test_upgrade_shamir_backup(gen: str, tag: Optional[str]):
|
|||||||
address = btc.get_address(session, "Bitcoin", PATH)
|
address = btc.get_address(session, "Bitcoin", PATH)
|
||||||
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
if session.protocol_version == ProtocolVersion.PROTOCOL_V1:
|
||||||
session.call(messages.Initialize(new_session=True))
|
session.call(messages.Initialize(new_session=True))
|
||||||
new_session = emu.client.get_session(passphrase="TREZOR")
|
new_session = Session(emu.client.get_session(passphrase="TREZOR"))
|
||||||
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)
|
address_passphrase = btc.get_address(new_session, "Bitcoin", PATH)
|
||||||
|
|
||||||
assert emu.client.features.backup_availability == BackupAvailability.Required
|
assert emu.client.features.backup_availability == BackupAvailability.Required
|
||||||
|
Loading…
Reference in New Issue
Block a user