From 39635fcf440d34473856544dc26289261e3bb8de Mon Sep 17 00:00:00 2001 From: tychovrahe Date: Mon, 17 Apr 2023 18:01:31 +0200 Subject: [PATCH] fixup! feat(core): trezorctl working via BLE --- python/src/trezorlib/cli/__init__.py | 67 +++++++++++++++++++--- python/src/trezorlib/cli/ble.py | 34 ++++++++++- python/src/trezorlib/cli/trezorctl.py | 2 + python/src/trezorlib/tealblue.py | 5 +- python/src/trezorlib/transport/ble.py | 53 +++++++++-------- python/src/trezorlib/transport/protocol.py | 2 +- 6 files changed, 124 insertions(+), 39 deletions(-) diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 058e3a514..ebbf6ae11 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -18,7 +18,7 @@ import functools import sys import threading from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import click import dbus @@ -149,7 +149,9 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[ loop = GLib.MainLoop() dbus.mainloop.glib.DBusGMainLoop(set_as_default=True) - def callback_wrapper(): + def callback_wrapper( + r: List[Optional["R"]], exc: List[Optional[Exception]] + ) -> None: try: with obj.client_context() as client: session_was_resumed = obj.session_id == client.session_id @@ -158,9 +160,9 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[ click.echo("Warning: failed to resume session.", err=True) try: - return func(client, *args, **kwargs) + r.append(func(client, *args, **kwargs)) except Exception as e: - print(e) + exc[0] = e finally: if not session_was_resumed: try: @@ -168,19 +170,70 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[ except Exception: pass except Exception as e: - print(e) + exc[0] = e finally: loop.quit() - threading.Thread(target=callback_wrapper, daemon=True).start() + result: List["R"] = [] + exc: List[Optional[Exception]] = [None] + threading.Thread( + target=callback_wrapper, daemon=True, args=(result, exc) + ).start() loop.run() - return None + + if exc[0] is not None: + raise exc[0] + + if len(result) == 0: + raise click.ClickException("Command did not return a result.") + + return result[0] # the return type of @click.pass_obj is improperly specified and pyright doesn't # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) return trezorctl_command_with_client # type: ignore [cannot be assigned to return type] +def with_ble(func: "Callable[P, R]") -> "Callable[P, R]": + """Wrap a Click command in `with obj.client_context() as client`. + + Sessions are handled transparently. The user is warned when session did not resume + cleanly. The session is closed after the command completes - unless the session + was resumed, in which case it should remain open. + """ + + @functools.wraps(func) + def trezorctl_command(*args: "P.args", **kwargs: "P.kwargs") -> "R": + + loop = GLib.MainLoop() + dbus.mainloop.glib.DBusGMainLoop(set_as_default=True) + + def callback_wrapper(r: List[Optional["R"]], exc: List[Optional[Exception]]): + try: + r.append(func(*args, **kwargs)) + except Exception as e: + exc[0] = e + finally: + loop.quit() + + result: List["R"] = [] + exc: List[Optional[Exception]] = [None] + threading.Thread( + target=callback_wrapper, daemon=True, args=(result, exc) + ).start() + loop.run() + + if exc[0] is not None: + raise exc[0] + + if len(result) == 0: + raise click.ClickException("Command did not return a result.") + + return result[0] + + return trezorctl_command + + class AliasedGroup(click.Group): """Command group that handles aliases and Click 6.x compatibility. diff --git a/python/src/trezorlib/cli/ble.py b/python/src/trezorlib/cli/ble.py index a6ac7dadc..44d980e96 100644 --- a/python/src/trezorlib/cli/ble.py +++ b/python/src/trezorlib/cli/ble.py @@ -20,8 +20,9 @@ from typing import TYPE_CHECKING, BinaryIO import click -from .. import ble, exceptions -from . import with_client +from .. import ble, exceptions, tealblue +from ..transport.ble import lookup_device, scan_device +from . import with_ble, with_client if TYPE_CHECKING: from ..client import TrezorClient @@ -60,3 +61,32 @@ def update( except exceptions.TrezorException as e: click.echo(f"Update failed: {e}") sys.exit(3) + + +@cli.command() +@with_ble +def connect() -> None: + """Connect to the device via BLE.""" + adapter = tealblue.TealBlue().find_adapter() + + devices = lookup_device(adapter) + + devices = [d for d in devices if d.connected] + + if len(devices) == 0: + print("Scanning...") + devices = scan_device(adapter, devices) + + if len(devices) == 0: + print("No BLE devices found") + return + else: + print("Found %d BLE device(s)" % len(devices)) + + for device in devices: + print(f"Device: {device.name}, {device.address}") + + device = devices[0] + print(f"Connecting to {device.name}...") + device.connect() + print("Connected") diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index ff0fe7c6c..752337114 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -49,6 +49,7 @@ from . import ( settings, stellar, tezos, + with_ble, with_client, ) @@ -281,6 +282,7 @@ def format_device_name(features: messages.Features) -> str: @cli.command(name="list") +@with_ble @click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names") def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" diff --git a/python/src/trezorlib/tealblue.py b/python/src/trezorlib/tealblue.py index f869507d9..3756c74e1 100755 --- a/python/src/trezorlib/tealblue.py +++ b/python/src/trezorlib/tealblue.py @@ -1,4 +1,5 @@ -#!/usr/bin/python3 +# !/usr/bin/python3 +# pyright: off import queue import threading @@ -41,7 +42,7 @@ class TealBlue: continue properties = interfaces["org.bluez.Adapter1"] return Adapter(self, path, properties) - return None # no adapter found + raise Exception("No adapter found") # copied from: # https://github.com/adafruit/Adafruit_Python_BluefruitLE/blob/master/Adafruit_BluefruitLE/bluez_dbus/provider.py diff --git a/python/src/trezorlib/transport/ble.py b/python/src/trezorlib/transport/ble.py index 55cdb289d..ee2474560 100644 --- a/python/src/trezorlib/transport/ble.py +++ b/python/src/trezorlib/transport/ble.py @@ -15,9 +15,10 @@ # If not, see . import logging from queue import Queue -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Any, Iterable, List, Optional from .. import tealblue +from ..tealblue import Adapter, Characteristic from . import TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 @@ -31,7 +32,7 @@ NUS_CHARACTERISTIC_RX = "6e400002-b5a3-f393-e0a9-e50e24dcca9e" NUS_CHARACTERISTIC_TX = "6e400003-b5a3-f393-e0a9-e50e24dcca9e" -def scan_device(adapter, devices): +def scan_device(adapter: Adapter, devices: List[tealblue.Device]): with adapter.scan(2) as scanner: for device in scanner: if NUS_SERVICE_UUID in device.UUIDs: @@ -41,7 +42,7 @@ def scan_device(adapter, devices): return devices -def lookup_device(adapter): +def lookup_device(adapter: Adapter): devices = [] for device in adapter.devices(): if NUS_SERVICE_UUID in device.UUIDs: @@ -53,7 +54,7 @@ class BleTransport(ProtocolBasedTransport): ENABLED = True PATH_PREFIX = "ble" - def __init__(self, mac_addr: str, adapter) -> None: + def __init__(self, mac_addr: str, adapter: Adapter) -> None: self.tx = None self.rx = None @@ -75,7 +76,7 @@ class BleTransport(ProtocolBasedTransport): def find_debug(self) -> "BleTransport": mac = self.device - return BleTransport(f"{mac}") + return BleTransport(f"{mac}", self.adapter) @classmethod def enumerate( @@ -84,32 +85,30 @@ class BleTransport(ProtocolBasedTransport): adapter = tealblue.TealBlue().find_adapter() devices = lookup_device(adapter) - devices = [d for d in devices if d.connected] - - if len(devices) == 0: - print("Scanning...") - devices = scan_device(adapter, devices) - - print("Found %d devices" % len(devices)) - for device in devices: print(f"Device: {device.name}, {device.address}") + devices = [d for d in devices if d.connected] + return [BleTransport(device.address, adapter) for device in devices] - # @classmethod - # def find_by_path(cls, path: str, prefix_search: bool = False) -> "BleTransport": - # try: - # path = path.replace(f"{cls.PATH_PREFIX}:", "") - # return cls._try_path(path) - # except TransportException: - # if not prefix_search: - # raise - # - # if prefix_search: - # return super().find_by_path(path, prefix_search) - # else: - # raise TransportException(f"No UDP device at {path}") + @classmethod + def _try_path(cls, path: str) -> "BleTransport": + devices = cls.enumerate(None) + devices = [d for d in devices if d.device == path] + if len(devices) == 0: + raise TransportException(f"No BLE device: {path}") + return devices[0] + + @classmethod + def find_by_path(cls, path: str, prefix_search: bool = False) -> "BleTransport": + if not prefix_search: + raise TransportException + + if prefix_search: + return super().find_by_path(path, prefix_search) + else: + raise TransportException(f"No BLE device: {path}") def open(self) -> None: @@ -133,7 +132,7 @@ class BleTransport(ProtocolBasedTransport): self.rx = service.characteristics[NUS_CHARACTERISTIC_RX] self.tx = service.characteristics[NUS_CHARACTERISTIC_TX] - def on_notify(characteristic, value): + def on_notify(characteristic: Characteristic, value: Any): self.received_data.put(bytes(value)) self.tx.on_notify = on_notify diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index 4e9433451..5aba296ee 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -75,7 +75,7 @@ class Protocol: its messages. """ - def __init__(self, handle: Handle, replen=REPLEN) -> None: + def __init__(self, handle: Handle, replen: int = REPLEN) -> None: self.handle = handle self.replen = replen self.session_counter = 0