feat(python): introduce Trezor models

This keeps information about vendors and USB IDs in one place, and
allows us to extend with model-specific information later.

By default, this should be backwards-compatible -- TrezorClient can
optionally accept model information, and if not, it will try to guess
based on Features.

It is possible to specify which models to look for in transport
enumeration. Bridge and UDP transports ignore the parameter, because
they can't know what model is on the other side.

supersedes #1448 and #1449
pull/1989/head
matejcik 3 years ago committed by matejcik
parent 38fca4a83d
commit a4bcc95deb

@ -0,0 +1 @@
Introduce Trezor models as an abstraction over USB IDs, vendor strings, and possibly protobuf mappings.

@ -15,10 +15,3 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
__version__ = "0.13.0" __version__ = "0.13.0"
# fmt: off
MINIMUM_FIRMWARE_VERSION = {
"1": (1, 8, 0),
"T": (2, 1, 0),
}
# fmt: on

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Optional
from mnemonic import Mnemonic from mnemonic import Mnemonic
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages from . import exceptions, mapping, messages, models
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import Capability from .messages import Capability
from .tools import expect, parse_path, session from .tools import expect, parse_path, session
@ -33,7 +33,6 @@ if TYPE_CHECKING:
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
VENDORS = ("bitcointrezor.com", "trezor.io")
MAX_PASSPHRASE_LENGTH = 50 MAX_PASSPHRASE_LENGTH = 50
MAX_PIN_LENGTH = 50 MAX_PIN_LENGTH = 50
@ -85,6 +84,7 @@ class TrezorClient:
ui: "TrezorClientUI", ui: "TrezorClientUI",
session_id: Optional[bytes] = None, session_id: Optional[bytes] = None,
derive_cardano: Optional[bool] = None, derive_cardano: Optional[bool] = None,
model: Optional[models.TrezorModel] = None,
_init_device: bool = True, _init_device: bool = True,
) -> None: ) -> None:
"""Create a TrezorClient instance. """Create a TrezorClient instance.
@ -101,6 +101,9 @@ class TrezorClient:
You can supply a `session_id` you might have saved in the previous session. If You can supply a `session_id` you might have saved in the previous session. If
you do, the user might not need to enter their passphrase again. you do, the user might not need to enter their passphrase again.
You can provide Trezor model information. If not provided, it is detected from
the model name reported at initialization time.
By default, the instance will open a connection to the Trezor device, send an By default, the instance will open a connection to the Trezor device, send an
`Initialize` message, set up the `features` field from the response, and connect `Initialize` message, set up the `features` field from the response, and connect
to a session. By specifying `_init_device=False`, this step is skipped. Notably, to a session. By specifying `_init_device=False`, this step is skipped. Notably,
@ -110,7 +113,11 @@ class TrezorClient:
might be removed at any time. might be removed at any time.
""" """
LOG.info(f"creating client instance for device: {transport.get_path()}") LOG.info(f"creating client instance for device: {transport.get_path()}")
self.mapping = mapping.DEFAULT_MAPPING self.model = model
if self.model:
self.mapping = self.model.default_mapping
else:
self.mapping = mapping.DEFAULT_MAPPING
self.transport = transport self.transport = transport
self.ui = ui self.ui = ui
self.session_counter = 0 self.session_counter = 0
@ -254,7 +261,14 @@ class TrezorClient:
def _refresh_features(self, features: messages.Features) -> None: def _refresh_features(self, features: messages.Features) -> None:
"""Update internal fields based on passed-in Features message.""" """Update internal fields based on passed-in Features message."""
if features.vendor not in VENDORS:
if not self.model:
# Trezor Model One bootloader 1.8.0 or older does not send model name
self.model = models.by_name(features.model or "1")
if self.model is None:
raise RuntimeError("Unsupported Trezor model")
if features.vendor not in self.model.vendors:
raise RuntimeError("Unsupported device") raise RuntimeError("Unsupported device")
self.features = features self.features = features
@ -353,9 +367,9 @@ class TrezorClient:
def is_outdated(self) -> bool: def is_outdated(self) -> bool:
if self.features.bootloader_mode: if self.features.bootloader_mode:
return False return False
model = self.features.model or "1"
required_version = MINIMUM_FIRMWARE_VERSION[model] assert self.model is not None # should happen in _refresh_features
return self.version < required_version return self.version < self.model.minimum_version
def check_firmware_version(self, warn_only: bool = False) -> None: def check_firmware_version(self, warn_only: bool = False) -> None:
if self.is_outdated(): if self.is_outdated():

