From e1ce484ba749fcfe7416fd9f15dfadb2442b8fb5 Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 21 Feb 2025 12:17:49 +0100 Subject: [PATCH] feat(python): add a timeout argument to read() from transport also take the opportunity to switch to new style typing annotations syntax [no changelog] --- python/src/trezorlib/transport/__init__.py | 41 ++++++++---------- python/src/trezorlib/transport/bridge.py | 46 ++++++++++++-------- python/src/trezorlib/transport/hid.py | 19 +++++--- python/src/trezorlib/transport/protocol.py | 30 +++++++------ python/src/trezorlib/transport/udp.py | 26 +++++------ python/src/trezorlib/transport/webusb.py | 50 ++++++++++++---------- tests/conftest.py | 6 ++- 7 files changed, 122 insertions(+), 96 deletions(-) diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index b04876b6b7..8aa759b173 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -14,17 +14,10 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -from typing import ( - TYPE_CHECKING, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar from ..exceptions import TrezorException @@ -52,6 +45,10 @@ class DeviceIsBusy(TransportException): pass +class Timeout(TransportException): + pass + + class Transport: """Raw connection to a Trezor device. @@ -84,23 +81,23 @@ class Transport: def end_session(self) -> None: raise NotImplementedError - def read(self) -> MessagePayload: + def read(self, timeout: float | None = None) -> MessagePayload: raise NotImplementedError def write(self, message_type: int, message_data: bytes) -> None: raise NotImplementedError - def find_debug(self: "T") -> "T": + def find_debug(self: T) -> T: raise NotImplementedError @classmethod def enumerate( - cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["T"]: + cls: type[T], models: Iterable[TrezorModel] | None = None + ) -> Iterable[T]: raise NotImplementedError @classmethod - def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": + def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T: for device in cls.enumerate(): if ( path is None @@ -112,13 +109,13 @@ class Transport: raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") -def all_transports() -> Iterable[Type["Transport"]]: +def all_transports() -> Iterable[type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - transports: Tuple[Type["Transport"], ...] = ( + transports: Tuple[type["Transport"], ...] = ( BridgeTransport, HidTransport, UdpTransport, @@ -128,9 +125,9 @@ def all_transports() -> Iterable[Type["Transport"]]: def enumerate_devices( - models: Optional[Iterable["TrezorModel"]] = None, -) -> Sequence["Transport"]: - devices: List["Transport"] = [] + models: Iterable[TrezorModel] | None = None, +) -> Sequence[Transport]: + devices: list[Transport] = [] for transport in all_transports(): name = transport.__name__ try: @@ -145,9 +142,7 @@ def enumerate_devices( return devices -def get_transport( - path: Optional[str] = None, prefix_search: bool = False -) -> "Transport": +def get_transport(path: str | None = None, prefix_search: bool = False) -> Transport: if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e0c34a8f70..ae7c79e903 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -14,11 +14,14 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import struct -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable import requests +from typing_extensions import Self from ..log import DUMP_PACKETS from . import DeviceIsBusy, MessagePayload, Transport, TransportException @@ -45,9 +48,11 @@ class BridgeException(TransportException): super().__init__(f"trezord: {path} failed with code {status}: {message}") -def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: +def call_bridge( + path: str, data: str | None = None, timeout: float | None = None +) -> requests.Response: url = TREZORD_HOST + "/" + path - r = CONNECTION.post(url, data=data) + r = CONNECTION.post(url, data=data, timeout=timeout) if r.status_code != 200: raise BridgeException(path, r.status_code, r.json()["error"]) return r @@ -63,7 +68,7 @@ class BridgeHandle: def __init__(self, transport: "BridgeTransport") -> None: self.transport = transport - def read_buf(self) -> bytes: + def read_buf(self, timeout: float | None = None) -> bytes: raise NotImplementedError def write_buf(self, buf: bytes) -> None: @@ -75,8 +80,8 @@ class BridgeHandleModern(BridgeHandle): LOG.log(DUMP_PACKETS, f"sending message: {buf.hex()}") self.transport._call("post", data=buf.hex()) - def read_buf(self) -> bytes: - data = self.transport._call("read") + def read_buf(self, timeout: float | None = None) -> bytes: + data = self.transport._call("read", timeout=timeout) LOG.log(DUMP_PACKETS, f"received message: {data.text}") return bytes.fromhex(data.text) @@ -84,19 +89,19 @@ class BridgeHandleModern(BridgeHandle): class BridgeHandleLegacy(BridgeHandle): def __init__(self, transport: "BridgeTransport") -> None: super().__init__(transport) - self.request: Optional[str] = None + self.request: str | None = None def write_buf(self, buf: bytes) -> None: if self.request is not None: raise TransportException("Can't write twice on legacy Bridge") self.request = buf.hex() - def read_buf(self) -> bytes: + def read_buf(self, timeout: float | None = None) -> bytes: if self.request is None: raise TransportException("Can't read without write on legacy Bridge") try: LOG.log(DUMP_PACKETS, f"calling with message: {self.request}") - data = self.transport._call("call", data=self.request) + data = self.transport._call("call", data=self.request, timeout=timeout) LOG.log(DUMP_PACKETS, f"received response: {data.text}") return bytes.fromhex(data.text) finally: @@ -112,13 +117,13 @@ class BridgeTransport(Transport): ENABLED: bool = True def __init__( - self, device: Dict[str, Any], legacy: bool, debug: bool = False + self, device: dict[str, Any], legacy: bool, debug: bool = False ) -> None: if legacy and debug: raise TransportException("Debugging not supported on legacy Bridge") self.device = device - self.session: Optional[str] = None + self.session: str | None = None self.debug = debug self.legacy = legacy @@ -130,21 +135,26 @@ class BridgeTransport(Transport): def get_path(self) -> str: return f"{self.PATH_PREFIX}:{self.device['path']}" - def find_debug(self) -> "BridgeTransport": + def find_debug(self) -> Self: if not self.device.get("debug"): raise TransportException("Debug device not available") - return BridgeTransport(self.device, self.legacy, debug=True) + return self.__class__(self.device, self.legacy, debug=True) - def _call(self, action: str, data: Optional[str] = None) -> requests.Response: + def _call( + self, + action: str, + data: str | None = None, + timeout: float | None = None, + ) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: uri = "debug/" + uri - return call_bridge(uri, data=data) + return call_bridge(uri, data=data, timeout=timeout) @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None + cls, _models: Iterable[TrezorModel] | None = None ) -> Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() @@ -173,8 +183,8 @@ class BridgeTransport(Transport): header = struct.pack(">HL", message_type, len(message_data)) self.handle.write_buf(header + message_data) - def read(self) -> MessagePayload: - data = self.handle.read_buf() + def read(self, timeout: float | None = None) -> MessagePayload: + data = self.handle.read_buf(timeout=timeout) headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", data[:headerlen]) return msg_type, data[headerlen : headerlen + datalen] diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65fa08ccd7..61cf8bafd9 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -14,14 +14,16 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import sys import time -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Dict, Iterable from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel -from . import UDEV_RULES_STR, TransportException +from . import UDEV_RULES_STR, Timeout, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -91,13 +93,16 @@ class HidHandle: LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}") self.handle.write(chunk) - def read_chunk(self) -> bytes: + def read_chunk(self, timeout: float | None = None) -> bytes: + start = time.time() while True: # hidapi seems to return lists of ints instead of bytes chunk = bytes(self.handle.read(64)) if chunk: break else: + if timeout is not None and time.time() - start > timeout: + raise Timeout(f"Timeout reading HID packet ({timeout}s)") time.sleep(0.001) LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}") @@ -134,13 +139,13 @@ class HidTransport(ProtocolBasedTransport): @classmethod def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False - ) -> Iterable["HidTransport"]: + cls, models: Iterable[TrezorModel] | None = None, debug: bool = False + ) -> Iterable[HidTransport]: if models is None: models = {TREZOR_ONE} usb_ids = [id for model in models for id in model.usb_ids] - devices: List["HidTransport"] = [] + devices: list[HidTransport] = [] for dev in hid.enumerate(0, 0): usb_id = (dev["vendor_id"], dev["product_id"]) if usb_id not in usb_ids: @@ -154,7 +159,7 @@ class HidTransport(ProtocolBasedTransport): devices.append(HidTransport(dev)) return devices - def find_debug(self) -> "HidTransport": + def find_debug(self) -> HidTransport: # For v1 protocol, find debug USB interface for the same serial number for debug in HidTransport.enumerate(debug=True): if debug.device["serial_number"] == self.device["serial_number"]: diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index b91e5d58a4..4a2c129ec6 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -14,9 +14,10 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import struct -from typing import Tuple from typing_extensions import Protocol as StructuralType @@ -31,6 +32,8 @@ V2_END_SESSION = 0x04 LOG = logging.getLogger(__name__) +_DEFAULT_READ_TIMEOUT: float | None = None + class Handle(StructuralType): """PEP 544 structural type for Handle functionality. @@ -48,7 +51,7 @@ class Handle(StructuralType): def close(self) -> None: ... - def read_chunk(self) -> bytes: ... + def read_chunk(self, timeout: float | None = None) -> bytes: ... def write_chunk(self, chunk: bytes) -> None: ... @@ -86,7 +89,7 @@ class Protocol: if self.session_counter == 0: self.handle.close() - def read(self) -> MessagePayload: + def read(self, timeout: float | None = None) -> MessagePayload: raise NotImplementedError def write(self, message_type: int, message_data: bytes) -> None: @@ -106,8 +109,8 @@ class ProtocolBasedTransport(Transport): def write(self, message_type: int, message_data: bytes) -> None: self.protocol.write(message_type, message_data) - def read(self) -> MessagePayload: - return self.protocol.read() + def read(self, timeout: float | None = None) -> MessagePayload: + return self.protocol.read(timeout=timeout) def begin_session(self) -> None: self.protocol.begin_session() @@ -134,20 +137,23 @@ class ProtocolV1(Protocol): self.handle.write_chunk(chunk) buffer = buffer[63:] - def read(self) -> MessagePayload: + def read(self, timeout: float | None = None) -> MessagePayload: + if timeout is None: + timeout = _DEFAULT_READ_TIMEOUT + buffer = bytearray() # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() + msg_type, datalen, first_chunk = self.read_first(timeout=timeout) buffer.extend(first_chunk) # Read the rest of the message while len(buffer) < datalen: - buffer.extend(self.read_next()) + buffer.extend(self.read_next(timeout=timeout)) return msg_type, buffer[:datalen] - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() + def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]: + chunk = self.handle.read_chunk(timeout=timeout) if chunk[:3] != b"?##": raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") try: @@ -158,8 +164,8 @@ class ProtocolV1(Protocol): data = chunk[3 + self.HEADER_LEN :] return msg_type, datalen, data - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() + def read_next(self, timeout: float | None = None) -> bytes: + chunk = self.handle.read_chunk(timeout=timeout) if chunk[:1] != b"?": raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") return chunk[1:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7e4c4614c6..a4652b6fbf 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -14,13 +14,15 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import socket import time -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable from ..log import DUMP_PACKETS -from . import TransportException +from . import Timeout, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 if TYPE_CHECKING: @@ -38,7 +40,7 @@ class UdpTransport(ProtocolBasedTransport): PATH_PREFIX = "udp" ENABLED: bool = True - def __init__(self, device: Optional[str] = None) -> None: + def __init__(self, device: str | None = None) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -47,7 +49,7 @@ class UdpTransport(ProtocolBasedTransport): host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT self.device = (host, port) - self.socket: Optional[socket.socket] = None + self.socket: socket.socket | None = None super().__init__(protocol=ProtocolV1(self)) @@ -77,7 +79,7 @@ class UdpTransport(ProtocolBasedTransport): @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None + cls, _models: Iterable["TrezorModel"] | None = None ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: @@ -94,10 +96,8 @@ class UdpTransport(ProtocolBasedTransport): if not prefix_search: raise - if prefix_search: - return super().find_by_path(path, prefix_search) - else: - raise TransportException(f"No UDP device at {path}") + assert prefix_search # otherwise we would have raised above + return super().find_by_path(path, prefix_search) def wait_until_ready(self, timeout: float = 10) -> None: try: @@ -108,7 +108,7 @@ class UdpTransport(ProtocolBasedTransport): break elapsed = time.monotonic() - start if elapsed >= timeout: - raise TransportException("Timed out waiting for connection.") + raise Timeout("Timed out waiting for connection.") time.sleep(0.05) finally: @@ -142,14 +142,16 @@ class UdpTransport(ProtocolBasedTransport): LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}") self.socket.sendall(chunk) - def read_chunk(self) -> bytes: + def read_chunk(self, timeout: float | None = None) -> bytes: assert self.socket is not None + start = time.time() while True: try: chunk = self.socket.recv(64) break except socket.timeout: - continue + if timeout is not None and time.time() - start > timeout: + raise Timeout(f"Timeout reading UDP packet ({timeout}s)") LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}") if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 8e2d08147a..6fa7868c0e 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -14,15 +14,19 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import atexit import logging import sys import time -from typing import Iterable, List, Optional +from typing import Iterable + +from typing_extensions import Self from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel -from . import UDEV_RULES_STR, DeviceIsBusy, TransportException +from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -45,12 +49,12 @@ WEBUSB_CHUNK_SIZE = 64 class WebUsbHandle: - def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: + def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None: self.device = device self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT self.count = 0 - self.handle: Optional["usb1.USBDeviceHandle"] = None + self.handle: usb1.USBDeviceHandle | None = None def open(self) -> None: self.handle = self.device.open() @@ -96,26 +100,24 @@ class WebUsbHandle: ) return - def read_chunk(self) -> bytes: + def read_chunk(self, timeout: float | None = None) -> bytes: assert self.handle is not None endpoint = 0x80 | self.endpoint + start = time.time() while True: try: chunk = self.handle.interruptRead( endpoint, WEBUSB_CHUNK_SIZE, USB_COMM_TIMEOUT_MS ) - if chunk: - break - else: - time.sleep(0.001) + LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}") + if len(chunk) != WEBUSB_CHUNK_SIZE: + raise TransportException(f"Unexpected chunk size: {len(chunk)}") + return chunk except usb1.USBErrorTimeout: - pass + if timeout is not None and time.time() - start > timeout: + raise Timeout(f"Timeout reading WebUSB packet ({timeout}s)") except Exception as e: raise TransportException(f"USB read failed: {e}") from e - LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}") - if len(chunk) != WEBUSB_CHUNK_SIZE: - raise TransportException(f"Unexpected chunk size: {len(chunk)}") - return chunk class WebUsbTransport(ProtocolBasedTransport): @@ -129,8 +131,8 @@ class WebUsbTransport(ProtocolBasedTransport): def __init__( self, - device: "usb1.USBDevice", - handle: Optional[WebUsbHandle] = None, + device: usb1.USBDevice, + handle: WebUsbHandle | None = None, debug: bool = False, ) -> None: if handle is None: @@ -147,8 +149,10 @@ class WebUsbTransport(ProtocolBasedTransport): @classmethod def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False - ) -> Iterable["WebUsbTransport"]: + cls, + models: Iterable[TrezorModel] | None = None, + usb_reset: bool = False, + ) -> Iterable[WebUsbTransport]: if cls.context is None: cls.context = usb1.USBContext() cls.context.open() @@ -157,7 +161,7 @@ class WebUsbTransport(ProtocolBasedTransport): if models is None: models = TREZORS usb_ids = [id for model in models for id in model.usb_ids] - devices: List["WebUsbTransport"] = [] + devices: list[WebUsbTransport] = [] for dev in cls.context.getDeviceIterator(skip_on_error=True): usb_id = (dev.getVendorID(), dev.getProductID()) if usb_id not in usb_ids: @@ -181,12 +185,12 @@ class WebUsbTransport(ProtocolBasedTransport): handle.close() return devices - def find_debug(self) -> "WebUsbTransport": + def find_debug(self) -> Self: # For v1 protocol, find debug USB interface for the same serial number - return WebUsbTransport(self.device, debug=True) + return self.__class__(self.device, debug=True) -def is_vendor_class(dev: "usb1.USBDevice") -> bool: +def is_vendor_class(dev: usb1.USBDevice) -> bool: configurationId = 0 altSettingId = 0 return ( @@ -195,7 +199,7 @@ def is_vendor_class(dev: "usb1.USBDevice") -> bool: ) -def dev_to_str(dev: "usb1.USBDevice") -> str: +def dev_to_str(dev: usb1.USBDevice) -> str: return ":".join( str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList() ) diff --git a/tests/conftest.py b/tests/conftest.py index 004a3b7b66..b2b1dc4eea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ from trezorlib import debuglink, log, models from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.device import apply_settings from trezorlib.device import wipe as wipe_device -from trezorlib.transport import enumerate_devices, get_transport +from trezorlib.transport import enumerate_devices, get_transport, protocol # register rewrites before importing from local package # so that we see details of failed asserts from this module @@ -134,6 +134,10 @@ def _raw_client(request: pytest.FixtureRequest) -> Client: client = emu_fixture.client else: interact = os.environ.get("INTERACT") == "1" + if not interact: + # prevent tests from getting stuck in case there is an USB packet loss + protocol._DEFAULT_READ_TIMEOUT = 50.0 + path = os.environ.get("TREZOR_PATH") if path: client = _client_from_path(request, path, interact)