mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-21 02:26:10 +00:00
chore(core): adapt trezorlib transports to session based
[no changelog] Co-authored-by: mmilata <martin@martinmilata.cz>
This commit is contained in:
parent
d0cc62dfb0
commit
10ba97a903
@ -2,7 +2,7 @@
|
||||
|
||||
import binascii
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport_hid import HidTransport
|
||||
from trezorlib.transport.hid import HidTransport
|
||||
|
||||
devices = HidTransport.enumerate()
|
||||
if len(devices) > 0:
|
||||
|
@ -17,14 +17,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Iterable, Sequence, Tuple, TypeVar
|
||||
import typing as t
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
T = TypeVar("T", bound="Transport")
|
||||
T = t.TypeVar("T", bound="Transport")
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -34,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
||||
""".strip()
|
||||
|
||||
|
||||
MessagePayload = Tuple[int, bytes]
|
||||
MessagePayload = t.Tuple[int, bytes]
|
||||
|
||||
|
||||
class TransportException(TrezorException):
|
||||
@ -50,72 +51,57 @@ class Timeout(TransportException):
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Raw connection to a Trezor device.
|
||||
|
||||
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
|
||||
or USB-HID connection, or UDP socket of listening emulator(s).
|
||||
It can also enumerate devices available over this communication link, and return
|
||||
them as instances.
|
||||
|
||||
Transport instance is a thing that:
|
||||
- can be identified and requested by a string URI-like path
|
||||
- can open and close sessions, which enclose related operations
|
||||
- can read and write protobuf messages
|
||||
|
||||
You need to implement a new Transport subclass if you invent a new way to connect
|
||||
a Trezor device to a computer.
|
||||
"""
|
||||
|
||||
PATH_PREFIX: str
|
||||
ENABLED = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.get_path()
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: t.Type[T], models: t.Iterable[TrezorModel] | None = None
|
||||
) -> t.Iterable[T]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: t.Type[T], path: str, prefix_search: bool = False) -> T:
|
||||
for device in cls.enumerate():
|
||||
|
||||
if device.get_path() == path:
|
||||
return device
|
||||
|
||||
if prefix_search and device.get_path().startswith(path):
|
||||
return device
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
def get_path(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def begin_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def end_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def find_debug(self: T) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: type[T], models: Iterable[TrezorModel] | None = None
|
||||
) -> Iterable[T]:
|
||||
def open(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: type[T], path: str, prefix_search: bool = False) -> T:
|
||||
for device in cls.enumerate():
|
||||
if (
|
||||
path is None
|
||||
or device.get_path() == path
|
||||
or (prefix_search and device.get_path().startswith(path))
|
||||
):
|
||||
return device
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
def ping(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
CHUNK_SIZE: t.ClassVar[int | None]
|
||||
|
||||
|
||||
def all_transports() -> Iterable[type["Transport"]]:
|
||||
def all_transports() -> t.Iterable[t.Type["Transport"]]:
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
transports: Tuple[type["Transport"], ...] = (
|
||||
transports: t.Tuple[t.Type["Transport"], ...] = (
|
||||
BridgeTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
@ -125,9 +111,9 @@ def all_transports() -> Iterable[type["Transport"]]:
|
||||
|
||||
|
||||
def enumerate_devices(
|
||||
models: Iterable[TrezorModel] | None = None,
|
||||
) -> Sequence[Transport]:
|
||||
devices: list[Transport] = []
|
||||
models: t.Iterable[TrezorModel] | None = None,
|
||||
) -> t.Sequence[Transport]:
|
||||
devices: t.List[Transport] = []
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
|
@ -17,16 +17,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Iterable
|
||||
import typing as t
|
||||
|
||||
import requests
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||
from . import DeviceIsBusy, Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -58,10 +57,13 @@ def call_bridge(
|
||||
return r
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
def get_bridge_version() -> t.Tuple[int, ...]:
|
||||
config = call_bridge("configure").json()
|
||||
version_tuple = tuple(map(int, config["version"].split(".")))
|
||||
return version_tuple < TREZORD_VERSION_MODERN
|
||||
return tuple(map(int, config["version"].split(".")))
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
return get_bridge_version() < TREZORD_VERSION_MODERN
|
||||
|
||||
|
||||
class BridgeHandle:
|
||||
@ -115,15 +117,15 @@ class BridgeTransport(Transport):
|
||||
|
||||
PATH_PREFIX = "bridge"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = None
|
||||
|
||||
def __init__(
|
||||
self, device: dict[str, Any], legacy: bool, debug: bool = False
|
||||
self, device: dict[str, t.Any], legacy: bool, debug: bool = False
|
||||
) -> None:
|
||||
if legacy and debug:
|
||||
raise TransportException("Debugging not supported on legacy Bridge")
|
||||
|
||||
self.device = device
|
||||
self.session: str | None = None
|
||||
self.session: str | None = device["session"]
|
||||
self.debug = debug
|
||||
self.legacy = legacy
|
||||
|
||||
@ -154,8 +156,8 @@ class BridgeTransport(Transport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Iterable[TrezorModel] | None = None
|
||||
) -> Iterable["BridgeTransport"]:
|
||||
cls, _models: t.Iterable[TrezorModel] | None = None
|
||||
) -> t.Iterable["BridgeTransport"]:
|
||||
try:
|
||||
legacy = is_legacy_bridge()
|
||||
return [
|
||||
@ -164,7 +166,7 @@ class BridgeTransport(Transport):
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def begin_session(self) -> None:
|
||||
def open(self) -> None:
|
||||
try:
|
||||
data = self._call("acquire/" + self.device["path"])
|
||||
except BridgeException as e:
|
||||
@ -173,18 +175,17 @@ class BridgeTransport(Transport):
|
||||
raise
|
||||
self.session = data.json()["session"]
|
||||
|
||||
def end_session(self) -> None:
|
||||
def close(self) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
self._call("release")
|
||||
self.session = None
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
self.handle.write_buf(header + message_data)
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
self.handle.write_buf(chunk)
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
data = self.handle.read_buf(timeout=timeout)
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||
return msg_type, data[headerlen : headerlen + datalen]
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||
return self.handle.read_buf(timeout=timeout)
|
||||
|
||||
def ping(self) -> bool:
|
||||
return self.session is not None
|
||||
|
@ -19,12 +19,11 @@ from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable
|
||||
import typing as t
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZOR_ONE, TrezorModel
|
||||
from . import UDEV_RULES_STR, Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, Timeout, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -37,23 +36,61 @@ except Exception as e:
|
||||
HID_IMPORTED = False
|
||||
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
HidDeviceHandle = Any
|
||||
HidDevice = t.Dict[str, t.Any]
|
||||
HidDeviceHandle = t.Any
|
||||
|
||||
|
||||
class HidHandle:
|
||||
def __init__(
|
||||
self, path: bytes, serial: str, probe_hid_version: bool = False
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.serial = serial
|
||||
class HidTransport(Transport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
|
||||
self.device = device
|
||||
self.device_path = device["path"]
|
||||
self.device_serial_number = device["serial_number"]
|
||||
self.handle: HidDeviceHandle = None
|
||||
self.hid_version = None if probe_hid_version else 2
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
|
||||
) -> t.Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: t.List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = hid.device()
|
||||
try:
|
||||
self.handle.open_path(self.path)
|
||||
self.handle.open_path(self.device_path)
|
||||
except (IOError, OSError) as e:
|
||||
if sys.platform.startswith("linux"):
|
||||
e.args = e.args + (UDEV_RULES_STR,)
|
||||
@ -64,11 +101,11 @@ class HidHandle:
|
||||
# and we wouldn't even know.
|
||||
# So we check that the serial matches what we expect.
|
||||
serial = self.handle.get_serial_number_string()
|
||||
if serial != self.serial:
|
||||
if serial != self.device_serial_number:
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
raise TransportException(
|
||||
f"Unexpected device {serial} on path {self.path.decode()}"
|
||||
f"Unexpected device {serial} on path {self.device_path.decode()}"
|
||||
)
|
||||
|
||||
self.handle.set_nonblocking(True)
|
||||
@ -79,7 +116,7 @@ class HidHandle:
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
# reload serial, because device.wipe() can reset it
|
||||
self.serial = self.handle.get_serial_number_string()
|
||||
self.device_serial_number = self.handle.get_serial_number_string()
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
@ -120,53 +157,6 @@ class HidHandle:
|
||||
raise TransportException("Unknown HID version")
|
||||
|
||||
|
||||
class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice) -> None:
|
||||
self.device = device
|
||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self.handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable[TrezorModel] | None = None, debug: bool = False
|
||||
) -> Iterable[HidTransport]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: list[HidTransport] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> HidTransport:
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
|
||||
def is_wirelink(dev: HidDevice) -> bool:
|
||||
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
||||
|
||||
|
@ -1,179 +0,0 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
from . import MessagePayload, Timeout, Transport
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
V2_FIRST_CHUNK = 0x01
|
||||
V2_NEXT_CHUNK = 0x02
|
||||
V2_BEGIN_SESSION = 0x03
|
||||
V2_END_SESSION = 0x04
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_READ_TIMEOUT: float | None = None
|
||||
|
||||
|
||||
class Handle(StructuralType):
|
||||
"""PEP 544 structural type for Handle functionality.
|
||||
(called a "Protocol" in the proposed PEP, name which is impractical here)
|
||||
|
||||
Handle is a "physical" layer for a protocol.
|
||||
It can open/close a connection and read/write bare data in 64-byte chunks.
|
||||
|
||||
Functionally we gain nothing from making this an (abstract) base class for handle
|
||||
implementations, so this definition is for type hinting purposes only. You can,
|
||||
but don't have to, inherit from it.
|
||||
"""
|
||||
|
||||
def open(self) -> None: ...
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes: ...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: ...
|
||||
|
||||
|
||||
class Protocol:
|
||||
"""Wire protocol that can communicate with a Trezor device, given a Handle.
|
||||
|
||||
A Protocol implements the part of the Transport API that relates to communicating
|
||||
logical messages over a physical layer. It is a thing that can:
|
||||
- open and close sessions,
|
||||
- send and receive protobuf messages,
|
||||
given the ability to:
|
||||
- open and close physical connections,
|
||||
- and send and receive binary chunks.
|
||||
|
||||
For now, the class also handles session counting and opening the underlying Handle.
|
||||
This will probably be removed in the future.
|
||||
|
||||
We will need a new Protocol class if we change the way a Trezor device encapsulates
|
||||
its messages.
|
||||
"""
|
||||
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
self.handle = handle
|
||||
self.session_counter = 0
|
||||
|
||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
||||
def begin_session(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.handle.open()
|
||||
try:
|
||||
# Drop queued responses to old requests
|
||||
while True:
|
||||
msg = self.handle.read_chunk(timeout=0.1)
|
||||
LOG.warning("ignored: %s", msg)
|
||||
except Timeout:
|
||||
pass
|
||||
|
||||
self.session_counter += 1
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
self.handle.close()
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtocolBasedTransport(Transport):
|
||||
"""Transport that implements its communications through a Protocol.
|
||||
|
||||
Intended as a base class for implementations that proxy their communication
|
||||
operations to a Protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol: Protocol) -> None:
|
||||
self.protocol = protocol
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
self.protocol.write(message_type, message_data)
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
return self.protocol.read(timeout=timeout)
|
||||
|
||||
def begin_session(self) -> None:
|
||||
self.protocol.begin_session()
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.protocol.end_session()
|
||||
|
||||
|
||||
class ProtocolV1(Protocol):
|
||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
||||
Does not understand sessions.
|
||||
"""
|
||||
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: REPLEN - 1]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self, timeout: float | None = None) -> MessagePayload:
|
||||
if timeout is None:
|
||||
timeout = _DEFAULT_READ_TIMEOUT
|
||||
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next(timeout=timeout))
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self, timeout: float | None = None) -> tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk(timeout=timeout)
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self, timeout: float | None = None) -> bytes:
|
||||
chunk = self.handle.read_chunk(timeout=timeout)
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
return chunk[1:]
|
26
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
26
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ... import messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Channel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.mapping = mapping
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_features(self) -> None:
|
||||
raise NotImplementedError
|
129
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
129
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
@ -0,0 +1,129 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2025 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ... import exceptions, messages
|
||||
from ...log import DUMP_BYTES
|
||||
from .protocol_and_channel import Channel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolV1Channel(Channel):
|
||||
_DEFAULT_READ_TIMEOUT: t.ClassVar[float | None] = None
|
||||
HEADER_LEN: t.ClassVar[int] = struct.calcsize(">HL")
|
||||
_features: messages.Features | None = None
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if self._features is None:
|
||||
self.update_features()
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
def update_features(self) -> None:
|
||||
self.write(messages.GetFeatures())
|
||||
resp = self.read()
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = resp
|
||||
|
||||
def read(self, timeout: float | None = None) -> t.Any:
|
||||
msg_type, msg_bytes = self._read(timeout=timeout)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
self.transport.close()
|
||||
return msg
|
||||
|
||||
def write(self, msg: t.Any) -> None:
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self._write(msg_type, msg_bytes)
|
||||
|
||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||
chunk_size = self.transport.CHUNK_SIZE
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
|
||||
if chunk_size is None:
|
||||
self.transport.write_chunk(header + message_data)
|
||||
return
|
||||
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: chunk_size - 1]
|
||||
chunk = chunk.ljust(chunk_size, b"\x00")
|
||||
self.transport.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def _read(self, timeout: float | None = None) -> t.Tuple[int, bytes]:
|
||||
if timeout is None:
|
||||
timeout = self._DEFAULT_READ_TIMEOUT
|
||||
|
||||
if self.transport.CHUNK_SIZE is None:
|
||||
return self.read_chunkless(timeout=timeout)
|
||||
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first(timeout=timeout)
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next(timeout=timeout))
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_chunkless(self, timeout: float | None = None) -> t.Tuple[int, bytes]:
|
||||
data = self.transport.read_chunk(timeout=timeout)
|
||||
msg_type, datalen = struct.unpack(">HL", data[: self.HEADER_LEN])
|
||||
return msg_type, data[self.HEADER_LEN : self.HEADER_LEN + datalen]
|
||||
|
||||
def read_first(self, timeout: float | None = None) -> t.Tuple[int, int, bytes]:
|
||||
chunk = self.transport.read_chunk(timeout=timeout)
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self, timeout: float | None = None) -> bytes:
|
||||
chunk = self.transport.read_chunk(timeout=timeout)
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
return chunk[1:]
|
@ -19,11 +19,10 @@ from __future__ import annotations
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Iterable
|
||||
from typing import TYPE_CHECKING, Iterable, Tuple
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import Timeout, Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
@ -33,12 +32,13 @@ SOCKET_TIMEOUT = 0.1
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UdpTransport(ProtocolBasedTransport):
|
||||
class UdpTransport(Transport):
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(self, device: str | None = None) -> None:
|
||||
if not device:
|
||||
@ -48,24 +48,17 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
devparts = device.split(":")
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
self.device = (host, port)
|
||||
self.device: Tuple[str, int] = (host, port)
|
||||
|
||||
self.socket: socket.socket | None = None
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _try_path(cls, path: str) -> "UdpTransport":
|
||||
d = cls(path)
|
||||
try:
|
||||
d.open()
|
||||
if d._ping():
|
||||
if d.ping():
|
||||
return d
|
||||
else:
|
||||
raise TransportException(
|
||||
@ -99,20 +92,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
assert prefix_search # otherwise we would have raised above
|
||||
return super().find_by_path(path, prefix_search)
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self._ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise Timeout("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def open(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
@ -124,18 +105,9 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
def _ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if self.socket is None:
|
||||
self.open()
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected data length")
|
||||
@ -143,6 +115,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self, timeout: float | None = None) -> bytes:
|
||||
if self.socket is None:
|
||||
self.open()
|
||||
assert self.socket is not None
|
||||
start = time.time()
|
||||
while True:
|
||||
@ -156,3 +130,33 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return bytearray(chunk)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self.ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise Timeout("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
@ -20,14 +20,11 @@ import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable
|
||||
|
||||
from typing_extensions import Self
|
||||
from typing import Iterable, List
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZORS, TrezorModel
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, Timeout, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -48,14 +45,70 @@ USB_COMM_TIMEOUT_MS = 300
|
||||
WEBUSB_CHUNK_SIZE = 64
|
||||
|
||||
|
||||
class WebUsbHandle:
|
||||
def __init__(self, device: usb1.USBDevice, debug: bool = False) -> None:
|
||||
class WebUsbTransport(Transport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.count = 0
|
||||
self.handle: usb1.USBDeviceHandle | None = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
return devices
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
if self.handle is None:
|
||||
@ -68,6 +121,8 @@ class WebUsbHandle:
|
||||
self.handle.claimInterface(self.interface)
|
||||
except usb1.USBErrorAccess as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
except usb1.USBErrorBusy as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
@ -79,6 +134,8 @@ class WebUsbHandle:
|
||||
self.handle = None
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if self.handle is None:
|
||||
self.open()
|
||||
assert self.handle is not None
|
||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
@ -119,73 +176,7 @@ class WebUsbHandle:
|
||||
except Exception as e:
|
||||
raise TransportException(f"USB read failed: {e}") from e
|
||||
|
||||
|
||||
class WebUsbTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: usb1.USBDevice,
|
||||
handle: WebUsbHandle | None = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
if handle is None:
|
||||
handle = WebUsbHandle(device, debug)
|
||||
|
||||
self.device = device
|
||||
self.handle = handle
|
||||
self.debug = debug
|
||||
|
||||
super().__init__(protocol=ProtocolV1(handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls,
|
||||
models: Iterable[TrezorModel] | None = None,
|
||||
usb_reset: bool = False,
|
||||
) -> Iterable[WebUsbTransport]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: list[WebUsbTransport] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
except usb1.USBErrorPipe:
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> Self:
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return self.__class__(self.device, debug=True)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user