diff --git a/common/protob/Makefile b/common/protob/Makefile index f8df2d2d5d..c2feff5a84 100644 --- a/common/protob/Makefile +++ b/common/protob/Makefile @@ -1,4 +1,4 @@ -check: messages.pb messages-binance.pb messages-bitcoin.pb messages-bootloader.pb messages-cardano.pb messages-common.pb messages-crypto.pb messages-debug.pb messages-ethereum.pb messages-management.pb messages-monero.pb messages-nem.pb messages-ripple.pb messages-stellar.pb messages-tezos.pb messages-eos.pb +check: messages.pb messages-binance.pb messages-bitcoin.pb messages-ble.pb messages-bootloader.pb messages-cardano.pb messages-common.pb messages-crypto.pb messages-debug.pb messages-ethereum.pb messages-management.pb messages-monero.pb messages-nem.pb messages-ripple.pb messages-stellar.pb messages-tezos.pb messages-eos.pb %.pb: %.proto protoc -I/usr/include -I. $< -o $@ diff --git a/poetry.lock b/poetry.lock index 4ffb20f24d..4e56502a4e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -346,6 +346,17 @@ ssh = ["bcrypt (>=3.1.5)"] test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test-randomorder = ["pytest-randomly"] +[[package]] +name = "dbus-next" +version = "0.2.3" +description = "A zero-dependency DBus library for Python with asyncio support" +optional = false +python-versions = ">=3.6.0" +files = [ + {file = "dbus_next-0.2.3-py3-none-any.whl", hash = "sha256:58948f9aff9db08316734c0be2a120f6dc502124d9642f55e90ac82ffb16a18b"}, + {file = "dbus_next-0.2.3.tar.gz", hash = "sha256:f4eae26909332ada528c0a3549dda8d4f088f9b365153952a408e28023a626a5"}, +] + [[package]] name = "demjson3" version = "3.0.5" @@ -1798,4 +1809,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "971d0f6f2926d839954b35b2029978046e282df3d8d595b1e176dc0cf37889fb" +content-hash = "5de61e56cadc4aaa5c61a18e8e756d9e6cf195ff461f77ae27d076cd8350c856" diff --git a/pyproject.toml b/pyproject.toml index be12042a7f..22e269af4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,9 @@ toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"} trezor-pylint-plugin = {path = "./tools/trezor-pylint-plugin", develop = true} trezor-core-tools = {path = "./core/tools", develop = true} +# ble +dbus-next = "*" + [tool.poetry.dev-dependencies] scan-build = "*" towncrier = "^23.6.0" diff --git a/python/requirements.txt b/python/requirements.txt index 440bc2a2be..f703db947b 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -6,3 +6,4 @@ libusb1>=1.6.4 construct>=2.9,!=2.10.55 typing_extensions>=4.7.1 construct-classes>=0.1.2 +dbus-next>=0.2.3 diff --git a/python/setup.cfg b/python/setup.cfg index 5d52e2c5a4..adc3b577d7 100644 --- a/python/setup.cfg +++ b/python/setup.cfg @@ -27,7 +27,7 @@ per-file-ignores = helper-scripts/*:I tools/*:I tests/*:I -known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore] +known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore],dbus-next:[dbus_next] [isort] profile = black diff --git a/python/src/trezorlib/ble/__init__.py b/python/src/trezorlib/ble/__init__.py new file mode 100644 index 0000000000..be1c1eb682 --- /dev/null +++ b/python/src/trezorlib/ble/__init__.py @@ -0,0 +1,59 @@ +import typing as t + +from .. import messages +from ..tools import session + +if t.TYPE_CHECKING: + from ..client import TrezorClient + + +@session +def update( + client: "TrezorClient", + datfile: bytes, + binfile: bytes, + progress_update: t.Callable[[int], t.Any] = lambda _: None, +): + chunk_len = 4096 + offset = 0 + + resp = client.call( + messages.UploadBLEFirmwareInit(init_data=datfile, binsize=len(binfile)) + ) + + while isinstance(resp, messages.UploadBLEFirmwareNextChunk): + + payload = binfile[offset : offset + chunk_len] + resp = client.call(messages.UploadBLEFirmwareChunk(data=payload)) + progress_update(chunk_len) + offset += chunk_len + + if isinstance(resp, messages.Success): + return + else: + raise RuntimeError(f"Unexpected message {resp}") + + +@session +def erase_bonds( + client: "TrezorClient", +): + + resp = client.call(messages.EraseBonds()) + + if isinstance(resp, messages.Success): + return + else: + raise RuntimeError(f"Unexpected message {resp}") + + +@session +def disconnect( + client: "TrezorClient", +): + resp = client.call(messages.Disconnect()) + + if isinstance(resp, messages.Success): + return + else: + raise RuntimeError(f"Unexpected message {resp}") diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 050e3788f7..3931b8f8a0 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -from .. import exceptions, transport +from .. import exceptions, messages, transport from ..client import TrezorClient from ..ui import ClickUI, ScriptUI @@ -110,6 +110,12 @@ class TrezorConnection: except transport.DeviceIsBusy: click.echo("Device is in use by another process.") sys.exit(1) + except exceptions.TrezorFailure as e: + if e.code is messages.FailureType.DeviceIsBusy: + click.echo(str(e)) + sys.exit(1) + else: + raise e except Exception: click.echo("Failed to find a Trezor device.") if self.path is not None: diff --git a/python/src/trezorlib/cli/ble.py b/python/src/trezorlib/cli/ble.py new file mode 100644 index 0000000000..c4ecf8fba6 --- /dev/null +++ b/python/src/trezorlib/cli/ble.py @@ -0,0 +1,139 @@ +# This file is part of the Trezor project. +# +# Copyright (C) 2012-2022 SatoshiLabs and contributors +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the License along with this library. +# If not, see . + +import json +import sys +import zipfile +from typing import TYPE_CHECKING, BinaryIO + +import click + +from .. import ble, exceptions +from ..transport.ble import BleProxy +from . import with_client + +if TYPE_CHECKING: + from ..client import TrezorClient + + +@click.group(name="ble") +def cli() -> None: + """BLE commands.""" + + +@cli.command() +# fmt: off +@click.argument("package", type=click.File("rb")) +# fmt: on +@with_client +def update( + client: "TrezorClient", + package: BinaryIO, +) -> None: + """Upload new BLE firmware to device.""" + + with zipfile.ZipFile(package) as archive: + manifest = archive.read("manifest.json") + mainfest_data = json.loads(manifest.decode("utf-8"))["manifest"] + + for k in mainfest_data.keys(): + + binfile = archive.read(mainfest_data[k]["bin_file"]) + datfile = archive.read(mainfest_data[k]["dat_file"]) + + """Perform the final act of loading the firmware into Trezor.""" + try: + click.echo("Uploading...\r", nl=False) + with click.progressbar( + label="Uploading", length=len(binfile), show_eta=False + ) as bar: + ble.update(client, datfile, binfile, bar.update) + click.echo("Update successful.") + except exceptions.Cancelled: + click.echo("Update aborted on device.") + except exceptions.TrezorException as e: + click.echo(f"Update failed: {e}") + sys.exit(3) + + +@cli.command() +def connect() -> None: + """Connect to the device via BLE.""" + ble = BleProxy() + devices = [d for d in ble.lookup() if d.connected] + + if len(devices) == 0: + click.echo("Scanning...") + devices = ble.scan() + + if len(devices) == 0: + click.echo("No BLE devices found") + return + else: + click.echo("Found %d BLE device(s)" % len(devices)) + + for device in devices: + click.echo(f"Device: {device.name}, {device.address}") + + device = devices[0] + click.echo(f"Connecting to {device.name}...") + ble.connect(device.address) + click.echo("Connected") + + +@with_client +def disconnect_device(client: "TrezorClient") -> None: + """Disconnect from device side.""" + try: + ble.disconnect(client) + except exceptions.Cancelled: + click.echo("Disconnect aborted on device.") + except exceptions.TrezorException as e: + click.echo(f"Disconnect failed: {e}") + sys.exit(3) + + +@cli.command() +@click.option("--device", is_flag=True, help="Disconnect from device side.") +def disconnect(device: bool) -> None: + + if device: + disconnect_device() + else: + ble_proxy = BleProxy() + devices = [d for d in ble_proxy.lookup() if d.connected] + if len(devices) == 0: + click.echo("No BLE devices found") + return + ble_proxy.connect(devices[0].address) + ble_proxy.disconnect() + + +@cli.command() +@with_client +def erase_bonds( + client: "TrezorClient", +) -> None: + """Erase BLE bonds on device.""" + + try: + ble.erase_bonds(client) + click.echo("Erase successful.") + except exceptions.Cancelled: + click.echo("Erase aborted on device.") + except exceptions.TrezorException as e: + click.echo(f"Update failed: {e}") + sys.exit(3) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 60f8e8d309..18380abd0e 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -33,6 +33,7 @@ from . import ( TrezorConnection, benchmark, binance, + ble, btc, cardano, crypto, @@ -88,6 +89,7 @@ COMMAND_ALIASES = { "upgrade-firmware": firmware.update, "firmware-upgrade": firmware.update, "firmware-update": firmware.update, + "ble-update": ble.update, } @@ -418,6 +420,7 @@ cli.add_command(tezos.cli) cli.add_command(firmware.cli) cli.add_command(debug.cli) cli.add_command(benchmark.cli) +cli.add_command(ble.cli) # # Main diff --git a/python/src/trezorlib/tealblue.py b/python/src/trezorlib/tealblue.py new file mode 100755 index 0000000000..42b9ff7ec4 --- /dev/null +++ b/python/src/trezorlib/tealblue.py @@ -0,0 +1,350 @@ +# !/usr/bin/python3 +# pyright: off + +import asyncio +import io +import logging + +import dbus_next + +LOG = logging.getLogger(__name__) + + +def unwrap_properties(properties): + return {k: v.value for k, v in properties.items()} + + +class TealBlue: + @classmethod + async def create(cls): + self = cls() + self._bus = await dbus_next.aio.MessageBus( + bus_type=dbus_next.constants.BusType.SYSTEM, negotiate_unix_fd=True + ).connect() + obj = await self.get_object("org.bluez", "/") + self._bluez = obj.get_interface("org.freedesktop.DBus.ObjectManager") + + return self + + async def find_adapter(self, mac_filter=""): + """Find the first adapter matching mac_filter.""" + objects = await self._bluez.call_get_managed_objects() + for path in sorted(objects.keys()): + interfaces = objects[path] + if "org.bluez.Adapter1" not in interfaces: + continue + properties = unwrap_properties(interfaces["org.bluez.Adapter1"]) + if mac_filter not in properties["Address"]: + continue + return await Adapter.create(self, path, properties) + raise Exception("No adapter found") + + async def get_object(self, name, path): + introspection = await self._bus.introspect(name, path) + obj = self._bus.get_proxy_object(name, path, introspection) + return obj + + +class Adapter: + @classmethod + async def create(cls, teal, path, properties): + self = cls() + self._teal = teal + self._path = path + self._properties = properties + obj = await self._teal.get_object("org.bluez", path) + self._object = obj.get_interface("org.bluez.Adapter1") + + return self + + def __repr__(self): + return "" % (self._properties["Address"]) + + async def devices(self): + """ + Returns the devices that BlueZ has discovered. + """ + objects = await self._teal._bluez.call_get_managed_objects() + devices = [] + for path in sorted(objects.keys()): + interfaces = objects[path] + if "org.bluez.Device1" not in interfaces: + continue + properties = unwrap_properties(interfaces["org.bluez.Device1"]) + devices.append(await Device.create(self._teal, path, properties)) + + return devices + + async def scan(self, timeout_s): + return await Scanner.create(self._teal, self, await self.devices(), timeout_s) + + +class Scanner: + @classmethod + async def create(cls, teal, adapter, initial_devices, timeout_s): + self = cls() + self._teal = teal + self._adapter = adapter + self._was_discovering = adapter._properties[ + "Discovering" + ] # TODO get current value, or watch property changes + self._queue = asyncio.Queue() + self.timeout_s = timeout_s + for device in initial_devices: + self._queue.put_nowait((device._path, device._properties)) + + self._teal._bluez.on_interfaces_added(self._on_iface_added) + if not self._was_discovering: + await self._adapter._object.call_start_discovery() + + return self + + def _on_iface_added(self, path, interfaces): + if "org.bluez.Device1" not in interfaces: + return + if not path.startswith(self._adapter._path + "/"): + return + properties = unwrap_properties(interfaces["org.bluez.Device1"]) + self._queue.put_nowait((path, properties)) + + async def __aenter__(self): + return self + + async def __aexit__(self, type, value, traceback): + if not self._was_discovering: + await self._adapter._object.call_stop_discovery() + self._teal._bluez.off_interfaces_added(self._on_iface_added) + + def __aiter__(self): + return self + + async def __anext__(self): + try: + (path, properties) = await asyncio.wait_for( + self._queue.get(), self.timeout_s + ) + return await Device.create(self._teal, path, properties) + except asyncio.TimeoutError: + raise StopAsyncIteration + + +class Device: + @classmethod + async def create(cls, teal, path, properties): + self = cls() + self._teal = teal + self._path = path + self._properties = properties + self._services_resolved = asyncio.Event() + self._services = None + + if properties["ServicesResolved"]: + self._services_resolved.set() + + # Listen to device events (connect, disconnect, ServicesResolved, ...) + obj = await self._teal.get_object("org.bluez", path) + self._device = obj.get_interface("org.bluez.Device1") + obj = await self._teal.get_object("org.bluez", path) + self._device_props = obj.get_interface("org.freedesktop.DBus.Properties") + self._device_props.on_properties_changed(self._on_prop_changed) + + return self + + def __del__(self): + self._device_props.off_properties_changed(self._on_prop_changed) + + def __repr__(self): + return "" % (self.address, self.name) + + def _on_prop_changed(self, _interface, changed_props, invalidated_props): + changed_props = unwrap_properties(changed_props) + LOG.debug( + f"prop changed: device {self._path} {changed_props.keys()} {invalidated_props}" + ) + for key, value in changed_props.items(): + self._properties[key] = value + for key in invalidated_props: + del self._properties[key] + + if "ServicesResolved" in changed_props: + if changed_props["ServicesResolved"]: + self._services_resolved.set() + else: + self._services_resolved.clear() + + async def connect(self): + await self._device.call_connect() + + async def disconnect(self): + await self._device.call_disconnect() + + async def services(self): + await self._services_resolved.wait() + if self._services is None: + self._services = {} + objects = await self._teal._bluez.call_get_managed_objects() + for path in sorted(objects.keys()): + if not path.startswith(self._path + "/"): + continue + if "org.bluez.GattService1" in objects[path]: + properties = unwrap_properties( + objects[path]["org.bluez.GattService1"] + ) + service = Service(self._teal, self, path, properties) + self._services[service.uuid] = service + elif "org.bluez.GattCharacteristic1" in objects[path]: + properties = unwrap_properties( + objects[path]["org.bluez.GattCharacteristic1"] + ) + characterstic = await Characteristic.create( + self._teal, self, path, properties + ) + for service in self._services.values(): + if properties["Service"] == service._path: + service.characteristics[characterstic.uuid] = characterstic + return self._services + + @property + def connected(self): + return bool(self._properties["Connected"]) + + @property + def services_resolved(self): + return bool(self._properties["ServicesResolved"]) + + @property + def UUIDs(self): + return [str(s) for s in self._properties["UUIDs"]] + + @property + def address(self): + return str(self._properties["Address"]) + + @property + def name(self): + if "Name" not in self._properties: + return None + return str(self._properties["Name"]) + + @property + def alias(self): + if "Alias" not in self._properties: + return None + return str(self._properties["Alias"]) + + +class Service: + def __init__(self, teal, device, path, properties): + self._device = device + self._teal = teal + self._path = path + self._properties = properties + self.characteristics = {} + + def __repr__(self): + return "" % ( + self._device.address, + self.uuid, + ) + + @property + def uuid(self): + return str(self._properties["UUID"]) + + +class Characteristic: + def __init__(self): + self._properties = {} + + @classmethod + async def create(cls, teal, device, path, properties): + self = cls() + self._device = device + self._teal = teal + self._path = path + self._properties = properties + self._values = asyncio.Queue() + + obj = await self._teal.get_object("org.bluez", path) + self._char = obj.get_interface("org.bluez.GattCharacteristic1") + self._props = obj.get_interface("org.freedesktop.DBus.Properties") + self._props.on_properties_changed(self._on_prop_changed) + + return self + + def __repr__(self): + return "" % ( + self._device.address, + self.uuid, + ) + + def __del__(self): + self._props.off_properties_changed(self._on_prop_changed) + + def _on_prop_changed(self, _interface, changed_props, invalidated_props): + changed_props = unwrap_properties(changed_props) + LOG.debug( + f"prop changed: characteristic {changed_props.keys()} {invalidated_props}" + ) + for key, value in changed_props.items(): + self._properties[key] = bytes(value) + for key in invalidated_props: + del self._properties[key] + + if "Value" in changed_props: + self._values.put_nowait(changed_props["Value"]) + + async def acquire(self, write: bool = False) -> tuple[io.FileIO, int]: + if write: + fd, mtu = await self._char.call_acquire_write({}) + mode = "w" + else: + fd, mtu = await self._char.call_acquire_notify({}) + mode = "r" + + f = io.FileIO(fd, mode) + LOG.debug(f"acquired {self.uuid} ({mode})") + return f, mtu + + async def read(self) -> bytes: + return bytes(await self._values.get()) + + async def write(self, value, command=True): + ty = "command" if command else "request" + await self._char.call_write_value(value, {"type": dbus_next.Variant("s", ty)}) + + # + # async def write(self, value, command=True): + # start = time.time() + # try: + # if command: + # await self._char.call_write_value(value, {"type": "command"}) + # else: + # await self._char.call_write_value(value, {"type": "request"}) + # + # except dbus_next.DBusError as e: + # if ( + # e.type == "org.bluez.Error.Failed" + # and e.text == "Not connected" + # ): + # raise NotConnectedError() + # else: + # raise # some other error + # + # # Workaround: if the write took very long, it is possible the connection + # # broke (without causing an exception). So check whether we are still + # # connected. + # # I think this is a bug in BlueZ. + # if time.time() - start > 0.5: + # if not self._device._device_props.call_get("org.bluez.Device1", "Connected"): + # raise NotConnectedError() + + async def start_notify(self): + await self._char.call_start_notify() + + async def stop_notify(self): + await self._char.call_stop_notify() + + @property + def uuid(self): + return str(self._properties["UUID"]) diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index b04876b6b7..97a3b740c9 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -113,6 +113,7 @@ class Transport: def all_transports() -> Iterable[Type["Transport"]]: + from .ble import BleTransport from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport @@ -123,6 +124,7 @@ def all_transports() -> Iterable[Type["Transport"]]: HidTransport, UdpTransport, WebUsbTransport, + BleTransport, ) return set(t for t in transports if t.ENABLED) diff --git a/python/src/trezorlib/transport/ble.py b/python/src/trezorlib/transport/ble.py new file mode 100644 index 0000000000..bf09835af2 --- /dev/null +++ b/python/src/trezorlib/transport/ble.py @@ -0,0 +1,265 @@ +# This file is part of the Trezor project. +# +# Copyright (C) 2012-2022 SatoshiLabs and contributors +# +# This library is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License version 3 +# as published by the Free Software Foundation. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the License along with this library. +# If not, see . +import asyncio +import logging +from dataclasses import dataclass +from multiprocessing import Pipe, Process +from multiprocessing.connection import Connection +from typing import TYPE_CHECKING, Any, Iterable, List, Optional + +from . import TransportException +from .protocol import ProtocolBasedTransport, ProtocolV1 + +if TYPE_CHECKING: + from ..models import TrezorModel + +LOG = logging.getLogger(__name__) + +NUS_SERVICE_UUID = "8c000001-a59b-4d58-a9ad-073df69fa1b1" +NUS_CHARACTERISTIC_RX = "8c000002-a59b-4d58-a9ad-073df69fa1b1" +NUS_CHARACTERISTIC_TX = "8c000003-a59b-4d58-a9ad-073df69fa1b1" + + +@dataclass +class Device: + address: str + name: str + connected: bool + + +class BleTransport(ProtocolBasedTransport): + ENABLED = True + PATH_PREFIX = "ble" + + _ble = None + + def __init__(self, mac_addr: str) -> None: + self.device = mac_addr + super().__init__(protocol=ProtocolV1(self, replen=244)) + + def get_path(self) -> str: + return "{}:{}".format(self.PATH_PREFIX, self.device) + + def find_debug(self) -> "BleTransport": + return BleTransport(self.device) + + @classmethod + def enumerate( + cls, _models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["BleTransport"]: + devices = cls.ble().lookup() + return [BleTransport(device.address) for device in devices if device.connected] + + @classmethod + def _try_path(cls, path: str) -> "BleTransport": + devices = cls.enumerate(None) + devices = [d for d in devices if d.device == path] + if len(devices) == 0: + raise TransportException(f"No BLE device: {path}") + return devices[0] + + @classmethod + def find_by_path(cls, path: str, prefix_search: bool = False) -> "BleTransport": + if not prefix_search: + raise TransportException + + if prefix_search: + return super().find_by_path(path, prefix_search) + else: + raise TransportException(f"No BLE device: {path}") + + def open(self) -> None: + self.ble().connect(self.device) + + def close(self) -> None: + pass + + def write_chunk(self, chunk: bytes) -> None: + self.ble().write(chunk) + + def read_chunk(self) -> bytes: + chunk = self.ble().read() + # LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}") + if len(chunk) != 64: + raise TransportException(f"Unexpected chunk size: {len(chunk)}") + return bytearray(chunk) + + @classmethod + def ble(cls) -> "BleProxy": + if cls._ble is None: + cls._ble = BleProxy() + return cls._ble + + +class BleProxy: + pipe = None + process = None + + def __init__(self): + if self.pipe is not None: + return + + parent_pipe, child_pipe = Pipe() + self.pipe = parent_pipe + self.process = Process(target=BleAsync, args=(child_pipe,), daemon=True) + self.process.start() + + def __getattr__(self, name: str): + def f(*args: Any, **kwargs: Any): + assert self.pipe is not None + self.pipe.send((name, args, kwargs)) + result = self.pipe.recv() + if isinstance(result, BaseException): + raise result + return result + + return f + + +class BleAsync: + def __init__(self, pipe: Connection): + asyncio.run(self.main(pipe)) + + async def main(self, pipe: Connection): + from ..tealblue import TealBlue + + tb = await TealBlue.create() + # TODO: add cli option for mac_filter and pass it here + self.adapter = await tb.find_adapter() + # TODO: currently only one concurrent device is supported + # To support more devices, connect() needs to return a Connection and also has to + # spawn a task that will forward data between that Connection and rx,tx. + self.current = None + self.rx = None + self.tx = None + + self.devices = {} + await self.lookup() # populate self.devices + LOG.debug("async BLE process started") + + while True: + await ready(pipe) + cmd, args, kwargs = pipe.recv() + try: + result = await getattr(self, cmd)(*args, **kwargs) + except Exception as e: + LOG.exception("Error in async BLE process:") + await ready(pipe, write=True) + pipe.send(e) + break + else: + await ready(pipe, write=True) + pipe.send(result) + + async def lookup(self) -> List[Device]: + self.devices.clear() + for device in await self.adapter.devices(): + if NUS_SERVICE_UUID in device.UUIDs: + self.devices[device.address] = device + return [ + Device(device.address, device.name, device.connected) + for device in self.devices.values() + ] + + async def scan(self) -> List[Device]: + LOG.debug("Initiating scan") + # TODO: configurable timeout + scanner = await self.adapter.scan(2) + self.devices.clear() + async with scanner: + async for device in scanner: + if NUS_SERVICE_UUID in device.UUIDs: + if device.address not in self.devices: + LOG.debug(f"scan: {device.address}: {device.name}") + self.devices[device.address] = device + return [ + Device(device.address, device.name, device.connected) + for device in self.devices.values() + ] + + async def connect(self, address: str): + if self.current == address: + return + # elif self.current is not None: + # self.devices[self.current].disconnect() + + ble_device = self.devices[address] + if not ble_device.connected: + LOG.info("Connecting to %s (%s)..." % (ble_device.name, ble_device.address)) + await ble_device.connect() + else: + LOG.info("Connected to %s (%s)." % (ble_device.name, ble_device.address)) + + services = await ble_device.services() + nus_service = services[NUS_SERVICE_UUID] + self.rx, _mtu = await nus_service.characteristics[ + NUS_CHARACTERISTIC_RX + ].acquire(write=True) + self.tx, _mtu = await nus_service.characteristics[ + NUS_CHARACTERISTIC_TX + ].acquire() + self.current = address + + async def disconnect(self): + if self.current is None: + return + ble_device = self.devices[self.current] + if ble_device.connected: + LOG.info( + "Disconnecting from %s (%s)..." % (ble_device.name, ble_device.address) + ) + await ble_device.disconnect() + else: + LOG.info( + "Disconnected from %s (%s)." % (ble_device.name, ble_device.address) + ) + self.current = None + self.rx = None + self.tx = None + + async def read(self): + assert self.tx is not None + await ready(self.tx) + return self.tx.read() + + async def write(self, chunk: bytes): + assert self.rx is not None + await ready(self.rx, write=True) + self.rx.write(chunk) + + +async def ready(f: Any, write: bool = False): + """Asynchronously wait for file-like object to become ready for reading or writing.""" + fd = f.fileno() + loop = asyncio.get_event_loop() + event = asyncio.Event() + + if write: + + def callback(): + event.set() + loop.remove_writer(fd) + + loop.add_writer(fd, callback) + else: + + def callback(): + event.set() + loop.remove_reader(fd) + + loop.add_reader(fd, callback) + + await event.wait() diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py index a5a0ee6be4..2e8c0cfabc 100644 --- a/python/src/trezorlib/transport/protocol.py +++ b/python/src/trezorlib/transport/protocol.py @@ -71,8 +71,9 @@ class Protocol: its messages. """ - def __init__(self, handle: Handle) -> None: + def __init__(self, handle: Handle, replen: int = REPLEN) -> None: self.handle = handle + self.replen = replen self.session_counter = 0 # XXX we might be able to remove this now that TrezorClient does session handling @@ -129,10 +130,10 @@ class ProtocolV1(Protocol): while buffer: # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: REPLEN - 1] - chunk = chunk.ljust(REPLEN, b"\x00") + chunk = b"?" + buffer[: self.replen - 1] + chunk = chunk.ljust(self.replen, b"\x00") self.handle.write_chunk(chunk) - buffer = buffer[63:] + buffer = buffer[self.replen - 1 :] def read(self) -> MessagePayload: buffer = bytearray() diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 8e2d08147a..ce216ec002 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -64,6 +64,8 @@ class WebUsbHandle: self.handle.claimInterface(self.interface) except usb1.USBErrorAccess as e: raise DeviceIsBusy(self.device) from e + except usb1.USBErrorBusy as e: + raise DeviceIsBusy(self.device) from e def close(self) -> None: if self.handle is not None: