1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-25 03:29:02 +00:00

chore(python): session passphrase rework

This commit is contained in:
M1nd3r 2025-02-28 17:13:59 +01:00
parent 99dc4ace6c
commit 51d8c54fcb
4 changed files with 22 additions and 12 deletions

View File

@ -161,11 +161,8 @@ class TrezorConnection:
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True
)
return session

View File

@ -212,3 +212,13 @@ def get_default_client(
transport.open()
return TrezorClient(transport, **kwargs)
def get_callback_passphrase_v1(
passphrase: str = "",
) -> t.Callable[[Session, t.Any], t.Any] | None:
def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any:
return session.call(messages.PassphraseAck(passphrase=passphrase))
return _callback_passphrase_v1

View File

@ -1312,8 +1312,10 @@ class TrezorClientDebugLink(TrezorClient):
return send_passphrase(on_device=True)
# else process host-entered passphrase
if passphrase is None:
passphrase = ""
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
raise RuntimeError(f"Passphrase must be a str {type(passphrase)}")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
@ -1331,15 +1333,19 @@ class TrezorClientDebugLink(TrezorClient):
def get_session(
self,
passphrase: str | object | None = "",
passphrase: str | object | None = None,
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionDebugWrapper:
if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase)
return SessionDebugWrapper(
super().get_session(passphrase, derive_cardano, session_id)
session = SessionDebugWrapper(
super().get_session(
passphrase, derive_cardano, session_id, should_derive=False
)
)
session.passphrase = passphrase
return session
def get_seedless_session(
self, *args: t.Any, **kwargs: t.Any

View File

@ -163,14 +163,11 @@ class SessionV1(Session):
def new(
cls,
client: TrezorClient,
passphrase: str | object = "",
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1Channel)
session = SessionV1(client, id=session_id or b"")
session.passphrase = passphrase
session.derive_cardano = derive_cardano
session.init_session(derive_cardano=session.derive_cardano)
return session
@ -202,7 +199,7 @@ class SessionV1(Session):
assert isinstance(self.client.protocol, ProtocolV1Channel)
return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None):
def init_session(self, derive_cardano: bool | None = None) -> None:
if self.id == b"":
new_session = True
session_id = None