1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-12 22:26:08 +00:00

chore(python): session passphrase rework

This commit is contained in:
M1nd3r 2025-02-28 17:13:59 +01:00
parent 7bcbe0aac4
commit cd781d7a70
4 changed files with 30 additions and 19 deletions

View File

@ -161,11 +161,8 @@ class TrezorConnection:
else:
available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano
passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True
)
return session

View File

@ -112,13 +112,14 @@ class TrezorClient:
passphrase: str | object | None = None,
derive_cardano: bool = False,
session_id: bytes | None = None,
should_derive: bool = True,
) -> Session:
"""
Returns initialized session (with derived seed).
Will fail if the device is not initialized
"""
from .transport.session import SessionV1
from .transport.session import SessionV1, derive_seed
if isinstance(self.protocol, ProtocolV1Channel):
session = SessionV1.new(
@ -158,11 +159,7 @@ class TrezorClient:
if not new_session and self._seedless_session is not None:
return self._seedless_session
if isinstance(self.protocol, ProtocolV1Channel):
self._seedless_session = SessionV1.new(
client=self,
passphrase="",
derive_cardano=False,
)
self._seedless_session = SessionV1.new(client=self, derive_cardano=False)
assert self._seedless_session is not None
return self._seedless_session
@ -249,3 +246,13 @@ def get_default_client(
transport.open()
return TrezorClient(transport, **kwargs)
def get_callback_passphrase_v1(
passphrase: str = "",
) -> t.Callable[[Session, t.Any], t.Any] | None:
def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any:
return session.call(messages.PassphraseAck(passphrase=passphrase))
return _callback_passphrase_v1

View File

@ -44,7 +44,7 @@ from .log import DUMP_BYTES
from .messages import Capability, DebugWaitType
from .protobuf import MessageType
from .tools import parse_path
from .transport.session import Session
from .transport.session import Session, SessionV1, derive_seed
from .transport.thp.protocol_v1 import ProtocolV1Channel
if t.TYPE_CHECKING:
@ -1303,8 +1303,10 @@ class TrezorClientDebugLink(TrezorClient):
return send_passphrase(on_device=True)
# else process host-entered passphrase
if passphrase is None:
passphrase = ""
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
raise RuntimeError(f"Passphrase must be a str {type(passphrase)}")
passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel())
@ -1322,15 +1324,23 @@ class TrezorClientDebugLink(TrezorClient):
def get_session(
self,
passphrase: str | object | None = "",
passphrase: str | object | None = None,
derive_cardano: bool = False,
session_id: bytes | None = None,
should_derive: bool = False,
) -> SessionDebugWrapper:
if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase)
return SessionDebugWrapper(
super().get_session(passphrase, derive_cardano, session_id)
session = SessionDebugWrapper(
super().get_session(
passphrase, derive_cardano, session_id, should_derive=False
)
)
session.passphrase = passphrase
if isinstance(session._session, SessionV1) and should_derive:
derive_seed(session=session)
return session
def get_seedless_session(
self, *args: t.Any, **kwargs: t.Any

View File

@ -131,14 +131,11 @@ class SessionV1(Session):
def new(
cls,
client: TrezorClient,
passphrase: str | object = "",
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1Channel)
session = SessionV1(client, id=session_id or b"")
session.passphrase = passphrase
session.derive_cardano = derive_cardano
session.init_session(session.derive_cardano)
return session
@ -160,7 +157,7 @@ class SessionV1(Session):
assert isinstance(self.client.protocol, ProtocolV1Channel)
return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None):
def init_session(self, derive_cardano: bool | None = None) -> None:
if self.id == b"":
session_id = None
else: