mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-07 21:22:41 +00:00
chore(core): adapt trezorlib transports to session based
[no changelog]
This commit is contained in:
parent
fbff05a89f
commit
61b2156a1e
@ -14,24 +14,18 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
import typing as t
|
||||
|
||||
from ..exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
T = TypeVar("T", bound="Transport")
|
||||
T = t.TypeVar("T", bound="Transport")
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
||||
""".strip()
|
||||
|
||||
|
||||
MessagePayload = Tuple[int, bytes]
|
||||
MessagePayload = t.Tuple[int, bytes]
|
||||
|
||||
|
||||
class TransportException(TrezorException):
|
||||
@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException):
|
||||
|
||||
|
||||
class Transport:
|
||||
"""Raw connection to a Trezor device.
|
||||
|
||||
Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB
|
||||
or USB-HID connection, or UDP socket of listening emulator(s).
|
||||
It can also enumerate devices available over this communication link, and return
|
||||
them as instances.
|
||||
|
||||
Transport instance is a thing that:
|
||||
- can be identified and requested by a string URI-like path
|
||||
- can open and close sessions, which enclose related operations
|
||||
- can read and write protobuf messages
|
||||
|
||||
You need to implement a new Transport subclass if you invent a new way to connect
|
||||
a Trezor device to a computer.
|
||||
"""
|
||||
|
||||
PATH_PREFIX: str
|
||||
ENABLED = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.get_path()
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None
|
||||
) -> t.Iterable["T"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
|
||||
if device.get_path() == path:
|
||||
return device
|
||||
|
||||
if prefix_search and device.get_path().startswith(path):
|
||||
return device
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
|
||||
def get_path(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def begin_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def end_session(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def find_debug(self: "T") -> "T":
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["T"]:
|
||||
def open(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T":
|
||||
for device in cls.enumerate():
|
||||
if (
|
||||
path is None
|
||||
or device.get_path() == path
|
||||
or (prefix_search and device.get_path().startswith(path))
|
||||
):
|
||||
return device
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}")
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
CHUNK_SIZE: t.ClassVar[int]
|
||||
|
||||
|
||||
def all_transports() -> Iterable[Type["Transport"]]:
|
||||
def all_transports() -> t.Iterable[t.Type["Transport"]]:
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
transports: Tuple[Type["Transport"], ...] = (
|
||||
transports: t.Tuple[t.Type["Transport"], ...] = (
|
||||
BridgeTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]:
|
||||
|
||||
|
||||
def enumerate_devices(
|
||||
models: Optional[Iterable["TrezorModel"]] = None,
|
||||
) -> Sequence["Transport"]:
|
||||
devices: List["Transport"] = []
|
||||
models: t.Iterable["TrezorModel"] | None = None,
|
||||
) -> t.Sequence["Transport"]:
|
||||
devices: t.List["Transport"] = []
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
@ -145,9 +121,7 @@ def enumerate_devices(
|
||||
return devices
|
||||
|
||||
|
||||
def get_transport(
|
||||
path: Optional[str] = None, prefix_search: bool = False
|
||||
) -> "Transport":
|
||||
def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport":
|
||||
if path is None:
|
||||
try:
|
||||
return next(iter(enumerate_devices()))
|
||||
|
@ -14,24 +14,30 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
|
||||
import typing as t
|
||||
|
||||
import requests
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DeviceIsBusy, MessagePayload, Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
if t.TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
PROTOCOL_VERSION_1 = 1
|
||||
PROTOCOL_VERSION_2 = 2
|
||||
|
||||
TREZORD_HOST = "http://127.0.0.1:21325"
|
||||
TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"}
|
||||
|
||||
TREZORD_VERSION_MODERN = (2, 0, 25)
|
||||
TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value
|
||||
|
||||
CONNECTION = requests.Session()
|
||||
CONNECTION.headers.update(TREZORD_ORIGIN_HEADER)
|
||||
@ -45,7 +51,7 @@ class BridgeException(TransportException):
|
||||
super().__init__(f"trezord: {path} failed with code {status}: {message}")
|
||||
|
||||
|
||||
def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
||||
def call_bridge(path: str, data: str | None = None) -> requests.Response:
|
||||
url = TREZORD_HOST + "/" + path
|
||||
r = CONNECTION.post(url, data=data)
|
||||
if r.status_code != 200:
|
||||
@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response:
|
||||
return r
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
def get_bridge_version() -> t.Tuple[int, ...]:
|
||||
config = call_bridge("configure").json()
|
||||
version_tuple = tuple(map(int, config["version"].split(".")))
|
||||
return version_tuple < TREZORD_VERSION_MODERN
|
||||
return tuple(map(int, config["version"].split(".")))
|
||||
|
||||
|
||||
def is_legacy_bridge() -> bool:
|
||||
return get_bridge_version() < TREZORD_VERSION_MODERN
|
||||
|
||||
|
||||
def supports_protocolV2() -> bool:
|
||||
return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT
|
||||
|
||||
|
||||
def detect_protocol_version(transport: "BridgeTransport") -> int:
|
||||
from .. import mapping, messages
|
||||
from ..messages import FailureType
|
||||
|
||||
protocol_version = PROTOCOL_VERSION_1
|
||||
request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize())
|
||||
transport.deprecated_begin_session()
|
||||
transport.deprecated_write(request_type, request_data)
|
||||
|
||||
response_type, response_data = transport.deprecated_read()
|
||||
response = mapping.DEFAULT_MAPPING.decode(response_type, response_data)
|
||||
transport.deprecated_begin_session()
|
||||
if isinstance(response, messages.Failure):
|
||||
if response.code == FailureType.InvalidProtocol:
|
||||
LOG.debug("Protocol V2 detected")
|
||||
protocol_version = PROTOCOL_VERSION_2
|
||||
|
||||
return protocol_version
|
||||
|
||||
|
||||
def _is_transport_valid(transport: "BridgeTransport") -> bool:
|
||||
is_valid = (
|
||||
supports_protocolV2()
|
||||
or detect_protocol_version(transport) == PROTOCOL_VERSION_1
|
||||
)
|
||||
if not is_valid:
|
||||
LOG.warning("Detected unsupported Bridge transport!")
|
||||
return is_valid
|
||||
|
||||
|
||||
def filter_invalid_bridge_transports(
|
||||
transports: t.Iterable["BridgeTransport"],
|
||||
) -> t.Sequence["BridgeTransport"]:
|
||||
"""Filters out invalid bridge transports. Keeps only valid ones."""
|
||||
return [t for t in transports if _is_transport_valid(t)]
|
||||
|
||||
|
||||
class BridgeHandle:
|
||||
@ -84,7 +134,7 @@ class BridgeHandleModern(BridgeHandle):
|
||||
class BridgeHandleLegacy(BridgeHandle):
|
||||
def __init__(self, transport: "BridgeTransport") -> None:
|
||||
super().__init__(transport)
|
||||
self.request: Optional[str] = None
|
||||
self.request: str | None = None
|
||||
|
||||
def write_buf(self, buf: bytes) -> None:
|
||||
if self.request is not None:
|
||||
@ -112,13 +162,12 @@ class BridgeTransport(Transport):
|
||||
ENABLED: bool = True
|
||||
|
||||
def __init__(
|
||||
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
||||
self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False
|
||||
) -> None:
|
||||
if legacy and debug:
|
||||
raise TransportException("Debugging not supported on legacy Bridge")
|
||||
|
||||
self.device = device
|
||||
self.session: Optional[str] = None
|
||||
self.session: str | None = device["session"]
|
||||
self.debug = debug
|
||||
self.legacy = legacy
|
||||
|
||||
@ -135,7 +184,7 @@ class BridgeTransport(Transport):
|
||||
raise TransportException("Debug device not available")
|
||||
return BridgeTransport(self.device, self.legacy, debug=True)
|
||||
|
||||
def _call(self, action: str, data: Optional[str] = None) -> requests.Response:
|
||||
def _call(self, action: str, data: str | None = None) -> requests.Response:
|
||||
session = self.session or "null"
|
||||
uri = action + "/" + str(session)
|
||||
if self.debug:
|
||||
@ -144,17 +193,20 @@ class BridgeTransport(Transport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["BridgeTransport"]:
|
||||
cls, _models: t.Iterable["TrezorModel"] | None = None
|
||||
) -> t.Iterable["BridgeTransport"]:
|
||||
try:
|
||||
legacy = is_legacy_bridge()
|
||||
return [
|
||||
BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json()
|
||||
return filter_invalid_bridge_transports(
|
||||
[
|
||||
BridgeTransport(dev, legacy)
|
||||
for dev in call_bridge("enumerate").json()
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def begin_session(self) -> None:
|
||||
def deprecated_begin_session(self) -> None:
|
||||
try:
|
||||
data = self._call("acquire/" + self.device["path"])
|
||||
except BridgeException as e:
|
||||
@ -163,18 +215,32 @@ class BridgeTransport(Transport):
|
||||
raise
|
||||
self.session = data.json()["session"]
|
||||
|
||||
def end_session(self) -> None:
|
||||
def deprecated_end_session(self) -> None:
|
||||
if not self.session:
|
||||
return
|
||||
self._call("release")
|
||||
self.session = None
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
def deprecated_write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
self.handle.write_buf(header + message_data)
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
def deprecated_read(self) -> MessagePayload:
|
||||
data = self.handle.read_buf()
|
||||
headerlen = struct.calcsize(">HL")
|
||||
msg_type, datalen = struct.unpack(">HL", data[:headerlen])
|
||||
return msg_type, data[headerlen : headerlen + datalen]
|
||||
|
||||
def open(self) -> None:
|
||||
pass
|
||||
# TODO self.handle.open()
|
||||
|
||||
def close(self) -> None:
|
||||
pass
|
||||
# TODO self.handle.close()
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :)
|
||||
self.handle.write_buf(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes: # TODO check if it works :)
|
||||
return self.handle.read_buf()
|
||||
|
@ -14,15 +14,16 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
import typing as t
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZOR_ONE, TrezorModel
|
||||
from . import UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -35,23 +36,61 @@ except Exception as e:
|
||||
HID_IMPORTED = False
|
||||
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
HidDeviceHandle = Any
|
||||
HidDevice = t.Dict[str, t.Any]
|
||||
HidDeviceHandle = t.Any
|
||||
|
||||
|
||||
class HidHandle:
|
||||
def __init__(
|
||||
self, path: bytes, serial: str, probe_hid_version: bool = False
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.serial = serial
|
||||
class HidTransport(Transport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
|
||||
self.device = device
|
||||
self.device_path = device["path"]
|
||||
self.device_serial_number = device["serial_number"]
|
||||
self.handle: HidDeviceHandle = None
|
||||
self.hid_version = None if probe_hid_version else 2
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
|
||||
) -> t.Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: t.List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = hid.device()
|
||||
try:
|
||||
self.handle.open_path(self.path)
|
||||
self.handle.open_path(self.device_path)
|
||||
except (IOError, OSError) as e:
|
||||
if sys.platform.startswith("linux"):
|
||||
e.args = e.args + (UDEV_RULES_STR,)
|
||||
@ -62,11 +101,11 @@ class HidHandle:
|
||||
# and we wouldn't even know.
|
||||
# So we check that the serial matches what we expect.
|
||||
serial = self.handle.get_serial_number_string()
|
||||
if serial != self.serial:
|
||||
if serial != self.device_serial_number:
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
raise TransportException(
|
||||
f"Unexpected device {serial} on path {self.path.decode()}"
|
||||
f"Unexpected device {serial} on path {self.device_path.decode()}"
|
||||
)
|
||||
|
||||
self.handle.set_nonblocking(True)
|
||||
@ -77,7 +116,7 @@ class HidHandle:
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
# reload serial, because device.wipe() can reset it
|
||||
self.serial = self.handle.get_serial_number_string()
|
||||
self.device_serial_number = self.handle.get_serial_number_string()
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
@ -115,53 +154,6 @@ class HidHandle:
|
||||
raise TransportException("Unknown HID version")
|
||||
|
||||
|
||||
class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice) -> None:
|
||||
self.device = device
|
||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self.handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
|
||||
) -> Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
|
||||
def is_wirelink(dev: HidDevice) -> bool:
|
||||
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
||||
|
||||
|
@ -1,165 +0,0 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
# as published by the Free Software Foundation.
|
||||
#
|
||||
# This library is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Lesser General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import Tuple
|
||||
|
||||
from typing_extensions import Protocol as StructuralType
|
||||
|
||||
from . import MessagePayload, Transport
|
||||
|
||||
REPLEN = 64
|
||||
|
||||
V2_FIRST_CHUNK = 0x01
|
||||
V2_NEXT_CHUNK = 0x02
|
||||
V2_BEGIN_SESSION = 0x03
|
||||
V2_END_SESSION = 0x04
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Handle(StructuralType):
|
||||
"""PEP 544 structural type for Handle functionality.
|
||||
(called a "Protocol" in the proposed PEP, name which is impractical here)
|
||||
|
||||
Handle is a "physical" layer for a protocol.
|
||||
It can open/close a connection and read/write bare data in 64-byte chunks.
|
||||
|
||||
Functionally we gain nothing from making this an (abstract) base class for handle
|
||||
implementations, so this definition is for type hinting purposes only. You can,
|
||||
but don't have to, inherit from it.
|
||||
"""
|
||||
|
||||
def open(self) -> None: ...
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
def read_chunk(self) -> bytes: ...
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None: ...
|
||||
|
||||
|
||||
class Protocol:
|
||||
"""Wire protocol that can communicate with a Trezor device, given a Handle.
|
||||
|
||||
A Protocol implements the part of the Transport API that relates to communicating
|
||||
logical messages over a physical layer. It is a thing that can:
|
||||
- open and close sessions,
|
||||
- send and receive protobuf messages,
|
||||
given the ability to:
|
||||
- open and close physical connections,
|
||||
- and send and receive binary chunks.
|
||||
|
||||
For now, the class also handles session counting and opening the underlying Handle.
|
||||
This will probably be removed in the future.
|
||||
|
||||
We will need a new Protocol class if we change the way a Trezor device encapsulates
|
||||
its messages.
|
||||
"""
|
||||
|
||||
def __init__(self, handle: Handle) -> None:
|
||||
self.handle = handle
|
||||
self.session_counter = 0
|
||||
|
||||
# XXX we might be able to remove this now that TrezorClient does session handling
|
||||
def begin_session(self) -> None:
|
||||
if self.session_counter == 0:
|
||||
self.handle.open()
|
||||
self.session_counter += 1
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.session_counter = max(self.session_counter - 1, 0)
|
||||
if self.session_counter == 0:
|
||||
self.handle.close()
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
raise NotImplementedError
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ProtocolBasedTransport(Transport):
|
||||
"""Transport that implements its communications through a Protocol.
|
||||
|
||||
Intended as a base class for implementations that proxy their communication
|
||||
operations to a Protocol.
|
||||
"""
|
||||
|
||||
def __init__(self, protocol: Protocol) -> None:
|
||||
self.protocol = protocol
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
self.protocol.write(message_type, message_data)
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
return self.protocol.read()
|
||||
|
||||
def begin_session(self) -> None:
|
||||
self.protocol.begin_session()
|
||||
|
||||
def end_session(self) -> None:
|
||||
self.protocol.end_session()
|
||||
|
||||
|
||||
class ProtocolV1(Protocol):
|
||||
"""Protocol version 1. Currently (11/2018) in use on all Trezors.
|
||||
Does not understand sessions.
|
||||
"""
|
||||
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
|
||||
def write(self, message_type: int, message_data: bytes) -> None:
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: REPLEN - 1]
|
||||
chunk = chunk.ljust(REPLEN, b"\x00")
|
||||
self.handle.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def read(self) -> MessagePayload:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> Tuple[int, int, bytes]:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.handle.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
223
python/src/trezorlib/transport/session.py
Normal file
223
python/src/trezorlib/transport/session.py
Normal file
@ -0,0 +1,223 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import typing as t
|
||||
|
||||
from .. import exceptions, messages, models
|
||||
from ..protobuf import MessageType
|
||||
from .thp.protocol_v1 import ProtocolV1
|
||||
from .thp.protocol_v2 import ProtocolV2
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ..client import TrezorClient
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
MT = t.TypeVar("MT", bound=MessageType)
|
||||
|
||||
|
||||
class Session:
|
||||
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
||||
|
||||
def __init__(
|
||||
self, client: TrezorClient, id: bytes, passphrase: str | object | None = None
|
||||
) -> None:
|
||||
self.client = client
|
||||
self._id = id
|
||||
self.passphrase = passphrase
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool
|
||||
) -> Session:
|
||||
raise NotImplementedError
|
||||
|
||||
def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT:
|
||||
# TODO self.check_firmware_version()
|
||||
resp = self.call_raw(msg)
|
||||
|
||||
while True:
|
||||
if isinstance(resp, messages.PinMatrixRequest):
|
||||
if self.pin_callback is None:
|
||||
raise Exception # TODO
|
||||
resp = self.pin_callback(self, resp)
|
||||
elif isinstance(resp, messages.PassphraseRequest):
|
||||
if self.passphrase_callback is None:
|
||||
raise Exception # TODO
|
||||
resp = self.passphrase_callback(self, resp)
|
||||
elif isinstance(resp, messages.ButtonRequest):
|
||||
if self.button_callback is None:
|
||||
raise Exception # TODO
|
||||
resp = self.button_callback(self, resp)
|
||||
elif isinstance(resp, messages.Failure):
|
||||
if resp.code == messages.FailureType.ActionCancelled:
|
||||
raise exceptions.Cancelled
|
||||
raise exceptions.TrezorFailure(resp)
|
||||
elif not isinstance(resp, expect):
|
||||
raise exceptions.UnexpectedMessageError(expect, resp)
|
||||
else:
|
||||
return resp
|
||||
|
||||
def call_raw(self, msg: t.Any) -> t.Any:
|
||||
self._write(msg)
|
||||
return self._read()
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def refresh_features(self) -> None:
|
||||
self.client.refresh_features()
|
||||
|
||||
def end(self) -> t.Any:
|
||||
return self.call(messages.EndSession())
|
||||
|
||||
def ping(self, message: str, button_protection: bool | None = None) -> str:
|
||||
resp = self.call(
|
||||
messages.Ping(message=message, button_protection=button_protection),
|
||||
expect=messages.Success,
|
||||
)
|
||||
assert resp.message is not None
|
||||
return resp.message
|
||||
|
||||
def invalidate(self) -> None:
|
||||
self.client.invalidate()
|
||||
|
||||
@property
|
||||
def features(self) -> messages.Features:
|
||||
return self.client.features
|
||||
|
||||
@property
|
||||
def model(self) -> models.TrezorModel:
|
||||
return self.client.model
|
||||
|
||||
@property
|
||||
def version(self) -> t.Tuple[int, int, int]:
|
||||
return self.client.version
|
||||
|
||||
@property
|
||||
def id(self) -> bytes:
|
||||
return self._id
|
||||
|
||||
@id.setter
|
||||
def id(self, value: bytes) -> None:
|
||||
if not isinstance(value, bytes):
|
||||
raise ValueError("id must be of type bytes")
|
||||
self._id = value
|
||||
|
||||
|
||||
class SessionV1(Session):
|
||||
derive_cardano: bool | None = False
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
client: TrezorClient,
|
||||
passphrase: str | object = "",
|
||||
derive_cardano: bool = False,
|
||||
session_id: bytes | None = None,
|
||||
) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, id=session_id or b"")
|
||||
|
||||
session._init_callbacks()
|
||||
session.passphrase = passphrase
|
||||
session.derive_cardano = derive_cardano
|
||||
session.init_session(session.derive_cardano)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1:
|
||||
assert isinstance(client.protocol, ProtocolV1)
|
||||
session = SessionV1(client, session_id)
|
||||
session.init_session()
|
||||
return session
|
||||
|
||||
def _init_callbacks(self) -> None:
|
||||
self.button_callback = self.client.button_callback
|
||||
if self.button_callback is None:
|
||||
self.button_callback = _callback_button
|
||||
self.pin_callback = self.client.pin_callback
|
||||
self.passphrase_callback = self.client.passphrase_callback
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1)
|
||||
self.client.protocol.write(msg)
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
if t.TYPE_CHECKING:
|
||||
assert isinstance(self.client.protocol, ProtocolV1)
|
||||
return self.client.protocol.read()
|
||||
|
||||
def init_session(self, derive_cardano: bool | None = None):
|
||||
if self.id == b"":
|
||||
session_id = None
|
||||
else:
|
||||
session_id = self.id
|
||||
resp: messages.Features = self.call_raw(
|
||||
messages.Initialize(session_id=session_id, derive_cardano=derive_cardano)
|
||||
)
|
||||
if isinstance(self.passphrase, str):
|
||||
self.passphrase_callback = self.client.passphrase_callback
|
||||
self._id = resp.session_id
|
||||
|
||||
|
||||
def _callback_button(session: Session, msg: t.Any) -> t.Any:
|
||||
print("Please confirm action on your Trezor device.") # TODO how to handle UI?
|
||||
return session.call(messages.ButtonAck())
|
||||
|
||||
|
||||
class SessionV2(Session):
|
||||
|
||||
@classmethod
|
||||
def new(
|
||||
cls,
|
||||
client: TrezorClient,
|
||||
passphrase: str | None,
|
||||
derive_cardano: bool,
|
||||
session_id: int = 0,
|
||||
) -> SessionV2:
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
session = cls(client, session_id.to_bytes(1, "big"))
|
||||
session.call(
|
||||
messages.ThpCreateNewSession(
|
||||
passphrase=passphrase, derive_cardano=derive_cardano
|
||||
),
|
||||
expect=messages.Success,
|
||||
)
|
||||
session.update_id_and_sid(session_id.to_bytes(1, "big"))
|
||||
return session
|
||||
|
||||
def __init__(self, client: TrezorClient, id: bytes) -> None:
|
||||
from ..debuglink import TrezorClientDebugLink
|
||||
|
||||
super().__init__(client, id)
|
||||
assert isinstance(client.protocol, ProtocolV2)
|
||||
|
||||
self.pin_callback = client.pin_callback
|
||||
self.button_callback = client.button_callback
|
||||
if self.button_callback is None:
|
||||
self.button_callback = _callback_button
|
||||
helper_debug = None
|
||||
if isinstance(client, TrezorClientDebugLink):
|
||||
helper_debug = client.debug
|
||||
self.channel: ProtocolV2 = client.protocol.get_channel(helper_debug)
|
||||
self.update_id_and_sid(id)
|
||||
|
||||
def _write(self, msg: t.Any) -> None:
|
||||
LOG.debug("writing message %s", type(msg))
|
||||
self.channel.write(self.sid, msg)
|
||||
|
||||
def _read(self) -> t.Any:
|
||||
msg = self.channel.read(self.sid)
|
||||
LOG.debug("reading message %s", type(msg))
|
||||
return msg
|
||||
|
||||
def update_id_and_sid(self, id: bytes) -> None:
|
||||
self._id = id
|
||||
self.sid = int.from_bytes(id, "big") # TODO update to extract only sid
|
102
python/src/trezorlib/transport/thp/alternating_bit_protocol.py
Normal file
102
python/src/trezorlib/transport/thp/alternating_bit_protocol.py
Normal file
@ -0,0 +1,102 @@
|
||||
# from storage.cache_thp import ChannelCache
|
||||
# from trezor import log
|
||||
# from trezor.wire.thp import ThpError
|
||||
|
||||
|
||||
# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool:
|
||||
# """
|
||||
# Checks if:
|
||||
# - an ACK message is expected
|
||||
# - the received ACK message acknowledges correct sequence number (bit)
|
||||
# """
|
||||
# if not _is_ack_expected(cache):
|
||||
# return False
|
||||
|
||||
# if not _has_ack_correct_sync_bit(cache, ack_bit):
|
||||
# return False
|
||||
|
||||
# return True
|
||||
|
||||
|
||||
# def _is_ack_expected(cache: ChannelCache) -> bool:
|
||||
# is_expected: bool = not is_sending_allowed(cache)
|
||||
# if __debug__ and not is_expected:
|
||||
# log.debug(__name__, "Received unexpected ACK message")
|
||||
# return is_expected
|
||||
|
||||
|
||||
# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool:
|
||||
# is_correct: bool = get_send_seq_bit(cache) == sync_bit
|
||||
# if __debug__ and not is_correct:
|
||||
# log.debug(__name__, "Received ACK message with wrong ack bit")
|
||||
# return is_correct
|
||||
|
||||
|
||||
# def is_sending_allowed(cache: ChannelCache) -> bool:
|
||||
# """
|
||||
# Checks whether sending a message in the provided channel is allowed.
|
||||
|
||||
# Note: Sending a message in a channel before receipt of ACK message for the previously
|
||||
# sent message (in the channel) is prohibited, as it can lead to desynchronization.
|
||||
# """
|
||||
# return bool(cache.sync >> 7)
|
||||
|
||||
|
||||
# def get_send_seq_bit(cache: ChannelCache) -> int:
|
||||
# """
|
||||
# Returns the sequential number (bit) of the next message to be sent
|
||||
# in the provided channel.
|
||||
# """
|
||||
# return (cache.sync & 0x20) >> 5
|
||||
|
||||
|
||||
# def get_expected_receive_seq_bit(cache: ChannelCache) -> int:
|
||||
# """
|
||||
# Returns the (expected) sequential number (bit) of the next message
|
||||
# to be received in the provided channel.
|
||||
# """
|
||||
# return (cache.sync & 0x40) >> 6
|
||||
|
||||
|
||||
# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None:
|
||||
# """
|
||||
# Set the flag whether sending a message in this channel is allowed or not.
|
||||
# """
|
||||
# cache.sync &= 0x7F
|
||||
# if sending_allowed:
|
||||
# cache.sync |= 0x80
|
||||
|
||||
|
||||
# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||
# """
|
||||
# Set the expected sequential number (bit) of the next message to be received
|
||||
# in the provided channel
|
||||
# """
|
||||
# if __debug__:
|
||||
# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit)
|
||||
# if seq_bit not in (0, 1):
|
||||
# raise ThpError("Unexpected receive sync bit")
|
||||
|
||||
# # set second bit to "seq_bit" value
|
||||
# cache.sync &= 0xBF
|
||||
# if seq_bit:
|
||||
# cache.sync |= 0x40
|
||||
|
||||
|
||||
# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None:
|
||||
# if seq_bit not in (0, 1):
|
||||
# raise ThpError("Unexpected send seq bit")
|
||||
# if __debug__:
|
||||
# log.debug(__name__, "setting sync send seq bit to %d", seq_bit)
|
||||
# # set third bit to "seq_bit" value
|
||||
# cache.sync &= 0xDF
|
||||
# if seq_bit:
|
||||
# cache.sync |= 0x20
|
||||
|
||||
|
||||
# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None:
|
||||
# """
|
||||
# Set the sequential bit of the "next message to be send" to the opposite value,
|
||||
# i.e. 1 -> 0 and 0 -> 1
|
||||
# """
|
||||
# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache))
|
47
python/src/trezorlib/transport/thp/channel_data.py
Normal file
47
python/src/trezorlib/transport/thp/channel_data.py
Normal file
@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from binascii import hexlify
|
||||
|
||||
|
||||
class ChannelData:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol_version_major: int,
|
||||
protocol_version_minor: int,
|
||||
transport_path: str,
|
||||
channel_id: int,
|
||||
key_request: bytes,
|
||||
key_response: bytes,
|
||||
nonce_request: int,
|
||||
nonce_response: int,
|
||||
sync_bit_send: int,
|
||||
sync_bit_receive: int,
|
||||
handshake_hash: bytes,
|
||||
) -> None:
|
||||
self.protocol_version_major: int = protocol_version_major
|
||||
self.protocol_version_minor: int = protocol_version_minor
|
||||
self.transport_path: str = transport_path
|
||||
self.channel_id: int = channel_id
|
||||
self.key_request: str = hexlify(key_request).decode()
|
||||
self.key_response: str = hexlify(key_response).decode()
|
||||
self.nonce_request: int = nonce_request
|
||||
self.nonce_response: int = nonce_response
|
||||
self.sync_bit_receive: int = sync_bit_receive
|
||||
self.sync_bit_send: int = sync_bit_send
|
||||
self.handshake_hash: str = hexlify(handshake_hash).decode()
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"protocol_version_major": self.protocol_version_major,
|
||||
"protocol_version_minor": self.protocol_version_minor,
|
||||
"transport_path": self.transport_path,
|
||||
"channel_id": self.channel_id,
|
||||
"key_request": self.key_request,
|
||||
"key_response": self.key_response,
|
||||
"nonce_request": self.nonce_request,
|
||||
"nonce_response": self.nonce_response,
|
||||
"sync_bit_send": self.sync_bit_send,
|
||||
"sync_bit_receive": self.sync_bit_receive,
|
||||
"handshake_hash": self.handshake_hash,
|
||||
}
|
148
python/src/trezorlib/transport/thp/channel_database.py
Normal file
148
python/src/trezorlib/transport/thp/channel_database.py
Normal file
@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
|
||||
from ..thp.channel_data import ChannelData
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
db: "ChannelDatabase | None" = None
|
||||
|
||||
|
||||
def get_channel_db() -> ChannelDatabase:
|
||||
if db is None:
|
||||
set_channel_database(should_not_store=True)
|
||||
assert db is not None
|
||||
return db
|
||||
|
||||
|
||||
class ChannelDatabase:
|
||||
|
||||
def load_stored_channels(self) -> t.List[ChannelData]: ...
|
||||
def clear_stored_channels(self) -> None: ...
|
||||
def read_all_channels(self) -> t.List: ...
|
||||
def save_all_channels(self, channels: t.List[t.Dict]) -> None: ...
|
||||
def save_channel(self, new_channel: ProtocolAndChannel): ...
|
||||
def remove_channel(self, transport_path: str) -> None: ...
|
||||
|
||||
|
||||
class DummyChannelDatabase(ChannelDatabase):
|
||||
|
||||
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||
return []
|
||||
|
||||
def clear_stored_channels(self) -> None:
|
||||
pass
|
||||
|
||||
def read_all_channels(self) -> t.List:
|
||||
return []
|
||||
|
||||
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||
return
|
||||
|
||||
def save_channel(self, new_channel: ProtocolAndChannel):
|
||||
pass
|
||||
|
||||
def remove_channel(self, transport_path: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class JsonChannelDatabase(ChannelDatabase):
|
||||
def __init__(self, data_path: str) -> None:
|
||||
self.data_path = data_path
|
||||
super().__init__()
|
||||
|
||||
def load_stored_channels(self) -> t.List[ChannelData]:
|
||||
dicts = self.read_all_channels()
|
||||
return [dict_to_channel_data(d) for d in dicts]
|
||||
|
||||
def clear_stored_channels(self) -> None:
|
||||
LOG.debug("Clearing contents of %s", self.data_path)
|
||||
with open(self.data_path, "w") as f:
|
||||
json.dump([], f)
|
||||
try:
|
||||
os.remove(self.data_path)
|
||||
except Exception as e:
|
||||
LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e)))
|
||||
|
||||
def read_all_channels(self) -> t.List:
|
||||
ensure_file_exists(self.data_path)
|
||||
with open(self.data_path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def save_all_channels(self, channels: t.List[t.Dict]) -> None:
|
||||
LOG.debug("saving all channels")
|
||||
with open(self.data_path, "w") as f:
|
||||
json.dump(channels, f, indent=4)
|
||||
|
||||
def save_channel(self, new_channel: ProtocolAndChannel):
|
||||
|
||||
LOG.debug("save channel")
|
||||
channels = self.read_all_channels()
|
||||
transport_path = new_channel.transport.get_path()
|
||||
|
||||
# If the channel is found in database: replace the old entry by the new
|
||||
for i, channel in enumerate(channels):
|
||||
if channel["transport_path"] == transport_path:
|
||||
LOG.debug("Modified channel entry for %s", transport_path)
|
||||
channels[i] = new_channel.get_channel_data().to_dict()
|
||||
self.save_all_channels(channels)
|
||||
return
|
||||
|
||||
# Channel was not found: add a new channel entry
|
||||
LOG.debug("Created a new channel entry on path %s", transport_path)
|
||||
channels.append(new_channel.get_channel_data().to_dict())
|
||||
self.save_all_channels(channels)
|
||||
|
||||
def remove_channel(self, transport_path: str) -> None:
|
||||
LOG.debug(
|
||||
"Removing channel with path %s from the channel database.",
|
||||
transport_path,
|
||||
)
|
||||
channels = self.read_all_channels()
|
||||
remaining_channels = [
|
||||
ch for ch in channels if ch["transport_path"] != transport_path
|
||||
]
|
||||
self.save_all_channels(remaining_channels)
|
||||
|
||||
|
||||
def dict_to_channel_data(dict: t.Dict) -> ChannelData:
|
||||
return ChannelData(
|
||||
protocol_version_major=dict["protocol_version_minor"],
|
||||
protocol_version_minor=dict["protocol_version_major"],
|
||||
transport_path=dict["transport_path"],
|
||||
channel_id=dict["channel_id"],
|
||||
key_request=bytes.fromhex(dict["key_request"]),
|
||||
key_response=bytes.fromhex(dict["key_response"]),
|
||||
nonce_request=dict["nonce_request"],
|
||||
nonce_response=dict["nonce_response"],
|
||||
sync_bit_send=dict["sync_bit_send"],
|
||||
sync_bit_receive=dict["sync_bit_receive"],
|
||||
handshake_hash=bytes.fromhex(dict["handshake_hash"]),
|
||||
)
|
||||
|
||||
|
||||
def ensure_file_exists(file_path: str) -> None:
|
||||
LOG.debug("checking if file %s exists", file_path)
|
||||
if not os.path.exists(file_path):
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
LOG.debug("File %s does not exist. Creating a new one.", file_path)
|
||||
with open(file_path, "w") as f:
|
||||
json.dump([], f)
|
||||
|
||||
|
||||
def set_channel_database(should_not_store: bool):
|
||||
global db
|
||||
if should_not_store:
|
||||
db = DummyChannelDatabase()
|
||||
else:
|
||||
from platformdirs import user_cache_dir
|
||||
|
||||
APP_NAME = "@trezor" # TODO
|
||||
DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json")
|
||||
|
||||
db = JsonChannelDatabase(DATA_PATH)
|
19
python/src/trezorlib/transport/thp/checksum.py
Normal file
19
python/src/trezorlib/transport/thp/checksum.py
Normal file
@ -0,0 +1,19 @@
|
||||
import zlib
|
||||
|
||||
CHECKSUM_LENGTH = 4
|
||||
|
||||
|
||||
def compute(data: bytes) -> bytes:
|
||||
"""
|
||||
Returns a CRC-32 checksum of the provided `data`.
|
||||
"""
|
||||
return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big")
|
||||
|
||||
|
||||
def is_valid(checksum: bytes, data: bytes) -> bool:
|
||||
"""
|
||||
Checks whether the CRC-32 checksum of the `data` is the same
|
||||
as the checksum provided in `checksum`.
|
||||
"""
|
||||
data_checksum = compute(data)
|
||||
return checksum == data_checksum
|
63
python/src/trezorlib/transport/thp/control_byte.py
Normal file
63
python/src/trezorlib/transport/thp/control_byte.py
Normal file
@ -0,0 +1,63 @@
|
||||
CODEC_V1 = 0x3F
|
||||
CONTINUATION_PACKET = 0x80
|
||||
HANDSHAKE_INIT_REQ = 0x00
|
||||
HANDSHAKE_INIT_RES = 0x01
|
||||
HANDSHAKE_COMP_REQ = 0x02
|
||||
HANDSHAKE_COMP_RES = 0x03
|
||||
ENCRYPTED_TRANSPORT = 0x04
|
||||
|
||||
CONTINUATION_PACKET_MASK = 0x80
|
||||
ACK_MASK = 0xF7
|
||||
DATA_MASK = 0xE7
|
||||
|
||||
ACK_MESSAGE = 0x20
|
||||
_ERROR = 0x42
|
||||
CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_CHANNEL_ALLOCATION_RES = 0x41
|
||||
|
||||
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||
TREZOR_STATE_PAIRED = b"\x01"
|
||||
|
||||
|
||||
def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int:
|
||||
if seq_bit == 0:
|
||||
return ctrl_byte & 0xEF
|
||||
if seq_bit == 1:
|
||||
return ctrl_byte | 0x10
|
||||
raise Exception("Unexpected sequence bit")
|
||||
|
||||
|
||||
def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int:
|
||||
if ack_bit == 0:
|
||||
return ctrl_byte & 0xF7
|
||||
if ack_bit == 1:
|
||||
return ctrl_byte | 0x08
|
||||
raise Exception("Unexpected acknowledgement bit")
|
||||
|
||||
|
||||
def get_seq_bit(ctrl_byte: int) -> int:
|
||||
return (ctrl_byte & 0x10) >> 4
|
||||
|
||||
|
||||
def is_ack(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||
|
||||
|
||||
def is_error(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte == _ERROR
|
||||
|
||||
|
||||
def is_continuation(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET
|
||||
|
||||
|
||||
def is_encrypted_transport(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||
|
||||
|
||||
def is_handshake_init_req(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ
|
||||
|
||||
|
||||
def is_handshake_comp_req(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ
|
40
python/src/trezorlib/transport/thp/cpace.py
Normal file
40
python/src/trezorlib/transport/thp/cpace.py
Normal file
@ -0,0 +1,40 @@
|
||||
import typing as t
|
||||
from hashlib import sha512
|
||||
|
||||
from . import curve25519
|
||||
|
||||
_PREFIX = b"\x08\x43\x50\x61\x63\x65\x32\x35\x35\x06"
|
||||
_PADDING = b"\x6f\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x20"
|
||||
|
||||
|
||||
class Cpace:
|
||||
"""
|
||||
CPace, a balanced composable PAKE: https://datatracker.ietf.org/doc/draft-irtf-cfrg-cpace/
|
||||
"""
|
||||
|
||||
random_bytes: t.Callable[[int], bytes]
|
||||
|
||||
def __init__(self, handshake_hash: bytes) -> None:
|
||||
self.handshake_hash: bytes = handshake_hash
|
||||
self.shared_secret: bytes
|
||||
self.host_private_key: bytes
|
||||
self.host_public_key: bytes
|
||||
|
||||
def generate_keys_and_secret(
|
||||
self, code_code_entry: bytes, trezor_public_key: bytes
|
||||
) -> None:
|
||||
"""
|
||||
Generate ephemeral key pair and a shared secret using Elligator2 with X25519.
|
||||
"""
|
||||
sha_ctx = sha512(_PREFIX)
|
||||
sha_ctx.update(code_code_entry)
|
||||
sha_ctx.update(_PADDING)
|
||||
sha_ctx.update(self.handshake_hash)
|
||||
sha_ctx.update(b"\x00")
|
||||
pregenerator = sha_ctx.digest()[:32]
|
||||
generator = curve25519.elligator2(pregenerator)
|
||||
self.host_private_key = self.random_bytes(32)
|
||||
self.host_public_key = curve25519.multiply(self.host_private_key, generator)
|
||||
self.shared_secret = curve25519.multiply(
|
||||
self.host_private_key, trezor_public_key
|
||||
)
|
159
python/src/trezorlib/transport/thp/curve25519.py
Normal file
159
python/src/trezorlib/transport/thp/curve25519.py
Normal file
@ -0,0 +1,159 @@
|
||||
from typing import Tuple
|
||||
|
||||
p = 2**255 - 19
|
||||
J = 486662
|
||||
|
||||
c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1)
|
||||
c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8
|
||||
a24 = 121666 # (J + 2) // 4
|
||||
|
||||
|
||||
def decode_scalar(scalar: bytes) -> int:
|
||||
# decodeScalar25519 from
|
||||
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||
|
||||
if len(scalar) != 32:
|
||||
raise ValueError("Invalid length of scalar")
|
||||
|
||||
array = bytearray(scalar)
|
||||
array[0] &= 248
|
||||
array[31] &= 127
|
||||
array[31] |= 64
|
||||
|
||||
return int.from_bytes(array, "little")
|
||||
|
||||
|
||||
def decode_coordinate(coordinate: bytes) -> int:
|
||||
# decodeUCoordinate from
|
||||
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||
if len(coordinate) != 32:
|
||||
raise ValueError("Invalid length of coordinate")
|
||||
|
||||
array = bytearray(coordinate)
|
||||
array[-1] &= 0x7F
|
||||
return int.from_bytes(array, "little") % p
|
||||
|
||||
|
||||
def encode_coordinate(coordinate: int) -> bytes:
|
||||
# encodeUCoordinate from
|
||||
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||
return coordinate.to_bytes(32, "little")
|
||||
|
||||
|
||||
def get_private_key(secret: bytes) -> bytes:
|
||||
return decode_scalar(secret).to_bytes(32, "little")
|
||||
|
||||
|
||||
def get_public_key(private_key: bytes) -> bytes:
|
||||
base_point = int.to_bytes(9, 32, "little")
|
||||
return multiply(private_key, base_point)
|
||||
|
||||
|
||||
def multiply(private_scalar: bytes, public_point: bytes):
|
||||
# X25519 from
|
||||
# https://datatracker.ietf.org/doc/html/rfc7748#section-5
|
||||
|
||||
def ladder_operation(
|
||||
x1: int, x2: int, z2: int, x3: int, z3: int
|
||||
) -> Tuple[int, int, int, int]:
|
||||
# https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3
|
||||
# (x4, z4) = 2 * (x2, z2)
|
||||
# (x5, z5) = (x2, z2) + (x3, z3)
|
||||
# where (x1, 1) = (x3, z3) - (x2, z2)
|
||||
|
||||
a = (x2 + z2) % p
|
||||
aa = (a * a) % p
|
||||
b = (x2 - z2) % p
|
||||
bb = (b * b) % p
|
||||
e = (aa - bb) % p
|
||||
c = (x3 + z3) % p
|
||||
d = (x3 - z3) % p
|
||||
da = (d * a) % p
|
||||
cb = (c * b) % p
|
||||
t0 = (da + cb) % p
|
||||
x5 = (t0 * t0) % p
|
||||
t1 = (da - cb) % p
|
||||
t2 = (t1 * t1) % p
|
||||
z5 = (x1 * t2) % p
|
||||
x4 = (aa * bb) % p
|
||||
t3 = (a24 * e) % p
|
||||
t4 = (bb + t3) % p
|
||||
z4 = (e * t4) % p
|
||||
|
||||
return x4, z4, x5, z5
|
||||
|
||||
def conditional_swap(first: int, second: int, condition: int):
|
||||
# Returns (second, first) if condition is true and (first, second) otherwise
|
||||
# Must be implemented in a way that it is constant time
|
||||
true_mask = -condition
|
||||
false_mask = ~true_mask
|
||||
return (first & false_mask) | (second & true_mask), (second & false_mask) | (
|
||||
first & true_mask
|
||||
)
|
||||
|
||||
k = decode_scalar(private_scalar)
|
||||
u = decode_coordinate(public_point)
|
||||
|
||||
x_1 = u
|
||||
x_2 = 1
|
||||
z_2 = 0
|
||||
x_3 = u
|
||||
z_3 = 1
|
||||
swap = 0
|
||||
|
||||
for i in reversed(range(256)):
|
||||
bit = (k >> i) & 1
|
||||
swap = bit ^ swap
|
||||
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
|
||||
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
|
||||
swap = bit
|
||||
x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3)
|
||||
|
||||
(x_2, x_3) = conditional_swap(x_2, x_3, swap)
|
||||
(z_2, z_3) = conditional_swap(z_2, z_3, swap)
|
||||
|
||||
x = pow(z_2, p - 2, p) * x_2 % p
|
||||
return encode_coordinate(x)
|
||||
|
||||
|
||||
def elligator2(point: bytes) -> bytes:
|
||||
# map_to_curve_elligator2_curve25519 from
|
||||
# https://www.rfc-editor.org/rfc/rfc9380.html#ell2-opt
|
||||
|
||||
def conditional_move(first: int, second: int, condition: bool):
|
||||
# Returns second if condition is true and first otherwise
|
||||
# Must be implemented in a way that it is constant time
|
||||
true_mask = -condition
|
||||
false_mask = ~true_mask
|
||||
return (first & false_mask) | (second & true_mask)
|
||||
|
||||
u = decode_coordinate(point)
|
||||
tv1 = (u * u) % p
|
||||
tv1 = (2 * tv1) % p
|
||||
xd = (tv1 + 1) % p
|
||||
x1n = (-J) % p
|
||||
tv2 = (xd * xd) % p
|
||||
gxd = (tv2 * xd) % p
|
||||
gx1 = (J * tv1) % p
|
||||
gx1 = (gx1 * x1n) % p
|
||||
gx1 = (gx1 + tv2) % p
|
||||
gx1 = (gx1 * x1n) % p
|
||||
tv3 = (gxd * gxd) % p
|
||||
tv2 = (tv3 * tv3) % p
|
||||
tv3 = (tv3 * gxd) % p
|
||||
tv3 = (tv3 * gx1) % p
|
||||
tv2 = (tv2 * tv3) % p
|
||||
y11 = pow(tv2, c4, p)
|
||||
y11 = (y11 * tv3) % p
|
||||
y12 = (y11 * c3) % p
|
||||
tv2 = (y11 * y11) % p
|
||||
tv2 = (tv2 * gxd) % p
|
||||
e1 = tv2 == gx1
|
||||
y1 = conditional_move(y12, y11, e1)
|
||||
x2n = (x1n * tv1) % p
|
||||
tv2 = (y1 * y1) % p
|
||||
tv2 = (tv2 * gxd) % p
|
||||
e3 = tv2 == gx1
|
||||
xn = conditional_move(x2n, x1n, e3)
|
||||
x = xn * pow(xd, p - 2, p) % p
|
||||
return encode_coordinate(x)
|
82
python/src/trezorlib/transport/thp/message_header.py
Normal file
82
python/src/trezorlib/transport/thp/message_header.py
Normal file
@ -0,0 +1,82 @@
|
||||
import struct
|
||||
|
||||
CODEC_V1 = 0x3F
|
||||
CONTINUATION_PACKET = 0x80
|
||||
HANDSHAKE_INIT_REQ = 0x00
|
||||
HANDSHAKE_INIT_RES = 0x01
|
||||
HANDSHAKE_COMP_REQ = 0x02
|
||||
HANDSHAKE_COMP_RES = 0x03
|
||||
ENCRYPTED_TRANSPORT = 0x04
|
||||
|
||||
CONTINUATION_PACKET_MASK = 0x80
|
||||
ACK_MASK = 0xF7
|
||||
DATA_MASK = 0xE7
|
||||
|
||||
ACK_MESSAGE = 0x20
|
||||
_ERROR = 0x42
|
||||
CHANNEL_ALLOCATION_REQ = 0x40
|
||||
_CHANNEL_ALLOCATION_RES = 0x41
|
||||
|
||||
TREZOR_STATE_UNPAIRED = b"\x00"
|
||||
TREZOR_STATE_PAIRED = b"\x01"
|
||||
|
||||
BROADCAST_CHANNEL_ID = 0xFFFF
|
||||
|
||||
|
||||
class MessageHeader:
|
||||
format_str_init = ">BHH"
|
||||
format_str_cont = ">BH"
|
||||
|
||||
def __init__(self, ctrl_byte: int, cid: int, length: int) -> None:
|
||||
self.ctrl_byte = ctrl_byte
|
||||
self.cid = cid
|
||||
self.data_length = length
|
||||
|
||||
def to_bytes_init(self) -> bytes:
|
||||
return struct.pack(
|
||||
self.format_str_init, self.ctrl_byte, self.cid, self.data_length
|
||||
)
|
||||
|
||||
def to_bytes_cont(self) -> bytes:
|
||||
return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid)
|
||||
|
||||
def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||
struct.pack_into(
|
||||
self.format_str_init,
|
||||
buffer,
|
||||
buffer_offset,
|
||||
self.ctrl_byte,
|
||||
self.cid,
|
||||
self.data_length,
|
||||
)
|
||||
|
||||
def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None:
|
||||
struct.pack_into(
|
||||
self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid
|
||||
)
|
||||
|
||||
def is_ack(self) -> bool:
|
||||
return self.ctrl_byte & ACK_MASK == ACK_MESSAGE
|
||||
|
||||
def is_channel_allocation_response(self):
|
||||
return (
|
||||
self.cid == BROADCAST_CHANNEL_ID
|
||||
and self.ctrl_byte == _CHANNEL_ALLOCATION_RES
|
||||
)
|
||||
|
||||
def is_handshake_init_response(self) -> bool:
|
||||
return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES
|
||||
|
||||
def is_handshake_comp_response(self) -> bool:
|
||||
return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES
|
||||
|
||||
def is_encrypted_transport(self) -> bool:
|
||||
return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT
|
||||
|
||||
@classmethod
|
||||
def get_error_header(cls, cid: int, length: int):
|
||||
return cls(_ERROR, cid, length)
|
||||
|
||||
@classmethod
|
||||
def get_channel_allocation_request_header(cls, length: int):
|
||||
return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length)
|
32
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
32
python/src/trezorlib/transport/thp/protocol_and_channel.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from ... import messages
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
from ..thp.channel_data import ChannelData
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolAndChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_data: ChannelData | None = None,
|
||||
) -> None:
|
||||
self.transport = transport
|
||||
self.mapping = mapping
|
||||
self.channel_keys = channel_data
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
raise NotImplementedError
|
||||
|
||||
def update_features(self) -> None:
|
||||
raise NotImplementedError
|
97
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
97
python/src/trezorlib/transport/thp/protocol_v1.py
Normal file
@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import typing as t
|
||||
|
||||
from ... import exceptions, messages
|
||||
from ...log import DUMP_BYTES
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProtocolV1(ProtocolAndChannel):
|
||||
HEADER_LEN = struct.calcsize(">HL")
|
||||
_features: messages.Features | None = None
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if self._features is None:
|
||||
self.update_features()
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
def update_features(self) -> None:
|
||||
self.write(messages.GetFeatures())
|
||||
resp = self.read()
|
||||
if not isinstance(resp, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = resp
|
||||
|
||||
def read(self) -> t.Any:
|
||||
msg_type, msg_bytes = self._read()
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
msg = self.mapping.decode(msg_type, msg_bytes)
|
||||
LOG.debug(
|
||||
f"received message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
self.transport.close()
|
||||
return msg
|
||||
|
||||
def write(self, msg: t.Any) -> None:
|
||||
LOG.debug(
|
||||
f"sending message: {msg.__class__.__name__}",
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
msg_type, msg_bytes = self.mapping.encode(msg)
|
||||
LOG.log(
|
||||
DUMP_BYTES,
|
||||
f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}",
|
||||
)
|
||||
self._write(msg_type, msg_bytes)
|
||||
|
||||
def _write(self, message_type: int, message_data: bytes) -> None:
|
||||
chunk_size = self.transport.CHUNK_SIZE
|
||||
header = struct.pack(">HL", message_type, len(message_data))
|
||||
buffer = bytearray(b"##" + header + message_data)
|
||||
|
||||
while buffer:
|
||||
# Report ID, data padded to 63 bytes
|
||||
chunk = b"?" + buffer[: chunk_size - 1]
|
||||
chunk = chunk.ljust(chunk_size, b"\x00")
|
||||
self.transport.write_chunk(chunk)
|
||||
buffer = buffer[63:]
|
||||
|
||||
def _read(self) -> t.Tuple[int, bytes]:
|
||||
buffer = bytearray()
|
||||
# Read header with first part of message data
|
||||
msg_type, datalen, first_chunk = self.read_first()
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < datalen:
|
||||
buffer.extend(self.read_next())
|
||||
|
||||
return msg_type, buffer[:datalen]
|
||||
|
||||
def read_first(self) -> t.Tuple[int, int, bytes]:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:3] != b"?##":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
try:
|
||||
msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN])
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[3 + self.HEADER_LEN :]
|
||||
return msg_type, datalen, data
|
||||
|
||||
def read_next(self) -> bytes:
|
||||
chunk = self.transport.read_chunk()
|
||||
if chunk[:1] != b"?":
|
||||
raise RuntimeError("Unexpected magic characters")
|
||||
return chunk[1:]
|
490
python/src/trezorlib/transport/thp/protocol_v2.py
Normal file
490
python/src/trezorlib/transport/thp/protocol_v2.py
Normal file
@ -0,0 +1,490 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import typing as t
|
||||
from binascii import hexlify
|
||||
|
||||
import click
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
from ... import exceptions, messages, protobuf
|
||||
from ...mapping import ProtobufMapping
|
||||
from .. import Transport
|
||||
from ..thp import checksum, curve25519, thp_io
|
||||
from ..thp.channel_data import ChannelData
|
||||
from ..thp.checksum import CHECKSUM_LENGTH
|
||||
from ..thp.message_header import MessageHeader
|
||||
from . import control_byte
|
||||
from .channel_database import ChannelDatabase, get_channel_db
|
||||
from .protocol_and_channel import ProtocolAndChannel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_SESSION_ID: int = 0
|
||||
|
||||
if t.TYPE_CHECKING:
|
||||
from ...debuglink import DebugLink
|
||||
MT = t.TypeVar("MT", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes:
|
||||
hash = hashlib.sha256(val_1)
|
||||
hash.update(val_2)
|
||||
return hash.digest()
|
||||
|
||||
|
||||
def _hkdf(chaining_key: bytes, input: bytes):
|
||||
temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest()
|
||||
output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest()
|
||||
ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256)
|
||||
ctx_output_2.update(b"\x02")
|
||||
output_2 = ctx_output_2.digest()
|
||||
return (output_1, output_2)
|
||||
|
||||
|
||||
def _get_iv_from_nonce(nonce: int) -> bytes:
|
||||
if not nonce <= 0xFFFFFFFFFFFFFFFF:
|
||||
raise ValueError("Nonce overflow, terminate the channel")
|
||||
return bytes(4) + nonce.to_bytes(8, "big")
|
||||
|
||||
|
||||
class ProtocolV2(ProtocolAndChannel):
|
||||
channel_id: int
|
||||
channel_database: ChannelDatabase
|
||||
key_request: bytes
|
||||
key_response: bytes
|
||||
nonce_request: int
|
||||
nonce_response: int
|
||||
sync_bit_send: int
|
||||
sync_bit_receive: int
|
||||
handshake_hash: bytes
|
||||
|
||||
_has_valid_channel: bool = False
|
||||
_features: messages.Features | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Transport,
|
||||
mapping: ProtobufMapping,
|
||||
channel_data: ChannelData | None = None,
|
||||
) -> None:
|
||||
self.channel_database: ChannelDatabase = get_channel_db()
|
||||
super().__init__(transport, mapping, channel_data)
|
||||
if channel_data is not None:
|
||||
self.channel_id = channel_data.channel_id
|
||||
self.key_request = bytes.fromhex(channel_data.key_request)
|
||||
self.key_response = bytes.fromhex(channel_data.key_response)
|
||||
self.nonce_request = channel_data.nonce_request
|
||||
self.nonce_response = channel_data.nonce_response
|
||||
self.sync_bit_receive = channel_data.sync_bit_receive
|
||||
self.sync_bit_send = channel_data.sync_bit_send
|
||||
self.handshake_hash = bytes.fromhex(channel_data.handshake_hash)
|
||||
self._has_valid_channel = True
|
||||
|
||||
def get_channel(self, helper_debug: DebugLink | None = None) -> ProtocolV2:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel(helper_debug)
|
||||
return self
|
||||
|
||||
def get_channel_data(self) -> ChannelData:
|
||||
return ChannelData(
|
||||
protocol_version_major=2,
|
||||
protocol_version_minor=2,
|
||||
transport_path=self.transport.get_path(),
|
||||
channel_id=self.channel_id,
|
||||
key_request=self.key_request,
|
||||
key_response=self.key_response,
|
||||
nonce_request=self.nonce_request,
|
||||
nonce_response=self.nonce_response,
|
||||
sync_bit_receive=self.sync_bit_receive,
|
||||
sync_bit_send=self.sync_bit_send,
|
||||
handshake_hash=self.handshake_hash,
|
||||
)
|
||||
|
||||
def read(self, session_id: int) -> t.Any:
|
||||
sid, msg_type, msg_data = self.read_and_decrypt()
|
||||
if sid != session_id:
|
||||
raise Exception("Received messsage on a different session.")
|
||||
self.channel_database.save_channel(self)
|
||||
return self.mapping.decode(msg_type, msg_data)
|
||||
|
||||
def write(self, session_id: int, msg: t.Any) -> None:
|
||||
msg_type, msg_data = self.mapping.encode(msg)
|
||||
self._encrypt_and_write(session_id, msg_type, msg_data)
|
||||
self.channel_database.save_channel(self)
|
||||
|
||||
def get_features(self) -> messages.Features:
|
||||
if not self._has_valid_channel:
|
||||
self._establish_new_channel()
|
||||
if self._features is None:
|
||||
self.update_features()
|
||||
assert self._features is not None
|
||||
return self._features
|
||||
|
||||
def update_features(self) -> None:
|
||||
message = messages.GetFeatures()
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
self.session_id: int = DEFAULT_SESSION_ID
|
||||
self._encrypt_and_write(DEFAULT_SESSION_ID, message_type, message_data)
|
||||
_ = self._read_until_valid_crc_check() # TODO check ACK
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
features = self.mapping.decode(msg_type, msg_data)
|
||||
if not isinstance(features, messages.Features):
|
||||
raise exceptions.TrezorException("Unexpected response to GetFeatures")
|
||||
self._features = features
|
||||
|
||||
def _send_message(
|
||||
self,
|
||||
message: protobuf.MessageType,
|
||||
session_id: int = DEFAULT_SESSION_ID,
|
||||
):
|
||||
message_type, message_data = self.mapping.encode(message)
|
||||
self._encrypt_and_write(session_id, message_type, message_data)
|
||||
self._read_ack()
|
||||
|
||||
def _read_message(self, message_type: type[MT]) -> MT:
|
||||
_, msg_type, msg_data = self.read_and_decrypt()
|
||||
msg = self.mapping.decode(msg_type, msg_data)
|
||||
assert isinstance(msg, message_type)
|
||||
return msg
|
||||
|
||||
def _establish_new_channel(self, helper_debug: DebugLink | None = None) -> None:
|
||||
self._reset_sync_bits()
|
||||
self._do_channel_allocation()
|
||||
self._do_handshake()
|
||||
self._do_pairing(helper_debug)
|
||||
|
||||
def _reset_sync_bits(self) -> None:
|
||||
self.sync_bit_send = 0
|
||||
self.sync_bit_receive = 0
|
||||
|
||||
def _do_channel_allocation(self) -> None:
|
||||
channel_allocation_nonce = os.urandom(8)
|
||||
self._send_channel_allocation_request(channel_allocation_nonce)
|
||||
cid, dp = self._read_channel_allocation_response(channel_allocation_nonce)
|
||||
self.channel_id = cid
|
||||
self.device_properties = dp
|
||||
|
||||
def _send_channel_allocation_request(self, nonce: bytes):
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
MessageHeader.get_channel_allocation_request_header(12),
|
||||
nonce,
|
||||
)
|
||||
|
||||
def _read_channel_allocation_response(
|
||||
self, expected_nonce: bytes
|
||||
) -> tuple[int, bytes]:
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not self._is_valid_channel_allocation_response(
|
||||
header, payload, expected_nonce
|
||||
):
|
||||
raise Exception("Invalid channel allocation response.")
|
||||
|
||||
channel_id = int.from_bytes(payload[8:10], "big")
|
||||
device_properties = payload[10:]
|
||||
return (channel_id, device_properties)
|
||||
|
||||
def _do_handshake(
|
||||
self, credential: bytes | None = None, host_static_privkey: bytes | None = None
|
||||
):
|
||||
host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32))
|
||||
host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey)
|
||||
|
||||
self._send_handshake_init_request(host_ephemeral_pubkey)
|
||||
self._read_ack()
|
||||
init_response = self._read_handshake_init_response()
|
||||
|
||||
trezor_ephemeral_pubkey = init_response[:32]
|
||||
encrypted_trezor_static_pubkey = init_response[32:80]
|
||||
noise_tag = init_response[80:96]
|
||||
LOG.debug("noise_tag: %s", hexlify(noise_tag).decode())
|
||||
|
||||
# TODO check noise_tag is valid
|
||||
|
||||
ck = self._send_handshake_completion_request(
|
||||
host_ephemeral_pubkey,
|
||||
host_ephemeral_privkey,
|
||||
trezor_ephemeral_pubkey,
|
||||
encrypted_trezor_static_pubkey,
|
||||
credential,
|
||||
host_static_privkey,
|
||||
)
|
||||
self._read_ack()
|
||||
self._read_handshake_completion_response()
|
||||
self.key_request, self.key_response = _hkdf(ck, b"")
|
||||
self.nonce_request = 0
|
||||
self.nonce_response = 1
|
||||
|
||||
def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None:
|
||||
ha_init_req_header = MessageHeader(0, self.channel_id, 36)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, ha_init_req_header, host_ephemeral_pubkey
|
||||
)
|
||||
|
||||
def _read_handshake_init_response(self) -> bytes:
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
self._send_ack_0()
|
||||
|
||||
if header.ctrl_byte == 0x42:
|
||||
if payload == b"\x05":
|
||||
raise exceptions.DeviceLockedException()
|
||||
|
||||
if not header.is_handshake_init_response():
|
||||
LOG.debug("Received message is not a valid handshake init response message")
|
||||
|
||||
click.echo(
|
||||
"Received message is not a valid handshake init response message",
|
||||
err=True,
|
||||
)
|
||||
return payload
|
||||
|
||||
def _send_handshake_completion_request(
|
||||
self,
|
||||
host_ephemeral_pubkey: bytes,
|
||||
host_ephemeral_privkey: bytes,
|
||||
trezor_ephemeral_pubkey: bytes,
|
||||
encrypted_trezor_static_pubkey: bytes,
|
||||
credential: bytes | None = None,
|
||||
host_static_privkey: bytes | None = None,
|
||||
) -> bytes:
|
||||
PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00"
|
||||
IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
|
||||
h = _sha256_of_two(PROTOCOL_NAME, self.device_properties)
|
||||
h = _sha256_of_two(h, host_ephemeral_pubkey)
|
||||
h = _sha256_of_two(h, trezor_ephemeral_pubkey)
|
||||
ck, k = _hkdf(
|
||||
PROTOCOL_NAME,
|
||||
curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey),
|
||||
)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
try:
|
||||
trezor_masked_static_pubkey = aes_ctx.decrypt(
|
||||
IV_1, encrypted_trezor_static_pubkey, h
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(
|
||||
f"Exception of type{type(e)}", err=True
|
||||
) # TODO how to handle potential exceptions? Q for Matejcik
|
||||
h = _sha256_of_two(h, encrypted_trezor_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey)
|
||||
)
|
||||
aes_ctx = AESGCM(k)
|
||||
|
||||
tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h)
|
||||
h = _sha256_of_two(h, tag_of_empty_string)
|
||||
|
||||
# TODO: search for saved credentials
|
||||
if host_static_privkey is not None and credential is not None:
|
||||
host_static_pubkey = curve25519.get_public_key(host_static_privkey)
|
||||
else:
|
||||
credential = None
|
||||
zeroes_32 = int.to_bytes(0, 32, "little")
|
||||
temp_host_static_privkey = curve25519.get_private_key(zeroes_32)
|
||||
temp_host_static_pubkey = curve25519.get_public_key(
|
||||
temp_host_static_privkey
|
||||
)
|
||||
host_static_privkey = temp_host_static_privkey
|
||||
host_static_pubkey = temp_host_static_pubkey
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h)
|
||||
h = _sha256_of_two(h, encrypted_host_static_pubkey)
|
||||
ck, k = _hkdf(
|
||||
ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey)
|
||||
)
|
||||
msg_data = self.mapping.encode_without_wire_type(
|
||||
messages.ThpHandshakeCompletionReqNoisePayload(
|
||||
host_pairing_credential=credential,
|
||||
)
|
||||
)
|
||||
|
||||
aes_ctx = AESGCM(k)
|
||||
|
||||
encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h)
|
||||
h = _sha256_of_two(h, encrypted_payload[:-16])
|
||||
ha_completion_req_header = MessageHeader(
|
||||
0x12,
|
||||
self.channel_id,
|
||||
len(encrypted_host_static_pubkey)
|
||||
+ len(encrypted_payload)
|
||||
+ CHECKSUM_LENGTH,
|
||||
)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport,
|
||||
ha_completion_req_header,
|
||||
encrypted_host_static_pubkey + encrypted_payload,
|
||||
)
|
||||
self.handshake_hash = h
|
||||
return ck
|
||||
|
||||
def _read_handshake_completion_response(self) -> None:
|
||||
# Read handshake completion response, ignore payload as we do not care about the state
|
||||
header, _ = self._read_until_valid_crc_check()
|
||||
if not header.is_handshake_comp_response():
|
||||
click.echo(
|
||||
"Received message is not a valid handshake completion response",
|
||||
err=True,
|
||||
)
|
||||
self._send_ack_1()
|
||||
|
||||
def _do_pairing(self, helper_debug: DebugLink | None):
|
||||
|
||||
self._send_message(messages.ThpPairingRequest())
|
||||
self._read_message(messages.ButtonRequest)
|
||||
self._send_message(messages.ButtonAck())
|
||||
|
||||
if helper_debug is not None:
|
||||
helper_debug.press_yes()
|
||||
|
||||
self._read_message(messages.ThpPairingRequestApproved)
|
||||
self._send_message(
|
||||
messages.ThpSelectMethod(
|
||||
selected_pairing_method=messages.ThpPairingMethod.SkipPairing
|
||||
)
|
||||
)
|
||||
self._read_message(messages.ThpEndResponse)
|
||||
|
||||
self._has_valid_channel = True
|
||||
|
||||
def _read_ack(self):
|
||||
header, payload = self._read_until_valid_crc_check()
|
||||
if not header.is_ack() or len(payload) > 0:
|
||||
click.echo("Received message is not a valid ACK", err=True)
|
||||
|
||||
def _send_ack_0(self):
|
||||
LOG.debug("sending ack 0")
|
||||
header = MessageHeader(0x20, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _send_ack_1(self):
|
||||
LOG.debug("sending ack 1")
|
||||
header = MessageHeader(0x28, self.channel_id, 4)
|
||||
thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"")
|
||||
|
||||
def _encrypt_and_write(
|
||||
self,
|
||||
session_id: int,
|
||||
message_type: int,
|
||||
message_data: bytes,
|
||||
ctrl_byte: int | None = None,
|
||||
) -> None:
|
||||
assert self.key_request is not None
|
||||
aes_ctx = AESGCM(self.key_request)
|
||||
|
||||
if ctrl_byte is None:
|
||||
ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send)
|
||||
self.sync_bit_send = 1 - self.sync_bit_send
|
||||
|
||||
sid = session_id.to_bytes(1, "big")
|
||||
msg_type = message_type.to_bytes(2, "big")
|
||||
data = sid + msg_type + message_data
|
||||
nonce = _get_iv_from_nonce(self.nonce_request)
|
||||
self.nonce_request += 1
|
||||
encrypted_message = aes_ctx.encrypt(nonce, data, b"")
|
||||
header = MessageHeader(
|
||||
ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH
|
||||
)
|
||||
|
||||
thp_io.write_payload_to_wire_and_add_checksum(
|
||||
self.transport, header, encrypted_message
|
||||
)
|
||||
|
||||
def read_and_decrypt(self) -> t.Tuple[int, int, bytes]:
|
||||
header, raw_payload = self._read_until_valid_crc_check()
|
||||
if control_byte.is_ack(header.ctrl_byte):
|
||||
# TODO fix this recursion
|
||||
return self.read_and_decrypt()
|
||||
if control_byte.is_error(header.ctrl_byte):
|
||||
# TODO check for different channel
|
||||
err = _get_error_from_int(raw_payload[0])
|
||||
raise Exception("Received ThpError: " + err)
|
||||
if not header.is_encrypted_transport():
|
||||
click.echo(
|
||||
"Trying to decrypt not encrypted message! ("
|
||||
+ hexlify(header.to_bytes_init() + raw_payload).decode()
|
||||
+ ")",
|
||||
err=True,
|
||||
)
|
||||
|
||||
if not control_byte.is_ack(header.ctrl_byte):
|
||||
LOG.debug(
|
||||
"--> Get sequence bit %d %s %s",
|
||||
control_byte.get_seq_bit(header.ctrl_byte),
|
||||
"from control byte",
|
||||
hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(),
|
||||
)
|
||||
if control_byte.get_seq_bit(header.ctrl_byte):
|
||||
self._send_ack_1()
|
||||
else:
|
||||
self._send_ack_0()
|
||||
aes_ctx = AESGCM(self.key_response)
|
||||
nonce = _get_iv_from_nonce(self.nonce_response)
|
||||
self.nonce_response += 1
|
||||
|
||||
message = aes_ctx.decrypt(nonce, raw_payload, b"")
|
||||
session_id = message[0]
|
||||
message_type = message[1:3]
|
||||
message_data = message[3:]
|
||||
return (
|
||||
session_id,
|
||||
int.from_bytes(message_type, "big"),
|
||||
message_data,
|
||||
)
|
||||
|
||||
def _read_until_valid_crc_check(
|
||||
self,
|
||||
) -> t.Tuple[MessageHeader, bytes]:
|
||||
is_valid = False
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
while not is_valid:
|
||||
is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload)
|
||||
if not is_valid:
|
||||
click.echo(
|
||||
"Received a message with an invalid checksum:"
|
||||
+ hexlify(header.to_bytes_init() + payload + chksum).decode(),
|
||||
err=True,
|
||||
)
|
||||
header, payload, chksum = thp_io.read(self.transport)
|
||||
|
||||
return header, payload
|
||||
|
||||
def _is_valid_channel_allocation_response(
|
||||
self, header: MessageHeader, payload: bytes, original_nonce: bytes
|
||||
) -> bool:
|
||||
if not header.is_channel_allocation_response():
|
||||
click.echo(
|
||||
"Received message is not a channel allocation response", err=True
|
||||
)
|
||||
return False
|
||||
if len(payload) < 10:
|
||||
click.echo("Invalid channel allocation response payload", err=True)
|
||||
return False
|
||||
if payload[:8] != original_nonce:
|
||||
click.echo(
|
||||
"Invalid channel allocation response payload (nonce mismatch)", err=True
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_error_from_int(error_code: int) -> str:
|
||||
# TODO FIXME improve this (ThpErrorType)
|
||||
if error_code == 1:
|
||||
return "TRANSPORT BUSY"
|
||||
if error_code == 2:
|
||||
return "UNALLOCATED CHANNEL"
|
||||
if error_code == 3:
|
||||
return "DECRYPTION FAILED"
|
||||
if error_code == 4:
|
||||
return "INVALID DATA"
|
||||
if error_code == 5:
|
||||
return "DEVICE LOCKED"
|
||||
raise Exception("Not Implemented error case")
|
93
python/src/trezorlib/transport/thp/thp_io.py
Normal file
93
python/src/trezorlib/transport/thp/thp_io.py
Normal file
@ -0,0 +1,93 @@
|
||||
import struct
|
||||
from typing import Tuple
|
||||
|
||||
from .. import Transport
|
||||
from ..thp import checksum
|
||||
from .message_header import MessageHeader
|
||||
|
||||
INIT_HEADER_LENGTH = 5
|
||||
CONT_HEADER_LENGTH = 3
|
||||
MAX_PAYLOAD_LEN = 60000
|
||||
MESSAGE_TYPE_LENGTH = 2
|
||||
|
||||
CONTINUATION_PACKET = 0x80
|
||||
|
||||
|
||||
def write_payload_to_wire_and_add_checksum(
|
||||
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||
):
|
||||
chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload)
|
||||
data = transport_payload + chksum
|
||||
write_payload_to_wire(transport, header, data)
|
||||
|
||||
|
||||
def write_payload_to_wire(
|
||||
transport: Transport, header: MessageHeader, transport_payload: bytes
|
||||
):
|
||||
transport.open()
|
||||
buffer = bytearray(transport_payload)
|
||||
chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH]
|
||||
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
|
||||
buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :]
|
||||
while buffer:
|
||||
chunk = (
|
||||
header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH]
|
||||
)
|
||||
chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00")
|
||||
transport.write_chunk(chunk)
|
||||
buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :]
|
||||
|
||||
|
||||
def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]:
|
||||
"""
|
||||
Reads from the given wire transport.
|
||||
|
||||
Returns `Tuple[MessageHeader, bytes, bytes]`:
|
||||
1. `header` (`MessageHeader`): Header of the message.
|
||||
2. `data` (`bytes`): Contents of the message (if any).
|
||||
3. `checksum` (`bytes`): crc32 checksum of the header + data.
|
||||
|
||||
"""
|
||||
buffer = bytearray()
|
||||
|
||||
# Read header with first part of message data
|
||||
header, first_chunk = read_first(transport)
|
||||
buffer.extend(first_chunk)
|
||||
|
||||
# Read the rest of the message
|
||||
while len(buffer) < header.data_length:
|
||||
buffer.extend(read_next(transport, header.cid))
|
||||
|
||||
data_len = header.data_length - checksum.CHECKSUM_LENGTH
|
||||
msg_data = buffer[:data_len]
|
||||
chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH]
|
||||
|
||||
return (header, msg_data, chksum)
|
||||
|
||||
|
||||
def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]:
|
||||
chunk = transport.read_chunk()
|
||||
try:
|
||||
ctrl_byte, cid, data_length = struct.unpack(
|
||||
MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH]
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError("Cannot parse header")
|
||||
|
||||
data = chunk[INIT_HEADER_LENGTH:]
|
||||
return MessageHeader(ctrl_byte, cid, data_length), data
|
||||
|
||||
|
||||
def read_next(transport: Transport, cid: int) -> bytes:
|
||||
chunk = transport.read_chunk()
|
||||
ctrl_byte, read_cid = struct.unpack(
|
||||
MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH]
|
||||
)
|
||||
if ctrl_byte != CONTINUATION_PACKET:
|
||||
raise RuntimeError("Continuation packet with incorrect control byte")
|
||||
if read_cid != cid:
|
||||
raise RuntimeError("Continuation packet for different channel")
|
||||
|
||||
return chunk[CONT_HEADER_LENGTH:]
|
@ -14,14 +14,15 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Iterable, Optional
|
||||
from typing import TYPE_CHECKING, Iterable, Tuple
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UdpTransport(ProtocolBasedTransport):
|
||||
class UdpTransport(Transport):
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED: bool = True
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(self, device: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
device: str | None = None,
|
||||
) -> None:
|
||||
if not device:
|
||||
host = UdpTransport.DEFAULT_HOST
|
||||
port = UdpTransport.DEFAULT_PORT
|
||||
@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
devparts = device.split(":")
|
||||
host = devparts[0]
|
||||
port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT
|
||||
self.device = (host, port)
|
||||
self.socket: Optional[socket.socket] = None
|
||||
self.device: Tuple[str, int] = (host, port)
|
||||
|
||||
super().__init__(protocol=ProtocolV1(self))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
self.socket: socket.socket | None = None
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def _try_path(cls, path: str) -> "UdpTransport":
|
||||
d = cls(path)
|
||||
try:
|
||||
d.open()
|
||||
if d._ping():
|
||||
if d.ping():
|
||||
return d
|
||||
else:
|
||||
raise TransportException(
|
||||
@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
cls, _models: Iterable["TrezorModel"] | None = None
|
||||
) -> Iterable["UdpTransport"]:
|
||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||
try:
|
||||
@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
else:
|
||||
raise TransportException(f"No UDP device at {path}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self._ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
def get_path(self) -> str:
|
||||
return "{}:{}:{}".format(self.PATH_PREFIX, *self.device)
|
||||
|
||||
def open(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.close()
|
||||
self.socket = None
|
||||
|
||||
def _ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if self.socket is None:
|
||||
self.open()
|
||||
assert self.socket is not None
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected data length")
|
||||
@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
self.socket.sendall(chunk)
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
if self.socket is None:
|
||||
self.open()
|
||||
assert self.socket is not None
|
||||
while True:
|
||||
try:
|
||||
@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
if len(chunk) != 64:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return bytearray(chunk)
|
||||
|
||||
def find_debug(self) -> "UdpTransport":
|
||||
host, port = self.device
|
||||
return UdpTransport(f"{host}:{port + 1}")
|
||||
|
||||
def wait_until_ready(self, timeout: float = 10) -> None:
|
||||
try:
|
||||
self.open()
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if self.ping():
|
||||
break
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed >= timeout:
|
||||
raise TransportException("Timed out waiting for connection.")
|
||||
|
||||
time.sleep(0.05)
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Test if the device is listening."""
|
||||
assert self.socket is not None
|
||||
resp = None
|
||||
try:
|
||||
self.socket.sendall(b"PINGPING")
|
||||
resp = self.socket.recv(8)
|
||||
except Exception:
|
||||
pass
|
||||
return resp == b"PONGPONG"
|
||||
|
@ -14,16 +14,17 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable, List
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZORS, TrezorModel
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300
|
||||
WEBUSB_CHUNK_SIZE = 64
|
||||
|
||||
|
||||
class WebUsbHandle:
|
||||
def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None:
|
||||
class WebUsbTransport(Transport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
CHUNK_SIZE = 64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
|
||||
self.device = device
|
||||
self.debug = debug
|
||||
|
||||
self.interface = DEBUG_INTERFACE if debug else INTERFACE
|
||||
self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT
|
||||
self.count = 0
|
||||
self.handle: Optional["usb1.USBDeviceHandle"] = None
|
||||
self.handle: usb1.USBDeviceHandle | None = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
return devices
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = self.device.open()
|
||||
@ -64,6 +121,8 @@ class WebUsbHandle:
|
||||
self.handle.claimInterface(self.interface)
|
||||
except usb1.USBErrorAccess as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
except usb1.USBErrorBusy as e:
|
||||
raise DeviceIsBusy(self.device) from e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
@ -75,6 +134,8 @@ class WebUsbHandle:
|
||||
self.handle = None
|
||||
|
||||
def write_chunk(self, chunk: bytes) -> None:
|
||||
if self.handle is None:
|
||||
self.open()
|
||||
assert self.handle is not None
|
||||
if len(chunk) != WEBUSB_CHUNK_SIZE:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
@ -97,6 +158,8 @@ class WebUsbHandle:
|
||||
return
|
||||
|
||||
def read_chunk(self) -> bytes:
|
||||
if self.handle is None:
|
||||
self.open()
|
||||
assert self.handle is not None
|
||||
endpoint = 0x80 | self.endpoint
|
||||
while True:
|
||||
@ -117,70 +180,6 @@ class WebUsbHandle:
|
||||
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
|
||||
return chunk
|
||||
|
||||
|
||||
class WebUsbTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
WebUsbTransport implements transport over WebUSB interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = USB_IMPORTED
|
||||
context = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: "usb1.USBDevice",
|
||||
handle: Optional[WebUsbHandle] = None,
|
||||
debug: bool = False,
|
||||
) -> None:
|
||||
if handle is None:
|
||||
handle = WebUsbHandle(device, debug)
|
||||
|
||||
self.device = device
|
||||
self.handle = handle
|
||||
self.debug = debug
|
||||
|
||||
super().__init__(protocol=ProtocolV1(handle))
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close)
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
try:
|
||||
# workaround for issue #223:
|
||||
# on certain combinations of Windows USB drivers and libusb versions,
|
||||
# Trezor is returned twice (possibly because Windows know it as both
|
||||
# a HID and a WebUSB device), and one of the returned devices is
|
||||
# non-functional.
|
||||
dev.getProduct()
|
||||
devices.append(WebUsbTransport(dev))
|
||||
except usb1.USBErrorNotSupported:
|
||||
pass
|
||||
except usb1.USBErrorPipe:
|
||||
if usb_reset:
|
||||
handle = dev.open()
|
||||
handle.resetDevice()
|
||||
handle.close()
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "WebUsbTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
return WebUsbTransport(self.device, debug=True)
|
||||
|
Loading…
Reference in New Issue
Block a user