1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-29 10:58:21 +00:00

TREZORCTL

This commit is contained in:
tychovrahe 2023-06-07 13:18:57 +02:00
parent 4d19f6a7fd
commit 426acb8c16
14 changed files with 850 additions and 8 deletions

View File

@ -1,4 +1,4 @@
check: messages.pb messages-binance.pb messages-bitcoin.pb messages-bootloader.pb messages-cardano.pb messages-common.pb messages-crypto.pb messages-debug.pb messages-ethereum.pb messages-management.pb messages-monero.pb messages-nem.pb messages-ripple.pb messages-stellar.pb messages-tezos.pb messages-eos.pb check: messages.pb messages-binance.pb messages-bitcoin.pb messages-ble.pb messages-bootloader.pb messages-cardano.pb messages-common.pb messages-crypto.pb messages-debug.pb messages-ethereum.pb messages-management.pb messages-monero.pb messages-nem.pb messages-ripple.pb messages-stellar.pb messages-tezos.pb messages-eos.pb
%.pb: %.proto %.pb: %.proto
protoc -I/usr/include -I. $< -o $@ protoc -I/usr/include -I. $< -o $@

13
poetry.lock generated
View File

@ -346,6 +346,17 @@ ssh = ["bcrypt (>=3.1.5)"]
test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"] test = ["certifi", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-xdist"]
test-randomorder = ["pytest-randomly"] test-randomorder = ["pytest-randomly"]
[[package]]
name = "dbus-next"
version = "0.2.3"
description = "A zero-dependency DBus library for Python with asyncio support"
optional = false
python-versions = ">=3.6.0"
files = [
{file = "dbus_next-0.2.3-py3-none-any.whl", hash = "sha256:58948f9aff9db08316734c0be2a120f6dc502124d9642f55e90ac82ffb16a18b"},
{file = "dbus_next-0.2.3.tar.gz", hash = "sha256:f4eae26909332ada528c0a3549dda8d4f088f9b365153952a408e28023a626a5"},
]
[[package]] [[package]]
name = "demjson3" name = "demjson3"
version = "3.0.5" version = "3.0.5"
@ -1798,4 +1809,4 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-it
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.8.1" python-versions = "^3.8.1"
content-hash = "971d0f6f2926d839954b35b2029978046e282df3d8d595b1e176dc0cf37889fb" content-hash = "5de61e56cadc4aaa5c61a18e8e756d9e6cf195ff461f77ae27d076cd8350c856"

View File

@ -74,6 +74,9 @@ toiftool = {path = "./python/tools/toiftool", develop = true, python = ">=3.8"}
trezor-pylint-plugin = {path = "./tools/trezor-pylint-plugin", develop = true} trezor-pylint-plugin = {path = "./tools/trezor-pylint-plugin", develop = true}
trezor-core-tools = {path = "./core/tools", develop = true} trezor-core-tools = {path = "./core/tools", develop = true}
# ble
dbus-next = "*"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
scan-build = "*" scan-build = "*"
towncrier = "^23.6.0" towncrier = "^23.6.0"

View File

@ -6,3 +6,4 @@ libusb1>=1.6.4
construct>=2.9,!=2.10.55 construct>=2.9,!=2.10.55
typing_extensions>=4.7.1 typing_extensions>=4.7.1
construct-classes>=0.1.2 construct-classes>=0.1.2
dbus-next>=0.2.3

View File

