1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-04-19 08:39:03 +00:00

fixup! feat(python): implement session based trezorlib

This commit is contained in:
M1nd3r 2025-03-26 16:48:50 +01:00
parent 2e627a82dd
commit 107e0e3c7b
5 changed files with 136 additions and 106 deletions

View File

@ -28,7 +28,7 @@ from .transport.thp.protocol_and_channel import Channel
from .transport.thp.protocol_v1 import ProtocolV1Channel
if t.TYPE_CHECKING:
from .transport.session import Session
from .transport.session import Session, SessionV1
LOG = logging.getLogger(__name__)
@ -36,6 +36,7 @@ MAX_PASSPHRASE_LENGTH = 50
MAX_PIN_LENGTH = 50
PASSPHRASE_ON_DEVICE = object()
SEEDLESS = object()
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
OUTDATED_FIRMWARE_ERROR = """
@ -51,16 +52,17 @@ class ProtocolVersion(IntEnum):
class TrezorClient:
button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None
button_callback: t.Callable[[messages.ButtonRequest], None] | None = None
passphrase_callback: (
t.Callable[[Session, messages.PassphraseRequest], t.Any] | None
) = None
pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None
pin_callback: t.Callable[[messages.PinMatrixRequest], str] | None = None
_model: models.TrezorModel
_features: messages.Features | None = None
_protocol_version: int
_setup_pin: str | None = None # Should be used only by conftest
_last_active_session: SessionV1 | None = None
def __init__(
self,
@ -99,42 +101,36 @@ class TrezorClient:
def get_session(
self,
passphrase: str | object | None = None,
passphrase: str | object = "",
derive_cardano: bool = False,
session_id: bytes | None = None,
should_derive: bool = True,
) -> Session:
"""
Returns a new session.
In case of seed derivation, the function will fail if the device is not initialized.
In the case of seed derivation, the function will fail if the device is not initialized.
"""
from .transport.session import SessionV1, derive_seed
if self.features.initialized is False and passphrase is not SEEDLESS:
raise exceptions.DerivationOnUninitaizedDeviceError(
"Calling uninitialized device with a passphrase. Call get_seedless_session instead."
)
if isinstance(self.protocol, ProtocolV1Channel):
if passphrase is None:
from .transport.session import SessionV1, derive_seed
if passphrase is SEEDLESS:
return SessionV1.new(client=self, derive_cardano=False)
session = SessionV1.new(
self,
derive_cardano=derive_cardano,
session_id=session_id,
)
if should_derive:
if isinstance(passphrase, str):
temporary = self.passphrase_callback
self.passphrase_callback = get_callback_passphrase_v1(
passphrase=passphrase
)
derive_seed(session)
self.passphrase_callback = temporary
elif passphrase is PASSPHRASE_ON_DEVICE:
derive_seed(session)
if self.features.passphrase_protection:
derive_seed(session, passphrase)
return session
raise NotImplementedError
def get_seedless_session(self) -> Session:
return self.get_session(passphrase=None)
return self.get_session(passphrase=SEEDLESS)
def invalidate(self) -> None:
self._is_invalidated = True

View File

@ -34,12 +34,11 @@ from mnemonic import Mnemonic
from . import btc, mapping, messages, models, protobuf
from .client import (
MAX_PASSPHRASE_LENGTH,
MAX_PIN_LENGTH,
PASSPHRASE_ON_DEVICE,
ProtocolVersion,
TrezorClient,
)
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES
from .messages import Capability, DebugWaitType
from .protobuf import MessageType
@ -51,7 +50,6 @@ from .transport.thp.protocol_v1 import ProtocolV1Channel
if t.TYPE_CHECKING:
from typing_extensions import Protocol
from .messages import PinMatrixRequestType
from .transport import Transport
ExpectedMessage = t.Union[
@ -839,7 +837,7 @@ class DebugUI:
except StopIteration:
self.input_flow = self.INPUT_FLOW_DONE
def get_pin(self, code: PinMatrixRequestType | None = None) -> str:
def get_pin(self) -> str:
self.debuglink.snapshot_legacy()
if self.pins is None:
@ -1251,6 +1249,16 @@ class TrezorClientDebugLink(TrezorClient):
self.transport = transport
self.ui: DebugUI = DebugUI(self.debug)
def get_pin(_msg: messages.PinMatrixRequest) -> str:
try:
pin = self.ui.get_pin()
except Cancelled:
raise
return pin
self.pin_callback = get_pin
self.button_callback = self.ui.button_request
self.sync_responses()
# So that we can choose right screenshotting logic (T1 vs TT)
@ -1274,35 +1282,6 @@ class TrezorClientDebugLink(TrezorClient):
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
return new_client
def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
session._write(messages.ButtonAck())
self.ui.button_request(msg)
return session._read()
def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any:
try:
pin = self.ui.get_pin(msg.type)
except Cancelled:
session.call_raw(messages.Cancel())
raise
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
session.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
def passphrase_callback(
self, session: Session, msg: messages.PassphraseRequest
) -> t.Any:
@ -1369,15 +1348,15 @@ class TrezorClientDebugLink(TrezorClient):
def get_session(
self,
passphrase: str | object | None = None,
passphrase: str | object = "",
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionDebugWrapper:
if isinstance(passphrase, str):
passphrase = Mnemonic.normalize_string(passphrase)
session = SessionDebugWrapper(
super().get_session(
passphrase, derive_cardano, session_id, should_derive=False
passphrase,
derive_cardano,
)
)
session.passphrase = passphrase

View File

@ -92,3 +92,20 @@ class FailedSessionResumption(TrezorException):
Raised when `trezorctl -s <sesssion_id>` is used or `TREZOR_SESSION_ID = <session_id>`
is set and resumption of session with the `session_id` fails."""
def __init__(self, received_session_id: bytes | None = None):
# We keep the session id that was received from Trezor for test purposes
self.received_session_id = received_session_id
super().__init__("Failed to resume session")
class InvalidSessionError(TrezorException):
"""Session expired and is no longer valid.
Raised when Trezor returns unexpected PassphraseRequest"""
class DerivationOnUninitaizedDeviceError(TrezorException):
"""Tried to derive seed on uninitialized device.
To communicate with uninitialized device, use seedless session instead."""

View File

@ -4,6 +4,7 @@ import logging
import typing as t
from .. import exceptions, messages, models
from ..client import MAX_PIN_LENGTH
from ..protobuf import MessageType
from .thp.protocol_v1 import ProtocolV1Channel
@ -20,23 +21,28 @@ class Session:
self.client = client
self._id = id
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
def call(
self,
msg: MessageType,
expect: type[MT] = MessageType,
_passphrase_ack: messages.PassphraseAck | None = None,
) -> MT:
self.client.check_firmware_version()
resp = self.call_raw(msg)
while True:
if isinstance(resp, messages.PinMatrixRequest):
if self.client.pin_callback is None:
raise NotImplementedError("Missing pin_callback")
resp = self.client.pin_callback(self, resp)
raise RuntimeError("Missing pin_callback")
resp = self._callback_pin(resp)
elif isinstance(resp, messages.PassphraseRequest):
if self.client.passphrase_callback is None:
raise NotImplementedError("Missing passphrase_callback")
resp = self.client.passphrase_callback(self, resp)
if _passphrase_ack is None:
# we got a PassphraseRequest when not explicitly trying to unlock
# the session, this means that the session has expired
raise exceptions.InvalidSessionError
resp = self.call_raw(_passphrase_ack)
elif isinstance(resp, messages.ButtonRequest):
resp = (self.client.button_callback or default_button_callback)(
self, resp
)
resp = self._callback_button(resp)
elif isinstance(resp, messages.Failure):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
@ -50,6 +56,39 @@ class Session:
self._write(msg)
return self._read()
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
if self.client.pin_callback is None:
raise RuntimeError("No PIN provided")
try:
pin = self.client.pin_callback(msg)
except exceptions.Cancelled:
self.call_raw(messages.Cancel())
raise
if any(d not in "123456789" for d in pin) or not (
1 <= len(pin) <= MAX_PIN_LENGTH
):
self.call_raw(messages.Cancel())
raise ValueError("Invalid PIN provided")
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise exceptions.PinException(resp.code, resp.message)
else:
return resp
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
__tracebackhide__ = True # for pytest # pylint: disable=W0612
# do this raw - send ButtonAck first, notify UI later
self._write(messages.ButtonAck())
if self.client.button_callback:
self.client.button_callback(msg)
return self._read()
def _write(self, msg: t.Any) -> None:
raise NotImplementedError
@ -78,9 +117,7 @@ class Session:
if isinstance(resp, messages.ButtonRequest):
# device is PIN-locked.
# respond and hope for the best
resp = (self.client.button_callback or default_button_callback)(
self, resp
)
resp = self._callback_button(resp)
resp = messages.Success.ensure_isinstance(resp)
assert resp.message is not None
return resp.message
@ -120,6 +157,7 @@ class Session:
class SessionV1(Session):
derive_cardano: bool | None = False
_was_initialized_at_least_once = False
@classmethod
def new(
@ -131,7 +169,7 @@ class SessionV1(Session):
assert isinstance(client.protocol, ProtocolV1Channel)
session = SessionV1(client, id=session_id or b"")
session.derive_cardano = derive_cardano
session.init_session(session.derive_cardano)
session.init_session(derive_cardano=session.derive_cardano)
return session
@classmethod
@ -142,39 +180,58 @@ class SessionV1(Session):
return session
def resume(self) -> None:
self.init_session(self.derive_cardano)
self.init_session(derive_cardano=self.derive_cardano)
def _write(self, msg: t.Any) -> None:
self._activate_self()
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1Channel)
self.client.protocol.write(msg)
def _activate_self(self) -> None:
if self.client._last_active_session is not self:
self.client._last_active_session = self
# self.resume()
def _read(self) -> t.Any:
assert self.client._last_active_session is self
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1Channel)
return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None) -> None:
if self.id == b"":
new_session = True
session_id = None
else:
new_session = False
session_id = self.id
resp: messages.Features = self.call_raw(
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
)
assert isinstance(resp, messages.Features)
if resp.session_id is not None:
if new_session:
assert resp.session_id is not None
assert len(resp.session_id) == 32
self.id = resp.session_id
elif self.id != resp.session_id:
raise exceptions.FailedSessionResumption(resp.session_id)
self.was_initialized_at_least_once = True
def default_button_callback(session: Session, msg: t.Any) -> t.Any:
return session.call_raw(messages.ButtonAck())
def derive_seed(session: Session, passphrase: str | object) -> None:
from ..client import PASSPHRASE_ON_DEVICE, PASSPHRASE_TEST_PATH
def derive_seed(session: Session) -> None:
from ..btc import get_address
from ..client import PASSPHRASE_TEST_PATH
get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
if passphrase is PASSPHRASE_ON_DEVICE:
ack = messages.PassphraseAck(on_device=True)
elif isinstance(passphrase, str):
ack = messages.PassphraseAck(passphrase=passphrase)
else:
raise ValueError("Invalid passphrase")
session.call(
messages.GetAddress(address_n=PASSPHRASE_TEST_PATH, coin_name="Testnet"),
expect=messages.Address,
_passphrase_ack=ack,
)
session.refresh_features()

