feat(python): use dbus-next for BLE

tychovrahe/T3W1/devkit1_with_ble_crypto2b
Martin Milata 12 months ago committed by tychovrahe
parent 331f337f6b
commit 502f6a065c

@ -77,9 +77,7 @@ binsize = "^0.1.3"
toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"} toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"}
# ble # ble
dbus-python = "*" dbus-next = "*"
PyGObject = "*"
nrfutil = "*"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
scan-build = "*" scan-build = "*"

@ -8,6 +8,4 @@ typing_extensions>=3.10
dataclasses ; python_version<'3.7' dataclasses ; python_version<'3.7'
simple-rlp>=0.1.2 ; python_version>='3.7' simple-rlp>=0.1.2 ; python_version>='3.7'
construct-classes>=0.1.2 construct-classes>=0.1.2
dbus-python>=1.3.2 dbus-next>=0.2.3
pygobject>=3.44.1
nrfutil>=5.0.0

@ -25,7 +25,7 @@ per-file-ignores =
helper-scripts/*:I helper-scripts/*:I
tools/*:I tools/*:I
tests/*:I tests/*:I
known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore],simple-rlp:[rlp],dbus-python:[dbus] known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore],simple-rlp:[rlp],dbus-next:[dbus_next]
[isort] [isort]
profile = black profile = black

@ -16,17 +16,12 @@
import functools import functools
import sys import sys
import threading
from contextlib import contextmanager from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click import click
import dbus
import dbus.mainloop.glib
import dbus.service
from gi.repository import GLib
from .. import exceptions, messages, transport from .. import exceptions, transport
from ..client import TrezorClient from ..client import TrezorClient
from ..ui import ClickUI, ScriptUI from ..ui import ClickUI, ScriptUI
@ -109,12 +104,6 @@ class TrezorConnection:
except transport.DeviceIsBusy: except transport.DeviceIsBusy:
click.echo("Device is in use by another process.") click.echo("Device is in use by another process.")
sys.exit(1) sys.exit(1)
except exceptions.TrezorFailure as e:
if e.code is messages.FailureType.DeviceIsBusy:
click.echo(str(e))
sys.exit(1)
else:
raise e
except Exception: except Exception:
click.echo("Failed to find a Trezor device.") click.echo("Failed to find a Trezor device.")
if self.path is not None: if self.path is not None:
@ -146,95 +135,26 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[
def trezorctl_command_with_client( def trezorctl_command_with_client(
obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs"
) -> "R": ) -> "R":
with obj.client_context() as client:
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
loop = GLib.MainLoop()
dbus.mainloop.glib.DBusGMainLoop(set_as_default=True)
def callback_wrapper(
r: List[Optional["R"]], exc: List[Optional[Exception]]
) -> None:
try: try:
with obj.client_context() as client: return func(client, *args, **kwargs)
session_was_resumed = obj.session_id == client.session_id
if not session_was_resumed and obj.session_id is not None:
# tried to resume but failed
click.echo("Warning: failed to resume session.", err=True)
try:
r.append(func(client, *args, **kwargs))
except Exception as e:
exc[0] = e
finally:
if not session_was_resumed:
try:
client.end_session()
except Exception:
pass
except Exception as e:
exc[0] = e
finally: finally:
loop.quit() if not session_was_resumed:
try:
result: List["R"] = [] client.end_session()
exc: List[Optional[Exception]] = [None] except Exception:
threading.Thread( pass
target=callback_wrapper, daemon=True, args=(result, exc)
).start()
loop.run()
if exc[0] is not None:
raise exc[0]
if len(result) == 0:
raise click.ClickException("Command did not return a result.")
return result[0]
# the return type of @click.pass_obj is improperly specified and pyright doesn't # the return type of @click.pass_obj is improperly specified and pyright doesn't
# understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs)
return trezorctl_command_with_client # type: ignore [cannot be assigned to return type] return trezorctl_command_with_client # type: ignore [cannot be assigned to return type]
def with_ble(func: "Callable[P, R]") -> "Callable[P, R]":
"""Wrap a Click command in `with obj.client_context() as client`.
Sessions are handled transparently. The user is warned when session did not resume
cleanly. The session is closed after the command completes - unless the session
was resumed, in which case it should remain open.
"""
@functools.wraps(func)
def trezorctl_command(*args: "P.args", **kwargs: "P.kwargs") -> "R":
loop = GLib.MainLoop()
dbus.mainloop.glib.DBusGMainLoop(set_as_default=True)
def callback_wrapper(r: List[Optional["R"]], exc: List[Optional[Exception]]):
try:
r.append(func(*args, **kwargs))
except Exception as e:
exc[0] = e
finally:
loop.quit()
result: List["R"] = []
exc: List[Optional[Exception]] = [None]
threading.Thread(
target=callback_wrapper, daemon=True, args=(result, exc)
).start()
loop.run()
if exc[0] is not None:
raise exc[0]
if len(result) == 0:
raise click.ClickException("Command did not return a result.")
return result[0]
return trezorctl_command
class AliasedGroup(click.Group): class AliasedGroup(click.Group):
"""Command group that handles aliases and Click 6.x compatibility. """Command group that handles aliases and Click 6.x compatibility.

@ -20,9 +20,9 @@ from typing import TYPE_CHECKING, BinaryIO
import click import click
from .. import ble, exceptions, tealblue from .. import ble, exceptions
from ..transport.ble import lookup_device, scan_device from ..transport.ble import BleProxy
from . import with_ble, with_client from . import with_client
if TYPE_CHECKING: if TYPE_CHECKING:
from ..client import TrezorClient from ..client import TrezorClient
@ -64,32 +64,28 @@ def update(
@cli.command() @cli.command()
@with_ble
def connect() -> None: def connect() -> None:
"""Connect to the device via BLE.""" """Connect to the device via BLE."""
adapter = tealblue.TealBlue().find_adapter() ble = BleProxy()
devices = [d for d in ble.lookup() if d.connected]
devices = lookup_device(adapter)
devices = [d for d in devices if d.connected]
if len(devices) == 0: if len(devices) == 0:
print("Scanning...") click.echo("Scanning...")
devices = scan_device(adapter, devices) devices = ble.scan()
if len(devices) == 0: if len(devices) == 0:
print("No BLE devices found") click.echo("No BLE devices found")
return return
else: else:
print("Found %d BLE device(s)" % len(devices)) click.echo("Found %d BLE device(s)" % len(devices))
for device in devices: for device in devices:
print(f"Device: {device.name}, {device.address}") click.echo(f"Device: {device.name}, {device.address}")
device = devices[0] device = devices[0]
print(f"Connecting to {device.name}...") click.echo(f"Connecting to {device.name}...")
device.connect() ble.connect(device.address)
print("Connected") click.echo("Connected")
@cli.command() @cli.command()

@ -49,7 +49,6 @@ from . import (
settings, settings,
stellar, stellar,
tezos, tezos,
with_ble,
with_client, with_client,
) )
@ -282,7 +281,6 @@ def format_device_name(features: messages.Features) -> str:
@cli.command(name="list") @cli.command(name="list")
@with_ble
@click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names") @click.option("-n", "no_resolve", is_flag=True, help="Do not resolve Trezor names")
def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]:
"""List connected Trezor devices.""" """List connected Trezor devices."""

@ -1,201 +1,170 @@
# !/usr/bin/python3 # !/usr/bin/python3
# pyright: off # pyright: off
import queue import asyncio
import threading import io
import time import logging
import dbus import dbus_next
import dbus.mainloop.glib
import dbus.service
LOG = logging.getLogger(__name__)
class NotConnectedError(Exception):
pass
def unwrap_properties(properties):
class DBusInvalidArgsException(dbus.exceptions.DBusException): return {k: v.value for k, v in properties.items()}
_dbus_error_name = "org.freedesktop.DBus.Error.InvalidArgs"
def format_uuid(uuid):
if type(uuid) == int:
if uuid > 0xFFFF:
raise ValueError("32-bit UUID not supported yet")
uuid = "%04X" % uuid
return uuid
class TealBlue: class TealBlue:
def __init__(self): @classmethod
self._bus = dbus.SystemBus() async def create(cls):
self._bluez = dbus.Interface( self = cls()
self._bus.get_object("org.bluez", "/"), "org.freedesktop.DBus.ObjectManager" self._bus = await dbus_next.aio.MessageBus(
) bus_type=dbus_next.constants.BusType.SYSTEM, negotiate_unix_fd=True
).connect()
obj = await self.get_object("org.bluez", "/")
self._bluez = obj.get_interface("org.freedesktop.DBus.ObjectManager")
return self
def find_adapter(self): async def find_adapter(self, mac_filter=""):
# find the first adapter """Find the first adapter matching mac_filter."""
objects = self._bluez.GetManagedObjects() objects = await self._bluez.call_get_managed_objects()
for path in sorted(objects.keys()): for path in sorted(objects.keys()):
interfaces = objects[path] interfaces = objects[path]
if "org.bluez.Adapter1" not in interfaces: if "org.bluez.Adapter1" not in interfaces:
continue continue
properties = interfaces["org.bluez.Adapter1"] properties = unwrap_properties(interfaces["org.bluez.Adapter1"])
return Adapter(self, path, properties) if mac_filter not in properties["Address"]:
continue
return await Adapter.create(self, path, properties)
raise Exception("No adapter found") raise Exception("No adapter found")
# copied from: async def get_object(self, name, path):
# https://github.com/adafruit/Adafruit_Python_BluefruitLE/blob/master/Adafruit_BluefruitLE/bluez_dbus/provider.py introspection = await self._bus.introspect(name, path)
def _print_tree(self): obj = self._bus.get_proxy_object(name, path, introspection)
"""Print tree of all bluez objects, useful for debugging.""" return obj
# This is based on the bluez sample code get-managed-objects.py.
objects = self._bluez.GetManagedObjects()
for path in sorted(objects.keys()):
print("[ %s ]" % (path))
interfaces = objects[path]
for interface in sorted(interfaces.keys()):
if interface in [
"org.freedesktop.DBus.Introspectable",
"org.freedesktop.DBus.Properties",
]:
continue
print(" %s" % (interface))
properties = interfaces[interface]
for key in sorted(properties.keys()):
print(" %s = %s" % (key, properties[key]))
class Adapter: class Adapter:
def __init__(self, teal, path, properties): @classmethod
async def create(cls, teal, path, properties):
self = cls()
self._teal = teal self._teal = teal
self._path = path self._path = path
self._properties = properties self._properties = properties
self._object = dbus.Interface( obj = await self._teal.get_object("org.bluez", path)
teal._bus.get_object("org.bluez", path), "org.bluez.Adapter1" self._object = obj.get_interface("org.bluez.Adapter1")
)
self._advertisement = None return self
def __repr__(self): def __repr__(self):
return "<tealblue.Adapter address=%s>" % (self._properties["Address"]) return "<tealblue.Adapter address=%s>" % (self._properties["Address"])
def devices(self): async def devices(self):
""" """
Returns the devices that BlueZ has discovered. Returns the devices that BlueZ has discovered.
""" """
objects = self._teal._bluez.GetManagedObjects() objects = await self._teal._bluez.call_get_managed_objects()
devices = []
for path in sorted(objects.keys()): for path in sorted(objects.keys()):
interfaces = objects[path] interfaces = objects[path]
if "org.bluez.Device1" not in interfaces: if "org.bluez.Device1" not in interfaces:
continue continue
properties = interfaces["org.bluez.Device1"] properties = unwrap_properties(interfaces["org.bluez.Device1"])
yield Device(self._teal, path, properties) devices.append(await Device.create(self._teal, path, properties))
def scan(self, timeout_s): return devices
return Scanner(self._teal, self, self.devices(), timeout_s)
@property
def advertisement(self):
if self._advertisement is None:
self._advertisement = Advertisement(self._teal, self)
return self._advertisement
def advertise(self, enable):
if enable:
self.advertisement.enable()
else:
self.advertisement.disable()
def advertise_data( async def scan(self, timeout_s):
self, return await Scanner.create(self._teal, self, await self.devices(), timeout_s)
local_name=None,
service_data=None,
service_uuids=None,
manufacturer_data=None,
):
self.advertisement.local_name = local_name
self.advertisement.service_data = service_data
self.advertisement.service_uuids = service_uuids
self.advertisement.manufacturer_data = manufacturer_data
class Scanner: class Scanner:
def __init__(self, teal, adapter, initial_devices, timeout_s): @classmethod
async def create(cls, teal, adapter, initial_devices, timeout_s):
self = cls()
self._teal = teal self._teal = teal
self._adapter = adapter self._adapter = adapter
self._was_discovering = adapter._properties[ self._was_discovering = adapter._properties[
"Discovering" "Discovering"
] # TODO get current value, or watch property changes ] # TODO get current value, or watch property changes
self._queue = queue.Queue() self._queue = asyncio.Queue()
self.timeout_s = timeout_s self.timeout_s = timeout_s
for device in initial_devices: for device in initial_devices:
self._queue.put(device) self._queue.put_nowait((device._path, device._properties))
def new_device(path, interfaces): self._teal._bluez.on_interfaces_added(self._on_iface_added)
if "org.bluez.Device1" not in interfaces:
return
if not path.startswith(self._adapter._path + "/"):
return
# properties = interfaces["org.bluez.Device1"]
self._queue.put(Device(self._teal, path, interfaces["org.bluez.Device1"]))
self._signal_receiver = self._teal._bus.add_signal_receiver(
new_device,
dbus_interface="org.freedesktop.DBus.ObjectManager",
signal_name="InterfacesAdded",
)
if not self._was_discovering: if not self._was_discovering:
self._adapter._object.StartDiscovery() await self._adapter._object.call_start_discovery()
return self
def _on_iface_added(self, path, interfaces):
if "org.bluez.Device1" not in interfaces:
return
if not path.startswith(self._adapter._path + "/"):
return
properties = unwrap_properties(interfaces["org.bluez.Device1"])
self._queue.put_nowait((path, properties))
def __enter__(self): async def __aenter__(self):
return self return self
def __exit__(self, type, value, traceback): async def __aexit__(self, type, value, traceback):
if not self._was_discovering: if not self._was_discovering:
self._adapter._object.StopDiscovery() await self._adapter._object.call_stop_discovery()
self._signal_receiver.remove() self._teal._bluez.off_interfaces_added(self._on_iface_added)
def __iter__(self): def __aiter__(self):
return self return self
def __next__(self): async def __anext__(self):
try: try:
return self._queue.get(timeout=self.timeout_s) (path, properties) = await asyncio.wait_for(
except queue.Empty: self._queue.get(), self.timeout_s
raise StopIteration )
return await Device.create(self._teal, path, properties)
except asyncio.TimeoutError:
raise StopAsyncIteration
class Device: class Device:
def __init__(self, teal, path, properties): @classmethod
async def create(cls, teal, path, properties):
self = cls()
self._teal = teal self._teal = teal
self._path = path self._path = path
self._properties = properties self._properties = properties
self._services_resolved = threading.Event() self._services_resolved = asyncio.Event()
self._services = None self._services = None
if properties["ServicesResolved"]: if properties["ServicesResolved"]:
self._services_resolved.set() self._services_resolved.set()
# Listen to device events (connect, disconnect, ServicesResolved, ...) # Listen to device events (connect, disconnect, ServicesResolved, ...)
self._device = dbus.Interface( obj = await self._teal.get_object("org.bluez", path)
teal._bus.get_object("org.bluez", path), "org.bluez.Device1" self._device = obj.get_interface("org.bluez.Device1")
) obj = await self._teal.get_object("org.bluez", path)
self._device_props = dbus.Interface( self._device_props = obj.get_interface("org.freedesktop.DBus.Properties")
self._device, "org.freedesktop.DBus.Properties" self._device_props.on_properties_changed(self._on_prop_changed)
)
self._signal_receiver = self._device_props.connect_to_signal( return self
"PropertiesChanged",
lambda itf, ch, inv: self._on_prop_changed(itf, ch, inv),
)
def __del__(self): def __del__(self):
self._signal_receiver.remove() self._device_props.off_properties_changed(self._on_prop_changed)
def __repr__(self): def __repr__(self):
return "<tealblue.Device address=%s name=%r>" % (self.address, self.name) return "<tealblue.Device address=%s name=%r>" % (self.address, self.name)
def _on_prop_changed(self, properties, changed_props, invalidated_props): def _on_prop_changed(self, _interface, changed_props, invalidated_props):
changed_props = unwrap_properties(changed_props)
LOG.debug(
f"prop changed: device {self._path} {changed_props.keys()} {invalidated_props}"
)
for key, value in changed_props.items(): for key, value in changed_props.items():
self._properties[key] = value self._properties[key] = value
for key in invalidated_props:
del self._properties[key]
if "ServicesResolved" in changed_props: if "ServicesResolved" in changed_props:
if changed_props["ServicesResolved"]: if changed_props["ServicesResolved"]:
@ -203,36 +172,33 @@ class Device:
else: else:
self._services_resolved.clear() self._services_resolved.clear()
def _wait_for_discovery(self): async def connect(self):
# wait until ServicesResolved is True await self._device.call_connect()
self._services_resolved.wait()
def connect(self):
self._device.Connect()
def disconnect(self):
self._device.Disconnect()
def resolve_services(self): async def disconnect(self):
self._services_resolved.wait() await self._device.call_disconnect()
@property async def services(self):
def services(self): await self._services_resolved.wait()
if not self._services_resolved.is_set():
return None
if self._services is None: if self._services is None:
self._services = {} self._services = {}
objects = self._teal._bluez.GetManagedObjects() objects = await self._teal._bluez.call_get_managed_objects()
for path in sorted(objects.keys()): for path in sorted(objects.keys()):
if not path.startswith(self._path + "/"): if not path.startswith(self._path + "/"):
continue continue
if "org.bluez.GattService1" in objects[path]: if "org.bluez.GattService1" in objects[path]:
properties = objects[path]["org.bluez.GattService1"] properties = unwrap_properties(
objects[path]["org.bluez.GattService1"]
)
service = Service(self._teal, self, path, properties) service = Service(self._teal, self, path, properties)
self._services[service.uuid] = service self._services[service.uuid] = service
elif "org.bluez.GattCharacteristic1" in objects[path]: elif "org.bluez.GattCharacteristic1" in objects[path]:
properties = objects[path]["org.bluez.GattCharacteristic1"] properties = unwrap_properties(
characterstic = Characteristic(self._teal, self, path, properties) objects[path]["org.bluez.GattCharacteristic1"]
)
characterstic = await Characteristic.create(
self._teal, self, path, properties
)
for service in self._services.values(): for service in self._services.values():
if properties["Service"] == service._path: if properties["Service"] == service._path:
service.characteristics[characterstic.uuid] = characterstic service.characteristics[characterstic.uuid] = characterstic
@ -287,22 +253,24 @@ class Service:
class Characteristic: class Characteristic:
def __init__(self, teal, device, path, properties): def __init__(self):
self._properties = {}
@classmethod
async def create(cls, teal, device, path, properties):
self = cls()
self._device = device self._device = device
self._teal = teal self._teal = teal
self._path = path self._path = path
self._properties = properties self._properties = properties
self._values = asyncio.Queue()
self.on_notify = None obj = await self._teal.get_object("org.bluez", path)
self._char = obj.get_interface("org.bluez.GattCharacteristic1")
self._props = obj.get_interface("org.freedesktop.DBus.Properties")
self._props.on_properties_changed(self._on_prop_changed)
self._char = dbus.Interface( return self
teal._bus.get_object("org.bluez", path), "org.bluez.GattCharacteristic1"
)
char_props = dbus.Interface(self._char, "org.freedesktop.DBus.Properties")
self._signal_receiver = char_props.connect_to_signal(
"PropertiesChanged",
lambda itf, ch, inv: self._on_prop_changed(itf, ch, inv),
)
def __repr__(self): def __repr__(self):
return "<tealblue.Characteristic device=%s uuid=%s>" % ( return "<tealblue.Characteristic device=%s uuid=%s>" % (
@ -311,147 +279,72 @@ class Characteristic:
) )
def __del__(self): def __del__(self):
self._signal_receiver.remove() self._props.off_properties_changed(self._on_prop_changed)
def _on_prop_changed(self, properties, changed_props, invalidated_props): def _on_prop_changed(self, _interface, changed_props, invalidated_props):
changed_props = unwrap_properties(changed_props)
LOG.debug(
f"prop changed: characteristic {changed_props.keys()} {invalidated_props}"
)
for key, value in changed_props.items(): for key, value in changed_props.items():
self._properties[key] = bytes(value) self._properties[key] = bytes(value)
for key in invalidated_props:
del self._properties[key]
if "Value" in changed_props and self.on_notify is not None: if "Value" in changed_props:
self.on_notify(self, changed_props["Value"]) self._values.put_nowait(changed_props["Value"])
def read(self):
return bytes(self._char.ReadValue({}))
def write(self, value, command=True):
start = time.time()
try:
if command:
self._char.WriteValue(value, {"type": "command"})
else:
self._char.WriteValue(value, {"type": "request"})
except dbus.DBusException as e:
if (
e.get_dbus_name() == "org.bluez.Error.Failed"
and e.get_dbus_message() == "Not connected"
):
raise NotConnectedError()
else:
raise # some other error
# Workaround: if the write took very long, it is possible the connection
# broke (without causing an exception). So check whether we are still
# connected.
# I think this is a bug in BlueZ.
if time.time() - start > 0.5:
if not self._device._device_props.Get("org.bluez.Device1", "Connected"):
raise NotConnectedError()
def start_notify(self):
self._char.StartNotify()
def stop_notify(self): async def acquire(self, write: bool = False) -> tuple[io.FileIO, int]:
self._char.StopNotify() if write:
fd, mtu = await self._char.call_acquire_write({})
mode = "w"
else:
fd, mtu = await self._char.call_acquire_notify({})
mode = "r"
f = io.FileIO(fd, mode)
LOG.debug(f"acquired {self.uuid} ({mode})")
return f, mtu
async def read(self) -> bytes:
return bytes(await self._values.get())
async def write(self, value, command=True):
ty = "command" if command else "request"
await self._char.call_write_value(value, {"type": dbus_next.Variant("s", ty)})
#
# async def write(self, value, command=True):
# start = time.time()
# try:
# if command:
# await self._char.call_write_value(value, {"type": "command"})
# else:
# await self._char.call_write_value(value, {"type": "request"})
#
# except dbus_next.DBusError as e:
# if (
# e.type == "org.bluez.Error.Failed"
# and e.text == "Not connected"
# ):
# raise NotConnectedError()
# else:
# raise # some other error
#
# # Workaround: if the write took very long, it is possible the connection
# # broke (without causing an exception). So check whether we are still
# # connected.
# # I think this is a bug in BlueZ.
# if time.time() - start > 0.5:
# if not self._device._device_props.call_get("org.bluez.Device1", "Connected"):
# raise NotConnectedError()
async def start_notify(self):
await self._char.call_start_notify()
async def stop_notify(self):
await self._char.call_stop_notify()
@property @property
def uuid(self): def uuid(self):
return str(self._properties["UUID"]) return str(self._properties["UUID"])
class Advertisement(dbus.service.Object):
PATH = "/com/github/aykevl/pynus/advertisement"
def __init__(self, teal, adapter):
self._teal = teal
self._adapter = adapter
self._enabled = False
self.service_uuids = None
self.manufacturer_data = None
self.solicit_uuids = None
self.service_data = None
self.local_name = None
self.include_tx_power = None
self._manager = dbus.Interface(
teal._bus.get_object("org.bluez", self._adapter._path),
"org.bluez.LEAdvertisingManager1",
)
self._adv_enabled = threading.Event()
dbus.service.Object.__init__(self, teal._bus, self.PATH)
def enable(self):
if self._enabled:
return
self._manager.RegisterAdvertisement(
dbus.ObjectPath(self.PATH),
dbus.Dictionary({}, signature="sv"),
reply_handler=self._cb_enabled,
error_handler=self._cb_enabled_err,
)
self._adv_enabled.wait()
self._adv_enabled.clear()
def _cb_enabled(self):
self._enabled = True
if self._adv_enabled.is_set():
raise RuntimeError("called enable() twice")
self._adv_enabled.set()
def _cb_enabled_err(self, err):
self._enabled = False
if self._adv_enabled.is_set():
raise RuntimeError("called enable() twice")
self._adv_enabled.set()
def disable(self):
if not self._enabled:
return
self._bus.UnregisterAdvertisement(self.PATH)
self._enabled = False
@property
def enabled(self):
return self._enabled
@dbus.service.method(
"org.freedesktop.DBus.Properties", in_signature="s", out_signature="a{sv}"
)
def GetAll(self, interface):
print("GetAll")
if interface != "org.bluez.LEAdvertisement1":
raise DBusInvalidArgsException()
try:
properties = {
"Type": dbus.String("peripheral"),
}
if self.service_uuids is not None:
properties["ServiceUUIDs"] = dbus.Array(
map(format_uuid, self.service_uuids), signature="s"
)
if self.solicit_uuids is not None:
properties["SolicitUUIDs"] = dbus.Array(
map(format_uuid, self.solicit_uuids), signature="s"
)
if self.manufacturer_data is not None:
properties["ManufacturerData"] = dbus.Dictionary(
{k: v for k, v in self.manufacturer_data.items()}, signature="qv"
)
if self.service_data is not None:
properties["ServiceData"] = dbus.Dictionary(
self.service_data, signature="sv"
)
if self.local_name is not None:
properties["LocalName"] = dbus.String(self.local_name)
if self.include_tx_power is not None:
properties["IncludeTxPower"] = dbus.Boolean(self.include_tx_power)
except Exception as e:
print("err: ", e)
print("properties:", properties)
return properties
@dbus.service.method(
"org.bluez.LEAdvertisement1", in_signature="", out_signature=""
)
def Release(self):
self._enabled = True

@ -13,12 +13,13 @@
# #
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import asyncio
import logging import logging
from queue import Queue from dataclasses import dataclass
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from typing import TYPE_CHECKING, Any, Iterable, List, Optional from typing import TYPE_CHECKING, Any, Iterable, List, Optional
from .. import tealblue
from ..tealblue import Adapter, Characteristic
from . import TransportException from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
@ -32,65 +33,35 @@ NUS_CHARACTERISTIC_RX = "6e400002-b5a3-f393-e0a9-e50e24dcca9e"
NUS_CHARACTERISTIC_TX = "6e400003-b5a3-f393-e0a9-e50e24dcca9e" NUS_CHARACTERISTIC_TX = "6e400003-b5a3-f393-e0a9-e50e24dcca9e"
def scan_device(adapter: Adapter, devices: List[tealblue.Device]): @dataclass
with adapter.scan(2) as scanner: class Device:
for device in scanner: address: str
if NUS_SERVICE_UUID in device.UUIDs: name: str
if device.address not in [d.address for d in devices]: connected: bool
print(f"Found device: {device.name}, {device.address}")
devices.append(device)
return devices
def lookup_device(adapter: Adapter):
devices = []
for device in adapter.devices():
if NUS_SERVICE_UUID in device.UUIDs:
devices.append(device)
return devices
class BleTransport(ProtocolBasedTransport): class BleTransport(ProtocolBasedTransport):
ENABLED = True ENABLED = True
PATH_PREFIX = "ble" PATH_PREFIX = "ble"
def __init__(self, mac_addr: str, adapter: Adapter) -> None: _ble = None
self.tx = None def __init__(self, mac_addr: str) -> None:
self.rx = None
self.device = mac_addr self.device = mac_addr
self.adapter = adapter
self.received_data = Queue()
devices = lookup_device(self.adapter)
for d in devices:
if d.address == mac_addr:
self.ble_device = d
break
super().__init__(protocol=ProtocolV1(self, replen=244)) super().__init__(protocol=ProtocolV1(self, replen=244))
def get_path(self) -> str: def get_path(self) -> str:
return "{}:{}".format(self.PATH_PREFIX, self.device) return "{}:{}".format(self.PATH_PREFIX, self.device)
def find_debug(self) -> "BleTransport": def find_debug(self) -> "BleTransport":
mac = self.device return BleTransport(self.device)
return BleTransport(f"{mac}", self.adapter)
@classmethod @classmethod
def enumerate( def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["BleTransport"]: ) -> Iterable["BleTransport"]:
adapter = tealblue.TealBlue().find_adapter() devices = cls.ble().lookup()
devices = lookup_device(adapter) return [BleTransport(device.address) for device in devices if device.connected]
for device in devices:
print(f"Device: {device.name}, {device.address}")
devices = [d for d in devices if d.connected]
return [BleTransport(device.address, adapter) for device in devices]
@classmethod @classmethod
def _try_path(cls, path: str) -> "BleTransport": def _try_path(cls, path: str) -> "BleTransport":
@ -111,44 +82,167 @@ class BleTransport(ProtocolBasedTransport):
raise TransportException(f"No BLE device: {path}") raise TransportException(f"No BLE device: {path}")
def open(self) -> None: def open(self) -> None:
self.ble().connect(self.device)
if not self.ble_device.connected:
print(
"Connecting to %s (%s)..."
% (self.ble_device.name, self.ble_device.address)
)
self.ble_device.connect()
else:
print(
"Connected to %s (%s)."
% (self.ble_device.name, self.ble_device.address)
)
if not self.ble_device.services_resolved:
print("Resolving services...")
self.ble_device.resolve_services()
service = self.ble_device.services[NUS_SERVICE_UUID]
self.rx = service.characteristics[NUS_CHARACTERISTIC_RX]
self.tx = service.characteristics[NUS_CHARACTERISTIC_TX]
def on_notify(characteristic: Characteristic, value: Any):
self.received_data.put(bytes(value))
self.tx.on_notify = on_notify
self.tx.start_notify()
def close(self) -> None: def close(self) -> None:
pass pass
def write_chunk(self, chunk: bytes) -> None: def write_chunk(self, chunk: bytes) -> None:
assert self.rx is not None self.ble().write(chunk)
self.rx.write(chunk)
def read_chunk(self) -> bytes: def read_chunk(self) -> bytes:
assert self.tx is not None chunk = self.ble().read()
chunk = self.received_data.get()
# LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}") # LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}") raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk) return bytearray(chunk)
@classmethod
def ble(cls) -> "BleProxy":
if cls._ble is None:
cls._ble = BleProxy()
return cls._ble
class BleProxy:
pipe = None
process = None
def __init__(self):
if self.pipe is not None:
return
parent_pipe, child_pipe = Pipe()
self.pipe = parent_pipe
self.process = Process(target=BleAsync, args=(child_pipe,), daemon=True)
self.process.start()
def __getattr__(self, name: str):
def f(*args: Any, **kwargs: Any):
assert self.pipe is not None
self.pipe.send((name, args, kwargs))
result = self.pipe.recv()
if isinstance(result, BaseException):
raise result
return result
return f
class BleAsync:
def __init__(self, pipe: Connection):
asyncio.run(self.main(pipe))
async def main(self, pipe: Connection):
from ..tealblue import TealBlue
tb = await TealBlue.create()
# TODO: add cli option for mac_filter and pass it here
self.adapter = await tb.find_adapter()
# TODO: currently only one concurrent device is supported
# To support more devices, connect() needs to return a Connection and also has to
# spawn a task that will forward data between that Connection and rx,tx.
self.current = None
self.rx = None
self.tx = None
self.devices = {}
await self.lookup() # populate self.devices
LOG.debug("async BLE process started")
while True:
await ready(pipe)
cmd, args, kwargs = pipe.recv()
try:
result = await getattr(self, cmd)(*args, **kwargs)
except Exception as e:
LOG.exception("Error in async BLE process:")
await ready(pipe, write=True)
pipe.send(e)
break
else:
await ready(pipe, write=True)
pipe.send(result)
async def lookup(self) -> List[Device]:
self.devices.clear()
for device in await self.adapter.devices():
if NUS_SERVICE_UUID in device.UUIDs:
self.devices[device.address] = device
return [
Device(device.address, device.name, device.connected)
for device in self.devices.values()
]
async def scan(self) -> List[Device]:
LOG.debug("Initiating scan")
# TODO: configurable timeout
scanner = await self.adapter.scan(2)
self.devices.clear()
async with scanner:
async for device in scanner:
if NUS_SERVICE_UUID in device.UUIDs:
if device.address not in self.devices:
LOG.debug(f"scan: {device.address}: {device.name}")
self.devices[device.address] = device
return [
Device(device.address, device.name, device.connected)
for device in self.devices.values()
]
async def connect(self, address: str):
if self.current == address:
return
# elif self.current is not None:
# self.devices[self.current].disconnect()
ble_device = self.devices[address]
if not ble_device.connected:
LOG.info("Connecting to %s (%s)..." % (ble_device.name, ble_device.address))
await ble_device.connect()
else:
LOG.info("Connected to %s (%s)." % (ble_device.name, ble_device.address))
services = await ble_device.services()
nus_service = services[NUS_SERVICE_UUID]
self.rx, _mtu = await nus_service.characteristics[
NUS_CHARACTERISTIC_RX
].acquire(write=True)
self.tx, _mtu = await nus_service.characteristics[
NUS_CHARACTERISTIC_TX
].acquire()
self.current = address
async def read(self):
assert self.tx is not None
await ready(self.tx)
return self.tx.read()
async def write(self, chunk: bytes):
assert self.rx is not None
await ready(self.rx, write=True)
self.rx.write(chunk)
async def ready(f: Any, write: bool = False):
"""Asynchronously wait for file-like object to become ready for reading or writing."""
fd = f.fileno()
loop = asyncio.get_event_loop()
event = asyncio.Event()
if write:
def callback():
event.set()
loop.remove_writer(fd)
loop.add_writer(fd, callback)
else:
def callback():
event.set()
loop.remove_reader(fd)
loop.add_reader(fd, callback)
await event.wait()

Loading…
Cancel
Save