diff --git a/core/tools/codegen/get_trezor_keys.py b/core/tools/codegen/get_trezor_keys.py index 31c40fef1f..b511abd807 100755 --- a/core/tools/codegen/get_trezor_keys.py +++ b/core/tools/codegen/get_trezor_keys.py @@ -2,7 +2,7 @@ import binascii from trezorlib.client import TrezorClient -from trezorlib.transport_hid import HidTransport +from trezorlib.transport.hid import HidTransport devices = HidTransport.enumerate() if len(devices) > 0: diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index 97a3b740c9..3a2862fc2e 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -14,24 +14,18 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -from typing import ( - TYPE_CHECKING, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +import typing as t from ..exceptions import TrezorException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel - T = TypeVar("T", bound="Transport") + T = t.TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) @@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules """.strip() -MessagePayload = Tuple[int, bytes] +MessagePayload = t.Tuple[int, bytes] class TransportException(TrezorException): @@ -53,73 +47,57 @@ class DeviceIsBusy(TransportException): class Transport: - """Raw connection to a Trezor device. - - Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB - or USB-HID connection, or UDP socket of listening emulator(s). - It can also enumerate devices available over this communication link, and return - them as instances. - - Transport instance is a thing that: - - can be identified and requested by a string URI-like path - - can open and close sessions, which enclose related operations - - can read and write protobuf messages - - You need to implement a new Transport subclass if you invent a new way to connect - a Trezor device to a computer. - """ - PATH_PREFIX: str - ENABLED = False - def __str__(self) -> str: - return self.get_path() + @classmethod + def enumerate( + cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["T"]: + raise NotImplementedError + + @classmethod + def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T": + for device in cls.enumerate(): + + if device.get_path() == path: + return device + + if prefix_search and device.get_path().startswith(path): + return device + + raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def get_path(self) -> str: raise NotImplementedError - def begin_session(self) -> None: - raise NotImplementedError - - def end_session(self) -> None: - raise NotImplementedError - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - def find_debug(self: "T") -> "T": raise NotImplementedError - @classmethod - def enumerate( - cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["T"]: + def open(self) -> None: raise NotImplementedError - @classmethod - def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": - for device in cls.enumerate(): - if ( - path is None - or device.get_path() == path - or (prefix_search and device.get_path().startswith(path)) - ): - return device + def close(self) -> None: + raise NotImplementedError - raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") + def write_chunk(self, chunk: bytes) -> None: + raise NotImplementedError + + def read_chunk(self) -> bytes: + raise NotImplementedError + + def ping(self) -> bool: + raise NotImplementedError + + CHUNK_SIZE: t.ClassVar[int | None] -def all_transports() -> Iterable[Type["Transport"]]: - from .ble import BleTransport +def all_transports() -> t.Iterable[t.Type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - transports: Tuple[Type["Transport"], ...] = ( + transports: t.Tuple[t.Type["Transport"], ...] = ( BridgeTransport, HidTransport, UdpTransport, @@ -130,9 +108,9 @@ def all_transports() -> Iterable[Type["Transport"]]: def enumerate_devices( - models: Optional[Iterable["TrezorModel"]] = None, -) -> Sequence["Transport"]: - devices: List["Transport"] = [] + models: t.Iterable["TrezorModel"] | None = None, +) -> t.Sequence["Transport"]: + devices: t.List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: @@ -147,9 +125,7 @@ def enumerate_devices( return devices -def get_transport( - path: Optional[str] = None, prefix_search: bool = False -) -> "Transport": +def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e0c34a8f70..65c545a768 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -14,16 +14,17 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -import struct -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +import typing as t import requests from ..log import DUMP_PACKETS -from . import DeviceIsBusy, MessagePayload, Transport, TransportException +from . import DeviceIsBusy, Transport, TransportException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel LOG = logging.getLogger(__name__) @@ -45,7 +46,7 @@ class BridgeException(TransportException): super().__init__(f"trezord: {path} failed with code {status}: {message}") -def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: +def call_bridge(path: str, data: str | None = None) -> requests.Response: url = TREZORD_HOST + "/" + path r = CONNECTION.post(url, data=data) if r.status_code != 200: @@ -53,10 +54,13 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: return r -def is_legacy_bridge() -> bool: +def get_bridge_version() -> t.Tuple[int, ...]: config = call_bridge("configure").json() - version_tuple = tuple(map(int, config["version"].split("."))) - return version_tuple < TREZORD_VERSION_MODERN + return tuple(map(int, config["version"].split("."))) + + +def is_legacy_bridge() -> bool: + return get_bridge_version() < TREZORD_VERSION_MODERN class BridgeHandle: @@ -84,7 +88,7 @@ class BridgeHandleModern(BridgeHandle): class BridgeHandleLegacy(BridgeHandle): def __init__(self, transport: "BridgeTransport") -> None: super().__init__(transport) - self.request: Optional[str] = None + self.request: str | None = None def write_buf(self, buf: bytes) -> None: if self.request is not None: @@ -110,15 +114,15 @@ class BridgeTransport(Transport): PATH_PREFIX = "bridge" ENABLED: bool = True + CHUNK_SIZE = None def __init__( - self, device: Dict[str, Any], legacy: bool, debug: bool = False + self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False ) -> None: if legacy and debug: raise TransportException("Debugging not supported on legacy Bridge") - self.device = device - self.session: Optional[str] = None + self.session: str | None = device["session"] self.debug = debug self.legacy = legacy @@ -135,7 +139,7 @@ class BridgeTransport(Transport): raise TransportException("Debug device not available") return BridgeTransport(self.device, self.legacy, debug=True) - def _call(self, action: str, data: Optional[str] = None) -> requests.Response: + def _call(self, action: str, data: str | None = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: @@ -144,8 +148,8 @@ class BridgeTransport(Transport): @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["BridgeTransport"]: + cls, _models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() return [ @@ -154,7 +158,7 @@ class BridgeTransport(Transport): except Exception: return [] - def begin_session(self) -> None: + def open(self) -> None: try: data = self._call("acquire/" + self.device["path"]) except BridgeException as e: @@ -163,18 +167,17 @@ class BridgeTransport(Transport): raise self.session = data.json()["session"] - def end_session(self) -> None: + def close(self) -> None: if not self.session: return self._call("release") self.session = None - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - self.handle.write_buf(header + message_data) + def write_chunk(self, chunk: bytes) -> None: + self.handle.write_buf(chunk) - def read(self) -> MessagePayload: - data = self.handle.read_buf() - headerlen = struct.calcsize(">HL") - msg_type, datalen = struct.unpack(">HL", data[:headerlen]) - return msg_type, data[headerlen : headerlen + datalen] + def read_chunk(self) -> bytes: + return self.handle.read_buf() + + def ping(self) -> bool: + return self.session is not None diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65fa08ccd7..65e2cddf7d 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -14,15 +14,16 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import sys import time -from typing import Any, Dict, Iterable, List, Optional +import typing as t from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel -from . import UDEV_RULES_STR, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, Transport, TransportException LOG = logging.getLogger(__name__) @@ -35,23 +36,61 @@ except Exception as e: HID_IMPORTED = False -HidDevice = Dict[str, Any] -HidDeviceHandle = Any +HidDevice = t.Dict[str, t.Any] +HidDeviceHandle = t.Any -class HidHandle: - def __init__( - self, path: bytes, serial: str, probe_hid_version: bool = False - ) -> None: - self.path = path - self.serial = serial +class HidTransport(Transport): + """ + HidTransport implements transport over USB HID interface. + """ + + PATH_PREFIX = "hid" + ENABLED = HID_IMPORTED + + def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None: + self.device = device + self.device_path = device["path"] + self.device_serial_number = device["serial_number"] self.handle: HidDeviceHandle = None self.hid_version = None if probe_hid_version else 2 + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" + + @classmethod + def enumerate( + cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False + ) -> t.Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + + devices: t.List["HidTransport"] = [] + for dev in hid.enumerate(0, 0): + usb_id = (dev["vendor_id"], dev["product_id"]) + if usb_id not in usb_ids: + continue + if debug: + if not is_debuglink(dev): + continue + else: + if not is_wirelink(dev): + continue + devices.append(HidTransport(dev)) + return devices + + def find_debug(self) -> "HidTransport": + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") + def open(self) -> None: self.handle = hid.device() try: - self.handle.open_path(self.path) + self.handle.open_path(self.device_path) except (IOError, OSError) as e: if sys.platform.startswith("linux"): e.args = e.args + (UDEV_RULES_STR,) @@ -62,11 +101,11 @@ class HidHandle: # and we wouldn't even know. # So we check that the serial matches what we expect. serial = self.handle.get_serial_number_string() - if serial != self.serial: + if serial != self.device_serial_number: self.handle.close() self.handle = None raise TransportException( - f"Unexpected device {serial} on path {self.path.decode()}" + f"Unexpected device {serial} on path {self.device_path.decode()}" ) self.handle.set_nonblocking(True) @@ -77,7 +116,7 @@ class HidHandle: def close(self) -> None: if self.handle is not None: # reload serial, because device.wipe() can reset it - self.serial = self.handle.get_serial_number_string() + self.device_serial_number = self.handle.get_serial_number_string() self.handle.close() self.handle = None @@ -115,53 +154,6 @@ class HidHandle: raise TransportException("Unknown HID version") -class HidTransport(ProtocolBasedTransport): - """ - HidTransport implements transport over USB HID interface. - """ - - PATH_PREFIX = "hid" - ENABLED = HID_IMPORTED - - def __init__(self, device: HidDevice) -> None: - self.device = device - self.handle = HidHandle(device["path"], device["serial_number"]) - - super().__init__(protocol=ProtocolV1(self.handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False - ) -> Iterable["HidTransport"]: - if models is None: - models = {TREZOR_ONE} - usb_ids = [id for model in models for id in model.usb_ids] - - devices: List["HidTransport"] = [] - for dev in hid.enumerate(0, 0): - usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id not in usb_ids: - continue - if debug: - if not is_debuglink(dev): - continue - else: - if not is_wirelink(dev): - continue - devices.append(HidTransport(dev)) - return devices - - def find_debug(self) -> "HidTransport": - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") - - def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py deleted file mode 100644 index 82ffefb0a9..0000000000 --- a/python/src/trezorlib/transport/protocol.py +++ /dev/null @@ -1,166 +0,0 @@ -# This file is part of the Trezor project. -# -# Copyright (C) 2012-2022 SatoshiLabs and contributors -# -# This library is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the License along with this library. -# If not, see . - -import logging -import struct -from typing import Tuple - -from typing_extensions import Protocol as StructuralType - -from . import MessagePayload, Transport - -REPLEN = 64 - -V2_FIRST_CHUNK = 0x01 -V2_NEXT_CHUNK = 0x02 -V2_BEGIN_SESSION = 0x03 -V2_END_SESSION = 0x04 - -LOG = logging.getLogger(__name__) - - -class Handle(StructuralType): - """PEP 544 structural type for Handle functionality. - (called a "Protocol" in the proposed PEP, name which is impractical here) - - Handle is a "physical" layer for a protocol. - It can open/close a connection and read/write bare data in 64-byte chunks. - - Functionally we gain nothing from making this an (abstract) base class for handle - implementations, so this definition is for type hinting purposes only. You can, - but don't have to, inherit from it. - """ - - def open(self) -> None: ... - - def close(self) -> None: ... - - def read_chunk(self) -> bytes: ... - - def write_chunk(self, chunk: bytes) -> None: ... - - -class Protocol: - """Wire protocol that can communicate with a Trezor device, given a Handle. - - A Protocol implements the part of the Transport API that relates to communicating - logical messages over a physical layer. It is a thing that can: - - open and close sessions, - - send and receive protobuf messages, - given the ability to: - - open and close physical connections, - - and send and receive binary chunks. - - For now, the class also handles session counting and opening the underlying Handle. - This will probably be removed in the future. - - We will need a new Protocol class if we change the way a Trezor device encapsulates - its messages. - """ - - def __init__(self, handle: Handle, replen: int = REPLEN) -> None: - self.handle = handle - self.replen = replen - self.session_counter = 0 - - # XXX we might be able to remove this now that TrezorClient does session handling - def begin_session(self) -> None: - if self.session_counter == 0: - self.handle.open() - self.session_counter += 1 - - def end_session(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - self.handle.close() - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - - -class ProtocolBasedTransport(Transport): - """Transport that implements its communications through a Protocol. - - Intended as a base class for implementations that proxy their communication - operations to a Protocol. - """ - - def __init__(self, protocol: Protocol) -> None: - self.protocol = protocol - - def write(self, message_type: int, message_data: bytes) -> None: - self.protocol.write(message_type, message_data) - - def read(self) -> MessagePayload: - return self.protocol.read() - - def begin_session(self) -> None: - self.protocol.begin_session() - - def end_session(self) -> None: - self.protocol.end_session() - - -class ProtocolV1(Protocol): - """Protocol version 1. Currently (11/2018) in use on all Trezors. - Does not understand sessions. - """ - - HEADER_LEN = struct.calcsize(">HL") - - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - buffer = bytearray(b"##" + header + message_data) - - while buffer: - # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: self.replen - 1] - chunk = chunk.ljust(self.replen, b"\x00") - self.handle.write_chunk(chunk) - buffer = buffer[self.replen - 1 :] - - def read(self) -> MessagePayload: - buffer = bytearray() - # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() - buffer.extend(first_chunk) - - # Read the rest of the message - while len(buffer) < datalen: - buffer.extend(self.read_next()) - - return msg_type, buffer[:datalen] - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - if chunk[:3] != b"?##": - raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") - try: - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) - except Exception: - raise RuntimeError(f"Cannot parse header: {chunk.hex()}") - - data = chunk[3 + self.HEADER_LEN :] - return msg_type, datalen, data - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - if chunk[:1] != b"?": - raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") - return chunk[1:] diff --git a/python/src/trezorlib/transport/thp/protocol_and_channel.py b/python/src/trezorlib/transport/thp/protocol_and_channel.py new file mode 100644 index 0000000000..1ce918e893 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_and_channel.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import logging + +from ... import messages +from ...mapping import ProtobufMapping +from .. import Transport + +LOG = logging.getLogger(__name__) + + +class Channel: + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + ) -> None: + self.transport = transport + self.mapping = mapping + + def get_features(self) -> messages.Features: + raise NotImplementedError() + + def update_features(self) -> None: + raise NotImplementedError diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py new file mode 100644 index 0000000000..633d500381 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +import struct +import typing as t + +from ... import exceptions, messages +from ...log import DUMP_BYTES +from .protocol_and_channel import Channel + +LOG = logging.getLogger(__name__) + + +class ProtocolV1Channel(Channel): + HEADER_LEN = struct.calcsize(">HL") + _features: messages.Features | None = None + + def get_features(self) -> messages.Features: + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + self.write(messages.GetFeatures()) + resp = self.read() + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = resp + + def read(self) -> t.Any: + msg_type, msg_bytes = self._read() + LOG.log( + DUMP_BYTES, + f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + msg = self.mapping.decode(msg_type, msg_bytes) + LOG.debug( + f"received message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + self.transport.close() + return msg + + def write(self, msg: t.Any) -> None: + LOG.debug( + f"sending message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = self.mapping.encode(msg) + LOG.log( + DUMP_BYTES, + f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + self._write(msg_type, msg_bytes) + + def _write(self, message_type: int, message_data: bytes) -> None: + chunk_size = self.transport.CHUNK_SIZE + header = struct.pack(">HL", message_type, len(message_data)) + + if chunk_size is None: + self.transport.write_chunk(header + message_data) + return + + buffer = bytearray(b"##" + header + message_data) + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: chunk_size - 1] + chunk = chunk.ljust(chunk_size, b"\x00") + self.transport.write_chunk(chunk) + buffer = buffer[63:] + + def _read(self) -> t.Tuple[int, bytes]: + if self.transport.CHUNK_SIZE is None: + return self.read_chunkless() + + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first() + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next()) + + return msg_type, buffer[:datalen] + + def read_chunkless(self) -> t.Tuple[int, bytes]: + data = self.transport.read_chunk() + msg_type, datalen = struct.unpack(">HL", data[: self.HEADER_LEN]) + return msg_type, data[self.HEADER_LEN : self.HEADER_LEN + datalen] + + def read_first(self) -> t.Tuple[int, int, bytes]: + chunk = self.transport.read_chunk() + if chunk[:3] != b"?##": + raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError(f"Cannot parse header: {chunk.hex()}") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self) -> bytes: + chunk = self.transport.read_chunk() + if chunk[:1] != b"?": + raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") + return chunk[1:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7e4c4614c6..e17d6f4500 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -14,14 +14,15 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import socket import time -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Tuple from ..log import DUMP_PACKETS -from . import TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import Transport, TransportException if TYPE_CHECKING: from ..models import TrezorModel @@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10 LOG = logging.getLogger(__name__) -class UdpTransport(ProtocolBasedTransport): +class UdpTransport(Transport): DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 21324 PATH_PREFIX = "udp" ENABLED: bool = True + CHUNK_SIZE = 64 - def __init__(self, device: Optional[str] = None) -> None: + def __init__( + self, + device: str | None = None, + ) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport): devparts = device.split(":") host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT - self.device = (host, port) - self.socket: Optional[socket.socket] = None + self.device: Tuple[str, int] = (host, port) - super().__init__(protocol=ProtocolV1(self)) - - def get_path(self) -> str: - return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) - - def find_debug(self) -> "UdpTransport": - host, port = self.device - return UdpTransport(f"{host}:{port + 1}") + self.socket: socket.socket | None = None + super().__init__() @classmethod def _try_path(cls, path: str) -> "UdpTransport": d = cls(path) try: d.open() - if d._ping(): + if d.ping(): return d else: raise TransportException( @@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport): @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None + cls, _models: Iterable["TrezorModel"] | None = None ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: @@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport): else: raise TransportException(f"No UDP device at {path}") - def wait_until_ready(self, timeout: float = 10) -> None: - try: - self.open() - start = time.monotonic() - while True: - if self._ping(): - break - elapsed = time.monotonic() - start - if elapsed >= timeout: - raise TransportException("Timed out waiting for connection.") - - time.sleep(0.05) - finally: - self.close() + def get_path(self) -> str: + return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) def open(self) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport): self.socket.close() self.socket = None - def _ping(self) -> bool: - """Test if the device is listening.""" - assert self.socket is not None - resp = None - try: - self.socket.sendall(b"PINGPING") - resp = self.socket.recv(8) - except Exception: - pass - return resp == b"PONGPONG" - def write_chunk(self, chunk: bytes) -> None: + if self.socket is None: + self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport): self.socket.sendall(chunk) def read_chunk(self) -> bytes: + if self.socket is None: + self.open() assert self.socket is not None while True: try: @@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport): if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return bytearray(chunk) + + def find_debug(self) -> "UdpTransport": + host, port = self.device + return UdpTransport(f"{host}:{port + 1}") + + def wait_until_ready(self, timeout: float = 10) -> None: + try: + self.open() + start = time.monotonic() + while True: + if self.ping(): + break + elapsed = time.monotonic() - start + if elapsed >= timeout: + raise TransportException("Timed out waiting for connection.") + + time.sleep(0.05) + finally: + self.close() + + def ping(self) -> bool: + """Test if the device is listening.""" + assert self.socket is not None + resp = None + try: + self.socket.sendall(b"PINGPING") + resp = self.socket.recv(8) + except Exception: + pass + return resp == b"PONGPONG" diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index ce216ec002..872d961960 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -14,16 +14,17 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import atexit import logging import sys import time -from typing import Iterable, List, Optional +from typing import Iterable, List from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel -from . import UDEV_RULES_STR, DeviceIsBusy, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException LOG = logging.getLogger(__name__) @@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300 WEBUSB_CHUNK_SIZE = 64 -class WebUsbHandle: - def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: +class WebUsbTransport(Transport): + """ + WebUsbTransport implements transport over WebUSB interface. + """ + + PATH_PREFIX = "webusb" + ENABLED = USB_IMPORTED + context = None + CHUNK_SIZE = 64 + + def __init__( + self, + device: "usb1.USBDevice", + debug: bool = False, + ) -> None: + self.device = device + self.debug = debug + self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT - self.count = 0 - self.handle: Optional["usb1.USBDeviceHandle"] = None + self.handle: usb1.USBDeviceHandle | None = None + + super().__init__() + + @classmethod + def enumerate( + cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["WebUsbTransport"] = [] + for dev in cls.context.getDeviceIterator(skip_on_error=True): + usb_id = (dev.getVendorID(), dev.getProductID()) + if usb_id not in usb_ids: + continue + if not is_vendor_class(dev): + continue + if usb_reset: + handle = dev.open() + handle.resetDevice() + handle.close() + continue + try: + # workaround for issue #223: + # on certain combinations of Windows USB drivers and libusb versions, + # Trezor is returned twice (possibly because Windows know it as both + # a HID and a WebUSB device), and one of the returned devices is + # non-functional. + dev.getProduct() + devices.append(WebUsbTransport(dev)) + except usb1.USBErrorNotSupported: + pass + return devices + + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" def open(self) -> None: self.handle = self.device.open() @@ -77,6 +134,8 @@ class WebUsbHandle: self.handle = None def write_chunk(self, chunk: bytes) -> None: + if self.handle is None: + self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -99,6 +158,8 @@ class WebUsbHandle: return def read_chunk(self) -> bytes: + if self.handle is None: + self.open() assert self.handle is not None endpoint = 0x80 | self.endpoint while True: @@ -119,70 +180,6 @@ class WebUsbHandle: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return chunk - -class WebUsbTransport(ProtocolBasedTransport): - """ - WebUsbTransport implements transport over WebUSB interface. - """ - - PATH_PREFIX = "webusb" - ENABLED = USB_IMPORTED - context = None - - def __init__( - self, - device: "usb1.USBDevice", - handle: Optional[WebUsbHandle] = None, - debug: bool = False, - ) -> None: - if handle is None: - handle = WebUsbHandle(device, debug) - - self.device = device - self.handle = handle - self.debug = debug - - super().__init__(protocol=ProtocolV1(handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False - ) -> Iterable["WebUsbTransport"]: - if cls.context is None: - cls.context = usb1.USBContext() - cls.context.open() - atexit.register(cls.context.close) - - if models is None: - models = TREZORS - usb_ids = [id for model in models for id in model.usb_ids] - devices: List["WebUsbTransport"] = [] - for dev in cls.context.getDeviceIterator(skip_on_error=True): - usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in usb_ids: - continue - if not is_vendor_class(dev): - continue - try: - # workaround for issue #223: - # on certain combinations of Windows USB drivers and libusb versions, - # Trezor is returned twice (possibly because Windows know it as both - # a HID and a WebUSB device), and one of the returned devices is - # non-functional. - dev.getProduct() - devices.append(WebUsbTransport(dev)) - except usb1.USBErrorNotSupported: - pass - except usb1.USBErrorPipe: - if usb_reset: - handle = dev.open() - handle.resetDevice() - handle.close() - return devices - def find_debug(self) -> "WebUsbTransport": # For v1 protocol, find debug USB interface for the same serial number return WebUsbTransport(self.device, debug=True)