mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-05 08:29:13 +00:00
284 lines
9.4 KiB
Python
284 lines
9.4 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import typing as t
|
|
|
|
from .. import exceptions, messages, models
|
|
from ..client import MAX_PIN_LENGTH
|
|
from ..protobuf import MessageType
|
|
from .thp.protocol_v1 import ProtocolV1Channel
|
|
from .thp.protocol_v2 import ProtocolV2Channel
|
|
|
|
if t.TYPE_CHECKING:
|
|
from ..client import TrezorClient
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
MT = t.TypeVar("MT", bound=MessageType)
|
|
|
|
|
|
class Session:
|
|
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
|
self.client = client
|
|
self._id = id
|
|
|
|
def call(
|
|
self,
|
|
msg: MessageType,
|
|
expect: type[MT] = MessageType,
|
|
skip_firmware_version_check: bool = False,
|
|
_passphrase_ack: messages.PassphraseAck | None = None,
|
|
) -> MT:
|
|
if not skip_firmware_version_check:
|
|
self.client.check_firmware_version()
|
|
resp = self.call_raw(msg)
|
|
|
|
while True:
|
|
if isinstance(resp, messages.PinMatrixRequest):
|
|
if self.client.pin_callback is None:
|
|
raise RuntimeError("Missing pin_callback")
|
|
resp = self._callback_pin(resp)
|
|
elif isinstance(resp, messages.PassphraseRequest):
|
|
if _passphrase_ack is None:
|
|
# we got a PassphraseRequest when not explicitly trying to unlock
|
|
# the session, this means that the session has expired
|
|
raise exceptions.InvalidSessionError
|
|
resp = self.call_raw(_passphrase_ack)
|
|
elif isinstance(resp, messages.ButtonRequest):
|
|
resp = self._callback_button(resp)
|
|
elif isinstance(resp, messages.Failure):
|
|
if resp.code == messages.FailureType.ActionCancelled:
|
|
raise exceptions.Cancelled
|
|
raise exceptions.TrezorFailure(resp)
|
|
elif not isinstance(resp, expect):
|
|
raise exceptions.UnexpectedMessageError(expect, resp)
|
|
else:
|
|
return resp
|
|
|
|
def call_raw(self, msg: t.Any) -> t.Any:
|
|
self._write(msg)
|
|
return self._read()
|
|
|
|
def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType:
|
|
if self.client.pin_callback is None:
|
|
raise RuntimeError("No PIN provided")
|
|
try:
|
|
pin = self.client.pin_callback(msg)
|
|
except exceptions.Cancelled:
|
|
self.call_raw(messages.Cancel())
|
|
raise
|
|
|
|
if any(d not in "123456789" for d in pin) or not (
|
|
1 <= len(pin) <= MAX_PIN_LENGTH
|
|
):
|
|
self.call_raw(messages.Cancel())
|
|
raise ValueError("Invalid PIN provided")
|
|
|
|
resp = self.call_raw(messages.PinMatrixAck(pin=pin))
|
|
if isinstance(resp, messages.Failure) and resp.code in (
|
|
messages.FailureType.PinInvalid,
|
|
messages.FailureType.PinCancelled,
|
|
messages.FailureType.PinExpected,
|
|
):
|
|
raise exceptions.PinException(resp.code, resp.message)
|
|
else:
|
|
return resp
|
|
|
|
def _callback_button(self, msg: messages.ButtonRequest) -> MessageType:
|
|
__tracebackhide__ = True # for pytest # pylint: disable=W0612
|
|
# do this raw - send ButtonAck first, notify UI later
|
|
self._write(messages.ButtonAck())
|
|
if self.client.button_callback:
|
|
self.client.button_callback(msg)
|
|
return self._read()
|
|
|
|
def _write(self, msg: t.Any) -> None:
|
|
raise NotImplementedError
|
|
|
|
def _read(self) -> t.Any:
|
|
raise NotImplementedError
|
|
|
|
def refresh_features(self) -> messages.Features:
|
|
return self.client.refresh_features()
|
|
|
|
def resume(self) -> None:
|
|
pass
|
|
|
|
def end(self) -> t.Any:
|
|
return self.call(messages.EndSession())
|
|
|
|
def cancel(self) -> None:
|
|
self._write(messages.Cancel())
|
|
|
|
def ping(self, message: str, button_protection: bool | None = None) -> str:
|
|
# We would like ping to work on any valid TrezorClient instance, but
|
|
# due to the protection modes, we need to go through self.call, and that will
|
|
# raise an exception if the firmware is too old.
|
|
# So we short-circuit the simplest variant of ping with call_raw.
|
|
if not button_protection:
|
|
resp = self.call_raw(messages.Ping(message=message))
|
|
if isinstance(resp, messages.ButtonRequest):
|
|
# device is PIN-locked.
|
|
# respond and hope for the best
|
|
resp = self._callback_button(resp)
|
|
resp = messages.Success.ensure_isinstance(resp)
|
|
assert resp.message is not None
|
|
return resp.message
|
|
|
|
resp = self.call(
|
|
messages.Ping(message=message, button_protection=button_protection),
|
|
expect=messages.Success,
|
|
)
|
|
assert resp.message is not None
|
|
return resp.message
|
|
|
|
def invalidate(self) -> None:
|
|
self.client.invalidate()
|
|
|
|
@property
|
|
def features(self) -> messages.Features:
|
|
return self.client.features
|
|
|
|
@property
|
|
def model(self) -> models.TrezorModel:
|
|
return self.client.model
|
|
|
|
@property
|
|
def version(self) -> t.Tuple[int, int, int]:
|
|
return self.client.version
|
|
|
|
@property
|
|
def id(self) -> bytes:
|
|
return self._id
|
|
|
|
@id.setter
|
|
def id(self, value: bytes) -> None:
|
|
if not isinstance(value, bytes):
|
|
raise ValueError("id must be of type bytes")
|
|
self._id = value
|
|
|
|
|
|
class SessionV1(Session):
|
|
derive_cardano: bool | None = False
|
|
_was_initialized_at_least_once = False
|
|
|
|
@classmethod
|
|
def new(
|
|
cls,
|
|
client: TrezorClient,
|
|
derive_cardano: bool = False,
|
|
session_id: bytes | None = None,
|
|
) -> SessionV1:
|
|
assert isinstance(client.protocol, ProtocolV1Channel)
|
|
session = SessionV1(client, id=session_id or b"")
|
|
session.derive_cardano = derive_cardano
|
|
session.init_session(derive_cardano=session.derive_cardano)
|
|
return session
|
|
|
|
@classmethod
|
|
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
|
|
assert isinstance(client.protocol, ProtocolV1Channel)
|
|
session = SessionV1(client, session_id)
|
|
session.init_session()
|
|
return session
|
|
|
|
def resume(self) -> None:
|
|
self.init_session(derive_cardano=self.derive_cardano)
|
|
|
|
def _write(self, msg: t.Any) -> None:
|
|
self._activate_self()
|
|
if t.TYPE_CHECKING:
|
|
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
|
self.client._write(msg)
|
|
|
|
def _activate_self(self) -> None:
|
|
if self.client._last_active_session is not self:
|
|
self.client._last_active_session = self
|
|
self.resume()
|
|
|
|
def _read(self) -> t.Any:
|
|
assert self.client._last_active_session is self
|
|
if t.TYPE_CHECKING:
|
|
assert isinstance(self.client.protocol, ProtocolV1Channel)
|
|
return self.client._read()
|
|
|
|
def init_session(self, derive_cardano: bool | None = None) -> None:
|
|
if self.id == b"":
|
|
new_session = True
|
|
session_id = None
|
|
else:
|
|
new_session = False
|
|
session_id = self.id
|
|
self.client._last_active_session = self
|
|
resp: messages.Features = self.call_raw(
|
|
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
|
|
)
|
|
assert isinstance(resp, messages.Features)
|
|
msg_id = resp.session_id or b""
|
|
if new_session:
|
|
self.id = msg_id
|
|
elif self.id != msg_id:
|
|
raise exceptions.FailedSessionResumption(resp.session_id)
|
|
self.was_initialized_at_least_once = True
|
|
|
|
|
|
def derive_seed(session: Session, passphrase: str | object) -> None:
|
|
|
|
from ..client import PASSPHRASE_ON_DEVICE, PASSPHRASE_TEST_PATH
|
|
|
|
if passphrase is PASSPHRASE_ON_DEVICE:
|
|
ack = messages.PassphraseAck(on_device=True)
|
|
elif isinstance(passphrase, str):
|
|
ack = messages.PassphraseAck(passphrase=passphrase)
|
|
else:
|
|
raise ValueError("Invalid passphrase")
|
|
session.call(
|
|
messages.GetAddress(address_n=PASSPHRASE_TEST_PATH, coin_name="Testnet"),
|
|
expect=messages.Address,
|
|
_passphrase_ack=ack,
|
|
)
|
|
session.refresh_features()
|
|
|
|
|
|
class SessionV2(Session):
|
|
|
|
@classmethod
|
|
def new(
|
|
cls,
|
|
client: TrezorClient,
|
|
passphrase: str | None,
|
|
derive_cardano: bool,
|
|
session_id: int = 0,
|
|
) -> SessionV2:
|
|
assert isinstance(client.protocol, ProtocolV2Channel)
|
|
session = cls(client, session_id.to_bytes(1, "big"))
|
|
session.call(
|
|
messages.ThpCreateNewSession(
|
|
passphrase=passphrase, derive_cardano=derive_cardano
|
|
),
|
|
expect=messages.Success,
|
|
)
|
|
session.update_id_and_sid(session_id.to_bytes(1, "big"))
|
|
return session
|
|
|
|
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
|
|
|
super().__init__(client, id)
|
|
assert isinstance(client.protocol, ProtocolV2Channel)
|
|
|
|
self.channel: ProtocolV2Channel = client.protocol
|
|
self.update_id_and_sid(id)
|
|
|
|
def _write(self, msg: t.Any) -> None:
|
|
LOG.debug("writing message %s", type(msg))
|
|
self.channel.write(self.sid, msg)
|
|
|
|
def _read(self) -> t.Any:
|
|
msg = self.channel.read(self.sid)
|
|
LOG.debug("reading message %s", type(msg))
|
|
return msg
|
|
|
|
def update_id_and_sid(self, id: bytes) -> None:
|
|
self._id = id
|
|
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid
|