1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-03-12 22:26:08 +00:00

chore(core): adapt trezorlib transports to session based

[no changelog]

Co-authored-by: mmilata <martin@martinmilata.cz>
This commit is contained in:
M1nd3r 2025-02-04 15:21:19 +01:00
parent 5c2d3c65b7
commit 72fe47ba3b
9 changed files with 380 additions and 434 deletions

View File

@ -2,7 +2,7 @@
import binascii
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.transport.hid import HidTransport
devices = HidTransport.enumerate()
if len(devices) > 0:

View File

@ -14,24 +14,18 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
from typing import (
TYPE_CHECKING,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
)
import typing as t
from ..exceptions import TrezorException
if TYPE_CHECKING:
if t.TYPE_CHECKING:
from ..models import TrezorModel
T = TypeVar("T", bound="Transport")
T = t.TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__)
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
""".strip()
MessagePayload = Tuple[int, bytes]
MessagePayload = t.Tuple[int, bytes]
class TransportException(TrezorException):
@ -53,73 +47,57 @@ class DeviceIsBusy(TransportException):
class Transport:
"""Raw connection to a Trezor device.
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
or USB-HID connection, or UDP socket of listening emulator(s).
It can also enumerate devices available over this communication link, and return
them as instances.
Transport instance is a thing that:
- can be identified and requested by a string URI-like path
- can open and close sessions, which enclose related operations
- can read and write protobuf messages
You need to implement a new Transport subclass if you invent a new way to connect
a Trezor device to a computer.
"""
PATH_PREFIX: str
ENABLED = False
def __str__(self) -> str:
return self.get_path()
@classmethod
def enumerate(
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
) -> t.Iterable["T"]:
raise NotImplementedError
@classmethod
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if device.get_path() == path:
return device
if prefix_search and device.get_path().startswith(path):
return device
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def get_path(self) -> str:
raise NotImplementedError
def begin_session(self) -> None:
raise NotImplementedError
def end_session(self) -> None:
raise NotImplementedError
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
def find_debug(self: "T") -> "T":
raise NotImplementedError
@classmethod
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
def open(self) -> None:
raise NotImplementedError
@classmethod
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
for device in cls.enumerate():
if (
path is None
or device.get_path() == path
or (prefix_search and device.get_path().startswith(path))
):
return device
def close(self) -> None:
raise NotImplementedError
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
def write_chunk(self, chunk: bytes) -> None:
raise NotImplementedError
def read_chunk(self) -> bytes:
raise NotImplementedError
def ping(self) -> bool:
raise NotImplementedError
CHUNK_SIZE: t.ClassVar[int | None]
def all_transports() -> Iterable[Type["Transport"]]:
from .ble import BleTransport
def all_transports() -> t.Iterable[t.Type["Transport"]]:
from .bridge import BridgeTransport
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport
transports: Tuple[Type["Transport"], ...] = (
transports: t.Tuple[t.Type["Transport"], ...] = (
BridgeTransport,
HidTransport,
UdpTransport,
@ -130,9 +108,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None,
) -> Sequence["Transport"]:
devices: List["Transport"] = []
models: t.Iterable["TrezorModel"] | None = None,
) -> t.Sequence["Transport"]:
devices: t.List["Transport"] = []
for transport in all_transports():
name = transport.__name__
try:
@ -147,9 +125,7 @@ def enumerate_devices(
return devices
def get_transport(
path: Optional[str] = None, prefix_search: bool = False
) -> "Transport":
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
if path is None:
try:
return next(iter(enumerate_devices()))

View File

@ -14,16 +14,17 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import struct
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
import typing as t
import requests
from ..log import DUMP_PACKETS
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
from . import DeviceIsBusy, Transport, TransportException
if TYPE_CHECKING:
if t.TYPE_CHECKING:
from ..models import TrezorModel
LOG = logging.getLogger(__name__)
@ -45,7 +46,7 @@ class BridgeException(TransportException):
super().__init__(f"trezord: {path} failed with code {status}: {message}")
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
def call_bridge(path: str, data: str | None = None) -> requests.Response:
url = TREZORD_HOST + "/" + path
r = CONNECTION.post(url, data=data)
if r.status_code != 200:
@ -53,10 +54,13 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
return r
def is_legacy_bridge() -> bool:
def get_bridge_version() -> t.Tuple[int, ...]:
config = call_bridge("configure").json()
version_tuple = tuple(map(int, config["version"].split(".")))
return version_tuple < TREZORD_VERSION_MODERN
return tuple(map(int, config["version"].split(".")))
def is_legacy_bridge() -> bool:
return get_bridge_version() < TREZORD_VERSION_MODERN
class BridgeHandle:
@ -84,7 +88,7 @@ class BridgeHandleModern(BridgeHandle):
class BridgeHandleLegacy(BridgeHandle):
def __init__(self, transport: "BridgeTransport") -> None:
super().__init__(transport)
self.request: Optional[str] = None
self.request: str | None = None
def write_buf(self, buf: bytes) -> None:
if self.request is not None:
@ -110,15 +114,15 @@ class BridgeTransport(Transport):
PATH_PREFIX = "bridge"
ENABLED: bool = True
CHUNK_SIZE = None
def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False
self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
) -> None:
if legacy and debug:
raise TransportException("Debugging not supported on legacy Bridge")
self.device = device
self.session: Optional[str] = None
self.session: str | None = device["session"]
self.debug = debug
self.legacy = legacy
@ -135,7 +139,7 @@ class BridgeTransport(Transport):
raise TransportException("Debug device not available")
return BridgeTransport(self.device, self.legacy, debug=True)
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
def _call(self, action: str, data: str | None = None) -> requests.Response:
session = self.session or "null"
uri = action + "/" + str(session)
if self.debug:
@ -144,8 +148,8 @@ class BridgeTransport(Transport):
@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["BridgeTransport"]:
cls, _models: t.Iterable["TrezorModel"] | None = None
) -> t.Iterable["BridgeTransport"]:
try:
legacy = is_legacy_bridge()
return [
@ -154,7 +158,7 @@ class BridgeTransport(Transport):
except Exception:
return []
def begin_session(self) -> None:
def open(self) -> None:
try:
data = self._call("acquire/" + self.device["path"])
except BridgeException as e:
@ -163,18 +167,17 @@ class BridgeTransport(Transport):
raise
self.session = data.json()["session"]
def end_session(self) -> None:
def close(self) -> None:
if not self.session:
return
self._call("release")
self.session = None
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
self.handle.write_buf(header + message_data)
def write_chunk(self, chunk: bytes) -> None:
self.handle.write_buf(chunk)
def read(self) -> MessagePayload:
data = self.handle.read_buf()
headerlen = struct.calcsize(">HL")
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
return msg_type, data[headerlen : headerlen + datalen]
def read_chunk(self) -> bytes:
return self.handle.read_buf()
def ping(self) -> bool:
return self.session is not None

View File

@ -14,15 +14,16 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import sys
import time
from typing import Any, Dict, Iterable, List, Optional
import typing as t
from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import UDEV_RULES_STR, Transport, TransportException
LOG = logging.getLogger(__name__)
@ -35,23 +36,61 @@ except Exception as e:
HID_IMPORTED = False
HidDevice = Dict[str, Any]
HidDeviceHandle = Any
HidDevice = t.Dict[str, t.Any]
HidDeviceHandle = t.Any
class HidHandle:
def __init__(
self, path: bytes, serial: str, probe_hid_version: bool = False
) -> None:
self.path = path
self.serial = serial
class HidTransport(Transport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
self.device = device
self.device_path = device["path"]
self.device_serial_number = device["serial_number"]
self.handle: HidDeviceHandle = None
self.hid_version = None if probe_hid_version else 2
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
) -> t.Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: t.List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def open(self) -> None:
self.handle = hid.device()
try:
self.handle.open_path(self.path)
self.handle.open_path(self.device_path)
except (IOError, OSError) as e:
if sys.platform.startswith("linux"):
e.args = e.args + (UDEV_RULES_STR,)
@ -62,11 +101,11 @@ class HidHandle:
# and we wouldn't even know.
# So we check that the serial matches what we expect.
serial = self.handle.get_serial_number_string()
if serial != self.serial:
if serial != self.device_serial_number:
self.handle.close()
self.handle = None
raise TransportException(
f"Unexpected device {serial} on path {self.path.decode()}"
f"Unexpected device {serial} on path {self.device_path.decode()}"
)
self.handle.set_nonblocking(True)
@ -77,7 +116,7 @@ class HidHandle:
def close(self) -> None:
if self.handle is not None:
# reload serial, because device.wipe() can reset it
self.serial = self.handle.get_serial_number_string()
self.device_serial_number = self.handle.get_serial_number_string()
self.handle.close()
self.handle = None
@ -115,53 +154,6 @@ class HidHandle:
raise TransportException("Unknown HID version")
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None:
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
super().__init__(protocol=ProtocolV1(self.handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0

View File

@ -1,166 +0,0 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import struct
from typing import Tuple
from typing_extensions import Protocol as StructuralType
from . import MessagePayload, Transport
REPLEN = 64
V2_FIRST_CHUNK = 0x01
V2_NEXT_CHUNK = 0x02
V2_BEGIN_SESSION = 0x03
V2_END_SESSION = 0x04
LOG = logging.getLogger(__name__)
class Handle(StructuralType):
"""PEP 544 structural type for Handle functionality.
(called a "Protocol" in the proposed PEP, name which is impractical here)
Handle is a "physical" layer for a protocol.
It can open/close a connection and read/write bare data in 64-byte chunks.
Functionally we gain nothing from making this an (abstract) base class for handle
implementations, so this definition is for type hinting purposes only. You can,
but don't have to, inherit from it.
"""
def open(self) -> None: ...
def close(self) -> None: ...
def read_chunk(self) -> bytes: ...
def write_chunk(self, chunk: bytes) -> None: ...
class Protocol:
"""Wire protocol that can communicate with a Trezor device, given a Handle.
A Protocol implements the part of the Transport API that relates to communicating
logical messages over a physical layer. It is a thing that can:
- open and close sessions,
- send and receive protobuf messages,
given the ability to:
- open and close physical connections,
- and send and receive binary chunks.
For now, the class also handles session counting and opening the underlying Handle.
This will probably be removed in the future.
We will need a new Protocol class if we change the way a Trezor device encapsulates
its messages.
"""
def __init__(self, handle: Handle, replen: int = REPLEN) -> None:
self.handle = handle
self.replen = replen
self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling
def begin_session(self) -> None:
if self.session_counter == 0:
self.handle.open()
self.session_counter += 1
def end_session(self) -> None:
self.session_counter = max(self.session_counter - 1, 0)
if self.session_counter == 0:
self.handle.close()
def read(self) -> MessagePayload:
raise NotImplementedError
def write(self, message_type: int, message_data: bytes) -> None:
raise NotImplementedError
class ProtocolBasedTransport(Transport):
"""Transport that implements its communications through a Protocol.
Intended as a base class for implementations that proxy their communication
operations to a Protocol.
"""
def __init__(self, protocol: Protocol) -> None:
self.protocol = protocol
def write(self, message_type: int, message_data: bytes) -> None:
self.protocol.write(message_type, message_data)
def read(self) -> MessagePayload:
return self.protocol.read()
def begin_session(self) -> None:
self.protocol.begin_session()
def end_session(self) -> None:
self.protocol.end_session()
class ProtocolV1(Protocol):
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
Does not understand sessions.
"""
HEADER_LEN = struct.calcsize(">HL")
def write(self, message_type: int, message_data: bytes) -> None:
header = struct.pack(">HL", message_type, len(message_data))
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: self.replen - 1]
chunk = chunk.ljust(self.replen, b"\x00")
self.handle.write_chunk(chunk)
buffer = buffer[self.replen - 1 :]
def read(self) -> MessagePayload:
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_first(self) -> Tuple[int, int, bytes]:
chunk = self.handle.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.handle.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:]

View File

@ -0,0 +1,26 @@
from __future__ import annotations
import logging
from ... import messages
from ...mapping import ProtobufMapping
from .. import Transport
LOG = logging.getLogger(__name__)
class Channel:
def __init__(
self,
transport: Transport,
mapping: ProtobufMapping,
) -> None:
self.transport = transport
self.mapping = mapping
def get_features(self) -> messages.Features:
raise NotImplementedError()
def update_features(self) -> None:
raise NotImplementedError

View File

@ -0,0 +1,109 @@
from __future__ import annotations
import logging
import struct
import typing as t
from ... import exceptions, messages
from ...log import DUMP_BYTES
from .protocol_and_channel import Channel
LOG = logging.getLogger(__name__)
class ProtocolV1Channel(Channel):
HEADER_LEN = struct.calcsize(">HL")
_features: messages.Features | None = None
def get_features(self) -> messages.Features:
if self._features is None:
self.update_features()
assert self._features is not None
return self._features
def update_features(self) -> None:
self.write(messages.GetFeatures())
resp = self.read()
if not isinstance(resp, messages.Features):
raise exceptions.TrezorException("Unexpected response to GetFeatures")
self._features = resp
def read(self) -> t.Any:
msg_type, msg_bytes = self._read()
LOG.log(
DUMP_BYTES,
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
msg = self.mapping.decode(msg_type, msg_bytes)
LOG.debug(
f"received message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
self.transport.close()
return msg
def write(self, msg: t.Any) -> None:
LOG.debug(
f"sending message: {msg.__class__.__name__}",
extra={"protobuf": msg},
)
msg_type, msg_bytes = self.mapping.encode(msg)
LOG.log(
DUMP_BYTES,
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
)
self._write(msg_type, msg_bytes)
def _write(self, message_type: int, message_data: bytes) -> None:
chunk_size = self.transport.CHUNK_SIZE
header = struct.pack(">HL", message_type, len(message_data))
if chunk_size is None:
self.transport.write_chunk(header + message_data)
return
buffer = bytearray(b"##" + header + message_data)
while buffer:
# Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: chunk_size - 1]
chunk = chunk.ljust(chunk_size, b"\x00")
self.transport.write_chunk(chunk)
buffer = buffer[63:]
def _read(self) -> t.Tuple[int, bytes]:
if self.transport.CHUNK_SIZE is None:
return self.read_chunkless()
buffer = bytearray()
# Read header with first part of message data
msg_type, datalen, first_chunk = self.read_first()
buffer.extend(first_chunk)
# Read the rest of the message
while len(buffer) < datalen:
buffer.extend(self.read_next())
return msg_type, buffer[:datalen]
def read_chunkless(self) -> t.Tuple[int, bytes]:
data = self.transport.read_chunk()
msg_type, datalen = struct.unpack(">HL", data[: self.HEADER_LEN])
return msg_type, data[self.HEADER_LEN : self.HEADER_LEN + datalen]
def read_first(self) -> t.Tuple[int, int, bytes]:
chunk = self.transport.read_chunk()
if chunk[:3] != b"?##":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
try:
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
except Exception:
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
data = chunk[3 + self.HEADER_LEN :]
return msg_type, datalen, data
def read_next(self) -> bytes:
chunk = self.transport.read_chunk()
if chunk[:1] != b"?":
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
return chunk[1:]

View File

@ -14,14 +14,15 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import socket
import time
from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, Iterable, Tuple
from ..log import DUMP_PACKETS
from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import Transport, TransportException
if TYPE_CHECKING:
from ..models import TrezorModel
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__)
class UdpTransport(ProtocolBasedTransport):
class UdpTransport(Transport):
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324
PATH_PREFIX = "udp"
ENABLED: bool = True
CHUNK_SIZE = 64
def __init__(self, device: Optional[str] = None) -> None:
def __init__(
self,
device: str | None = None,
) -> None:
if not device:
host = UdpTransport.DEFAULT_HOST
port = UdpTransport.DEFAULT_PORT
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
devparts = device.split(":")
host = devparts[0]
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
self.device = (host, port)
self.socket: Optional[socket.socket] = None
self.device: Tuple[str, int] = (host, port)
super().__init__(protocol=ProtocolV1(self))
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
self.socket: socket.socket | None = None
super().__init__()
@classmethod
def _try_path(cls, path: str) -> "UdpTransport":
d = cls(path)
try:
d.open()
if d._ping():
if d.ping():
return d
else:
raise TransportException(
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
cls, _models: Iterable["TrezorModel"] | None = None
) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try:
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
else:
raise TransportException(f"No UDP device at {path}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self._ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def get_path(self) -> str:
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
def open(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.close()
self.socket = None
def _ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"
def write_chunk(self, chunk: bytes) -> None:
if self.socket is None:
self.open()
assert self.socket is not None
if len(chunk) != 64:
raise TransportException("Unexpected data length")
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
self.socket.sendall(chunk)
def read_chunk(self) -> bytes:
if self.socket is None:
self.open()
assert self.socket is not None
while True:
try:
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk)
def find_debug(self) -> "UdpTransport":
host, port = self.device
return UdpTransport(f"{host}:{port + 1}")
def wait_until_ready(self, timeout: float = 10) -> None:
try:
self.open()
start = time.monotonic()
while True:
if self.ping():
break
elapsed = time.monotonic() - start
if elapsed >= timeout:
raise TransportException("Timed out waiting for connection.")
time.sleep(0.05)
finally:
self.close()
def ping(self) -> bool:
"""Test if the device is listening."""
assert self.socket is not None
resp = None
try:
self.socket.sendall(b"PINGPING")
resp = self.socket.recv(8)
except Exception:
pass
return resp == b"PONGPONG"

View File

@ -14,16 +14,17 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import atexit
import logging
import sys
import time
from typing import Iterable, List, Optional
from typing import Iterable, List
from ..log import DUMP_PACKETS
from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
LOG = logging.getLogger(__name__)
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
WEBUSB_CHUNK_SIZE = 64
class WebUsbHandle:
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
class WebUsbTransport(Transport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
CHUNK_SIZE = 64
def __init__(
self,
device: "usb1.USBDevice",
debug: bool = False,
) -> None:
self.device = device
self.debug = debug
self.interface = DEBUG_INTERFACE if debug else INTERFACE
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
self.count = 0
self.handle: Optional["usb1.USBDeviceHandle"] = None
self.handle: usb1.USBDeviceHandle | None = None
super().__init__()
@classmethod
def enumerate(
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
return devices
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
def open(self) -> None:
self.handle = self.device.open()
@ -77,6 +134,8 @@ class WebUsbHandle:
self.handle = None
def write_chunk(self, chunk: bytes) -> None:
if self.handle is None:
self.open()
assert self.handle is not None
if len(chunk) != WEBUSB_CHUNK_SIZE:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
@ -99,6 +158,8 @@ class WebUsbHandle:
return
def read_chunk(self) -> bytes:
if self.handle is None:
self.open()
assert self.handle is not None
endpoint = 0x80 | self.endpoint
while True:
@ -119,70 +180,6 @@ class WebUsbHandle:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return chunk
class WebUsbTransport(ProtocolBasedTransport):
"""
WebUsbTransport implements transport over WebUSB interface.
"""
PATH_PREFIX = "webusb"
ENABLED = USB_IMPORTED
context = None
def __init__(
self,
device: "usb1.USBDevice",
handle: Optional[WebUsbHandle] = None,
debug: bool = False,
) -> None:
if handle is None:
handle = WebUsbHandle(device, debug)
self.device = device
self.handle = handle
self.debug = debug
super().__init__(protocol=ProtocolV1(handle))
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close)
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue
try:
# workaround for issue #223:
# on certain combinations of Windows USB drivers and libusb versions,
# Trezor is returned twice (possibly because Windows know it as both
# a HID and a WebUSB device), and one of the returned devices is
# non-functional.
dev.getProduct()
devices.append(WebUsbTransport(dev))
except usb1.USBErrorNotSupported:
pass
except usb1.USBErrorPipe:
if usb_reset:
handle = dev.open()
handle.resetDevice()
handle.close()
return devices
def find_debug(self) -> "WebUsbTransport":
# For v1 protocol, find debug USB interface for the same serial number
return WebUsbTransport(self.device, debug=True)