mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-21 10:35:43 +00:00
feat(python): add a timeout argument to read() from transport
also take the opportunity to switch to new style typing annotations syntax [no changelog]
This commit is contained in:
parent
0fb1693ea8
commit
e1ce484ba7
@ -14,17 +14,10 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
|
||||
@ -52,6 +45,10 @@ class DeviceIsBusy(TransportException):
|
||||
pass
|
||||
|
||||
|
||||
class Timeout(TransportException):
|
||||
pass
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Raw connection to a Trezor device.
|
||||
|
||||
@ -84,23 +81,23 @@ class Transport:
|
||||
def end_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def find_debug(self: "T") -> "T":
|
||||
def find_debug(self: T) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["T"]:
|
||||
cls: type[T], models: Iterable[TrezorModel] | None = None
|
||||
) -> Iterable[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T:
|
||||
for device in cls.enumerate():
|
||||
if (
|
||||
path is None
|
||||
@ -112,13 +109,13 @@ class Transport:
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
|
||||
def all_transports() -> Iterable[Type["Transport"]]:
|
||||
def all_transports() -> Iterable[type["Transport"]]:
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
transports: Tuple[Type["Transport"], ...] = (
|
||||
transports: Tuple[type["Transport"], ...] = (
|
||||
BridgeTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
@ -128,9 +125,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
|
||||
|
||||
|
||||
def enumerate_devices(
|
||||
models: Optional[Iterable["TrezorModel"]] = None,
|
||||
) -> Sequence["Transport"]:
|
||||
devices: List["Transport"] = []
|
||||
models: Iterable[TrezorModel] | None = None,
|
||||
) -> Sequence[Transport]:
|
||||
devices: list[Transport] = []
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
@ -145,9 +142,7 @@ def enumerate_devices(
|
||||
return devices
|
||||
|
||||
|
||||
def get_transport(
|
||||
path: Optional[str] = None, prefix_search: bool = False
|
||||
) -> "Transport":
|
||||
def get_transport(path: str | None = None, prefix_search: bool = False) -> Transport:
|
||||
if path is None:
|
||||
try:
|
||||
return next(iter(enumerate_devices()))
|
||||
|
@ -14,11 +14,14 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
|
||||
import requests
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||
@ -45,9 +48,11 @@ class BridgeException(TransportException):
|
||||
super().__init__(f"trezord: {path} failed with code {status}: {message}")
|
||||
|
||||
|
||||
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
||||
def call_bridge(
|
||||
path: str, data: str | None = None, timeout: float | None = None
|
||||
) -> requests.Response:
|
||||
url = TREZORD_HOST + "/" + path
|
||||
r = CONNECTION.post(url, data=data)
|
||||
r = CONNECTION.post(url, data=data, timeout=timeout)
|
||||
if r.status_code != 200:
|
||||
raise BridgeException(path, r.status_code, r.json()["error"])
|
||||
return r
|
||||
@ -63,7 +68,7 @@ class BridgeHandle:
|
||||
def __init__(self, transport: "BridgeTransport") -> None:
|
||||
self.transport = transport
|
||||
|
||||
def read_buf(self) -> bytes:
|
||||
def read_buf(self, timeout: float | None = None) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
def write_buf(self, buf: bytes) -> None:
|
||||
@ -75,8 +80,8 @@ class BridgeHandleModern(BridgeHandle):
|
||||
LOG.log(DUMP_PACKETS, f"sending message: {buf.hex()}")
|
||||
self.transport._call("post", data=buf.hex())
|
||||
|
||||
def read_buf(self) -> bytes:
|
||||
data = self.transport._call("read")
|
||||
def read_buf(self, timeout: float | None = None) -> bytes:
|
||||
data = self.transport._call("read", timeout=timeout)
|
||||
LOG.log(DUMP_PACKETS, f"received message: {data.text}")
|
||||
return bytes.fromhex(data.text)
|
||||
|
||||
@ -84,19 +89,19 @@ class BridgeHandleModern(BridgeHandle):
|
||||
class BridgeHandleLegacy(BridgeHandle):
|
||||
def __init__(self, transport: "BridgeTransport") -> None:
|
||||
super().__init__(transport)
|
||||
self.request: Optional[str] = None
|
||||
self.request: str | None = None
|
||||
|
||||
def write_buf(self, buf: bytes) -> None:
|
||||
if self.request is not None:
|
||||
raise TransportException("Can't write twice on legacy Bridge")
|
||||
self.request = buf.hex()
|
||||
|
||||
def read_buf(self) -> bytes:
|
||||
def read_buf(self, timeout: float | None = None) -> bytes:
|
||||
if self.request is None:
|
||||
raise TransportException("Can't read without write on legacy Bridge")
|
||||
try:
|
||||
LOG.log(DUMP_PACKETS, f"calling with message: {self.request}")
|
||||
data = self.transport._call("call", data=self.request)
|
||||
data = self.transport._call("call", data=self.request, timeout=timeout)
|
||||
LOG.log(DUMP_PACKETS, f"received response: {data.text}")
|
||||
return bytes.fromhex(data.text)
|
||||
finally:
|
||||
@ -112,13 +117,13 @@ class BridgeTransport(Transport):
|
||||
ENABLED: bool = True
|
||||
|
||||
def __init__(
|
||||
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
||||
self, device: dict[str, Any], legacy: bool, debug: bool = False
|
||||
) -> None:
|
||||
if legacy and debug:
|
||||
raise TransportException("Debugging not supported on legacy Bridge")
|
||||
|
||||
self.device = device
|
||||
self.session: Optional[str] = None
|
||||
self.session: str | None = None
|
||||
self.debug = debug
|
||||
self.legacy = legacy
|
||||
|
||||
@ -130,21 +135,26 @@ class BridgeTransport(Transport):
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path']}"
|
||||
|
||||
def find_debug(self) -> "BridgeTransport":
|
||||
def find_debug(self) -> Self:
|
||||
if not self.device.get("debug"):
|
||||
raise TransportException("Debug device not available")
|
||||
return BridgeTransport(self.device, self.legacy, debug=True)
|
||||
return self.__class__(self.device, self.legacy, debug=True)
|
||||
|
||||
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
|
||||
def _call(
|
||||
self,
|
||||
action: str,
|
||||
data: str | None = None,
|
||||
timeout: float | None = None,
|
||||
) -> requests.Response:
|
||||
session = self.session or "null"
|
||||
uri = action + "/" + str(session)
|
||||
if self.debug:
|
||||
uri = "debug/" + uri
|
||||
return call_bridge(uri, data=data)
|
||||
return call_bridge(uri, data=data, timeout=timeout)
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
cls, _models: Iterable[TrezorModel] | None = None
|
||||
) -> Iterable["BridgeTransport"]:
|
||||
try:
|
||||
legacy = is_legacy_bridge()
|
||||
@ -173,8 +183,8 @@ class BridgeTransport(Transport):
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
self.handle.write_buf(header + message_data)
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
data = self.handle.read_buf()
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
data = self.handle.read_buf(timeout=timeout)
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||
return msg_type, data[headerlen : headerlen + datalen]
|
||||
|
@ -14,14 +14,16 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZOR_ONE, TrezorModel
|
||||
from . import UDEV_RULES_STR, TransportException
|
||||
from . import UDEV_RULES_STR, Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -91,13 +93,16 @@ class HidHandle:
|
||||
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
|
||||
self.handle.write(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||
start = time.time()
|
||||
while True:
|
||||
# hidapi seems to return lists of ints instead of bytes
|
||||
chunk = bytes(self.handle.read(64))
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
if timeout is not None and time.time() - start > timeout:
|
||||
raise Timeout(f"Timeout reading HID packet ({timeout}s)")
|
||||
time.sleep(0.001)
|
||||
|
||||
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
||||
@ -134,13 +139,13 @@ class HidTransport(ProtocolBasedTransport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
|
||||
) -> Iterable["HidTransport"]:
|
||||
cls, models: Iterable[TrezorModel] | None = None, debug: bool = False
|
||||
) -> Iterable[HidTransport]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: List["HidTransport"] = []
|
||||
devices: list[HidTransport] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
@ -154,7 +159,7 @@ class HidTransport(ProtocolBasedTransport):
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
def find_debug(self) -> HidTransport:
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
|
@ -14,9 +14,10 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import Tuple
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
@ -31,6 +32,8 @@ V2_END_SESSION = 0x04
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_READ_TIMEOUT: float | None = None
|
||||
|
||||
|
||||
class Handle(StructuralType):
|
||||
"""PEP 544 structural type for Handle functionality.
|
||||
@ -48,7 +51,7 @@ class Handle(StructuralType):
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
def read_chunk(self) -> bytes: ...
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes: ...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: ...
|
||||
|
||||
@ -86,7 +89,7 @@ class Protocol:
|
||||
if self.session_counter == 0:
|
||||
self.handle.close()
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
@ -106,8 +109,8 @@ class ProtocolBasedTransport(Transport):
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
self.protocol.write(message_type, message_data)
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
return self.protocol.read()
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
return self.protocol.read(timeout=timeout)
|
||||
|
||||
def begin_session(self) -> None:
|
||||
self.protocol.begin_session()
|
||||
@ -134,20 +137,23 @@ class ProtocolV1(Protocol):
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
if timeout is None:
|
||||
timeout = _DEFAULT_READ_TIMEOUT
|
||||
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
buffer.extend(self.read_next(timeout=timeout))
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk(timeout=timeout)
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
try:
|
||||
@ -158,8 +164,8 @@ class ProtocolV1(Protocol):
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
def read_next(self, timeout: float | None = None) -> bytes:
|
||||
chunk = self.handle.read_chunk(timeout=timeout)
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
return chunk[1:]
|
||||
|
@ -14,13 +14,15 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Iterable, Optional
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TransportException
|
||||
from . import Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -38,7 +40,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED: bool = True
|
||||
|
||||
def __init__(self, device: Optional[str] = None) -> None:
|
||||
def __init__(self, device: str | None = None) -> None:
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
@ -47,7 +49,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
self.device = (host, port)
|
||||
self.socket: Optional[socket.socket] = None
|
||||
self.socket: socket.socket | None = None
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self))
|
||||
|
||||
@ -77,7 +79,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
cls, _models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["UdpTransport"]:
|
||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||
try:
|
||||
@ -94,10 +96,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
if not prefix_search:
|
||||
raise
|
||||
|
||||
if prefix_search:
|
||||
return super().find_by_path(path, prefix_search)
|
||||
else:
|
||||
raise TransportException(f"No UDP device at {path}")
|
||||
assert prefix_search # otherwise we would have raised above
|
||||
return super().find_by_path(path, prefix_search)
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
@ -108,7 +108,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
raise Timeout("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
@ -142,14 +142,16 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}")
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||
assert self.socket is not None
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
chunk = self.socket.recv(64)
|
||||
break
|
||||
except socket.timeout:
|
||||
continue
|
||||
if timeout is not None and time.time() - start > timeout:
|
||||
raise Timeout(f"Timeout reading UDP packet ({timeout}s)")
|
||||
LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
|
@ -14,15 +14,19 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZORS, TrezorModel
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -45,12 +49,12 @@ WEBUSB_CHUNK_SIZE = 64
|
||||
|
||||
|
||||
class WebUsbHandle:
|
||||
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
|
||||
def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
|
||||
self.device = device
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.count = 0
|
||||
self.handle: Optional["usb1.USBDeviceHandle"] = None
|
||||
self.handle: usb1.USBDeviceHandle | None = None
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
@ -96,26 +100,24 @@ class WebUsbHandle:
|
||||
)
|
||||
return
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||
assert self.handle is not None
|
||||
endpoint = 0x80 | self.endpoint
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
chunk = self.handle.interruptRead(
|
||||
endpoint, WEBUSB_CHUNK_SIZE, USB_COMM_TIMEOUT_MS
|
||||
)
|
||||
if chunk:
|
||||
break
|
||||
else:
|
||||
time.sleep(0.001)
|
||||
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return chunk
|
||||
except usb1.USBErrorTimeout:
|
||||
pass
|
||||
if timeout is not None and time.time() - start > timeout:
|
||||
raise Timeout(f"Timeout reading WebUSB packet ({timeout}s)")
|
||||
except Exception as e:
|
||||
raise TransportException(f"USB read failed: {e}") from e
|
||||
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return chunk
|
||||
|
||||
|
||||
class WebUsbTransport(ProtocolBasedTransport):
|
||||
@ -129,8 +131,8 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
handle: Optional[WebUsbHandle] = None,
|
||||
device: usb1.USBDevice,
|
||||
handle: WebUsbHandle | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
if handle is None:
|
||||
@ -147,8 +149,10 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
cls,
|
||||
models: Iterable[TrezorModel] | None = None,
|
||||
usb_reset: bool = False,
|
||||
) -> Iterable[WebUsbTransport]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
@ -157,7 +161,7 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
devices: list[WebUsbTransport] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
@ -181,12 +185,12 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
handle.close()
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
def find_debug(self) -> Self:
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
return self.__class__(self.device, debug=True)
|
||||
|
||||
|
||||
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
||||
def is_vendor_class(dev: usb1.USBDevice) -> bool:
|
||||
configurationId = 0
|
||||
altSettingId = 0
|
||||
return (
|
||||
@ -195,7 +199,7 @@ def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
||||
)
|
||||
|
||||
|
||||
def dev_to_str(dev: "usb1.USBDevice") -> str:
|
||||
def dev_to_str(dev: usb1.USBDevice) -> str:
|
||||
return ":".join(
|
||||
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
||||
)
|
||||
|
@ -30,7 +30,7 @@ from trezorlib import debuglink, log, models
|
||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||
from trezorlib.device import apply_settings
|
||||
from trezorlib.device import wipe as wipe_device
|
||||
from trezorlib.transport import enumerate_devices, get_transport
|
||||
from trezorlib.transport import enumerate_devices, get_transport, protocol
|
||||
|
||||
# register rewrites before importing from local package
|
||||
# so that we see details of failed asserts from this module
|
||||
@ -134,6 +134,10 @@ def _raw_client(request: pytest.FixtureRequest) -> Client:
|
||||
client = emu_fixture.client
|
||||
else:
|
||||
interact = os.environ.get("INTERACT") == "1"
|
||||
if not interact:
|
||||
# prevent tests from getting stuck in case there is an USB packet loss
|
||||
protocol._DEFAULT_READ_TIMEOUT = 50.0
|
||||
|
||||
path = os.environ.get("TREZOR_PATH")
|
||||
if path:
|
||||
client = _client_from_path(request, path, interact)
|
||||
|
Loading…
Reference in New Issue
Block a user