From 107e0e3c7b898226a706518fc020f4a43f81df13 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 26 Mar 2025 16:48:50 +0100 Subject: [PATCH] fixup! feat(python): implement session based trezorlib --- python/src/trezorlib/client.py | 38 ++++---- python/src/trezorlib/debuglink.py | 51 ++++------- python/src/trezorlib/exceptions.py | 17 ++++ python/src/trezorlib/transport/session.py | 103 +++++++++++++++++----- python/src/trezorlib/ui.py | 33 ++----- 5 files changed, 136 insertions(+), 106 deletions(-) diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index d1679a5482..cede4c8af5 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -28,7 +28,7 @@ from .transport.thp.protocol_and_channel import Channel from .transport.thp.protocol_v1 import ProtocolV1Channel if t.TYPE_CHECKING: - from .transport.session import Session + from .transport.session import Session, SessionV1 LOG = logging.getLogger(__name__) @@ -36,6 +36,7 @@ MAX_PASSPHRASE_LENGTH = 50 MAX_PIN_LENGTH = 50 PASSPHRASE_ON_DEVICE = object() +SEEDLESS = object() PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") OUTDATED_FIRMWARE_ERROR = """ @@ -51,16 +52,17 @@ class ProtocolVersion(IntEnum): class TrezorClient: - button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None + button_callback: t.Callable[[messages.ButtonRequest], None] | None = None passphrase_callback: ( t.Callable[[Session, messages.PassphraseRequest], t.Any] | None ) = None - pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None + pin_callback: t.Callable[[messages.PinMatrixRequest], str] | None = None _model: models.TrezorModel _features: messages.Features | None = None _protocol_version: int _setup_pin: str | None = None # Should be used only by conftest + _last_active_session: SessionV1 | None = None def __init__( self, @@ -99,42 +101,36 @@ class TrezorClient: def get_session( self, - passphrase: str | object | None = None, + passphrase: str | object = "", derive_cardano: bool = False, - session_id: bytes | None = None, - should_derive: bool = True, ) -> Session: """ Returns a new session. - In case of seed derivation, the function will fail if the device is not initialized. + In the case of seed derivation, the function will fail if the device is not initialized. """ - from .transport.session import SessionV1, derive_seed + if self.features.initialized is False and passphrase is not SEEDLESS: + raise exceptions.DerivationOnUninitaizedDeviceError( + "Calling uninitialized device with a passphrase. Call get_seedless_session instead." + ) if isinstance(self.protocol, ProtocolV1Channel): - if passphrase is None: + from .transport.session import SessionV1, derive_seed + + if passphrase is SEEDLESS: return SessionV1.new(client=self, derive_cardano=False) session = SessionV1.new( self, derive_cardano=derive_cardano, - session_id=session_id, ) - if should_derive: - if isinstance(passphrase, str): - temporary = self.passphrase_callback - self.passphrase_callback = get_callback_passphrase_v1( - passphrase=passphrase - ) - derive_seed(session) - self.passphrase_callback = temporary - elif passphrase is PASSPHRASE_ON_DEVICE: - derive_seed(session) + if self.features.passphrase_protection: + derive_seed(session, passphrase) return session raise NotImplementedError def get_seedless_session(self) -> Session: - return self.get_session(passphrase=None) + return self.get_session(passphrase=SEEDLESS) def invalidate(self) -> None: self._is_invalidated = True diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index dbf86e2f91..6801374c5c 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -34,12 +34,11 @@ from mnemonic import Mnemonic from . import btc, mapping, messages, models, protobuf from .client import ( MAX_PASSPHRASE_LENGTH, - MAX_PIN_LENGTH, PASSPHRASE_ON_DEVICE, ProtocolVersion, TrezorClient, ) -from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError +from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError from .log import DUMP_BYTES from .messages import Capability, DebugWaitType from .protobuf import MessageType @@ -51,7 +50,6 @@ from .transport.thp.protocol_v1 import ProtocolV1Channel if t.TYPE_CHECKING: from typing_extensions import Protocol - from .messages import PinMatrixRequestType from .transport import Transport ExpectedMessage = t.Union[ @@ -839,7 +837,7 @@ class DebugUI: except StopIteration: self.input_flow = self.INPUT_FLOW_DONE - def get_pin(self, code: PinMatrixRequestType | None = None) -> str: + def get_pin(self) -> str: self.debuglink.snapshot_legacy() if self.pins is None: @@ -1251,6 +1249,16 @@ class TrezorClientDebugLink(TrezorClient): self.transport = transport self.ui: DebugUI = DebugUI(self.debug) + def get_pin(_msg: messages.PinMatrixRequest) -> str: + try: + pin = self.ui.get_pin() + except Cancelled: + raise + return pin + + self.pin_callback = get_pin + self.button_callback = self.ui.button_request + self.sync_responses() # So that we can choose right screenshotting logic (T1 vs TT) @@ -1274,35 +1282,6 @@ class TrezorClientDebugLink(TrezorClient): new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter return new_client - def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - session._write(messages.ButtonAck()) - self.ui.button_request(msg) - return session._read() - - def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any: - try: - pin = self.ui.get_pin(msg.type) - except Cancelled: - session.call_raw(messages.Cancel()) - raise - - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - session.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - resp = session.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise PinException(resp.code, resp.message) - else: - return resp - def passphrase_callback( self, session: Session, msg: messages.PassphraseRequest ) -> t.Any: @@ -1369,15 +1348,15 @@ class TrezorClientDebugLink(TrezorClient): def get_session( self, - passphrase: str | object | None = None, + passphrase: str | object = "", derive_cardano: bool = False, - session_id: bytes | None = None, ) -> SessionDebugWrapper: if isinstance(passphrase, str): passphrase = Mnemonic.normalize_string(passphrase) session = SessionDebugWrapper( super().get_session( - passphrase, derive_cardano, session_id, should_derive=False + passphrase, + derive_cardano, ) ) session.passphrase = passphrase diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index 0d0ab892ed..66b726dbfa 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -92,3 +92,20 @@ class FailedSessionResumption(TrezorException): Raised when `trezorctl -s ` is used or `TREZOR_SESSION_ID = ` is set and resumption of session with the `session_id` fails.""" + + def __init__(self, received_session_id: bytes | None = None): + # We keep the session id that was received from Trezor for test purposes + self.received_session_id = received_session_id + super().__init__("Failed to resume session") + + +class InvalidSessionError(TrezorException): + """Session expired and is no longer valid. + + Raised when Trezor returns unexpected PassphraseRequest""" + + +class DerivationOnUninitaizedDeviceError(TrezorException): + """Tried to derive seed on uninitialized device. + + To communicate with uninitialized device, use seedless session instead.""" diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py index 0cf91599cb..cfd03a3b46 100644 --- a/python/src/trezorlib/transport/session.py +++ b/python/src/trezorlib/transport/session.py @@ -4,6 +4,7 @@ import logging import typing as t from .. import exceptions, messages, models +from ..client import MAX_PIN_LENGTH from ..protobuf import MessageType from .thp.protocol_v1 import ProtocolV1Channel @@ -20,23 +21,28 @@ class Session: self.client = client self._id = id - def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT: + def call( + self, + msg: MessageType, + expect: type[MT] = MessageType, + _passphrase_ack: messages.PassphraseAck | None = None, + ) -> MT: self.client.check_firmware_version() resp = self.call_raw(msg) while True: if isinstance(resp, messages.PinMatrixRequest): if self.client.pin_callback is None: - raise NotImplementedError("Missing pin_callback") - resp = self.client.pin_callback(self, resp) + raise RuntimeError("Missing pin_callback") + resp = self._callback_pin(resp) elif isinstance(resp, messages.PassphraseRequest): - if self.client.passphrase_callback is None: - raise NotImplementedError("Missing passphrase_callback") - resp = self.client.passphrase_callback(self, resp) + if _passphrase_ack is None: + # we got a PassphraseRequest when not explicitly trying to unlock + # the session, this means that the session has expired + raise exceptions.InvalidSessionError + resp = self.call_raw(_passphrase_ack) elif isinstance(resp, messages.ButtonRequest): - resp = (self.client.button_callback or default_button_callback)( - self, resp - ) + resp = self._callback_button(resp) elif isinstance(resp, messages.Failure): if resp.code == messages.FailureType.ActionCancelled: raise exceptions.Cancelled @@ -50,6 +56,39 @@ class Session: self._write(msg) return self._read() + def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType: + if self.client.pin_callback is None: + raise RuntimeError("No PIN provided") + try: + pin = self.client.pin_callback(msg) + except exceptions.Cancelled: + self.call_raw(messages.Cancel()) + raise + + if any(d not in "123456789" for d in pin) or not ( + 1 <= len(pin) <= MAX_PIN_LENGTH + ): + self.call_raw(messages.Cancel()) + raise ValueError("Invalid PIN provided") + + resp = self.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise exceptions.PinException(resp.code, resp.message) + else: + return resp + + def _callback_button(self, msg: messages.ButtonRequest) -> MessageType: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + self._write(messages.ButtonAck()) + if self.client.button_callback: + self.client.button_callback(msg) + return self._read() + def _write(self, msg: t.Any) -> None: raise NotImplementedError @@ -78,9 +117,7 @@ class Session: if isinstance(resp, messages.ButtonRequest): # device is PIN-locked. # respond and hope for the best - resp = (self.client.button_callback or default_button_callback)( - self, resp - ) + resp = self._callback_button(resp) resp = messages.Success.ensure_isinstance(resp) assert resp.message is not None return resp.message @@ -120,6 +157,7 @@ class Session: class SessionV1(Session): derive_cardano: bool | None = False + _was_initialized_at_least_once = False @classmethod def new( @@ -131,7 +169,7 @@ class SessionV1(Session): assert isinstance(client.protocol, ProtocolV1Channel) session = SessionV1(client, id=session_id or b"") session.derive_cardano = derive_cardano - session.init_session(session.derive_cardano) + session.init_session(derive_cardano=session.derive_cardano) return session @classmethod @@ -142,39 +180,58 @@ class SessionV1(Session): return session def resume(self) -> None: - self.init_session(self.derive_cardano) + self.init_session(derive_cardano=self.derive_cardano) def _write(self, msg: t.Any) -> None: + self._activate_self() if t.TYPE_CHECKING: assert isinstance(self.client.protocol, ProtocolV1Channel) self.client.protocol.write(msg) + def _activate_self(self) -> None: + if self.client._last_active_session is not self: + self.client._last_active_session = self + # self.resume() + def _read(self) -> t.Any: + assert self.client._last_active_session is self if t.TYPE_CHECKING: assert isinstance(self.client.protocol, ProtocolV1Channel) return self.client.protocol.read() def init_session(self, derive_cardano: bool | None = None) -> None: if self.id == b"": + new_session = True session_id = None else: + new_session = False session_id = self.id resp: messages.Features = self.call_raw( messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) ) assert isinstance(resp, messages.Features) - if resp.session_id is not None: + if new_session: + assert resp.session_id is not None + assert len(resp.session_id) == 32 self.id = resp.session_id + elif self.id != resp.session_id: + raise exceptions.FailedSessionResumption(resp.session_id) + self.was_initialized_at_least_once = True -def default_button_callback(session: Session, msg: t.Any) -> t.Any: - return session.call_raw(messages.ButtonAck()) +def derive_seed(session: Session, passphrase: str | object) -> None: + from ..client import PASSPHRASE_ON_DEVICE, PASSPHRASE_TEST_PATH -def derive_seed(session: Session) -> None: - - from ..btc import get_address - from ..client import PASSPHRASE_TEST_PATH - - get_address(session, "Testnet", PASSPHRASE_TEST_PATH) + if passphrase is PASSPHRASE_ON_DEVICE: + ack = messages.PassphraseAck(on_device=True) + elif isinstance(passphrase, str): + ack = messages.PassphraseAck(passphrase=passphrase) + else: + raise ValueError("Invalid passphrase") + session.call( + messages.GetAddress(address_n=PASSPHRASE_TEST_PATH, coin_name="Testnet"), + expect=messages.Address, + _passphrase_ack=ack, + ) session.refresh_features() diff --git a/python/src/trezorlib/ui.py b/python/src/trezorlib/ui.py index 5d8ec4dfd7..644993e709 100644 --- a/python/src/trezorlib/ui.py +++ b/python/src/trezorlib/ui.py @@ -23,7 +23,7 @@ from mnemonic import Mnemonic from . import device, messages from .client import MAX_PIN_LENGTH -from .exceptions import Cancelled, PinException +from .exceptions import Cancelled from .messages import Capability, PinMatrixRequestType, WordRequestType from .transport.session import Session @@ -91,15 +91,14 @@ class ClickUI: return "Please confirm action on your Trezor device." - def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any: + def button_request(self, br: messages.ButtonRequest) -> None: prompt = self._prompt_for_button(br) if prompt != self.last_prompt_shown: echo(prompt) if not self.always_prompt: self.last_prompt_shown = prompt - return session.call_raw(messages.ButtonAck()) - def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any: + def get_pin(self, request: messages.PinMatrixRequest) -> str: code = request.type if code == PIN_CURRENT: desc = "current PIN" @@ -123,7 +122,6 @@ class ClickUI: try: pin = prompt(f"Please enter {desc}", hide_input=True) except click.Abort: - session.call_raw(messages.Cancel()) raise Cancelled from None # translate letters to numbers if letters were used @@ -137,15 +135,7 @@ class ClickUI: elif len(pin) > MAX_PIN_LENGTH: echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.") else: - resp = session.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise PinException(resp.code, resp.message) - else: - return resp + return pin def get_passphrase( self, session: Session, request: messages.PassphraseRequest @@ -206,13 +196,12 @@ class ScriptUI: """ @staticmethod - def button_request(session: Session, br: messages.ButtonRequest) -> t.Any: + def button_request(br: messages.ButtonRequest) -> None: code = br.code.name if br.code else None print(f"?BUTTON code={code} pages={br.pages} name={br.name}") - return session.call_raw(messages.ButtonAck()) @staticmethod - def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any: + def get_pin(request: messages.PinMatrixRequest) -> str: code = request.type if code is None: print("?PIN") @@ -226,15 +215,7 @@ class ScriptUI: raise RuntimeError("Sent PIN must start with ':'") else: pin = pin[1:] - resp = session.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise PinException(resp.code, resp.message) - else: - return resp + return pin @staticmethod def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any: