mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-19 11:50:58 +00:00
transport: cleaner Transport list instantiation
Previously if an import of a dependent module (usb1, hid) failed, import of the whole transport module would fail. This was resolved by catching ImportErrors in the all_transports method. This had two drawbacks: - if something other than ImportError happened - e.g., libusb would raise OSError if it couldn't find libusb.so - all_transports would crash anyway - at the same time, if a legitimately needed dependency (typing_extensions) was missing, this would be masked by the ImportError handling. Instead, we unconditionally import the modules, and inside each one, wrap dependencies in a try-except. As an added benefit, it is now possible to disable a transport just by setting SomeTransport.ENABLED = False
This commit is contained in:
parent
f04458d6ea
commit
69ef1f0acd
@ -61,6 +61,7 @@ class Transport:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATH_PREFIX = None # type: str
|
PATH_PREFIX = None # type: str
|
||||||
|
ENABLED = False
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return self.get_path()
|
return self.get_path()
|
||||||
@ -100,21 +101,16 @@ class Transport:
|
|||||||
|
|
||||||
|
|
||||||
def all_transports() -> Iterable[Type[Transport]]:
|
def all_transports() -> Iterable[Type[Transport]]:
|
||||||
transports = set() # type: Set[Type[Transport]]
|
from .bridge import BridgeTransport
|
||||||
for modname in ("bridge", "hid", "udp", "webusb"):
|
from .hid import HidTransport
|
||||||
try:
|
from .udp import UdpTransport
|
||||||
# Import the module and find the Transport class.
|
from .webusb import WebUsbTransport
|
||||||
# To avoid iterating over every item, the module should assign its Transport class
|
|
||||||
# to a constant named TRANSPORT.
|
|
||||||
module = importlib.import_module("." + modname, __name__)
|
|
||||||
try:
|
|
||||||
transports.add(getattr(module, "TRANSPORT"))
|
|
||||||
except AttributeError:
|
|
||||||
LOG.warning("Skipping broken module {}".format(modname))
|
|
||||||
except ImportError as e:
|
|
||||||
LOG.info("Failed to import module {}: {}".format(modname, e))
|
|
||||||
|
|
||||||
return transports
|
return set(
|
||||||
|
cls
|
||||||
|
for cls in (BridgeTransport, HidTransport, UdpTransport, WebUsbTransport)
|
||||||
|
if cls.ENABLED
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def enumerate_devices() -> Iterable[Transport]:
|
def enumerate_devices() -> Iterable[Transport]:
|
||||||
|
@ -102,6 +102,7 @@ class BridgeTransport(Transport):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATH_PREFIX = "bridge"
|
PATH_PREFIX = "bridge"
|
||||||
|
ENABLED = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
self, device: Dict[str, Any], legacy: bool, debug: bool = False
|
||||||
@ -177,6 +178,3 @@ class BridgeTransport(Transport):
|
|||||||
extra={"protobuf": msg},
|
extra={"protobuf": msg},
|
||||||
)
|
)
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
|
|
||||||
TRANSPORT = BridgeTransport
|
|
||||||
|
@ -14,15 +14,23 @@
|
|||||||
# You should have received a copy of the License along with this library.
|
# You should have received a copy of the License along with this library.
|
||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Iterable
|
from typing import Any, Dict, Iterable
|
||||||
|
|
||||||
import hid
|
|
||||||
|
|
||||||
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import hid
|
||||||
|
except Exception as e:
|
||||||
|
LOG.info("HID transport is disabled: {}".format(e))
|
||||||
|
hid = None
|
||||||
|
|
||||||
|
|
||||||
HidDevice = Dict[str, Any]
|
HidDevice = Dict[str, Any]
|
||||||
HidDeviceHandle = Any
|
HidDeviceHandle = Any
|
||||||
|
|
||||||
@ -87,6 +95,7 @@ class HidTransport(ProtocolBasedTransport):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATH_PREFIX = "hid"
|
PATH_PREFIX = "hid"
|
||||||
|
ENABLED = hid is not None
|
||||||
|
|
||||||
def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None:
|
def __init__(self, device: HidDevice, hid_handle: HidHandle = None) -> None:
|
||||||
if hid_handle is None:
|
if hid_handle is None:
|
||||||
@ -135,6 +144,3 @@ def is_wirelink(dev: HidDevice) -> bool:
|
|||||||
|
|
||||||
def is_debuglink(dev: HidDevice) -> bool:
|
def is_debuglink(dev: HidDevice) -> bool:
|
||||||
return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1
|
return dev["usage_page"] == 0xFF01 or dev["interface_number"] == 1
|
||||||
|
|
||||||
|
|
||||||
TRANSPORT = HidTransport
|
|
||||||
|
@ -30,6 +30,7 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
DEFAULT_HOST = "127.0.0.1"
|
DEFAULT_HOST = "127.0.0.1"
|
||||||
DEFAULT_PORT = 21324
|
DEFAULT_PORT = 21324
|
||||||
PATH_PREFIX = "udp"
|
PATH_PREFIX = "udp"
|
||||||
|
ENABLED = True
|
||||||
|
|
||||||
def __init__(self, device: str = None) -> None:
|
def __init__(self, device: str = None) -> None:
|
||||||
if not device:
|
if not device:
|
||||||
@ -123,6 +124,3 @@ class UdpTransport(ProtocolBasedTransport):
|
|||||||
if len(chunk) != 64:
|
if len(chunk) != 64:
|
||||||
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
raise TransportException("Unexpected chunk size: %d" % len(chunk))
|
||||||
return bytearray(chunk)
|
return bytearray(chunk)
|
||||||
|
|
||||||
|
|
||||||
TRANSPORT = UdpTransport
|
|
||||||
|
@ -15,15 +15,22 @@
|
|||||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from typing import Iterable, Optional
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
import usb1
|
|
||||||
|
|
||||||
from . import TREZORS, UDEV_RULES_STR, TransportException
|
from . import TREZORS, UDEV_RULES_STR, TransportException
|
||||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import usb1
|
||||||
|
except Exception as e:
|
||||||
|
LOG.warning("WebUSB transport is disabled: {}".format(e))
|
||||||
|
usb1 = None
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
# mark Optional as used, otherwise it only exists in comments
|
# mark Optional as used, otherwise it only exists in comments
|
||||||
Optional
|
Optional
|
||||||
@ -84,6 +91,7 @@ class WebUsbTransport(ProtocolBasedTransport):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
PATH_PREFIX = "webusb"
|
PATH_PREFIX = "webusb"
|
||||||
|
ENABLED = usb1 is not None
|
||||||
context = None
|
context = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -150,6 +158,3 @@ def dev_to_str(dev: usb1.USBDevice) -> str:
|
|||||||
return ":".join(
|
return ":".join(
|
||||||
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
str(x) for x in ["%03i" % (dev.getBusNumber(),)] + dev.getPortNumberList()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
TRANSPORT = WebUsbTransport
|
|
||||||
|
Loading…
Reference in New Issue
Block a user