1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-12 22:26:08 +00:00

chore(core): adapt trezorlib transports to session based

[no changelog]

Co-authored-by: mmilata <martin@martinmilata.cz>
This commit is contained in:
M1nd3r 2025-02-04 15:21:19 +01:00
parent 5c2d3c65b7
commit 72fe47ba3b
9 changed files with 380 additions and 434 deletions

View File

@ -2,7 +2,7 @@
import binascii import binascii
from trezorlib.client import TrezorClient from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport from trezorlib.transport.hid import HidTransport
devices = HidTransport.enumerate() devices = HidTransport.enumerate()
if len(devices) > 0: if len(devices) > 0:

View File

@ -14,24 +14,18 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
from typing import ( import typing as t
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from ..exceptions import TrezorException from ..exceptions import TrezorException
if TYPE_CHECKING: if t.TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
T = TypeVar("T", bound="Transport") T = t.TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip() """.strip()
MessagePayload = Tuple[int, bytes] MessagePayload = t.Tuple[int, bytes]
class TransportException(TrezorException): class TransportException(TrezorException):
@ -53,73 +47,57 @@ class DeviceIsBusy(TransportException):
class Transport: class Transport:
"""Raw connection to a Trezor device.
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
PATH_PREFIX: str PATH_PREFIX: str
ENABLED = False
def __str__(self) -> str: @classmethod
return self.get_path() def enumerate(
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
) -> t.Iterable["T"]:
raise NotImplementedError
@classmethod
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if device.get_path() == path:
return device
if prefix_search and device.get_path().startswith(path):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str: def get_path(self) -> str:
raise NotImplementedError raise NotImplementedError
def begin_session(self) -> None:
raise NotImplementedError
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
def find_debug(self: "T") -> "T": def find_debug(self: "T") -> "T":
raise NotImplementedError raise NotImplementedError
@classmethod def open(self) -> None:
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
raise NotImplementedError raise NotImplementedError
@classmethod def close(self) -> None:
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": raise NotImplementedError
for device in cls.enumerate():
if (
path is None
or device.get_path() == path
or (prefix_search and device.get_path().startswith(path))
):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def write_chunk(self, chunk: bytes) -> None:
raise NotImplementedError
def read_chunk(self) -> bytes:
raise NotImplementedError
def ping(self) -> bool:
raise NotImplementedError
CHUNK_SIZE: t.ClassVar[int | None]
def all_transports() -> Iterable[Type["Transport"]]: def all_transports() -> t.Iterable[t.Type["Transport"]]:
from .ble import BleTransport
from .bridge import BridgeTransport from .bridge import BridgeTransport
from .hid import HidTransport from .hid import HidTransport
from .udp import UdpTransport from .udp import UdpTransport
from .webusb import WebUsbTransport from .webusb import WebUsbTransport
transports: Tuple[Type["Transport"], ...] = ( transports: t.Tuple[t.Type["Transport"], ...] = (
BridgeTransport, BridgeTransport,
HidTransport, HidTransport,
UdpTransport, UdpTransport,
@ -130,9 +108,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
def enumerate_devices( def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None, models: t.Iterable["TrezorModel"] | None = None,
) -> Sequence["Transport"]: ) -> t.Sequence["Transport"]:
devices: List["Transport"] = [] devices: t.List["Transport"] = []
for transport in all_transports(): for transport in all_transports():
name = transport.__name__ name = transport.__name__
try: try:
@ -147,9 +125,7 @@ def enumerate_devices(
return devices return devices
def get_transport( def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
if path is None: if path is None:
try: try:
return next(iter(enumerate_devices())) return next(iter(enumerate_devices()))

View File

@ -14,16 +14,17 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import struct import typing as t
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
import requests import requests
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException from . import DeviceIsBusy, Transport, TransportException
if TYPE_CHECKING: if t.TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -45,7 +46,7 @@ class BridgeException(TransportException):
super().__init__(f"trezord: {path} failed with code {status}: {message}") super().__init__(f"trezord: {path} failed with code {status}: {message}")
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: def call_bridge(path: str, data: str | None = None) -> requests.Response:
url = TREZORD_HOST + "/" + path url = TREZORD_HOST + "/" + path
r = CONNECTION.post(url, data=data) r = CONNECTION.post(url, data=data)
if r.status_code != 200: if r.status_code != 200:
@ -53,10 +54,13 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
return r return r
def is_legacy_bridge() -> bool: def get_bridge_version() -> t.Tuple[int, ...]:
config = call_bridge("configure").json() config = call_bridge("configure").json()
version_tuple = tuple(map(int, config["version"].split("."))) return tuple(map(int, config["version"].split(".")))
return version_tuple < TREZORD_VERSION_MODERN
def is_legacy_bridge() -> bool:
return get_bridge_version() < TREZORD_VERSION_MODERN
class BridgeHandle: class BridgeHandle:
@ -84,7 +88,7 @@ class BridgeHandleModern(BridgeHandle):
class BridgeHandleLegacy(BridgeHandle): class BridgeHandleLegacy(BridgeHandle):
def __init__(self, transport: "BridgeTransport") -> None: def __init__(self, transport: "BridgeTransport") -> None:
super().__init__(transport) super().__init__(transport)
self.request: Optional[str] = None self.request: str | None = None
def write_buf(self, buf: bytes) -> None: def write_buf(self, buf: bytes) -> None:
if self.request is not None: if self.request is not None:
@ -110,15 +114,15 @@ class BridgeTransport(Transport):
PATH_PREFIX = "bridge" PATH_PREFIX = "bridge"
ENABLED: bool = True ENABLED: bool = True
CHUNK_SIZE = None
def __init__( def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
) -> None: ) -> None:
if legacy and debug: if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge") raise TransportException("Debugging not supported on legacy Bridge")
self.device = device self.device = device
self.session: Optional[str] = None self.session: str | None = device["session"]
self.debug = debug self.debug = debug
self.legacy = legacy self.legacy = legacy
@ -135,7 +139,7 @@ class BridgeTransport(Transport):
raise TransportException("Debug device not available") raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True) return BridgeTransport(self.device, self.legacy, debug=True)
def _call(self, action: str, data: Optional[str] = None) -> requests.Response: def _call(self, action: str, data: str | None = None) -> requests.Response:
session = self.session or "null" session = self.session or "null"
uri = action + "/" + str(session) uri = action + "/" + str(session)
if self.debug: if self.debug:
@ -144,8 +148,8 @@ class BridgeTransport(Transport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: t.Iterable["TrezorModel"] | None = None
) -> Iterable["BridgeTransport"]: ) -> t.Iterable["BridgeTransport"]:
try: try:
legacy = is_legacy_bridge() legacy = is_legacy_bridge()
return [ return [
@ -154,7 +158,7 @@ class BridgeTransport(Transport):
except Exception: except Exception:
return [] return []
def begin_session(self) -> None: def open(self) -> None:
try: try:
data = self._call("acquire/" + self.device["path"]) data = self._call("acquire/" + self.device["path"])
except BridgeException as e: except BridgeException as e:
@ -163,18 +167,17 @@ class BridgeTransport(Transport):
raise raise
self.session = data.json()["session"] self.session = data.json()["session"]
def end_session(self) -> None: def close(self) -> None:
if not self.session: if not self.session:
return return
self._call("release") self._call("release")
self.session = None self.session = None
def write(self, message_type: int, message_data: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data)) self.handle.write_buf(chunk)
self.handle.write_buf(header + message_data)
def read(self) -> MessagePayload: def read_chunk(self) -> bytes:
data = self.handle.read_buf() return self.handle.read_buf()
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen]) def ping(self) -> bool:
return msg_type, data[headerlen : headerlen + datalen] return self.session is not None

View File

@ -14,15 +14,16 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import sys import sys
import time import time
from typing import Any, Dict, Iterable, List, Optional import typing as t
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException from . import UDEV_RULES_STR, Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -35,23 +36,61 @@ except Exception as e:
HID_IMPORTED = False HID_IMPORTED = False
HidDevice = Dict[str, Any] HidDevice = t.Dict[str, t.Any]
HidDeviceHandle = Any HidDeviceHandle = t.Any
class HidHandle: class HidTransport(Transport):
def __init__( """
self, path: bytes, serial: str, probe_hid_version: bool = False HidTransport implements transport over USB HID interface.
) -> None: """
self.path = path
self.serial = serial PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
self.device = device
self.device_path = device["path"]
self.device_serial_number = device["serial_number"]
self.handle: HidDeviceHandle = None self.handle: HidDeviceHandle = None
self.hid_version = None if probe_hid_version else 2 self.hid_version = None if probe_hid_version else 2
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
) -> t.Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: t.List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def open(self) -> None: def open(self) -> None:
self.handle = hid.device() self.handle = hid.device()
try: try:
self.handle.open_path(self.path) self.handle.open_path(self.device_path)
except (IOError, OSError) as e: except (IOError, OSError) as e:
if sys.platform.startswith("linux"): if sys.platform.startswith("linux"):
e.args = e.args + (UDEV_RULES_STR,) e.args = e.args + (UDEV_RULES_STR,)
@ -62,11 +101,11 @@ class HidHandle:
# and we wouldn't even know. # and we wouldn't even know.
# So we check that the serial matches what we expect. # So we check that the serial matches what we expect.
serial = self.handle.get_serial_number_string() serial = self.handle.get_serial_number_string()
if serial != self.serial: if serial != self.device_serial_number:
self.handle.close() self.handle.close()
self.handle = None self.handle = None
raise TransportException( raise TransportException(
f"Unexpected device {serial} on path {self.path.decode()}" f"Unexpected device {serial} on path {self.device_path.decode()}"
) )
self.handle.set_nonblocking(True) self.handle.set_nonblocking(True)
@ -77,7 +116,7 @@ class HidHandle:
def close(self) -> None: def close(self) -> None:
if self.handle is not None: if self.handle is not None:
# reload serial, because device.wipe() can reset it # reload serial, because device.wipe() can reset it
self.serial = self.handle.get_serial_number_string() self.device_serial_number = self.handle.get_serial_number_string()
self.handle.close() self.handle.close()
self.handle = None self.handle = None
@ -115,53 +154,6 @@ class HidHandle:
raise TransportException("Unknown HID version") raise TransportException("Unknown HID version")
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None:
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
super().__init__(protocol=ProtocolV1(self.handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def is_wirelink(dev: HidDevice) -> bool: def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0

View File

@ -1,166 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . import MessagePayload, Transport
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None: ...
def close(self) -> None: ...
def read_chunk(self) -> bytes: ...
def write_chunk(self, chunk: bytes) -> None: ...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
def __init__(self, handle: Handle, replen: int = REPLEN) -> None:
self.handle = handle
self.replen = replen
self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
self.session_counter += 1
def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)
def read(self) -> MessagePayload:
return self.protocol.read()
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
HEADER_LEN = struct.calcsize(">HL")
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: self.replen - 1]
chunk = chunk.ljust(self.replen, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[self.replen - 1 :]
def read(self) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:]

