1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-19 11:50:58 +00:00

transport: cleaner Transport list instantiation

Previously if an import of a dependent module (usb1, hid) failed, import
of the whole transport module would fail. This was resolved by catching
ImportErrors in the all_transports method.

This had two drawbacks:
- if something other than ImportError happened - e.g., libusb would
raise OSError if it couldn't find libusb.so - all_transports would crash
anyway
- at the same time, if a legitimately needed dependency
(typing_extensions) was missing, this would be masked by the ImportError
handling.

Instead, we unconditionally import the modules, and inside each one,
wrap dependencies in a try-except.

As an added benefit, it is now possible to disable a transport just by
setting SomeTransport.ENABLED = False
This commit is contained in:
matejcik 2018-11-23 12:09:49 +01:00
parent f04458d6ea
commit 69ef1f0acd
5 changed files with 33 additions and 30 deletions

View File

@ -61,6 +61,7 @@ class Transport:
""" """
PATH_PREFIX = None # type: str PATH_PREFIX = None # type: str
ENABLED = False
def __str__(self) -> str: def __str__(self) -> str:
return self.get_path() return self.get_path()
@ -100,21 +101,16 @@ class Transport:
def all_transports() -> Iterable[Type[Transport]]: def all_transports() -> Iterable[Type[Transport]]:
transports = set() # type: Set[Type[Transport]] from .bridge import BridgeTransport
for modname in ("bridge", "hid", "udp", "webusb"): from .hid import HidTransport
try: from .udp import UdpTransport
# Import the module and find the Transport class. from .webusb import WebUsbTransport
# To avoid iterating over every item, the module should assign its Transport class
# to a constant named TRANSPORT.
module = importlib.import_module("." + modname, __name__)
try:
transports.add(getattr(module, "TRANSPORT"))
except AttributeError:
LOG.warning("Skipping broken module {}".format(modname))
except ImportError as e:
LOG.info("Failed to import module {}: {}".format(modname, e))
return transports return set(
cls
for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport)
if cls.ENABLED
)
def enumerate_devices() -> Iterable[Transport]: def enumerate_devices() -> Iterable[Transport]:

View File

@ -102,6 +102,7 @@ class BridgeTransport(Transport):
""" """
PATH_PREFIX = "bridge" PATH_PREFIX = "bridge"
ENABLED = True
def __init__( def __init__(
self, device: Dict[str, Any], legacy: bool, debug: bool = False self, device: Dict[str, Any], legacy: bool, debug: bool = False
@ -177,6 +178,3 @@ class BridgeTransport(Transport):
extra={"protobuf": msg}, extra={"protobuf": msg},
) )
return msg return msg
TRANSPORT = BridgeTransport

View File

@ -14,15 +14,23 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import logging
import sys import sys
import time import time
from typing import Any, Dict, Iterable from typing import Any, Dict, Iterable
import hid
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__)
try:
import hid
except Exception as e:
LOG.info("HID transport is disabled: {}".format(e))
hid = None
HidDevice = Dict[str, Any] HidDevice = Dict[str, Any]
HidDeviceHandle = Any HidDeviceHandle = Any
@ -87,6 +95,7 @@ class HidTransport(ProtocolBasedTransport):
""" """
PATH_PREFIX = "hid" PATH_PREFIX = "hid"
ENABLED = hid is not None
def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None: def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None:
if hid_handle is None: if hid_handle is None:
@ -135,6 +144,3 @@ def is_wirelink(dev: HidDevice) -> bool:
def is_debuglink(dev: HidDevice) -> bool: def is_debuglink(dev: HidDevice) -> bool:
return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1 return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1
TRANSPORT = HidTransport

View File

@ -30,6 +30,7 @@ class UdpTransport(ProtocolBasedTransport):
DEFAULT_HOST = "127.0.0.1" DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 21324 DEFAULT_PORT = 21324
PATH_PREFIX = "udp" PATH_PREFIX = "udp"
ENABLED = True
def __init__(self, device: str = None) -> None: def __init__(self, device: str = None) -> None:
if not device: if not device:
@ -123,6 +124,3 @@ class UdpTransport(ProtocolBasedTransport):
if len(chunk) != 64: if len(chunk) != 64:
raise TransportException("Unexpected chunk size: %d" % len(chunk)) raise TransportException("Unexpected chunk size: %d" % len(chunk))
return bytearray(chunk) return bytearray(chunk)
TRANSPORT = UdpTransport

View File

@ -15,15 +15,22 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import atexit import atexit
import logging
import sys import sys
import time import time
from typing import Iterable, Optional from typing import Iterable, Optional
import usb1
from . import TREZORS, UDEV_RULES_STR, TransportException from . import TREZORS, UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__)
try:
import usb1
except Exception as e:
LOG.warning("WebUSB transport is disabled: {}".format(e))
usb1 = None
if False: if False:
# mark Optional as used, otherwise it only exists in comments # mark Optional as used, otherwise it only exists in comments
Optional Optional
@ -84,6 +91,7 @@ class WebUsbTransport(ProtocolBasedTransport):
""" """
PATH_PREFIX = "webusb" PATH_PREFIX = "webusb"
ENABLED = usb1 is not None
context = None context = None
def __init__( def __init__(
@ -150,6 +158,3 @@ def dev_to_str(dev: usb1.USBDevice) -> str:
return ":".join( return ":".join(
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList() str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
) )
TRANSPORT = WebUsbTransport