From e779a251fb95c6a23daa26f3827e549d24090638 Mon Sep 17 00:00:00 2001 From: matejcik Date: Thu, 24 May 2018 19:14:05 +0200 Subject: [PATCH] transport: better ways to handle errors when enumerating devices --- trezorlib/client.py | 1 + trezorlib/transport/__init__.py | 69 ++++++++++++++++++--------------- trezorlib/transport/bridge.py | 3 ++ trezorlib/transport/hid.py | 3 ++ trezorlib/transport/udp.py | 3 ++ trezorlib/transport/webusb.py | 3 ++ 6 files changed, 50 insertions(+), 32 deletions(-) diff --git a/trezorlib/client.py b/trezorlib/client.py index c27fb8393e..b0b9780696 100644 --- a/trezorlib/client.py +++ b/trezorlib/client.py @@ -150,6 +150,7 @@ class BaseClient(object): # Implements very basic layer of sending raw protobuf # messages to device and getting its response back. def __init__(self, transport, **kwargs): + LOG.info("creating client instance for device: {}".format(transport.get_path())) self.transport = transport super(BaseClient, self).__init__() # *args, **kwargs) diff --git a/trezorlib/transport/__init__.py b/trezorlib/transport/__init__.py index e6a4bca13e..0aa22a92e9 100644 --- a/trezorlib/transport/__init__.py +++ b/trezorlib/transport/__init__.py @@ -17,6 +17,13 @@ # You should have received a copy of the GNU Lesser General Public License # along with this library. If not, see . +import importlib +import logging + +from typing import Iterable, Type, List, Set + +LOG = logging.getLogger(__name__) + class TransportException(Exception): pass @@ -63,54 +70,52 @@ class Transport(object): raise TransportException('{} device not found: {}'.format(cls.PATH_PREFIX, path)) -def all_transports(): - transports = [] - try: - from .bridge import BridgeTransport - transports.append(BridgeTransport) - except: - pass - - try: - from .hid import HidTransport - transports.append(HidTransport) - except: - pass - - try: - from .udp import UdpTransport - transports.append(UdpTransport) - except: - pass - - try: - from .webusb import WebUsbTransport - transports.append(WebUsbTransport) - except: - pass +def all_transports() -> Iterable[Type[Transport]]: + transports = set() # type: Set[Type[Transport]] + for modname in ("bridge", "hid", "udp", "webusb"): + try: + # Import the module and find the Transport class. + # To avoid iterating over every item, the module should assign its Transport class + # to a constant named TRANSPORT. + module = importlib.import_module("." + modname, __name__) + try: + transports.add(getattr(module, "TRANSPORT")) + except AttributeError: + LOG.warning("Skipping broken module {}".format(modname)) + except ImportError as e: + LOG.info("Failed to import module {}: {}".format(modname, e)) return transports -def enumerate_devices(): - return [device - for transport in all_transports() - for device in transport.enumerate()] +def enumerate_devices() -> Iterable[Transport]: + devices = [] # type: List[Transport] + for transport in all_transports(): + try: + found = transport.enumerate() + LOG.info("Enumerating {}: found {} devices".format(transport.__name__, len(found))) + devices.extend(found) + except NotImplementedError: + LOG.error("{} does not implement device enumeration".format(transport.__name__)) + except Exception as e: + LOG.error("Failed to enumerate {}. {}: {}".format(transport.__name__, e.__class__.__name__, e)) + return devices -def get_transport(path=None, prefix_search=False): +def get_transport(path: str = None, prefix_search: bool = False) -> Transport: if path is None: try: - return enumerate_devices()[0] + return next(iter(enumerate_devices())) except IndexError: raise Exception("No TREZOR device found") from None # Find whether B is prefix of A (transport name is part of the path) # or A is prefix of B (path is a prefix, or a name, of transport). # This naively expects that no two transports have a common prefix. - def match_prefix(a, b): + def match_prefix(a: str, b: str) -> bool: return a.startswith(b) or b.startswith(a) + LOG.info("looking for device by {}: {}".format("prefix" if prefix_search else "full path", path)) transports = [t for t in all_transports() if match_prefix(path, t.PATH_PREFIX)] if transports: return transports[0].find_by_path(path, prefix_search=prefix_search) diff --git a/trezorlib/transport/bridge.py b/trezorlib/transport/bridge.py index 898e6a62c4..4b20780abe 100644 --- a/trezorlib/transport/bridge.py +++ b/trezorlib/transport/bridge.py @@ -106,3 +106,6 @@ class BridgeTransport(Transport): extra={'protobuf': msg}) self.response = None return msg + + +TRANSPORT = BridgeTransport diff --git a/trezorlib/transport/hid.py b/trezorlib/transport/hid.py index 9a47354ab8..991998a826 100644 --- a/trezorlib/transport/hid.py +++ b/trezorlib/transport/hid.py @@ -180,3 +180,6 @@ def is_wirelink(dev): def is_debuglink(dev): return (dev['usage_page'] == 0xFF01 or dev['interface_number'] == 1) + + +TRANSPORT = HidTransport diff --git a/trezorlib/transport/udp.py b/trezorlib/transport/udp.py index ffcfb250f6..631cd0a18e 100644 --- a/trezorlib/transport/udp.py +++ b/trezorlib/transport/udp.py @@ -124,3 +124,6 @@ class UdpTransport(Transport): if len(chunk) != 64: raise TransportException('Unexpected chunk size: %d' % len(chunk)) return bytearray(chunk) + + +TRANSPORT = UdpTransport diff --git a/trezorlib/transport/webusb.py b/trezorlib/transport/webusb.py index b9782f748d..81ea75a713 100644 --- a/trezorlib/transport/webusb.py +++ b/trezorlib/transport/webusb.py @@ -188,3 +188,6 @@ def is_vendor_class(dev): def dev_to_str(dev): return ':'.join(str(x) for x in ['%03i' % (dev.getBusNumber(), )] + dev.getPortNumberList()) + + +TRANSPORT = WebUsbTransport