diff --git a/python/.changelog.d/1967.changed b/python/.changelog.d/1967.changed new file mode 100644 index 000000000..db2572399 --- /dev/null +++ b/python/.changelog.d/1967.changed @@ -0,0 +1 @@ +Introduce Trezor models as an abstraction over USB IDs, vendor strings, and possibly protobuf mappings. diff --git a/python/src/trezorlib/__init__.py b/python/src/trezorlib/__init__.py index c3c862f89..0b1e89c14 100644 --- a/python/src/trezorlib/__init__.py +++ b/python/src/trezorlib/__init__.py @@ -15,10 +15,3 @@ # If not, see . __version__ = "0.13.0" - -# fmt: off -MINIMUM_FIRMWARE_VERSION = { - "1": (1, 8, 0), - "T": (2, 1, 0), -} -# fmt: on diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 18d860fde..684093b03 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Optional from mnemonic import Mnemonic -from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages +from . import exceptions, mapping, messages, models from .log import DUMP_BYTES from .messages import Capability from .tools import expect, parse_path, session @@ -33,7 +33,6 @@ if TYPE_CHECKING: LOG = logging.getLogger(__name__) -VENDORS = ("bitcointrezor.com", "trezor.io") MAX_PASSPHRASE_LENGTH = 50 MAX_PIN_LENGTH = 50 @@ -85,6 +84,7 @@ class TrezorClient: ui: "TrezorClientUI", session_id: Optional[bytes] = None, derive_cardano: Optional[bool] = None, + model: Optional[models.TrezorModel] = None, _init_device: bool = True, ) -> None: """Create a TrezorClient instance. @@ -101,6 +101,9 @@ class TrezorClient: You can supply a `session_id` you might have saved in the previous session. If you do, the user might not need to enter their passphrase again. + You can provide Trezor model information. If not provided, it is detected from + the model name reported at initialization time. + By default, the instance will open a connection to the Trezor device, send an `Initialize` message, set up the `features` field from the response, and connect to a session. By specifying `_init_device=False`, this step is skipped. Notably, @@ -110,7 +113,11 @@ class TrezorClient: might be removed at any time. """ LOG.info(f"creating client instance for device: {transport.get_path()}") - self.mapping = mapping.DEFAULT_MAPPING + self.model = model + if self.model: + self.mapping = self.model.default_mapping + else: + self.mapping = mapping.DEFAULT_MAPPING self.transport = transport self.ui = ui self.session_counter = 0 @@ -254,7 +261,14 @@ class TrezorClient: def _refresh_features(self, features: messages.Features) -> None: """Update internal fields based on passed-in Features message.""" - if features.vendor not in VENDORS: + + if not self.model: + # Trezor Model One bootloader 1.8.0 or older does not send model name + self.model = models.by_name(features.model or "1") + if self.model is None: + raise RuntimeError("Unsupported Trezor model") + + if features.vendor not in self.model.vendors: raise RuntimeError("Unsupported device") self.features = features @@ -353,9 +367,9 @@ class TrezorClient: def is_outdated(self) -> bool: if self.features.bootloader_mode: return False - model = self.features.model or "1" - required_version = MINIMUM_FIRMWARE_VERSION[model] - return self.version < required_version + + assert self.model is not None # should happen in _refresh_features + return self.version < self.model.minimum_version def check_firmware_version(self, warn_only: bool = False) -> None: if self.is_outdated(): diff --git a/python/src/trezorlib/models.py b/python/src/trezorlib/models.py new file mode 100644 index 000000000..183dfb3eb --- /dev/null +++ b/python/src/trezorlib/models.py @@ -0,0 +1,43 @@ +from dataclasses import dataclass +from typing import Collection, Optional, Tuple + +from . import mapping + +UsbId = Tuple[int, int] + +VENDORS = ("bitcointrezor.com", "trezor.io") + + +@dataclass(eq=True, frozen=True) +class TrezorModel: + name: str + minimum_version: Tuple[int, int, int] + vendors: Collection[str] + usb_ids: Collection[UsbId] + default_mapping: mapping.ProtobufMapping + + +TREZOR_ONE = TrezorModel( + name="1", + minimum_version=(1, 8, 0), + vendors=VENDORS, + usb_ids=((0x534C, 0x0001),), + default_mapping=mapping.DEFAULT_MAPPING, +) + +TREZOR_T = TrezorModel( + name="T", + minimum_version=(2, 1, 0), + vendors=VENDORS, + usb_ids=((0x1209, 0x53C1), (0x1209, 0x53C0)), + default_mapping=mapping.DEFAULT_MAPPING, +) + +TREZORS = {TREZOR_ONE, TREZOR_T} + + +def by_name(name: str) -> Optional[TrezorModel]: + for model in TREZORS: + if model.name == name: + return model + return None diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index 8508e42b2..0828c6ed9 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -29,17 +29,12 @@ from typing import ( from ..exceptions import TrezorException if TYPE_CHECKING: + from ..models import TrezorModel + T = TypeVar("T", bound="Transport") LOG = logging.getLogger(__name__) -# USB vendor/product IDs for Trezors -DEV_TREZOR1 = (0x534C, 0x0001) -DEV_TREZOR2 = (0x1209, 0x53C1) -DEV_TREZOR2_BL = (0x1209, 0x53C0) - -TREZORS = {DEV_TREZOR1, DEV_TREZOR2, DEV_TREZOR2_BL} - UDEV_RULES_STR = """ Do you have udev rules installed? https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules @@ -95,7 +90,9 @@ class Transport: raise NotImplementedError @classmethod - def enumerate(cls: Type["T"]) -> Iterable["T"]: + def enumerate( + cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["T"]: raise NotImplementedError @classmethod @@ -126,12 +123,14 @@ def all_transports() -> Iterable[Type["Transport"]]: return set(t for t in transports if t.ENABLED) -def enumerate_devices() -> Sequence["Transport"]: +def enumerate_devices( + models: Optional[Iterable["TrezorModel"]] = None, +) -> Sequence["Transport"]: devices: List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: - found = list(transport.enumerate()) + found = list(transport.enumerate(models)) LOG.info(f"Enumerating {name}: found {len(found)} devices") devices.extend(found) except NotImplementedError: diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e38551607..d77e3693d 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -16,13 +16,16 @@ import logging import struct -from typing import Any, Dict, Iterable, Optional +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional import requests from ..log import DUMP_PACKETS from . import MessagePayload, Transport, TransportException +if TYPE_CHECKING: + from ..models import TrezorModel + LOG = logging.getLogger(__name__) TREZORD_HOST = "http://127.0.0.1:21325" @@ -135,7 +138,9 @@ class BridgeTransport(Transport): return call_bridge(uri, data=data) @classmethod - def enumerate(cls) -> Iterable["BridgeTransport"]: + def enumerate( + cls, _models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() return [ diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 06e22afb8..65fa08ccd 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -17,10 +17,11 @@ import logging import sys import time -from typing import Any, Dict, Iterable, List +from typing import Any, Dict, Iterable, List, Optional from ..log import DUMP_PACKETS -from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException +from ..models import TREZOR_ONE, TrezorModel +from . import UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -132,11 +133,17 @@ class HidTransport(ProtocolBasedTransport): return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" @classmethod - def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]: + def enumerate( + cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False + ) -> Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["HidTransport"] = [] for dev in hid.enumerate(0, 0): usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id != DEV_TREZOR1: + if usb_id not in usb_ids: continue if debug: if not is_debuglink(dev): diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 5f4225945..0bd3e43bc 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -17,12 +17,15 @@ import logging import socket import time -from typing import Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Optional from ..log import DUMP_PACKETS from . import TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 +if TYPE_CHECKING: + from ..models import TrezorModel + SOCKET_TIMEOUT = 10 LOG = logging.getLogger(__name__) @@ -70,7 +73,9 @@ class UdpTransport(ProtocolBasedTransport): d.close() @classmethod - def enumerate(cls) -> Iterable["UdpTransport"]: + def enumerate( + cls, _models: Optional[Iterable["TrezorModel"]] = None + ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: return [cls._try_path(default_path)] diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index cde54c08d..cf71f0883 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -21,7 +21,8 @@ import time from typing import Iterable, List, Optional from ..log import DUMP_PACKETS -from . import TREZORS, UDEV_RULES_STR, TransportException +from ..models import TREZORS, TrezorModel +from . import UDEV_RULES_STR, TransportException from .protocol import ProtocolBasedTransport, ProtocolV1 LOG = logging.getLogger(__name__) @@ -114,15 +115,21 @@ class WebUsbTransport(ProtocolBasedTransport): return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" @classmethod - def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]: + def enumerate( + cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: if cls.context is None: cls.context = usb1.USBContext() cls.context.open() atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value] + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] devices: List["WebUsbTransport"] = [] for dev in cls.context.getDeviceIterator(skip_on_error=True): usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in TREZORS: + if usb_id not in usb_ids: continue if not is_vendor_class(dev): continue diff --git a/tests/upgrade_tests/test_firmware_upgrades.py b/tests/upgrade_tests/test_firmware_upgrades.py index 7b4fa888f..76c8228c9 100644 --- a/tests/upgrade_tests/test_firmware_upgrades.py +++ b/tests/upgrade_tests/test_firmware_upgrades.py @@ -14,9 +14,11 @@ # You should have received a copy of the License along with this library. # If not, see . +import dataclasses + import pytest -from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device, exceptions, fido +from trezorlib import btc, debuglink, device, exceptions, fido, models from trezorlib.messages import BackupType from trezorlib.tools import H_ @@ -26,9 +28,9 @@ from ..device_handler import BackgroundDeviceHandler from ..emulators import ALL_TAGS, EmulatorWrapper from . import for_all, for_tags -MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0) -MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0) - +models.TREZOR_ONE = dataclasses.replace(models.TREZOR_ONE, minimum_version=(1, 0, 0)) +models.TREZOR_T = dataclasses.replace(models.TREZOR_T, minimum_version=(2, 0, 0)) +models.TREZORS = {models.TREZOR_ONE, models.TREZOR_T} # **** COMMON DEFINITIONS **** diff --git a/tests/upgrade_tests/test_passphrase_consistency.py b/tests/upgrade_tests/test_passphrase_consistency.py index 2974bdf55..63262acc3 100644 --- a/tests/upgrade_tests/test_passphrase_consistency.py +++ b/tests/upgrade_tests/test_passphrase_consistency.py @@ -16,7 +16,7 @@ import pytest -from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, device, mapping, messages, protobuf +from trezorlib import btc, device, mapping, messages, models, protobuf from trezorlib.tools import parse_path from ..emulators import EmulatorWrapper @@ -57,8 +57,8 @@ def emulator(gen, tag): @for_all( - core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"], - legacy_minimum_version=MINIMUM_FIRMWARE_VERSION["1"], + core_minimum_version=models.TREZOR_T.minimum_version, + legacy_minimum_version=models.TREZOR_ONE.minimum_version, ) def test_passphrase_works(emulator): """Check that passphrase handling in trezorlib works correctly in all versions.""" @@ -92,7 +92,7 @@ def test_passphrase_works(emulator): @for_all( - core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"], + core_minimum_version=models.TREZOR_T.minimum_version, legacy_minimum_version=(1, 9, 0), ) def test_init_device(emulator):