mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-04-19 08:39:03 +00:00
fixup! feat(python): implement session based trezorlib
This commit is contained in:
parent
2e627a82dd
commit
107e0e3c7b
@ -28,7 +28,7 @@ from .transport.thp.protocol_and_channel import Channel
|
||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .transport.session import Session
|
||||
from .transport.session import Session, SessionV1
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -36,6 +36,7 @@ MAX_PASSPHRASE_LENGTH = 50
|
||||
MAX_PIN_LENGTH = 50
|
||||
|
||||
PASSPHRASE_ON_DEVICE = object()
|
||||
SEEDLESS = object()
|
||||
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
|
||||
|
||||
OUTDATED_FIRMWARE_ERROR = """
|
||||
@ -51,16 +52,17 @@ class ProtocolVersion(IntEnum):
|
||||
|
||||
|
||||
class TrezorClient:
|
||||
button_callback: t.Callable[[Session, messages.ButtonRequest], t.Any] | None = None
|
||||
button_callback: t.Callable[[messages.ButtonRequest], None] | None = None
|
||||
passphrase_callback: (
|
||||
t.Callable[[Session, messages.PassphraseRequest], t.Any] | None
|
||||
) = None
|
||||
pin_callback: t.Callable[[Session, messages.PinMatrixRequest], t.Any] | None = None
|
||||
pin_callback: t.Callable[[messages.PinMatrixRequest], str] | None = None
|
||||
|
||||
_model: models.TrezorModel
|
||||
_features: messages.Features | None = None
|
||||
_protocol_version: int
|
||||
_setup_pin: str | None = None # Should be used only by conftest
|
||||
_last_active_session: SessionV1 | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -99,42 +101,36 @@ class TrezorClient:
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str | object | None = None,
|
||||
passphrase: str | object = "",
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
should_derive: bool = True,
|
||||
) -> Session:
|
||||
"""
|
||||
Returns a new session.
|
||||
|
||||
In case of seed derivation, the function will fail if the device is not initialized.
|
||||
In the case of seed derivation, the function will fail if the device is not initialized.
|
||||
"""
|
||||
from .transport.session import SessionV1, derive_seed
|
||||
if self.features.initialized is False and passphrase is not SEEDLESS:
|
||||
raise exceptions.DerivationOnUninitaizedDeviceError(
|
||||
"Calling uninitialized device with a passphrase. Call get_seedless_session instead."
|
||||
)
|
||||
|
||||
if isinstance(self.protocol, ProtocolV1Channel):
|
||||
if passphrase is None:
|
||||
from .transport.session import SessionV1, derive_seed
|
||||
|
||||
if passphrase is SEEDLESS:
|
||||
return SessionV1.new(client=self, derive_cardano=False)
|
||||
session = SessionV1.new(
|
||||
self,
|
||||
derive_cardano=derive_cardano,
|
||||
session_id=session_id,
|
||||
)
|
||||
if should_derive:
|
||||
if isinstance(passphrase, str):
|
||||
temporary = self.passphrase_callback
|
||||
self.passphrase_callback = get_callback_passphrase_v1(
|
||||
passphrase=passphrase
|
||||
)
|
||||
derive_seed(session)
|
||||
self.passphrase_callback = temporary
|
||||
elif passphrase is PASSPHRASE_ON_DEVICE:
|
||||
derive_seed(session)
|
||||
if self.features.passphrase_protection:
|
||||
derive_seed(session, passphrase)
|
||||
|
||||
return session
|
||||
raise NotImplementedError
|
||||
|
||||
def get_seedless_session(self) -> Session:
|
||||
return self.get_session(passphrase=None)
|
||||
return self.get_session(passphrase=SEEDLESS)
|
||||
|
||||
def invalidate(self) -> None:
|
||||
self._is_invalidated = True
|
||||
|
@ -34,12 +34,11 @@ from mnemonic import Mnemonic
|
||||
from . import btc, mapping, messages, models, protobuf
|
||||
from .client import (
|
||||
MAX_PASSPHRASE_LENGTH,
|
||||
MAX_PIN_LENGTH,
|
||||
PASSPHRASE_ON_DEVICE,
|
||||
ProtocolVersion,
|
||||
TrezorClient,
|
||||
)
|
||||
from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError
|
||||
from .exceptions import Cancelled, TrezorFailure, UnexpectedMessageError
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import Capability, DebugWaitType
|
||||
from .protobuf import MessageType
|
||||
@ -51,7 +50,6 @@ from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
if t.TYPE_CHECKING:
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from .messages import PinMatrixRequestType
|
||||
from .transport import Transport
|
||||
|
||||
ExpectedMessage = t.Union[
|
||||
@ -839,7 +837,7 @@ class DebugUI:
|
||||
except StopIteration:
|
||||
self.input_flow = self.INPUT_FLOW_DONE
|
||||
|
||||
def get_pin(self, code: PinMatrixRequestType | None = None) -> str:
|
||||
def get_pin(self) -> str:
|
||||
self.debuglink.snapshot_legacy()
|
||||
|
||||
if self.pins is None:
|
||||
@ -1251,6 +1249,16 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
self.transport = transport
|
||||
self.ui: DebugUI = DebugUI(self.debug)
|
||||
|
||||
def get_pin(_msg: messages.PinMatrixRequest) -> str:
|
||||
try:
|
||||
pin = self.ui.get_pin()
|
||||
except Cancelled:
|
||||
raise
|
||||
return pin
|
||||
|
||||
self.pin_callback = get_pin
|
||||
self.button_callback = self.ui.button_request
|
||||
|
||||
self.sync_responses()
|
||||
|
||||
# So that we can choose right screenshotting logic (T1 vs TT)
|
||||
@ -1274,35 +1282,6 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
new_client.debug.t1_screenshot_counter = self.debug.t1_screenshot_counter
|
||||
return new_client
|
||||
|
||||
def button_callback(self, session: Session, msg: messages.ButtonRequest) -> t.Any:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
session._write(messages.ButtonAck())
|
||||
self.ui.button_request(msg)
|
||||
return session._read()
|
||||
|
||||
def pin_callback(self, session: Session, msg: messages.PinMatrixRequest) -> t.Any:
|
||||
try:
|
||||
pin = self.ui.get_pin(msg.type)
|
||||
except Cancelled:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if any(d not in "123456789" for d in pin) or not (
|
||||
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||
):
|
||||
session.call_raw(messages.Cancel())
|
||||
raise ValueError("Invalid PIN provided")
|
||||
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def passphrase_callback(
|
||||
self, session: Session, msg: messages.PassphraseRequest
|
||||
) -> t.Any:
|
||||
@ -1369,15 +1348,15 @@ class TrezorClientDebugLink(TrezorClient):
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str | object | None = None,
|
||||
passphrase: str | object = "",
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
) -> SessionDebugWrapper:
|
||||
if isinstance(passphrase, str):
|
||||
passphrase = Mnemonic.normalize_string(passphrase)
|
||||
session = SessionDebugWrapper(
|
||||
super().get_session(
|
||||
passphrase, derive_cardano, session_id, should_derive=False
|
||||
passphrase,
|
||||
derive_cardano,
|
||||
)
|
||||
)
|
||||
session.passphrase = passphrase
|
||||
|
@ -92,3 +92,20 @@ class FailedSessionResumption(TrezorException):
|
||||
|
||||
Raised when `trezorctl -s <sesssion_id>` is used or `TREZOR_SESSION_ID = <session_id>`
|
||||
is set and resumption of session with the `session_id` fails."""
|
||||
|
||||
def __init__(self, received_session_id: bytes | None = None):
|
||||
# We keep the session id that was received from Trezor for test purposes
|
||||
self.received_session_id = received_session_id
|
||||
super().__init__("Failed to resume session")
|
||||
|
||||
|
||||
class InvalidSessionError(TrezorException):
|
||||
"""Session expired and is no longer valid.
|
||||
|
||||
Raised when Trezor returns unexpected PassphraseRequest"""
|
||||
|
||||
|
||||
class DerivationOnUninitaizedDeviceError(TrezorException):
|
||||
"""Tried to derive seed on uninitialized device.
|
||||
|
||||
To communicate with uninitialized device, use seedless session instead."""
|
||||
|
@ -4,6 +4,7 @@ import logging
|
||||
import typing as t
|
||||
|
||||
from .. import exceptions, messages, models
|
||||
from ..client import MAX_PIN_LENGTH
|
||||
from ..protobuf import MessageType
|
||||
from .thp.protocol_v1 import ProtocolV1Channel
|
||||
|
||||
@ -20,23 +21,28 @@ class Session:
|
||||
self.client = client
|
||||
self._id = id
|
||||
|
||||
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
|
||||
def call(
|
||||
self,
|
||||
msg: MessageType,
|
||||
expect: type[MT] = MessageType,
|
||||
_passphrase_ack: messages.PassphraseAck | None = None,
|
||||
) -> MT:
|
||||
self.client.check_firmware_version()
|
||||
resp = self.call_raw(msg)
|
||||
|
||||
while True:
|
||||
if isinstance(resp, messages.PinMatrixRequest):
|
||||
if self.client.pin_callback is None:
|
||||
raise NotImplementedError("Missing pin_callback")
|
||||
resp = self.client.pin_callback(self, resp)
|
||||
raise RuntimeError("Missing pin_callback")
|
||||
resp = self._callback_pin(resp)
|
||||
elif isinstance(resp, messages.PassphraseRequest):
|
||||
if self.client.passphrase_callback is None:
|
||||
raise NotImplementedError("Missing passphrase_callback")
|
||||
resp = self.client.passphrase_callback(self, resp)
|
||||
if _passphrase_ack is None:
|
||||
# we got a PassphraseRequest when not explicitly trying to unlock
|
||||
# the session, this means that the session has expired
|
||||
raise exceptions.InvalidSessionError
|
||||
resp = self.call_raw(_passphrase_ack)
|
||||
elif isinstance(resp, messages.ButtonRequest):
|
||||
resp = (self.client.button_callback or default_button_callback)(
|
||||
self, resp
|
||||
)
|
||||
resp = self._callback_button(resp)
|
||||
elif isinstance(resp, messages.Failure):
|
||||
if resp.code == messages.FailureType.ActionCancelled:
|
||||
raise exceptions.Cancelled
|
||||
@ -50,6 +56,39 @@ class Session:
|
||||
self._write(msg)
|
||||
return self._read()
|
||||
|
||||
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
|
||||
if self.client.pin_callback is None:
|
||||
raise RuntimeError("No PIN provided")
|
||||
try:
|
||||
pin = self.client.pin_callback(msg)
|
||||
except exceptions.Cancelled:
|
||||
self.call_raw(messages.Cancel())
|
||||
raise
|
||||
|
||||
if any(d not in "123456789" for d in pin) or not (
|
||||
1 <= len(pin) <= MAX_PIN_LENGTH
|
||||
):
|
||||
self.call_raw(messages.Cancel())
|
||||
raise ValueError("Invalid PIN provided")
|
||||
|
||||
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise exceptions.PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
|
||||
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
||||
# do this raw - send ButtonAck first, notify UI later
|
||||
self._write(messages.ButtonAck())
|
||||
if self.client.button_callback:
|
||||
self.client.button_callback(msg)
|
||||
return self._read()
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -78,9 +117,7 @@ class Session:
|
||||
if isinstance(resp, messages.ButtonRequest):
|
||||
# device is PIN-locked.
|
||||
# respond and hope for the best
|
||||
resp = (self.client.button_callback or default_button_callback)(
|
||||
self, resp
|
||||
)
|
||||
resp = self._callback_button(resp)
|
||||
resp = messages.Success.ensure_isinstance(resp)
|
||||
assert resp.message is not None
|
||||
return resp.message
|
||||
@ -120,6 +157,7 @@ class Session:
|
||||
|
||||
class SessionV1(Session):
|
||||
derive_cardano: bool | None = False
|
||||
_was_initialized_at_least_once = False
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
@ -131,7 +169,7 @@ class SessionV1(Session):
|
||||
assert isinstance(client.protocol, ProtocolV1Channel)
|
||||
session = SessionV1(client, id=session_id or b"")
|
||||
session.derive_cardano = derive_cardano
|
||||
session.init_session(session.derive_cardano)
|
||||
session.init_session(derive_cardano=session.derive_cardano)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
@ -142,39 +180,58 @@ class SessionV1(Session):
|
||||
return session
|
||||
|
||||
def resume(self) -> None:
|
||||
self.init_session(self.derive_cardano)
|
||||
self.init_session(derive_cardano=self.derive_cardano)
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
self._activate_self()
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
||||
self.client.protocol.write(msg)
|
||||
|
||||
def _activate_self(self) -> None:
|
||||
if self.client._last_active_session is not self:
|
||||
self.client._last_active_session = self
|
||||
# self.resume()
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
assert self.client._last_active_session is self
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
||||
return self.client.protocol.read()
|
||||
|
||||
def init_session(self, derive_cardano: bool | None = None) -> None:
|
||||
if self.id == b"":
|
||||
new_session = True
|
||||
session_id = None
|
||||
else:
|
||||
new_session = False
|
||||
session_id = self.id
|
||||
resp: messages.Features = self.call_raw(
|
||||
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
|
||||
)
|
||||
assert isinstance(resp, messages.Features)
|
||||
if resp.session_id is not None:
|
||||
if new_session:
|
||||
assert resp.session_id is not None
|
||||
assert len(resp.session_id) == 32
|
||||
self.id = resp.session_id
|
||||
elif self.id != resp.session_id:
|
||||
raise exceptions.FailedSessionResumption(resp.session_id)
|
||||
self.was_initialized_at_least_once = True
|
||||
|
||||
|
||||
def default_button_callback(session: Session, msg: t.Any) -> t.Any:
|
||||
return session.call_raw(messages.ButtonAck())
|
||||
def derive_seed(session: Session, passphrase: str | object) -> None:
|
||||
|
||||
from ..client import PASSPHRASE_ON_DEVICE, PASSPHRASE_TEST_PATH
|
||||
|
||||
def derive_seed(session: Session) -> None:
|
||||
|
||||
from ..btc import get_address
|
||||
from ..client import PASSPHRASE_TEST_PATH
|
||||
|
||||
get_address(session, "Testnet", PASSPHRASE_TEST_PATH)
|
||||
if passphrase is PASSPHRASE_ON_DEVICE:
|
||||
ack = messages.PassphraseAck(on_device=True)
|
||||
elif isinstance(passphrase, str):
|
||||
ack = messages.PassphraseAck(passphrase=passphrase)
|
||||
else:
|
||||
raise ValueError("Invalid passphrase")
|
||||
session.call(
|
||||
messages.GetAddress(address_n=PASSPHRASE_TEST_PATH, coin_name="Testnet"),
|
||||
expect=messages.Address,
|
||||
_passphrase_ack=ack,
|
||||
)
|
||||
session.refresh_features()
|
||||
|
@ -23,7 +23,7 @@ from mnemonic import Mnemonic
|
||||
|
||||
from . import device, messages
|
||||
from .client import MAX_PIN_LENGTH
|
||||
from .exceptions import Cancelled, PinException
|
||||
from .exceptions import Cancelled
|
||||
from .messages import Capability, PinMatrixRequestType, WordRequestType
|
||||
from .transport.session import Session
|
||||
|
||||
@ -91,15 +91,14 @@ class ClickUI:
|
||||
|
||||
return "Please confirm action on your Trezor device."
|
||||
|
||||
def button_request(self, session: Session, br: messages.ButtonRequest) -> t.Any:
|
||||
def button_request(self, br: messages.ButtonRequest) -> None:
|
||||
prompt = self._prompt_for_button(br)
|
||||
if prompt != self.last_prompt_shown:
|
||||
echo(prompt)
|
||||
if not self.always_prompt:
|
||||
self.last_prompt_shown = prompt
|
||||
return session.call_raw(messages.ButtonAck())
|
||||
|
||||
def get_pin(self, session: Session, request: messages.PinMatrixRequest) -> t.Any:
|
||||
def get_pin(self, request: messages.PinMatrixRequest) -> str:
|
||||
code = request.type
|
||||
if code == PIN_CURRENT:
|
||||
desc = "current PIN"
|
||||
@ -123,7 +122,6 @@ class ClickUI:
|
||||
try:
|
||||
pin = prompt(f"Please enter {desc}", hide_input=True)
|
||||
except click.Abort:
|
||||
session.call_raw(messages.Cancel())
|
||||
raise Cancelled from None
|
||||
|
||||
# translate letters to numbers if letters were used
|
||||
@ -137,15 +135,7 @@ class ClickUI:
|
||||
elif len(pin) > MAX_PIN_LENGTH:
|
||||
echo(f"The value must be at most {MAX_PIN_LENGTH} digits in length.")
|
||||
else:
|
||||
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
return pin
|
||||
|
||||
def get_passphrase(
|
||||
self, session: Session, request: messages.PassphraseRequest
|
||||
@ -206,13 +196,12 @@ class ScriptUI:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def button_request(session: Session, br: messages.ButtonRequest) -> t.Any:
|
||||
def button_request(br: messages.ButtonRequest) -> None:
|
||||
code = br.code.name if br.code else None
|
||||
print(f"?BUTTON code={code} pages={br.pages} name={br.name}")
|
||||
return session.call_raw(messages.ButtonAck())
|
||||
|
||||
@staticmethod
|
||||
def get_pin(session: Session, request: messages.PinMatrixRequest) -> t.Any:
|
||||
def get_pin(request: messages.PinMatrixRequest) -> str:
|
||||
code = request.type
|
||||
if code is None:
|
||||
print("?PIN")
|
||||
@ -226,15 +215,7 @@ class ScriptUI:
|
||||
raise RuntimeError("Sent PIN must start with ':'")
|
||||
else:
|
||||
pin = pin[1:]
|
||||
resp = session.call_raw(messages.PinMatrixAck(pin=pin))
|
||||
if isinstance(resp, messages.Failure) and resp.code in (
|
||||
messages.FailureType.PinInvalid,
|
||||
messages.FailureType.PinCancelled,
|
||||
messages.FailureType.PinExpected,
|
||||
):
|
||||
raise PinException(resp.code, resp.message)
|
||||
else:
|
||||
return resp
|
||||
return pin
|
||||
|
||||
@staticmethod
|
||||
def get_passphrase(session: Session, request: messages.PassphraseRequest) -> t.Any:
|
||||
|
Loading…
Reference in New Issue
Block a user