1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-29 20:28:45 +00:00

feat(python): add code entry pairing support to trezorlib

This commit is contained in:
M1nd3r 2025-04-28 14:55:16 +02:00
parent 961a405387
commit ad43218b6f
5 changed files with 150 additions and 13 deletions

View File

@ -20,6 +20,7 @@ import atexit
import functools
import logging
import os
import re
import sys
import typing as t
from contextlib import contextmanager
@ -104,6 +105,28 @@ def get_passphrase(
raise exceptions.Cancelled from None
def get_code_entry_code() -> int:
while True:
try:
code_input = ui.prompt(
"Enter code from Trezor",
hide_input=False,
default="",
show_default=False,
)
# Keep only digits 0-9, ignore all other symbols
code_str = re.sub(r"\D", "", code_input)
if len(code_str) != 6:
ui.echo("Code must be 6-digits long.")
continue
code = int(code_str)
return code
except click.Abort:
raise exceptions.Cancelled from None
def get_client(transport: Transport) -> TrezorClient:
return TrezorClient(transport)
@ -278,6 +301,9 @@ class TrezorConnection:
except transport.DeviceIsBusy:
click.echo("Device is in use by another process.")
sys.exit(1)
except exceptions.UnexpectedCodeEntryTagException:
click.echo("Entered Code is invalid.")
sys.exit(1)
except exceptions.FailedSessionResumption:
sys.exit(1)
except Exception:

View File

@ -20,16 +20,18 @@ import os
import typing as t
import warnings
from enum import IntEnum
from hashlib import sha256
from . import exceptions, mapping, messages, models
from .tools import parse_path
from .transport import Transport, get_transport
from .transport.thp.cpace import Cpace
from .transport.thp.protocol_and_channel import Channel
from .transport.thp.protocol_v1 import ProtocolV1Channel
from .transport.thp.protocol_v2 import ProtocolV2Channel, TrezorState
if t.TYPE_CHECKING:
from .transport.session import Session, SessionV1
from .transport.session import Session, SessionV1, SessionV2
LOG = logging.getLogger(__name__)
@ -63,9 +65,6 @@ class TrezorClient:
_last_active_session: SessionV1 | None = None
_session_id_counter: int = 0
_default_pairing_method: messages.ThpPairingMethod = (
messages.ThpPairingMethod.SkipPairing
)
def __init__(
self,
@ -112,21 +111,106 @@ class TrezorClient:
assert self.protocol_version == ProtocolVersion.V2
if pairing_method is None:
pairing_method = self._default_pairing_method
supported_methods = self.device_properties.pairing_methods
if messages.ThpPairingMethod.SkipPairing in supported_methods:
pairing_method = messages.ThpPairingMethod.SkipPairing
elif messages.ThpPairingMethod.CodeEntry in supported_methods:
pairing_method = messages.ThpPairingMethod.CodeEntry
else:
raise RuntimeError(
"Connected Trezor does not support any trezorlib-compatible pairing method."
)
session = SessionV2(client=self, id=b"\x00")
session.call(
messages.ThpPairingRequest(host_name="Trezorlib"),
expect=messages.ThpPairingRequestApproved,
skip_firmware_version_check=True,
)
if pairing_method is messages.ThpPairingMethod.SkipPairing:
return self._handle_skip_pairing(session)
if pairing_method is messages.ThpPairingMethod.CodeEntry:
return self._handle_code_entry(session)
raise RuntimeError("Unexpected pairing method")
def _handle_skip_pairing(self, session: SessionV2) -> None:
session.call(
messages.ThpSelectMethod(selected_pairing_method=pairing_method),
messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
),
expect=messages.ThpEndResponse,
skip_firmware_version_check=True,
)
assert isinstance(self.protocol, ProtocolV2Channel)
self.protocol._has_valid_channel = True
def _handle_code_entry(self, session: SessionV2) -> None:
from .cli import get_code_entry_code
commitment_msg = session.call(
messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.CodeEntry
),
expect=messages.ThpCodeEntryCommitment,
skip_firmware_version_check=True,
)
challenge = os.urandom(16)
cpace_trezor_msg = session.call(
messages.ThpCodeEntryChallenge(challenge=challenge),
expect=messages.ThpCodeEntryCpaceTrezor,
skip_firmware_version_check=True,
)
code = get_code_entry_code()
assert isinstance(session.client.protocol, ProtocolV2Channel)
cpace = Cpace(handshake_hash=session.client.protocol.handshake_hash)
cpace.random_bytes = os.urandom
assert cpace_trezor_msg.cpace_trezor_public_key is not None
cpace.generate_keys_and_secret(
code.to_bytes(6, "big"), cpace_trezor_msg.cpace_trezor_public_key
)
sha_ctx = sha256(cpace.shared_secret)
tag = sha_ctx.digest()
try:
secret_msg = session.call(
messages.ThpCodeEntryCpaceHostTag(
cpace_host_public_key=cpace.host_public_key,
tag=tag,
),
expect=messages.ThpCodeEntrySecret,
skip_firmware_version_check=True,
)
except exceptions.TrezorFailure as e:
if e.message == "Unexpected Code Entry Tag":
raise exceptions.UnexpectedCodeEntryTagException
else:
raise e
# Check `commitment` and `code`
assert secret_msg.secret is not None
sha_ctx = sha256(secret_msg.secret)
computed_commitment = sha_ctx.digest()
assert commitment_msg.commitment == computed_commitment
sha_ctx = sha256(messages.ThpPairingMethod.CodeEntry.to_bytes(1, "big"))
sha_ctx.update(session.client.protocol.handshake_hash)
sha_ctx.update(secret_msg.secret)
sha_ctx.update(challenge)
code_hash = sha_ctx.digest()
computed_code = int.from_bytes(code_hash, "big") % 1000000
assert code == computed_code
session.call(
messages.ThpEndRequest(),
expect=messages.ThpEndResponse,
skip_firmware_version_check=True,
)
assert isinstance(self.protocol, ProtocolV2Channel)
self.protocol._has_valid_channel = True
def get_session(
self,
passphrase: str | object = "",
@ -223,6 +307,19 @@ class TrezorClient:
def is_invalidated(self) -> bool:
return self._is_invalidated
@property
def device_properties(self) -> messages.ThpDeviceProperties:
if self.protocol_version == ProtocolVersion.V1:
raise RuntimeError("Device properties are not avaialble with ProtocolV1.")
assert isinstance(self.protocol, ProtocolV2Channel)
if self.protocol.device_properties is None:
raise RuntimeError("Device properties are not avaialble.")
dp = self.mapping.decode_without_wire_type(
messages.ThpDeviceProperties, self.protocol.device_properties
)
assert isinstance(dp, messages.ThpDeviceProperties)
return dp
def refresh_features(self) -> messages.Features:
self.protocol.update_features()
self._features = self.protocol.get_features()

View File

@ -113,3 +113,7 @@ class DerivationOnUninitaizedDeviceError(TrezorException):
class DeviceLockedException(TrezorException):
pass
class UnexpectedCodeEntryTagException(TrezorException):
pass

View File

@ -17,27 +17,28 @@
from __future__ import annotations
import io
import typing as t
from types import ModuleType
from typing import Dict, Optional, Tuple, Type, TypeVar
from typing_extensions import Self
from . import messages, protobuf
T = TypeVar("T")
T = t.TypeVar("T")
MT = t.TypeVar("MT", bound=protobuf.MessageType)
class ProtobufMapping:
"""Mapping of protobuf classes to Python classes"""
def __init__(self) -> None:
self.type_to_class: Dict[int, Type[protobuf.MessageType]] = {}
self.class_to_type_override: Dict[Type[protobuf.MessageType], int] = {}
self.type_to_class: t.Dict[int, t.Type[protobuf.MessageType]] = {}
self.class_to_type_override: t.Dict[t.Type[protobuf.MessageType], int] = {}
def register(
self,
msg_class: Type[protobuf.MessageType],
msg_wire_type: Optional[int] = None,
msg_class: t.Type[protobuf.MessageType],
msg_wire_type: int | None = None,
) -> None:
"""Register a Python class as a protobuf type.
@ -55,7 +56,7 @@ class ProtobufMapping:
self.type_to_class[msg_wire_type] = msg_class
def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]:
def encode(self, msg: protobuf.MessageType) -> tuple[int, bytes]:
"""Serialize a Python protobuf class.
Returns the message wire type and a byte representation of the protobuf message.
@ -86,6 +87,14 @@ class ProtobufMapping:
buf = io.BytesIO(msg_bytes)
return protobuf.load_message(buf, cls)
def decode_without_wire_type(
self, message_type: type[MT], msg_bytes: bytes
) -> protobuf.MessageType:
"""Deserialize a protobuf message into a Python class."""
cls = message_type
buf = io.BytesIO(msg_bytes)
return protobuf.load_message(buf, cls)
@classmethod
def from_module(cls, module: ModuleType) -> Self:
"""Generate a mapping from a module.

View File

@ -36,6 +36,7 @@ class ProtocolV2Channel(Channel):
sync_bit_send: int
sync_bit_receive: int
handshake_hash: bytes
device_properties: bytes
_has_valid_channel: bool = False
_features: messages.Features | None = None