@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import Collection, Optional, Tuple
from . import mapping
UsbId = Tuple[int, int]
VENDORS = ("bitcointrezor.com", "trezor.io")
@dataclass(eq=True, frozen=True)
class TrezorModel:
name: str
minimum_version: Tuple[int, int, int]
vendors: Collection[str]
usb_ids: Collection[UsbId]
default_mapping: mapping.ProtobufMapping
TREZOR_ONE = TrezorModel(
name="1",
minimum_version=(1, 8, 0),
vendors=VENDORS,
usb_ids=((0x534C, 0x0001),),
default_mapping=mapping.DEFAULT_MAPPING,
)
TREZOR_T = TrezorModel(
name="T",
minimum_version=(2, 1, 0),
vendors=VENDORS,
usb_ids=((0x1209, 0x53C1), (0x1209, 0x53C0)),
default_mapping=mapping.DEFAULT_MAPPING,
)
TREZORS = {TREZOR_ONE, TREZOR_T}
def by_name(name: str) -> Optional[TrezorModel]:
for model in TREZORS:
if model.name == name:
return model
return None

@ -29,17 +29,12 @@ from typing import (
from ..exceptions import TrezorException from ..exceptions import TrezorException
if TYPE_CHECKING: if TYPE_CHECKING:
from ..models import TrezorModel
T = TypeVar("T", bound="Transport") T = TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
# USB vendor/product IDs for Trezors
DEV_TREZOR1 = (0x534C, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53C1)
DEV_TREZOR2_BL = (0x1209, 0x53C0)
TREZORS = {DEV_TREZOR1, DEV_TREZOR2, DEV_TREZOR2_BL}
UDEV_RULES_STR = """ UDEV_RULES_STR = """
Do you have udev rules installed? Do you have udev rules installed?
https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
@ -95,7 +90,9 @@ class Transport:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def enumerate(cls: Type["T"]) -> Iterable["T"]: def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
@ -126,12 +123,14 @@ def all_transports() -> Iterable[Type["Transport"]]:
return set(t for t in transports if t.ENABLED) return set(t for t in transports if t.ENABLED)
def enumerate_devices() -> Sequence["Transport"]: def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None,
) -> Sequence["Transport"]:
devices: List["Transport"] = [] devices: List["Transport"] = []
for transport in all_transports(): for transport in all_transports():
name = transport.__name__ name = transport.__name__
try: try:
found = list(transport.enumerate()) found = list(transport.enumerate(models))
LOG.info(f"Enumerating {name}: found {len(found)} devices") LOG.info(f"Enumerating {name}: found {len(found)} devices")
devices.extend(found) devices.extend(found)
except NotImplementedError: except NotImplementedError:

@ -16,13 +16,16 @@
import logging import logging
import struct import struct
from typing import Any, Dict, Iterable, Optional from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
import requests import requests
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import MessagePayload, Transport, TransportException from . import MessagePayload, Transport, TransportException
if TYPE_CHECKING:
from ..models import TrezorModel
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_HOST = "http://127.0.0.1:21325"
@ -135,7 +138,9 @@ class BridgeTransport(Transport):
return call_bridge(uri, data=data) return call_bridge(uri, data=data)
@classmethod @classmethod
def enumerate(cls) -> Iterable["BridgeTransport"]: def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["BridgeTransport"]:
try: try:
legacy = is_legacy_bridge() legacy = is_legacy_bridge()
return [ return [

@ -17,10 +17,11 @@
import logging import logging
import sys import sys
import time import time
from typing import Any, Dict, Iterable, List from typing import Any, Dict, Iterable, List, Optional
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -132,11 +133,17 @@ class HidTransport(ProtocolBasedTransport):
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod @classmethod
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]: def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = [] devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0): for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"]) usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id != DEV_TREZOR1: if usb_id not in usb_ids:
continue continue
if debug: if debug:
if not is_debuglink(dev): if not is_debuglink(dev):