View File

@ -0,0 +1,26 @@
from __future__ import annotations
import logging
from ... import messages
from ...mapping import ProtobufMapping
from .. import Transport
LOG = logging.getLogger(__name__)
class Channel:
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
) -> None:
self.transport = transport
self.mapping = mapping
def get_features(self) -> messages.Features:
raise NotImplementedError()
def update_features(self) -> None:
raise NotImplementedError

View File

@ -0,0 +1,109 @@
from __future__ import annotations
import logging
import struct
import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES
from .protocol_and_channel import Channel
LOG = logging.getLogger(__name__)
class ProtocolV1Channel(Channel):
HEADER_LEN = struct.calcsize(">HL")
_features: messages.Features | None = None
def get_features(self) -> messages.Features:
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
self.write(messages.GetFeatures())
resp = self.read()
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = resp
def read(self) -> t.Any:
msg_type, msg_bytes = self._read()
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
self.transport.close()
return msg
def write(self, msg: t.Any) -> None:
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None:
chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data))
if chunk_size is None:
self.transport.write_chunk(header + message_data)
return
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: chunk_size - 1]
chunk = chunk.ljust(chunk_size, b"\x00")
self.transport.write_chunk(chunk)
buffer = buffer[63:]
def _read(self) -> t.Tuple[int, bytes]:
if self.transport.CHUNK_SIZE is None:
return self.read_chunkless()
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_chunkless(self) -> t.Tuple[int, bytes]:
data = self.transport.read_chunk()
msg_type, datalen = struct.unpack(">HL", data[: self.HEADER_LEN])
return msg_type, data[self.HEADER_LEN : self.HEADER_LEN + datalen]
def read_first(self) -> t.Tuple[int, int, bytes]:
chunk = self.transport.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.transport.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:]

View File

