1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

feat(python): use dbus-next for BLE

This commit is contained in:
Martin Milata 2023-06-27 20:19:08 +02:00 committed by tychovrahe
parent 331f337f6b
commit 502f6a065c
8 changed files with 372 additions and 475 deletions

View File

@ -77,9 +77,7 @@ binsize = "^0.1.3"
toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"}
# ble
dbus-python = "*"
PyGObject = "*"
nrfutil = "*"
dbus-next = "*"
[tool.poetry.dev-dependencies]
scan-build = "*"

View File

@ -8,6 +8,4 @@ typing_extensions>=3.10
dataclasses ; python_version<'3.7'
simple-rlp>=0.1.2 ; python_version>='3.7'
construct-classes>=0.1.2
dbus-python>=1.3.2
pygobject>=3.44.1
nrfutil>=5.0.0
dbus-next>=0.2.3

View File

@ -25,7 +25,7 @@ per-file-ignores =
helper-scripts/*:I
tools/*:I
tests/*:I
known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore],simple-rlp:[rlp],dbus-python:[dbus]
known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore],simple-rlp:[rlp],dbus-next:[dbus_next]
[isort]
profile = black

View File

@ -16,17 +16,12 @@
import functools
import sys
import threading
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click
import dbus
import dbus.mainloop.glib
import dbus.service
from gi.repository import GLib
from .. import exceptions, messages, transport
from .. import exceptions, transport
from ..client import TrezorClient
from ..ui import ClickUI, ScriptUI
@ -109,12 +104,6 @@ class TrezorConnection:
except transport.DeviceIsBusy:
click.echo("Device is in use by another process.")
sys.exit(1)
except exceptions.TrezorFailure as e:
if e.code is messages.FailureType.DeviceIsBusy:
click.echo(str(e))
sys.exit(1)
else:
raise e
except Exception:
click.echo("Failed to find a Trezor device.")
if self.path is not None:
@ -146,95 +135,26 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
def trezorctl_command_with_client(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R":
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
loop = GLib.MainLoop()
dbus.mainloop.glib.DBusGMainLoop(set_as_default=True)
def callback_wrapper(
r: List[Optional["R"]], exc: List[Optional[Exception]]
) -> None:
try:
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
try:
r.append(func(client, *args, **kwargs))
except Exception as e:
exc[0] = e
finally:
if not session_was_resumed:
try:
client.end_session()
except Exception:
pass
except Exception as e:
exc[0] = e
return func(client, *args, **kwargs)
finally:
loop.quit()
result: List["R"] = []
exc: List[Optional[Exception]] = [None]
threading.Thread(
target=callback_wrapper, daemon=True, args=(result, exc)
).start()
loop.run()
if exc[0] is not None:
raise exc[0]
if len(result) == 0:
raise click.ClickException("Command did not return a result.")
return result[0]
if not session_was_resumed:
try:
client.end_session()
except Exception:
pass
# the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return trezorctl_command_with_client # type: ignore [cannot be assigned to return type]
def with_ble(func: "Callable[P, R]") -> "Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
cleanly. The session is closed after the command completes - unless the session
was resumed, in which case it should remain open.
"""
@functools.wraps(func)
def trezorctl_command(*args: "P.args", **kwargs: "P.kwargs") -> "R":
loop = GLib.MainLoop()
dbus.mainloop.glib.DBusGMainLoop(set_as_default=True)
def callback_wrapper(r: List[Optional["R"]], exc: List[Optional[Exception]]):
try:
r.append(func(*args, **kwargs))
except Exception as e:
exc[0] = e
finally:
loop.quit()
result: List["R"] = []
exc: List[Optional[Exception]] = [None]
threading.Thread(
target=callback_wrapper, daemon=True, args=(result, exc)
).start()
loop.run()
if exc[0] is not None:
raise exc[0]
if len(result) == 0:
raise click.ClickException("Command did not return a result.")
return result[0]
return trezorctl_command
class AliasedGroup(click.Group):
"""Command group that handles aliases and Click 6.x compatibility.

View File

@ -20,9 +20,9 @@ from typing import TYPE_CHECKING, BinaryIO
import click
from .. import ble, exceptions, tealblue
from ..transport.ble import lookup_device, scan_device
from . import with_ble, with_client
from .. import ble, exceptions
from ..transport.ble import BleProxy
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
@ -64,32 +64,28 @@ def update(
@cli.command()
@with_ble
def connect() -> None:
"""Connect to the device via BLE."""
adapter = tealblue.TealBlue().find_adapter()
devices = lookup_device(adapter)
devices = [d for d in devices if d.connected]
ble = BleProxy()
devices = [d for d in ble.lookup() if d.connected]
if len(devices) == 0:
print("Scanning...")
devices = scan_device(adapter, devices)
click.echo("Scanning...")
devices = ble.scan()
if len(devices) == 0:
print("No BLE devices found")
click.echo("No BLE devices found")
return
else:
print("Found %d BLE device(s)" % len(devices))
click.echo("Found %d BLE device(s)" % len(devices))
for device in devices:
print(f"Device: {device.name}, {device.address}")
click.echo(f"Device: {device.name}, {device.address}")
device = devices[0]
print(f"Connecting to {device.name}...")
device.connect()
print("Connected")
click.echo(f"Connecting to {device.name}...")
ble.connect(device.address)
click.echo("Connected")
@cli.command()

View File

@ -49,7 +49,6 @@ from . import (
settings,
stellar,
tezos,
with_ble,
with_client,
)
@ -282,7 +281,6 @@ def format_device_name(features: messages.Features) -> str:
@cli.command(name="list")
@with_ble
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices."""

View File

@ -1,201 +1,170 @@
# !/usr/bin/python3
# pyright: off
import queue
import threading
import time
import asyncio
import io
import logging
import dbus
import dbus.mainloop.glib
import dbus.service
import dbus_next
LOG = logging.getLogger(__name__)
class NotConnectedError(Exception):
pass
class DBusInvalidArgsException(dbus.exceptions.DBusException):
_dbus_error_name = "org.freedesktop.DBus.Error.InvalidArgs"
def format_uuid(uuid):
if type(uuid) == int:
if uuid > 0xFFFF:
raise ValueError("32-bit UUID not supported yet")
uuid = "%04X" % uuid
return uuid
def unwrap_properties(properties):
return {k: v.value for k, v in properties.items()}
class TealBlue:
def __init__(self):
self._bus = dbus.SystemBus()
self._bluez = dbus.Interface(
self._bus.get_object("org.bluez", "/"), "org.freedesktop.DBus.ObjectManager"
)
@classmethod
async def create(cls):
self = cls()
self._bus = await dbus_next.aio.MessageBus(
bus_type=dbus_next.constants.BusType.SYSTEM, negotiate_unix_fd=True
).connect()
obj = await self.get_object("org.bluez", "/")
self._bluez = obj.get_interface("org.freedesktop.DBus.ObjectManager")
def find_adapter(self):
# find the first adapter
objects = self._bluez.GetManagedObjects()
return self
async def find_adapter(self, mac_filter=""):
"""Find the first adapter matching mac_filter."""
objects = await self._bluez.call_get_managed_objects()
for path in sorted(objects.keys()):
interfaces = objects[path]
if "org.bluez.Adapter1" not in interfaces:
continue
properties = interfaces["org.bluez.Adapter1"]
return Adapter(self, path, properties)
properties = unwrap_properties(interfaces["org.bluez.Adapter1"])
if mac_filter not in properties["Address"]:
continue
return await Adapter.create(self, path, properties)
raise Exception("No adapter found")
# copied from:
# https://github.com/adafruit/Adafruit_Python_BluefruitLE/blob/master/Adafruit_BluefruitLE/bluez_dbus/provider.py
def _print_tree(self):
"""Print tree of all bluez objects, useful for debugging."""
# This is based on the bluez sample code get-managed-objects.py.
objects = self._bluez.GetManagedObjects()
for path in sorted(objects.keys()):
print("[ %s ]" % (path))
interfaces = objects[path]
for interface in sorted(interfaces.keys()):
if interface in [
"org.freedesktop.DBus.Introspectable",
"org.freedesktop.DBus.Properties",
]:
continue
print(" %s" % (interface))
properties = interfaces[interface]
for key in sorted(properties.keys()):
print(" %s = %s" % (key, properties[key]))
async def get_object(self, name, path):
introspection = await self._bus.introspect(name, path)
obj = self._bus.get_proxy_object(name, path, introspection)
return obj
class Adapter:
def __init__(self, teal, path, properties):
@classmethod
async def create(cls, teal, path, properties):
self = cls()
self._teal = teal
self._path = path
self._properties = properties
self._object = dbus.Interface(
teal._bus.get_object("org.bluez", path), "org.bluez.Adapter1"
)
self._advertisement = None
obj = await self._teal.get_object("org.bluez", path)
self._object = obj.get_interface("org.bluez.Adapter1")
return self
def __repr__(self):
return "<tealblue.Adapter address=%s>" % (self._properties["Address"])
def devices(self):
async def devices(self):
"""
Returns the devices that BlueZ has discovered.
"""
objects = self._teal._bluez.GetManagedObjects()
objects = await self._teal._bluez.call_get_managed_objects()
devices = []
for path in sorted(objects.keys()):
interfaces = objects[path]
if "org.bluez.Device1" not in interfaces:
continue
properties = interfaces["org.bluez.Device1"]
yield Device(self._teal, path, properties)
properties = unwrap_properties(interfaces["org.bluez.Device1"])
devices.append(await Device.create(self._teal, path, properties))
def scan(self, timeout_s):
return Scanner(self._teal, self, self.devices(), timeout_s)
return devices
@property
def advertisement(self):
if self._advertisement is None:
self._advertisement = Advertisement(self._teal, self)
return self._advertisement
def advertise(self, enable):
if enable:
self.advertisement.enable()
else:
self.advertisement.disable()
def advertise_data(
self,
local_name=None,
service_data=None,
service_uuids=None,
manufacturer_data=None,
):
self.advertisement.local_name = local_name
self.advertisement.service_data = service_data
self.advertisement.service_uuids = service_uuids
self.advertisement.manufacturer_data = manufacturer_data
async def scan(self, timeout_s):
return await Scanner.create(self._teal, self, await self.devices(), timeout_s)
class Scanner:
def __init__(self, teal, adapter, initial_devices, timeout_s):
@classmethod
async def create(cls, teal, adapter, initial_devices, timeout_s):
self = cls()
self._teal = teal
self._adapter = adapter
self._was_discovering = adapter._properties[
"Discovering"
] # TODO get current value, or watch property changes
self._queue = queue.Queue()
self._queue = asyncio.Queue()
self.timeout_s = timeout_s
for device in initial_devices:
self._queue.put(device)
self._queue.put_nowait((device._path, device._properties))
def new_device(path, interfaces):
if "org.bluez.Device1" not in interfaces:
return
if not path.startswith(self._adapter._path + "/"):
return
# properties = interfaces["org.bluez.Device1"]
self._queue.put(Device(self._teal, path, interfaces["org.bluez.Device1"]))
self._signal_receiver = self._teal._bus.add_signal_receiver(
new_device,
dbus_interface="org.freedesktop.DBus.ObjectManager",
signal_name="InterfacesAdded",
)
self._teal._bluez.on_interfaces_added(self._on_iface_added)
if not self._was_discovering:
self._adapter._object.StartDiscovery()
await self._adapter._object.call_start_discovery()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
if not self._was_discovering:
self._adapter._object.StopDiscovery()
self._signal_receiver.remove()
def _on_iface_added(self, path, interfaces):
if "org.bluez.Device1" not in interfaces:
return
if not path.startswith(self._adapter._path + "/"):
return
properties = unwrap_properties(interfaces["org.bluez.Device1"])
self._queue.put_nowait((path, properties))
def __iter__(self):
async def __aenter__(self):
return self
def __next__(self):
async def __aexit__(self, type, value, traceback):
if not self._was_discovering:
await self._adapter._object.call_stop_discovery()
self._teal._bluez.off_interfaces_added(self._on_iface_added)
def __aiter__(self):
return self
async def __anext__(self):
try:
return self._queue.get(timeout=self.timeout_s)
except queue.Empty:
raise StopIteration
(path, properties) = await asyncio.wait_for(
self._queue.get(), self.timeout_s
)
return await Device.create(self._teal, path, properties)
except asyncio.TimeoutError:
raise StopAsyncIteration
class Device:
def __init__(self, teal, path, properties):
@classmethod
async def create(cls, teal, path, properties):
self = cls()
self._teal = teal
self._path = path
self._properties = properties
self._services_resolved = threading.Event()
self._services_resolved = asyncio.Event()
self._services = None
if properties["ServicesResolved"]:
self._services_resolved.set()
# Listen to device events (connect, disconnect, ServicesResolved, ...)
self._device = dbus.Interface(
teal._bus.get_object("org.bluez", path), "org.bluez.Device1"
)
self._device_props = dbus.Interface(
self._device, "org.freedesktop.DBus.Properties"
)
self._signal_receiver = self._device_props.connect_to_signal(
"PropertiesChanged",
lambda itf, ch, inv: self._on_prop_changed(itf, ch, inv),
)
obj = await self._teal.get_object("org.bluez", path)
self._device = obj.get_interface("org.bluez.Device1")
obj = await self._teal.get_object("org.bluez", path)
self._device_props = obj.get_interface("org.freedesktop.DBus.Properties")
self._device_props.on_properties_changed(self._on_prop_changed)
return self
def __del__(self):
self._signal_receiver.remove()
self._device_props.off_properties_changed(self._on_prop_changed)
def __repr__(self):
return "<tealblue.Device address=%s name=%r>" % (self.address, self.name)
def _on_prop_changed(self, properties, changed_props, invalidated_props):
def _on_prop_changed(self, _interface, changed_props, invalidated_props):
changed_props = unwrap_properties(changed_props)
LOG.debug(
f"prop changed: device {self._path} {changed_props.keys()} {invalidated_props}"
)
for key, value in changed_props.items():
self._properties[key] = value
for key in invalidated_props:
del self._properties[key]
if "ServicesResolved" in changed_props:
if changed_props["ServicesResolved"]:
@ -203,36 +172,33 @@ class Device:
else:
self._services_resolved.clear()
def _wait_for_discovery(self):
# wait until ServicesResolved is True
self._services_resolved.wait()
async def connect(self):
await self._device.call_connect()
def connect(self):
self._device.Connect()
async def disconnect(self):
await self._device.call_disconnect()
def disconnect(self):
self._device.Disconnect()
def resolve_services(self):
self._services_resolved.wait()
@property
def services(self):
if not self._services_resolved.is_set():
return None
async def services(self):
await self._services_resolved.wait()
if self._services is None:
self._services = {}
objects = self._teal._bluez.GetManagedObjects()
objects = await self._teal._bluez.call_get_managed_objects()
for path in sorted(objects.keys()):
if not path.startswith(self._path + "/"):
continue
if "org.bluez.GattService1" in objects[path]:
properties = objects[path]["org.bluez.GattService1"]
properties = unwrap_properties(
objects[path]["org.bluez.GattService1"]
)
service = Service(self._teal, self, path, properties)
self._services[service.uuid] = service
elif "org.bluez.GattCharacteristic1" in objects[path]:
properties = objects[path]["org.bluez.GattCharacteristic1"]
characterstic = Characteristic(self._teal, self, path, properties)
properties = unwrap_properties(
objects[path]["org.bluez.GattCharacteristic1"]
)
characterstic = await Characteristic.create(
self._teal, self, path, properties
)
for service in self._services.values():
if properties["Service"] == service._path:
service.characteristics[characterstic.uuid] = characterstic
@ -287,22 +253,24 @@ class Service:
class Characteristic:
def __init__(self, teal, device, path, properties):
def __init__(self):
self._properties = {}
@classmethod
async def create(cls, teal, device, path, properties):
self = cls()
self._device = device
self._teal = teal
self._path = path
self._properties = properties
self._values = asyncio.Queue()
self.on_notify = None
obj = await self._teal.get_object("org.bluez", path)
self._char = obj.get_interface("org.bluez.GattCharacteristic1")
self._props = obj.get_interface("org.freedesktop.DBus.Properties")
self._props.on_properties_changed(self._on_prop_changed)
self._char = dbus.Interface(
teal._bus.get_object("org.bluez", path), "org.bluez.GattCharacteristic1"
)
char_props = dbus.Interface(self._char, "org.freedesktop.DBus.Properties")
self._signal_receiver = char_props.connect_to_signal(
"PropertiesChanged",
lambda itf, ch, inv: self._on_prop_changed(itf, ch, inv),
)
return self
def __repr__(self):
return "<tealblue.Characteristic device=%s uuid=%s>" % (
@ -311,147 +279,72 @@ class Characteristic:
)
def __del__(self):
self._signal_receiver.remove()
self._props.off_properties_changed(self._on_prop_changed)
def _on_prop_changed(self, properties, changed_props, invalidated_props):
def _on_prop_changed(self, _interface, changed_props, invalidated_props):
changed_props = unwrap_properties(changed_props)
LOG.debug(
f"prop changed: characteristic {changed_props.keys()} {invalidated_props}"
)
for key, value in changed_props.items():
self._properties[key] = bytes(value)
for key in invalidated_props:
del self._properties[key]
if "Value" in changed_props and self.on_notify is not None:
self.on_notify(self, changed_props["Value"])
if "Value" in changed_props:
self._values.put_nowait(changed_props["Value"])
def read(self):
return bytes(self._char.ReadValue({}))
async def acquire(self, write: bool = False) -> tuple[io.FileIO, int]:
if write:
fd, mtu = await self._char.call_acquire_write({})
mode = "w"
else:
fd, mtu = await self._char.call_acquire_notify({})
mode = "r"
def write(self, value, command=True):
start = time.time()
try:
if command:
self._char.WriteValue(value, {"type": "command"})
else:
self._char.WriteValue(value, {"type": "request"})
f = io.FileIO(fd, mode)
LOG.debug(f"acquired {self.uuid} ({mode})")
return f, mtu
except dbus.DBusException as e:
if (
e.get_dbus_name() == "org.bluez.Error.Failed"
and e.get_dbus_message() == "Not connected"
):
raise NotConnectedError()
else:
raise # some other error
async def read(self) -> bytes:
return bytes(await self._values.get())
# Workaround: if the write took very long, it is possible the connection
# broke (without causing an exception). So check whether we are still
# connected.
# I think this is a bug in BlueZ.
if time.time() - start > 0.5:
if not self._device._device_props.Get("org.bluez.Device1", "Connected"):
raise NotConnectedError()
async def write(self, value, command=True):
ty = "command" if command else "request"
await self._char.call_write_value(value, {"type": dbus_next.Variant("s", ty)})
def start_notify(self):
self._char.StartNotify()
#
# async def write(self, value, command=True):
# start = time.time()
# try:
# if command:
# await self._char.call_write_value(value, {"type": "command"})
# else:
# await self._char.call_write_value(value, {"type": "request"})
#
# except dbus_next.DBusError as e:
# if (
# e.type == "org.bluez.Error.Failed"
# and e.text == "Not connected"
# ):
# raise NotConnectedError()
# else:
# raise # some other error
#
# # Workaround: if the write took very long, it is possible the connection
# # broke (without causing an exception). So check whether we are still
# # connected.
# # I think this is a bug in BlueZ.
# if time.time() - start > 0.5:
# if not self._device._device_props.call_get("org.bluez.Device1", "Connected"):
# raise NotConnectedError()
def stop_notify(self):
self._char.StopNotify()
async def start_notify(self):
await self._char.call_start_notify()
async def stop_notify(self):
await self._char.call_stop_notify()
@property
def uuid(self):
return str(self._properties["UUID"])
class Advertisement(dbus.service.Object):
PATH = "/com/github/aykevl/pynus/advertisement"
def __init__(self, teal, adapter):
self._teal = teal
self._adapter = adapter
self._enabled = False
self.service_uuids = None
self.manufacturer_data = None
self.solicit_uuids = None
self.service_data = None
self.local_name = None
self.include_tx_power = None
self._manager = dbus.Interface(
teal._bus.get_object("org.bluez", self._adapter._path),
"org.bluez.LEAdvertisingManager1",
)
self._adv_enabled = threading.Event()
dbus.service.Object.__init__(self, teal._bus, self.PATH)
def enable(self):
if self._enabled:
return
self._manager.RegisterAdvertisement(
dbus.ObjectPath(self.PATH),
dbus.Dictionary({}, signature="sv"),
reply_handler=self._cb_enabled,
error_handler=self._cb_enabled_err,
)
self._adv_enabled.wait()
self._adv_enabled.clear()
def _cb_enabled(self):
self._enabled = True
if self._adv_enabled.is_set():
raise RuntimeError("called enable() twice")
self._adv_enabled.set()
def _cb_enabled_err(self, err):
self._enabled = False
if self._adv_enabled.is_set():
raise RuntimeError("called enable() twice")
self._adv_enabled.set()
def disable(self):
if not self._enabled:
return
self._bus.UnregisterAdvertisement(self.PATH)
self._enabled = False
@property
def enabled(self):
return self._enabled
@dbus.service.method(
"org.freedesktop.DBus.Properties", in_signature="s", out_signature="a{sv}"
)
def GetAll(self, interface):
print("GetAll")
if interface != "org.bluez.LEAdvertisement1":
raise DBusInvalidArgsException()
try:
properties = {
"Type": dbus.String("peripheral"),
}
if self.service_uuids is not None:
properties["ServiceUUIDs"] = dbus.Array(
map(format_uuid, self.service_uuids), signature="s"
)
if self.solicit_uuids is not None:
properties["SolicitUUIDs"] = dbus.Array(
map(format_uuid, self.solicit_uuids), signature="s"
)
if self.manufacturer_data is not None:
properties["ManufacturerData"] = dbus.Dictionary(
{k: v for k, v in self.manufacturer_data.items()}, signature="qv"
)
if self.service_data is not None:
properties["ServiceData"] = dbus.Dictionary(
self.service_data, signature="sv"
)
if self.local_name is not None:
properties["LocalName"] = dbus.String(self.local_name)
if self.include_tx_power is not None:
properties["IncludeTxPower"] = dbus.Boolean(self.include_tx_power)
except Exception as e:
print("err: ", e)
print("properties:", properties)
return properties
@dbus.service.method(
"org.bluez.LEAdvertisement1", in_signature="", out_signature=""
)
def Release(self):
self._enabled = True

View File

@ -13,12 +13,13 @@
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import asyncio
import logging
from queue import Queue
from dataclasses import dataclass
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
from .. import tealblue
from ..tealblue import Adapter, Characteristic
from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
@ -32,65 +33,35 @@ NUS_CHARACTERISTIC_RX = "6e400002-b5a3-f393-e0a9-e50e24dcca9e"
NUS_CHARACTERISTIC_TX = "6e400003-b5a3-f393-e0a9-e50e24dcca9e"
def scan_device(adapter: Adapter, devices: List[tealblue.Device]):
with adapter.scan(2) as scanner:
for device in scanner:
if NUS_SERVICE_UUID in device.UUIDs:
if device.address not in [d.address for d in devices]:
print(f"Found device: {device.name}, {device.address}")
devices.append(device)
return devices
def lookup_device(adapter: Adapter):
devices = []
for device in adapter.devices():
if NUS_SERVICE_UUID in device.UUIDs:
devices.append(device)
return devices
@dataclass
class Device:
address: str
name: str
connected: bool
class BleTransport(ProtocolBasedTransport):
ENABLED = True
PATH_PREFIX = "ble"
def __init__(self, mac_addr: str, adapter: Adapter) -> None:
_ble = None
self.tx = None
self.rx = None
def __init__(self, mac_addr: str) -> None:
self.device = mac_addr
self.adapter = adapter
self.received_data = Queue()
devices = lookup_device(self.adapter)
for d in devices:
if d.address == mac_addr:
self.ble_device = d
break
super().__init__(protocol=ProtocolV1(self, replen=244))
def get_path(self) -> str:
return "{}:{}".format(self.PATH_PREFIX, self.device)
def find_debug(self) -> "BleTransport":
mac = self.device
return BleTransport(f"{mac}", self.adapter)
return BleTransport(self.device)
@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["BleTransport"]:
adapter = tealblue.TealBlue().find_adapter()
devices = lookup_device(adapter)
for device in devices:
print(f"Device: {device.name}, {device.address}")
devices = [d for d in devices if d.connected]
return [BleTransport(device.address, adapter) for device in devices]
devices = cls.ble().lookup()
return [BleTransport(device.address) for device in devices if device.connected]
@classmethod
def _try_path(cls, path: str) -> "BleTransport":
@ -111,44 +82,167 @@ class BleTransport(ProtocolBasedTransport):
raise TransportException(f"No BLE device: {path}")
def open(self) -> None:
if not self.ble_device.connected:
print(
"Connecting to %s (%s)..."
% (self.ble_device.name, self.ble_device.address)
)
self.ble_device.connect()
else:
print(
"Connected to %s (%s)."
% (self.ble_device.name, self.ble_device.address)
)
if not self.ble_device.services_resolved:
print("Resolving services...")
self.ble_device.resolve_services()
service = self.ble_device.services[NUS_SERVICE_UUID]
self.rx = service.characteristics[NUS_CHARACTERISTIC_RX]
self.tx = service.characteristics[NUS_CHARACTERISTIC_TX]
def on_notify(characteristic: Characteristic, value: Any):
self.received_data.put(bytes(value))
self.tx.on_notify = on_notify
self.tx.start_notify()
self.ble().connect(self.device)
def close(self) -> None:
pass
def write_chunk(self, chunk: bytes) -> None:
assert self.rx is not None
self.rx.write(chunk)
self.ble().write(chunk)
def read_chunk(self) -> bytes:
assert self.tx is not None
chunk = self.received_data.get()
chunk = self.ble().read()
# LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk)
@classmethod
def ble(cls) -> "BleProxy":
if cls._ble is None:
cls._ble = BleProxy()
return cls._ble
class BleProxy:
pipe = None
process = None
def __init__(self):
if self.pipe is not None:
return
parent_pipe, child_pipe = Pipe()
self.pipe = parent_pipe
self.process = Process(target=BleAsync, args=(child_pipe,), daemon=True)
self.process.start()
def __getattr__(self, name: str):
def f(*args: Any, **kwargs: Any):
assert self.pipe is not None
self.pipe.send((name, args, kwargs))
result = self.pipe.recv()
if isinstance(result, BaseException):
raise result
return result
return f
class BleAsync:
def __init__(self, pipe: Connection):
asyncio.run(self.main(pipe))
async def main(self, pipe: Connection):
from ..tealblue import TealBlue
tb = await TealBlue.create()
# TODO: add cli option for mac_filter and pass it here
self.adapter = await tb.find_adapter()
# TODO: currently only one concurrent device is supported
# To support more devices, connect() needs to return a Connection and also has to
# spawn a task that will forward data between that Connection and rx,tx.
self.current = None
self.rx = None
self.tx = None
self.devices = {}
await self.lookup() # populate self.devices
LOG.debug("async BLE process started")
while True:
await ready(pipe)
cmd, args, kwargs = pipe.recv()
try:
result = await getattr(self, cmd)(*args, **kwargs)
except Exception as e:
LOG.exception("Error in async BLE process:")
await ready(pipe, write=True)
pipe.send(e)
break
else:
await ready(pipe, write=True)
pipe.send(result)
async def lookup(self) -> List[Device]:
self.devices.clear()
for device in await self.adapter.devices():
if NUS_SERVICE_UUID in device.UUIDs:
self.devices[device.address] = device
return [
Device(device.address, device.name, device.connected)
for device in self.devices.values()
]
async def scan(self) -> List[Device]:
LOG.debug("Initiating scan")
# TODO: configurable timeout
scanner = await self.adapter.scan(2)
self.devices.clear()
async with scanner:
async for device in scanner:
if NUS_SERVICE_UUID in device.UUIDs:
if device.address not in self.devices:
LOG.debug(f"scan: {device.address}: {device.name}")
self.devices[device.address] = device
return [
Device(device.address, device.name, device.connected)
for device in self.devices.values()
]
async def connect(self, address: str):
if self.current == address:
return
# elif self.current is not None:
# self.devices[self.current].disconnect()
ble_device = self.devices[address]
if not ble_device.connected:
LOG.info("Connecting to %s (%s)..." % (ble_device.name, ble_device.address))
await ble_device.connect()
else:
LOG.info("Connected to %s (%s)." % (ble_device.name, ble_device.address))
services = await ble_device.services()
nus_service = services[NUS_SERVICE_UUID]
self.rx, _mtu = await nus_service.characteristics[
NUS_CHARACTERISTIC_RX
].acquire(write=True)
self.tx, _mtu = await nus_service.characteristics[
NUS_CHARACTERISTIC_TX
].acquire()
self.current = address
async def read(self):
assert self.tx is not None
await ready(self.tx)
return self.tx.read()
async def write(self, chunk: bytes):
assert self.rx is not None
await ready(self.rx, write=True)
self.rx.write(chunk)
async def ready(f: Any, write: bool = False):
"""Asynchronously wait for file-like object to become ready for reading or writing."""
fd = f.fileno()
loop = asyncio.get_event_loop()
event = asyncio.Event()
if write:
def callback():
event.set()
loop.remove_writer(fd)
loop.add_writer(fd, callback)
else:
def callback():
event.set()
loop.remove_reader(fd)
loop.add_reader(fd, callback)
await event.wait()