1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-12 14:16:06 +00:00

fix(python): simplify UI callbacks

This commit is contained in:
Martin Milata 2025-02-27 18:49:04 +01:00 committed by M1nd3r
parent 21b69d06c6
commit 7bcbe0aac4
3 changed files with 87 additions and 93 deletions

View File

@ -56,9 +56,11 @@ class ProtocolVersion(IntEnum):
class TrezorClient:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None
passphrase_callback: (
t.Callable[[Session, messages.PassphraseRequest], t.Any] | None
) = None
pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None
_seedless_session: Session | None = None
_features: messages.Features | None = None

View File

@ -1065,9 +1065,6 @@ class SessionDebugWrapper(Session):
t.Type[protobuf.MessageType],
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
] = {}
self.button_callback = self.client.button_callback
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self._session.passphrase_callback
def __enter__(self) -> "SessionDebugWrapper":
# For usage in with/expected_responses
@ -1232,102 +1229,88 @@ class TrezorClientDebugLink(TrezorClient):
self.ui: DebugUI = DebugUI(self.debug)
self.in_with_statement = False
@property
def button_callback(self):
def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
session._write(messages.ButtonAck())
self.ui.button_request(msg)
return session._read()
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
session._write(messages.ButtonAck())
self.ui.button_request(msg)
return session._read()
def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any:
try:
pin = self.ui.get_pin(msg.type)
except Cancelled:
session.call_raw(messages.Cancel())
raise
return _callback_button
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
session.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
@property
def pin_callback(self):
def passphrase_callback(
self, session: Session, msg: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
try:
pin = self.ui.get_pin(msg.type)
except Cancelled:
session.call_raw(messages.Cancel())
raise
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> MessageType:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
if resp.state is not None:
session.id = resp.state
else:
raise RuntimeError("Object resp.state is None")
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
session.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
if isinstance(session, SessionDebugWrapper):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
if passphrase is None:
passphrase = session.passphrase
else:
return resp
raise NotImplementedError
except Cancelled:
session.call_raw(messages.Cancel())
raise
return _callback_pin
@property
def passphrase_callback(self):
def _callback_passphrase(
session: Session, msg: messages.PassphraseRequest
) -> t.Any:
available_on_device = (
Capability.PassphraseEntry in session.features.capabilities
)
def send_passphrase(
passphrase: str | None = None, on_device: bool | None = None
) -> MessageType:
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
resp = session.call_raw(msg)
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
if resp.state is not None:
session.id = resp.state
else:
raise RuntimeError("Object resp.state is None")
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
return resp
# short-circuit old style entry
if msg._on_device is True:
return send_passphrase(None, None)
try:
if isinstance(session, SessionDebugWrapper):
passphrase = self.ui.get_passphrase(
available_on_device=available_on_device
)
if passphrase is None:
passphrase = session.passphrase
else:
raise NotImplementedError
except Cancelled:
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
if passphrase is PASSPHRASE_ON_DEVICE:
if not available_on_device:
session.call_raw(messages.Cancel())
raise RuntimeError("Device is not capable of entering passphrase")
else:
return send_passphrase(on_device=True)
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
# else process host-entered passphrase
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
raise ValueError("Passphrase too long")
return send_passphrase(passphrase, on_device=False)
return _callback_passphrase
return send_passphrase(passphrase, on_device=False)
def close_transport(self) -> None:
self.transport.close()

View File

@ -174,4 +174,13 @@ class SessionV1(Session):
def default_button_callback(session: Session, msg: t.Any) -> t.Any:
return session.call(messages.ButtonAck())
return session.call_raw(messages.ButtonAck())
def derive_seed(session: Session) -> None:
from ..btc import get_address
from ..client import PASSPHRASE_TEST_PATH
get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
session.refresh_features()