@ -14,14 +14,15 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import socket import socket
import time import time
from typing import TYPE_CHECKING, Iterable, Optional from typing import TYPE_CHECKING, Iterable, Tuple
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import TransportException from . import Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
if TYPE_CHECKING: if TYPE_CHECKING:
from ..models import TrezorModel from ..models import TrezorModel
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class UdpTransport(ProtocolBasedTransport): class UdpTransport(Transport):
DEFAULT_HOST = "127.0.0.1" DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324 DEFAULT_PORT = 21324
PATH_PREFIX = "udp" PATH_PREFIX = "udp"
ENABLED: bool = True ENABLED: bool = True
CHUNK_SIZE = 64
def __init__(self, device: Optional[str] = None) -> None: def __init__(
self,
device: str | None = None,
) -> None:
if not device: if not device:
host = UdpTransport.DEFAULT_HOST host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT port = UdpTransport.DEFAULT_PORT
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
devparts = device.split(":") devparts = device.split(":")
host = devparts[0] host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port) self.device: Tuple[str, int] = (host, port)
self.socket: Optional[socket.socket] = None
super().__init__(protocol=ProtocolV1(self)) self.socket: socket.socket | None = None
super().__init__()
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
@classmethod @classmethod
def _try_path(cls, path: str) -> "UdpTransport": def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path) d = cls(path)
try: try:
d.open() d.open()
if d._ping(): if d.ping():
return d return d
else: else:
raise TransportException( raise TransportException(
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: Iterable["TrezorModel"] | None = None
) -> Iterable["UdpTransport"]: ) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try: try:
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
else: else:
raise TransportException(f"No UDP device at {path}") raise TransportException(f"No UDP device at {path}")
def wait_until_ready(self, timeout: float = 10) -> None: def get_path(self) -> str:
try: return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
self.open()
start = time.monotonic()
while True:
if self._ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def open(self) -> None: def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.close() self.socket.close()
self.socket = None self.socket = None
def _ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
if self.socket is None:
self.open()
assert self.socket is not None assert self.socket is not None
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected data length") raise TransportException("Unexpected data length")
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.sendall(chunk) self.socket.sendall(chunk)
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
if self.socket is None:
self.open()
assert self.socket is not None assert self.socket is not None
while True: while True:
try: try:
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk) return bytearray(chunk)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self.ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"

View File

@ -14,16 +14,17 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit import atexit
import logging import logging
import sys import sys
import time import time
from typing import Iterable, List, Optional from typing import Iterable, List
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from ..models import TREZORS, TrezorModel from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
WEBUSB_CHUNK_SIZE = 64 WEBUSB_CHUNK_SIZE = 64
class WebUsbHandle: class WebUsbTransport(Transport):
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: """
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
CHUNK_SIZE = 64
def __init__(
self,
device: "usb1.USBDevice",
debug: bool = False,
) -> None:
self.device = device self.device = device
self.debug = debug
self.interface = DEBUG_INTERFACE if debug else INTERFACE self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0 self.handle: usb1.USBDeviceHandle | None = None
self.handle: Optional["usb1.USBDeviceHandle"] = None
super().__init__()
@classmethod
def enumerate(
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
return devices
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
def open(self) -> None: def open(self) -> None:
self.handle = self.device.open() self.handle = self.device.open()
@ -77,6 +134,8 @@ class WebUsbHandle:
self.handle = None self.handle = None
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
if self.handle is None:
self.open()
assert self.handle is not None assert self.handle is not None
if len(chunk) != WEBUSB_CHUNK_SIZE: if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
@ -99,6 +158,8 @@ class WebUsbHandle:
return return
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
if self.handle is None:
self.open()
assert self.handle is not None assert self.handle is not None
endpoint = 0x80 | self.endpoint endpoint = 0x80 | self.endpoint
while True: while True:
@ -119,70 +180,6 @@ class WebUsbHandle:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk return chunk
class WebUsbTransport(ProtocolBasedTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
def __init__(
self,
device: "usb1.USBDevice",
handle: Optional[WebUsbHandle] = None,
debug: bool = False,
) -> None:
if handle is None:
handle = WebUsbHandle(device, debug)
self.device = device
self.handle = handle
self.debug = debug
super().__init__(protocol=ProtocolV1(handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
except usb1.USBErrorPipe:
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
return devices
def find_debug(self) -> "WebUsbTransport": def find_debug(self) -> "WebUsbTransport":
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True) return WebUsbTransport(self.device, debug=True)