diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 2de3a4d97e..48d8b98b4f 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -56,9 +56,11 @@ class ProtocolVersion(IntEnum): class TrezorClient: - button_callback: t.Callable[[Session, t.Any], t.Any] | None = None - passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None - pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None + passphrase_callback: ( + t.Callable[[Session, messages.PassphraseRequest], t.Any] | None + ) = None + pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None _seedless_session: Session | None = None _features: messages.Features | None = None diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 253f05e374..9ea029784e 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1065,9 +1065,6 @@ class SessionDebugWrapper(Session): t.Type[protobuf.MessageType], t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} - self.button_callback = self.client.button_callback - self.pin_callback = self.client.pin_callback - self.passphrase_callback = self._session.passphrase_callback def __enter__(self) -> "SessionDebugWrapper": # For usage in with/expected_responses @@ -1232,102 +1229,88 @@ class TrezorClientDebugLink(TrezorClient): self.ui: DebugUI = DebugUI(self.debug) self.in_with_statement = False - @property - def button_callback(self): + def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + session._write(messages.ButtonAck()) + self.ui.button_request(msg) + return session._read() - def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - session._write(messages.ButtonAck()) - self.ui.button_request(msg) - return session._read() + def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any: + try: + pin = self.ui.get_pin(msg.type) + except Cancelled: + session.call_raw(messages.Cancel()) + raise - return _callback_button + if any(d not in "123456789" for d in pin) or not ( + 1 <= len(pin) <= MAX_PIN_LENGTH + ): + session.call_raw(messages.Cancel()) + raise ValueError("Invalid PIN provided") + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp - @property - def pin_callback(self): + def passphrase_callback( + self, session: Session, msg: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) - def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any: - try: - pin = self.ui.get_pin(msg.type) - except Cancelled: - session.call_raw(messages.Cancel()) - raise + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> MessageType: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + if resp.state is not None: + session.id = resp.state + else: + raise RuntimeError("Object resp.state is None") + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - session.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - resp = session.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise PinException(resp.code, resp.message) + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + if isinstance(session, SessionDebugWrapper): + passphrase = self.ui.get_passphrase( + available_on_device=available_on_device + ) + if passphrase is None: + passphrase = session.passphrase else: - return resp + raise NotImplementedError + except Cancelled: + session.call_raw(messages.Cancel()) + raise - return _callback_pin - - @property - def passphrase_callback(self): - def _callback_passphrase( - session: Session, msg: messages.PassphraseRequest - ) -> t.Any: - available_on_device = ( - Capability.PassphraseEntry in session.features.capabilities - ) - - def send_passphrase( - passphrase: str | None = None, on_device: bool | None = None - ) -> MessageType: - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = session.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - if resp.state is not None: - session.id = resp.state - else: - raise RuntimeError("Object resp.state is None") - resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - if isinstance(session, SessionDebugWrapper): - passphrase = self.ui.get_passphrase( - available_on_device=available_on_device - ) - if passphrase is None: - passphrase = session.passphrase - else: - raise NotImplementedError - except Cancelled: + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: session.call_raw(messages.Cancel()) - raise + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - session.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - session.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - return _callback_passphrase + return send_passphrase(passphrase, on_device=False) def close_transport(self) -> None: self.transport.close() diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 38d2b16ceb..527a2a1d6e 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -174,4 +174,13 @@ class SessionV1(Session): def default_button_callback(session: Session, msg: t.Any) -> t.Any: - return session.call(messages.ButtonAck()) + return session.call_raw(messages.ButtonAck()) + + +def derive_seed(session: Session) -> None: + + from ..btc import get_address + from ..client import PASSPHRASE_TEST_PATH + + get_address(session, "Testnet", PASSPHRASE_TEST_PATH) + session.refresh_features()