1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-05-06 08:59:15 +00:00

feat(python): add a timeout argument to read() from transport

also take the opportunity to switch to new style typing annotations
syntax

[no changelog]
This commit is contained in:
matejcik 2025-02-21 12:17:49 +01:00 committed by Roman Zeyde
parent 0fb1693ea8
commit e1ce484ba7
7 changed files with 122 additions and 96 deletions

View File

@ -14,17 +14,10 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
from typing import ( from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from ..exceptions import TrezorException from ..exceptions import TrezorException
@ -52,6 +45,10 @@ class DeviceIsBusy(TransportException):
pass pass
class Timeout(TransportException):
pass
class Transport: class Transport:
"""Raw connection to a Trezor device. """Raw connection to a Trezor device.
@ -84,23 +81,23 @@ class Transport:
def end_session(self) -> None: def end_session(self) -> None:
raise NotImplementedError raise NotImplementedError
def read(self) -> MessagePayload: def read(self, timeout: float | None = None) -> MessagePayload:
raise NotImplementedError raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None: def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError raise NotImplementedError
def find_debug(self: "T") -> "T": def find_debug(self: T) -> T:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def enumerate( def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None cls: type[T], models: Iterable[TrezorModel] | None = None
) -> Iterable["T"]: ) -> Iterable[T]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T:
for device in cls.enumerate(): for device in cls.enumerate():
if ( if (
path is None path is None
@ -112,13 +109,13 @@ class Transport:
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def all_transports() -> Iterable[Type["Transport"]]: def all_transports() -> Iterable[type["Transport"]]:
from .bridge import BridgeTransport from .bridge import BridgeTransport
from .hid import HidTransport from .hid import HidTransport
from .udp import UdpTransport from .udp import UdpTransport
from .webusb import WebUsbTransport from .webusb import WebUsbTransport
transports: Tuple[Type["Transport"], ...] = ( transports: Tuple[type["Transport"], ...] = (
BridgeTransport, BridgeTransport,
HidTransport, HidTransport,
UdpTransport, UdpTransport,
@ -128,9 +125,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
def enumerate_devices( def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None, models: Iterable[TrezorModel] | None = None,
) -> Sequence["Transport"]: ) -> Sequence[Transport]:
devices: List["Transport"] = [] devices: list[Transport] = []
for transport in all_transports(): for transport in all_transports():
name = transport.__name__ name = transport.__name__
try: try:
@ -145,9 +142,7 @@ def enumerate_devices(
return devices return devices
def get_transport( def get_transport(path: str | None = None, prefix_search: bool = False) -> Transport:
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
if path is None: if path is None:
try: try:
return next(iter(enumerate_devices())) return next(iter(enumerate_devices()))

View File

@ -14,11 +14,14 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import struct import struct
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional from typing import TYPE_CHECKING, Any, Iterable
import requests import requests
from typing_extensions import Self
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException from . import DeviceIsBusy, MessagePayload, Transport, TransportException
@ -45,9 +48,11 @@ class BridgeException(TransportException):
super().__init__(f"trezord: {path} failed with code {status}: {message}") super().__init__(f"trezord: {path} failed with code {status}: {message}")
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: def call_bridge(
path: str, data: str | None = None, timeout: float | None = None
) -> requests.Response:
url = TREZORD_HOST + "/" + path url = TREZORD_HOST + "/" + path
r = CONNECTION.post(url, data=data) r = CONNECTION.post(url, data=data, timeout=timeout)
if r.status_code != 200: if r.status_code != 200:
raise BridgeException(path, r.status_code, r.json()["error"]) raise BridgeException(path, r.status_code, r.json()["error"])
return r return r
@ -63,7 +68,7 @@ class BridgeHandle:
def __init__(self, transport: "BridgeTransport") -> None: def __init__(self, transport: "BridgeTransport") -> None:
self.transport = transport self.transport = transport
def read_buf(self) -> bytes: def read_buf(self, timeout: float | None = None) -> bytes:
raise NotImplementedError raise NotImplementedError
def write_buf(self, buf: bytes) -> None: def write_buf(self, buf: bytes) -> None:
@ -75,8 +80,8 @@ class BridgeHandleModern(BridgeHandle):
LOG.log(DUMP_PACKETS, f"sending message: {buf.hex()}") LOG.log(DUMP_PACKETS, f"sending message: {buf.hex()}")
self.transport._call("post", data=buf.hex()) self.transport._call("post", data=buf.hex())
def read_buf(self) -> bytes: def read_buf(self, timeout: float | None = None) -> bytes:
data = self.transport._call("read") data = self.transport._call("read", timeout=timeout)
LOG.log(DUMP_PACKETS, f"received message: {data.text}") LOG.log(DUMP_PACKETS, f"received message: {data.text}")
return bytes.fromhex(data.text) return bytes.fromhex(data.text)
@ -84,19 +89,19 @@ class BridgeHandleModern(BridgeHandle):
class BridgeHandleLegacy(BridgeHandle): class BridgeHandleLegacy(BridgeHandle):
def __init__(self, transport: "BridgeTransport") -> None: def __init__(self, transport: "BridgeTransport") -> None:
super().__init__(transport) super().__init__(transport)
self.request: Optional[str] = None self.request: str | None = None
def write_buf(self, buf: bytes) -> None: def write_buf(self, buf: bytes) -> None:
if self.request is not None: if self.request is not None:
raise TransportException("Can't write twice on legacy Bridge") raise TransportException("Can't write twice on legacy Bridge")
self.request = buf.hex() self.request = buf.hex()
def read_buf(self) -> bytes: def read_buf(self, timeout: float | None = None) -> bytes:
if self.request is None: if self.request is None:
raise TransportException("Can't read without write on legacy Bridge") raise TransportException("Can't read without write on legacy Bridge")
try: try:
LOG.log(DUMP_PACKETS, f"calling with message: {self.request}") LOG.log(DUMP_PACKETS, f"calling with message: {self.request}")
data = self.transport._call("call", data=self.request) data = self.transport._call("call", data=self.request, timeout=timeout)
LOG.log(DUMP_PACKETS, f"received response: {data.text}") LOG.log(DUMP_PACKETS, f"received response: {data.text}")
return bytes.fromhex(data.text) return bytes.fromhex(data.text)
finally: finally:
@ -112,13 +117,13 @@ class BridgeTransport(Transport):
ENABLED: bool = True ENABLED: bool = True
def __init__( def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False self, device: dict[str, Any], legacy: bool, debug: bool = False
) -> None: ) -> None:
if legacy and debug: if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge") raise TransportException("Debugging not supported on legacy Bridge")
self.device = device self.device = device
self.session: Optional[str] = None self.session: str | None = None
self.debug = debug self.debug = debug
self.legacy = legacy self.legacy = legacy
@ -130,21 +135,26 @@ class BridgeTransport(Transport):
def get_path(self) -> str: def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path']}" return f"{self.PATH_PREFIX}:{self.device['path']}"
def find_debug(self) -> "BridgeTransport": def find_debug(self) -> Self:
if not self.device.get("debug"): if not self.device.get("debug"):
raise TransportException("Debug device not available") raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True) return self.__class__(self.device, self.legacy, debug=True)
def _call(self, action: str, data: Optional[str] = None) -> requests.Response: def _call(
self,
action: str,
data: str | None = None,
timeout: float | None = None,
) -> requests.Response:
session = self.session or "null" session = self.session or "null"
uri = action + "/" + str(session) uri = action + "/" + str(session)
if self.debug: if self.debug:
uri = "debug/" + uri uri = "debug/" + uri
return call_bridge(uri, data=data) return call_bridge(uri, data=data, timeout=timeout)
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: Iterable[TrezorModel] | None = None
) -> Iterable["BridgeTransport"]: ) -> Iterable["BridgeTransport"]:
try: try:
legacy = is_legacy_bridge() legacy = is_legacy_bridge()
@ -173,8 +183,8 @@ class BridgeTransport(Transport):
header = struct.pack(">HL", message_type, len(message_data)) header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data) self.handle.write_buf(header + message_data)
def read(self) -> MessagePayload: def read(self, timeout: float | None = None) -> MessagePayload:
data = self.handle.read_buf() data = self.handle.read_buf(timeout=timeout)
headerlen = struct.calcsize(">HL") headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen]) msg_type, datalen = struct.unpack(">HL", data[:headerlen])
return msg_type, data[headerlen : headerlen + datalen] return msg_type, data[headerlen : headerlen + datalen]

