1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-12 22:26:08 +00:00

chore(python): session passphrase rework

This commit is contained in:
M1nd3r 2025-02-28 17:13:59 +01:00
parent 7bcbe0aac4
commit cd781d7a70
4 changed files with 30 additions and 19 deletions

View File

@ -161,11 +161,8 @@ class TrezorConnection:
else: else:
available_on_device = Capability.PassphraseEntry in features.capabilities available_on_device = Capability.PassphraseEntry in features.capabilities
passphrase = get_passphrase(available_on_device, self.passphrase_on_host) passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str")
session = client.get_session( session = client.get_session(
passphrase=passphrase, derive_cardano=derive_cardano passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True
) )
return session return session

View File

@ -112,13 +112,14 @@ class TrezorClient:
passphrase: str | object | None = None, passphrase: str | object | None = None,
derive_cardano: bool = False, derive_cardano: bool = False,
session_id: bytes | None = None, session_id: bytes | None = None,
should_derive: bool = True,
) -> Session: ) -> Session:
""" """
Returns initialized session (with derived seed). Returns initialized session (with derived seed).
Will fail if the device is not initialized Will fail if the device is not initialized
""" """
from .transport.session import SessionV1 from .transport.session import SessionV1, derive_seed
if isinstance(self.protocol, ProtocolV1Channel): if isinstance(self.protocol, ProtocolV1Channel):
session = SessionV1.new( session = SessionV1.new(
@ -158,11 +159,7 @@ class TrezorClient:
if not new_session and self._seedless_session is not None: if not new_session and self._seedless_session is not None:
return self._seedless_session return self._seedless_session
if isinstance(self.protocol, ProtocolV1Channel): if isinstance(self.protocol, ProtocolV1Channel):
self._seedless_session = SessionV1.new( self._seedless_session = SessionV1.new(client=self, derive_cardano=False)
client=self,
passphrase="",
derive_cardano=False,
)
assert self._seedless_session is not None assert self._seedless_session is not None
return self._seedless_session return self._seedless_session
@ -249,3 +246,13 @@ def get_default_client(
transport.open() transport.open()
return TrezorClient(transport, **kwargs) return TrezorClient(transport, **kwargs)
def get_callback_passphrase_v1(
passphrase: str = "",
) -> t.Callable[[Session, t.Any], t.Any] | None:
def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any:
return session.call(messages.PassphraseAck(passphrase=passphrase))
return _callback_passphrase_v1

View File

@ -44,7 +44,7 @@ from .log import DUMP_BYTES
from .messages import Capability, DebugWaitType from .messages import Capability, DebugWaitType
from .protobuf import MessageType from .protobuf import MessageType
from .tools import parse_path from .tools import parse_path
from .transport.session import Session from .transport.session import Session, SessionV1, derive_seed
from .transport.thp.protocol_v1 import ProtocolV1Channel from .transport.thp.protocol_v1 import ProtocolV1Channel
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
@ -1303,8 +1303,10 @@ class TrezorClientDebugLink(TrezorClient):
return send_passphrase(on_device=True) return send_passphrase(on_device=True)
# else process host-entered passphrase # else process host-entered passphrase
if passphrase is None:
passphrase = ""
if not isinstance(passphrase, str): if not isinstance(passphrase, str):
raise RuntimeError("Passphrase must be a str") raise RuntimeError(f"Passphrase must be a str {type(passphrase)}")
passphrase = Mnemonic.normalize_string(passphrase) passphrase = Mnemonic.normalize_string(passphrase)
if len(passphrase) > MAX_PASSPHRASE_LENGTH: if len(passphrase) > MAX_PASSPHRASE_LENGTH:
session.call_raw(messages.Cancel()) session.call_raw(messages.Cancel())
@ -1322,15 +1324,23 @@ class TrezorClientDebugLink(TrezorClient):
def get_session( def get_session(
self, self,
passphrase: str | object | None = "", passphrase: str | object | None = None,
derive_cardano: bool = False, derive_cardano: bool = False,
session_id: bytes | None = None, session_id: bytes | None = None,
should_derive: bool = False,
) -> SessionDebugWrapper: ) -> SessionDebugWrapper:
if isinstance(passphrase, str): if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase) passphrase = Mnemonic.normalize_string(passphrase)
return SessionDebugWrapper( session = SessionDebugWrapper(
super().get_session(passphrase, derive_cardano, session_id) super().get_session(
passphrase, derive_cardano, session_id, should_derive=False
)
) )
session.passphrase = passphrase
if isinstance(session._session, SessionV1) and should_derive:
derive_seed(session=session)
return session
def get_seedless_session( def get_seedless_session(
self, *args: t.Any, **kwargs: t.Any self, *args: t.Any, **kwargs: t.Any

View File

@ -131,14 +131,11 @@ class SessionV1(Session):
def new( def new(
cls, cls,
client: TrezorClient, client: TrezorClient,
passphrase: str | object = "",
derive_cardano: bool = False, derive_cardano: bool = False,
session_id: bytes | None = None, session_id: bytes | None = None,
) -> SessionV1: ) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1Channel) assert isinstance(client.protocol, ProtocolV1Channel)
session = SessionV1(client, id=session_id or b"") session = SessionV1(client, id=session_id or b"")
session.passphrase = passphrase
session.derive_cardano = derive_cardano session.derive_cardano = derive_cardano
session.init_session(session.derive_cardano) session.init_session(session.derive_cardano)
return session return session
@ -160,7 +157,7 @@ class SessionV1(Session):
assert isinstance(self.client.protocol, ProtocolV1Channel) assert isinstance(self.client.protocol, ProtocolV1Channel)
return self.client.protocol.read() return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None): def init_session(self, derive_cardano: bool | None = None) -> None:
if self.id == b"": if self.id == b"":
session_id = None session_id = None
else: else: