1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-07 21:22:41 +00:00

chore(core): adapt trezorlib transports to session based

[no changelog]
This commit is contained in:
M1nd3r 2025-02-04 15:21:19 +01:00
parent fbff05a89f
commit 61b2156a1e
19 changed files with 1896 additions and 426 deletions

View File

@ -14,24 +14,18 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
from typing import ( import typing as t
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from ..exceptions import TrezorException from ..exceptions import TrezorException
if TYPE_CHECKING: if t.TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
T = TypeVar("T", bound="Transport") T = t.TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip() """.strip()
MessagePayload = Tuple[int, bytes] MessagePayload = t.Tuple[int, bytes]
class TransportException(TrezorException): class TransportException(TrezorException):
@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException):
class Transport: class Transport:
"""Raw connection to a Trezor device.
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
PATH_PREFIX: str PATH_PREFIX: str
ENABLED = False
def __str__(self) -> str: @classmethod
return self.get_path() def enumerate(
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
) -> t.Iterable["T"]:
raise NotImplementedError
@classmethod
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if device.get_path() == path:
return device
if prefix_search and device.get_path().startswith(path):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str: def get_path(self) -> str:
raise NotImplementedError raise NotImplementedError
def begin_session(self) -> None:
raise NotImplementedError
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
def find_debug(self: "T") -> "T": def find_debug(self: "T") -> "T":
raise NotImplementedError raise NotImplementedError
@classmethod def open(self) -> None:
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
raise NotImplementedError raise NotImplementedError
@classmethod def close(self) -> None:
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": raise NotImplementedError
for device in cls.enumerate():
if (
path is None
or device.get_path() == path
or (prefix_search and device.get_path().startswith(path))
):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def write_chunk(self, chunk: bytes) -> None:
raise NotImplementedError
def read_chunk(self) -> bytes:
raise NotImplementedError
CHUNK_SIZE: t.ClassVar[int]
def all_transports() -> Iterable[Type["Transport"]]: def all_transports() -> t.Iterable[t.Type["Transport"]]:
from .bridge import BridgeTransport from .bridge import BridgeTransport
from .hid import HidTransport from .hid import HidTransport
from .udp import UdpTransport from .udp import UdpTransport
from .webusb import WebUsbTransport from .webusb import WebUsbTransport
transports: Tuple[Type["Transport"], ...] = ( transports: t.Tuple[t.Type["Transport"], ...] = (
BridgeTransport, BridgeTransport,
HidTransport, HidTransport,
UdpTransport, UdpTransport,
@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
def enumerate_devices( def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None, models: t.Iterable["TrezorModel"] | None = None,
) -> Sequence["Transport"]: ) -> t.Sequence["Transport"]:
devices: List["Transport"] = [] devices: t.List["Transport"] = []
for transport in all_transports(): for transport in all_transports():
name = transport.__name__ name = transport.__name__
try: try:
@ -145,9 +121,7 @@ def enumerate_devices(
return devices return devices
def get_transport( def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
if path is None: if path is None:
try: try:
return next(iter(enumerate_devices())) return next(iter(enumerate_devices()))

View File

@ -14,24 +14,30 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import struct import struct
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional import typing as t
import requests import requests
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException from . import DeviceIsBusy, MessagePayload, Transport, TransportException
if TYPE_CHECKING: if t.TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
PROTOCOL_VERSION_1 = 1
PROTOCOL_VERSION_2 = 2
TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_HOST = "http://127.0.0.1:21325"
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
TREZORD_VERSION_MODERN = (2, 0, 25) TREZORD_VERSION_MODERN = (2, 0, 25)
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
CONNECTION = requests.Session() CONNECTION = requests.Session()
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
@ -45,7 +51,7 @@ class BridgeException(TransportException):
super().__init__(f"trezord: {path} failed with code {status}: {message}") super().__init__(f"trezord: {path} failed with code {status}: {message}")
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: def call_bridge(path: str, data: str | None = None) -> requests.Response:
url = TREZORD_HOST + "/" + path url = TREZORD_HOST + "/" + path
r = CONNECTION.post(url, data=data) r = CONNECTION.post(url, data=data)
if r.status_code != 200: if r.status_code != 200:
@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
return r return r
def is_legacy_bridge() -> bool: def get_bridge_version() -> t.Tuple[int, ...]:
config = call_bridge("configure").json() config = call_bridge("configure").json()
version_tuple = tuple(map(int, config["version"].split("."))) return tuple(map(int, config["version"].split(".")))
return version_tuple < TREZORD_VERSION_MODERN
def is_legacy_bridge() -> bool:
return get_bridge_version() < TREZORD_VERSION_MODERN
def supports_protocolV2() -> bool:
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
def detect_protocol_version(transport: "BridgeTransport") -> int:
from .. import mapping, messages
from ..messages import FailureType
protocol_version = PROTOCOL_VERSION_1
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
transport.deprecated_begin_session()
transport.deprecated_write(request_type, request_data)
response_type, response_data = transport.deprecated_read()
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
transport.deprecated_begin_session()
if isinstance(response, messages.Failure):
if response.code == FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol_version = PROTOCOL_VERSION_2
return protocol_version
def _is_transport_valid(transport: "BridgeTransport") -> bool:
is_valid = (
supports_protocolV2()
or detect_protocol_version(transport) == PROTOCOL_VERSION_1
)
if not is_valid:
LOG.warning("Detected unsupported Bridge transport!")
return is_valid
def filter_invalid_bridge_transports(
transports: t.Iterable["BridgeTransport"],
) -> t.Sequence["BridgeTransport"]:
"""Filters out invalid bridge transports. Keeps only valid ones."""
return [t for t in transports if _is_transport_valid(t)]
class BridgeHandle: class BridgeHandle:
@ -84,7 +134,7 @@ class BridgeHandleModern(BridgeHandle):
class BridgeHandleLegacy(BridgeHandle): class BridgeHandleLegacy(BridgeHandle):
def __init__(self, transport: "BridgeTransport") -> None: def __init__(self, transport: "BridgeTransport") -> None:
super().__init__(transport) super().__init__(transport)
self.request: Optional[str] = None self.request: str | None = None
def write_buf(self, buf: bytes) -> None: def write_buf(self, buf: bytes) -> None:
if self.request is not None: if self.request is not None:
@ -112,13 +162,12 @@ class BridgeTransport(Transport):
ENABLED: bool = True ENABLED: bool = True
def __init__( def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
) -> None: ) -> None:
if legacy and debug: if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge") raise TransportException("Debugging not supported on legacy Bridge")
self.device = device self.device = device
self.session: Optional[str] = None self.session: str | None = device["session"]
self.debug = debug self.debug = debug
self.legacy = legacy self.legacy = legacy
@ -135,7 +184,7 @@ class BridgeTransport(Transport):
raise TransportException("Debug device not available") raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True) return BridgeTransport(self.device, self.legacy, debug=True)
def _call(self, action: str, data: Optional[str] = None) -> requests.Response: def _call(self, action: str, data: str | None = None) -> requests.Response:
session = self.session or "null" session = self.session or "null"
uri = action + "/" + str(session) uri = action + "/" + str(session)
if self.debug: if self.debug:
@ -144,17 +193,20 @@ class BridgeTransport(Transport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: t.Iterable["TrezorModel"] | None = None
) -> Iterable["BridgeTransport"]: ) -> t.Iterable["BridgeTransport"]:
try: try:
legacy = is_legacy_bridge() legacy = is_legacy_bridge()
return [ return filter_invalid_bridge_transports(
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() [
] BridgeTransport(dev, legacy)
for dev in call_bridge("enumerate").json()
]
)
except Exception: except Exception:
return [] return []
def begin_session(self) -> None: def deprecated_begin_session(self) -> None:
try: try:
data = self._call("acquire/" + self.device["path"]) data = self._call("acquire/" + self.device["path"])
except BridgeException as e: except BridgeException as e:
@ -163,18 +215,32 @@ class BridgeTransport(Transport):
raise raise
self.session = data.json()["session"] self.session = data.json()["session"]
def end_session(self) -> None: def deprecated_end_session(self) -> None:
if not self.session: if not self.session:
return return
self._call("release") self._call("release")
self.session = None self.session = None
def write(self, message_type: int, message_data: bytes) -> None: def deprecated_write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data)) header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data) self.handle.write_buf(header + message_data)
def read(self) -> MessagePayload: def deprecated_read(self) -> MessagePayload:
data = self.handle.read_buf() data = self.handle.read_buf()
headerlen = struct.calcsize(">HL") headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen]) msg_type, datalen = struct.unpack(">HL", data[:headerlen])
return msg_type, data[headerlen : headerlen + datalen] return msg_type, data[headerlen : headerlen + datalen]
def open(self) -> None:
pass
# TODO self.handle.open()
def close(self) -> None:
pass
# TODO self.handle.close()
def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :)
self.handle.write_buf(chunk)
def read_chunk(self) -> bytes: # TODO check if it works :)
return self.handle.read_buf()

View File

@ -14,15 +14,16 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import sys import sys
import time import time
from typing import Any, Dict, Iterable, List, Optional import typing as t
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException from . import UDEV_RULES_STR, Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -35,23 +36,61 @@ except Exception as e:
HID_IMPORTED = False HID_IMPORTED = False
HidDevice = Dict[str, Any] HidDevice = t.Dict[str, t.Any]
HidDeviceHandle = Any HidDeviceHandle = t.Any
class HidHandle: class HidTransport(Transport):
def __init__( """
self, path: bytes, serial: str, probe_hid_version: bool = False HidTransport implements transport over USB HID interface.
) -> None: """
self.path = path
self.serial = serial PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
self.device = device
self.device_path = device["path"]
self.device_serial_number = device["serial_number"]
self.handle: HidDeviceHandle = None self.handle: HidDeviceHandle = None
self.hid_version = None if probe_hid_version else 2 self.hid_version = None if probe_hid_version else 2
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
) -> t.Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: t.List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def open(self) -> None: def open(self) -> None:
self.handle = hid.device() self.handle = hid.device()
try: try:
self.handle.open_path(self.path) self.handle.open_path(self.device_path)
except (IOError, OSError) as e: except (IOError, OSError) as e:
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
e.args = e.args + (UDEV_RULES_STR,) e.args = e.args + (UDEV_RULES_STR,)
@ -62,11 +101,11 @@ class HidHandle:
# and we wouldn't even know. # and we wouldn't even know.
# So we check that the serial matches what we expect. # So we check that the serial matches what we expect.
serial = self.handle.get_serial_number_string() serial = self.handle.get_serial_number_string()
if serial != self.serial: if serial != self.device_serial_number:
self.handle.close() self.handle.close()
self.handle = None self.handle = None
raise TransportException( raise TransportException(
f"Unexpected device {serial} on path {self.path.decode()}" f"Unexpected device {serial} on path {self.device_path.decode()}"
) )
self.handle.set_nonblocking(True) self.handle.set_nonblocking(True)
@ -77,7 +116,7 @@ class HidHandle:
def close(self) -> None: def close(self) -> None:
if self.handle is not None: if self.handle is not None:
# reload serial, because device.wipe() can reset it # reload serial, because device.wipe() can reset it
self.serial = self.handle.get_serial_number_string() self.device_serial_number = self.handle.get_serial_number_string()
self.handle.close() self.handle.close()
self.handle = None self.handle = None
@ -115,53 +154,6 @@ class HidHandle:
raise TransportException("Unknown HID version") raise TransportException("Unknown HID version")
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None:
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
super().__init__(protocol=ProtocolV1(self.handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def is_wirelink(dev: HidDevice) -> bool: def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0

View File

@ -1,165 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . import MessagePayload, Transport
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None: ...
def close(self) -> None: ...
def read_chunk(self) -> bytes: ...
def write_chunk(self, chunk: bytes) -> None: ...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
def __init__(self, handle: Handle) -> None:
self.handle = handle
self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
self.session_counter += 1
def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)
def read(self) -> MessagePayload:
return self.protocol.read()
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
HEADER_LEN = struct.calcsize(">HL")
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1]
chunk = chunk.ljust(REPLEN, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

View File

@ -0,0 +1,223 @@
from __future__ import annotations
import logging
import typing as t
from .. import exceptions, messages, models
from ..protobuf import MessageType
from .thp.protocol_v1 import ProtocolV1
from .thp.protocol_v2 import ProtocolV2
if t.TYPE_CHECKING:
from ..client import TrezorClient
LOG = logging.getLogger(__name__)
MT = t.TypeVar("MT", bound=MessageType)
class Session:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
def __init__(
self, client: TrezorClient, id: bytes, passphrase: str | object | None = None
) -> None:
self.client = client
self._id = id
self.passphrase = passphrase
@classmethod
def new(
cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool
) -> Session:
raise NotImplementedError
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
# TODO self.check_firmware_version()
resp = self.call_raw(msg)
while True:
if isinstance(resp, messages.PinMatrixRequest):
if self.pin_callback is None:
raise Exception # TODO
resp = self.pin_callback(self, resp)
elif isinstance(resp, messages.PassphraseRequest):
if self.passphrase_callback is None:
raise Exception # TODO
resp = self.passphrase_callback(self, resp)
elif isinstance(resp, messages.ButtonRequest):
if self.button_callback is None:
raise Exception # TODO
resp = self.button_callback(self, resp)
elif isinstance(resp, messages.Failure):
if resp.code == messages.FailureType.ActionCancelled:
raise exceptions.Cancelled
raise exceptions.TrezorFailure(resp)
elif not isinstance(resp, expect):
raise exceptions.UnexpectedMessageError(expect, resp)
else:
return resp
def call_raw(self, msg: t.Any) -> t.Any:
self._write(msg)
return self._read()
def _write(self, msg: t.Any) -> None:
raise NotImplementedError
def _read(self) -> t.Any:
raise NotImplementedError
def refresh_features(self) -> None:
self.client.refresh_features()
def end(self) -> t.Any:
return self.call(messages.EndSession())
def ping(self, message: str, button_protection: bool | None = None) -> str:
resp = self.call(
messages.Ping(message=message, button_protection=button_protection),
expect=messages.Success,
)
assert resp.message is not None
return resp.message
def invalidate(self) -> None:
self.client.invalidate()
@property
def features(self) -> messages.Features:
return self.client.features
@property
def model(self) -> models.TrezorModel:
return self.client.model
@property
def version(self) -> t.Tuple[int, int, int]:
return self.client.version
@property
def id(self) -> bytes:
return self._id
@id.setter
def id(self, value: bytes) -> None:
if not isinstance(value, bytes):
raise ValueError("id must be of type bytes")
self._id = value
class SessionV1(Session):
derive_cardano: bool | None = False
@classmethod
def new(
cls,
client: TrezorClient,
passphrase: str | object = "",
derive_cardano: bool = False,
session_id: bytes | None = None,
) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, id=session_id or b"")
session._init_callbacks()
session.passphrase = passphrase
session.derive_cardano = derive_cardano
session.init_session(session.derive_cardano)
return session
@classmethod
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
assert isinstance(client.protocol, ProtocolV1)
session = SessionV1(client, session_id)
session.init_session()
return session
def _init_callbacks(self) -> None:
self.button_callback = self.client.button_callback
if self.button_callback is None:
self.button_callback = _callback_button
self.pin_callback = self.client.pin_callback
self.passphrase_callback = self.client.passphrase_callback
def _write(self, msg: t.Any) -> None:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1)
self.client.protocol.write(msg)
def _read(self) -> t.Any:
if t.TYPE_CHECKING:
assert isinstance(self.client.protocol, ProtocolV1)
return self.client.protocol.read()
def init_session(self, derive_cardano: bool | None = None):
if self.id == b"":
session_id = None
else:
session_id = self.id
resp: messages.Features = self.call_raw(
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
)
if isinstance(self.passphrase, str):
self.passphrase_callback = self.client.passphrase_callback
self._id = resp.session_id
def _callback_button(session: Session, msg: t.Any) -> t.Any:
print("Please confirm action on your Trezor device.") # TODO how to handle UI?
return session.call(messages.ButtonAck())
class SessionV2(Session):
@classmethod
def new(
cls,
client: TrezorClient,
passphrase: str | None,
derive_cardano: bool,
session_id: int = 0,
) -> SessionV2:
assert isinstance(client.protocol, ProtocolV2)
session = cls(client, session_id.to_bytes(1, "big"))
session.call(
messages.ThpCreateNewSession(
passphrase=passphrase, derive_cardano=derive_cardano
),
expect=messages.Success,
)
session.update_id_and_sid(session_id.to_bytes(1, "big"))
return session
def __init__(self, client: TrezorClient, id: bytes) -> None:
from ..debuglink import TrezorClientDebugLink
super().__init__(client, id)
assert isinstance(client.protocol, ProtocolV2)
self.pin_callback = client.pin_callback
self.button_callback = client.button_callback
if self.button_callback is None:
self.button_callback = _callback_button
helper_debug = None
if isinstance(client, TrezorClientDebugLink):
helper_debug = client.debug
self.channel: ProtocolV2 = client.protocol.get_channel(helper_debug)
self.update_id_and_sid(id)
def _write(self, msg: t.Any) -> None:
LOG.debug("writing message %s", type(msg))
self.channel.write(self.sid, msg)
def _read(self) -> t.Any:
msg = self.channel.read(self.sid)
LOG.debug("reading message %s", type(msg))
return msg
def update_id_and_sid(self, id: bytes) -> None:
self._id = id
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid

View File

@ -0,0 +1,102 @@
# from storage.cache_thp import ChannelCache
# from trezor import log
# from trezor.wire.thp import ThpError
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
# """
# Checks if:
# - an ACK message is expected
# - the received ACK message acknowledges correct sequence number (bit)
# """
# if not _is_ack_expected(cache):
# return False
# if not _has_ack_correct_sync_bit(cache, ack_bit):
# return False
# return True
# def _is_ack_expected(cache: ChannelCache) -> bool:
# is_expected: bool = not is_sending_allowed(cache)
# if __debug__ and not is_expected:
# log.debug(__name__, "Received unexpected ACK message")
# return is_expected
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
# if __debug__ and not is_correct:
# log.debug(__name__, "Received ACK message with wrong ack bit")
# return is_correct
# def is_sending_allowed(cache: ChannelCache) -> bool:
# """
# Checks whether sending a message in the provided channel is allowed.
# Note: Sending a message in a channel before receipt of ACK message for the previously
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
# """
# return bool(cache.sync >> 7)
# def get_send_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the sequential number (bit) of the next message to be sent
# in the provided channel.
# """
# return (cache.sync & 0x20) >> 5
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
# """
# Returns the (expected) sequential number (bit) of the next message
# to be received in the provided channel.
# """
# return (cache.sync & 0x40) >> 6
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
# """
# Set the flag whether sending a message in this channel is allowed or not.
# """
# cache.sync &= 0x7F
# if sending_allowed:
# cache.sync |= 0x80
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# """
# Set the expected sequential number (bit) of the next message to be received
# in the provided channel
# """
# if __debug__:
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected receive sync bit")
# # set second bit to "seq_bit" value
# cache.sync &= 0xBF
# if seq_bit:
# cache.sync |= 0x40
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
# if seq_bit not in (0, 1):
# raise ThpError("Unexpected send seq bit")
# if __debug__:
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
# # set third bit to "seq_bit" value
# cache.sync &= 0xDF
# if seq_bit:
# cache.sync |= 0x20
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
# """
# Set the sequential bit of the "next message to be send" to the opposite value,
# i.e. 1 -> 0 and 0 -> 1
# """
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))

View File

@ -0,0 +1,47 @@
from __future__ import annotations
from binascii import hexlify
class ChannelData:
def __init__(
self,
protocol_version_major: int,
protocol_version_minor: int,
transport_path: str,
channel_id: int,
key_request: bytes,
key_response: bytes,
nonce_request: int,
nonce_response: int,
sync_bit_send: int,
sync_bit_receive: int,
handshake_hash: bytes,
) -> None:
self.protocol_version_major: int = protocol_version_major
self.protocol_version_minor: int = protocol_version_minor
self.transport_path: str = transport_path
self.channel_id: int = channel_id
self.key_request: str = hexlify(key_request).decode()
self.key_response: str = hexlify(key_response).decode()
self.nonce_request: int = nonce_request
self.nonce_response: int = nonce_response
self.sync_bit_receive: int = sync_bit_receive
self.sync_bit_send: int = sync_bit_send
self.handshake_hash: str = hexlify(handshake_hash).decode()
def to_dict(self):
return {
"protocol_version_major": self.protocol_version_major,
"protocol_version_minor": self.protocol_version_minor,
"transport_path": self.transport_path,
"channel_id": self.channel_id,
"key_request": self.key_request,
"key_response": self.key_response,
"nonce_request": self.nonce_request,
"nonce_response": self.nonce_response,
"sync_bit_send": self.sync_bit_send,
"sync_bit_receive": self.sync_bit_receive,
"handshake_hash": self.handshake_hash,
}

View File

@ -0,0 +1,148 @@
from __future__ import annotations
import json
import logging
import os
import typing as t
from ..thp.channel_data import ChannelData
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
db: "ChannelDatabase | None" = None
def get_channel_db() -> ChannelDatabase:
if db is None:
set_channel_database(should_not_store=True)
assert db is not None
return db
class ChannelDatabase:
def load_stored_channels(self) -> t.List[ChannelData]: ...
def clear_stored_channels(self) -> None: ...
def read_all_channels(self) -> t.List: ...
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
def save_channel(self, new_channel: ProtocolAndChannel): ...
def remove_channel(self, transport_path: str) -> None: ...
class DummyChannelDatabase(ChannelDatabase):
def load_stored_channels(self) -> t.List[ChannelData]:
return []
def clear_stored_channels(self) -> None:
pass
def read_all_channels(self) -> t.List:
return []
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
return
def save_channel(self, new_channel: ProtocolAndChannel):
pass
def remove_channel(self, transport_path: str) -> None:
pass
class JsonChannelDatabase(ChannelDatabase):
def __init__(self, data_path: str) -> None:
self.data_path = data_path
super().__init__()
def load_stored_channels(self) -> t.List[ChannelData]:
dicts = self.read_all_channels()
return [dict_to_channel_data(d) for d in dicts]
def clear_stored_channels(self) -> None:
LOG.debug("Clearing contents of %s", self.data_path)
with open(self.data_path, "w") as f:
json.dump([], f)
try:
os.remove(self.data_path)
except Exception as e:
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
def read_all_channels(self) -> t.List:
ensure_file_exists(self.data_path)
with open(self.data_path, "r") as f:
return json.load(f)
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
LOG.debug("saving all channels")
with open(self.data_path, "w") as f:
json.dump(channels, f, indent=4)
def save_channel(self, new_channel: ProtocolAndChannel):
LOG.debug("save channel")
channels = self.read_all_channels()
transport_path = new_channel.transport.get_path()
# If the channel is found in database: replace the old entry by the new
for i, channel in enumerate(channels):
if channel["transport_path"] == transport_path:
LOG.debug("Modified channel entry for %s", transport_path)
channels[i] = new_channel.get_channel_data().to_dict()
self.save_all_channels(channels)
return
# Channel was not found: add a new channel entry
LOG.debug("Created a new channel entry on path %s", transport_path)
channels.append(new_channel.get_channel_data().to_dict())
self.save_all_channels(channels)
def remove_channel(self, transport_path: str) -> None:
LOG.debug(
"Removing channel with path %s from the channel database.",
transport_path,
)
channels = self.read_all_channels()
remaining_channels = [
ch for ch in channels if ch["transport_path"] != transport_path
]
self.save_all_channels(remaining_channels)
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
return ChannelData(
protocol_version_major=dict["protocol_version_minor"],
protocol_version_minor=dict["protocol_version_major"],
transport_path=dict["transport_path"],
channel_id=dict["channel_id"],
key_request=bytes.fromhex(dict["key_request"]),
key_response=bytes.fromhex(dict["key_response"]),
nonce_request=dict["nonce_request"],
nonce_response=dict["nonce_response"],
sync_bit_send=dict["sync_bit_send"],
sync_bit_receive=dict["sync_bit_receive"],
handshake_hash=bytes.fromhex(dict["handshake_hash"]),
)
def ensure_file_exists(file_path: str) -> None:
LOG.debug("checking if file %s exists", file_path)
if not os.path.exists(file_path):
os.makedirs(os.path.dirname(file_path), exist_ok=True)
LOG.debug("File %s does not exist. Creating a new one.", file_path)
with open(file_path, "w") as f:
json.dump([], f)
def set_channel_database(should_not_store: bool):
global db
if should_not_store:
db = DummyChannelDatabase()
else:
from platformdirs import user_cache_dir
APP_NAME = "@trezor" # TODO
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
db = JsonChannelDatabase(DATA_PATH)

View File

@ -0,0 +1,19 @@
import zlib
CHECKSUM_LENGTH = 4
def compute(data: bytes) -> bytes:
"""
Returns a CRC-32 checksum of the provided `data`.
"""
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
def is_valid(checksum: bytes, data: bytes) -> bool:
"""
Checks whether the CRC-32 checksum of the `data` is the same
as the checksum provided in `checksum`.
"""
data_checksum = compute(data)
return checksum == data_checksum

View File

@ -0,0 +1,63 @@
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
if seq_bit == 0:
return ctrl_byte & 0xEF
if seq_bit == 1:
return ctrl_byte | 0x10
raise Exception("Unexpected sequence bit")
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
if ack_bit == 0:
return ctrl_byte & 0xF7
if ack_bit == 1:
return ctrl_byte | 0x08
raise Exception("Unexpected acknowledgement bit")
def get_seq_bit(ctrl_byte: int) -> int:
return (ctrl_byte & 0x10) >> 4
def is_ack(ctrl_byte: int) -> bool:
return ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_error(ctrl_byte: int) -> bool:
return ctrl_byte == _ERROR
def is_continuation(ctrl_byte: int) -> bool:
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
def is_encrypted_transport(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
def is_handshake_init_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
def is_handshake_comp_req(ctrl_byte: int) -> bool:
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ

View File

@ -0,0 +1,40 @@
import typing as t
from hashlib import sha512
from . import curve25519
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
class Cpace:
"""
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
"""
random_bytes: t.Callable[[int], bytes]
def __init__(self, handshake_hash: bytes) -> None:
self.handshake_hash: bytes = handshake_hash
self.shared_secret: bytes
self.host_private_key: bytes
self.host_public_key: bytes
def generate_keys_and_secret(
self, code_code_entry: bytes, trezor_public_key: bytes
) -> None:
"""
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
"""
sha_ctx = sha512(_PREFIX)
sha_ctx.update(code_code_entry)
sha_ctx.update(_PADDING)
sha_ctx.update(self.handshake_hash)
sha_ctx.update(b"\x00")
pregenerator = sha_ctx.digest()[:32]
generator = curve25519.elligator2(pregenerator)
self.host_private_key = self.random_bytes(32)
self.host_public_key = curve25519.multiply(self.host_private_key, generator)
self.shared_secret = curve25519.multiply(
self.host_private_key, trezor_public_key
)

View File

@ -0,0 +1,159 @@
from typing import Tuple
p = 2**255 - 19
J = 486662
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
a24 = 121666 # (J + 2) // 4
def decode_scalar(scalar: bytes) -> int:
# decodeScalar25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(scalar) != 32:
raise ValueError("Invalid length of scalar")
array = bytearray(scalar)
array[0] &= 248
array[31] &= 127
array[31] |= 64
return int.from_bytes(array, "little")
def decode_coordinate(coordinate: bytes) -> int:
# decodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
if len(coordinate) != 32:
raise ValueError("Invalid length of coordinate")
array = bytearray(coordinate)
array[-1] &= 0x7F
return int.from_bytes(array, "little") % p
def encode_coordinate(coordinate: int) -> bytes:
# encodeUCoordinate from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
return coordinate.to_bytes(32, "little")
def get_private_key(secret: bytes) -> bytes:
return decode_scalar(secret).to_bytes(32, "little")
def get_public_key(private_key: bytes) -> bytes:
base_point = int.to_bytes(9, 32, "little")
return multiply(private_key, base_point)
def multiply(private_scalar: bytes, public_point: bytes):
# X25519 from
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
def ladder_operation(
x1: int, x2: int, z2: int, x3: int, z3: int
) -> Tuple[int, int, int, int]:
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
# (x4, z4) = 2 * (x2, z2)
# (x5, z5) = (x2, z2) + (x3, z3)
# where (x1, 1) = (x3, z3) - (x2, z2)
a = (x2 + z2) % p
aa = (a * a) % p
b = (x2 - z2) % p
bb = (b * b) % p
e = (aa - bb) % p
c = (x3 + z3) % p
d = (x3 - z3) % p
da = (d * a) % p
cb = (c * b) % p
t0 = (da + cb) % p
x5 = (t0 * t0) % p
t1 = (da - cb) % p
t2 = (t1 * t1) % p
z5 = (x1 * t2) % p
x4 = (aa * bb) % p
t3 = (a24 * e) % p
t4 = (bb + t3) % p
z4 = (e * t4) % p
return x4, z4, x5, z5
def conditional_swap(first: int, second: int, condition: int):
# Returns (second, first) if condition is true and (first, second) otherwise
# Must be implemented in a way that it is constant time
true_mask = -condition
false_mask = ~true_mask
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
first & true_mask
)
k = decode_scalar(private_scalar)
u = decode_coordinate(public_point)
x_1 = u
x_2 = 1
z_2 = 0
x_3 = u
z_3 = 1
swap = 0
for i in reversed(range(256)):
bit = (k >> i) & 1
swap = bit ^ swap
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
swap = bit
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
x = pow(z_2, p - 2, p) * x_2 % p
return encode_coordinate(x)
def elligator2(point: bytes) -> bytes:
# map_to_curve_elligator2_curve25519 from
# https://www.rfc-editor.org/rfc/rfc9380.html#ell2-opt
def conditional_move(first: int, second: int, condition: bool):
# Returns second if condition is true and first otherwise
# Must be implemented in a way that it is constant time
true_mask = -condition
false_mask = ~true_mask
return (first & false_mask) | (second & true_mask)
u = decode_coordinate(point)
tv1 = (u * u) % p
tv1 = (2 * tv1) % p
xd = (tv1 + 1) % p
x1n = (-J) % p
tv2 = (xd * xd) % p
gxd = (tv2 * xd) % p
gx1 = (J * tv1) % p
gx1 = (gx1 * x1n) % p
gx1 = (gx1 + tv2) % p
gx1 = (gx1 * x1n) % p
tv3 = (gxd * gxd) % p
tv2 = (tv3 * tv3) % p
tv3 = (tv3 * gxd) % p
tv3 = (tv3 * gx1) % p
tv2 = (tv2 * tv3) % p
y11 = pow(tv2, c4, p)
y11 = (y11 * tv3) % p
y12 = (y11 * c3) % p
tv2 = (y11 * y11) % p
tv2 = (tv2 * gxd) % p
e1 = tv2 == gx1
y1 = conditional_move(y12, y11, e1)
x2n = (x1n * tv1) % p
tv2 = (y1 * y1) % p
tv2 = (tv2 * gxd) % p
e3 = tv2 == gx1
xn = conditional_move(x2n, x1n, e3)
x = xn * pow(xd, p - 2, p) % p
return encode_coordinate(x)

View File

@ -0,0 +1,82 @@
import struct
CODEC_V1 = 0x3F
CONTINUATION_PACKET = 0x80
HANDSHAKE_INIT_REQ = 0x00
HANDSHAKE_INIT_RES = 0x01
HANDSHAKE_COMP_REQ = 0x02
HANDSHAKE_COMP_RES = 0x03
ENCRYPTED_TRANSPORT = 0x04
CONTINUATION_PACKET_MASK = 0x80
ACK_MASK = 0xF7
DATA_MASK = 0xE7
ACK_MESSAGE = 0x20
_ERROR = 0x42
CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x41
TREZOR_STATE_UNPAIRED = b"\x00"
TREZOR_STATE_PAIRED = b"\x01"
BROADCAST_CHANNEL_ID = 0xFFFF
class MessageHeader:
format_str_init = ">BHH"
format_str_cont = ">BH"
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
self.ctrl_byte = ctrl_byte
self.cid = cid
self.data_length = length
def to_bytes_init(self) -> bytes:
return struct.pack(
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
)
def to_bytes_cont(self) -> bytes:
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_init,
buffer,
buffer_offset,
self.ctrl_byte,
self.cid,
self.data_length,
)
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
struct.pack_into(
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
)
def is_ack(self) -> bool:
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
def is_channel_allocation_response(self):
return (
self.cid == BROADCAST_CHANNEL_ID
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
)
def is_handshake_init_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
def is_handshake_comp_response(self) -> bool:
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
def is_encrypted_transport(self) -> bool:
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
@classmethod
def get_error_header(cls, cid: int, length: int):
return cls(_ERROR, cid, length)
@classmethod
def get_channel_allocation_request_header(cls, length: int):
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)

View File

@ -0,0 +1,32 @@
from __future__ import annotations
import logging
from ... import messages
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp.channel_data import ChannelData
LOG = logging.getLogger(__name__)
class ProtocolAndChannel:
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.transport = transport
self.mapping = mapping
self.channel_keys = channel_data
def get_features(self) -> messages.Features:
raise NotImplementedError()
def get_channel_data(self) -> ChannelData:
raise NotImplementedError
def update_features(self) -> None:
raise NotImplementedError

View File

@ -0,0 +1,97 @@
from __future__ import annotations
import logging
import struct
import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
class ProtocolV1(ProtocolAndChannel):
HEADER_LEN = struct.calcsize(">HL")
_features: messages.Features | None = None
def get_features(self) -> messages.Features:
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
self.write(messages.GetFeatures())
resp = self.read()
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = resp
def read(self) -> t.Any:
msg_type, msg_bytes = self._read()
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
self.transport.close()
return msg
def write(self, msg: t.Any) -> None:
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None:
chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: chunk_size - 1]
chunk = chunk.ljust(chunk_size, b"\x00")
self.transport.write_chunk(chunk)
buffer = buffer[63:]
def _read(self) -> t.Tuple[int, bytes]:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> t.Tuple[int, int, bytes]:
chunk = self.transport.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError("Unexpected magic characters")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.transport.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError("Unexpected magic characters")
return chunk[1:]

View File

@ -0,0 +1,490 @@
from __future__ import annotations
import hashlib
import hmac
import logging
import os
import typing as t
from binascii import hexlify
import click
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from ... import exceptions, messages, protobuf
from ...mapping import ProtobufMapping
from .. import Transport
from ..thp import checksum, curve25519, thp_io
from ..thp.channel_data import ChannelData
from ..thp.checksum import CHECKSUM_LENGTH
from ..thp.message_header import MessageHeader
from . import control_byte
from .channel_database import ChannelDatabase, get_channel_db
from .protocol_and_channel import ProtocolAndChannel
LOG = logging.getLogger(__name__)
DEFAULT_SESSION_ID: int = 0
if t.TYPE_CHECKING:
from ...debuglink import DebugLink
MT = t.TypeVar("MT", bound=protobuf.MessageType)
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
hash = hashlib.sha256(val_1)
hash.update(val_2)
return hash.digest()
def _hkdf(chaining_key: bytes, input: bytes):
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
ctx_output_2.update(b"\x02")
output_2 = ctx_output_2.digest()
return (output_1, output_2)
def _get_iv_from_nonce(nonce: int) -> bytes:
if not nonce <= 0xFFFFFFFFFFFFFFFF:
raise ValueError("Nonce overflow, terminate the channel")
return bytes(4) + nonce.to_bytes(8, "big")
class ProtocolV2(ProtocolAndChannel):
channel_id: int
channel_database: ChannelDatabase
key_request: bytes
key_response: bytes
nonce_request: int
nonce_response: int
sync_bit_send: int
sync_bit_receive: int
handshake_hash: bytes
_has_valid_channel: bool = False
_features: messages.Features | None = None
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
channel_data: ChannelData | None = None,
) -> None:
self.channel_database: ChannelDatabase = get_channel_db()
super().__init__(transport, mapping, channel_data)
if channel_data is not None:
self.channel_id = channel_data.channel_id
self.key_request = bytes.fromhex(channel_data.key_request)
self.key_response = bytes.fromhex(channel_data.key_response)
self.nonce_request = channel_data.nonce_request
self.nonce_response = channel_data.nonce_response
self.sync_bit_receive = channel_data.sync_bit_receive
self.sync_bit_send = channel_data.sync_bit_send
self.handshake_hash = bytes.fromhex(channel_data.handshake_hash)
self._has_valid_channel = True
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2:
if not self._has_valid_channel:
self._establish_new_channel(helper_debug)
return self
def get_channel_data(self) -> ChannelData:
return ChannelData(
protocol_version_major=2,
protocol_version_minor=2,
transport_path=self.transport.get_path(),
channel_id=self.channel_id,
key_request=self.key_request,
key_response=self.key_response,
nonce_request=self.nonce_request,
nonce_response=self.nonce_response,
sync_bit_receive=self.sync_bit_receive,
sync_bit_send=self.sync_bit_send,
handshake_hash=self.handshake_hash,
)
def read(self, session_id: int) -> t.Any:
sid, msg_type, msg_data = self.read_and_decrypt()
if sid != session_id:
raise Exception("Received messsage on a different session.")
self.channel_database.save_channel(self)
return self.mapping.decode(msg_type, msg_data)
def write(self, session_id: int, msg: t.Any) -> None:
msg_type, msg_data = self.mapping.encode(msg)
self._encrypt_and_write(session_id, msg_type, msg_data)
self.channel_database.save_channel(self)
def get_features(self) -> messages.Features:
if not self._has_valid_channel:
self._establish_new_channel()
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
message = messages.GetFeatures()
message_type, message_data = self.mapping.encode(message)
self.session_id: int = DEFAULT_SESSION_ID
self._encrypt_and_write(DEFAULT_SESSION_ID, message_type, message_data)
_ = self._read_until_valid_crc_check() # TODO check ACK
_, msg_type, msg_data = self.read_and_decrypt()
features = self.mapping.decode(msg_type, msg_data)
if not isinstance(features, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = features
def _send_message(
self,
message: protobuf.MessageType,
session_id: int = DEFAULT_SESSION_ID,
):
message_type, message_data = self.mapping.encode(message)
self._encrypt_and_write(session_id, message_type, message_data)
self._read_ack()
def _read_message(self, message_type: type[MT]) -> MT:
_, msg_type, msg_data = self.read_and_decrypt()
msg = self.mapping.decode(msg_type, msg_data)
assert isinstance(msg, message_type)
return msg
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
self._reset_sync_bits()
self._do_channel_allocation()
self._do_handshake()
self._do_pairing(helper_debug)
def _reset_sync_bits(self) -> None:
self.sync_bit_send = 0
self.sync_bit_receive = 0
def _do_channel_allocation(self) -> None:
channel_allocation_nonce = os.urandom(8)
self._send_channel_allocation_request(channel_allocation_nonce)
cid, dp = self._read_channel_allocation_response(channel_allocation_nonce)
self.channel_id = cid
self.device_properties = dp
def _send_channel_allocation_request(self, nonce: bytes):
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
MessageHeader.get_channel_allocation_request_header(12),
nonce,
)
def _read_channel_allocation_response(
self, expected_nonce: bytes
) -> tuple[int, bytes]:
header, payload = self._read_until_valid_crc_check()
if not self._is_valid_channel_allocation_response(
header, payload, expected_nonce
):
raise Exception("Invalid channel allocation response.")
channel_id = int.from_bytes(payload[8:10], "big")
device_properties = payload[10:]
return (channel_id, device_properties)
def _do_handshake(
self, credential: bytes | None = None, host_static_privkey: bytes | None = None
):
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
self._send_handshake_init_request(host_ephemeral_pubkey)
self._read_ack()
init_response = self._read_handshake_init_response()
trezor_ephemeral_pubkey = init_response[:32]
encrypted_trezor_static_pubkey = init_response[32:80]
noise_tag = init_response[80:96]
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
# TODO check noise_tag is valid
ck = self._send_handshake_completion_request(
host_ephemeral_pubkey,
host_ephemeral_privkey,
trezor_ephemeral_pubkey,
encrypted_trezor_static_pubkey,
credential,
host_static_privkey,
)
self._read_ack()
self._read_handshake_completion_response()
self.key_request, self.key_response = _hkdf(ck, b"")
self.nonce_request = 0
self.nonce_response = 1
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, ha_init_req_header, host_ephemeral_pubkey
)
def _read_handshake_init_response(self) -> bytes:
header, payload = self._read_until_valid_crc_check()
self._send_ack_0()
if header.ctrl_byte == 0x42:
if payload == b"\x05":
raise exceptions.DeviceLockedException()
if not header.is_handshake_init_response():
LOG.debug("Received message is not a valid handshake init response message")
click.echo(
"Received message is not a valid handshake init response message",
err=True,
)
return payload
def _send_handshake_completion_request(
self,
host_ephemeral_pubkey: bytes,
host_ephemeral_privkey: bytes,
trezor_ephemeral_pubkey: bytes,
encrypted_trezor_static_pubkey: bytes,
credential: bytes | None = None,
host_static_privkey: bytes | None = None,
) -> bytes:
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
h = _sha256_of_two(h, host_ephemeral_pubkey)
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
ck, k = _hkdf(
PROTOCOL_NAME,
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
)
aes_ctx = AESGCM(k)
try:
trezor_masked_static_pubkey = aes_ctx.decrypt(
IV_1, encrypted_trezor_static_pubkey, h
)
except Exception as e:
click.echo(
f"Exception of type{type(e)}", err=True
) # TODO how to handle potential exceptions? Q for Matejcik
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
)
aes_ctx = AESGCM(k)
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
h = _sha256_of_two(h, tag_of_empty_string)
# TODO: search for saved credentials
if host_static_privkey is not None and credential is not None:
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
else:
credential = None
zeroes_32 = int.to_bytes(0, 32, "little")
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
temp_host_static_pubkey = curve25519.get_public_key(
temp_host_static_privkey
)
host_static_privkey = temp_host_static_privkey
host_static_pubkey = temp_host_static_pubkey
aes_ctx = AESGCM(k)
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
h = _sha256_of_two(h, encrypted_host_static_pubkey)
ck, k = _hkdf(
ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey)
)
msg_data = self.mapping.encode_without_wire_type(
messages.ThpHandshakeCompletionReqNoisePayload(
host_pairing_credential=credential,
)
)
aes_ctx = AESGCM(k)
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
h = _sha256_of_two(h, encrypted_payload[:-16])
ha_completion_req_header = MessageHeader(
0x12,
self.channel_id,
len(encrypted_host_static_pubkey)
+ len(encrypted_payload)
+ CHECKSUM_LENGTH,
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport,
ha_completion_req_header,
encrypted_host_static_pubkey + encrypted_payload,
)
self.handshake_hash = h
return ck
def _read_handshake_completion_response(self) -> None:
# Read handshake completion response, ignore payload as we do not care about the state
header, _ = self._read_until_valid_crc_check()
if not header.is_handshake_comp_response():
click.echo(
"Received message is not a valid handshake completion response",
err=True,
)
self._send_ack_1()
def _do_pairing(self, helper_debug: DebugLink | None):
self._send_message(messages.ThpPairingRequest())
self._read_message(messages.ButtonRequest)
self._send_message(messages.ButtonAck())
if helper_debug is not None:
helper_debug.press_yes()
self._read_message(messages.ThpPairingRequestApproved)
self._send_message(
messages.ThpSelectMethod(
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
)
)
self._read_message(messages.ThpEndResponse)
self._has_valid_channel = True
def _read_ack(self):
header, payload = self._read_until_valid_crc_check()
if not header.is_ack() or len(payload) > 0:
click.echo("Received message is not a valid ACK", err=True)
def _send_ack_0(self):
LOG.debug("sending ack 0")
header = MessageHeader(0x20, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _send_ack_1(self):
LOG.debug("sending ack 1")
header = MessageHeader(0x28, self.channel_id, 4)
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
def _encrypt_and_write(
self,
session_id: int,
message_type: int,
message_data: bytes,
ctrl_byte: int | None = None,
) -> None:
assert self.key_request is not None
aes_ctx = AESGCM(self.key_request)
if ctrl_byte is None:
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
self.sync_bit_send = 1 - self.sync_bit_send
sid = session_id.to_bytes(1, "big")
msg_type = message_type.to_bytes(2, "big")
data = sid + msg_type + message_data
nonce = _get_iv_from_nonce(self.nonce_request)
self.nonce_request += 1
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
header = MessageHeader(
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
)
thp_io.write_payload_to_wire_and_add_checksum(
self.transport, header, encrypted_message
)
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
header, raw_payload = self._read_until_valid_crc_check()
if control_byte.is_ack(header.ctrl_byte):
# TODO fix this recursion
return self.read_and_decrypt()
if control_byte.is_error(header.ctrl_byte):
# TODO check for different channel
err = _get_error_from_int(raw_payload[0])
raise Exception("Received ThpError: " + err)
if not header.is_encrypted_transport():
click.echo(
"Trying to decrypt not encrypted message! ("
+ hexlify(header.to_bytes_init() + raw_payload).decode()
+ ")",
err=True,
)
if not control_byte.is_ack(header.ctrl_byte):
LOG.debug(
"--> Get sequence bit %d %s %s",
control_byte.get_seq_bit(header.ctrl_byte),
"from control byte",
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
)
if control_byte.get_seq_bit(header.ctrl_byte):
self._send_ack_1()
else:
self._send_ack_0()
aes_ctx = AESGCM(self.key_response)
nonce = _get_iv_from_nonce(self.nonce_response)
self.nonce_response += 1
message = aes_ctx.decrypt(nonce, raw_payload, b"")
session_id = message[0]
message_type = message[1:3]
message_data = message[3:]
return (
session_id,
int.from_bytes(message_type, "big"),
message_data,
)
def _read_until_valid_crc_check(
self,
) -> t.Tuple[MessageHeader, bytes]:
is_valid = False
header, payload, chksum = thp_io.read(self.transport)
while not is_valid:
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
if not is_valid:
click.echo(
"Received a message with an invalid checksum:"
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
err=True,
)
header, payload, chksum = thp_io.read(self.transport)
return header, payload
def _is_valid_channel_allocation_response(
self, header: MessageHeader, payload: bytes, original_nonce: bytes
) -> bool:
if not header.is_channel_allocation_response():
click.echo(
"Received message is not a channel allocation response", err=True
)
return False
if len(payload) < 10:
click.echo("Invalid channel allocation response payload", err=True)
return False
if payload[:8] != original_nonce:
click.echo(
"Invalid channel allocation response payload (nonce mismatch)", err=True
)
return False
return True
def _get_error_from_int(error_code: int) -> str:
# TODO FIXME improve this (ThpErrorType)
if error_code == 1:
return "TRANSPORT BUSY"
if error_code == 2:
return "UNALLOCATED CHANNEL"
if error_code == 3:
return "DECRYPTION FAILED"
if error_code == 4:
return "INVALID DATA"
if error_code == 5:
return "DEVICE LOCKED"
raise Exception("Not Implemented error case")

View File

@ -0,0 +1,93 @@
import struct
from typing import Tuple
from .. import Transport
from ..thp import checksum
from .message_header import MessageHeader
INIT_HEADER_LENGTH = 5
CONT_HEADER_LENGTH = 3
MAX_PAYLOAD_LEN = 60000
MESSAGE_TYPE_LENGTH = 2
CONTINUATION_PACKET = 0x80
def write_payload_to_wire_and_add_checksum(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
data = transport_payload + chksum
write_payload_to_wire(transport, header, data)
def write_payload_to_wire(
transport: Transport, header: MessageHeader, transport_payload: bytes
):
transport.open()
buffer = bytearray(transport_payload)
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
while buffer:
chunk = (
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
)
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
transport.write_chunk(chunk)
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
"""
Reads from the given wire transport.
Returns `Tuple[MessageHeader, bytes, bytes]`:
1. `header` (`MessageHeader`): Header of the message.
2. `data` (`bytes`): Contents of the message (if any).
3. `checksum` (`bytes`): crc32 checksum of the header + data.
"""
buffer = bytearray()
# Read header with first part of message data
header, first_chunk = read_first(transport)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < header.data_length:
buffer.extend(read_next(transport, header.cid))
data_len = header.data_length - checksum.CHECKSUM_LENGTH
msg_data = buffer[:data_len]
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
return (header, msg_data, chksum)
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
chunk = transport.read_chunk()
try:
ctrl_byte, cid, data_length = struct.unpack(
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
)
except Exception:
raise RuntimeError("Cannot parse header")
data = chunk[INIT_HEADER_LENGTH:]
return MessageHeader(ctrl_byte, cid, data_length), data
def read_next(transport: Transport, cid: int) -> bytes:
chunk = transport.read_chunk()
ctrl_byte, read_cid = struct.unpack(
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
)
if ctrl_byte != CONTINUATION_PACKET:
raise RuntimeError("Continuation packet with incorrect control byte")
if read_cid != cid:
raise RuntimeError("Continuation packet for different channel")
return chunk[CONT_HEADER_LENGTH:]

View File

@ -14,14 +14,15 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import socket import socket
import time import time
from typing import TYPE_CHECKING, Iterable, Optional from typing import TYPE_CHECKING, Iterable, Tuple
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import TransportException from . import Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
if TYPE_CHECKING: if TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class UdpTransport(ProtocolBasedTransport): class UdpTransport(Transport):
DEFAULT_HOST = "127.0.0.1" DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324 DEFAULT_PORT = 21324
PATH_PREFIX = "udp" PATH_PREFIX = "udp"
ENABLED: bool = True ENABLED: bool = True
CHUNK_SIZE = 64
def __init__(self, device: Optional[str] = None) -> None: def __init__(
self,
device: str | None = None,
) -> None:
if not device: if not device:
host = UdpTransport.DEFAULT_HOST host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT port = UdpTransport.DEFAULT_PORT
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
devparts = device.split(":") devparts = device.split(":")
host = devparts[0] host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port) self.device: Tuple[str, int] = (host, port)
self.socket: Optional[socket.socket] = None
super().__init__(protocol=ProtocolV1(self)) self.socket: socket.socket | None = None
super().__init__()
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
@classmethod @classmethod
def _try_path(cls, path: str) -> "UdpTransport": def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path) d = cls(path)
try: try:
d.open() d.open()
if d._ping(): if d.ping():
return d return d
else: else:
raise TransportException( raise TransportException(
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: Iterable["TrezorModel"] | None = None
) -> Iterable["UdpTransport"]: ) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try: try:
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
else: else:
raise TransportException(f"No UDP device at {path}") raise TransportException(f"No UDP device at {path}")
def wait_until_ready(self, timeout: float = 10) -> None: def get_path(self) -> str:
try: return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
self.open()
start = time.monotonic()
while True:
if self._ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def open(self) -> None: def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.close() self.socket.close()
self.socket = None self.socket = None
def _ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
if self.socket is None:
self.open()
assert self.socket is not None assert self.socket is not None
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected data length") raise TransportException("Unexpected data length")
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.sendall(chunk) self.socket.sendall(chunk)
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
if self.socket is None:
self.open()
assert self.socket is not None assert self.socket is not None
while True: while True:
try: try:
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk) return bytearray(chunk)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self.ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"

View File

@ -14,16 +14,17 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit import atexit
import logging import logging
import sys import sys
import time import time
from typing import Iterable, List, Optional from typing import Iterable, List
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from ..models import TREZORS, TrezorModel from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
WEBUSB_CHUNK_SIZE = 64 WEBUSB_CHUNK_SIZE = 64
class WebUsbHandle: class WebUsbTransport(Transport):
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: """
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
CHUNK_SIZE = 64
def __init__(
self,
device: "usb1.USBDevice",
debug: bool = False,
) -> None:
self.device = device self.device = device
self.debug = debug
self.interface = DEBUG_INTERFACE if debug else INTERFACE self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0 self.handle: usb1.USBDeviceHandle | None = None
self.handle: Optional["usb1.USBDeviceHandle"] = None
super().__init__()
@classmethod
def enumerate(
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
return devices
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
def open(self) -> None: def open(self) -> None:
self.handle = self.device.open() self.handle = self.device.open()
@ -64,6 +121,8 @@ class WebUsbHandle:
self.handle.claimInterface(self.interface) self.handle.claimInterface(self.interface)
except usb1.USBErrorAccess as e: except usb1.USBErrorAccess as e:
raise DeviceIsBusy(self.device) from e raise DeviceIsBusy(self.device) from e
except usb1.USBErrorBusy as e:
raise DeviceIsBusy(self.device) from e
def close(self) -> None: def close(self) -> None:
if self.handle is not None: if self.handle is not None:
@ -75,6 +134,8 @@ class WebUsbHandle:
self.handle = None self.handle = None
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
if self.handle is None:
self.open()
assert self.handle is not None assert self.handle is not None
if len(chunk) != WEBUSB_CHUNK_SIZE: if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
@ -97,6 +158,8 @@ class WebUsbHandle:
return return
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
if self.handle is None:
self.open()
assert self.handle is not None assert self.handle is not None
endpoint = 0x80 | self.endpoint endpoint = 0x80 | self.endpoint
while True: while True:
@ -117,70 +180,6 @@ class WebUsbHandle:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk return chunk
class WebUsbTransport(ProtocolBasedTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
def __init__(
self,
device: "usb1.USBDevice",
handle: Optional[WebUsbHandle] = None,
debug: bool = False,
) -> None:
if handle is None:
handle = WebUsbHandle(device, debug)
self.device = device
self.handle = handle
self.debug = debug
super().__init__(protocol=ProtocolV1(handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
except usb1.USBErrorPipe:
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
return devices
def find_debug(self) -> "WebUsbTransport": def find_debug(self) -> "WebUsbTransport":
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True) return WebUsbTransport(self.device, debug=True)