View File

@ -14,14 +14,16 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import sys import sys
import time import time
from typing import Any, Dict, Iterable, List, Optional from typing import Any, Dict, Iterable
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException from . import UDEV_RULES_STR, Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -91,13 +93,16 @@ class HidHandle:
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}") LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
self.handle.write(chunk) self.handle.write(chunk)
def read_chunk(self) -> bytes: def read_chunk(self, timeout: float | None = None) -> bytes:
start = time.time()
while True: while True:
# hidapi seems to return lists of ints instead of bytes # hidapi seems to return lists of ints instead of bytes
chunk = bytes(self.handle.read(64)) chunk = bytes(self.handle.read(64))
if chunk: if chunk:
break break
else: else:
if timeout is not None and time.time() - start > timeout:
raise Timeout(f"Timeout reading HID packet ({timeout}s)")
time.sleep(0.001) time.sleep(0.001)
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}") LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
@ -134,13 +139,13 @@ class HidTransport(ProtocolBasedTransport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False cls, models: Iterable[TrezorModel] | None = None, debug: bool = False
) -> Iterable["HidTransport"]: ) -> Iterable[HidTransport]:
if models is None: if models is None:
models = {TREZOR_ONE} models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids] usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = [] devices: list[HidTransport] = []
for dev in hid.enumerate(0, 0): for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"]) usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids: if usb_id not in usb_ids:
@ -154,7 +159,7 @@ class HidTransport(ProtocolBasedTransport):
devices.append(HidTransport(dev)) devices.append(HidTransport(dev))
return devices return devices
def find_debug(self) -> "HidTransport": def find_debug(self) -> HidTransport:
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True): for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]: if debug.device["serial_number"] == self.device["serial_number"]:

View File

@ -14,9 +14,10 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import struct import struct
from typing import Tuple
from typing_extensions import Protocol as StructuralType from typing_extensions import Protocol as StructuralType
@ -31,6 +32,8 @@ V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
_DEFAULT_READ_TIMEOUT: float | None = None
class Handle(StructuralType): class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality. """PEP 544 structural type for Handle functionality.
@ -48,7 +51,7 @@ class Handle(StructuralType):
def close(self) -> None: ... def close(self) -> None: ...
def read_chunk(self) -> bytes: ... def read_chunk(self, timeout: float | None = None) -> bytes: ...
def write_chunk(self, chunk: bytes) -> None: ... def write_chunk(self, chunk: bytes) -> None: ...
@ -86,7 +89,7 @@ class Protocol:
if self.session_counter == 0: if self.session_counter == 0:
self.handle.close() self.handle.close()
def read(self) -> MessagePayload: def read(self, timeout: float | None = None) -> MessagePayload:
raise NotImplementedError raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None: def write(self, message_type: int, message_data: bytes) -> None:
@ -106,8 +109,8 @@ class ProtocolBasedTransport(Transport):
def write(self, message_type: int, message_data: bytes) -> None: def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data) self.protocol.write(message_type, message_data)
def read(self) -> MessagePayload: def read(self, timeout: float | None = None) -> MessagePayload:
return self.protocol.read() return self.protocol.read(timeout=timeout)
def begin_session(self) -> None: def begin_session(self) -> None:
self.protocol.begin_session() self.protocol.begin_session()
@ -134,20 +137,23 @@ class ProtocolV1(Protocol):
self.handle.write_chunk(chunk) self.handle.write_chunk(chunk)
buffer = buffer[63:] buffer = buffer[63:]
def read(self) -> MessagePayload: def read(self, timeout: float | None = None) -> MessagePayload:
if timeout is None:
timeout = _DEFAULT_READ_TIMEOUT
buffer = bytearray() buffer = bytearray()
# Read header with first part of message data # Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first() msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
buffer.extend(first_chunk) buffer.extend(first_chunk)
# Read the rest of the message # Read the rest of the message
while len(buffer) < datalen: while len(buffer) < datalen:
buffer.extend(self.read_next()) buffer.extend(self.read_next(timeout=timeout))
return msg_type, buffer[:datalen] return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]: def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]:
chunk = self.handle.read_chunk() chunk = self.handle.read_chunk(timeout=timeout)
if chunk[:3] != b"?##": if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try: try:
@ -158,8 +164,8 @@ class ProtocolV1(Protocol):
data = chunk[3 + self.HEADER_LEN :] data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data return msg_type, datalen, data
def read_next(self) -> bytes: def read_next(self, timeout: float | None = None) -> bytes:
chunk = self.handle.read_chunk() chunk = self.handle.read_chunk(timeout=timeout)
if chunk[:1] != b"?": if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}") raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:] return chunk[1:]

View File

@ -14,13 +14,15 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging import logging
import socket import socket
import time import time
from typing import TYPE_CHECKING, Iterable, Optional from typing import TYPE_CHECKING, Iterable
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import TransportException from . import Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
if TYPE_CHECKING: if TYPE_CHECKING:
@ -38,7 +40,7 @@ class UdpTransport(ProtocolBasedTransport):
PATH_PREFIX = "udp" PATH_PREFIX = "udp"
ENABLED: bool = True ENABLED: bool = True
def __init__(self, device: Optional[str] = None) -> None: def __init__(self, device: str | None = None) -> None:
if not device: if not device:
host = UdpTransport.DEFAULT_HOST host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT port = UdpTransport.DEFAULT_PORT
@ -47,7 +49,7 @@ class UdpTransport(ProtocolBasedTransport):
host = devparts[0] host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port) self.device = (host, port)
self.socket: Optional[socket.socket] = None self.socket: socket.socket | None = None
super().__init__(protocol=ProtocolV1(self)) super().__init__(protocol=ProtocolV1(self))
@ -77,7 +79,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: Iterable["TrezorModel"] | None = None
) -> Iterable["UdpTransport"]: ) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try: try:
@ -94,10 +96,8 @@ class UdpTransport(ProtocolBasedTransport):
if not prefix_search: if not prefix_search:
raise raise
if prefix_search: assert prefix_search # otherwise we would have raised above
return super().find_by_path(path, prefix_search) return super().find_by_path(path, prefix_search)
else:
raise TransportException(f"No UDP device at {path}")
def wait_until_ready(self, timeout: float = 10) -> None: def wait_until_ready(self, timeout: float = 10) -> None:
try: try:
@ -108,7 +108,7 @@ class UdpTransport(ProtocolBasedTransport):
break break
elapsed = time.monotonic() - start elapsed = time.monotonic() - start
if elapsed >= timeout: if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.") raise Timeout("Timed out waiting for connection.")
time.sleep(0.05) time.sleep(0.05)
finally: finally:
@ -142,14 +142,16 @@ class UdpTransport(ProtocolBasedTransport):
LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}") LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}")
self.socket.sendall(chunk) self.socket.sendall(chunk)
def read_chunk(self) -> bytes: def read_chunk(self, timeout: float | None = None) -> bytes:
assert self.socket is not None assert self.socket is not None
start = time.time()
while True: while True:
try: try:
chunk = self.socket.recv(64) chunk = self.socket.recv(64)
break break
except socket.timeout: except socket.timeout:
continue if timeout is not None and time.time() - start > timeout:
raise Timeout(f"Timeout reading UDP packet ({timeout}s)")
LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}") LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")

View File

@ -14,15 +14,19 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit import atexit
import logging import logging
import sys import sys
import time import time
from typing import Iterable, List, Optional from typing import Iterable
from typing_extensions import Self
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from ..models import TREZORS, TrezorModel from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -45,12 +49,12 @@ WEBUSB_CHUNK_SIZE = 64
class WebUsbHandle: class WebUsbHandle:
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
self.device = device self.device = device
self.interface = DEBUG_INTERFACE if debug else INTERFACE self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0 self.count = 0
self.handle: Optional["usb1.USBDeviceHandle"] = None self.handle: usb1.USBDeviceHandle | None = None
def open(self) -> None: def open(self) -> None:
self.handle = self.device.open() self.handle = self.device.open()
@ -96,26 +100,24 @@ class WebUsbHandle:
) )
return return
def read_chunk(self) -> bytes: def read_chunk(self, timeout: float | None = None) -> bytes:
assert self.handle is not None assert self.handle is not None
endpoint = 0x80 | self.endpoint endpoint = 0x80 | self.endpoint
start = time.time()
while True: while True:
try: try:
chunk = self.handle.interruptRead( chunk = self.handle.interruptRead(
endpoint, WEBUSB_CHUNK_SIZE, USB_COMM_TIMEOUT_MS endpoint, WEBUSB_CHUNK_SIZE, USB_COMM_TIMEOUT_MS
) )
if chunk: LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
break if len(chunk) != WEBUSB_CHUNK_SIZE:
else: raise TransportException(f"Unexpected chunk size: {len(chunk)}")
time.sleep(0.001) return chunk
except usb1.USBErrorTimeout: except usb1.USBErrorTimeout:
pass if timeout is not None and time.time() - start > timeout:
raise Timeout(f"Timeout reading WebUSB packet ({timeout}s)")
except Exception as e: except Exception as e:
raise TransportException(f"USB read failed: {e}") from e raise TransportException(f"USB read failed: {e}") from e
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk
class WebUsbTransport(ProtocolBasedTransport): class WebUsbTransport(ProtocolBasedTransport):
@ -129,8 +131,8 @@ class WebUsbTransport(ProtocolBasedTransport):
def __init__( def __init__(
self, self,
device: "usb1.USBDevice", device: usb1.USBDevice,
handle: Optional[WebUsbHandle] = None, handle: WebUsbHandle | None = None,
debug: bool = False, debug: bool = False,
) -> None: ) -> None:
if handle is None: if handle is None:
@ -147,8 +149,10 @@ class WebUsbTransport(ProtocolBasedTransport):
@classmethod @classmethod
def enumerate( def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False cls,
) -> Iterable["WebUsbTransport"]: models: Iterable[TrezorModel] | None = None,
usb_reset: bool = False,
) -> Iterable[WebUsbTransport]:
if cls.context is None: if cls.context is None:
cls.context = usb1.USBContext() cls.context = usb1.USBContext()
cls.context.open() cls.context.open()
@ -157,7 +161,7 @@ class WebUsbTransport(ProtocolBasedTransport):
if models is None: if models is None:
models = TREZORS models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids] usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = [] devices: list[WebUsbTransport] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True): for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID()) usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids: if usb_id not in usb_ids:
@ -181,12 +185,12 @@ class WebUsbTransport(ProtocolBasedTransport):
handle.close() handle.close()
return devices return devices
def find_debug(self) -> "WebUsbTransport": def find_debug(self) -> Self:
# For v1 protocol, find debug USB interface for the same serial number # For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True) return self.__class__(self.device, debug=True)
def is_vendor_class(dev: "usb1.USBDevice") -> bool: def is_vendor_class(dev: usb1.USBDevice) -> bool:
configurationId = 0 configurationId = 0
altSettingId = 0 altSettingId = 0
return ( return (
@ -195,7 +199,7 @@ def is_vendor_class(dev: "usb1.USBDevice") -> bool:
) )
def dev_to_str(dev: "usb1.USBDevice") -> str: def dev_to_str(dev: usb1.USBDevice) -> str:
return ":".join( return ":".join(
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList() str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
) )

View File

@ -30,7 +30,7 @@ from trezorlib import debuglink, log, models
from trezorlib.debuglink import TrezorClientDebugLink as Client from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.device import apply_settings from trezorlib.device import apply_settings
from trezorlib.device import wipe as wipe_device from trezorlib.device import wipe as wipe_device
from trezorlib.transport import enumerate_devices, get_transport from trezorlib.transport import enumerate_devices, get_transport, protocol
# register rewrites before importing from local package # register rewrites before importing from local package
# so that we see details of failed asserts from this module # so that we see details of failed asserts from this module
@ -134,6 +134,10 @@ def _raw_client(request: pytest.FixtureRequest) -> Client:
client = emu_fixture.client client = emu_fixture.client
else: else:
interact = os.environ.get("INTERACT") == "1" interact = os.environ.get("INTERACT") == "1"
if not interact:
# prevent tests from getting stuck in case there is an USB packet loss
protocol._DEFAULT_READ_TIMEOUT = 50.0
path = os.environ.get("TREZOR_PATH") path = os.environ.get("TREZOR_PATH")
if path: if path:
client = _client_from_path(request, path, interact) client = _client_from_path(request, path, interact)