1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(core): update HidTransport to support NewTransport

[no changelog]
This commit is contained in:
M1nd3r 2024-09-23 17:15:15 +02:00
parent 6778181e2a
commit 30f05643fd
3 changed files with 60 additions and 67 deletions

View File

@ -2,7 +2,7 @@
import binascii
from trezorlib.client import TrezorClient
from trezorlib.transport_hid import HidTransport
from trezorlib.transport.hid import HidTransport
devices = HidTransport.enumerate()
if len(devices) > 0:

View File

@ -170,14 +170,15 @@ class NewTransport:
def all_transports() -> Iterable[Type["NewTransport"]]:
# from .bridge import BridgeTransport
# from .hid import HidTransport
# TODO add bridge and HID
# TODO add BridgeTransport
from .hid import HidTransport
from .udp import UdpTransport
from .webusb import WebUsbTransport
transports: Tuple[Type["NewTransport"], ...] = (
# BridgeTransport,
# HidTransport,
HidTransport,
UdpTransport,
WebUsbTransport,
)

View File

@ -1,6 +1,6 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
# Copyright (C) 2012-2024 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
@ -14,15 +14,16 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import sys
import time
from typing import Any, Dict, Iterable, List, Optional
import typing as t
from ..log import DUMP_PACKETS
from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport
from . import UDEV_RULES_STR, NewTransport, TransportException
LOG = logging.getLogger(__name__)
@ -35,23 +36,61 @@ except Exception as e:
HID_IMPORTED = False
HidDevice = Dict[str, Any]
HidDeviceHandle = Any
HidDevice = t.Dict[str, t.Any]
HidDeviceHandle = t.Any
class HidHandle:
def __init__(
self, path: bytes, serial: str, probe_hid_version: bool = False
) -> None:
self.path = path
self.serial = serial
class HidTransport(NewTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
self.device = device
self.device_path = device["path"]
self.device_serial_number = device["serial_number"]
self.handle: HidDeviceHandle = None
self.hid_version = None if probe_hid_version else 2
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
) -> t.Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: t.List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def open(self) -> None:
self.handle = hid.device()
try:
self.handle.open_path(self.path)
self.handle.open_path(self.device_path)
except (IOError, OSError) as e:
if sys.platform.startswith("linux"):
e.args = e.args + (UDEV_RULES_STR,)
@ -62,11 +101,11 @@ class HidHandle:
# and we wouldn't even know.
# So we check that the serial matches what we expect.
serial = self.handle.get_serial_number_string()
if serial != self.serial:
if serial != self.device_serial_number:
self.handle.close()
self.handle = None
raise TransportException(
f"Unexpected device {serial} on path {self.path.decode()}"
f"Unexpected device {serial} on path {self.device_path.decode()}"
)
self.handle.set_nonblocking(True)
@ -77,7 +116,7 @@ class HidHandle:
def close(self) -> None:
if self.handle is not None:
# reload serial, because device.wipe() can reset it
self.serial = self.handle.get_serial_number_string()
self.device_serial_number = self.handle.get_serial_number_string()
self.handle.close()
self.handle = None
@ -115,53 +154,6 @@ class HidHandle:
raise TransportException("Unknown HID version")
class HidTransport(ProtocolBasedTransport):
"""
HidTransport implements transport over USB HID interface.
"""
PATH_PREFIX = "hid"
ENABLED = HID_IMPORTED
def __init__(self, device: HidDevice) -> None:
self.device = device
self.handle = HidHandle(device["path"], device["serial_number"])
protocol = self.get_protocol()
super().__init__(protocol)
def get_path(self) -> str:
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):
continue
else:
if not is_wirelink(dev):
continue
devices.append(HidTransport(dev))
return devices
def find_debug(self) -> "HidTransport":
# For v1 protocol, find debug USB interface for the same serial number
for debug in HidTransport.enumerate(debug=True):
if debug.device["serial_number"] == self.device["serial_number"]:
return debug
raise TransportException("Debug HID device not found")
def is_wirelink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0