mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-05-06 17:09:11 +00:00
feat(python): add a timeout argument to read() from transport
also take the opportunity to switch to new style typing annotations syntax [no changelog]
This commit is contained in:
parent
0fb1693ea8
commit
e1ce484ba7
@ -14,17 +14,10 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar
|
||||||
TYPE_CHECKING,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
|
||||||
)
|
|
||||||
|
|
||||||
from ..exceptions import TrezorException
|
from ..exceptions import TrezorException
|
||||||
|
|
||||||
@ -52,6 +45,10 @@ class DeviceIsBusy(TransportException):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Timeout(TransportException):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Transport:
|
class Transport:
|
||||||
"""Raw connection to a Trezor device.
|
"""Raw connection to a Trezor device.
|
||||||
|
|
||||||
@ -84,23 +81,23 @@ class Transport:
|
|||||||
def end_session(self) -> None:
|
def end_session(self) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def find_debug(self: "T") -> "T":
|
def find_debug(self: T) -> T:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(
|
def enumerate(
|
||||||
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
|
cls: type[T], models: Iterable[TrezorModel] | None = None
|
||||||
) -> Iterable["T"]:
|
) -> Iterable[T]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T:
|
||||||
for device in cls.enumerate():
|
for device in cls.enumerate():
|
||||||
if (
|
if (
|
||||||
path is None
|
path is None
|
||||||
@ -112,13 +109,13 @@ class Transport:
|
|||||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||||
|
|
||||||
|
|
||||||
def all_transports() -> Iterable[Type["Transport"]]:
|
def all_transports() -> Iterable[type["Transport"]]:
|
||||||
from .bridge import BridgeTransport
|
from .bridge import BridgeTransport
|
||||||
from .hid import HidTransport
|
from .hid import HidTransport
|
||||||
from .udp import UdpTransport
|
from .udp import UdpTransport
|
||||||
from .webusb import WebUsbTransport
|
from .webusb import WebUsbTransport
|
||||||
|
|
||||||
transports: Tuple[Type["Transport"], ...] = (
|
transports: Tuple[type["Transport"], ...] = (
|
||||||
BridgeTransport,
|
BridgeTransport,
|
||||||
HidTransport,
|
HidTransport,
|
||||||
UdpTransport,
|
UdpTransport,
|
||||||
@ -128,9 +125,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
|
|||||||
|
|
||||||
|
|
||||||
def enumerate_devices(
|
def enumerate_devices(
|
||||||
models: Optional[Iterable["TrezorModel"]] = None,
|
models: Iterable[TrezorModel] | None = None,
|
||||||
) -> Sequence["Transport"]:
|
) -> Sequence[Transport]:
|
||||||
devices: List["Transport"] = []
|
devices: list[Transport] = []
|
||||||
for transport in all_transports():
|
for transport in all_transports():
|
||||||
name = transport.__name__
|
name = transport.__name__
|
||||||
try:
|
try:
|
||||||
@ -145,9 +142,7 @@ def enumerate_devices(
|
|||||||
return devices
|
return devices
|
||||||
|
|
||||||
|
|
||||||
def get_transport(
|
def get_transport(path: str | None = None, prefix_search: bool = False) -> Transport:
|
||||||
path: Optional[str] = None, prefix_search: bool = False
|
|
||||||
) -> "Transport":
|
|
||||||
if path is None:
|
if path is None:
|
||||||
try:
|
try:
|
||||||
return next(iter(enumerate_devices()))
|
return next(iter(enumerate_devices()))
|
||||||
|
@ -14,11 +14,14 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
|
from typing import TYPE_CHECKING, Any, Iterable
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||||
@ -45,9 +48,11 @@ class BridgeException(TransportException):
|
|||||||
super().__init__(f"trezord: {path} failed with code {status}: {message}")
|
super().__init__(f"trezord: {path} failed with code {status}: {message}")
|
||||||
|
|
||||||
|
|
||||||
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
def call_bridge(
|
||||||
|
path: str, data: str | None = None, timeout: float | None = None
|
||||||
|
) -> requests.Response:
|
||||||
url = TREZORD_HOST + "/" + path
|
url = TREZORD_HOST + "/" + path
|
||||||
r = CONNECTION.post(url, data=data)
|
r = CONNECTION.post(url, data=data, timeout=timeout)
|
||||||
if r.status_code != 200:
|
if r.status_code != 200:
|
||||||
raise BridgeException(path, r.status_code, r.json()["error"])
|
raise BridgeException(path, r.status_code, r.json()["error"])
|
||||||
return r
|
return r
|
||||||
@ -63,7 +68,7 @@ class BridgeHandle:
|
|||||||
def __init__(self, transport: "BridgeTransport") -> None:
|
def __init__(self, transport: "BridgeTransport") -> None:
|
||||||
self.transport = transport
|
self.transport = transport
|
||||||
|
|
||||||
def read_buf(self) -> bytes:
|
def read_buf(self, timeout: float | None = None) -> bytes:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def write_buf(self, buf: bytes) -> None:
|
def write_buf(self, buf: bytes) -> None:
|
||||||
@ -75,8 +80,8 @@ class BridgeHandleModern(BridgeHandle):
|
|||||||
LOG.log(DUMP_PACKETS, f"sending message: {buf.hex()}")
|
LOG.log(DUMP_PACKETS, f"sending message: {buf.hex()}")
|
||||||
self.transport._call("post", data=buf.hex())
|
self.transport._call("post", data=buf.hex())
|
||||||
|
|
||||||
def read_buf(self) -> bytes:
|
def read_buf(self, timeout: float | None = None) -> bytes:
|
||||||
data = self.transport._call("read")
|
data = self.transport._call("read", timeout=timeout)
|
||||||
LOG.log(DUMP_PACKETS, f"received message: {data.text}")
|
LOG.log(DUMP_PACKETS, f"received message: {data.text}")
|
||||||
return bytes.fromhex(data.text)
|
return bytes.fromhex(data.text)
|
||||||
|
|
||||||
@ -84,19 +89,19 @@ class BridgeHandleModern(BridgeHandle):
|
|||||||
class BridgeHandleLegacy(BridgeHandle):
|
class BridgeHandleLegacy(BridgeHandle):
|
||||||
def __init__(self, transport: "BridgeTransport") -> None:
|
def __init__(self, transport: "BridgeTransport") -> None:
|
||||||
super().__init__(transport)
|
super().__init__(transport)
|
||||||
self.request: Optional[str] = None
|
self.request: str | None = None
|
||||||
|
|
||||||
def write_buf(self, buf: bytes) -> None:
|
def write_buf(self, buf: bytes) -> None:
|
||||||
if self.request is not None:
|
if self.request is not None:
|
||||||
raise TransportException("Can't write twice on legacy Bridge")
|
raise TransportException("Can't write twice on legacy Bridge")
|
||||||
self.request = buf.hex()
|
self.request = buf.hex()
|
||||||
|
|
||||||
def read_buf(self) -> bytes:
|
def read_buf(self, timeout: float | None = None) -> bytes:
|
||||||
if self.request is None:
|
if self.request is None:
|
||||||
raise TransportException("Can't read without write on legacy Bridge")
|
raise TransportException("Can't read without write on legacy Bridge")
|
||||||
try:
|
try:
|
||||||
LOG.log(DUMP_PACKETS, f"calling with message: {self.request}")
|
LOG.log(DUMP_PACKETS, f"calling with message: {self.request}")
|
||||||
data = self.transport._call("call", data=self.request)
|
data = self.transport._call("call", data=self.request, timeout=timeout)
|
||||||
LOG.log(DUMP_PACKETS, f"received response: {data.text}")
|
LOG.log(DUMP_PACKETS, f"received response: {data.text}")
|
||||||
return bytes.fromhex(data.text)
|
return bytes.fromhex(data.text)
|
||||||
finally:
|
finally:
|
||||||
@ -112,13 +117,13 @@ class BridgeTransport(Transport):
|
|||||||
ENABLED: bool = True
|
ENABLED: bool = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
self, device: dict[str, Any], legacy: bool, debug: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
if legacy and debug:
|
if legacy and debug:
|
||||||
raise TransportException("Debugging not supported on legacy Bridge")
|
raise TransportException("Debugging not supported on legacy Bridge")
|
||||||
|
|
||||||
self.device = device
|
self.device = device
|
||||||
self.session: Optional[str] = None
|
self.session: str | None = None
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.legacy = legacy
|
self.legacy = legacy
|
||||||
|
|
||||||
@ -130,21 +135,26 @@ class BridgeTransport(Transport):
|
|||||||
def get_path(self) -> str:
|
def get_path(self) -> str:
|
||||||
return f"{self.PATH_PREFIX}:{self.device['path']}"
|
return f"{self.PATH_PREFIX}:{self.device['path']}"
|
||||||
|
|
||||||
def find_debug(self) -> "BridgeTransport":
|
def find_debug(self) -> Self:
|
||||||
if not self.device.get("debug"):
|
if not self.device.get("debug"):
|
||||||
raise TransportException("Debug device not available")
|
raise TransportException("Debug device not available")
|
||||||
return BridgeTransport(self.device, self.legacy, debug=True)
|
return self.__class__(self.device, self.legacy, debug=True)
|
||||||
|
|
||||||
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
|
def _call(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
data: str | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> requests.Response:
|
||||||
session = self.session or "null"
|
session = self.session or "null"
|
||||||
uri = action + "/" + str(session)
|
uri = action + "/" + str(session)
|
||||||
if self.debug:
|
if self.debug:
|
||||||
uri = "debug/" + uri
|
uri = "debug/" + uri
|
||||||
return call_bridge(uri, data=data)
|
return call_bridge(uri, data=data, timeout=timeout)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(
|
def enumerate(
|
||||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
cls, _models: Iterable[TrezorModel] | None = None
|
||||||
) -> Iterable["BridgeTransport"]:
|
) -> Iterable["BridgeTransport"]:
|
||||||
try:
|
try:
|
||||||
legacy = is_legacy_bridge()
|
legacy = is_legacy_bridge()
|
||||||
@ -173,8 +183,8 @@ class BridgeTransport(Transport):
|
|||||||
header = struct.pack(">HL", message_type, len(message_data))
|
header = struct.pack(">HL", message_type, len(message_data))
|
||||||
self.handle.write_buf(header + message_data)
|
self.handle.write_buf(header + message_data)
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||||
data = self.handle.read_buf()
|
data = self.handle.read_buf(timeout=timeout)
|
||||||
headerlen = struct.calcsize(">HL")
|
headerlen = struct.calcsize(">HL")
|
||||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||||
return msg_type, data[headerlen : headerlen + datalen]
|
return msg_type, data[headerlen : headerlen + datalen]
|
||||||
|
@ -14,14 +14,16 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Iterable, List, Optional
|
from typing import Any, Dict, Iterable
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from ..models import TREZOR_ONE, TrezorModel
|
from ..models import TREZOR_ONE, TrezorModel
|
||||||
from . import UDEV_RULES_STR, TransportException
|
from . import UDEV_RULES_STR, Timeout, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -91,13 +93,16 @@ class HidHandle:
|
|||||||
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
|
LOG.log(DUMP_PACKETS, f"writing packet: {chunk.hex()}")
|
||||||
self.handle.write(chunk)
|
self.handle.write(chunk)
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||||
|
start = time.time()
|
||||||
while True:
|
while True:
|
||||||
# hidapi seems to return lists of ints instead of bytes
|
# hidapi seems to return lists of ints instead of bytes
|
||||||
chunk = bytes(self.handle.read(64))
|
chunk = bytes(self.handle.read(64))
|
||||||
if chunk:
|
if chunk:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
if timeout is not None and time.time() - start > timeout:
|
||||||
|
raise Timeout(f"Timeout reading HID packet ({timeout}s)")
|
||||||
time.sleep(0.001)
|
time.sleep(0.001)
|
||||||
|
|
||||||
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
||||||
@ -134,13 +139,13 @@ class HidTransport(ProtocolBasedTransport):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(
|
def enumerate(
|
||||||
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
|
cls, models: Iterable[TrezorModel] | None = None, debug: bool = False
|
||||||
) -> Iterable["HidTransport"]:
|
) -> Iterable[HidTransport]:
|
||||||
if models is None:
|
if models is None:
|
||||||
models = {TREZOR_ONE}
|
models = {TREZOR_ONE}
|
||||||
usb_ids = [id for model in models for id in model.usb_ids]
|
usb_ids = [id for model in models for id in model.usb_ids]
|
||||||
|
|
||||||
devices: List["HidTransport"] = []
|
devices: list[HidTransport] = []
|
||||||
for dev in hid.enumerate(0, 0):
|
for dev in hid.enumerate(0, 0):
|
||||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||||
if usb_id not in usb_ids:
|
if usb_id not in usb_ids:
|
||||||
@ -154,7 +159,7 @@ class HidTransport(ProtocolBasedTransport):
|
|||||||
devices.append(HidTransport(dev))
|
devices.append(HidTransport(dev))
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
def find_debug(self) -> "HidTransport":
|
def find_debug(self) -> HidTransport:
|
||||||
# For v1 protocol, find debug USB interface for the same serial number
|
# For v1 protocol, find debug USB interface for the same serial number
|
||||||
for debug in HidTransport.enumerate(debug=True):
|
for debug in HidTransport.enumerate(debug=True):
|
||||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||||
|
@ -14,9 +14,10 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import struct
|
import struct
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from typing_extensions import Protocol as StructuralType
|
from typing_extensions import Protocol as StructuralType
|
||||||
|
|
||||||
@ -31,6 +32,8 @@ V2_END_SESSION = 0x04
|
|||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_DEFAULT_READ_TIMEOUT: float | None = None
|
||||||
|
|
||||||
|
|
||||||
class Handle(StructuralType):
|
class Handle(StructuralType):
|
||||||
"""PEP 544 structural type for Handle functionality.
|
"""PEP 544 structural type for Handle functionality.
|
||||||
@ -48,7 +51,7 @@ class Handle(StructuralType):
|
|||||||
|
|
||||||
def close(self) -> None: ...
|
def close(self) -> None: ...
|
||||||
|
|
||||||
def read_chunk(self) -> bytes: ...
|
def read_chunk(self, timeout: float | None = None) -> bytes: ...
|
||||||
|
|
||||||
def write_chunk(self, chunk: bytes) -> None: ...
|
def write_chunk(self, chunk: bytes) -> None: ...
|
||||||
|
|
||||||
@ -86,7 +89,7 @@ class Protocol:
|
|||||||
if self.session_counter == 0:
|
if self.session_counter == 0:
|
||||||
self.handle.close()
|
self.handle.close()
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
@ -106,8 +109,8 @@ class ProtocolBasedTransport(Transport):
|
|||||||
def write(self, message_type: int, message_data: bytes) -> None:
|
def write(self, message_type: int, message_data: bytes) -> None:
|
||||||
self.protocol.write(message_type, message_data)
|
self.protocol.write(message_type, message_data)
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||||
return self.protocol.read()
|
return self.protocol.read(timeout=timeout)
|
||||||
|
|
||||||
def begin_session(self) -> None:
|
def begin_session(self) -> None:
|
||||||
self.protocol.begin_session()
|
self.protocol.begin_session()
|
||||||
@ -134,20 +137,23 @@ class ProtocolV1(Protocol):
|
|||||||
self.handle.write_chunk(chunk)
|
self.handle.write_chunk(chunk)
|
||||||
buffer = buffer[63:]
|
buffer = buffer[63:]
|
||||||
|
|
||||||
def read(self) -> MessagePayload:
|
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||||
|
if timeout is None:
|
||||||
|
timeout = _DEFAULT_READ_TIMEOUT
|
||||||
|
|
||||||
buffer = bytearray()
|
buffer = bytearray()
|
||||||
# Read header with first part of message data
|
# Read header with first part of message data
|
||||||
msg_type, datalen, first_chunk = self.read_first()
|
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
|
||||||
buffer.extend(first_chunk)
|
buffer.extend(first_chunk)
|
||||||
|
|
||||||
# Read the rest of the message
|
# Read the rest of the message
|
||||||
while len(buffer) < datalen:
|
while len(buffer) < datalen:
|
||||||
buffer.extend(self.read_next())
|
buffer.extend(self.read_next(timeout=timeout))
|
||||||
|
|
||||||
return msg_type, buffer[:datalen]
|
return msg_type, buffer[:datalen]
|
||||||
|
|
||||||
def read_first(self) -> Tuple[int, int, bytes]:
|
def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]:
|
||||||
chunk = self.handle.read_chunk()
|
chunk = self.handle.read_chunk(timeout=timeout)
|
||||||
if chunk[:3] != b"?##":
|
if chunk[:3] != b"?##":
|
||||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||||
try:
|
try:
|
||||||
@ -158,8 +164,8 @@ class ProtocolV1(Protocol):
|
|||||||
data = chunk[3 + self.HEADER_LEN :]
|
data = chunk[3 + self.HEADER_LEN :]
|
||||||
return msg_type, datalen, data
|
return msg_type, datalen, data
|
||||||
|
|
||||||
def read_next(self) -> bytes:
|
def read_next(self, timeout: float | None = None) -> bytes:
|
||||||
chunk = self.handle.read_chunk()
|
chunk = self.handle.read_chunk(timeout=timeout)
|
||||||
if chunk[:1] != b"?":
|
if chunk[:1] != b"?":
|
||||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||||
return chunk[1:]
|
return chunk[1:]
|
||||||
|
@ -14,13 +14,15 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
import socket
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Iterable, Optional
|
from typing import TYPE_CHECKING, Iterable
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from . import TransportException
|
from . import Timeout, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -38,7 +40,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
PATH_PREFIX = "udp"
|
PATH_PREFIX = "udp"
|
||||||
ENABLED: bool = True
|
ENABLED: bool = True
|
||||||
|
|
||||||
def __init__(self, device: Optional[str] = None) -> None:
|
def __init__(self, device: str | None = None) -> None:
|
||||||
if not device:
|
if not device:
|
||||||
host = UdpTransport.DEFAULT_HOST
|
host = UdpTransport.DEFAULT_HOST
|
||||||
port = UdpTransport.DEFAULT_PORT
|
port = UdpTransport.DEFAULT_PORT
|
||||||
@ -47,7 +49,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
host = devparts[0]
|
host = devparts[0]
|
||||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||||
self.device = (host, port)
|
self.device = (host, port)
|
||||||
self.socket: Optional[socket.socket] = None
|
self.socket: socket.socket | None = None
|
||||||
|
|
||||||
super().__init__(protocol=ProtocolV1(self))
|
super().__init__(protocol=ProtocolV1(self))
|
||||||
|
|
||||||
@ -77,7 +79,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(
|
def enumerate(
|
||||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
cls, _models: Iterable["TrezorModel"] | None = None
|
||||||
) -> Iterable["UdpTransport"]:
|
) -> Iterable["UdpTransport"]:
|
||||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||||
try:
|
try:
|
||||||
@ -94,10 +96,8 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
if not prefix_search:
|
if not prefix_search:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if prefix_search:
|
assert prefix_search # otherwise we would have raised above
|
||||||
return super().find_by_path(path, prefix_search)
|
return super().find_by_path(path, prefix_search)
|
||||||
else:
|
|
||||||
raise TransportException(f"No UDP device at {path}")
|
|
||||||
|
|
||||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||||
try:
|
try:
|
||||||
@ -108,7 +108,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
break
|
break
|
||||||
elapsed = time.monotonic() - start
|
elapsed = time.monotonic() - start
|
||||||
if elapsed >= timeout:
|
if elapsed >= timeout:
|
||||||
raise TransportException("Timed out waiting for connection.")
|
raise Timeout("Timed out waiting for connection.")
|
||||||
|
|
||||||
time.sleep(0.05)
|
time.sleep(0.05)
|
||||||
finally:
|
finally:
|
||||||
@ -142,14 +142,16 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}")
|
LOG.log(DUMP_PACKETS, f"sending packet: {chunk.hex()}")
|
||||||
self.socket.sendall(chunk)
|
self.socket.sendall(chunk)
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||||
assert self.socket is not None
|
assert self.socket is not None
|
||||||
|
start = time.time()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = self.socket.recv(64)
|
chunk = self.socket.recv(64)
|
||||||
break
|
break
|
||||||
except socket.timeout:
|
except socket.timeout:
|
||||||
continue
|
if timeout is not None and time.time() - start > timeout:
|
||||||
|
raise Timeout(f"Timeout reading UDP packet ({timeout}s)")
|
||||||
LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
|
LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
|
||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||||
|
@ -14,15 +14,19 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, List, Optional
|
from typing import Iterable
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
from ..log import DUMP_PACKETS
|
from ..log import DUMP_PACKETS
|
||||||
from ..models import TREZORS, TrezorModel
|
from ..models import TREZORS, TrezorModel
|
||||||
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@ -45,12 +49,12 @@ WEBUSB_CHUNK_SIZE = 64
|
|||||||
|
|
||||||
|
|
||||||
class WebUsbHandle:
|
class WebUsbHandle:
|
||||||
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
|
def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||||
self.count = 0
|
self.count = 0
|
||||||
self.handle: Optional["usb1.USBDeviceHandle"] = None
|
self.handle: usb1.USBDeviceHandle | None = None
|
||||||
|
|
||||||
def open(self) -> None:
|
def open(self) -> None:
|
||||||
self.handle = self.device.open()
|
self.handle = self.device.open()
|
||||||
@ -96,26 +100,24 @@ class WebUsbHandle:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
def read_chunk(self) -> bytes:
|
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||||
assert self.handle is not None
|
assert self.handle is not None
|
||||||
endpoint = 0x80 | self.endpoint
|
endpoint = 0x80 | self.endpoint
|
||||||
|
start = time.time()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
chunk = self.handle.interruptRead(
|
chunk = self.handle.interruptRead(
|
||||||
endpoint, WEBUSB_CHUNK_SIZE, USB_COMM_TIMEOUT_MS
|
endpoint, WEBUSB_CHUNK_SIZE, USB_COMM_TIMEOUT_MS
|
||||||
)
|
)
|
||||||
if chunk:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
time.sleep(0.001)
|
|
||||||
except usb1.USBErrorTimeout:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
raise TransportException(f"USB read failed: {e}") from e
|
|
||||||
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
LOG.log(DUMP_PACKETS, f"read packet: {chunk.hex()}")
|
||||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||||
return chunk
|
return chunk
|
||||||
|
except usb1.USBErrorTimeout:
|
||||||
|
if timeout is not None and time.time() - start > timeout:
|
||||||
|
raise Timeout(f"Timeout reading WebUSB packet ({timeout}s)")
|
||||||
|
except Exception as e:
|
||||||
|
raise TransportException(f"USB read failed: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
class WebUsbTransport(ProtocolBasedTransport):
|
class WebUsbTransport(ProtocolBasedTransport):
|
||||||
@ -129,8 +131,8 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: "usb1.USBDevice",
|
device: usb1.USBDevice,
|
||||||
handle: Optional[WebUsbHandle] = None,
|
handle: WebUsbHandle | None = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if handle is None:
|
if handle is None:
|
||||||
@ -147,8 +149,10 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def enumerate(
|
def enumerate(
|
||||||
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
|
cls,
|
||||||
) -> Iterable["WebUsbTransport"]:
|
models: Iterable[TrezorModel] | None = None,
|
||||||
|
usb_reset: bool = False,
|
||||||
|
) -> Iterable[WebUsbTransport]:
|
||||||
if cls.context is None:
|
if cls.context is None:
|
||||||
cls.context = usb1.USBContext()
|
cls.context = usb1.USBContext()
|
||||||
cls.context.open()
|
cls.context.open()
|
||||||
@ -157,7 +161,7 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
if models is None:
|
if models is None:
|
||||||
models = TREZORS
|
models = TREZORS
|
||||||
usb_ids = [id for model in models for id in model.usb_ids]
|
usb_ids = [id for model in models for id in model.usb_ids]
|
||||||
devices: List["WebUsbTransport"] = []
|
devices: list[WebUsbTransport] = []
|
||||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||||
if usb_id not in usb_ids:
|
if usb_id not in usb_ids:
|
||||||
@ -181,12 +185,12 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
handle.close()
|
handle.close()
|
||||||
return devices
|
return devices
|
||||||
|
|
||||||
def find_debug(self) -> "WebUsbTransport":
|
def find_debug(self) -> Self:
|
||||||
# For v1 protocol, find debug USB interface for the same serial number
|
# For v1 protocol, find debug USB interface for the same serial number
|
||||||
return WebUsbTransport(self.device, debug=True)
|
return self.__class__(self.device, debug=True)
|
||||||
|
|
||||||
|
|
||||||
def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
def is_vendor_class(dev: usb1.USBDevice) -> bool:
|
||||||
configurationId = 0
|
configurationId = 0
|
||||||
altSettingId = 0
|
altSettingId = 0
|
||||||
return (
|
return (
|
||||||
@ -195,7 +199,7 @@ def is_vendor_class(dev: "usb1.USBDevice") -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def dev_to_str(dev: "usb1.USBDevice") -> str:
|
def dev_to_str(dev: usb1.USBDevice) -> str:
|
||||||
return ":".join(
|
return ":".join(
|
||||||
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
||||||
)
|
)
|
||||||
|
@ -30,7 +30,7 @@ from trezorlib import debuglink, log, models
|
|||||||
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
from trezorlib.debuglink import TrezorClientDebugLink as Client
|
||||||
from trezorlib.device import apply_settings
|
from trezorlib.device import apply_settings
|
||||||
from trezorlib.device import wipe as wipe_device
|
from trezorlib.device import wipe as wipe_device
|
||||||
from trezorlib.transport import enumerate_devices, get_transport
|
from trezorlib.transport import enumerate_devices, get_transport, protocol
|
||||||
|
|
||||||
# register rewrites before importing from local package
|
# register rewrites before importing from local package
|
||||||
# so that we see details of failed asserts from this module
|
# so that we see details of failed asserts from this module
|
||||||
@ -134,6 +134,10 @@ def _raw_client(request: pytest.FixtureRequest) -> Client:
|
|||||||
client = emu_fixture.client
|
client = emu_fixture.client
|
||||||
else:
|
else:
|
||||||
interact = os.environ.get("INTERACT") == "1"
|
interact = os.environ.get("INTERACT") == "1"
|
||||||
|
if not interact:
|
||||||
|
# prevent tests from getting stuck in case there is an USB packet loss
|
||||||
|
protocol._DEFAULT_READ_TIMEOUT = 50.0
|
||||||
|
|
||||||
path = os.environ.get("TREZOR_PATH")
|
path = os.environ.get("TREZOR_PATH")
|
||||||
if path:
|
if path:
|
||||||
client = _client_from_path(request, path, interact)
|
client = _client_from_path(request, path, interact)
|
||||||
|
Loading…
Reference in New Issue
Block a user