diff --git a/trezorlib/transport/__init__.py b/trezorlib/transport/__init__.py index 4d8d5e86f7..22e1847afa 100644 --- a/trezorlib/transport/__init__.py +++ b/trezorlib/transport/__init__.py @@ -61,6 +61,7 @@ class Transport: """ PATH_PREFIX = None # type: str + ENABLED = False def __str__(self) -> str: return self.get_path() @@ -100,21 +101,16 @@ class Transport: def all_transports() -> Iterable[Type[Transport]]: - transports = set() # type: Set[Type[Transport]] - for modname in ("bridge", "hid", "udp", "webusb"): - try: - # Import the module and find the Transport class. - # To avoid iterating over every item, the module should assign its Transport class - # to a constant named TRANSPORT. - module = importlib.import_module("." + modname, __name__) - try: - transports.add(getattr(module, "TRANSPORT")) - except AttributeError: - LOG.warning("Skipping broken module {}".format(modname)) - except ImportError as e: - LOG.info("Failed to import module {}: {}".format(modname, e)) + from .bridge import BridgeTransport + from .hid import HidTransport + from .udp import UdpTransport + from .webusb import WebUsbTransport - return transports + return set( + cls + for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport) + if cls.ENABLED + ) def enumerate_devices() -> Iterable[Transport]: diff --git a/trezorlib/transport/bridge.py b/trezorlib/transport/bridge.py index 29615485cc..51446d02e6 100644 --- a/trezorlib/transport/bridge.py +++ b/trezorlib/transport/bridge.py @@ -102,6 +102,7 @@ class BridgeTransport(Transport): """ PATH_PREFIX = "bridge" + ENABLED = True def __init__( self, device: Dict[str, Any], legacy: bool, debug: bool = False @@ -177,6 +178,3 @@ class BridgeTransport(Transport): extra={"protobuf": msg}, ) return msg - - -TRANSPORT = BridgeTransport diff --git a/trezorlib/transport/hid.py b/trezorlib/transport/hid.py index 9c4aa385ee..656e415906 100644 --- a/trezorlib/transport/hid.py +++ b/trezorlib/transport/hid.py @@ -14,15 +14,23 @@ # You should have received a copy of the License along with this library. # If not, see . +import logging import sys import time from typing import Any, Dict, Iterable -import hid - from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 +LOG = logging.getLogger(__name__) + +try: + import hid +except Exception as e: + LOG.info("HID transport is disabled: {}".format(e)) + hid = None + + HidDevice = Dict[str, Any] HidDeviceHandle = Any @@ -87,6 +95,7 @@ class HidTransport(ProtocolBasedTransport): """ PATH_PREFIX = "hid" + ENABLED = hid is not None def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None: if hid_handle is None: @@ -135,6 +144,3 @@ def is_wirelink(dev: HidDevice) -> bool: def is_debuglink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1 - - -TRANSPORT = HidTransport diff --git a/trezorlib/transport/udp.py b/trezorlib/transport/udp.py index fd2ee5f790..bb0ccfe782 100644 --- a/trezorlib/transport/udp.py +++ b/trezorlib/transport/udp.py @@ -30,6 +30,7 @@ class UdpTransport(ProtocolBasedTransport): DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 21324 PATH_PREFIX = "udp" + ENABLED = True def __init__(self, device: str = None) -> None: if not device: @@ -123,6 +124,3 @@ class UdpTransport(ProtocolBasedTransport): if len(chunk) != 64: raise TransportException("Unexpected chunk size: %d" % len(chunk)) return bytearray(chunk) - - -TRANSPORT = UdpTransport diff --git a/trezorlib/transport/webusb.py b/trezorlib/transport/webusb.py index 257131c94c..b0221aecb2 100644 --- a/trezorlib/transport/webusb.py +++ b/trezorlib/transport/webusb.py @@ -15,15 +15,22 @@ # If not, see . import atexit +import logging import sys import time from typing import Iterable, Optional -import usb1 - from . import TREZORS, UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 +LOG = logging.getLogger(__name__) + +try: + import usb1 +except Exception as e: + LOG.warning("WebUSB transport is disabled: {}".format(e)) + usb1 = None + if False: # mark Optional as used, otherwise it only exists in comments Optional @@ -84,6 +91,7 @@ class WebUsbTransport(ProtocolBasedTransport): """ PATH_PREFIX = "webusb" + ENABLED = usb1 is not None context = None def __init__( @@ -150,6 +158,3 @@ def dev_to_str(dev: usb1.USBDevice) -> str: return ":".join( str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList() ) - - -TRANSPORT = WebUsbTransport