From a4bcc95deb467e6042fdc3886e564beef6a6e186 Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 26 Nov 2021 16:31:35 +0100 Subject: [PATCH] feat(python): introduce Trezor models This keeps information about vendors and USB IDs in one place, and allows us to extend with model-specific information later. By default, this should be backwards-compatible -- TrezorClient can optionally accept model information, and if not, it will try to guess based on Features. It is possible to specify which models to look for in transport enumeration. Bridge and UDP transports ignore the parameter, because they can't know what model is on the other side. supersedes #1448 and #1449 --- python/.changelog.d/1967.changed | 1 + python/src/trezorlib/__init__.py | 7 --- python/src/trezorlib/client.py | 28 +++++++++--- python/src/trezorlib/models.py | 43 +++++++++++++++++++ python/src/trezorlib/transport/__init__.py | 19 ++++---- python/src/trezorlib/transport/bridge.py | 9 +++- python/src/trezorlib/transport/hid.py | 15 +++++-- python/src/trezorlib/transport/udp.py | 9 +++- python/src/trezorlib/transport/webusb.py | 13 ++++-- tests/upgrade_tests/test_firmware_upgrades.py | 10 +++-- .../test_passphrase_consistency.py | 8 ++-- 11 files changed, 119 insertions(+), 43 deletions(-) create mode 100644 python/.changelog.d/1967.changed create mode 100644 python/src/trezorlib/models.py diff --git a/python/.changelog.d/1967.changed b/python/.changelog.d/1967.changed new file mode 100644 index 000000000..db2572399 --- /dev/null +++ b/python/.changelog.d/1967.changed @@ -0,0 +1 @@ +Introduce Trezor models as an abstraction over USB IDs, vendor strings, and possibly protobuf mappings. diff --git a/python/src/trezorlib/__init__.py b/python/src/trezorlib/__init__.py index c3c862f89..0b1e89c14 100644 --- a/python/src/trezorlib/__init__.py +++ b/python/src/trezorlib/__init__.py @@ -15,10 +15,3 @@ # If not, see . __version__ = "0.13.0" - -# fmt: off -MINIMUM_FIRMWARE_VERSION = { - "1": (1, 8, 0), - "T": (2, 1, 0), -} -# fmt: on diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 18d860fde..684093b03 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Optional from mnemonic import Mnemonic -from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages +from . import exceptions, mapping, messages, models from .log import DUMP_BYTES from .messages import Capability from .tools import expect, parse_path, session @@ -33,7 +33,6 @@ if TYPE_CHECKING: LOG = logging.getLogger(__name__) -VENDORS = ("bitcointrezor.com", "trezor.io") MAX_PASSPHRASE_LENGTH = 50 MAX_PIN_LENGTH = 50 @@ -85,6 +84,7 @@ class TrezorClient: ui: "TrezorClientUI", session_id: Optional[bytes] = None, derive_cardano: Optional[bool] = None, + model: Optional[models.TrezorModel] = None, _init_device: bool = True, ) -> None: """Create a TrezorClient instance. @@ -101,6 +101,9 @@ class TrezorClient: You can supply a `session_id` you might have saved in the previous session. If you do, the user might not need to enter their passphrase again. + You can provide Trezor model information. If not provided, it is detected from + the model name reported at initialization time. + By default, the instance will open a connection to the Trezor device, send an `Initialize` message, set up the `features` field from the response, and connect to a session. By specifying `_init_device=False`, this step is skipped. Notably, @@ -110,7 +113,11 @@ class TrezorClient: might be removed at any time. """ LOG.info(f"creating client instance for device: {transport.get_path()}") - self.mapping = mapping.DEFAULT_MAPPING + self.model = model + if self.model: + self.mapping = self.model.default_mapping + else: + self.mapping = mapping.DEFAULT_MAPPING self.transport = transport self.ui = ui self.session_counter = 0 @@ -254,7 +261,14 @@ class TrezorClient: def _refresh_features(self, features: messages.Features) -> None: """Update internal fields based on passed-in Features message.""" - if features.vendor not in VENDORS: + + if not self.model: + # Trezor Model One bootloader 1.8.0 or older does not send model name + self.model = models.by_name(features.model or "1") + if self.model is None: + raise RuntimeError("Unsupported Trezor model") + + if features.vendor not in self.model.vendors: raise RuntimeError("Unsupported device") self.features = features @@ -353,9 +367,9 @@ class TrezorClient: def is_outdated(self) -> bool: if self.features.bootloader_mode: return False - model = self.features.model or "1" - required_version = MINIMUM_FIRMWARE_VERSION[model] - return self.version < required_version + + assert self.model is not None # should happen in _refresh_features + return self.version < self.model.minimum_version def check_firmware_version(self, warn_only: bool = False) -> None: if self.is_outdated(): diff --git a/python/src/trezorlib/models.py b/python/src/trezorlib/models.py new file mode 100644 index 000000000..183dfb3eb --- /dev/null +++ b/python/src/trezorlib/models.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import Collection, Optional, Tuple + +from . import mapping + +UsbId = Tuple[int, int] + +VENDORS = ("bitcointrezor.com", "trezor.io") + + +@dataclass(eq=True, frozen=True) +class TrezorModel: + name: str + minimum_version: Tuple[int, int, int] + vendors: Collection[str] + usb_ids: Collection[UsbId] + default_mapping: mapping.ProtobufMapping + + +TREZOR_ONE = TrezorModel( + name="1", + minimum_version=(1, 8, 0), + vendors=VENDORS, + usb_ids=((0x534C, 0x0001),), + default_mapping=mapping.DEFAULT_MAPPING, +) + +TREZOR_T = TrezorModel( + name="T", + minimum_version=(2, 1, 0), + vendors=VENDORS, + usb_ids=((0x1209, 0x53C1), (0x1209, 0x53C0)), + default_mapping=mapping.DEFAULT_MAPPING, +) + +TREZORS = {TREZOR_ONE, TREZOR_T} + + +def by_name(name: str) -> Optional[TrezorModel]: + for model in TREZORS: + if model.name == name: + return model + return None diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index 8508e42b2..0828c6ed9 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -29,17 +29,12 @@ from typing import ( from ..exceptions import TrezorException if TYPE_CHECKING: + from ..models import TrezorModel + T = TypeVar("T", bound="Transport") LOG = logging.getLogger(__name__) -# USB vendor/product IDs for Trezors -DEV_TREZOR1 = (0x534C, 0x0001) -DEV_TREZOR2 = (0x1209, 0x53C1) -DEV_TREZOR2_BL = (0x1209, 0x53C0) - -TREZORS = {DEV_TREZOR1, DEV_TREZOR2, DEV_TREZOR2_BL} - UDEV_RULES_STR = """ Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules @@ -95,7 +90,9 @@ class Transport: raise NotImplementedError @classmethod - def enumerate(cls: Type["T"]) -> Iterable["T"]: + def enumerate( + cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["T"]: raise NotImplementedError @classmethod @@ -126,12 +123,14 @@ def all_transports() -> Iterable[Type["Transport"]]: return set(t for t in transports if t.ENABLED) -def enumerate_devices() -> Sequence["Transport"]: +def enumerate_devices( + models: Optional[Iterable["TrezorModel"]] = None, +) -> Sequence["Transport"]: devices: List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: - found = list(transport.enumerate()) + found = list(transport.enumerate(models)) LOG.info(f"Enumerating {name}: found {len(found)} devices") devices.extend(found) except NotImplementedError: diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e38551607..d77e3693d 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -16,13 +16,16 @@ import logging import struct -from typing import Any, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional import requests from ..log import DUMP_PACKETS from . import MessagePayload, Transport, TransportException +if TYPE_CHECKING: + from ..models import TrezorModel + LOG = logging.getLogger(__name__) TREZORD_HOST = "http://127.0.0.1:21325" @@ -135,7 +138,9 @@ class BridgeTransport(Transport): return call_bridge(uri, data=data) @classmethod - def enumerate(cls) -> Iterable["BridgeTransport"]: + def enumerate( + cls, _models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() return [ diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 06e22afb8..65fa08ccd 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -17,10 +17,11 @@ import logging import sys import time -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Optional from ..log import DUMP_PACKETS -from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException +from ..models import TREZOR_ONE, TrezorModel +from . import UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -132,11 +133,17 @@ class HidTransport(ProtocolBasedTransport): return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" @classmethod - def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]: + def enumerate( + cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False + ) -> Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["HidTransport"] = [] for dev in hid.enumerate(0, 0): usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id != DEV_TREZOR1: + if usb_id not in usb_ids: continue if debug: if not is_debuglink(dev): diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 5f4225945..0bd3e43bc 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -17,12 +17,15 @@ import logging import socket import time -from typing import Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Optional from ..log import DUMP_PACKETS from . import TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 +if TYPE_CHECKING: + from ..models import TrezorModel + SOCKET_TIMEOUT = 10 LOG = logging.getLogger(__name__) @@ -70,7 +73,9 @@ class UdpTransport(ProtocolBasedTransport): d.close() @classmethod - def enumerate(cls) -> Iterable["UdpTransport"]: + def enumerate( + cls, _models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: return [cls._try_path(default_path)] diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index cde54c08d..cf71f0883 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -21,7 +21,8 @@ import time from typing import Iterable, List, Optional from ..log import DUMP_PACKETS -from . import TREZORS, UDEV_RULES_STR, TransportException +from ..models import TREZORS, TrezorModel +from . import UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -114,15 +115,21 @@ class WebUsbTransport(ProtocolBasedTransport): return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" @classmethod - def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]: + def enumerate( + cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: if cls.context is None: cls.context = usb1.USBContext() cls.context.open() atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value] + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] devices: List["WebUsbTransport"] = [] for dev in cls.context.getDeviceIterator(skip_on_error=True): usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in TREZORS: + if usb_id not in usb_ids: continue if not is_vendor_class(dev): continue diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 7b4fa888f..76c8228c9 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -14,9 +14,11 @@ # You should have received a copy of the License along with this library. # If not, see . +import dataclasses + import pytest -from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device, exceptions, fido +from trezorlib import btc, debuglink, device, exceptions, fido, models from trezorlib.messages import BackupType from trezorlib.tools import H_ @@ -26,9 +28,9 @@ from ..device_handler import BackgroundDeviceHandler from ..emulators import ALL_TAGS, EmulatorWrapper from . import for_all, for_tags -MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0) -MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0) - +models.TREZOR_ONE = dataclasses.replace(models.TREZOR_ONE, minimum_version=(1, 0, 0)) +models.TREZOR_T = dataclasses.replace(models.TREZOR_T, minimum_version=(2, 0, 0)) +models.TREZORS = {models.TREZOR_ONE, models.TREZOR_T} # **** COMMON DEFINITIONS **** diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index 2974bdf55..63262acc3 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -16,7 +16,7 @@ import pytest -from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, device, mapping, messages, protobuf +from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib.tools import parse_path from ..emulators import EmulatorWrapper @@ -57,8 +57,8 @@ def emulator(gen, tag): @for_all( - core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"], - legacy_minimum_version=MINIMUM_FIRMWARE_VERSION["1"], + core_minimum_version=models.TREZOR_T.minimum_version, + legacy_minimum_version=models.TREZOR_ONE.minimum_version, ) def test_passphrase_works(emulator): """Check that passphrase handling in trezorlib works correctly in all versions.""" @@ -92,7 +92,7 @@ def test_passphrase_works(emulator): @for_all( - core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"], + core_minimum_version=models.TREZOR_T.minimum_version, legacy_minimum_version=(1, 9, 0), ) def test_init_device(emulator):