mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-12 14:16:06 +00:00
fix(python): simplify UI callbacks
This commit is contained in:
parent
21b69d06c6
commit
7bcbe0aac4
@ -56,9 +56,11 @@ class ProtocolVersion(IntEnum):
|
|||||||
|
|
||||||
|
|
||||||
class TrezorClient:
|
class TrezorClient:
|
||||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None
|
||||||
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
passphrase_callback: (
|
||||||
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
t.Callable[[Session, messages.PassphraseRequest], t.Any] | None
|
||||||
|
) = None
|
||||||
|
pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None
|
||||||
|
|
||||||
_seedless_session: Session | None = None
|
_seedless_session: Session | None = None
|
||||||
_features: messages.Features | None = None
|
_features: messages.Features | None = None
|
||||||
|
@ -1065,9 +1065,6 @@ class SessionDebugWrapper(Session):
|
|||||||
t.Type[protobuf.MessageType],
|
t.Type[protobuf.MessageType],
|
||||||
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
t.Callable[[protobuf.MessageType], protobuf.MessageType] | None,
|
||||||
] = {}
|
] = {}
|
||||||
self.button_callback = self.client.button_callback
|
|
||||||
self.pin_callback = self.client.pin_callback
|
|
||||||
self.passphrase_callback = self._session.passphrase_callback
|
|
||||||
|
|
||||||
def __enter__(self) -> "SessionDebugWrapper":
|
def __enter__(self) -> "SessionDebugWrapper":
|
||||||
# For usage in with/expected_responses
|
# For usage in with/expected_responses
|
||||||
@ -1232,102 +1229,88 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
self.ui: DebugUI = DebugUI(self.debug)
|
self.ui: DebugUI = DebugUI(self.debug)
|
||||||
self.in_with_statement = False
|
self.in_with_statement = False
|
||||||
|
|
||||||
@property
|
def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any:
|
||||||
def button_callback(self):
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||||
|
# do this raw - send ButtonAck first, notify UI later
|
||||||
|
session._write(messages.ButtonAck())
|
||||||
|
self.ui.button_request(msg)
|
||||||
|
return session._read()
|
||||||
|
|
||||||
def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any:
|
def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any:
|
||||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
try:
|
||||||
# do this raw - send ButtonAck first, notify UI later
|
pin = self.ui.get_pin(msg.type)
|
||||||
session._write(messages.ButtonAck())
|
except Cancelled:
|
||||||
self.ui.button_request(msg)
|
session.call_raw(messages.Cancel())
|
||||||
return session._read()
|
raise
|
||||||
|
|
||||||
return _callback_button
|
if any(d not in "123456789" for d in pin) or not (
|
||||||
|
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||||
|
):
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise ValueError("Invalid PIN provided")
|
||||||
|
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||||
|
if isinstance(resp, messages.Failure) and resp.code in (
|
||||||
|
messages.FailureType.PinInvalid,
|
||||||
|
messages.FailureType.PinCancelled,
|
||||||
|
messages.FailureType.PinExpected,
|
||||||
|
):
|
||||||
|
raise PinException(resp.code, resp.message)
|
||||||
|
else:
|
||||||
|
return resp
|
||||||
|
|
||||||
@property
|
def passphrase_callback(
|
||||||
def pin_callback(self):
|
self, session: Session, msg: messages.PassphraseRequest
|
||||||
|
) -> t.Any:
|
||||||
|
available_on_device = (
|
||||||
|
Capability.PassphraseEntry in session.features.capabilities
|
||||||
|
)
|
||||||
|
|
||||||
def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any:
|
def send_passphrase(
|
||||||
try:
|
passphrase: str | None = None, on_device: bool | None = None
|
||||||
pin = self.ui.get_pin(msg.type)
|
) -> MessageType:
|
||||||
except Cancelled:
|
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
||||||
session.call_raw(messages.Cancel())
|
resp = session.call_raw(msg)
|
||||||
raise
|
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
||||||
|
if resp.state is not None:
|
||||||
|
session.id = resp.state
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Object resp.state is None")
|
||||||
|
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
||||||
|
return resp
|
||||||
|
|
||||||
if any(d not in "123456789" for d in pin) or not (
|
# short-circuit old style entry
|
||||||
1 <= len(pin) <= MAX_PIN_LENGTH
|
if msg._on_device is True:
|
||||||
):
|
return send_passphrase(None, None)
|
||||||
session.call_raw(messages.Cancel())
|
|
||||||
raise ValueError("Invalid PIN provided")
|
try:
|
||||||
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
if isinstance(session, SessionDebugWrapper):
|
||||||
if isinstance(resp, messages.Failure) and resp.code in (
|
passphrase = self.ui.get_passphrase(
|
||||||
messages.FailureType.PinInvalid,
|
available_on_device=available_on_device
|
||||||
messages.FailureType.PinCancelled,
|
)
|
||||||
messages.FailureType.PinExpected,
|
if passphrase is None:
|
||||||
):
|
passphrase = session.passphrase
|
||||||
raise PinException(resp.code, resp.message)
|
|
||||||
else:
|
else:
|
||||||
return resp
|
raise NotImplementedError
|
||||||
|
except Cancelled:
|
||||||
|
session.call_raw(messages.Cancel())
|
||||||
|
raise
|
||||||
|
|
||||||
return _callback_pin
|
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||||
|
if not available_on_device:
|
||||||
@property
|
|
||||||
def passphrase_callback(self):
|
|
||||||
def _callback_passphrase(
|
|
||||||
session: Session, msg: messages.PassphraseRequest
|
|
||||||
) -> t.Any:
|
|
||||||
available_on_device = (
|
|
||||||
Capability.PassphraseEntry in session.features.capabilities
|
|
||||||
)
|
|
||||||
|
|
||||||
def send_passphrase(
|
|
||||||
passphrase: str | None = None, on_device: bool | None = None
|
|
||||||
) -> MessageType:
|
|
||||||
msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device)
|
|
||||||
resp = session.call_raw(msg)
|
|
||||||
if isinstance(resp, messages.Deprecated_PassphraseStateRequest):
|
|
||||||
if resp.state is not None:
|
|
||||||
session.id = resp.state
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Object resp.state is None")
|
|
||||||
resp = session.call_raw(messages.Deprecated_PassphraseStateAck())
|
|
||||||
return resp
|
|
||||||
|
|
||||||
# short-circuit old style entry
|
|
||||||
if msg._on_device is True:
|
|
||||||
return send_passphrase(None, None)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if isinstance(session, SessionDebugWrapper):
|
|
||||||
passphrase = self.ui.get_passphrase(
|
|
||||||
available_on_device=available_on_device
|
|
||||||
)
|
|
||||||
if passphrase is None:
|
|
||||||
passphrase = session.passphrase
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
except Cancelled:
|
|
||||||
session.call_raw(messages.Cancel())
|
session.call_raw(messages.Cancel())
|
||||||
raise
|
raise RuntimeError("Device is not capable of entering passphrase")
|
||||||
|
else:
|
||||||
|
return send_passphrase(on_device=True)
|
||||||
|
|
||||||
if passphrase is PASSPHRASE_ON_DEVICE:
|
# else process host-entered passphrase
|
||||||
if not available_on_device:
|
if not isinstance(passphrase, str):
|
||||||
session.call_raw(messages.Cancel())
|
raise RuntimeError("Passphrase must be a str")
|
||||||
raise RuntimeError("Device is not capable of entering passphrase")
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
else:
|
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||||
return send_passphrase(on_device=True)
|
session.call_raw(messages.Cancel())
|
||||||
|
raise ValueError("Passphrase too long")
|
||||||
|
|
||||||
# else process host-entered passphrase
|
return send_passphrase(passphrase, on_device=False)
|
||||||
if not isinstance(passphrase, str):
|
|
||||||
raise RuntimeError("Passphrase must be a str")
|
|
||||||
passphrase = Mnemonic.normalize_string(passphrase)
|
|
||||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
|
||||||
session.call_raw(messages.Cancel())
|
|
||||||
raise ValueError("Passphrase too long")
|
|
||||||
|
|
||||||
return send_passphrase(passphrase, on_device=False)
|
|
||||||
|
|
||||||
return _callback_passphrase
|
|
||||||
|
|
||||||
def close_transport(self) -> None:
|
def close_transport(self) -> None:
|
||||||
self.transport.close()
|
self.transport.close()
|
||||||
|
@ -174,4 +174,13 @@ class SessionV1(Session):
|
|||||||
|
|
||||||
|
|
||||||
def default_button_callback(session: Session, msg: t.Any) -> t.Any:
|
def default_button_callback(session: Session, msg: t.Any) -> t.Any:
|
||||||
return session.call(messages.ButtonAck())
|
return session.call_raw(messages.ButtonAck())
|
||||||
|
|
||||||
|
|
||||||
|
def derive_seed(session: Session) -> None:
|
||||||
|
|
||||||
|
from ..btc import get_address
|
||||||
|
from ..client import PASSPHRASE_TEST_PATH
|
||||||
|
|
||||||
|
get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
|
||||||
|
session.refresh_features()
|
||||||
|
Loading…
Reference in New Issue
Block a user