1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-18 05:28:40 +00:00

fixup! feat(core): trezorctl working via BLE

This commit is contained in:
tychovrahe 2023-04-17 18:01:31 +02:00
parent f56a8710a8
commit f203932f54
6 changed files with 124 additions and 39 deletions

View File

@ -18,7 +18,7 @@ import functools
import sys
import threading
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
import click
import dbus
@ -149,7 +149,9 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
loop = GLib.MainLoop()
dbus.mainloop.glib.DBusGMainLoop(set_as_default=True)
def callback_wrapper():
def callback_wrapper(
r: List[Optional["R"]], exc: List[Optional[Exception]]
) -> None:
try:
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
@ -158,9 +160,9 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
click.echo("Warning: failed to resume session.", err=True)
try:
return func(client, *args, **kwargs)
r.append(func(client, *args, **kwargs))
except Exception as e:
print(e)
exc[0] = e
finally:
if not session_was_resumed:
try:
@ -168,19 +170,70 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
except Exception:
pass
except Exception as e:
print(e)
exc[0] = e
finally:
loop.quit()
threading.Thread(target=callback_wrapper, daemon=True).start()
result: List["R"] = []
exc: List[Optional[Exception]] = [None]
threading.Thread(
target=callback_wrapper, daemon=True, args=(result, exc)
).start()
loop.run()
return None
if exc[0] is not None:
raise exc[0]
if len(result) == 0:
raise click.ClickException("Command did not return a result.")
return result[0]
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return trezorctl_command_with_client # type: ignore [cannot be assigned to return type]
def with_ble(func: "Callable[P, R]") -> "Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
cleanly. The session is closed after the command completes - unless the session
was resumed, in which case it should remain open.
"""
@functools.wraps(func)
def trezorctl_command(*args: "P.args", **kwargs: "P.kwargs") -> "R":
loop = GLib.MainLoop()
dbus.mainloop.glib.DBusGMainLoop(set_as_default=True)
def callback_wrapper(r: List[Optional["R"]], exc: List[Optional[Exception]]):
try:
r.append(func(*args, **kwargs))
except Exception as e:
exc[0] = e
finally:
loop.quit()
result: List["R"] = []
exc: List[Optional[Exception]] = [None]
threading.Thread(
target=callback_wrapper, daemon=True, args=(result, exc)
).start()
loop.run()
if exc[0] is not None:
raise exc[0]
if len(result) == 0:
raise click.ClickException("Command did not return a result.")
return result[0]
return trezorctl_command
class AliasedGroup(click.Group):
"""Command group that handles aliases and Click 6.x compatibility.

View File

@ -20,8 +20,9 @@ from typing import TYPE_CHECKING, BinaryIO
import click
from .. import ble, exceptions
from . import with_client
from .. import ble, exceptions, tealblue
from ..transport.ble import lookup_device, scan_device
from . import with_ble, with_client
if TYPE_CHECKING:
from ..client import TrezorClient
@ -60,3 +61,32 @@ def update(
except exceptions.TrezorException as e:
click.echo(f"Update failed: {e}")
sys.exit(3)
@cli.command()
@with_ble
def connect() -> None:
"""Connect to the device via BLE."""
adapter = tealblue.TealBlue().find_adapter()
devices = lookup_device(adapter)
devices = [d for d in devices if d.connected]
if len(devices) == 0:
print("Scanning...")
devices = scan_device(adapter, devices)
if len(devices) == 0:
print("No BLE devices found")
return
else:
print("Found %d BLE device(s)" % len(devices))
for device in devices:
print(f"Device: {device.name}, {device.address}")
device = devices[0]
print(f"Connecting to {device.name}...")
device.connect()
print("Connected")

View File

@ -49,6 +49,7 @@ from . import (
settings,
stellar,
tezos,
with_ble,
with_client,
)
@ -281,6 +282,7 @@ def format_device_name(features: messages.Features) -> str:
@cli.command(name="list")
@with_ble
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices."""

View File

@ -1,4 +1,5 @@
#!/usr/bin/python3
# !/usr/bin/python3
# pyright: off
import queue
import threading
@ -41,7 +42,7 @@ class TealBlue:
continue
properties = interfaces["org.bluez.Adapter1"]
return Adapter(self, path, properties)
return None # no adapter found
raise Exception("No adapter found")
# copied from:
# https://github.com/adafruit/Adafruit_Python_BluefruitLE/blob/master/Adafruit_BluefruitLE/bluez_dbus/provider.py

View File

@ -15,9 +15,10 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
from queue import Queue
from typing import TYPE_CHECKING, Iterable, Optional
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
from .. import tealblue
from ..tealblue import Adapter, Characteristic
from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
@ -31,7 +32,7 @@ NUS_CHARACTERISTIC_RX = "6e400002-b5a3-f393-e0a9-e50e24dcca9e"
NUS_CHARACTERISTIC_TX = "6e400003-b5a3-f393-e0a9-e50e24dcca9e"
def scan_device(adapter, devices):
def scan_device(adapter: Adapter, devices: List[tealblue.Device]):
with adapter.scan(2) as scanner:
for device in scanner:
if NUS_SERVICE_UUID in device.UUIDs:
@ -41,7 +42,7 @@ def scan_device(adapter, devices):
return devices
def lookup_device(adapter):
def lookup_device(adapter: Adapter):
devices = []
for device in adapter.devices():
if NUS_SERVICE_UUID in device.UUIDs:
@ -53,7 +54,7 @@ class BleTransport(ProtocolBasedTransport):
ENABLED = True
PATH_PREFIX = "ble"
def __init__(self, mac_addr: str, adapter) -> None:
def __init__(self, mac_addr: str, adapter: Adapter) -> None:
self.tx = None
self.rx = None
@ -75,7 +76,7 @@ class BleTransport(ProtocolBasedTransport):
def find_debug(self) -> "BleTransport":
mac = self.device
return BleTransport(f"{mac}")
return BleTransport(f"{mac}", self.adapter)
@classmethod
def enumerate(
@ -84,32 +85,30 @@ class BleTransport(ProtocolBasedTransport):
adapter = tealblue.TealBlue().find_adapter()
devices = lookup_device(adapter)
devices = [d for d in devices if d.connected]
if len(devices) == 0:
print("Scanning...")
devices = scan_device(adapter, devices)
print("Found %d devices" % len(devices))
for device in devices:
print(f"Device: {device.name}, {device.address}")
devices = [d for d in devices if d.connected]
return [BleTransport(device.address, adapter) for device in devices]
# @classmethod
# def find_by_path(cls, path: str, prefix_search: bool = False) -> "BleTransport":
# try:
# path = path.replace(f"{cls.PATH_PREFIX}:", "")
# return cls._try_path(path)
# except TransportException:
# if not prefix_search:
# raise
#
# if prefix_search:
# return super().find_by_path(path, prefix_search)
# else:
# raise TransportException(f"No UDP device at {path}")
@classmethod
def _try_path(cls, path: str) -> "BleTransport":
devices = cls.enumerate(None)
devices = [d for d in devices if d.device == path]
if len(devices) == 0:
raise TransportException(f"No BLE device: {path}")
return devices[0]
@classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "BleTransport":
if not prefix_search:
raise TransportException
if prefix_search:
return super().find_by_path(path, prefix_search)
else:
raise TransportException(f"No BLE device: {path}")
def open(self) -> None:
@ -133,7 +132,7 @@ class BleTransport(ProtocolBasedTransport):
self.rx = service.characteristics[NUS_CHARACTERISTIC_RX]
self.tx = service.characteristics[NUS_CHARACTERISTIC_TX]
def on_notify(characteristic, value):
def on_notify(characteristic: Characteristic, value: Any):
self.received_data.put(bytes(value))
self.tx.on_notify = on_notify

View File

@ -75,7 +75,7 @@ class Protocol:
its messages.
"""
def __init__(self, handle: Handle, replen=REPLEN) -> None:
def __init__(self, handle: Handle, replen: int = REPLEN) -> None:
self.handle = handle
self.replen = replen
self.session_counter = 0