mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-12 22:26:08 +00:00
chore(python): session passphrase rework
This commit is contained in:
parent
7bcbe0aac4
commit
cd781d7a70
@ -161,11 +161,8 @@ class TrezorConnection:
|
|||||||
else:
|
else:
|
||||||
available_on_device = Capability.PassphraseEntry in features.capabilities
|
available_on_device = Capability.PassphraseEntry in features.capabilities
|
||||||
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
passphrase = get_passphrase(available_on_device, self.passphrase_on_host)
|
||||||
# TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func
|
|
||||||
if not isinstance(passphrase, str):
|
|
||||||
raise RuntimeError("Passphrase must be a str")
|
|
||||||
session = client.get_session(
|
session = client.get_session(
|
||||||
passphrase=passphrase, derive_cardano=derive_cardano
|
passphrase=passphrase, derive_cardano=derive_cardano, should_derive=True
|
||||||
)
|
)
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
@ -112,13 +112,14 @@ class TrezorClient:
|
|||||||
passphrase: str | object | None = None,
|
passphrase: str | object | None = None,
|
||||||
derive_cardano: bool = False,
|
derive_cardano: bool = False,
|
||||||
session_id: bytes | None = None,
|
session_id: bytes | None = None,
|
||||||
|
should_derive: bool = True,
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""
|
"""
|
||||||
Returns initialized session (with derived seed).
|
Returns initialized session (with derived seed).
|
||||||
|
|
||||||
Will fail if the device is not initialized
|
Will fail if the device is not initialized
|
||||||
"""
|
"""
|
||||||
from .transport.session import SessionV1
|
from .transport.session import SessionV1, derive_seed
|
||||||
|
|
||||||
if isinstance(self.protocol, ProtocolV1Channel):
|
if isinstance(self.protocol, ProtocolV1Channel):
|
||||||
session = SessionV1.new(
|
session = SessionV1.new(
|
||||||
@ -158,11 +159,7 @@ class TrezorClient:
|
|||||||
if not new_session and self._seedless_session is not None:
|
if not new_session and self._seedless_session is not None:
|
||||||
return self._seedless_session
|
return self._seedless_session
|
||||||
if isinstance(self.protocol, ProtocolV1Channel):
|
if isinstance(self.protocol, ProtocolV1Channel):
|
||||||
self._seedless_session = SessionV1.new(
|
self._seedless_session = SessionV1.new(client=self, derive_cardano=False)
|
||||||
client=self,
|
|
||||||
passphrase="",
|
|
||||||
derive_cardano=False,
|
|
||||||
)
|
|
||||||
assert self._seedless_session is not None
|
assert self._seedless_session is not None
|
||||||
return self._seedless_session
|
return self._seedless_session
|
||||||
|
|
||||||
@ -249,3 +246,13 @@ def get_default_client(
|
|||||||
transport.open()
|
transport.open()
|
||||||
|
|
||||||
return TrezorClient(transport, **kwargs)
|
return TrezorClient(transport, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_callback_passphrase_v1(
|
||||||
|
passphrase: str = "",
|
||||||
|
) -> t.Callable[[Session, t.Any], t.Any] | None:
|
||||||
|
|
||||||
|
def _callback_passphrase_v1(session: Session, msg: t.Any) -> t.Any:
|
||||||
|
return session.call(messages.PassphraseAck(passphrase=passphrase))
|
||||||
|
|
||||||
|
return _callback_passphrase_v1
|
||||||
|
@ -44,7 +44,7 @@ from .log import DUMP_BYTES
|
|||||||
from .messages import Capability, DebugWaitType
|
from .messages import Capability, DebugWaitType
|
||||||
from .protobuf import MessageType
|
from .protobuf import MessageType
|
||||||
from .tools import parse_path
|
from .tools import parse_path
|
||||||
from .transport.session import Session
|
from .transport.session import Session, SessionV1, derive_seed
|
||||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||||
|
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
@ -1303,8 +1303,10 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
return send_passphrase(on_device=True)
|
return send_passphrase(on_device=True)
|
||||||
|
|
||||||
# else process host-entered passphrase
|
# else process host-entered passphrase
|
||||||
|
if passphrase is None:
|
||||||
|
passphrase = ""
|
||||||
if not isinstance(passphrase, str):
|
if not isinstance(passphrase, str):
|
||||||
raise RuntimeError("Passphrase must be a str")
|
raise RuntimeError(f"Passphrase must be a str {type(passphrase)}")
|
||||||
passphrase = Mnemonic.normalize_string(passphrase)
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
if len(passphrase) > MAX_PASSPHRASE_LENGTH:
|
||||||
session.call_raw(messages.Cancel())
|
session.call_raw(messages.Cancel())
|
||||||
@ -1322,15 +1324,23 @@ class TrezorClientDebugLink(TrezorClient):
|
|||||||
|
|
||||||
def get_session(
|
def get_session(
|
||||||
self,
|
self,
|
||||||
passphrase: str | object | None = "",
|
passphrase: str | object | None = None,
|
||||||
derive_cardano: bool = False,
|
derive_cardano: bool = False,
|
||||||
session_id: bytes | None = None,
|
session_id: bytes | None = None,
|
||||||
|
should_derive: bool = False,
|
||||||
) -> SessionDebugWrapper:
|
) -> SessionDebugWrapper:
|
||||||
if isinstance(passphrase, str):
|
if isinstance(passphrase, str):
|
||||||
passphrase = Mnemonic.normalize_string(passphrase)
|
passphrase = Mnemonic.normalize_string(passphrase)
|
||||||
return SessionDebugWrapper(
|
session = SessionDebugWrapper(
|
||||||
super().get_session(passphrase, derive_cardano, session_id)
|
super().get_session(
|
||||||
|
passphrase, derive_cardano, session_id, should_derive=False
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
session.passphrase = passphrase
|
||||||
|
|
||||||
|
if isinstance(session._session, SessionV1) and should_derive:
|
||||||
|
derive_seed(session=session)
|
||||||
|
return session
|
||||||
|
|
||||||
def get_seedless_session(
|
def get_seedless_session(
|
||||||
self, *args: t.Any, **kwargs: t.Any
|
self, *args: t.Any, **kwargs: t.Any
|
||||||
|
@ -131,14 +131,11 @@ class SessionV1(Session):
|
|||||||
def new(
|
def new(
|
||||||
cls,
|
cls,
|
||||||
client: TrezorClient,
|
client: TrezorClient,
|
||||||
passphrase: str | object = "",
|
|
||||||
derive_cardano: bool = False,
|
derive_cardano: bool = False,
|
||||||
session_id: bytes | None = None,
|
session_id: bytes | None = None,
|
||||||
) -> SessionV1:
|
) -> SessionV1:
|
||||||
assert isinstance(client.protocol, ProtocolV1Channel)
|
assert isinstance(client.protocol, ProtocolV1Channel)
|
||||||
session = SessionV1(client, id=session_id or b"")
|
session = SessionV1(client, id=session_id or b"")
|
||||||
|
|
||||||
session.passphrase = passphrase
|
|
||||||
session.derive_cardano = derive_cardano
|
session.derive_cardano = derive_cardano
|
||||||
session.init_session(session.derive_cardano)
|
session.init_session(session.derive_cardano)
|
||||||
return session
|
return session
|
||||||
@ -160,7 +157,7 @@ class SessionV1(Session):
|
|||||||
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
||||||
return self.client.protocol.read()
|
return self.client.protocol.read()
|
||||||
|
|
||||||
def init_session(self, derive_cardano: bool | None = None):
|
def init_session(self, derive_cardano: bool | None = None) -> None:
|
||||||
if self.id == b"":
|
if self.id == b"":
|
||||||
session_id = None
|
session_id = None
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user