1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-21 10:35:43 +00:00

feat(python): add a timeout argument to read() from transport

also take the opportunity to switch to new style typing annotations
syntax

[no changelog]
This commit is contained in:
matejcik 2025-02-21 12:17:49 +01:00 committed by Roman Zeyde
parent 0fb1693ea8
commit e1ce484ba7
7 changed files with 122 additions and 96 deletions

View File

@ -14,17 +14,10 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar
from ..exceptions import TrezorException
@ -52,6 +45,10 @@ class DeviceIsBusy(TransportException):
pass
class Timeout(TransportException):
pass
class Transport:
"""Raw connection to a Trezor device.
@ -84,23 +81,23 @@ class Transport:
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessagePayload:
def read(self, timeout: float | None = None) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
def find_debug(self: "T") -> "T":
def find_debug(self: T) -> T:
raise NotImplementedError
@classmethod
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
cls: type[T], models: Iterable[TrezorModel] | None = None
) -> Iterable[T]:
raise NotImplementedError
@classmethod
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T:
for device in cls.enumerate():
if (
path is None
@ -112,13 +109,13 @@ class Transport:
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def all_transports() -> Iterable[Type["Transport"]]:
def all_transports() -> Iterable[type["Transport"]]:
from .bridge import BridgeTransport
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport
transports: Tuple[Type["Transport"], ...] = (
transports: Tuple[type["Transport"], ...] = (
BridgeTransport,
HidTransport,
UdpTransport,
@ -128,9 +125,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None,
) -> Sequence["Transport"]:
devices: List["Transport"] = []
models: Iterable[TrezorModel] | None = None,
) -> Sequence[Transport]:
devices: list[Transport] = []
for transport in all_transports():
name = transport.__name__
try:
@ -145,9 +142,7 @@ def enumerate_devices(
return devices
def get_transport(
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
def get_transport(path: str | None = None, prefix_search: bool = False) -> Transport:
if path is None:
try:
return next(iter(enumerate_devices()))

View File

@ -14,11 +14,14 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import struct
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Any, Iterable
import requests
from typing_extensions import Self
from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
@ -45,9 +48,11 @@ class BridgeException(TransportException):
super().__init__(f"trezord: {path} failed with code {status}: {message}")
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
def call_bridge(
path: str, data: str | None = None, timeout: float | None = None
) -> requests.Response:
url = TREZORD_HOST + "/" + path
r = CONNECTION.post(url, data=data)
r = CONNECTION.post(url, data=data, timeout=timeout)
if r.status_code != 200:
raise BridgeException(path, r.status_code, r.json()["error"])
return r
@ -63,7 +68,7 @@ class BridgeHandle:
def __init__(self, transport: "BridgeTransport") -> None:
self.transport = transport
def read_buf(self) -> bytes:
def read_buf(self, timeout: float | None = None) -> bytes:
raise NotImplementedError
def write_buf(self, buf: bytes) -> None:
@ -75,8 +80,8 @@ class BridgeHandleModern(BridgeHandle):
LOG.log(DUMP_PACKETS, f"sending message: {buf.hex()}")
self.transport._call("post", data=buf.hex())
def read_buf(self) -> bytes:
data = self.transport._call("read")
def read_buf(self, timeout: float | None = None) -> bytes:
data = self.transport._call("read", timeout=timeout)
LOG.log(DUMP_PACKETS, f"received message: {data.text}")
return bytes.fromhex(data.text)
@ -84,19 +89,19 @@ class BridgeHandleModern(BridgeHandle):
class BridgeHandleLegacy(BridgeHandle):
def __init__(self, transport: "BridgeTransport") -> None:
super().__init__(transport)
self.request: Optional[str] = None
self.request: str | None = None
def write_buf(self, buf: bytes) -> None:
if self.request is not None:
raise TransportException("Can't write twice on legacy Bridge")
self.request = buf.hex()
def read_buf(self) -> bytes:
def read_buf(self, timeout: float | None = None) -> bytes:
if self.request is None:
raise TransportException("Can't read without write on legacy Bridge")
try:
LOG.log(DUMP_PACKETS, f"calling with message: {self.request}")
data = self.transport._call("call", data=self.request)
data = self.transport._call("call", data=self.request, timeout=timeout)
LOG.log(DUMP_PACKETS, f"received response: {data.text}")
return bytes.fromhex(data.text)
finally:
@ -112,13 +117,13 @@ class BridgeTransport(Transport):
ENABLED: bool = True
def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False
self, device: dict[str, Any], legacy: bool, debug: bool = False
) -> None:
if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge")
self.device = device
self.session: Optional[str] = None
self.session: str | None = None
self.debug = debug
self.legacy = legacy
@ -130,21 +135,26 @@ class BridgeTransport(Transport):
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path']}"
def find_debug(self) -> "BridgeTransport":
def find_debug(self) -> Self:
if not self.device.get("debug"):
raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True)
return self.__class__(self.device, self.legacy, debug=True)
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
def _call(
self,
action: str,
data: str | None = None,
timeout: float | None = None,
) -> requests.Response:
session = self.session or "null"
uri = action + "/" + str(session)
if self.debug:
uri = "debug/" + uri
return call_bridge(uri, data=data)
return call_bridge(uri, data=data, timeout=timeout)
@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
cls, _models: Iterable[TrezorModel] | None = None
) -> Iterable["BridgeTransport"]:
try:
legacy = is_legacy_bridge()
@ -173,8 +183,8 @@ class BridgeTransport(Transport):
header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data)
def read(self) -> MessagePayload:
data = self.handle.read_buf()
def read(self, timeout: float | None = None) -> MessagePayload:
data = self.handle.read_buf(timeout=timeout)
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
return msg_type, data[headerlen : headerlen + datalen]

View File

@ -14,14 +14,16 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import sys
import time
from typing import Any, Dict, Iterable, List, Optional
from typing import Any, Dict, Iterable
from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException
from . import UDEV_RULES_STR, Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__)
@ -91,13 +93,16 @@ class HidHandle:
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
self.handle.write(chunk)
def read_chunk(self) -> bytes:
def read_chunk(self, timeout: float | None = None) -> bytes:
start = time.time()
while True:
# hidapi seems to return lists of ints instead of bytes
chunk = bytes(self.handle.read(64))
if chunk:
break
else:
if timeout is not None and time.time() - start > timeout:
raise Timeout(f"Timeout reading HID packet ({timeout}s)")
time.sleep(0.001)
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
@ -134,13 +139,13 @@ class HidTransport(ProtocolBasedTransport):
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
cls, models: Iterable[TrezorModel] | None = None, debug: bool = False
) -> Iterable[HidTransport]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = []
devices: list[HidTransport] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
@ -154,7 +159,7 @@ class HidTransport(ProtocolBasedTransport):
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
def find_debug(self) -> HidTransport:
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:

View File

@ -14,9 +14,10 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import struct
from typing import Tuple
from typing_extensions import Protocol as StructuralType
@ -31,6 +32,8 @@ V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
_DEFAULT_READ_TIMEOUT: float | None = None
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
@ -48,7 +51,7 @@ class Handle(StructuralType):
def close(self) -> None: ...
def read_chunk(self) -> bytes: ...
def read_chunk(self, timeout: float | None = None) -> bytes: ...
def write_chunk(self, chunk: bytes) -> None: ...
@ -86,7 +89,7 @@ class Protocol:
if self.session_counter == 0:
self.handle.close()
def read(self) -> MessagePayload:
def read(self, timeout: float | None = None) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
@ -106,8 +109,8 @@ class ProtocolBasedTransport(Transport):
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)
def read(self) -> MessagePayload:
return self.protocol.read()
def read(self, timeout: float | None = None) -> MessagePayload:
return self.protocol.read(timeout=timeout)
def begin_session(self) -> None:
self.protocol.begin_session()
@ -134,20 +137,23 @@ class ProtocolV1(Protocol):
self.handle.write_chunk(chunk)
buffer = buffer[63:]
def read(self) -> MessagePayload:
def read(self, timeout: float | None = None) -> MessagePayload:
if timeout is None:
timeout = _DEFAULT_READ_TIMEOUT
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
buffer.extend(self.read_next(timeout=timeout))
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]:
chunk = self.handle.read_chunk(timeout=timeout)
if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try:
@ -158,8 +164,8 @@ class ProtocolV1(Protocol):
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
def read_next(self, timeout: float | None = None) -> bytes:
chunk = self.handle.read_chunk(timeout=timeout)
if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:]

View File

@ -14,13 +14,15 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import socket
import time
from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, Iterable
from ..log import DUMP_PACKETS
from . import TransportException
from . import Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
if TYPE_CHECKING:
@ -38,7 +40,7 @@ class UdpTransport(ProtocolBasedTransport):
PATH_PREFIX = "udp"
ENABLED: bool = True
def __init__(self, device: Optional[str] = None) -> None:
def __init__(self, device: str | None = None) -> None:
if not device:
host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT
@ -47,7 +49,7 @@ class UdpTransport(ProtocolBasedTransport):
host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port)
self.socket: Optional[socket.socket] = None
self.socket: socket.socket | None = None
super().__init__(protocol=ProtocolV1(self))
@ -77,7 +79,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
cls, _models: Iterable["TrezorModel"] | None = None
) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try:
@ -94,10 +96,8 @@ class UdpTransport(ProtocolBasedTransport):
if not prefix_search:
raise
if prefix_search:
return super().find_by_path(path, prefix_search)
else:
raise TransportException(f"No UDP device at {path}")
assert prefix_search # otherwise we would have raised above
return super().find_by_path(path, prefix_search)
def wait_until_ready(self, timeout: float = 10) -> None:
try:
@ -108,7 +108,7 @@ class UdpTransport(ProtocolBasedTransport):
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
raise Timeout("Timed out waiting for connection.")
time.sleep(0.05)
finally:
@ -142,14 +142,16 @@ class UdpTransport(ProtocolBasedTransport):
LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}")
self.socket.sendall(chunk)
def read_chunk(self) -> bytes:
def read_chunk(self, timeout: float | None = None) -> bytes:
assert self.socket is not None
start = time.time()
while True:
try:
chunk = self.socket.recv(64)
break
except socket.timeout:
continue
if timeout is not None and time.time() - start > timeout:
raise Timeout(f"Timeout reading UDP packet ({timeout}s)")
LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")

View File

@ -14,15 +14,19 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit
import logging
import sys
import time
from typing import Iterable, List, Optional
from typing import Iterable
from typing_extensions import Self
from ..log import DUMP_PACKETS
from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__)
@ -45,12 +49,12 @@ WEBUSB_CHUNK_SIZE = 64
class WebUsbHandle:
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
self.device = device
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0
self.handle: Optional["usb1.USBDeviceHandle"] = None
self.handle: usb1.USBDeviceHandle | None = None
def open(self) -> None:
self.handle = self.device.open()
@ -96,26 +100,24 @@ class WebUsbHandle:
)
return
def read_chunk(self) -> bytes:
def read_chunk(self, timeout: float | None = None) -> bytes:
assert self.handle is not None
endpoint = 0x80 | self.endpoint
start = time.time()
while True:
try:
chunk = self.handle.interruptRead(
endpoint, WEBUSB_CHUNK_SIZE, USB_COMM_TIMEOUT_MS
)
if chunk:
break
else:
time.sleep(0.001)
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk
except usb1.USBErrorTimeout:
pass
if timeout is not None and time.time() - start > timeout:
raise Timeout(f"Timeout reading WebUSB packet ({timeout}s)")
except Exception as e:
raise TransportException(f"USB read failed: {e}") from e
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk
class WebUsbTransport(ProtocolBasedTransport):
@ -129,8 +131,8 @@ class WebUsbTransport(ProtocolBasedTransport):
def __init__(
self,
device: "usb1.USBDevice",
handle: Optional[WebUsbHandle] = None,
device: usb1.USBDevice,
handle: WebUsbHandle | None = None,
debug: bool = False,
) -> None:
if handle is None:
@ -147,8 +149,10 @@ class WebUsbTransport(ProtocolBasedTransport):
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
cls,
models: Iterable[TrezorModel] | None = None,
usb_reset: bool = False,
) -> Iterable[WebUsbTransport]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
@ -157,7 +161,7 @@ class WebUsbTransport(ProtocolBasedTransport):
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
devices: list[WebUsbTransport] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
@ -181,12 +185,12 @@ class WebUsbTransport(ProtocolBasedTransport):
handle.close()
return devices
def find_debug(self) -> "WebUsbTransport":
def find_debug(self) -> Self:
# For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True)
return self.__class__(self.device, debug=True)
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
def is_vendor_class(dev: usb1.USBDevice) -> bool:
configurationId = 0
altSettingId = 0
return (
@ -195,7 +199,7 @@ def is_vendor_class(dev: "usb1.USBDevice") -> bool:
)
def dev_to_str(dev: "usb1.USBDevice") -> str:
def dev_to_str(dev: usb1.USBDevice) -> str:
return ":".join(
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
)

View File

@ -30,7 +30,7 @@ from trezorlib import debuglink, log, models
from trezorlib.debuglink import TrezorClientDebugLink as Client
from trezorlib.device import apply_settings
from trezorlib.device import wipe as wipe_device
from trezorlib.transport import enumerate_devices, get_transport
from trezorlib.transport import enumerate_devices, get_transport, protocol
# register rewrites before importing from local package
# so that we see details of failed asserts from this module
@ -134,6 +134,10 @@ def _raw_client(request: pytest.FixtureRequest) -> Client:
client = emu_fixture.client
else:
interact = os.environ.get("INTERACT") == "1"
if not interact:
# prevent tests from getting stuck in case there is an USB packet loss
protocol._DEFAULT_READ_TIMEOUT = 50.0
path = os.environ.get("TREZOR_PATH")
if path:
client = _client_from_path(request, path, interact)