@ -17,12 +17,15 @@
import logging import logging
import socket import socket
import time import time
from typing import Iterable, Optional from typing import TYPE_CHECKING, Iterable, Optional
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import TransportException from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
if TYPE_CHECKING:
from ..models import TrezorModel
SOCKET_TIMEOUT = 10 SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -70,7 +73,9 @@ class UdpTransport(ProtocolBasedTransport):
d.close() d.close()
@classmethod @classmethod
def enumerate(cls) -> Iterable["UdpTransport"]: def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try: try:
return [cls._try_path(default_path)] return [cls._try_path(default_path)]

@ -21,7 +21,8 @@ import time
from typing import Iterable, List, Optional from typing import Iterable, List, Optional
from ..log import DUMP_PACKETS from ..log import DUMP_PACKETS
from . import TREZORS, UDEV_RULES_STR, TransportException from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1 from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -114,15 +115,21 @@ class WebUsbTransport(ProtocolBasedTransport):
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod @classmethod
def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]: def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None: if cls.context is None:
cls.context = usb1.USBContext() cls.context = usb1.USBContext()
cls.context.open() cls.context.open()
atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value] atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value]
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = [] devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True): for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID()) usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in TREZORS: if usb_id not in usb_ids:
continue continue
if not is_vendor_class(dev): if not is_vendor_class(dev):
continue continue

@ -14,9 +14,11 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import dataclasses
import pytest import pytest
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device, exceptions, fido from trezorlib import btc, debuglink, device, exceptions, fido, models
from trezorlib.messages import BackupType from trezorlib.messages import BackupType
from trezorlib.tools import H_ from trezorlib.tools import H_
@ -26,9 +28,9 @@ from ..device_handler import BackgroundDeviceHandler
from ..emulators import ALL_TAGS, EmulatorWrapper from ..emulators import ALL_TAGS, EmulatorWrapper
from . import for_all, for_tags from . import for_all, for_tags
MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0) models.TREZOR_ONE = dataclasses.replace(models.TREZOR_ONE, minimum_version=(1, 0, 0))
MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0) models.TREZOR_T = dataclasses.replace(models.TREZOR_T, minimum_version=(2, 0, 0))
models.TREZORS = {models.TREZOR_ONE, models.TREZOR_T}
# **** COMMON DEFINITIONS **** # **** COMMON DEFINITIONS ****

@ -16,7 +16,7 @@
import pytest import pytest
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, device, mapping, messages, protobuf from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ..emulators import EmulatorWrapper from ..emulators import EmulatorWrapper
@ -57,8 +57,8 @@ def emulator(gen, tag):
@for_all( @for_all(
core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"], core_minimum_version=models.TREZOR_T.minimum_version,
legacy_minimum_version=MINIMUM_FIRMWARE_VERSION["1"], legacy_minimum_version=models.TREZOR_ONE.minimum_version,
) )
def test_passphrase_works(emulator): def test_passphrase_works(emulator):
"""Check that passphrase handling in trezorlib works correctly in all versions.""" """Check that passphrase handling in trezorlib works correctly in all versions."""
@ -92,7 +92,7 @@ def test_passphrase_works(emulator):
@for_all( @for_all(
core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"], core_minimum_version=models.TREZOR_T.minimum_version,
legacy_minimum_version=(1, 9, 0), legacy_minimum_version=(1, 9, 0),
) )
def test_init_device(emulator): def test_init_device(emulator):

Loading…
Cancel
Save