View File

@ -23,7 +23,7 @@ from mnemonic import Mnemonic
from . import device, messages
from .client import MAX_PIN_LENGTH
from .exceptions import Cancelled, PinException
from .exceptions import Cancelled
from .messages import Capability, PinMatrixRequestType, WordRequestType
from .transport.session import Session
@ -91,15 +91,14 @@ class ClickUI:
return "Please confirm action on your Trezor device."
def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any:
def button_request(self, br: messages.ButtonRequest) -> None:
prompt = self._prompt_for_button(br)
if prompt != self.last_prompt_shown:
echo(prompt)
if not self.always_prompt:
self.last_prompt_shown = prompt
return session.call_raw(messages.ButtonAck())
def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any:
def get_pin(self, request: messages.PinMatrixRequest) -> str:
code = request.type
if code == PIN_CURRENT:
desc = "current PIN"
@ -123,7 +122,6 @@ class ClickUI:
try:
pin = prompt(f"Please enter {desc}", hide_input=True)
except click.Abort:
session.call_raw(messages.Cancel())
raise Cancelled from None
# translate letters to numbers if letters were used
@ -137,15 +135,7 @@ class ClickUI:
elif len(pin) > MAX_PIN_LENGTH:
echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
else:
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
return pin
def get_passphrase(
self, session: Session, request: messages.PassphraseRequest
@ -206,13 +196,12 @@ class ScriptUI:
"""
@staticmethod
def button_request(session: Session, br: messages.ButtonRequest) -> t.Any:
def button_request(br: messages.ButtonRequest) -> None:
code = br.code.name if br.code else None
print(f"?BUTTON code={code} pages={br.pages} name={br.name}")
return session.call_raw(messages.ButtonAck())
@staticmethod
def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any:
def get_pin(request: messages.PinMatrixRequest) -> str:
code = request.type
if code is None:
print("?PIN")
@ -226,15 +215,7 @@ class ScriptUI:
raise RuntimeError("Sent PIN must start with ':'")
else:
pin = pin[1:]
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
if isinstance(resp, messages.Failure) and resp.code in (
messages.FailureType.PinInvalid,
messages.FailureType.PinCancelled,
messages.FailureType.PinExpected,
):
raise PinException(resp.code, resp.message)
else:
return resp
return pin
@staticmethod
def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any: