diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 5ad859eec4..7b78e3e686 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -3,6 +3,11 @@ from __future__ import annotations import logging import typing as t +from mnemonic import Mnemonic +from trezorlib.client import MAX_PASSPHRASE_LENGTH, PASSPHRASE_ON_DEVICE + +from trezor.enums import Capability + from .. import exceptions, messages, models from .thp.protocol_v1 import ProtocolV1 from .thp.protocol_v2 import ProtocolV2 @@ -16,6 +21,7 @@ LOG = logging.getLogger(__name__) class Session: button_callback: t.Callable[[Session, t.Any], t.Any] | None = None pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None def __init__(self, client: TrezorClient, id: bytes) -> None: self.client = client @@ -37,8 +43,9 @@ class Session: raise Exception # TODO resp = self.pin_callback(self, resp) elif isinstance(resp, messages.PassphraseRequest): - raise NotImplementedError - # resp = self._callback_passphrase(resp) + if self.passphrase_callback is None: + raise Exception # TODO + resp = self.passphrase_callback(self, resp) elif isinstance(resp, messages.ButtonRequest): if self.button_callback is None: raise Exception # TODO @@ -97,6 +104,7 @@ class SessionV1(Session): session = SessionV1(client, session_id) session.button_callback = client.button_callback session.pin_callback = client.pin_callback + session.passphrase_callback = _callback_passphrase session._init_session(derive_cardano=derive_cardano) return session @@ -111,17 +119,58 @@ class SessionV1(Session): return self.client.protocol.read() def _init_session(self, derive_cardano: bool = False): - self._write( + self.call_raw( messages.Initialize(session_id=self.id, derive_cardano=derive_cardano) ) - _ = self._read() - def _callback_button(session: Session, msg: t.Any) -> t.Any: print("Please confirm action on your Trezor device.") # TODO how to handle UI? return session.call(messages.ButtonAck()) +def _callback_passphrase(session: Session, msg: messages.PassphraseRequest) -> t.Any: + available_on_device = Capability.PassphraseEntry in Session.features.capabilities + + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> t.Any: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + session.session_id = resp.state + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp + + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + passphrase = session.client.ui.get_passphrase( + available_on_device=available_on_device + ) # TODO + except exceptions.Cancelled: + session.call_raw(messages.Cancel()) + raise + + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: + session.call_raw(messages.Cancel()) + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) + + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") + + return send_passphrase(passphrase, on_device=False) + + class SessionV2(Session): @classmethod