@ -27,7 +27,7 @@ per-file-ignores =
helper-scripts/*:I helper-scripts/*:I
tools/*:I tools/*:I
tests/*:I tests/*:I
known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore] known-modules = libusb1:[usb1],hidapi:[hid],PyQt5:[PyQt5.QtWidgets,PyQt5.QtGui,PyQt5.QtCore],dbus-next:[dbus_next]
[isort] [isort]
profile = black profile = black

View File

@ -0,0 +1,59 @@
import typing as t
from .. import messages
from ..tools import session
if t.TYPE_CHECKING:
from ..client import TrezorClient
@session
def update(
client: "TrezorClient",
datfile: bytes,
binfile: bytes,
progress_update: t.Callable[[int], t.Any] = lambda _: None,
):
chunk_len = 4096
offset = 0
resp = client.call(
messages.UploadBLEFirmwareInit(init_data=datfile, binsize=len(binfile))
)
while isinstance(resp, messages.UploadBLEFirmwareNextChunk):
payload = binfile[offset : offset + chunk_len]
resp = client.call(messages.UploadBLEFirmwareChunk(data=payload))
progress_update(chunk_len)
offset += chunk_len
if isinstance(resp, messages.Success):
return
else:
raise RuntimeError(f"Unexpected message {resp}")
@session
def erase_bonds(
client: "TrezorClient",
):
resp = client.call(messages.EraseBonds())
if isinstance(resp, messages.Success):
return
else:
raise RuntimeError(f"Unexpected message {resp}")
@session
def disconnect(
client: "TrezorClient",
):
resp = client.call(messages.Disconnect())
if isinstance(resp, messages.Success):
return
else:
raise RuntimeError(f"Unexpected message {resp}")

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
import click import click
from .. import exceptions, transport from .. import exceptions, messages, transport
from ..client import TrezorClient from ..client import TrezorClient
from ..ui import ClickUI, ScriptUI from ..ui import ClickUI, ScriptUI
@ -110,6 +110,12 @@ class TrezorConnection:
except transport.DeviceIsBusy: except transport.DeviceIsBusy:
click.echo("Device is in use by another process.") click.echo("Device is in use by another process.")
sys.exit(1) sys.exit(1)
except exceptions.TrezorFailure as e:
if e.code is messages.FailureType.DeviceIsBusy:
click.echo(str(e))
sys.exit(1)
else:
raise e
except Exception: except Exception:
click.echo("Failed to find a Trezor device.") click.echo("Failed to find a Trezor device.")
if self.path is not None: if self.path is not None:

View File

@ -0,0 +1,139 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import json
import sys
import zipfile
from typing import TYPE_CHECKING, BinaryIO
import click
from .. import ble, exceptions
from ..transport.ble import BleProxy
from . import with_client
if TYPE_CHECKING:
from ..client import TrezorClient
@click.group(name="ble")
def cli() -> None:
"""BLE commands."""
@cli.command()
# fmt: off
@click.argument("package", type=click.File("rb"))
# fmt: on
@with_client
def update(
client: "TrezorClient",
package: BinaryIO,
) -> None:
"""Upload new BLE firmware to device."""
with zipfile.ZipFile(package) as archive:
manifest = archive.read("manifest.json")
mainfest_data = json.loads(manifest.decode("utf-8"))["manifest"]
for k in mainfest_data.keys():
binfile = archive.read(mainfest_data[k]["bin_file"])
datfile = archive.read(mainfest_data[k]["dat_file"])
"""Perform the final act of loading the firmware into Trezor."""
try:
click.echo("Uploading...\r", nl=False)
with click.progressbar(
label="Uploading", length=len(binfile), show_eta=False
) as bar:
ble.update(client, datfile, binfile, bar.update)
click.echo("Update successful.")
except exceptions.Cancelled:
click.echo("Update aborted on device.")
except exceptions.TrezorException as e:
click.echo(f"Update failed: {e}")
sys.exit(3)
@cli.command()
def connect() -> None:
"""Connect to the device via BLE."""
ble = BleProxy()
devices = [d for d in ble.lookup() if d.connected]
if len(devices) == 0:
click.echo("Scanning...")
devices = ble.scan()
if len(devices) == 0:
click.echo("No BLE devices found")
return
else:
click.echo("Found %d BLE device(s)" % len(devices))
for device in devices:
click.echo(f"Device: {device.name}, {device.address}")
device = devices[0]
click.echo(f"Connecting to {device.name}...")
ble.connect(device.address)
click.echo("Connected")
@with_client
def disconnect_device(client: "TrezorClient") -> None:
"""Disconnect from device side."""
try:
ble.disconnect(client)
except exceptions.Cancelled:
click.echo("Disconnect aborted on device.")
except exceptions.TrezorException as e:
click.echo(f"Disconnect failed: {e}")
sys.exit(3)
@cli.command()
@click.option("--device", is_flag=True, help="Disconnect from device side.")
def disconnect(device: bool) -> None:
if device:
disconnect_device()
else:
ble_proxy = BleProxy()
devices = [d for d in ble_proxy.lookup() if d.connected]
if len(devices) == 0:
click.echo("No BLE devices found")
return
ble_proxy.connect(devices[0].address)
ble_proxy.disconnect()
@cli.command()
@with_client
def erase_bonds(
client: "TrezorClient",
) -> None:
"""Erase BLE bonds on device."""
try:
ble.erase_bonds(client)
click.echo("Erase successful.")
except exceptions.Cancelled:
click.echo("Erase aborted on device.")
except exceptions.TrezorException as e:
click.echo(f"Update failed: {e}")
sys.exit(3)

View File

@ -33,6 +33,7 @@ from . import (
TrezorConnection, TrezorConnection,
benchmark, benchmark,
binance, binance,
ble,
btc, btc,
cardano, cardano,
crypto, crypto,
@ -88,6 +89,7 @@ COMMAND_ALIASES = {
"upgrade-firmware": firmware.update, "upgrade-firmware": firmware.update,
"firmware-upgrade": firmware.update, "firmware-upgrade": firmware.update,
"firmware-update": firmware.update, "firmware-update": firmware.update,
"ble-update": ble.update,
} }
@ -418,6 +420,7 @@ cli.add_command(tezos.cli)
cli.add_command(firmware.cli) cli.add_command(firmware.cli)
cli.add_command(debug.cli) cli.add_command(debug.cli)
cli.add_command(benchmark.cli) cli.add_command(benchmark.cli)
cli.add_command(ble.cli)
# #
# Main # Main

350
python/src/trezorlib/tealblue.py Executable file
View File

@ -0,0 +1,350 @@
# !/usr/bin/python3
# pyright: off
import asyncio
import io
import logging
import dbus_next
LOG = logging.getLogger(__name__)
def unwrap_properties(properties):
return {k: v.value for k, v in properties.items()}
class TealBlue:
@classmethod
async def create(cls):
self = cls()
self._bus = await dbus_next.aio.MessageBus(
bus_type=dbus_next.constants.BusType.SYSTEM, negotiate_unix_fd=True
).connect()
obj = await self.get_object("org.bluez", "/")
self._bluez = obj.get_interface("org.freedesktop.DBus.ObjectManager")
return self
async def find_adapter(self, mac_filter=""):
"""Find the first adapter matching mac_filter."""
objects = await self._bluez.call_get_managed_objects()
for path in sorted(objects.keys()):
interfaces = objects[path]
if "org.bluez.Adapter1" not in interfaces:
continue
properties = unwrap_properties(interfaces["org.bluez.Adapter1"])
if mac_filter not in properties["Address"]:
continue
return await Adapter.create(self, path, properties)
raise Exception("No adapter found")
async def get_object(self, name, path):
introspection = await self._bus.introspect(name, path)
obj = self._bus.get_proxy_object(name, path, introspection)
return obj
class Adapter:
@classmethod
async def create(cls, teal, path, properties):
self = cls()
self._teal = teal
self._path = path
self._properties = properties
obj = await self._teal.get_object("org.bluez", path)
self._object = obj.get_interface("org.bluez.Adapter1")
return self
def __repr__(self):
return "<tealblue.Adapter address=%s>" % (self._properties["Address"])
async def devices(self):
"""
Returns the devices that BlueZ has discovered.
"""
objects = await self._teal._bluez.call_get_managed_objects()
devices = []
for path in sorted(objects.keys()):
interfaces = objects[path]
if "org.bluez.Device1" not in interfaces:
continue
properties = unwrap_properties(interfaces["org.bluez.Device1"])
devices.append(await Device.create(self._teal, path, properties))
return devices
async def scan(self, timeout_s):
return await Scanner.create(self._teal, self, await self.devices(), timeout_s)
class Scanner:
@classmethod
async def create(cls, teal, adapter, initial_devices, timeout_s):
self = cls()
self._teal = teal
self._adapter = adapter
self._was_discovering = adapter._properties[
"Discovering"
] # TODO get current value, or watch property changes
self._queue = asyncio.Queue()
self.timeout_s = timeout_s
for device in initial_devices:
self._queue.put_nowait((device._path, device._properties))
self._teal._bluez.on_interfaces_added(self._on_iface_added)
if not self._was_discovering:
await self._adapter._object.call_start_discovery()
return self
def _on_iface_added(self, path, interfaces):
if "org.bluez.Device1" not in interfaces:
return
if not path.startswith(self._adapter._path + "/"):
return
properties = unwrap_properties(interfaces["org.bluez.Device1"])
self._queue.put_nowait((path, properties))
async def __aenter__(self):
return self
async def __aexit__(self, type, value, traceback):
if not self._was_discovering:
await self._adapter._object.call_stop_discovery()
self._teal._bluez.off_interfaces_added(self._on_iface_added)
def __aiter__(self):
return self
async def __anext__(self):
try:
(path, properties) = await asyncio.wait_for(
self._queue.get(), self.timeout_s
)
return await Device.create(self._teal, path, properties)
except asyncio.TimeoutError:
raise StopAsyncIteration
class Device:
@classmethod
async def create(cls, teal, path, properties):
self = cls()
self._teal = teal
self._path = path
self._properties = properties
self._services_resolved = asyncio.Event()
self._services = None
if properties["ServicesResolved"]:
self._services_resolved.set()
# Listen to device events (connect, disconnect, ServicesResolved, ...)
obj = await self._teal.get_object("org.bluez", path)
self._device = obj.get_interface("org.bluez.Device1")
obj = await self._teal.get_object("org.bluez", path)
self._device_props = obj.get_interface("org.freedesktop.DBus.Properties")
self._device_props.on_properties_changed(self._on_prop_changed)
return self
def __del__(self):
self._device_props.off_properties_changed(self._on_prop_changed)
def __repr__(self):
return "<tealblue.Device address=%s name=%r>" % (self.address, self.name)
def _on_prop_changed(self, _interface, changed_props, invalidated_props):
changed_props = unwrap_properties(changed_props)
LOG.debug(
f"prop changed: device {self._path} {changed_props.keys()} {invalidated_props}"
)
for key, value in changed_props.items():
self._properties[key] = value
for key in invalidated_props:
del self._properties[key]
if "ServicesResolved" in changed_props:
if changed_props["ServicesResolved"]:
self._services_resolved.set()
else:
self._services_resolved.clear()
async def connect(self):
await self._device.call_connect()
async def disconnect(self):
await self._device.call_disconnect()
async def services(self):
await self._services_resolved.wait()
if self._services is None:
self._services = {}
objects = await self._teal._bluez.call_get_managed_objects()
for path in sorted(objects.keys()):
if not path.startswith(self._path + "/"):
continue
if "org.bluez.GattService1" in objects[path]:
properties = unwrap_properties(
objects[path]["org.bluez.GattService1"]
)
service = Service(self._teal, self, path, properties)
self._services[service.uuid] = service
elif "org.bluez.GattCharacteristic1" in objects[path]:
properties = unwrap_properties(
objects[path]["org.bluez.GattCharacteristic1"]
)
characterstic = await Characteristic.create(
self._teal, self, path, properties
)
for service in self._services.values():
if properties["Service"] == service._path:
service.characteristics[characterstic.uuid] = characterstic
return self._services
@property
def connected(self):
return bool(self._properties["Connected"])
@property
def services_resolved(self):
return bool(self._properties["ServicesResolved"])
@property
def UUIDs(self):
return [str(s) for s in self._properties["UUIDs"]]
@property
def address(self):
return str(self._properties["Address"])
@property
def name(self):
if "Name" not in self._properties:
return None
return str(self._properties["Name"])
@property
def alias(self):
if "Alias" not in self._properties:
return None
return str(self._properties["Alias"])
class Service:
def __init__(self, teal, device, path, properties):
self._device = device
self._teal = teal
self._path = path
self._properties = properties
self.characteristics = {}
def __repr__(self):
return "<tealblue.Service device=%s uuid=%s>" % (
self._device.address,
self.uuid,
)
@property
def uuid(self):
return str(self._properties["UUID"])
class Characteristic:
def __init__(self):
self._properties = {}
@classmethod
async def create(cls, teal, device, path, properties):
self = cls()
self._device = device
self._teal = teal
self._path = path
self._properties = properties
self._values = asyncio.Queue()
obj = await self._teal.get_object("org.bluez", path)
self._char = obj.get_interface("org.bluez.GattCharacteristic1")
self._props = obj.get_interface("org.freedesktop.DBus.Properties")
self._props.on_properties_changed(self._on_prop_changed)
return self
def __repr__(self):
return "<tealblue.Characteristic device=%s uuid=%s>" % (
self._device.address,
self.uuid,
)
def __del__(self):
self._props.off_properties_changed(self._on_prop_changed)
def _on_prop_changed(self, _interface, changed_props, invalidated_props):
changed_props = unwrap_properties(changed_props)
LOG.debug(
f"prop changed: characteristic {changed_props.keys()} {invalidated_props}"
)
for key, value in changed_props.items():
self._properties[key] = bytes(value)
for key in invalidated_props:
del self._properties[key]
if "Value" in changed_props:
self._values.put_nowait(changed_props["Value"])
async def acquire(self, write: bool = False) -> tuple[io.FileIO, int]:
if write:
fd, mtu = await self._char.call_acquire_write({})
mode = "w"
else:
fd, mtu = await self._char.call_acquire_notify({})
mode = "r"
f = io.FileIO(fd, mode)
LOG.debug(f"acquired {self.uuid} ({mode})")
return f, mtu
async def read(self) -> bytes:
return bytes(await self._values.get())
async def write(self, value, command=True):
ty = "command" if command else "request"
await self._char.call_write_value(value, {"type": dbus_next.Variant("s", ty)})
#
# async def write(self, value, command=True):
# start = time.time()
# try:
# if command:
# await self._char.call_write_value(value, {"type": "command"})
# else:
# await self._char.call_write_value(value, {"type": "request"})
#
# except dbus_next.DBusError as e:
# if (
# e.type == "org.bluez.Error.Failed"
# and e.text == "Not connected"
# ):
# raise NotConnectedError()
# else:
# raise # some other error
#
# # Workaround: if the write took very long, it is possible the connection
# # broke (without causing an exception). So check whether we are still
# # connected.
# # I think this is a bug in BlueZ.
# if time.time() - start > 0.5:
# if not self._device._device_props.call_get("org.bluez.Device1", "Connected"):
# raise NotConnectedError()
async def start_notify(self):
await self._char.call_start_notify()
async def stop_notify(self):
await self._char.call_stop_notify()
@property
def uuid(self):
return str(self._properties["UUID"])

View File

@ -113,6 +113,7 @@ class Transport:
def all_transports() -> Iterable[Type["Transport"]]: def all_transports() -> Iterable[Type["Transport"]]:
from .ble import BleTransport
from .bridge import BridgeTransport from .bridge import BridgeTransport
from .hid import HidTransport from .hid import HidTransport
from .udp import UdpTransport from .udp import UdpTransport
@ -123,6 +124,7 @@ def all_transports() -> Iterable[Type["Transport"]]:
HidTransport, HidTransport,
UdpTransport, UdpTransport,
WebUsbTransport, WebUsbTransport,
BleTransport,
) )
return set(t for t in transports if t.ENABLED) return set(t for t in transports if t.ENABLED)

View File

@ -0,0 +1,265 @@
# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import asyncio
import logging
from dataclasses import dataclass
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from typing import TYPE_CHECKING, Any, Iterable, List, Optional
from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
if TYPE_CHECKING:
from ..models import TrezorModel
LOG = logging.getLogger(__name__)
NUS_SERVICE_UUID = "8c000001-a59b-4d58-a9ad-073df69fa1b1"
NUS_CHARACTERISTIC_RX = "8c000002-a59b-4d58-a9ad-073df69fa1b1"
NUS_CHARACTERISTIC_TX = "8c000003-a59b-4d58-a9ad-073df69fa1b1"
@dataclass
class Device:
address: str
name: str
connected: bool
class BleTransport(ProtocolBasedTransport):
ENABLED = True
PATH_PREFIX = "ble"
_ble = None
def __init__(self, mac_addr: str) -> None:
self.device = mac_addr
super().__init__(protocol=ProtocolV1(self, replen=244))
def get_path(self) -> str:
return "{}:{}".format(self.PATH_PREFIX, self.device)
def find_debug(self) -> "BleTransport":
return BleTransport(self.device)
@classmethod
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["BleTransport"]:
devices = cls.ble().lookup()
return [BleTransport(device.address) for device in devices if device.connected]
@classmethod
def _try_path(cls, path: str) -> "BleTransport":
devices = cls.enumerate(None)
devices = [d for d in devices if d.device == path]
if len(devices) == 0:
raise TransportException(f"No BLE device: {path}")
return devices[0]
@classmethod
def find_by_path(cls, path: str, prefix_search: bool = False) -> "BleTransport":
if not prefix_search:
raise TransportException
if prefix_search:
return super().find_by_path(path, prefix_search)
else:
raise TransportException(f"No BLE device: {path}")
def open(self) -> None:
self.ble().connect(self.device)
def close(self) -> None:
pass
def write_chunk(self, chunk: bytes) -> None:
self.ble().write(chunk)
def read_chunk(self) -> bytes:
chunk = self.ble().read()
# LOG.log(DUMP_PACKETS, f"received packet: {chunk.hex()}")
if len(chunk) != 64:
raise TransportException(f"Unexpected chunk size: {len(chunk)}")
return bytearray(chunk)
@classmethod
def ble(cls) -> "BleProxy":
if cls._ble is None:
cls._ble = BleProxy()
return cls._ble
class BleProxy:
pipe = None
process = None
def __init__(self):
if self.pipe is not None:
return
parent_pipe, child_pipe = Pipe()
self.pipe = parent_pipe
self.process = Process(target=BleAsync, args=(child_pipe,), daemon=True)
self.process.start()
def __getattr__(self, name: str):
def f(*args: Any, **kwargs: Any):
assert self.pipe is not None
self.pipe.send((name, args, kwargs))
result = self.pipe.recv()
if isinstance(result, BaseException):
raise result
return result
return f
class BleAsync:
def __init__(self, pipe: Connection):
asyncio.run(self.main(pipe))
async def main(self, pipe: Connection):
from ..tealblue import TealBlue
tb = await TealBlue.create()
# TODO: add cli option for mac_filter and pass it here
self.adapter = await tb.find_adapter()
# TODO: currently only one concurrent device is supported
# To support more devices, connect() needs to return a Connection and also has to
# spawn a task that will forward data between that Connection and rx,tx.
self.current = None
self.rx = None
self.tx = None
self.devices = {}
await self.lookup() # populate self.devices
LOG.debug("async BLE process started")
while True:
await ready(pipe)
cmd, args, kwargs = pipe.recv()
try:
result = await getattr(self, cmd)(*args, **kwargs)
except Exception as e:
LOG.exception("Error in async BLE process:")
await ready(pipe, write=True)
pipe.send(e)
break
else:
await ready(pipe, write=True)
pipe.send(result)
async def lookup(self) -> List[Device]:
self.devices.clear()
for device in await self.adapter.devices():
if NUS_SERVICE_UUID in device.UUIDs:
self.devices[device.address] = device
return [
Device(device.address, device.name, device.connected)
for device in self.devices.values()
]
async def scan(self) -> List[Device]:
LOG.debug("Initiating scan")
# TODO: configurable timeout
scanner = await self.adapter.scan(2)
self.devices.clear()
async with scanner:
async for device in scanner:
if NUS_SERVICE_UUID in device.UUIDs:
if device.address not in self.devices:
LOG.debug(f"scan: {device.address}: {device.name}")
self.devices[device.address] = device
return [
Device(device.address, device.name, device.connected)
for device in self.devices.values()
]
async def connect(self, address: str):
if self.current == address:
return
# elif self.current is not None:
# self.devices[self.current].disconnect()
ble_device = self.devices[address]
if not ble_device.connected:
LOG.info("Connecting to %s (%s)..." % (ble_device.name, ble_device.address))
await ble_device.connect()
else:
LOG.info("Connected to %s (%s)." % (ble_device.name, ble_device.address))
services = await ble_device.services()
nus_service = services[NUS_SERVICE_UUID]
self.rx, _mtu = await nus_service.characteristics[
NUS_CHARACTERISTIC_RX
].acquire(write=True)
self.tx, _mtu = await nus_service.characteristics[
NUS_CHARACTERISTIC_TX
].acquire()
self.current = address
async def disconnect(self):
if self.current is None:
return
ble_device = self.devices[self.current]
if ble_device.connected:
LOG.info(
"Disconnecting from %s (%s)..." % (ble_device.name, ble_device.address)
)
await ble_device.disconnect()
else:
LOG.info(
"Disconnected from %s (%s)." % (ble_device.name, ble_device.address)
)
self.current = None
self.rx = None
self.tx = None
async def read(self):
assert self.tx is not None
await ready(self.tx)
return self.tx.read()
async def write(self, chunk: bytes):
assert self.rx is not None
await ready(self.rx, write=True)
self.rx.write(chunk)
async def ready(f: Any, write: bool = False):
"""Asynchronously wait for file-like object to become ready for reading or writing."""
fd = f.fileno()
loop = asyncio.get_event_loop()
event = asyncio.Event()
if write:
def callback():
event.set()
loop.remove_writer(fd)
loop.add_writer(fd, callback)
else:
def callback():
event.set()
loop.remove_reader(fd)
loop.add_reader(fd, callback)
await event.wait()

View File

@ -71,8 +71,9 @@ class Protocol:
its messages. its messages.
""" """
def __init__(self, handle: Handle) -> None: def __init__(self, handle: Handle, replen: int = REPLEN) -> None:
self.handle = handle self.handle = handle
self.replen = replen
self.session_counter = 0 self.session_counter = 0
# XXX we might be able to remove this now that TrezorClient does session handling # XXX we might be able to remove this now that TrezorClient does session handling
@ -129,10 +130,10 @@ class ProtocolV1(Protocol):
while buffer: while buffer:
# Report ID, data padded to 63 bytes # Report ID, data padded to 63 bytes
chunk = b"?" + buffer[: REPLEN - 1] chunk = b"?" + buffer[: self.replen - 1]
chunk = chunk.ljust(REPLEN, b"\x00") chunk = chunk.ljust(self.replen, b"\x00")
self.handle.write_chunk(chunk) self.handle.write_chunk(chunk)
buffer = buffer[63:] buffer = buffer[self.replen - 1 :]
def read(self) -> MessagePayload: def read(self) -> MessagePayload:
buffer = bytearray() buffer = bytearray()

View File

@ -64,6 +64,8 @@ class WebUsbHandle:
self.handle.claimInterface(self.interface) self.handle.claimInterface(self.interface)
except usb1.USBErrorAccess as e: except usb1.USBErrorAccess as e:
raise DeviceIsBusy(self.device) from e raise DeviceIsBusy(self.device) from e
except usb1.USBErrorBusy as e:
raise DeviceIsBusy(self.device) from e
def close(self) -> None: def close(self) -> None:
if self.handle is not None: if self.handle is not None: