From 4a3e763a1bdd8317ae96ceedc38cb3bb6f3483a3 Mon Sep 17 00:00:00 2001 From: Martin Milata Date: Tue, 27 Jun 2023 20:19:08 +0200 Subject: [PATCH] feat(python): use dbus-next for BLE --- pyproject.toml | 4 +- python/requirements.txt | 4 +- python/setup.cfg | 2 +- python/src/trezorlib/cli/__init__.py | 106 +----- python/src/trezorlib/cli/ble.py | 30 +- python/src/trezorlib/cli/trezorctl.py | 2 - python/src/trezorlib/tealblue.py | 465 ++++++++++---------------- python/src/trezorlib/transport/ble.py | 240 +++++++++---- 8 files changed, 375 insertions(+), 478 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e2878dfcc..fe0d4eddf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,9 +77,7 @@ binsize = "^0.1.3" toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"} # ble -dbus-python = "*" -PyGObject = "*" -nrfutil = "*" +dbus-next = "*" [tool.poetry.dev-dependencies] scan-build = "*" diff --git a/python/requirements.txt b/python/requirements.txt index 1493b3e27..ebd5bdd66 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -8,6 +8,4 @@ typing_extensions>=3.10 dataclasses ; python_version<'3.7' simple-rlp>=0.1.2 ; python_version>='3.7' construct-classes>=0.1.2 -dbus-python>=1.3.2 -pygobject>=3.44.1 -nrfutil>=5.0.0 +dbus-next>=0.2.3 diff --git a/python/setup.cfg b/python/setup.cfg index 2889dd2e6..c7c16c34e 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -25,7 +25,7 @@ per-file-ignores = helper-scripts/*:I tools/*:I tests/*:I -known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore],simple-rlp:[rlp],dbus-python:[dbus] +known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore],simple-rlp:[rlp],dbus-next:[dbus_next] [isort] profile = black diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index d975de8bf..8a4191a8e 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -16,17 +16,12 @@ import functools import sys -import threading from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -import dbus -import dbus.mainloop.glib -import dbus.service -from gi.repository import GLib -from .. import exceptions, messages, transport +from .. import exceptions, transport from ..client import TrezorClient from ..ui import ClickUI, ScriptUI @@ -109,12 +104,6 @@ class TrezorConnection: except transport.DeviceIsBusy: click.echo("Device is in use by another process.") sys.exit(1) - except exceptions.TrezorFailure as e: - if e.code is messages.FailureType.DeviceIsBusy: - click.echo(str(e)) - sys.exit(1) - else: - raise e except Exception: click.echo("Failed to find a Trezor device.") if self.path is not None: @@ -146,95 +135,26 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[ def trezorctl_command_with_client( obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" ) -> "R": + with obj.client_context() as client: + session_was_resumed = obj.session_id == client.session_id + if not session_was_resumed and obj.session_id is not None: + # tried to resume but failed + click.echo("Warning: failed to resume session.", err=True) - loop = GLib.MainLoop() - dbus.mainloop.glib.DBusGMainLoop(set_as_default=True) - - def callback_wrapper( - r: List[Optional["R"]], exc: List[Optional[Exception]] - ) -> None: try: - with obj.client_context() as client: - session_was_resumed = obj.session_id == client.session_id - if not session_was_resumed and obj.session_id is not None: - # tried to resume but failed - click.echo("Warning: failed to resume session.", err=True) - - try: - r.append(func(client, *args, **kwargs)) - except Exception as e: - exc[0] = e - finally: - if not session_was_resumed: - try: - client.end_session() - except Exception: - pass - except Exception as e: - exc[0] = e + return func(client, *args, **kwargs) finally: - loop.quit() - - result: List["R"] = [] - exc: List[Optional[Exception]] = [None] - threading.Thread( - target=callback_wrapper, daemon=True, args=(result, exc) - ).start() - loop.run() - - if exc[0] is not None: - raise exc[0] - - if len(result) == 0: - raise click.ClickException("Command did not return a result.") - - return result[0] + if not session_was_resumed: + try: + client.end_session() + except Exception: + pass # the return type of @click.pass_obj is improperly specified and pyright doesn't # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) return trezorctl_command_with_client # type: ignore [cannot be assigned to return type] -def with_ble(func: "Callable[P, R]") -> "Callable[P, R]": - """Wrap a Click command in `with obj.client_context() as client`. - - Sessions are handled transparently. The user is warned when session did not resume - cleanly. The session is closed after the command completes - unless the session - was resumed, in which case it should remain open. - """ - - @functools.wraps(func) - def trezorctl_command(*args: "P.args", **kwargs: "P.kwargs") -> "R": - - loop = GLib.MainLoop() - dbus.mainloop.glib.DBusGMainLoop(set_as_default=True) - - def callback_wrapper(r: List[Optional["R"]], exc: List[Optional[Exception]]): - try: - r.append(func(*args, **kwargs)) - except Exception as e: - exc[0] = e - finally: - loop.quit() - - result: List["R"] = [] - exc: List[Optional[Exception]] = [None] - threading.Thread( - target=callback_wrapper, daemon=True, args=(result, exc) - ).start() - loop.run() - - if exc[0] is not None: - raise exc[0] - - if len(result) == 0: - raise click.ClickException("Command did not return a result.") - - return result[0] - - return trezorctl_command - - class AliasedGroup(click.Group): """Command group that handles aliases and Click 6.x compatibility. diff --git a/python/src/trezorlib/cli/ble.py b/python/src/trezorlib/cli/ble.py index 53596dd67..c89a872b8 100644 --- a/python/src/trezorlib/cli/ble.py +++ b/python/src/trezorlib/cli/ble.py @@ -21,9 +21,9 @@ from typing import TYPE_CHECKING, BinaryIO import click -from .. import ble, exceptions, tealblue -from ..transport.ble import lookup_device, scan_device -from . import with_ble, with_client +from .. import ble, exceptions +from ..transport.ble import BleProxy +from . import with_client if TYPE_CHECKING: from ..client import TrezorClient @@ -70,32 +70,28 @@ def update( @cli.command() -@with_ble def connect() -> None: """Connect to the device via BLE.""" - adapter = tealblue.TealBlue().find_adapter() - - devices = lookup_device(adapter) - - devices = [d for d in devices if d.connected] + ble = BleProxy() + devices = [d for d in ble.lookup() if d.connected] if len(devices) == 0: - print("Scanning...") - devices = scan_device(adapter, devices) + click.echo("Scanning...") + devices = ble.scan() if len(devices) == 0: - print("No BLE devices found") + click.echo("No BLE devices found") return else: - print("Found %d BLE device(s)" % len(devices)) + click.echo("Found %d BLE device(s)" % len(devices)) for device in devices: - print(f"Device: {device.name}, {device.address}") + click.echo(f"Device: {device.name}, {device.address}") device = devices[0] - print(f"Connecting to {device.name}...") - device.connect() - print("Connected") + click.echo(f"Connecting to {device.name}...") + ble.connect(device.address) + click.echo("Connected") @cli.command() diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 752337114..ff0fe7c6c 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -49,7 +49,6 @@ from . import ( settings, stellar, tezos, - with_ble, with_client, ) @@ -282,7 +281,6 @@ def format_device_name(features: messages.Features) -> str: @cli.command(name="list") -@with_ble @click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names") def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" diff --git a/python/src/trezorlib/tealblue.py b/python/src/trezorlib/tealblue.py index 3756c74e1..42b9ff7ec 100755 --- a/python/src/trezorlib/tealblue.py +++ b/python/src/trezorlib/tealblue.py @@ -1,201 +1,170 @@ # !/usr/bin/python3 # pyright: off -import queue -import threading -import time +import asyncio +import io +import logging -import dbus -import dbus.mainloop.glib -import dbus.service +import dbus_next +LOG = logging.getLogger(__name__) -class NotConnectedError(Exception): - pass - -class DBusInvalidArgsException(dbus.exceptions.DBusException): - _dbus_error_name = "org.freedesktop.DBus.Error.InvalidArgs" - - -def format_uuid(uuid): - if type(uuid) == int: - if uuid > 0xFFFF: - raise ValueError("32-bit UUID not supported yet") - uuid = "%04X" % uuid - return uuid +def unwrap_properties(properties): + return {k: v.value for k, v in properties.items()} class TealBlue: - def __init__(self): - self._bus = dbus.SystemBus() - self._bluez = dbus.Interface( - self._bus.get_object("org.bluez", "/"), "org.freedesktop.DBus.ObjectManager" - ) + @classmethod + async def create(cls): + self = cls() + self._bus = await dbus_next.aio.MessageBus( + bus_type=dbus_next.constants.BusType.SYSTEM, negotiate_unix_fd=True + ).connect() + obj = await self.get_object("org.bluez", "/") + self._bluez = obj.get_interface("org.freedesktop.DBus.ObjectManager") + + return self - def find_adapter(self): - # find the first adapter - objects = self._bluez.GetManagedObjects() + async def find_adapter(self, mac_filter=""): + """Find the first adapter matching mac_filter.""" + objects = await self._bluez.call_get_managed_objects() for path in sorted(objects.keys()): interfaces = objects[path] if "org.bluez.Adapter1" not in interfaces: continue - properties = interfaces["org.bluez.Adapter1"] - return Adapter(self, path, properties) + properties = unwrap_properties(interfaces["org.bluez.Adapter1"]) + if mac_filter not in properties["Address"]: + continue + return await Adapter.create(self, path, properties) raise Exception("No adapter found") - # copied from: - # https://github.com/adafruit/Adafruit_Python_BluefruitLE/blob/master/Adafruit_BluefruitLE/bluez_dbus/provider.py - def _print_tree(self): - """Print tree of all bluez objects, useful for debugging.""" - # This is based on the bluez sample code get-managed-objects.py. - objects = self._bluez.GetManagedObjects() - for path in sorted(objects.keys()): - print("[ %s ]" % (path)) - interfaces = objects[path] - for interface in sorted(interfaces.keys()): - if interface in [ - "org.freedesktop.DBus.Introspectable", - "org.freedesktop.DBus.Properties", - ]: - continue - print(" %s" % (interface)) - properties = interfaces[interface] - for key in sorted(properties.keys()): - print(" %s = %s" % (key, properties[key])) + async def get_object(self, name, path): + introspection = await self._bus.introspect(name, path) + obj = self._bus.get_proxy_object(name, path, introspection) + return obj class Adapter: - def __init__(self, teal, path, properties): + @classmethod + async def create(cls, teal, path, properties): + self = cls() self._teal = teal self._path = path self._properties = properties - self._object = dbus.Interface( - teal._bus.get_object("org.bluez", path), "org.bluez.Adapter1" - ) - self._advertisement = None + obj = await self._teal.get_object("org.bluez", path) + self._object = obj.get_interface("org.bluez.Adapter1") + + return self def __repr__(self): return "" % (self._properties["Address"]) - def devices(self): + async def devices(self): """ Returns the devices that BlueZ has discovered. """ - objects = self._teal._bluez.GetManagedObjects() + objects = await self._teal._bluez.call_get_managed_objects() + devices = [] for path in sorted(objects.keys()): interfaces = objects[path] if "org.bluez.Device1" not in interfaces: continue - properties = interfaces["org.bluez.Device1"] - yield Device(self._teal, path, properties) + properties = unwrap_properties(interfaces["org.bluez.Device1"]) + devices.append(await Device.create(self._teal, path, properties)) - def scan(self, timeout_s): - return Scanner(self._teal, self, self.devices(), timeout_s) - - @property - def advertisement(self): - if self._advertisement is None: - self._advertisement = Advertisement(self._teal, self) - return self._advertisement - - def advertise(self, enable): - if enable: - self.advertisement.enable() - else: - self.advertisement.disable() + return devices - def advertise_data( - self, - local_name=None, - service_data=None, - service_uuids=None, - manufacturer_data=None, - ): - self.advertisement.local_name = local_name - self.advertisement.service_data = service_data - self.advertisement.service_uuids = service_uuids - self.advertisement.manufacturer_data = manufacturer_data + async def scan(self, timeout_s): + return await Scanner.create(self._teal, self, await self.devices(), timeout_s) class Scanner: - def __init__(self, teal, adapter, initial_devices, timeout_s): + @classmethod + async def create(cls, teal, adapter, initial_devices, timeout_s): + self = cls() self._teal = teal self._adapter = adapter self._was_discovering = adapter._properties[ "Discovering" ] # TODO get current value, or watch property changes - self._queue = queue.Queue() + self._queue = asyncio.Queue() self.timeout_s = timeout_s for device in initial_devices: - self._queue.put(device) + self._queue.put_nowait((device._path, device._properties)) - def new_device(path, interfaces): - if "org.bluez.Device1" not in interfaces: - return - if not path.startswith(self._adapter._path + "/"): - return - # properties = interfaces["org.bluez.Device1"] - self._queue.put(Device(self._teal, path, interfaces["org.bluez.Device1"])) - - self._signal_receiver = self._teal._bus.add_signal_receiver( - new_device, - dbus_interface="org.freedesktop.DBus.ObjectManager", - signal_name="InterfacesAdded", - ) + self._teal._bluez.on_interfaces_added(self._on_iface_added) if not self._was_discovering: - self._adapter._object.StartDiscovery() + await self._adapter._object.call_start_discovery() + + return self + + def _on_iface_added(self, path, interfaces): + if "org.bluez.Device1" not in interfaces: + return + if not path.startswith(self._adapter._path + "/"): + return + properties = unwrap_properties(interfaces["org.bluez.Device1"]) + self._queue.put_nowait((path, properties)) - def __enter__(self): + async def __aenter__(self): return self - def __exit__(self, type, value, traceback): + async def __aexit__(self, type, value, traceback): if not self._was_discovering: - self._adapter._object.StopDiscovery() - self._signal_receiver.remove() + await self._adapter._object.call_stop_discovery() + self._teal._bluez.off_interfaces_added(self._on_iface_added) - def __iter__(self): + def __aiter__(self): return self - def __next__(self): + async def __anext__(self): try: - return self._queue.get(timeout=self.timeout_s) - except queue.Empty: - raise StopIteration + (path, properties) = await asyncio.wait_for( + self._queue.get(), self.timeout_s + ) + return await Device.create(self._teal, path, properties) + except asyncio.TimeoutError: + raise StopAsyncIteration class Device: - def __init__(self, teal, path, properties): + @classmethod + async def create(cls, teal, path, properties): + self = cls() self._teal = teal self._path = path self._properties = properties - self._services_resolved = threading.Event() + self._services_resolved = asyncio.Event() self._services = None if properties["ServicesResolved"]: self._services_resolved.set() # Listen to device events (connect, disconnect, ServicesResolved, ...) - self._device = dbus.Interface( - teal._bus.get_object("org.bluez", path), "org.bluez.Device1" - ) - self._device_props = dbus.Interface( - self._device, "org.freedesktop.DBus.Properties" - ) - self._signal_receiver = self._device_props.connect_to_signal( - "PropertiesChanged", - lambda itf, ch, inv: self._on_prop_changed(itf, ch, inv), - ) + obj = await self._teal.get_object("org.bluez", path) + self._device = obj.get_interface("org.bluez.Device1") + obj = await self._teal.get_object("org.bluez", path) + self._device_props = obj.get_interface("org.freedesktop.DBus.Properties") + self._device_props.on_properties_changed(self._on_prop_changed) + + return self def __del__(self): - self._signal_receiver.remove() + self._device_props.off_properties_changed(self._on_prop_changed) def __repr__(self): return "" % (self.address, self.name) - def _on_prop_changed(self, properties, changed_props, invalidated_props): + def _on_prop_changed(self, _interface, changed_props, invalidated_props): + changed_props = unwrap_properties(changed_props) + LOG.debug( + f"prop changed: device {self._path} {changed_props.keys()} {invalidated_props}" + ) for key, value in changed_props.items(): self._properties[key] = value + for key in invalidated_props: + del self._properties[key] if "ServicesResolved" in changed_props: if changed_props["ServicesResolved"]: @@ -203,36 +172,33 @@ class Device: else: self._services_resolved.clear() - def _wait_for_discovery(self): - # wait until ServicesResolved is True - self._services_resolved.wait() - - def connect(self): - self._device.Connect() - - def disconnect(self): - self._device.Disconnect() + async def connect(self): + await self._device.call_connect() - def resolve_services(self): - self._services_resolved.wait() + async def disconnect(self): + await self._device.call_disconnect() - @property - def services(self): - if not self._services_resolved.is_set(): - return None + async def services(self): + await self._services_resolved.wait() if self._services is None: self._services = {} - objects = self._teal._bluez.GetManagedObjects() + objects = await self._teal._bluez.call_get_managed_objects() for path in sorted(objects.keys()): if not path.startswith(self._path + "/"): continue if "org.bluez.GattService1" in objects[path]: - properties = objects[path]["org.bluez.GattService1"] + properties = unwrap_properties( + objects[path]["org.bluez.GattService1"] + ) service = Service(self._teal, self, path, properties) self._services[service.uuid] = service elif "org.bluez.GattCharacteristic1" in objects[path]: - properties = objects[path]["org.bluez.GattCharacteristic1"] - characterstic = Characteristic(self._teal, self, path, properties) + properties = unwrap_properties( + objects[path]["org.bluez.GattCharacteristic1"] + ) + characterstic = await Characteristic.create( + self._teal, self, path, properties + ) for service in self._services.values(): if properties["Service"] == service._path: service.characteristics[characterstic.uuid] = characterstic @@ -287,22 +253,24 @@ class Service: class Characteristic: - def __init__(self, teal, device, path, properties): + def __init__(self): + self._properties = {} + + @classmethod + async def create(cls, teal, device, path, properties): + self = cls() self._device = device self._teal = teal self._path = path self._properties = properties + self._values = asyncio.Queue() - self.on_notify = None + obj = await self._teal.get_object("org.bluez", path) + self._char = obj.get_interface("org.bluez.GattCharacteristic1") + self._props = obj.get_interface("org.freedesktop.DBus.Properties") + self._props.on_properties_changed(self._on_prop_changed) - self._char = dbus.Interface( - teal._bus.get_object("org.bluez", path), "org.bluez.GattCharacteristic1" - ) - char_props = dbus.Interface(self._char, "org.freedesktop.DBus.Properties") - self._signal_receiver = char_props.connect_to_signal( - "PropertiesChanged", - lambda itf, ch, inv: self._on_prop_changed(itf, ch, inv), - ) + return self def __repr__(self): return "" % ( @@ -311,147 +279,72 @@ class Characteristic: ) def __del__(self): - self._signal_receiver.remove() + self._props.off_properties_changed(self._on_prop_changed) - def _on_prop_changed(self, properties, changed_props, invalidated_props): + def _on_prop_changed(self, _interface, changed_props, invalidated_props): + changed_props = unwrap_properties(changed_props) + LOG.debug( + f"prop changed: characteristic {changed_props.keys()} {invalidated_props}" + ) for key, value in changed_props.items(): self._properties[key] = bytes(value) + for key in invalidated_props: + del self._properties[key] - if "Value" in changed_props and self.on_notify is not None: - self.on_notify(self, changed_props["Value"]) - - def read(self): - return bytes(self._char.ReadValue({})) - - def write(self, value, command=True): - start = time.time() - try: - if command: - self._char.WriteValue(value, {"type": "command"}) - else: - self._char.WriteValue(value, {"type": "request"}) - - except dbus.DBusException as e: - if ( - e.get_dbus_name() == "org.bluez.Error.Failed" - and e.get_dbus_message() == "Not connected" - ): - raise NotConnectedError() - else: - raise # some other error - - # Workaround: if the write took very long, it is possible the connection - # broke (without causing an exception). So check whether we are still - # connected. - # I think this is a bug in BlueZ. - if time.time() - start > 0.5: - if not self._device._device_props.Get("org.bluez.Device1", "Connected"): - raise NotConnectedError() - - def start_notify(self): - self._char.StartNotify() + if "Value" in changed_props: + self._values.put_nowait(changed_props["Value"]) - def stop_notify(self): - self._char.StopNotify() + async def acquire(self, write: bool = False) -> tuple[io.FileIO, int]: + if write: + fd, mtu = await self._char.call_acquire_write({}) + mode = "w" + else: + fd, mtu = await self._char.call_acquire_notify({}) + mode = "r" + + f = io.FileIO(fd, mode) + LOG.debug(f"acquired {self.uuid} ({mode})") + return f, mtu + + async def read(self) -> bytes: + return bytes(await self._values.get()) + + async def write(self, value, command=True): + ty = "command" if command else "request" + await self._char.call_write_value(value, {"type": dbus_next.Variant("s", ty)}) + + # + # async def write(self, value, command=True): + # start = time.time() + # try: + # if command: + # await self._char.call_write_value(value, {"type": "command"}) + # else: + # await self._char.call_write_value(value, {"type": "request"}) + # + # except dbus_next.DBusError as e: + # if ( + # e.type == "org.bluez.Error.Failed" + # and e.text == "Not connected" + # ): + # raise NotConnectedError() + # else: + # raise # some other error + # + # # Workaround: if the write took very long, it is possible the connection + # # broke (without causing an exception). So check whether we are still + # # connected. + # # I think this is a bug in BlueZ. + # if time.time() - start > 0.5: + # if not self._device._device_props.call_get("org.bluez.Device1", "Connected"): + # raise NotConnectedError() + + async def start_notify(self): + await self._char.call_start_notify() + + async def stop_notify(self): + await self._char.call_stop_notify() @property def uuid(self): return str(self._properties["UUID"]) - - -class Advertisement(dbus.service.Object): - PATH = "/com/github/aykevl/pynus/advertisement" - - def __init__(self, teal, adapter): - self._teal = teal - self._adapter = adapter - self._enabled = False - self.service_uuids = None - self.manufacturer_data = None - self.solicit_uuids = None - self.service_data = None - self.local_name = None - self.include_tx_power = None - self._manager = dbus.Interface( - teal._bus.get_object("org.bluez", self._adapter._path), - "org.bluez.LEAdvertisingManager1", - ) - self._adv_enabled = threading.Event() - dbus.service.Object.__init__(self, teal._bus, self.PATH) - - def enable(self): - if self._enabled: - return - self._manager.RegisterAdvertisement( - dbus.ObjectPath(self.PATH), - dbus.Dictionary({}, signature="sv"), - reply_handler=self._cb_enabled, - error_handler=self._cb_enabled_err, - ) - self._adv_enabled.wait() - self._adv_enabled.clear() - - def _cb_enabled(self): - self._enabled = True - if self._adv_enabled.is_set(): - raise RuntimeError("called enable() twice") - self._adv_enabled.set() - - def _cb_enabled_err(self, err): - self._enabled = False - if self._adv_enabled.is_set(): - raise RuntimeError("called enable() twice") - self._adv_enabled.set() - - def disable(self): - if not self._enabled: - return - self._bus.UnregisterAdvertisement(self.PATH) - self._enabled = False - - @property - def enabled(self): - return self._enabled - - @dbus.service.method( - "org.freedesktop.DBus.Properties", in_signature="s", out_signature="a{sv}" - ) - def GetAll(self, interface): - print("GetAll") - if interface != "org.bluez.LEAdvertisement1": - raise DBusInvalidArgsException() - - try: - properties = { - "Type": dbus.String("peripheral"), - } - if self.service_uuids is not None: - properties["ServiceUUIDs"] = dbus.Array( - map(format_uuid, self.service_uuids), signature="s" - ) - if self.solicit_uuids is not None: - properties["SolicitUUIDs"] = dbus.Array( - map(format_uuid, self.solicit_uuids), signature="s" - ) - if self.manufacturer_data is not None: - properties["ManufacturerData"] = dbus.Dictionary( - {k: v for k, v in self.manufacturer_data.items()}, signature="qv" - ) - if self.service_data is not None: - properties["ServiceData"] = dbus.Dictionary( - self.service_data, signature="sv" - ) - if self.local_name is not None: - properties["LocalName"] = dbus.String(self.local_name) - if self.include_tx_power is not None: - properties["IncludeTxPower"] = dbus.Boolean(self.include_tx_power) - except Exception as e: - print("err: ", e) - print("properties:", properties) - return properties - - @dbus.service.method( - "org.bluez.LEAdvertisement1", in_signature="", out_signature="" - ) - def Release(self): - self._enabled = True diff --git a/python/src/trezorlib/transport/ble.py b/python/src/trezorlib/transport/ble.py index ee2474560..c8a7415f1 100644 --- a/python/src/trezorlib/transport/ble.py +++ b/python/src/trezorlib/transport/ble.py @@ -13,12 +13,13 @@ # # You should have received a copy of the License along with this library. # If not, see . +import asyncio import logging -from queue import Queue +from dataclasses import dataclass +from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection from typing import TYPE_CHECKING, Any, Iterable, List, Optional -from .. import tealblue -from ..tealblue import Adapter, Characteristic from . import TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 @@ -32,65 +33,35 @@ NUS_CHARACTERISTIC_RX = "6e400002-b5a3-f393-e0a9-e50e24dcca9e" NUS_CHARACTERISTIC_TX = "6e400003-b5a3-f393-e0a9-e50e24dcca9e" -def scan_device(adapter: Adapter, devices: List[tealblue.Device]): - with adapter.scan(2) as scanner: - for device in scanner: - if NUS_SERVICE_UUID in device.UUIDs: - if device.address not in [d.address for d in devices]: - print(f"Found device: {device.name}, {device.address}") - devices.append(device) - return devices - - -def lookup_device(adapter: Adapter): - devices = [] - for device in adapter.devices(): - if NUS_SERVICE_UUID in device.UUIDs: - devices.append(device) - return devices +@dataclass +class Device: + address: str + name: str + connected: bool class BleTransport(ProtocolBasedTransport): ENABLED = True PATH_PREFIX = "ble" - def __init__(self, mac_addr: str, adapter: Adapter) -> None: + _ble = None - self.tx = None - self.rx = None + def __init__(self, mac_addr: str) -> None: self.device = mac_addr - self.adapter = adapter - self.received_data = Queue() - - devices = lookup_device(self.adapter) - - for d in devices: - if d.address == mac_addr: - self.ble_device = d - break - super().__init__(protocol=ProtocolV1(self, replen=244)) def get_path(self) -> str: return "{}:{}".format(self.PATH_PREFIX, self.device) def find_debug(self) -> "BleTransport": - mac = self.device - return BleTransport(f"{mac}", self.adapter) + return BleTransport(self.device) @classmethod def enumerate( cls, _models: Optional[Iterable["TrezorModel"]] = None ) -> Iterable["BleTransport"]: - adapter = tealblue.TealBlue().find_adapter() - devices = lookup_device(adapter) - - for device in devices: - print(f"Device: {device.name}, {device.address}") - - devices = [d for d in devices if d.connected] - - return [BleTransport(device.address, adapter) for device in devices] + devices = cls.ble().lookup() + return [BleTransport(device.address) for device in devices if device.connected] @classmethod def _try_path(cls, path: str) -> "BleTransport": @@ -111,44 +82,167 @@ class BleTransport(ProtocolBasedTransport): raise TransportException(f"No BLE device: {path}") def open(self) -> None: - - if not self.ble_device.connected: - print( - "Connecting to %s (%s)..." - % (self.ble_device.name, self.ble_device.address) - ) - self.ble_device.connect() - else: - print( - "Connected to %s (%s)." - % (self.ble_device.name, self.ble_device.address) - ) - - if not self.ble_device.services_resolved: - print("Resolving services...") - self.ble_device.resolve_services() - - service = self.ble_device.services[NUS_SERVICE_UUID] - self.rx = service.characteristics[NUS_CHARACTERISTIC_RX] - self.tx = service.characteristics[NUS_CHARACTERISTIC_TX] - - def on_notify(characteristic: Characteristic, value: Any): - self.received_data.put(bytes(value)) - - self.tx.on_notify = on_notify - self.tx.start_notify() + self.ble().connect(self.device) def close(self) -> None: pass def write_chunk(self, chunk: bytes) -> None: - assert self.rx is not None - self.rx.write(chunk) + self.ble().write(chunk) def read_chunk(self) -> bytes: - assert self.tx is not None - chunk = self.received_data.get() + chunk = self.ble().read() # LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}") if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return bytearray(chunk) + + @classmethod + def ble(cls) -> "BleProxy": + if cls._ble is None: + cls._ble = BleProxy() + return cls._ble + + +class BleProxy: + pipe = None + process = None + + def __init__(self): + if self.pipe is not None: + return + + parent_pipe, child_pipe = Pipe() + self.pipe = parent_pipe + self.process = Process(target=BleAsync, args=(child_pipe,), daemon=True) + self.process.start() + + def __getattr__(self, name: str): + def f(*args: Any, **kwargs: Any): + assert self.pipe is not None + self.pipe.send((name, args, kwargs)) + result = self.pipe.recv() + if isinstance(result, BaseException): + raise result + return result + + return f + + +class BleAsync: + def __init__(self, pipe: Connection): + asyncio.run(self.main(pipe)) + + async def main(self, pipe: Connection): + from ..tealblue import TealBlue + + tb = await TealBlue.create() + # TODO: add cli option for mac_filter and pass it here + self.adapter = await tb.find_adapter() + # TODO: currently only one concurrent device is supported + # To support more devices, connect() needs to return a Connection and also has to + # spawn a task that will forward data between that Connection and rx,tx. + self.current = None + self.rx = None + self.tx = None + + self.devices = {} + await self.lookup() # populate self.devices + LOG.debug("async BLE process started") + + while True: + await ready(pipe) + cmd, args, kwargs = pipe.recv() + try: + result = await getattr(self, cmd)(*args, **kwargs) + except Exception as e: + LOG.exception("Error in async BLE process:") + await ready(pipe, write=True) + pipe.send(e) + break + else: + await ready(pipe, write=True) + pipe.send(result) + + async def lookup(self) -> List[Device]: + self.devices.clear() + for device in await self.adapter.devices(): + if NUS_SERVICE_UUID in device.UUIDs: + self.devices[device.address] = device + return [ + Device(device.address, device.name, device.connected) + for device in self.devices.values() + ] + + async def scan(self) -> List[Device]: + LOG.debug("Initiating scan") + # TODO: configurable timeout + scanner = await self.adapter.scan(2) + self.devices.clear() + async with scanner: + async for device in scanner: + if NUS_SERVICE_UUID in device.UUIDs: + if device.address not in self.devices: + LOG.debug(f"scan: {device.address}: {device.name}") + self.devices[device.address] = device + return [ + Device(device.address, device.name, device.connected) + for device in self.devices.values() + ] + + async def connect(self, address: str): + if self.current == address: + return + # elif self.current is not None: + # self.devices[self.current].disconnect() + + ble_device = self.devices[address] + if not ble_device.connected: + LOG.info("Connecting to %s (%s)..." % (ble_device.name, ble_device.address)) + await ble_device.connect() + else: + LOG.info("Connected to %s (%s)." % (ble_device.name, ble_device.address)) + + services = await ble_device.services() + nus_service = services[NUS_SERVICE_UUID] + self.rx, _mtu = await nus_service.characteristics[ + NUS_CHARACTERISTIC_RX + ].acquire(write=True) + self.tx, _mtu = await nus_service.characteristics[ + NUS_CHARACTERISTIC_TX + ].acquire() + self.current = address + + async def read(self): + assert self.tx is not None + await ready(self.tx) + return self.tx.read() + + async def write(self, chunk: bytes): + assert self.rx is not None + await ready(self.rx, write=True) + self.rx.write(chunk) + + +async def ready(f: Any, write: bool = False): + """Asynchronously wait for file-like object to become ready for reading or writing.""" + fd = f.fileno() + loop = asyncio.get_event_loop() + event = asyncio.Event() + + if write: + + def callback(): + event.set() + loop.remove_writer(fd) + + loop.add_writer(fd, callback) + else: + + def callback(): + event.set() + loop.remove_reader(fd) + + loop.add_reader(fd, callback) + + await event.wait()