From 61b2156a1e599fe027fe029e8d5347e0bede7116 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 4 Feb 2025 15:21:19 +0100 Subject: [PATCH] chore(core): adapt trezorlib transports to session based [no changelog] --- python/src/trezorlib/transport/__init__.py | 106 ++-- python/src/trezorlib/transport/bridge.py | 106 +++- python/src/trezorlib/transport/hid.py | 116 ++--- python/src/trezorlib/transport/protocol.py | 165 ------ python/src/trezorlib/transport/session.py | 223 ++++++++ .../transport/thp/alternating_bit_protocol.py | 102 ++++ .../trezorlib/transport/thp/channel_data.py | 47 ++ .../transport/thp/channel_database.py | 148 ++++++ .../src/trezorlib/transport/thp/checksum.py | 19 + .../trezorlib/transport/thp/control_byte.py | 63 +++ python/src/trezorlib/transport/thp/cpace.py | 40 ++ .../src/trezorlib/transport/thp/curve25519.py | 159 ++++++ .../trezorlib/transport/thp/message_header.py | 82 +++ .../transport/thp/protocol_and_channel.py | 32 ++ .../trezorlib/transport/thp/protocol_v1.py | 97 ++++ .../trezorlib/transport/thp/protocol_v2.py | 490 ++++++++++++++++++ python/src/trezorlib/transport/thp/thp_io.py | 93 ++++ python/src/trezorlib/transport/udp.py | 93 ++-- python/src/trezorlib/transport/webusb.py | 141 +++-- 19 files changed, 1896 insertions(+), 426 deletions(-) delete mode 100644 python/src/trezorlib/transport/protocol.py create mode 100644 python/src/trezorlib/transport/session.py create mode 100644 python/src/trezorlib/transport/thp/alternating_bit_protocol.py create mode 100644 python/src/trezorlib/transport/thp/channel_data.py create mode 100644 python/src/trezorlib/transport/thp/channel_database.py create mode 100644 python/src/trezorlib/transport/thp/checksum.py create mode 100644 python/src/trezorlib/transport/thp/control_byte.py create mode 100644 python/src/trezorlib/transport/thp/cpace.py create mode 100644 python/src/trezorlib/transport/thp/curve25519.py create mode 100644 python/src/trezorlib/transport/thp/message_header.py create mode 100644 python/src/trezorlib/transport/thp/protocol_and_channel.py create mode 100644 python/src/trezorlib/transport/thp/protocol_v1.py create mode 100644 python/src/trezorlib/transport/thp/protocol_v2.py create mode 100644 python/src/trezorlib/transport/thp/thp_io.py diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index b04876b6b7..2c208be36d 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -14,24 +14,18 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -from typing import ( - TYPE_CHECKING, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +import typing as t from ..exceptions import TrezorException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel - T = TypeVar("T", bound="Transport") + T = t.TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) @@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules """.strip() -MessagePayload = Tuple[int, bytes] +MessagePayload = t.Tuple[int, bytes] class TransportException(TrezorException): @@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException): class Transport: - """Raw connection to a Trezor device. - - Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB - or USB-HID connection, or UDP socket of listening emulator(s). - It can also enumerate devices available over this communication link, and return - them as instances. - - Transport instance is a thing that: - - can be identified and requested by a string URI-like path - - can open and close sessions, which enclose related operations - - can read and write protobuf messages - - You need to implement a new Transport subclass if you invent a new way to connect - a Trezor device to a computer. - """ - PATH_PREFIX: str - ENABLED = False - def __str__(self) -> str: - return self.get_path() + @classmethod + def enumerate( + cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["T"]: + raise NotImplementedError + + @classmethod + def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T": + for device in cls.enumerate(): + + if device.get_path() == path: + return device + + if prefix_search and device.get_path().startswith(path): + return device + + raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def get_path(self) -> str: raise NotImplementedError - def begin_session(self) -> None: - raise NotImplementedError - - def end_session(self) -> None: - raise NotImplementedError - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - def find_debug(self: "T") -> "T": raise NotImplementedError - @classmethod - def enumerate( - cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["T"]: + def open(self) -> None: raise NotImplementedError - @classmethod - def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": - for device in cls.enumerate(): - if ( - path is None - or device.get_path() == path - or (prefix_search and device.get_path().startswith(path)) - ): - return device + def close(self) -> None: + raise NotImplementedError - raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") + def write_chunk(self, chunk: bytes) -> None: + raise NotImplementedError + + def read_chunk(self) -> bytes: + raise NotImplementedError + + CHUNK_SIZE: t.ClassVar[int] -def all_transports() -> Iterable[Type["Transport"]]: +def all_transports() -> t.Iterable[t.Type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - transports: Tuple[Type["Transport"], ...] = ( + transports: t.Tuple[t.Type["Transport"], ...] = ( BridgeTransport, HidTransport, UdpTransport, @@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]: def enumerate_devices( - models: Optional[Iterable["TrezorModel"]] = None, -) -> Sequence["Transport"]: - devices: List["Transport"] = [] + models: t.Iterable["TrezorModel"] | None = None, +) -> t.Sequence["Transport"]: + devices: t.List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: @@ -145,9 +121,7 @@ def enumerate_devices( return devices -def get_transport( - path: Optional[str] = None, prefix_search: bool = False -) -> "Transport": +def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e0c34a8f70..7f136608b0 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -14,24 +14,30 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import struct -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +import typing as t import requests from ..log import DUMP_PACKETS from . import DeviceIsBusy, MessagePayload, Transport, TransportException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel LOG = logging.getLogger(__name__) +PROTOCOL_VERSION_1 = 1 +PROTOCOL_VERSION_2 = 2 + TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_VERSION_MODERN = (2, 0, 25) +TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) @@ -45,7 +51,7 @@ class BridgeException(TransportException): super().__init__(f"trezord: {path} failed with code {status}: {message}") -def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: +def call_bridge(path: str, data: str | None = None) -> requests.Response: url = TREZORD_HOST + "/" + path r = CONNECTION.post(url, data=data) if r.status_code != 200: @@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: return r -def is_legacy_bridge() -> bool: +def get_bridge_version() -> t.Tuple[int, ...]: config = call_bridge("configure").json() - version_tuple = tuple(map(int, config["version"].split("."))) - return version_tuple < TREZORD_VERSION_MODERN + return tuple(map(int, config["version"].split("."))) + + +def is_legacy_bridge() -> bool: + return get_bridge_version() < TREZORD_VERSION_MODERN + + +def supports_protocolV2() -> bool: + return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT + + +def detect_protocol_version(transport: "BridgeTransport") -> int: + from .. import mapping, messages + from ..messages import FailureType + + protocol_version = PROTOCOL_VERSION_1 + request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize()) + transport.deprecated_begin_session() + transport.deprecated_write(request_type, request_data) + + response_type, response_data = transport.deprecated_read() + response = mapping.DEFAULT_MAPPING.decode(response_type, response_data) + transport.deprecated_begin_session() + if isinstance(response, messages.Failure): + if response.code == FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol_version = PROTOCOL_VERSION_2 + + return protocol_version + + +def _is_transport_valid(transport: "BridgeTransport") -> bool: + is_valid = ( + supports_protocolV2() + or detect_protocol_version(transport) == PROTOCOL_VERSION_1 + ) + if not is_valid: + LOG.warning("Detected unsupported Bridge transport!") + return is_valid + + +def filter_invalid_bridge_transports( + transports: t.Iterable["BridgeTransport"], +) -> t.Sequence["BridgeTransport"]: + """Filters out invalid bridge transports. Keeps only valid ones.""" + return [t for t in transports if _is_transport_valid(t)] class BridgeHandle: @@ -84,7 +134,7 @@ class BridgeHandleModern(BridgeHandle): class BridgeHandleLegacy(BridgeHandle): def __init__(self, transport: "BridgeTransport") -> None: super().__init__(transport) - self.request: Optional[str] = None + self.request: str | None = None def write_buf(self, buf: bytes) -> None: if self.request is not None: @@ -112,13 +162,12 @@ class BridgeTransport(Transport): ENABLED: bool = True def __init__( - self, device: Dict[str, Any], legacy: bool, debug: bool = False + self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False ) -> None: if legacy and debug: raise TransportException("Debugging not supported on legacy Bridge") - self.device = device - self.session: Optional[str] = None + self.session: str | None = device["session"] self.debug = debug self.legacy = legacy @@ -135,7 +184,7 @@ class BridgeTransport(Transport): raise TransportException("Debug device not available") return BridgeTransport(self.device, self.legacy, debug=True) - def _call(self, action: str, data: Optional[str] = None) -> requests.Response: + def _call(self, action: str, data: str | None = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: @@ -144,17 +193,20 @@ class BridgeTransport(Transport): @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["BridgeTransport"]: + cls, _models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() - return [ - BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() - ] + return filter_invalid_bridge_transports( + [ + BridgeTransport(dev, legacy) + for dev in call_bridge("enumerate").json() + ] + ) except Exception: return [] - def begin_session(self) -> None: + def deprecated_begin_session(self) -> None: try: data = self._call("acquire/" + self.device["path"]) except BridgeException as e: @@ -163,18 +215,32 @@ class BridgeTransport(Transport): raise self.session = data.json()["session"] - def end_session(self) -> None: + def deprecated_end_session(self) -> None: if not self.session: return self._call("release") self.session = None - def write(self, message_type: int, message_data: bytes) -> None: + def deprecated_write(self, message_type: int, message_data: bytes) -> None: header = struct.pack(">HL", message_type, len(message_data)) self.handle.write_buf(header + message_data) - def read(self) -> MessagePayload: + def deprecated_read(self) -> MessagePayload: data = self.handle.read_buf() headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", data[:headerlen]) return msg_type, data[headerlen : headerlen + datalen] + + def open(self) -> None: + pass + # TODO self.handle.open() + + def close(self) -> None: + pass + # TODO self.handle.close() + + def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :) + self.handle.write_buf(chunk) + + def read_chunk(self) -> bytes: # TODO check if it works :) + return self.handle.read_buf() diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65fa08ccd7..65e2cddf7d 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -14,15 +14,16 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import sys import time -from typing import Any, Dict, Iterable, List, Optional +import typing as t from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel -from . import UDEV_RULES_STR, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, Transport, TransportException LOG = logging.getLogger(__name__) @@ -35,23 +36,61 @@ except Exception as e: HID_IMPORTED = False -HidDevice = Dict[str, Any] -HidDeviceHandle = Any +HidDevice = t.Dict[str, t.Any] +HidDeviceHandle = t.Any -class HidHandle: - def __init__( - self, path: bytes, serial: str, probe_hid_version: bool = False - ) -> None: - self.path = path - self.serial = serial +class HidTransport(Transport): + """ + HidTransport implements transport over USB HID interface. + """ + + PATH_PREFIX = "hid" + ENABLED = HID_IMPORTED + + def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None: + self.device = device + self.device_path = device["path"] + self.device_serial_number = device["serial_number"] self.handle: HidDeviceHandle = None self.hid_version = None if probe_hid_version else 2 + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" + + @classmethod + def enumerate( + cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False + ) -> t.Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + + devices: t.List["HidTransport"] = [] + for dev in hid.enumerate(0, 0): + usb_id = (dev["vendor_id"], dev["product_id"]) + if usb_id not in usb_ids: + continue + if debug: + if not is_debuglink(dev): + continue + else: + if not is_wirelink(dev): + continue + devices.append(HidTransport(dev)) + return devices + + def find_debug(self) -> "HidTransport": + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") + def open(self) -> None: self.handle = hid.device() try: - self.handle.open_path(self.path) + self.handle.open_path(self.device_path) except (IOError, OSError) as e: if sys.platform.startswith("linux"): e.args = e.args + (UDEV_RULES_STR,) @@ -62,11 +101,11 @@ class HidHandle: # and we wouldn't even know. # So we check that the serial matches what we expect. serial = self.handle.get_serial_number_string() - if serial != self.serial: + if serial != self.device_serial_number: self.handle.close() self.handle = None raise TransportException( - f"Unexpected device {serial} on path {self.path.decode()}" + f"Unexpected device {serial} on path {self.device_path.decode()}" ) self.handle.set_nonblocking(True) @@ -77,7 +116,7 @@ class HidHandle: def close(self) -> None: if self.handle is not None: # reload serial, because device.wipe() can reset it - self.serial = self.handle.get_serial_number_string() + self.device_serial_number = self.handle.get_serial_number_string() self.handle.close() self.handle = None @@ -115,53 +154,6 @@ class HidHandle: raise TransportException("Unknown HID version") -class HidTransport(ProtocolBasedTransport): - """ - HidTransport implements transport over USB HID interface. - """ - - PATH_PREFIX = "hid" - ENABLED = HID_IMPORTED - - def __init__(self, device: HidDevice) -> None: - self.device = device - self.handle = HidHandle(device["path"], device["serial_number"]) - - super().__init__(protocol=ProtocolV1(self.handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False - ) -> Iterable["HidTransport"]: - if models is None: - models = {TREZOR_ONE} - usb_ids = [id for model in models for id in model.usb_ids] - - devices: List["HidTransport"] = [] - for dev in hid.enumerate(0, 0): - usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id not in usb_ids: - continue - if debug: - if not is_debuglink(dev): - continue - else: - if not is_wirelink(dev): - continue - devices.append(HidTransport(dev)) - return devices - - def find_debug(self) -> "HidTransport": - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") - - def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py deleted file mode 100644 index a5a0ee6be4..0000000000 --- a/python/src/trezorlib/transport/protocol.py +++ /dev/null @@ -1,165 +0,0 @@ -# This file is part of the Trezor project. -# -# Copyright (C) 2012-2022 SatoshiLabs and contributors -# -# This library is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the License along with this library. -# If not, see . - -import logging -import struct -from typing import Tuple - -from typing_extensions import Protocol as StructuralType - -from . import MessagePayload, Transport - -REPLEN = 64 - -V2_FIRST_CHUNK = 0x01 -V2_NEXT_CHUNK = 0x02 -V2_BEGIN_SESSION = 0x03 -V2_END_SESSION = 0x04 - -LOG = logging.getLogger(__name__) - - -class Handle(StructuralType): - """PEP 544 structural type for Handle functionality. - (called a "Protocol" in the proposed PEP, name which is impractical here) - - Handle is a "physical" layer for a protocol. - It can open/close a connection and read/write bare data in 64-byte chunks. - - Functionally we gain nothing from making this an (abstract) base class for handle - implementations, so this definition is for type hinting purposes only. You can, - but don't have to, inherit from it. - """ - - def open(self) -> None: ... - - def close(self) -> None: ... - - def read_chunk(self) -> bytes: ... - - def write_chunk(self, chunk: bytes) -> None: ... - - -class Protocol: - """Wire protocol that can communicate with a Trezor device, given a Handle. - - A Protocol implements the part of the Transport API that relates to communicating - logical messages over a physical layer. It is a thing that can: - - open and close sessions, - - send and receive protobuf messages, - given the ability to: - - open and close physical connections, - - and send and receive binary chunks. - - For now, the class also handles session counting and opening the underlying Handle. - This will probably be removed in the future. - - We will need a new Protocol class if we change the way a Trezor device encapsulates - its messages. - """ - - def __init__(self, handle: Handle) -> None: - self.handle = handle - self.session_counter = 0 - - # XXX we might be able to remove this now that TrezorClient does session handling - def begin_session(self) -> None: - if self.session_counter == 0: - self.handle.open() - self.session_counter += 1 - - def end_session(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - self.handle.close() - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - - -class ProtocolBasedTransport(Transport): - """Transport that implements its communications through a Protocol. - - Intended as a base class for implementations that proxy their communication - operations to a Protocol. - """ - - def __init__(self, protocol: Protocol) -> None: - self.protocol = protocol - - def write(self, message_type: int, message_data: bytes) -> None: - self.protocol.write(message_type, message_data) - - def read(self) -> MessagePayload: - return self.protocol.read() - - def begin_session(self) -> None: - self.protocol.begin_session() - - def end_session(self) -> None: - self.protocol.end_session() - - -class ProtocolV1(Protocol): - """Protocol version 1. Currently (11/2018) in use on all Trezors. - Does not understand sessions. - """ - - HEADER_LEN = struct.calcsize(">HL") - - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - buffer = bytearray(b"##" + header + message_data) - - while buffer: - # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: REPLEN - 1] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - buffer = buffer[63:] - - def read(self) -> MessagePayload: - buffer = bytearray() - # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() - buffer.extend(first_chunk) - - # Read the rest of the message - while len(buffer) < datalen: - buffer.extend(self.read_next()) - - return msg_type, buffer[:datalen] - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - if chunk[:3] != b"?##": - raise RuntimeError("Unexpected magic characters") - try: - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) - except Exception: - raise RuntimeError("Cannot parse header") - - data = chunk[3 + self.HEADER_LEN :] - return msg_type, datalen, data - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - if chunk[:1] != b"?": - raise RuntimeError("Unexpected magic characters") - return chunk[1:] diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py new file mode 100644 index 0000000000..90dab10bfa --- /dev/null +++ b/python/src/trezorlib/transport/session.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import logging +import typing as t + +from .. import exceptions, messages, models +from ..protobuf import MessageType +from .thp.protocol_v1 import ProtocolV1 +from .thp.protocol_v2 import ProtocolV2 + +if t.TYPE_CHECKING: + from ..client import TrezorClient + +LOG = logging.getLogger(__name__) + +MT = t.TypeVar("MT", bound=MessageType) + + +class Session: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + + def __init__( + self, client: TrezorClient, id: bytes, passphrase: str | object | None = None + ) -> None: + self.client = client + self._id = id + self.passphrase = passphrase + + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool + ) -> Session: + raise NotImplementedError + + def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT: + # TODO self.check_firmware_version() + resp = self.call_raw(msg) + + while True: + if isinstance(resp, messages.PinMatrixRequest): + if self.pin_callback is None: + raise Exception # TODO + resp = self.pin_callback(self, resp) + elif isinstance(resp, messages.PassphraseRequest): + if self.passphrase_callback is None: + raise Exception # TODO + resp = self.passphrase_callback(self, resp) + elif isinstance(resp, messages.ButtonRequest): + if self.button_callback is None: + raise Exception # TODO + resp = self.button_callback(self, resp) + elif isinstance(resp, messages.Failure): + if resp.code == messages.FailureType.ActionCancelled: + raise exceptions.Cancelled + raise exceptions.TrezorFailure(resp) + elif not isinstance(resp, expect): + raise exceptions.UnexpectedMessageError(expect, resp) + else: + return resp + + def call_raw(self, msg: t.Any) -> t.Any: + self._write(msg) + return self._read() + + def _write(self, msg: t.Any) -> None: + raise NotImplementedError + + def _read(self) -> t.Any: + raise NotImplementedError + + def refresh_features(self) -> None: + self.client.refresh_features() + + def end(self) -> t.Any: + return self.call(messages.EndSession()) + + def ping(self, message: str, button_protection: bool | None = None) -> str: + resp = self.call( + messages.Ping(message=message, button_protection=button_protection), + expect=messages.Success, + ) + assert resp.message is not None + return resp.message + + def invalidate(self) -> None: + self.client.invalidate() + + @property + def features(self) -> messages.Features: + return self.client.features + + @property + def model(self) -> models.TrezorModel: + return self.client.model + + @property + def version(self) -> t.Tuple[int, int, int]: + return self.client.version + + @property + def id(self) -> bytes: + return self._id + + @id.setter + def id(self, value: bytes) -> None: + if not isinstance(value, bytes): + raise ValueError("id must be of type bytes") + self._id = value + + +class SessionV1(Session): + derive_cardano: bool | None = False + + @classmethod + def new( + cls, + client: TrezorClient, + passphrase: str | object = "", + derive_cardano: bool = False, + session_id: bytes | None = None, + ) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, id=session_id or b"") + + session._init_callbacks() + session.passphrase = passphrase + session.derive_cardano = derive_cardano + session.init_session(session.derive_cardano) + return session + + @classmethod + def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, session_id) + session.init_session() + return session + + def _init_callbacks(self) -> None: + self.button_callback = self.client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self.client.passphrase_callback + + def _write(self, msg: t.Any) -> None: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + self.client.protocol.write(msg) + + def _read(self) -> t.Any: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + return self.client.protocol.read() + + def init_session(self, derive_cardano: bool | None = None): + if self.id == b"": + session_id = None + else: + session_id = self.id + resp: messages.Features = self.call_raw( + messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) + ) + if isinstance(self.passphrase, str): + self.passphrase_callback = self.client.passphrase_callback + self._id = resp.session_id + + +def _callback_button(session: Session, msg: t.Any) -> t.Any: + print("Please confirm action on your Trezor device.") # TODO how to handle UI? + return session.call(messages.ButtonAck()) + + +class SessionV2(Session): + + @classmethod + def new( + cls, + client: TrezorClient, + passphrase: str | None, + derive_cardano: bool, + session_id: int = 0, + ) -> SessionV2: + assert isinstance(client.protocol, ProtocolV2) + session = cls(client, session_id.to_bytes(1, "big")) + session.call( + messages.ThpCreateNewSession( + passphrase=passphrase, derive_cardano=derive_cardano + ), + expect=messages.Success, + ) + session.update_id_and_sid(session_id.to_bytes(1, "big")) + return session + + def __init__(self, client: TrezorClient, id: bytes) -> None: + from ..debuglink import TrezorClientDebugLink + + super().__init__(client, id) + assert isinstance(client.protocol, ProtocolV2) + + self.pin_callback = client.pin_callback + self.button_callback = client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + helper_debug = None + if isinstance(client, TrezorClientDebugLink): + helper_debug = client.debug + self.channel: ProtocolV2 = client.protocol.get_channel(helper_debug) + self.update_id_and_sid(id) + + def _write(self, msg: t.Any) -> None: + LOG.debug("writing message %s", type(msg)) + self.channel.write(self.sid, msg) + + def _read(self) -> t.Any: + msg = self.channel.read(self.sid) + LOG.debug("reading message %s", type(msg)) + return msg + + def update_id_and_sid(self, id: bytes) -> None: + self._id = id + self.sid = int.from_bytes(id, "big") # TODO update to extract only sid diff --git a/python/src/trezorlib/transport/thp/alternating_bit_protocol.py b/python/src/trezorlib/transport/thp/alternating_bit_protocol.py new file mode 100644 index 0000000000..62fb650fab --- /dev/null +++ b/python/src/trezorlib/transport/thp/alternating_bit_protocol.py @@ -0,0 +1,102 @@ +# from storage.cache_thp import ChannelCache +# from trezor import log +# from trezor.wire.thp import ThpError + + +# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool: +# """ +# Checks if: +# - an ACK message is expected +# - the received ACK message acknowledges correct sequence number (bit) +# """ +# if not _is_ack_expected(cache): +# return False + +# if not _has_ack_correct_sync_bit(cache, ack_bit): +# return False + +# return True + + +# def _is_ack_expected(cache: ChannelCache) -> bool: +# is_expected: bool = not is_sending_allowed(cache) +# if __debug__ and not is_expected: +# log.debug(__name__, "Received unexpected ACK message") +# return is_expected + + +# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool: +# is_correct: bool = get_send_seq_bit(cache) == sync_bit +# if __debug__ and not is_correct: +# log.debug(__name__, "Received ACK message with wrong ack bit") +# return is_correct + + +# def is_sending_allowed(cache: ChannelCache) -> bool: +# """ +# Checks whether sending a message in the provided channel is allowed. + +# Note: Sending a message in a channel before receipt of ACK message for the previously +# sent message (in the channel) is prohibited, as it can lead to desynchronization. +# """ +# return bool(cache.sync >> 7) + + +# def get_send_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the sequential number (bit) of the next message to be sent +# in the provided channel. +# """ +# return (cache.sync & 0x20) >> 5 + + +# def get_expected_receive_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the (expected) sequential number (bit) of the next message +# to be received in the provided channel. +# """ +# return (cache.sync & 0x40) >> 6 + + +# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None: +# """ +# Set the flag whether sending a message in this channel is allowed or not. +# """ +# cache.sync &= 0x7F +# if sending_allowed: +# cache.sync |= 0x80 + + +# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# """ +# Set the expected sequential number (bit) of the next message to be received +# in the provided channel +# """ +# if __debug__: +# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit) +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected receive sync bit") + +# # set second bit to "seq_bit" value +# cache.sync &= 0xBF +# if seq_bit: +# cache.sync |= 0x40 + + +# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected send seq bit") +# if __debug__: +# log.debug(__name__, "setting sync send seq bit to %d", seq_bit) +# # set third bit to "seq_bit" value +# cache.sync &= 0xDF +# if seq_bit: +# cache.sync |= 0x20 + + +# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None: +# """ +# Set the sequential bit of the "next message to be send" to the opposite value, +# i.e. 1 -> 0 and 0 -> 1 +# """ +# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache)) diff --git a/python/src/trezorlib/transport/thp/channel_data.py b/python/src/trezorlib/transport/thp/channel_data.py new file mode 100644 index 0000000000..4d9d11d8d0 --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_data.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from binascii import hexlify + + +class ChannelData: + + def __init__( + self, + protocol_version_major: int, + protocol_version_minor: int, + transport_path: str, + channel_id: int, + key_request: bytes, + key_response: bytes, + nonce_request: int, + nonce_response: int, + sync_bit_send: int, + sync_bit_receive: int, + handshake_hash: bytes, + ) -> None: + self.protocol_version_major: int = protocol_version_major + self.protocol_version_minor: int = protocol_version_minor + self.transport_path: str = transport_path + self.channel_id: int = channel_id + self.key_request: str = hexlify(key_request).decode() + self.key_response: str = hexlify(key_response).decode() + self.nonce_request: int = nonce_request + self.nonce_response: int = nonce_response + self.sync_bit_receive: int = sync_bit_receive + self.sync_bit_send: int = sync_bit_send + self.handshake_hash: str = hexlify(handshake_hash).decode() + + def to_dict(self): + return { + "protocol_version_major": self.protocol_version_major, + "protocol_version_minor": self.protocol_version_minor, + "transport_path": self.transport_path, + "channel_id": self.channel_id, + "key_request": self.key_request, + "key_response": self.key_response, + "nonce_request": self.nonce_request, + "nonce_response": self.nonce_response, + "sync_bit_send": self.sync_bit_send, + "sync_bit_receive": self.sync_bit_receive, + "handshake_hash": self.handshake_hash, + } diff --git a/python/src/trezorlib/transport/thp/channel_database.py b/python/src/trezorlib/transport/thp/channel_database.py new file mode 100644 index 0000000000..03be0f7ece --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_database.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +import json +import logging +import os +import typing as t + +from ..thp.channel_data import ChannelData +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +db: "ChannelDatabase | None" = None + + +def get_channel_db() -> ChannelDatabase: + if db is None: + set_channel_database(should_not_store=True) + assert db is not None + return db + + +class ChannelDatabase: + + def load_stored_channels(self) -> t.List[ChannelData]: ... + def clear_stored_channels(self) -> None: ... + def read_all_channels(self) -> t.List: ... + def save_all_channels(self, channels: t.List[t.Dict]) -> None: ... + def save_channel(self, new_channel: ProtocolAndChannel): ... + def remove_channel(self, transport_path: str) -> None: ... + + +class DummyChannelDatabase(ChannelDatabase): + + def load_stored_channels(self) -> t.List[ChannelData]: + return [] + + def clear_stored_channels(self) -> None: + pass + + def read_all_channels(self) -> t.List: + return [] + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + return + + def save_channel(self, new_channel: ProtocolAndChannel): + pass + + def remove_channel(self, transport_path: str) -> None: + pass + + +class JsonChannelDatabase(ChannelDatabase): + def __init__(self, data_path: str) -> None: + self.data_path = data_path + super().__init__() + + def load_stored_channels(self) -> t.List[ChannelData]: + dicts = self.read_all_channels() + return [dict_to_channel_data(d) for d in dicts] + + def clear_stored_channels(self) -> None: + LOG.debug("Clearing contents of %s", self.data_path) + with open(self.data_path, "w") as f: + json.dump([], f) + try: + os.remove(self.data_path) + except Exception as e: + LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e))) + + def read_all_channels(self) -> t.List: + ensure_file_exists(self.data_path) + with open(self.data_path, "r") as f: + return json.load(f) + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + LOG.debug("saving all channels") + with open(self.data_path, "w") as f: + json.dump(channels, f, indent=4) + + def save_channel(self, new_channel: ProtocolAndChannel): + + LOG.debug("save channel") + channels = self.read_all_channels() + transport_path = new_channel.transport.get_path() + + # If the channel is found in database: replace the old entry by the new + for i, channel in enumerate(channels): + if channel["transport_path"] == transport_path: + LOG.debug("Modified channel entry for %s", transport_path) + channels[i] = new_channel.get_channel_data().to_dict() + self.save_all_channels(channels) + return + + # Channel was not found: add a new channel entry + LOG.debug("Created a new channel entry on path %s", transport_path) + channels.append(new_channel.get_channel_data().to_dict()) + self.save_all_channels(channels) + + def remove_channel(self, transport_path: str) -> None: + LOG.debug( + "Removing channel with path %s from the channel database.", + transport_path, + ) + channels = self.read_all_channels() + remaining_channels = [ + ch for ch in channels if ch["transport_path"] != transport_path + ] + self.save_all_channels(remaining_channels) + + +def dict_to_channel_data(dict: t.Dict) -> ChannelData: + return ChannelData( + protocol_version_major=dict["protocol_version_minor"], + protocol_version_minor=dict["protocol_version_major"], + transport_path=dict["transport_path"], + channel_id=dict["channel_id"], + key_request=bytes.fromhex(dict["key_request"]), + key_response=bytes.fromhex(dict["key_response"]), + nonce_request=dict["nonce_request"], + nonce_response=dict["nonce_response"], + sync_bit_send=dict["sync_bit_send"], + sync_bit_receive=dict["sync_bit_receive"], + handshake_hash=bytes.fromhex(dict["handshake_hash"]), + ) + + +def ensure_file_exists(file_path: str) -> None: + LOG.debug("checking if file %s exists", file_path) + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + LOG.debug("File %s does not exist. Creating a new one.", file_path) + with open(file_path, "w") as f: + json.dump([], f) + + +def set_channel_database(should_not_store: bool): + global db + if should_not_store: + db = DummyChannelDatabase() + else: + from platformdirs import user_cache_dir + + APP_NAME = "@trezor" # TODO + DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json") + + db = JsonChannelDatabase(DATA_PATH) diff --git a/python/src/trezorlib/transport/thp/checksum.py b/python/src/trezorlib/transport/thp/checksum.py new file mode 100644 index 0000000000..8e0f32f013 --- /dev/null +++ b/python/src/trezorlib/transport/thp/checksum.py @@ -0,0 +1,19 @@ +import zlib + +CHECKSUM_LENGTH = 4 + + +def compute(data: bytes) -> bytes: + """ + Returns a CRC-32 checksum of the provided `data`. + """ + return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + + +def is_valid(checksum: bytes, data: bytes) -> bool: + """ + Checks whether the CRC-32 checksum of the `data` is the same + as the checksum provided in `checksum`. + """ + data_checksum = compute(data) + return checksum == data_checksum diff --git a/python/src/trezorlib/transport/thp/control_byte.py b/python/src/trezorlib/transport/thp/control_byte.py new file mode 100644 index 0000000000..dca681ef02 --- /dev/null +++ b/python/src/trezorlib/transport/thp/control_byte.py @@ -0,0 +1,63 @@ +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + + +def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int: + if seq_bit == 0: + return ctrl_byte & 0xEF + if seq_bit == 1: + return ctrl_byte | 0x10 + raise Exception("Unexpected sequence bit") + + +def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int: + if ack_bit == 0: + return ctrl_byte & 0xF7 + if ack_bit == 1: + return ctrl_byte | 0x08 + raise Exception("Unexpected acknowledgement bit") + + +def get_seq_bit(ctrl_byte: int) -> int: + return (ctrl_byte & 0x10) >> 4 + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & ACK_MASK == ACK_MESSAGE + + +def is_error(ctrl_byte: int) -> bool: + return ctrl_byte == _ERROR + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ diff --git a/python/src/trezorlib/transport/thp/cpace.py b/python/src/trezorlib/transport/thp/cpace.py new file mode 100644 index 0000000000..d0b28e265c --- /dev/null +++ b/python/src/trezorlib/transport/thp/cpace.py @@ -0,0 +1,40 @@ +import typing as t +from hashlib import sha512 + +from . import curve25519 + +_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06" +_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20" + + +class Cpace: + """ + CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/ + """ + + random_bytes: t.Callable[[int], bytes] + + def __init__(self, handshake_hash: bytes) -> None: + self.handshake_hash: bytes = handshake_hash + self.shared_secret: bytes + self.host_private_key: bytes + self.host_public_key: bytes + + def generate_keys_and_secret( + self, code_code_entry: bytes, trezor_public_key: bytes + ) -> None: + """ + Generate ephemeral key pair and a shared secret using Elligator2 with X25519. + """ + sha_ctx = sha512(_PREFIX) + sha_ctx.update(code_code_entry) + sha_ctx.update(_PADDING) + sha_ctx.update(self.handshake_hash) + sha_ctx.update(b"\x00") + pregenerator = sha_ctx.digest()[:32] + generator = curve25519.elligator2(pregenerator) + self.host_private_key = self.random_bytes(32) + self.host_public_key = curve25519.multiply(self.host_private_key, generator) + self.shared_secret = curve25519.multiply( + self.host_private_key, trezor_public_key + ) diff --git a/python/src/trezorlib/transport/thp/curve25519.py b/python/src/trezorlib/transport/thp/curve25519.py new file mode 100644 index 0000000000..e4416225f1 --- /dev/null +++ b/python/src/trezorlib/transport/thp/curve25519.py @@ -0,0 +1,159 @@ +from typing import Tuple + +p = 2**255 - 19 +J = 486662 + +c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1) +c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8 +a24 = 121666 # (J + 2) // 4 + + +def decode_scalar(scalar: bytes) -> int: + # decodeScalar25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + if len(scalar) != 32: + raise ValueError("Invalid length of scalar") + + array = bytearray(scalar) + array[0] &= 248 + array[31] &= 127 + array[31] |= 64 + + return int.from_bytes(array, "little") + + +def decode_coordinate(coordinate: bytes) -> int: + # decodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + if len(coordinate) != 32: + raise ValueError("Invalid length of coordinate") + + array = bytearray(coordinate) + array[-1] &= 0x7F + return int.from_bytes(array, "little") % p + + +def encode_coordinate(coordinate: int) -> bytes: + # encodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + return coordinate.to_bytes(32, "little") + + +def get_private_key(secret: bytes) -> bytes: + return decode_scalar(secret).to_bytes(32, "little") + + +def get_public_key(private_key: bytes) -> bytes: + base_point = int.to_bytes(9, 32, "little") + return multiply(private_key, base_point) + + +def multiply(private_scalar: bytes, public_point: bytes): + # X25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + def ladder_operation( + x1: int, x2: int, z2: int, x3: int, z3: int + ) -> Tuple[int, int, int, int]: + # https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3 + # (x4, z4) = 2 * (x2, z2) + # (x5, z5) = (x2, z2) + (x3, z3) + # where (x1, 1) = (x3, z3) - (x2, z2) + + a = (x2 + z2) % p + aa = (a * a) % p + b = (x2 - z2) % p + bb = (b * b) % p + e = (aa - bb) % p + c = (x3 + z3) % p + d = (x3 - z3) % p + da = (d * a) % p + cb = (c * b) % p + t0 = (da + cb) % p + x5 = (t0 * t0) % p + t1 = (da - cb) % p + t2 = (t1 * t1) % p + z5 = (x1 * t2) % p + x4 = (aa * bb) % p + t3 = (a24 * e) % p + t4 = (bb + t3) % p + z4 = (e * t4) % p + + return x4, z4, x5, z5 + + def conditional_swap(first: int, second: int, condition: int): + # Returns (second, first) if condition is true and (first, second) otherwise + # Must be implemented in a way that it is constant time + true_mask = -condition + false_mask = ~true_mask + return (first & false_mask) | (second & true_mask), (second & false_mask) | ( + first & true_mask + ) + + k = decode_scalar(private_scalar) + u = decode_coordinate(public_point) + + x_1 = u + x_2 = 1 + z_2 = 0 + x_3 = u + z_3 = 1 + swap = 0 + + for i in reversed(range(256)): + bit = (k >> i) & 1 + swap = bit ^ swap + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + swap = bit + x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3) + + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + + x = pow(z_2, p - 2, p) * x_2 % p + return encode_coordinate(x) + + +def elligator2(point: bytes) -> bytes: + # map_to_curve_elligator2_curve25519 from + # https://www.rfc-editor.org/rfc/rfc9380.html#ell2-opt + + def conditional_move(first: int, second: int, condition: bool): + # Returns second if condition is true and first otherwise + # Must be implemented in a way that it is constant time + true_mask = -condition + false_mask = ~true_mask + return (first & false_mask) | (second & true_mask) + + u = decode_coordinate(point) + tv1 = (u * u) % p + tv1 = (2 * tv1) % p + xd = (tv1 + 1) % p + x1n = (-J) % p + tv2 = (xd * xd) % p + gxd = (tv2 * xd) % p + gx1 = (J * tv1) % p + gx1 = (gx1 * x1n) % p + gx1 = (gx1 + tv2) % p + gx1 = (gx1 * x1n) % p + tv3 = (gxd * gxd) % p + tv2 = (tv3 * tv3) % p + tv3 = (tv3 * gxd) % p + tv3 = (tv3 * gx1) % p + tv2 = (tv2 * tv3) % p + y11 = pow(tv2, c4, p) + y11 = (y11 * tv3) % p + y12 = (y11 * c3) % p + tv2 = (y11 * y11) % p + tv2 = (tv2 * gxd) % p + e1 = tv2 == gx1 + y1 = conditional_move(y12, y11, e1) + x2n = (x1n * tv1) % p + tv2 = (y1 * y1) % p + tv2 = (tv2 * gxd) % p + e3 = tv2 == gx1 + xn = conditional_move(x2n, x1n, e3) + x = xn * pow(xd, p - 2, p) % p + return encode_coordinate(x) diff --git a/python/src/trezorlib/transport/thp/message_header.py b/python/src/trezorlib/transport/thp/message_header.py new file mode 100644 index 0000000000..d2ff002d63 --- /dev/null +++ b/python/src/trezorlib/transport/thp/message_header.py @@ -0,0 +1,82 @@ +import struct + +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + +BROADCAST_CHANNEL_ID = 0xFFFF + + +class MessageHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.data_length = length + + def to_bytes_init(self) -> bytes: + return struct.pack( + self.format_str_init, self.ctrl_byte, self.cid, self.data_length + ) + + def to_bytes_cont(self) -> bytes: + return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.data_length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + def is_ack(self) -> bool: + return self.ctrl_byte & ACK_MASK == ACK_MESSAGE + + def is_channel_allocation_response(self): + return ( + self.cid == BROADCAST_CHANNEL_ID + and self.ctrl_byte == _CHANNEL_ALLOCATION_RES + ) + + def is_handshake_init_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES + + def is_handshake_comp_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES + + def is_encrypted_transport(self) -> bool: + return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + @classmethod + def get_error_header(cls, cid: int, length: int): + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_request_header(cls, length: int): + return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length) diff --git a/python/src/trezorlib/transport/thp/protocol_and_channel.py b/python/src/trezorlib/transport/thp/protocol_and_channel.py new file mode 100644 index 0000000000..fa420ac0af --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_and_channel.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import logging + +from ... import messages +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp.channel_data import ChannelData + +LOG = logging.getLogger(__name__) + + +class ProtocolAndChannel: + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.transport = transport + self.mapping = mapping + self.channel_keys = channel_data + + def get_features(self) -> messages.Features: + raise NotImplementedError() + + def get_channel_data(self) -> ChannelData: + raise NotImplementedError + + def update_features(self) -> None: + raise NotImplementedError diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py new file mode 100644 index 0000000000..baea7e7401 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +import struct +import typing as t + +from ... import exceptions, messages +from ...log import DUMP_BYTES +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + + +class ProtocolV1(ProtocolAndChannel): + HEADER_LEN = struct.calcsize(">HL") + _features: messages.Features | None = None + + def get_features(self) -> messages.Features: + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + self.write(messages.GetFeatures()) + resp = self.read() + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = resp + + def read(self) -> t.Any: + msg_type, msg_bytes = self._read() + LOG.log( + DUMP_BYTES, + f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + msg = self.mapping.decode(msg_type, msg_bytes) + LOG.debug( + f"received message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + self.transport.close() + return msg + + def write(self, msg: t.Any) -> None: + LOG.debug( + f"sending message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = self.mapping.encode(msg) + LOG.log( + DUMP_BYTES, + f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + self._write(msg_type, msg_bytes) + + def _write(self, message_type: int, message_data: bytes) -> None: + chunk_size = self.transport.CHUNK_SIZE + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) + + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: chunk_size - 1] + chunk = chunk.ljust(chunk_size, b"\x00") + self.transport.write_chunk(chunk) + buffer = buffer[63:] + + def _read(self) -> t.Tuple[int, bytes]: + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first() + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next()) + + return msg_type, buffer[:datalen] + + def read_first(self) -> t.Tuple[int, int, bytes]: + chunk = self.transport.read_chunk() + if chunk[:3] != b"?##": + raise RuntimeError("Unexpected magic characters") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self) -> bytes: + chunk = self.transport.read_chunk() + if chunk[:1] != b"?": + raise RuntimeError("Unexpected magic characters") + return chunk[1:] diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py new file mode 100644 index 0000000000..b073a0264d --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +import typing as t +from binascii import hexlify + +import click +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from ... import exceptions, messages, protobuf +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp import checksum, curve25519, thp_io +from ..thp.channel_data import ChannelData +from ..thp.checksum import CHECKSUM_LENGTH +from ..thp.message_header import MessageHeader +from . import control_byte +from .channel_database import ChannelDatabase, get_channel_db +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +DEFAULT_SESSION_ID: int = 0 + +if t.TYPE_CHECKING: + from ...debuglink import DebugLink +MT = t.TypeVar("MT", bound=protobuf.MessageType) + + +def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: + hash = hashlib.sha256(val_1) + hash.update(val_2) + return hash.digest() + + +def _hkdf(chaining_key: bytes, input: bytes): + temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest() + output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest() + ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _get_iv_from_nonce(nonce: int) -> bytes: + if not nonce <= 0xFFFFFFFFFFFFFFFF: + raise ValueError("Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") + + +class ProtocolV2(ProtocolAndChannel): + channel_id: int + channel_database: ChannelDatabase + key_request: bytes + key_response: bytes + nonce_request: int + nonce_response: int + sync_bit_send: int + sync_bit_receive: int + handshake_hash: bytes + + _has_valid_channel: bool = False + _features: messages.Features | None = None + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.channel_database: ChannelDatabase = get_channel_db() + super().__init__(transport, mapping, channel_data) + if channel_data is not None: + self.channel_id = channel_data.channel_id + self.key_request = bytes.fromhex(channel_data.key_request) + self.key_response = bytes.fromhex(channel_data.key_response) + self.nonce_request = channel_data.nonce_request + self.nonce_response = channel_data.nonce_response + self.sync_bit_receive = channel_data.sync_bit_receive + self.sync_bit_send = channel_data.sync_bit_send + self.handshake_hash = bytes.fromhex(channel_data.handshake_hash) + self._has_valid_channel = True + + def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2: + if not self._has_valid_channel: + self._establish_new_channel(helper_debug) + return self + + def get_channel_data(self) -> ChannelData: + return ChannelData( + protocol_version_major=2, + protocol_version_minor=2, + transport_path=self.transport.get_path(), + channel_id=self.channel_id, + key_request=self.key_request, + key_response=self.key_response, + nonce_request=self.nonce_request, + nonce_response=self.nonce_response, + sync_bit_receive=self.sync_bit_receive, + sync_bit_send=self.sync_bit_send, + handshake_hash=self.handshake_hash, + ) + + def read(self, session_id: int) -> t.Any: + sid, msg_type, msg_data = self.read_and_decrypt() + if sid != session_id: + raise Exception("Received messsage on a different session.") + self.channel_database.save_channel(self) + return self.mapping.decode(msg_type, msg_data) + + def write(self, session_id: int, msg: t.Any) -> None: + msg_type, msg_data = self.mapping.encode(msg) + self._encrypt_and_write(session_id, msg_type, msg_data) + self.channel_database.save_channel(self) + + def get_features(self) -> messages.Features: + if not self._has_valid_channel: + self._establish_new_channel() + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + message = messages.GetFeatures() + message_type, message_data = self.mapping.encode(message) + self.session_id: int = DEFAULT_SESSION_ID + self._encrypt_and_write(DEFAULT_SESSION_ID, message_type, message_data) + _ = self._read_until_valid_crc_check() # TODO check ACK + _, msg_type, msg_data = self.read_and_decrypt() + features = self.mapping.decode(msg_type, msg_data) + if not isinstance(features, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = features + + def _send_message( + self, + message: protobuf.MessageType, + session_id: int = DEFAULT_SESSION_ID, + ): + message_type, message_data = self.mapping.encode(message) + self._encrypt_and_write(session_id, message_type, message_data) + self._read_ack() + + def _read_message(self, message_type: type[MT]) -> MT: + _, msg_type, msg_data = self.read_and_decrypt() + msg = self.mapping.decode(msg_type, msg_data) + assert isinstance(msg, message_type) + return msg + + def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None: + self._reset_sync_bits() + self._do_channel_allocation() + self._do_handshake() + self._do_pairing(helper_debug) + + def _reset_sync_bits(self) -> None: + self.sync_bit_send = 0 + self.sync_bit_receive = 0 + + def _do_channel_allocation(self) -> None: + channel_allocation_nonce = os.urandom(8) + self._send_channel_allocation_request(channel_allocation_nonce) + cid, dp = self._read_channel_allocation_response(channel_allocation_nonce) + self.channel_id = cid + self.device_properties = dp + + def _send_channel_allocation_request(self, nonce: bytes): + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + MessageHeader.get_channel_allocation_request_header(12), + nonce, + ) + + def _read_channel_allocation_response( + self, expected_nonce: bytes + ) -> tuple[int, bytes]: + header, payload = self._read_until_valid_crc_check() + if not self._is_valid_channel_allocation_response( + header, payload, expected_nonce + ): + raise Exception("Invalid channel allocation response.") + + channel_id = int.from_bytes(payload[8:10], "big") + device_properties = payload[10:] + return (channel_id, device_properties) + + def _do_handshake( + self, credential: bytes | None = None, host_static_privkey: bytes | None = None + ): + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + self._send_handshake_init_request(host_ephemeral_pubkey) + self._read_ack() + init_response = self._read_handshake_init_response() + + trezor_ephemeral_pubkey = init_response[:32] + encrypted_trezor_static_pubkey = init_response[32:80] + noise_tag = init_response[80:96] + LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) + + # TODO check noise_tag is valid + + ck = self._send_handshake_completion_request( + host_ephemeral_pubkey, + host_ephemeral_privkey, + trezor_ephemeral_pubkey, + encrypted_trezor_static_pubkey, + credential, + host_static_privkey, + ) + self._read_ack() + self._read_handshake_completion_response() + self.key_request, self.key_response = _hkdf(ck, b"") + self.nonce_request = 0 + self.nonce_response = 1 + + def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None: + ha_init_req_header = MessageHeader(0, self.channel_id, 36) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, ha_init_req_header, host_ephemeral_pubkey + ) + + def _read_handshake_init_response(self) -> bytes: + header, payload = self._read_until_valid_crc_check() + self._send_ack_0() + + if header.ctrl_byte == 0x42: + if payload == b"\x05": + raise exceptions.DeviceLockedException() + + if not header.is_handshake_init_response(): + LOG.debug("Received message is not a valid handshake init response message") + + click.echo( + "Received message is not a valid handshake init response message", + err=True, + ) + return payload + + def _send_handshake_completion_request( + self, + host_ephemeral_pubkey: bytes, + host_ephemeral_privkey: bytes, + trezor_ephemeral_pubkey: bytes, + encrypted_trezor_static_pubkey: bytes, + credential: bytes | None = None, + host_static_privkey: bytes | None = None, + ) -> bytes: + PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) + h = _sha256_of_two(h, host_ephemeral_pubkey) + h = _sha256_of_two(h, trezor_ephemeral_pubkey) + ck, k = _hkdf( + PROTOCOL_NAME, + curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), + ) + + aes_ctx = AESGCM(k) + try: + trezor_masked_static_pubkey = aes_ctx.decrypt( + IV_1, encrypted_trezor_static_pubkey, h + ) + except Exception as e: + click.echo( + f"Exception of type{type(e)}", err=True + ) # TODO how to handle potential exceptions? Q for Matejcik + h = _sha256_of_two(h, encrypted_trezor_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) + ) + aes_ctx = AESGCM(k) + + tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) + h = _sha256_of_two(h, tag_of_empty_string) + + # TODO: search for saved credentials + if host_static_privkey is not None and credential is not None: + host_static_pubkey = curve25519.get_public_key(host_static_privkey) + else: + credential = None + zeroes_32 = int.to_bytes(0, 32, "little") + temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + temp_host_static_pubkey = curve25519.get_public_key( + temp_host_static_privkey + ) + host_static_privkey = temp_host_static_privkey + host_static_pubkey = temp_host_static_pubkey + + aes_ctx = AESGCM(k) + encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h) + h = _sha256_of_two(h, encrypted_host_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey) + ) + msg_data = self.mapping.encode_without_wire_type( + messages.ThpHandshakeCompletionReqNoisePayload( + host_pairing_credential=credential, + ) + ) + + aes_ctx = AESGCM(k) + + encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) + h = _sha256_of_two(h, encrypted_payload[:-16]) + ha_completion_req_header = MessageHeader( + 0x12, + self.channel_id, + len(encrypted_host_static_pubkey) + + len(encrypted_payload) + + CHECKSUM_LENGTH, + ) + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + ha_completion_req_header, + encrypted_host_static_pubkey + encrypted_payload, + ) + self.handshake_hash = h + return ck + + def _read_handshake_completion_response(self) -> None: + # Read handshake completion response, ignore payload as we do not care about the state + header, _ = self._read_until_valid_crc_check() + if not header.is_handshake_comp_response(): + click.echo( + "Received message is not a valid handshake completion response", + err=True, + ) + self._send_ack_1() + + def _do_pairing(self, helper_debug: DebugLink | None): + + self._send_message(messages.ThpPairingRequest()) + self._read_message(messages.ButtonRequest) + self._send_message(messages.ButtonAck()) + + if helper_debug is not None: + helper_debug.press_yes() + + self._read_message(messages.ThpPairingRequestApproved) + self._send_message( + messages.ThpSelectMethod( + selected_pairing_method=messages.ThpPairingMethod.SkipPairing + ) + ) + self._read_message(messages.ThpEndResponse) + + self._has_valid_channel = True + + def _read_ack(self): + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + + def _send_ack_0(self): + LOG.debug("sending ack 0") + header = MessageHeader(0x20, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _send_ack_1(self): + LOG.debug("sending ack 1") + header = MessageHeader(0x28, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _encrypt_and_write( + self, + session_id: int, + message_type: int, + message_data: bytes, + ctrl_byte: int | None = None, + ) -> None: + assert self.key_request is not None + aes_ctx = AESGCM(self.key_request) + + if ctrl_byte is None: + ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send) + self.sync_bit_send = 1 - self.sync_bit_send + + sid = session_id.to_bytes(1, "big") + msg_type = message_type.to_bytes(2, "big") + data = sid + msg_type + message_data + nonce = _get_iv_from_nonce(self.nonce_request) + self.nonce_request += 1 + encrypted_message = aes_ctx.encrypt(nonce, data, b"") + header = MessageHeader( + ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH + ) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, header, encrypted_message + ) + + def read_and_decrypt(self) -> t.Tuple[int, int, bytes]: + header, raw_payload = self._read_until_valid_crc_check() + if control_byte.is_ack(header.ctrl_byte): + # TODO fix this recursion + return self.read_and_decrypt() + if control_byte.is_error(header.ctrl_byte): + # TODO check for different channel + err = _get_error_from_int(raw_payload[0]) + raise Exception("Received ThpError: " + err) + if not header.is_encrypted_transport(): + click.echo( + "Trying to decrypt not encrypted message! (" + + hexlify(header.to_bytes_init() + raw_payload).decode() + + ")", + err=True, + ) + + if not control_byte.is_ack(header.ctrl_byte): + LOG.debug( + "--> Get sequence bit %d %s %s", + control_byte.get_seq_bit(header.ctrl_byte), + "from control byte", + hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(), + ) + if control_byte.get_seq_bit(header.ctrl_byte): + self._send_ack_1() + else: + self._send_ack_0() + aes_ctx = AESGCM(self.key_response) + nonce = _get_iv_from_nonce(self.nonce_response) + self.nonce_response += 1 + + message = aes_ctx.decrypt(nonce, raw_payload, b"") + session_id = message[0] + message_type = message[1:3] + message_data = message[3:] + return ( + session_id, + int.from_bytes(message_type, "big"), + message_data, + ) + + def _read_until_valid_crc_check( + self, + ) -> t.Tuple[MessageHeader, bytes]: + is_valid = False + header, payload, chksum = thp_io.read(self.transport) + while not is_valid: + is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) + if not is_valid: + click.echo( + "Received a message with an invalid checksum:" + + hexlify(header.to_bytes_init() + payload + chksum).decode(), + err=True, + ) + header, payload, chksum = thp_io.read(self.transport) + + return header, payload + + def _is_valid_channel_allocation_response( + self, header: MessageHeader, payload: bytes, original_nonce: bytes + ) -> bool: + if not header.is_channel_allocation_response(): + click.echo( + "Received message is not a channel allocation response", err=True + ) + return False + if len(payload) < 10: + click.echo("Invalid channel allocation response payload", err=True) + return False + if payload[:8] != original_nonce: + click.echo( + "Invalid channel allocation response payload (nonce mismatch)", err=True + ) + return False + return True + + +def _get_error_from_int(error_code: int) -> str: + # TODO FIXME improve this (ThpErrorType) + if error_code == 1: + return "TRANSPORT BUSY" + if error_code == 2: + return "UNALLOCATED CHANNEL" + if error_code == 3: + return "DECRYPTION FAILED" + if error_code == 4: + return "INVALID DATA" + if error_code == 5: + return "DEVICE LOCKED" + raise Exception("Not Implemented error case") diff --git a/python/src/trezorlib/transport/thp/thp_io.py b/python/src/trezorlib/transport/thp/thp_io.py new file mode 100644 index 0000000000..d0237f9e36 --- /dev/null +++ b/python/src/trezorlib/transport/thp/thp_io.py @@ -0,0 +1,93 @@ +import struct +from typing import Tuple + +from .. import Transport +from ..thp import checksum +from .message_header import MessageHeader + +INIT_HEADER_LENGTH = 5 +CONT_HEADER_LENGTH = 3 +MAX_PAYLOAD_LEN = 60000 +MESSAGE_TYPE_LENGTH = 2 + +CONTINUATION_PACKET = 0x80 + + +def write_payload_to_wire_and_add_checksum( + transport: Transport, header: MessageHeader, transport_payload: bytes +): + chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload) + data = transport_payload + chksum + write_payload_to_wire(transport, header, data) + + +def write_payload_to_wire( + transport: Transport, header: MessageHeader, transport_payload: bytes +): + transport.open() + buffer = bytearray(transport_payload) + chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH] + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + + buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :] + while buffer: + chunk = ( + header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH] + ) + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :] + + +def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]: + """ + Reads from the given wire transport. + + Returns `Tuple[MessageHeader, bytes, bytes]`: + 1. `header` (`MessageHeader`): Header of the message. + 2. `data` (`bytes`): Contents of the message (if any). + 3. `checksum` (`bytes`): crc32 checksum of the header + data. + + """ + buffer = bytearray() + + # Read header with first part of message data + header, first_chunk = read_first(transport) + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < header.data_length: + buffer.extend(read_next(transport, header.cid)) + + data_len = header.data_length - checksum.CHECKSUM_LENGTH + msg_data = buffer[:data_len] + chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH] + + return (header, msg_data, chksum) + + +def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]: + chunk = transport.read_chunk() + try: + ctrl_byte, cid, data_length = struct.unpack( + MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH] + ) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[INIT_HEADER_LENGTH:] + return MessageHeader(ctrl_byte, cid, data_length), data + + +def read_next(transport: Transport, cid: int) -> bytes: + chunk = transport.read_chunk() + ctrl_byte, read_cid = struct.unpack( + MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH] + ) + if ctrl_byte != CONTINUATION_PACKET: + raise RuntimeError("Continuation packet with incorrect control byte") + if read_cid != cid: + raise RuntimeError("Continuation packet for different channel") + + return chunk[CONT_HEADER_LENGTH:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7e4c4614c6..e17d6f4500 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -14,14 +14,15 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import socket import time -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Tuple from ..log import DUMP_PACKETS -from . import TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import Transport, TransportException if TYPE_CHECKING: from ..models import TrezorModel @@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10 LOG = logging.getLogger(__name__) -class UdpTransport(ProtocolBasedTransport): +class UdpTransport(Transport): DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 21324 PATH_PREFIX = "udp" ENABLED: bool = True + CHUNK_SIZE = 64 - def __init__(self, device: Optional[str] = None) -> None: + def __init__( + self, + device: str | None = None, + ) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport): devparts = device.split(":") host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT - self.device = (host, port) - self.socket: Optional[socket.socket] = None + self.device: Tuple[str, int] = (host, port) - super().__init__(protocol=ProtocolV1(self)) - - def get_path(self) -> str: - return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) - - def find_debug(self) -> "UdpTransport": - host, port = self.device - return UdpTransport(f"{host}:{port + 1}") + self.socket: socket.socket | None = None + super().__init__() @classmethod def _try_path(cls, path: str) -> "UdpTransport": d = cls(path) try: d.open() - if d._ping(): + if d.ping(): return d else: raise TransportException( @@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport): @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None + cls, _models: Iterable["TrezorModel"] | None = None ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: @@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport): else: raise TransportException(f"No UDP device at {path}") - def wait_until_ready(self, timeout: float = 10) -> None: - try: - self.open() - start = time.monotonic() - while True: - if self._ping(): - break - elapsed = time.monotonic() - start - if elapsed >= timeout: - raise TransportException("Timed out waiting for connection.") - - time.sleep(0.05) - finally: - self.close() + def get_path(self) -> str: + return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) def open(self) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport): self.socket.close() self.socket = None - def _ping(self) -> bool: - """Test if the device is listening.""" - assert self.socket is not None - resp = None - try: - self.socket.sendall(b"PINGPING") - resp = self.socket.recv(8) - except Exception: - pass - return resp == b"PONGPONG" - def write_chunk(self, chunk: bytes) -> None: + if self.socket is None: + self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport): self.socket.sendall(chunk) def read_chunk(self) -> bytes: + if self.socket is None: + self.open() assert self.socket is not None while True: try: @@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport): if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return bytearray(chunk) + + def find_debug(self) -> "UdpTransport": + host, port = self.device + return UdpTransport(f"{host}:{port + 1}") + + def wait_until_ready(self, timeout: float = 10) -> None: + try: + self.open() + start = time.monotonic() + while True: + if self.ping(): + break + elapsed = time.monotonic() - start + if elapsed >= timeout: + raise TransportException("Timed out waiting for connection.") + + time.sleep(0.05) + finally: + self.close() + + def ping(self) -> bool: + """Test if the device is listening.""" + assert self.socket is not None + resp = None + try: + self.socket.sendall(b"PINGPING") + resp = self.socket.recv(8) + except Exception: + pass + return resp == b"PONGPONG" diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 8e2d08147a..872d961960 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -14,16 +14,17 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import atexit import logging import sys import time -from typing import Iterable, List, Optional +from typing import Iterable, List from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel -from . import UDEV_RULES_STR, DeviceIsBusy, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException LOG = logging.getLogger(__name__) @@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300 WEBUSB_CHUNK_SIZE = 64 -class WebUsbHandle: - def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: +class WebUsbTransport(Transport): + """ + WebUsbTransport implements transport over WebUSB interface. + """ + + PATH_PREFIX = "webusb" + ENABLED = USB_IMPORTED + context = None + CHUNK_SIZE = 64 + + def __init__( + self, + device: "usb1.USBDevice", + debug: bool = False, + ) -> None: + self.device = device + self.debug = debug + self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT - self.count = 0 - self.handle: Optional["usb1.USBDeviceHandle"] = None + self.handle: usb1.USBDeviceHandle | None = None + + super().__init__() + + @classmethod + def enumerate( + cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["WebUsbTransport"] = [] + for dev in cls.context.getDeviceIterator(skip_on_error=True): + usb_id = (dev.getVendorID(), dev.getProductID()) + if usb_id not in usb_ids: + continue + if not is_vendor_class(dev): + continue + if usb_reset: + handle = dev.open() + handle.resetDevice() + handle.close() + continue + try: + # workaround for issue #223: + # on certain combinations of Windows USB drivers and libusb versions, + # Trezor is returned twice (possibly because Windows know it as both + # a HID and a WebUSB device), and one of the returned devices is + # non-functional. + dev.getProduct() + devices.append(WebUsbTransport(dev)) + except usb1.USBErrorNotSupported: + pass + return devices + + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" def open(self) -> None: self.handle = self.device.open() @@ -64,6 +121,8 @@ class WebUsbHandle: self.handle.claimInterface(self.interface) except usb1.USBErrorAccess as e: raise DeviceIsBusy(self.device) from e + except usb1.USBErrorBusy as e: + raise DeviceIsBusy(self.device) from e def close(self) -> None: if self.handle is not None: @@ -75,6 +134,8 @@ class WebUsbHandle: self.handle = None def write_chunk(self, chunk: bytes) -> None: + if self.handle is None: + self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -97,6 +158,8 @@ class WebUsbHandle: return def read_chunk(self) -> bytes: + if self.handle is None: + self.open() assert self.handle is not None endpoint = 0x80 | self.endpoint while True: @@ -117,70 +180,6 @@ class WebUsbHandle: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return chunk - -class WebUsbTransport(ProtocolBasedTransport): - """ - WebUsbTransport implements transport over WebUSB interface. - """ - - PATH_PREFIX = "webusb" - ENABLED = USB_IMPORTED - context = None - - def __init__( - self, - device: "usb1.USBDevice", - handle: Optional[WebUsbHandle] = None, - debug: bool = False, - ) -> None: - if handle is None: - handle = WebUsbHandle(device, debug) - - self.device = device - self.handle = handle - self.debug = debug - - super().__init__(protocol=ProtocolV1(handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False - ) -> Iterable["WebUsbTransport"]: - if cls.context is None: - cls.context = usb1.USBContext() - cls.context.open() - atexit.register(cls.context.close) - - if models is None: - models = TREZORS - usb_ids = [id for model in models for id in model.usb_ids] - devices: List["WebUsbTransport"] = [] - for dev in cls.context.getDeviceIterator(skip_on_error=True): - usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in usb_ids: - continue - if not is_vendor_class(dev): - continue - try: - # workaround for issue #223: - # on certain combinations of Windows USB drivers and libusb versions, - # Trezor is returned twice (possibly because Windows know it as both - # a HID and a WebUSB device), and one of the returned devices is - # non-functional. - dev.getProduct() - devices.append(WebUsbTransport(dev)) - except usb1.USBErrorNotSupported: - pass - except usb1.USBErrorPipe: - if usb_reset: - handle = dev.open() - handle.resetDevice() - handle.close() - return devices - def find_debug(self) -> "WebUsbTransport": # For v1 protocol, find debug USB interface for the same serial number return WebUsbTransport(self.device, debug=True)