From ad43218b6f2215a83da8b17c88180950f9925da0 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Mon, 28 Apr 2025 14:55:16 +0200 Subject: [PATCH] feat(python): add code entry pairing support to trezorlib --- python/src/trezorlib/cli/__init__.py | 26 +++++ python/src/trezorlib/client.py | 109 +++++++++++++++++- python/src/trezorlib/exceptions.py | 4 + python/src/trezorlib/mapping.py | 23 ++-- .../trezorlib/transport/thp/protocol_v2.py | 1 + 5 files changed, 150 insertions(+), 13 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index cc7e031b70..48a314cbc3 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -20,6 +20,7 @@ import atexit import functools import logging import os +import re import sys import typing as t from contextlib import contextmanager @@ -104,6 +105,28 @@ def get_passphrase( raise exceptions.Cancelled from None +def get_code_entry_code() -> int: + while True: + try: + code_input = ui.prompt( + "Enter code from Trezor", + hide_input=False, + default="", + show_default=False, + ) + + # Keep only digits 0-9, ignore all other symbols + code_str = re.sub(r"\D", "", code_input) + + if len(code_str) != 6: + ui.echo("Code must be 6-digits long.") + continue + code = int(code_str) + return code + except click.Abort: + raise exceptions.Cancelled from None + + def get_client(transport: Transport) -> TrezorClient: return TrezorClient(transport) @@ -278,6 +301,9 @@ class TrezorConnection: except transport.DeviceIsBusy: click.echo("Device is in use by another process.") sys.exit(1) + except exceptions.UnexpectedCodeEntryTagException: + click.echo("Entered Code is invalid.") + sys.exit(1) except exceptions.FailedSessionResumption: sys.exit(1) except Exception: diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index d48f64028d..7ab761358e 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -20,16 +20,18 @@ import os import typing as t import warnings from enum import IntEnum +from hashlib import sha256 from . import exceptions, mapping, messages, models from .tools import parse_path from .transport import Transport, get_transport +from .transport.thp.cpace import Cpace from .transport.thp.protocol_and_channel import Channel from .transport.thp.protocol_v1 import ProtocolV1Channel from .transport.thp.protocol_v2 import ProtocolV2Channel, TrezorState if t.TYPE_CHECKING: - from .transport.session import Session, SessionV1 + from .transport.session import Session, SessionV1, SessionV2 LOG = logging.getLogger(__name__) @@ -63,9 +65,6 @@ class TrezorClient: _last_active_session: SessionV1 | None = None _session_id_counter: int = 0 - _default_pairing_method: messages.ThpPairingMethod = ( - messages.ThpPairingMethod.SkipPairing - ) def __init__( self, @@ -112,21 +111,106 @@ class TrezorClient: assert self.protocol_version == ProtocolVersion.V2 if pairing_method is None: - pairing_method = self._default_pairing_method + supported_methods = self.device_properties.pairing_methods + if messages.ThpPairingMethod.SkipPairing in supported_methods: + pairing_method = messages.ThpPairingMethod.SkipPairing + elif messages.ThpPairingMethod.CodeEntry in supported_methods: + pairing_method = messages.ThpPairingMethod.CodeEntry + else: + raise RuntimeError( + "Connected Trezor does not support any trezorlib-compatible pairing method." + ) session = SessionV2(client=self, id=b"\x00") session.call( messages.ThpPairingRequest(host_name="Trezorlib"), expect=messages.ThpPairingRequestApproved, skip_firmware_version_check=True, ) + if pairing_method is messages.ThpPairingMethod.SkipPairing: + return self._handle_skip_pairing(session) + if pairing_method is messages.ThpPairingMethod.CodeEntry: + return self._handle_code_entry(session) + + raise RuntimeError("Unexpected pairing method") + + def _handle_skip_pairing(self, session: SessionV2) -> None: session.call( - messages.ThpSelectMethod(selected_pairing_method=pairing_method), + messages.ThpSelectMethod( + selected_pairing_method=messages.ThpPairingMethod.SkipPairing + ), expect=messages.ThpEndResponse, skip_firmware_version_check=True, ) assert isinstance(self.protocol, ProtocolV2Channel) self.protocol._has_valid_channel = True + def _handle_code_entry(self, session: SessionV2) -> None: + from .cli import get_code_entry_code + + commitment_msg = session.call( + messages.ThpSelectMethod( + selected_pairing_method=messages.ThpPairingMethod.CodeEntry + ), + expect=messages.ThpCodeEntryCommitment, + skip_firmware_version_check=True, + ) + challenge = os.urandom(16) + cpace_trezor_msg = session.call( + messages.ThpCodeEntryChallenge(challenge=challenge), + expect=messages.ThpCodeEntryCpaceTrezor, + skip_firmware_version_check=True, + ) + + code = get_code_entry_code() + assert isinstance(session.client.protocol, ProtocolV2Channel) + cpace = Cpace(handshake_hash=session.client.protocol.handshake_hash) + cpace.random_bytes = os.urandom + assert cpace_trezor_msg.cpace_trezor_public_key is not None + cpace.generate_keys_and_secret( + code.to_bytes(6, "big"), cpace_trezor_msg.cpace_trezor_public_key + ) + sha_ctx = sha256(cpace.shared_secret) + tag = sha_ctx.digest() + + try: + secret_msg = session.call( + messages.ThpCodeEntryCpaceHostTag( + cpace_host_public_key=cpace.host_public_key, + tag=tag, + ), + expect=messages.ThpCodeEntrySecret, + skip_firmware_version_check=True, + ) + except exceptions.TrezorFailure as e: + if e.message == "Unexpected Code Entry Tag": + raise exceptions.UnexpectedCodeEntryTagException + else: + raise e + + # Check `commitment` and `code` + assert secret_msg.secret is not None + sha_ctx = sha256(secret_msg.secret) + computed_commitment = sha_ctx.digest() + + assert commitment_msg.commitment == computed_commitment + + sha_ctx = sha256(messages.ThpPairingMethod.CodeEntry.to_bytes(1, "big")) + sha_ctx.update(session.client.protocol.handshake_hash) + sha_ctx.update(secret_msg.secret) + sha_ctx.update(challenge) + code_hash = sha_ctx.digest() + computed_code = int.from_bytes(code_hash, "big") % 1000000 + assert code == computed_code + + session.call( + messages.ThpEndRequest(), + expect=messages.ThpEndResponse, + skip_firmware_version_check=True, + ) + + assert isinstance(self.protocol, ProtocolV2Channel) + self.protocol._has_valid_channel = True + def get_session( self, passphrase: str | object = "", @@ -223,6 +307,19 @@ class TrezorClient: def is_invalidated(self) -> bool: return self._is_invalidated + @property + def device_properties(self) -> messages.ThpDeviceProperties: + if self.protocol_version == ProtocolVersion.V1: + raise RuntimeError("Device properties are not avaialble with ProtocolV1.") + assert isinstance(self.protocol, ProtocolV2Channel) + if self.protocol.device_properties is None: + raise RuntimeError("Device properties are not avaialble.") + dp = self.mapping.decode_without_wire_type( + messages.ThpDeviceProperties, self.protocol.device_properties + ) + assert isinstance(dp, messages.ThpDeviceProperties) + return dp + def refresh_features(self) -> messages.Features: self.protocol.update_features() self._features = self.protocol.get_features() diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index 5b3eb4c3ae..7a8aeb2df2 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -113,3 +113,7 @@ class DerivationOnUninitaizedDeviceError(TrezorException): class DeviceLockedException(TrezorException): pass + + +class UnexpectedCodeEntryTagException(TrezorException): + pass diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index a92fb062d6..b00c2b8709 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -17,27 +17,28 @@ from __future__ import annotations import io +import typing as t from types import ModuleType -from typing import Dict, Optional, Tuple, Type, TypeVar from typing_extensions import Self from . import messages, protobuf -T = TypeVar("T") +T = t.TypeVar("T") +MT = t.TypeVar("MT", bound=protobuf.MessageType) class ProtobufMapping: """Mapping of protobuf classes to Python classes""" def __init__(self) -> None: - self.type_to_class: Dict[int, Type[protobuf.MessageType]] = {} - self.class_to_type_override: Dict[Type[protobuf.MessageType], int] = {} + self.type_to_class: t.Dict[int, t.Type[protobuf.MessageType]] = {} + self.class_to_type_override: t.Dict[t.Type[protobuf.MessageType], int] = {} def register( self, - msg_class: Type[protobuf.MessageType], - msg_wire_type: Optional[int] = None, + msg_class: t.Type[protobuf.MessageType], + msg_wire_type: int | None = None, ) -> None: """Register a Python class as a protobuf type. @@ -55,7 +56,7 @@ class ProtobufMapping: self.type_to_class[msg_wire_type] = msg_class - def encode(self, msg: protobuf.MessageType) -> Tuple[int, bytes]: + def encode(self, msg: protobuf.MessageType) -> tuple[int, bytes]: """Serialize a Python protobuf class. Returns the message wire type and a byte representation of the protobuf message. @@ -86,6 +87,14 @@ class ProtobufMapping: buf = io.BytesIO(msg_bytes) return protobuf.load_message(buf, cls) + def decode_without_wire_type( + self, message_type: type[MT], msg_bytes: bytes + ) -> protobuf.MessageType: + """Deserialize a protobuf message into a Python class.""" + cls = message_type + buf = io.BytesIO(msg_bytes) + return protobuf.load_message(buf, cls) + @classmethod def from_module(cls, module: ModuleType) -> Self: """Generate a mapping from a module. diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index 1218a1d6bb..f05707b12b 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -36,6 +36,7 @@ class ProtocolV2Channel(Channel): sync_bit_send: int sync_bit_receive: int handshake_hash: bytes + device_properties: bytes _has_valid_channel: bool = False _features: messages.Features | None = None