mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-12 22:26:08 +00:00
chore(python): session passphrase rework
This commit is contained in:
parent
7bcbe0aac4
commit
cd781d7a70
@ -161,11 +161,8 @@ class TrezorConnection:
|
||||
else:
|
||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
session = client.get_session(
|
||||
passphrase=passphrase, derive_cardano=derive_cardano
|
||||
passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True
|
||||
)
|
||||
return session
|
||||
|
||||
|
@ -112,13 +112,14 @@ class TrezorClient:
|
||||
passphrase: str | object | None = None,
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
should_derive: bool = True,
|
||||
) -> Session:
|
||||
"""
|
||||
Returns initialized session (with derived seed).
|
||||
|
||||
Will fail if the device is not initialized
|
||||
"""
|
||||
from .transport.session import SessionV1
|
||||
from .transport.session import SessionV1, derive_seed
|
||||
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
session = SessionV1.new(
|
||||
@ -158,11 +159,7 @@ class TrezorClient:
|
||||
if not new_session and self._seedless_session is not None:
|
||||
return self._seedless_session
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
self._seedless_session = SessionV1.new(
|
||||
client=self,
|
||||
passphrase="",
|
||||
derive_cardano=False,
|
||||
)
|
||||
self._seedless_session = SessionV1.new(client=self, derive_cardano=False)
|
||||
assert self._seedless_session is not None
|
||||
return self._seedless_session
|
||||
|
||||
@ -249,3 +246,13 @@ def get_default_client(
|
||||
transport.open()
|
||||
|
||||
return TrezorClient(transport, **kwargs)
|
||||
|
||||
|
||||
def get_callback_passphrase_v1(
|
||||
passphrase: str = "",
|
||||
) -> t.Callable[[Session, t.Any], t.Any] | None:
|
||||
|
||||
def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any:
|
||||
return session.call(messages.PassphraseAck(passphrase=passphrase))
|
||||
|
||||
return _callback_passphrase_v1
|
||||
|
@ -44,7 +44,7 @@ from .log import DUMP_BYTES
|
||||
from .messages import Capability, DebugWaitType
|
||||
from .protobuf import MessageType
|
||||
from .tools import parse_path
|
||||
from .transport.session import Session
|
||||
from .transport.session import Session, SessionV1, derive_seed
|
||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
@ -1303,8 +1303,10 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
return send_passphrase(on_device=True)
|
||||
|
||||
# else process host-entered passphrase
|
||||
if passphrase is None:
|
||||
passphrase = ""
|
||||
if not isinstance(passphrase, str):
|
||||
raise RuntimeError("Passphrase must be a str")
|
||||
raise RuntimeError(f"Passphrase must be a str {type(passphrase)}")
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||
session.call_raw(messages.Cancel())
|
||||
@ -1322,15 +1324,23 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str | object | None = "",
|
||||
passphrase: str | object | None = None,
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
should_derive: bool = False,
|
||||
) -> SessionDebugWrapper:
|
||||
if isinstance(passphrase, str):
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
return SessionDebugWrapper(
|
||||
super().get_session(passphrase, derive_cardano, session_id)
|
||||
session = SessionDebugWrapper(
|
||||
super().get_session(
|
||||
passphrase, derive_cardano, session_id, should_derive=False
|
||||
)
|
||||
)
|
||||
session.passphrase = passphrase
|
||||
|
||||
if isinstance(session._session, SessionV1) and should_derive:
|
||||
derive_seed(session=session)
|
||||
return session
|
||||
|
||||
def get_seedless_session(
|
||||
self, *args: t.Any, **kwargs: t.Any
|
||||
|
@ -131,14 +131,11 @@ class SessionV1(Session):
|
||||
def new(
|
||||
cls,
|
||||
client: TrezorClient,
|
||||
passphrase: str | object = "",
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1Channel)
|
||||
session = SessionV1(client, id=session_id or b"")
|
||||
|
||||
session.passphrase = passphrase
|
||||
session.derive_cardano = derive_cardano
|
||||
session.init_session(session.derive_cardano)
|
||||
return session
|
||||
@ -160,7 +157,7 @@ class SessionV1(Session):
|
||||
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
||||
return self.client.protocol.read()
|
||||
|
||||
def init_session(self, derive_cardano: bool | None = None):
|
||||
def init_session(self, derive_cardano: bool | None = None) -> None:
|
||||
if self.id == b"":
|
||||
session_id = None
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user