From 51d8c54fcb136dacb4b1645a472df1b3846b84d1 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 28 Feb 2025 17:13:59 +0100 Subject: [PATCH] chore(python): session passphrase rework --- python/src/trezorlib/cli/__init__.py | 5 +---- python/src/trezorlib/client.py | 10 ++++++++++ python/src/trezorlib/debuglink.py | 14 ++++++++++---- python/src/trezorlib/transport/session.py | 5 +---- 4 files changed, 22 insertions(+), 12 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 74b6c11d91..92de8014a3 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -161,11 +161,8 @@ class TrezorConnection: else: available_on_device = Capability.PassphraseEntry in features.capabilities passphrase = get_passphrase(available_on_device, self.passphrase_on_host) - # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") session = client.get_session( - passphrase=passphrase, derive_cardano=derive_cardano + passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True ) return session diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index a470ce992b..0f735c24b4 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -212,3 +212,13 @@ def get_default_client( transport.open() return TrezorClient(transport, **kwargs) + + +def get_callback_passphrase_v1( + passphrase: str = "", +) -> t.Callable[[Session, t.Any], t.Any] | None: + + def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any: + return session.call(messages.PassphraseAck(passphrase=passphrase)) + + return _callback_passphrase_v1 diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index c13c71b24f..cd021433b2 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1312,8 +1312,10 @@ class TrezorClientDebugLink(TrezorClient): return send_passphrase(on_device=True) # else process host-entered passphrase + if passphrase is None: + passphrase = "" if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") + raise RuntimeError(f"Passphrase must be a str {type(passphrase)}") passphrase = Mnemonic.normalize_string(passphrase) if len(passphrase) > MAX_PASSPHRASE_LENGTH: session.call_raw(messages.Cancel()) @@ -1331,15 +1333,19 @@ class TrezorClientDebugLink(TrezorClient): def get_session( self, - passphrase: str | object | None = "", + passphrase: str | object | None = None, derive_cardano: bool = False, session_id: bytes | None = None, ) -> SessionDebugWrapper: if isinstance(passphrase, str): passphrase = Mnemonic.normalize_string(passphrase) - return SessionDebugWrapper( - super().get_session(passphrase, derive_cardano, session_id) + session = SessionDebugWrapper( + super().get_session( + passphrase, derive_cardano, session_id, should_derive=False + ) ) + session.passphrase = passphrase + return session def get_seedless_session( self, *args: t.Any, **kwargs: t.Any diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index a2e24f310a..95e12f59d5 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -163,14 +163,11 @@ class SessionV1(Session): def new( cls, client: TrezorClient, - passphrase: str | object = "", derive_cardano: bool = False, session_id: bytes | None = None, ) -> SessionV1: assert isinstance(client.protocol, ProtocolV1Channel) session = SessionV1(client, id=session_id or b"") - - session.passphrase = passphrase session.derive_cardano = derive_cardano session.init_session(derive_cardano=session.derive_cardano) return session @@ -202,7 +199,7 @@ class SessionV1(Session): assert isinstance(self.client.protocol, ProtocolV1Channel) return self.client.protocol.read() - def init_session(self, derive_cardano: bool | None = None): + def init_session(self, derive_cardano: bool | None = None) -> None: if self.id == b"": new_session = True session_id = None