mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-19 03:40:59 +00:00
transport: cleaner Transport list instantiation
Previously if an import of a dependent module (usb1, hid) failed, import of the whole transport module would fail. This was resolved by catching ImportErrors in the all_transports method. This had two drawbacks: - if something other than ImportError happened - e.g., libusb would raise OSError if it couldn't find libusb.so - all_transports would crash anyway - at the same time, if a legitimately needed dependency (typing_extensions) was missing, this would be masked by the ImportError handling. Instead, we unconditionally import the modules, and inside each one, wrap dependencies in a try-except. As an added benefit, it is now possible to disable a transport just by setting SomeTransport.ENABLED = False
This commit is contained in:
parent
f04458d6ea
commit
69ef1f0acd
@ -61,6 +61,7 @@ class Transport:
|
||||
"""
|
||||
|
||||
PATH_PREFIX = None # type: str
|
||||
ENABLED = False
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.get_path()
|
||||
@ -100,21 +101,16 @@ class Transport:
|
||||
|
||||
|
||||
def all_transports() -> Iterable[Type[Transport]]:
|
||||
transports = set() # type: Set[Type[Transport]]
|
||||
for modname in ("bridge", "hid", "udp", "webusb"):
|
||||
try:
|
||||
# Import the module and find the Transport class.
|
||||
# To avoid iterating over every item, the module should assign its Transport class
|
||||
# to a constant named TRANSPORT.
|
||||
module = importlib.import_module("." + modname, __name__)
|
||||
try:
|
||||
transports.add(getattr(module, "TRANSPORT"))
|
||||
except AttributeError:
|
||||
LOG.warning("Skipping broken module {}".format(modname))
|
||||
except ImportError as e:
|
||||
LOG.info("Failed to import module {}: {}".format(modname, e))
|
||||
from .bridge import BridgeTransport
|
||||
from .hid import HidTransport
|
||||
from .udp import UdpTransport
|
||||
from .webusb import WebUsbTransport
|
||||
|
||||
return transports
|
||||
return set(
|
||||
cls
|
||||
for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport)
|
||||
if cls.ENABLED
|
||||
)
|
||||
|
||||
|
||||
def enumerate_devices() -> Iterable[Transport]:
|
||||
|
@ -102,6 +102,7 @@ class BridgeTransport(Transport):
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "bridge"
|
||||
ENABLED = True
|
||||
|
||||
def __init__(
|
||||
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
||||
@ -177,6 +178,3 @@ class BridgeTransport(Transport):
|
||||
extra={"protobuf": msg},
|
||||
)
|
||||
return msg
|
||||
|
||||
|
||||
TRANSPORT = BridgeTransport
|
||||
|
@ -14,15 +14,23 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable
|
||||
|
||||
import hid
|
||||
|
||||
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import hid
|
||||
except Exception as e:
|
||||
LOG.info("HID transport is disabled: {}".format(e))
|
||||
hid = None
|
||||
|
||||
|
||||
HidDevice = Dict[str, Any]
|
||||
HidDeviceHandle = Any
|
||||
|
||||
@ -87,6 +95,7 @@ class HidTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "hid"
|
||||
ENABLED = hid is not None
|
||||
|
||||
def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None:
|
||||
if hid_handle is None:
|
||||
@ -135,6 +144,3 @@ def is_wirelink(dev: HidDevice) -> bool:
|
||||
|
||||
def is_debuglink(dev: HidDevice) -> bool:
|
||||
return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1
|
||||
|
||||
|
||||
TRANSPORT = HidTransport
|
||||
|
@ -30,6 +30,7 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 21324
|
||||
PATH_PREFIX = "udp"
|
||||
ENABLED = True
|
||||
|
||||
def __init__(self, device: str = None) -> None:
|
||||
if not device:
|
||||
@ -123,6 +124,3 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
if len(chunk) != 64:
|
||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||
return bytearray(chunk)
|
||||
|
||||
|
||||
TRANSPORT = UdpTransport
|
||||
|
@ -15,15 +15,22 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import usb1
|
||||
|
||||
from . import TREZORS, UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import usb1
|
||||
except Exception as e:
|
||||
LOG.warning("WebUSB transport is disabled: {}".format(e))
|
||||
usb1 = None
|
||||
|
||||
if False:
|
||||
# mark Optional as used, otherwise it only exists in comments
|
||||
Optional
|
||||
@ -84,6 +91,7 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
"""
|
||||
|
||||
PATH_PREFIX = "webusb"
|
||||
ENABLED = usb1 is not None
|
||||
context = None
|
||||
|
||||
def __init__(
|
||||
@ -150,6 +158,3 @@ def dev_to_str(dev: usb1.USBDevice) -> str:
|
||||
return ":".join(
|
||||
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
||||
)
|
||||
|
||||
|
||||
TRANSPORT = WebUsbTransport
|
||||
|
Loading…
Reference in New Issue
Block a user