mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
feat(core): update HidTransport to support NewTransport
[no changelog]
This commit is contained in:
parent
6778181e2a
commit
30f05643fd
@ -2,7 +2,7 @@
|
||||
|
||||
import binascii
|
||||
from trezorlib.client import TrezorClient
|
||||
from trezorlib.transport_hid import HidTransport
|
||||
from trezorlib.transport.hid import HidTransport
|
||||
|
||||
devices = HidTransport.enumerate()
|
||||
if len(devices) > 0:
|
||||
|
@ -170,14 +170,15 @@ class NewTransport:
|
||||
|
||||
def all_transports() -> Iterable[Type["NewTransport"]]:
|
||||
# from .bridge import BridgeTransport
|
||||
# from .hid import HidTransport
|
||||
# TODO add bridge and HID
|
||||
# TODO add BridgeTransport
|
||||
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
transports: Tuple[Type["NewTransport"], ...] = (
|
||||
# BridgeTransport,
|
||||
# HidTransport,
|
||||
HidTransport,
|
||||
UdpTransport,
|
||||
WebUsbTransport,
|
||||
)
|
||||
|
@ -1,6 +1,6 @@
|
||||
# This file is part of the Trezor project.
|
||||
#
|
||||
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
||||
# Copyright (C) 2012-2024 SatoshiLabs and contributors
|
||||
#
|
||||
# This library is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Lesser General Public License version 3
|
||||
@ -14,15 +14,16 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
import typing as t
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from ..models import TREZOR_ONE, TrezorModel
|
||||
from . import UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport
|
||||
from . import UDEV_RULES_STR, NewTransport, TransportException
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@ -35,23 +36,61 @@ except Exception as e:
|
||||
HID_IMPORTED = False
|
||||
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
HidDeviceHandle = Any
|
||||
HidDevice = t.Dict[str, t.Any]
|
||||
HidDeviceHandle = t.Any
|
||||
|
||||
|
||||
class HidHandle:
|
||||
def __init__(
|
||||
self, path: bytes, serial: str, probe_hid_version: bool = False
|
||||
) -> None:
|
||||
self.path = path
|
||||
self.serial = serial
|
||||
class HidTransport(NewTransport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None:
|
||||
self.device = device
|
||||
self.device_path = device["path"]
|
||||
self.device_serial_number = device["serial_number"]
|
||||
self.handle: HidDeviceHandle = None
|
||||
self.hid_version = None if probe_hid_version else 2
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False
|
||||
) -> t.Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: t.List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
def open(self) -> None:
|
||||
self.handle = hid.device()
|
||||
try:
|
||||
self.handle.open_path(self.path)
|
||||
self.handle.open_path(self.device_path)
|
||||
except (IOError, OSError) as e:
|
||||
if sys.platform.startswith("linux"):
|
||||
e.args = e.args + (UDEV_RULES_STR,)
|
||||
@ -62,11 +101,11 @@ class HidHandle:
|
||||
# and we wouldn't even know.
|
||||
# So we check that the serial matches what we expect.
|
||||
serial = self.handle.get_serial_number_string()
|
||||
if serial != self.serial:
|
||||
if serial != self.device_serial_number:
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
raise TransportException(
|
||||
f"Unexpected device {serial} on path {self.path.decode()}"
|
||||
f"Unexpected device {serial} on path {self.device_path.decode()}"
|
||||
)
|
||||
|
||||
self.handle.set_nonblocking(True)
|
||||
@ -77,7 +116,7 @@ class HidHandle:
|
||||
def close(self) -> None:
|
||||
if self.handle is not None:
|
||||
# reload serial, because device.wipe() can reset it
|
||||
self.serial = self.handle.get_serial_number_string()
|
||||
self.device_serial_number = self.handle.get_serial_number_string()
|
||||
self.handle.close()
|
||||
self.handle = None
|
||||
|
||||
@ -115,53 +154,6 @@ class HidHandle:
|
||||
raise TransportException("Unknown HID version")
|
||||
|
||||
|
||||
class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
HidTransport implements transport over USB HID interface.
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = HID_IMPORTED
|
||||
|
||||
def __init__(self, device: HidDevice) -> None:
|
||||
self.device = device
|
||||
self.handle = HidHandle(device["path"], device["serial_number"])
|
||||
protocol = self.get_protocol()
|
||||
super().__init__(protocol)
|
||||
|
||||
def get_path(self) -> str:
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
|
||||
) -> Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
continue
|
||||
else:
|
||||
if not is_wirelink(dev):
|
||||
continue
|
||||
devices.append(HidTransport(dev))
|
||||
return devices
|
||||
|
||||
def find_debug(self) -> "HidTransport":
|
||||
# For v1 protocol, find debug USB interface for the same serial number
|
||||
for debug in HidTransport.enumerate(debug=True):
|
||||
if debug.device["serial_number"] == self.device["serial_number"]:
|
||||
return debug
|
||||
raise TransportException("Debug HID device not found")
|
||||
|
||||
|
||||
def is_wirelink(dev: HidDevice) -> bool:
|
||||
return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user