mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-29 20:28:45 +00:00
feat(python): add code entry pairing support to trezorlib
This commit is contained in:
parent
961a405387
commit
ad43218b6f
@ -20,6 +20,7 @@ import atexit
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import typing as t
|
||||
from contextlib import contextmanager
|
||||
@ -104,6 +105,28 @@ def get_passphrase(
|
||||
raise exceptions.Cancelled from None
|
||||
|
||||
|
||||
def get_code_entry_code() -> int:
|
||||
while True:
|
||||
try:
|
||||
code_input = ui.prompt(
|
||||
"Enter code from Trezor",
|
||||
hide_input=False,
|
||||
default="",
|
||||
show_default=False,
|
||||
)
|
||||
|
||||
# Keep only digits 0-9, ignore all other symbols
|
||||
code_str = re.sub(r"\D", "", code_input)
|
||||
|
||||
if len(code_str) != 6:
|
||||
ui.echo("Code must be 6-digits long.")
|
||||
continue
|
||||
code = int(code_str)
|
||||
return code
|
||||
except click.Abort:
|
||||
raise exceptions.Cancelled from None
|
||||
|
||||
|
||||
def get_client(transport: Transport) -> TrezorClient:
|
||||
return TrezorClient(transport)
|
||||
|
||||
@ -278,6 +301,9 @@ class TrezorConnection:
|
||||
except transport.DeviceIsBusy:
|
||||
click.echo("Device is in use by another process.")
|
||||
sys.exit(1)
|
||||
except exceptions.UnexpectedCodeEntryTagException:
|
||||
click.echo("Entered Code is invalid.")
|
||||
sys.exit(1)
|
||||
except exceptions.FailedSessionResumption:
|
||||
sys.exit(1)
|
||||
except Exception:
|
||||
|
@ -20,16 +20,18 @@ import os
|
||||
import typing as t
|
||||
import warnings
|
||||
from enum import IntEnum
|
||||
from hashlib import sha256
|
||||
|
||||
from . import exceptions, mapping, messages, models
|
||||
from .tools import parse_path
|
||||
from .transport import Transport, get_transport
|
||||
from .transport.thp.cpace import Cpace
|
||||
from .transport.thp.protocol_and_channel import Channel
|
||||
from .transport.thp.protocol_v1 import ProtocolV1Channel
|
||||
from .transport.thp.protocol_v2 import ProtocolV2Channel, TrezorState
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from .transport.session import Session, SessionV1
|
||||
from .transport.session import Session, SessionV1, SessionV2
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -63,9 +65,6 @@ class TrezorClient:
|
||||
_last_active_session: SessionV1 | None = None
|
||||
|
||||
_session_id_counter: int = 0
|
||||
_default_pairing_method: messages.ThpPairingMethod = (
|
||||
messages.ThpPairingMethod.SkipPairing
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -112,21 +111,106 @@ class TrezorClient:
|
||||
|
||||
assert self.protocol_version == ProtocolVersion.V2
|
||||
if pairing_method is None:
|
||||
pairing_method = self._default_pairing_method
|
||||
supported_methods = self.device_properties.pairing_methods
|
||||
if messages.ThpPairingMethod.SkipPairing in supported_methods:
|
||||
pairing_method = messages.ThpPairingMethod.SkipPairing
|
||||
elif messages.ThpPairingMethod.CodeEntry in supported_methods:
|
||||
pairing_method = messages.ThpPairingMethod.CodeEntry
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Connected Trezor does not support any trezorlib-compatible pairing method."
|
||||
)
|
||||
session = SessionV2(client=self, id=b"\x00")
|
||||
session.call(
|
||||
messages.ThpPairingRequest(host_name="Trezorlib"),
|
||||
expect=messages.ThpPairingRequestApproved,
|
||||
skip_firmware_version_check=True,
|
||||
)
|
||||
if pairing_method is messages.ThpPairingMethod.SkipPairing:
|
||||
return self._handle_skip_pairing(session)
|
||||
if pairing_method is messages.ThpPairingMethod.CodeEntry:
|
||||
return self._handle_code_entry(session)
|
||||
|
||||
raise RuntimeError("Unexpected pairing method")
|
||||
|
||||
def _handle_skip_pairing(self, session: SessionV2) -> None:
|
||||
session.call(
|
||||
messages.ThpSelectMethod(selected_pairing_method=pairing_method),
|
||||
messages.ThpSelectMethod(
|
||||
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
|
||||
),
|
||||
expect=messages.ThpEndResponse,
|
||||
skip_firmware_version_check=True,
|
||||
)
|
||||
assert isinstance(self.protocol, ProtocolV2Channel)
|
||||
self.protocol._has_valid_channel = True
|
||||
|
||||
def _handle_code_entry(self, session: SessionV2) -> None:
|
||||
from .cli import get_code_entry_code
|
||||
|
||||
commitment_msg = session.call(
|
||||
messages.ThpSelectMethod(
|
||||
selected_pairing_method=messages.ThpPairingMethod.CodeEntry
|
||||
),
|
||||
expect=messages.ThpCodeEntryCommitment,
|
||||
skip_firmware_version_check=True,
|
||||
)
|
||||
challenge = os.urandom(16)
|
||||
cpace_trezor_msg = session.call(
|
||||
messages.ThpCodeEntryChallenge(challenge=challenge),
|
||||
expect=messages.ThpCodeEntryCpaceTrezor,
|
||||
skip_firmware_version_check=True,
|
||||
)
|
||||
|
||||
code = get_code_entry_code()
|
||||
assert isinstance(session.client.protocol, ProtocolV2Channel)
|
||||
cpace = Cpace(handshake_hash=session.client.protocol.handshake_hash)
|
||||
cpace.random_bytes = os.urandom
|
||||
assert cpace_trezor_msg.cpace_trezor_public_key is not None
|
||||
cpace.generate_keys_and_secret(
|
||||
code.to_bytes(6, "big"), cpace_trezor_msg.cpace_trezor_public_key
|
||||
)
|
||||
sha_ctx = sha256(cpace.shared_secret)
|
||||
tag = sha_ctx.digest()
|
||||
|
||||
try:
|
||||
secret_msg = session.call(
|
||||
messages.ThpCodeEntryCpaceHostTag(
|
||||
cpace_host_public_key=cpace.host_public_key,
|
||||
tag=tag,
|
||||
),
|
||||
expect=messages.ThpCodeEntrySecret,
|
||||
skip_firmware_version_check=True,
|
||||
)
|
||||
except exceptions.TrezorFailure as e:
|
||||
if e.message == "Unexpected Code Entry Tag":
|
||||
raise exceptions.UnexpectedCodeEntryTagException
|
||||
else:
|
||||
raise e
|
||||
|
||||
# Check `commitment` and `code`
|
||||
assert secret_msg.secret is not None
|
||||
sha_ctx = sha256(secret_msg.secret)
|
||||
computed_commitment = sha_ctx.digest()
|
||||
|
||||
assert commitment_msg.commitment == computed_commitment
|
||||
|
||||
sha_ctx = sha256(messages.ThpPairingMethod.CodeEntry.to_bytes(1, "big"))
|
||||
sha_ctx.update(session.client.protocol.handshake_hash)
|
||||
sha_ctx.update(secret_msg.secret)
|
||||
sha_ctx.update(challenge)
|
||||
code_hash = sha_ctx.digest()
|
||||
computed_code = int.from_bytes(code_hash, "big") % 1000000
|
||||
assert code == computed_code
|
||||
|
||||
session.call(
|
||||
messages.ThpEndRequest(),
|
||||
expect=messages.ThpEndResponse,
|
||||
skip_firmware_version_check=True,
|
||||
)
|
||||
|
||||
assert isinstance(self.protocol, ProtocolV2Channel)
|
||||
self.protocol._has_valid_channel = True
|
||||
|
||||
def get_session(
|
||||
self,
|
||||
passphrase: str | object = "",
|
||||
@ -223,6 +307,19 @@ class TrezorClient:
|
||||
def is_invalidated(self) -> bool:
|
||||
return self._is_invalidated
|
||||
|
||||
@property
|
||||
def device_properties(self) -> messages.ThpDeviceProperties:
|
||||
if self.protocol_version == ProtocolVersion.V1:
|
||||
raise RuntimeError("Device properties are not avaialble with ProtocolV1.")
|
||||
assert isinstance(self.protocol, ProtocolV2Channel)
|
||||
if self.protocol.device_properties is None:
|
||||
raise RuntimeError("Device properties are not avaialble.")
|
||||
dp = self.mapping.decode_without_wire_type(
|
||||
messages.ThpDeviceProperties, self.protocol.device_properties
|
||||
)
|
||||
assert isinstance(dp, messages.ThpDeviceProperties)
|
||||
return dp
|
||||
|
||||
def refresh_features(self) -> messages.Features:
|
||||
self.protocol.update_features()
|
||||
self._features = self.protocol.get_features()
|
||||
|
@ -113,3 +113,7 @@ class DerivationOnUninitaizedDeviceError(TrezorException):
|
||||
|
||||
class DeviceLockedException(TrezorException):
|
||||
pass
|
||||
|
||||
|
||||
class UnexpectedCodeEntryTagException(TrezorException):
|
||||
pass
|
||||
|
@ -17,27 +17,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import typing as t
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from . import messages, protobuf
|
||||
|
||||
T = TypeVar("T")
|
||||
T = t.TypeVar("T")
|
||||
MT = t.TypeVar("MT", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
class ProtobufMapping:
|
||||
"""Mapping of protobuf classes to Python classes"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
|
||||
self.class_to_type_override: Dict[Type[protobuf.MessageType], int] = {}
|
||||
self.type_to_class: t.Dict[int, t.Type[protobuf.MessageType]] = {}
|
||||
self.class_to_type_override: t.Dict[t.Type[protobuf.MessageType], int] = {}
|
||||
|
||||
def register(
|
||||
self,
|
||||
msg_class: Type[protobuf.MessageType],
|
||||
msg_wire_type: Optional[int] = None,
|
||||
msg_class: t.Type[protobuf.MessageType],
|
||||
msg_wire_type: int | None = None,
|
||||
) -> None:
|
||||
"""Register a Python class as a protobuf type.
|
||||
|
||||
@ -55,7 +56,7 @@ class ProtobufMapping:
|
||||
|
||||
self.type_to_class[msg_wire_type] = msg_class
|
||||
|
||||
def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]:
|
||||
def encode(self, msg: protobuf.MessageType) -> tuple[int, bytes]:
|
||||
"""Serialize a Python protobuf class.
|
||||
|
||||
Returns the message wire type and a byte representation of the protobuf message.
|
||||
@ -86,6 +87,14 @@ class ProtobufMapping:
|
||||
buf = io.BytesIO(msg_bytes)
|
||||
return protobuf.load_message(buf, cls)
|
||||
|
||||
def decode_without_wire_type(
|
||||
self, message_type: type[MT], msg_bytes: bytes
|
||||
) -> protobuf.MessageType:
|
||||
"""Deserialize a protobuf message into a Python class."""
|
||||
cls = message_type
|
||||
buf = io.BytesIO(msg_bytes)
|
||||
return protobuf.load_message(buf, cls)
|
||||
|
||||
@classmethod
|
||||
def from_module(cls, module: ModuleType) -> Self:
|
||||
"""Generate a mapping from a module.
|
||||
|
@ -36,6 +36,7 @@ class ProtocolV2Channel(Channel):
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
handshake_hash: bytes
|
||||
device_properties: bytes
|
||||
|
||||
_has_valid_channel: bool = False
|
||||
_features: messages.Features | None = None
|
||||
|
Loading…
Reference in New Issue
Block a user