mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-03-12 22:26:08 +00:00
chore(core): adapt trezorlib transports to session based
[no changelog] Co-authored-by: mmilata <martin@martinmilata.cz>
This commit is contained in:
parent
5c2d3c65b7
commit
72fe47ba3b
@ -2,7 +2,7 @@
|
||||
|
||||
import binascii
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport_hid import HidTransport
|
||||
from trezorlib.transport.hid import HidTransport
|
||||
|
||||
devices = HidTransport.enumerate()
|
||||
if len(devices) > 0:
|
||||
|
@ -14,24 +14,18 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
import typing as t
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
T = TypeVar("T", bound="Transport")
|
||||
T = t.TypeVar("T", bound="Transport")
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
||||
""".strip()
|
||||
|
||||
|
||||
MessagePayload = Tuple[int, bytes]
|
||||
MessagePayload = t.Tuple[int, bytes]
|
||||
|
||||
|
||||
class TransportException(TrezorException):
|
||||
@ -53,73 +47,57 @@ class DeviceIsBusy(TransportException):
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Raw connection to a Trezor device.
|
||||
|
||||
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
|
||||
or USB-HID connection, or UDP socket of listening emulator(s).
|
||||
It can also enumerate devices available over this communication link, and return
|
||||
them as instances.
|
||||
|
||||
Transport instance is a thing that:
|
||||
- can be identified and requested by a string URI-like path
|
||||
- can open and close sessions, which enclose related operations
|
||||
- can read and write protobuf messages
|
||||
|
||||
You need to implement a new Transport subclass if you invent a new way to connect
|
||||
a Trezor device to a computer.
|
||||
"""
|
||||
|
||||
PATH_PREFIX: str
|
||||
ENABLED = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.get_path()
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
|
||||
) -> t.Iterable["T"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
|
||||
if device.get_path() == path:
|
||||
return device
|
||||
|
||||
if prefix_search and device.get_path().startswith(path):
|
||||
return device
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
def get_path(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def begin_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def end_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def find_debug(self: "T") -> "T":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["T"]:
|
||||
def open(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
if (
|
||||
path is None
|
||||
or device.get_path() == path
|
||||
or (prefix_search and device.get_path().startswith(path))
|
||||
):
|
||||
return device
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
def ping(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
CHUNK_SIZE: t.ClassVar[int | None]
|
||||
|
||||
|
||||
def all_transports() -> Iterable[Type["Transport"]]:
|
||||
from .ble import BleTransport
|
||||
def all_transports() -> t.Iterable[t.Type["Transport"]]:
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
transports: Tuple[Type["Transport"], ...] = (
|
||||
transports: t.Tuple[t.Type["Transport"], ...] = (
|
||||
BridgeTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
@ -130,9 +108,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
|
||||
|
||||
|
||||
def enumerate_devices(
|
||||
models: Optional[Iterable["TrezorModel"]] = None,
|
||||
) -> Sequence["Transport"]:
|
||||
devices: List["Transport"] = []
|
||||
models: t.Iterable["TrezorModel"] | None = None,
|
||||
) -> t.Sequence["Transport"]:
|
||||
devices: t.List["Transport"] = []
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
@ -147,9 +125,7 @@ def enumerate_devices(
|
||||
return devices
|
||||
|
||||
|
||||
def get_transport(
|
||||
path: Optional[str] = None, prefix_search: bool = False
|
||||
) -> "Transport":
|
||||
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
|
||||
if path is None:
|
||||
try:
|
||||
return next(iter(enumerate_devices()))
|
||||
|
@ -14,16 +14,17 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
|
||||
import typing as t
|
||||
|
||||
import requests
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||
from . import DeviceIsBusy, Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -45,7 +46,7 @@ class BridgeException(TransportException):
|
||||
super().__init__(f"trezord: {path} failed with code {status}: {message}")
|
||||
|
||||
|
||||
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
||||
def call_bridge(path: str, data: str | None = None) -> requests.Response:
|
||||
url = TREZORD_HOST + "/" + path
|
||||
r = CONNECTION.post(url, data=data)
|
||||
if r.status_code != 200:
|
||||
@ -53,10 +54,13 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
||||
return r
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
def get_bridge_version() -> t.Tuple[int, ...]:
|
||||
config = call_bridge("configure").json()
|
||||
version_tuple = tuple(map(int, config["version"].split(".")))
|
||||
return version_tuple < TREZORD_VERSION_MODERN
|
||||
return tuple(map(int, config["version"].split(".")))
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
return get_bridge_version() < TREZORD_VERSION_MODERN
|
||||
|
||||
|
||||
class BridgeHandle:
|
||||
@ -84,7 +88,7 @@ class BridgeHandleModern(BridgeHandle):
|
||||
class BridgeHandleLegacy(BridgeHandle):
|
||||
def __init__(self, transport: "BridgeTransport") -> None:
|
||||
super().__init__(transport)
|
||||
self.request: Optional[str] = None
|
||||
self.request: str | None = None
|
||||
|
||||
def write_buf(self, buf: bytes) -> None:
|
||||
if self.request is not None:
|
||||
@ -110,15 +114,15 @@ class BridgeTransport(Transport):
|
||||
|
||||
PATH_PREFIX = "bridge"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = None
|
||||
|
||||
def __init__(
|
||||
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
||||
self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
|
||||
) -> None:
|
||||
if legacy and debug:
|
||||
raise TransportException("Debugging not supported on legacy Bridge")
|
||||
|
||||
self.device = device
|
||||
self.session: Optional[str] = None
|
||||
self.session: str | None = device["session"]
|
||||
self.debug = debug
|
||||
self.legacy = legacy
|
||||
|
||||
@ -135,7 +139,7 @@ class BridgeTransport(Transport):
|
||||
raise TransportException("Debug device not available")
|
||||
return BridgeTransport(self.device, self.legacy, debug=True)
|
||||
|
||||
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
|
||||
def _call(self, action: str, data: str | None = None) -> requests.Response:
|
||||
session = self.session or "null"
|
||||
uri = action + "/" + str(session)
|
||||
if self.debug:
|
||||
@ -144,8 +148,8 @@ class BridgeTransport(Transport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["BridgeTransport"]:
|
||||
cls, _models: t.Iterable["TrezorModel"] | None = None
|
||||
) -> t.Iterable["BridgeTransport"]:
|
||||
try:
|
||||
legacy = is_legacy_bridge()
|
||||
return [
|
||||
@ -154,7 +158,7 @@ class BridgeTransport(Transport):
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def begin_session(self) -> None:
|
||||
def open(self) -> None:
|
||||
try:
|
||||
data = self._call("acquire/" + self.device["path"])
|
||||
except BridgeException as e:
|
||||
@ -163,18 +167,17 @@ class BridgeTransport(Transport):
|
||||
raise
|
||||
self.session = data.json()["session"]
|
||||
|
||||
def end_session(self) -> None:
|
||||
def close(self) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
self._call("release")
|
||||
self.session = None
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
self.handle.write_buf(header + message_data)
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
self.handle.write_buf(chunk)
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
data = self.handle.read_buf()
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||
return msg_type, data[headerlen : headerlen + datalen]
|
||||
def read_chunk(self) -> bytes:
|
||||
return self.handle.read_buf()
|
||||
|
||||
def ping(self) -> bool:
|
||||
return self.session is not None
|
||||
|
@ -14,15 +14,16 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
import typing as t
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZOR_ONE, TrezorModel
|
||||
from . import UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -35,23 +36,61 @@ except Exception as e:
|
||||
HID_IMPORTED = False
|
||||
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
HidDeviceHandle = Any
|
||||
HidDevice = t.Dict[str, t.Any]
|
||||
HidDeviceHandle = t.Any
|
||||
|
||||
|
||||
class HidHandle:
|
||||
def __init__(
|
||||
self, path: bytes, serial: str, probe_hid_version: bool = False
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.serial = serial
|
||||
class HidTransport(Transport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
|
||||
self.device = device
|
||||
self.device_path = device["path"]
|
||||
self.device_serial_number = device["serial_number"]
|
||||
self.handle: HidDeviceHandle = None
|
||||
self.hid_version = None if probe_hid_version else 2
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
|
||||
) -> t.Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: t.List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = hid.device()
|
||||
try:
|
||||
self.handle.open_path(self.path)
|
||||
self.handle.open_path(self.device_path)
|
||||
except (IOError, OSError) as e:
|
||||
if sys.platform.startswith("linux"):
|
||||
e.args = e.args + (UDEV_RULES_STR,)
|
||||
@ -62,11 +101,11 @@ class HidHandle:
|
||||
# and we wouldn't even know.
|
||||
# So we check that the serial matches what we expect.
|
||||
serial = self.handle.get_serial_number_string()
|
||||
if serial != self.serial:
|
||||
if serial != self.device_serial_number:
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
raise TransportException(
|
||||
f"Unexpected device {serial} on path {self.path.decode()}"
|
||||
f"Unexpected device {serial} on path {self.device_path.decode()}"
|
||||
)
|
||||
|
||||
self.handle.set_nonblocking(True)
|
||||
@ -77,7 +116,7 @@ class HidHandle:
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
# reload serial, because device.wipe() can reset it
|
||||
self.serial = self.handle.get_serial_number_string()
|
||||
self.device_serial_number = self.handle.get_serial_number_string()
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
@ -115,53 +154,6 @@ class HidHandle:
|
||||
raise TransportException("Unknown HID version")
|
||||
|
||||
|
||||
class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice) -> None:
|
||||
self.device = device
|
||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self.handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
|
||||
) -> Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
|
||||
def is_wirelink(dev: HidDevice) -> bool:
|
||||
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
||||
|
||||
|
@ -1,166 +0,0 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import Tuple
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
from . import MessagePayload, Transport
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
V2_FIRST_CHUNK = 0x01
|
||||
V2_NEXT_CHUNK = 0x02
|
||||
V2_BEGIN_SESSION = 0x03
|
||||
V2_END_SESSION = 0x04
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Handle(StructuralType):
|
||||
"""PEP 544 structural type for Handle functionality.
|
||||
(called a "Protocol" in the proposed PEP, name which is impractical here)
|
||||
|
||||
Handle is a "physical" layer for a protocol.
|
||||
It can open/close a connection and read/write bare data in 64-byte chunks.
|
||||
|
||||
Functionally we gain nothing from making this an (abstract) base class for handle
|
||||
implementations, so this definition is for type hinting purposes only. You can,
|
||||
but don't have to, inherit from it.
|
||||
"""
|
||||
|
||||
def open(self) -> None: ...
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
def read_chunk(self) -> bytes: ...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: ...
|
||||
|
||||
|
||||
class Protocol:
|
||||
"""Wire protocol that can communicate with a Trezor device, given a Handle.
|
||||
|
||||
A Protocol implements the part of the Transport API that relates to communicating
|
||||
logical messages over a physical layer. It is a thing that can:
|
||||
- open and close sessions,
|
||||
- send and receive protobuf messages,
|
||||
given the ability to:
|
||||
- open and close physical connections,
|
||||
- and send and receive binary chunks.
|
||||
|
||||
For now, the class also handles session counting and opening the underlying Handle.
|
||||
This will probably be removed in the future.
|
||||
|
||||
We will need a new Protocol class if we change the way a Trezor device encapsulates
|
||||
its messages.
|
||||
"""
|
||||
|
||||
def __init__(self, handle: Handle, replen: int = REPLEN) -> None:
|
||||
self.handle = handle
|
||||
self.replen = replen
|
||||
self.session_counter = 0
|
||||
|
||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
||||
def begin_session(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.handle.open()
|
||||
self.session_counter += 1
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
self.handle.close()
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtocolBasedTransport(Transport):
|
||||
"""Transport that implements its communications through a Protocol.
|
||||
|
||||
Intended as a base class for implementations that proxy their communication
|
||||
operations to a Protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol: Protocol) -> None:
|
||||
self.protocol = protocol
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
self.protocol.write(message_type, message_data)
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
return self.protocol.read()
|
||||
|
||||
def begin_session(self) -> None:
|
||||
self.protocol.begin_session()
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.protocol.end_session()
|
||||
|
||||
|
||||
class ProtocolV1(Protocol):
|
||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
||||
Does not understand sessions.
|
||||
"""
|
||||
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: self.replen - 1]
|
||||
chunk = chunk.ljust(self.replen, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[self.replen - 1 :]
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
return chunk[1:]
|
26
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
26
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ... import messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Channel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.mapping = mapping
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
raise NotImplementedError()
|
||||
|
||||
def update_features(self) -> None:
|
||||
raise NotImplementedError
|
109
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
109
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
@ -0,0 +1,109 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ... import exceptions, messages
|
||||
from ...log import DUMP_BYTES
|
||||
from .protocol_and_channel import Channel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolV1Channel(Channel):
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
_features: messages.Features | None = None
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if self._features is None:
|
||||
self.update_features()
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
def update_features(self) -> None:
|
||||
self.write(messages.GetFeatures())
|
||||
resp = self.read()
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = resp
|
||||
|
||||
def read(self) -> t.Any:
|
||||
msg_type, msg_bytes = self._read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
self.transport.close()
|
||||
return msg
|
||||
|
||||
def write(self, msg: t.Any) -> None:
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self._write(msg_type, msg_bytes)
|
||||
|
||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||
chunk_size = self.transport.CHUNK_SIZE
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
|
||||
if chunk_size is None:
|
||||
self.transport.write_chunk(header + message_data)
|
||||
return
|
||||
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: chunk_size - 1]
|
||||
chunk = chunk.ljust(chunk_size, b"\x00")
|
||||
self.transport.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def _read(self) -> t.Tuple[int, bytes]:
|
||||
if self.transport.CHUNK_SIZE is None:
|
||||
return self.read_chunkless()
|
||||
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_chunkless(self) -> t.Tuple[int, bytes]:
|
||||
data = self.transport.read_chunk()
|
||||
msg_type, datalen = struct.unpack(">HL", data[: self.HEADER_LEN])
|
||||
return msg_type, data[self.HEADER_LEN : self.HEADER_LEN + datalen]
|
||||
|
||||
def read_first(self) -> t.Tuple[int, int, bytes]:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError(f"Cannot parse header: {chunk.hex()}")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError(f"Unexpected magic characters: {chunk.hex()}")
|
||||
return chunk[1:]
|
@ -14,14 +14,15 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Iterable, Optional
|
||||
from typing import TYPE_CHECKING, Iterable, Tuple
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UdpTransport(ProtocolBasedTransport):
|
||||
class UdpTransport(Transport):
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(self, device: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
device: str | None = None,
|
||||
) -> None:
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
devparts = device.split(":")
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
self.device = (host, port)
|
||||
self.socket: Optional[socket.socket] = None
|
||||
self.device: Tuple[str, int] = (host, port)
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
self.socket: socket.socket | None = None
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _try_path(cls, path: str) -> "UdpTransport":
|
||||
d = cls(path)
|
||||
try:
|
||||
d.open()
|
||||
if d._ping():
|
||||
if d.ping():
|
||||
return d
|
||||
else:
|
||||
raise TransportException(
|
||||
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
cls, _models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["UdpTransport"]:
|
||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||
try:
|
||||
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
else:
|
||||
raise TransportException(f"No UDP device at {path}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self._ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def open(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
def _ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if self.socket is None:
|
||||
self.open()
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected data length")
|
||||
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
if self.socket is None:
|
||||
self.open()
|
||||
assert self.socket is not None
|
||||
while True:
|
||||
try:
|
||||
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return bytearray(chunk)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self.ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
@ -14,16 +14,17 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable, List
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZORS, TrezorModel
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
|
||||
WEBUSB_CHUNK_SIZE = 64
|
||||
|
||||
|
||||
class WebUsbHandle:
|
||||
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
|
||||
class WebUsbTransport(Transport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.count = 0
|
||||
self.handle: Optional["usb1.USBDeviceHandle"] = None
|
||||
self.handle: usb1.USBDeviceHandle | None = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
return devices
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
@ -77,6 +134,8 @@ class WebUsbHandle:
|
||||
self.handle = None
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if self.handle is None:
|
||||
self.open()
|
||||
assert self.handle is not None
|
||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
@ -99,6 +158,8 @@ class WebUsbHandle:
|
||||
return
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
if self.handle is None:
|
||||
self.open()
|
||||
assert self.handle is not None
|
||||
endpoint = 0x80 | self.endpoint
|
||||
while True:
|
||||
@ -119,70 +180,6 @@ class WebUsbHandle:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return chunk
|
||||
|
||||
|
||||
class WebUsbTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
handle: Optional[WebUsbHandle] = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
if handle is None:
|
||||
handle = WebUsbHandle(device, debug)
|
||||
|
||||
self.device = device
|
||||
self.handle = handle
|
||||
self.debug = debug
|
||||
|
||||
super().__init__(protocol=ProtocolV1(handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
except usb1.USBErrorPipe:
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
|
Loading…
Reference in New Issue
Block a user