diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 05b8b75e9f..1d324bf3f1 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -161,11 +161,8 @@ class TrezorConnection: else: available_on_device = Capability.PassphraseEntry in features.capabilities passphrase = get_passphrase(available_on_device, self.passphrase_on_host) - # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") session = client.get_session( - passphrase=passphrase, derive_cardano=derive_cardano + passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True ) return session diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 48d8b98b4f..fb9ac1dc8f 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -112,13 +112,14 @@ class TrezorClient: passphrase: str | object | None = None, derive_cardano: bool = False, session_id: bytes | None = None, + should_derive: bool = True, ) -> Session: """ Returns initialized session (with derived seed). Will fail if the device is not initialized """ - from .transport.session import SessionV1 + from .transport.session import SessionV1, derive_seed if isinstance(self.protocol, ProtocolV1Channel): session = SessionV1.new( @@ -158,11 +159,7 @@ class TrezorClient: if not new_session and self._seedless_session is not None: return self._seedless_session if isinstance(self.protocol, ProtocolV1Channel): - self._seedless_session = SessionV1.new( - client=self, - passphrase="", - derive_cardano=False, - ) + self._seedless_session = SessionV1.new(client=self, derive_cardano=False) assert self._seedless_session is not None return self._seedless_session @@ -249,3 +246,13 @@ def get_default_client( transport.open() return TrezorClient(transport, **kwargs) + + +def get_callback_passphrase_v1( + passphrase: str = "", +) -> t.Callable[[Session, t.Any], t.Any] | None: + + def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any: + return session.call(messages.PassphraseAck(passphrase=passphrase)) + + return _callback_passphrase_v1 diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 9ea029784e..3bf126037c 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -44,7 +44,7 @@ from .log import DUMP_BYTES from .messages import Capability, DebugWaitType from .protobuf import MessageType from .tools import parse_path -from .transport.session import Session +from .transport.session import Session, SessionV1, derive_seed from .transport.thp.protocol_v1 import ProtocolV1Channel if t.TYPE_CHECKING: @@ -1303,8 +1303,10 @@ class TrezorClientDebugLink(TrezorClient): return send_passphrase(on_device=True) # else process host-entered passphrase + if passphrase is None: + passphrase = "" if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") + raise RuntimeError(f"Passphrase must be a str {type(passphrase)}") passphrase = Mnemonic.normalize_string(passphrase) if len(passphrase) > MAX_PASSPHRASE_LENGTH: session.call_raw(messages.Cancel()) @@ -1322,15 +1324,23 @@ class TrezorClientDebugLink(TrezorClient): def get_session( self, - passphrase: str | object | None = "", + passphrase: str | object | None = None, derive_cardano: bool = False, session_id: bytes | None = None, + should_derive: bool = False, ) -> SessionDebugWrapper: if isinstance(passphrase, str): passphrase = Mnemonic.normalize_string(passphrase) - return SessionDebugWrapper( - super().get_session(passphrase, derive_cardano, session_id) + session = SessionDebugWrapper( + super().get_session( + passphrase, derive_cardano, session_id, should_derive=False + ) ) + session.passphrase = passphrase + + if isinstance(session._session, SessionV1) and should_derive: + derive_seed(session=session) + return session def get_seedless_session( self, *args: t.Any, **kwargs: t.Any diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 527a2a1d6e..8e3e1d0a94 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -131,14 +131,11 @@ class SessionV1(Session): def new( cls, client: TrezorClient, - passphrase: str | object = "", derive_cardano: bool = False, session_id: bytes | None = None, ) -> SessionV1: assert isinstance(client.protocol, ProtocolV1Channel) session = SessionV1(client, id=session_id or b"") - - session.passphrase = passphrase session.derive_cardano = derive_cardano session.init_session(session.derive_cardano) return session @@ -160,7 +157,7 @@ class SessionV1(Session): assert isinstance(self.client.protocol, ProtocolV1Channel) return self.client.protocol.read() - def init_session(self, derive_cardano: bool | None = None): + def init_session(self, derive_cardano: bool | None = None) -> None: if self.id == b"": session_id = None else: