1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 16:00:57 +00:00

feat(python): introduce Trezor models

This keeps information about vendors and USB IDs in one place, and
allows us to extend with model-specific information later.

By default, this should be backwards-compatible -- TrezorClient can
optionally accept model information, and if not, it will try to guess
based on Features.

It is possible to specify which models to look for in transport
enumeration. Bridge and UDP transports ignore the parameter, because
they can't know what model is on the other side.

supersedes #1448 and #1449
This commit is contained in:
matejcik 2021-11-26 16:31:35 +01:00 committed by matejcik
parent 38fca4a83d
commit a4bcc95deb
11 changed files with 119 additions and 43 deletions

View File

@ -0,0 +1 @@
Introduce Trezor models as an abstraction over USB IDs, vendor strings, and possibly protobuf mappings.

View File

@ -15,10 +15,3 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
__version__ = "0.13.0"
# fmt: off
MINIMUM_FIRMWARE_VERSION = {
"1": (1, 8, 0),
"T": (2, 1, 0),
}
# fmt: on

View File

@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Optional
from mnemonic import Mnemonic
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages
from . import exceptions, mapping, messages, models
from .log import DUMP_BYTES
from .messages import Capability
from .tools import expect, parse_path, session
@ -33,7 +33,6 @@ if TYPE_CHECKING:
LOG = logging.getLogger(__name__)
VENDORS = ("bitcointrezor.com", "trezor.io")
MAX_PASSPHRASE_LENGTH = 50
MAX_PIN_LENGTH = 50
@ -85,6 +84,7 @@ class TrezorClient:
ui: "TrezorClientUI",
session_id: Optional[bytes] = None,
derive_cardano: Optional[bool] = None,
model: Optional[models.TrezorModel] = None,
_init_device: bool = True,
) -> None:
"""Create a TrezorClient instance.
@ -101,6 +101,9 @@ class TrezorClient:
You can supply a `session_id` you might have saved in the previous session. If
you do, the user might not need to enter their passphrase again.
You can provide Trezor model information. If not provided, it is detected from
the model name reported at initialization time.
By default, the instance will open a connection to the Trezor device, send an
`Initialize` message, set up the `features` field from the response, and connect
to a session. By specifying `_init_device=False`, this step is skipped. Notably,
@ -110,7 +113,11 @@ class TrezorClient:
might be removed at any time.
"""
LOG.info(f"creating client instance for device: {transport.get_path()}")
self.mapping = mapping.DEFAULT_MAPPING
self.model = model
if self.model:
self.mapping = self.model.default_mapping
else:
self.mapping = mapping.DEFAULT_MAPPING
self.transport = transport
self.ui = ui
self.session_counter = 0
@ -254,7 +261,14 @@ class TrezorClient:
def _refresh_features(self, features: messages.Features) -> None:
"""Update internal fields based on passed-in Features message."""
if features.vendor not in VENDORS:
if not self.model:
# Trezor Model One bootloader 1.8.0 or older does not send model name
self.model = models.by_name(features.model or "1")
if self.model is None:
raise RuntimeError("Unsupported Trezor model")
if features.vendor not in self.model.vendors:
raise RuntimeError("Unsupported device")
self.features = features
@ -353,9 +367,9 @@ class TrezorClient:
def is_outdated(self) -> bool:
if self.features.bootloader_mode:
return False
model = self.features.model or "1"
required_version = MINIMUM_FIRMWARE_VERSION[model]
return self.version < required_version
assert self.model is not None # should happen in _refresh_features
return self.version < self.model.minimum_version
def check_firmware_version(self, warn_only: bool = False) -> None:
if self.is_outdated():

View File

@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import Collection, Optional, Tuple
from . import mapping
UsbId = Tuple[int, int]
VENDORS = ("bitcointrezor.com", "trezor.io")
@dataclass(eq=True, frozen=True)
class TrezorModel:
name: str
minimum_version: Tuple[int, int, int]
vendors: Collection[str]
usb_ids: Collection[UsbId]
default_mapping: mapping.ProtobufMapping
TREZOR_ONE = TrezorModel(
name="1",
minimum_version=(1, 8, 0),
vendors=VENDORS,
usb_ids=((0x534C, 0x0001),),
default_mapping=mapping.DEFAULT_MAPPING,
)
TREZOR_T = TrezorModel(
name="T",
minimum_version=(2, 1, 0),
vendors=VENDORS,
usb_ids=((0x1209, 0x53C1), (0x1209, 0x53C0)),
default_mapping=mapping.DEFAULT_MAPPING,
)
TREZORS = {TREZOR_ONE, TREZOR_T}
def by_name(name: str) -> Optional[TrezorModel]:
for model in TREZORS:
if model.name == name:
return model
return None

View File

@ -29,17 +29,12 @@ from typing import (
from ..exceptions import TrezorException
if TYPE_CHECKING:
from ..models import TrezorModel
T = TypeVar("T", bound="Transport")
LOG = logging.getLogger(__name__)
# USB vendor/product IDs for Trezors
DEV_TREZOR1 = (0x534C, 0x0001)
DEV_TREZOR2 = (0x1209, 0x53C1)
DEV_TREZOR2_BL = (0x1209, 0x53C0)
TREZORS = {DEV_TREZOR1, DEV_TREZOR2, DEV_TREZOR2_BL}
UDEV_RULES_STR = """
Do you have udev rules installed?
https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
@ -95,7 +90,9 @@ class Transport:
raise NotImplementedError
@classmethod
def enumerate(cls: Type["T"]) -> Iterable["T"]:
def enumerate(
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["T"]:
raise NotImplementedError
@classmethod
@ -126,12 +123,14 @@ def all_transports() -> Iterable[Type["Transport"]]:
return set(t for t in transports if t.ENABLED)
def enumerate_devices() -> Sequence["Transport"]:
def enumerate_devices(
models: Optional[Iterable["TrezorModel"]] = None,
) -> Sequence["Transport"]:
devices: List["Transport"] = []
for transport in all_transports():
name = transport.__name__
try:
found = list(transport.enumerate())
found = list(transport.enumerate(models))
LOG.info(f"Enumerating {name}: found {len(found)} devices")
devices.extend(found)
except NotImplementedError:

View File

@ -16,13 +16,16 @@
import logging
import struct
from typing import Any, Dict, Iterable, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
import requests
from ..log import DUMP_PACKETS
from . import MessagePayload, Transport, TransportException
if TYPE_CHECKING:
from ..models import TrezorModel
LOG = logging.getLogger(__name__)
TREZORD_HOST = "http://127.0.0.1:21325"
@ -135,7 +138,9 @@ class BridgeTransport(Transport):
return call_bridge(uri, data=data)
@classmethod
def enumerate(cls) -> Iterable["BridgeTransport"]:
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["BridgeTransport"]:
try:
legacy = is_legacy_bridge()
return [

View File

@ -17,10 +17,11 @@
import logging
import sys
import time
from typing import Any, Dict, Iterable, List
from typing import Any, Dict, Iterable, List, Optional
from ..log import DUMP_PACKETS
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
from ..models import TREZOR_ONE, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__)
@ -132,11 +133,17 @@ class HidTransport(ProtocolBasedTransport):
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
@classmethod
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
) -> Iterable["HidTransport"]:
if models is None:
models = {TREZOR_ONE}
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["HidTransport"] = []
for dev in hid.enumerate(0, 0):
usb_id = (dev["vendor_id"], dev["product_id"])
if usb_id != DEV_TREZOR1:
if usb_id not in usb_ids:
continue
if debug:
if not is_debuglink(dev):

View File

@ -17,12 +17,15 @@
import logging
import socket
import time
from typing import Iterable, Optional
from typing import TYPE_CHECKING, Iterable, Optional
from ..log import DUMP_PACKETS
from . import TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
if TYPE_CHECKING:
from ..models import TrezorModel
SOCKET_TIMEOUT = 10
LOG = logging.getLogger(__name__)
@ -70,7 +73,9 @@ class UdpTransport(ProtocolBasedTransport):
d.close()
@classmethod
def enumerate(cls) -> Iterable["UdpTransport"]:
def enumerate(
cls, _models: Optional[Iterable["TrezorModel"]] = None
) -> Iterable["UdpTransport"]:
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
try:
return [cls._try_path(default_path)]

View File

@ -21,7 +21,8 @@ import time
from typing import Iterable, List, Optional
from ..log import DUMP_PACKETS
from . import TREZORS, UDEV_RULES_STR, TransportException
from ..models import TREZORS, TrezorModel
from . import UDEV_RULES_STR, TransportException
from .protocol import ProtocolBasedTransport, ProtocolV1
LOG = logging.getLogger(__name__)
@ -114,15 +115,21 @@ class WebUsbTransport(ProtocolBasedTransport):
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
@classmethod
def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]:
def enumerate(
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
) -> Iterable["WebUsbTransport"]:
if cls.context is None:
cls.context = usb1.USBContext()
cls.context.open()
atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value]
if models is None:
models = TREZORS
usb_ids = [id for model in models for id in model.usb_ids]
devices: List["WebUsbTransport"] = []
for dev in cls.context.getDeviceIterator(skip_on_error=True):
usb_id = (dev.getVendorID(), dev.getProductID())
if usb_id not in TREZORS:
if usb_id not in usb_ids:
continue
if not is_vendor_class(dev):
continue

View File

@ -14,9 +14,11 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
import dataclasses
import pytest
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device, exceptions, fido
from trezorlib import btc, debuglink, device, exceptions, fido, models
from trezorlib.messages import BackupType
from trezorlib.tools import H_
@ -26,9 +28,9 @@ from ..device_handler import BackgroundDeviceHandler
from ..emulators import ALL_TAGS, EmulatorWrapper
from . import for_all, for_tags
MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0)
MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0)
models.TREZOR_ONE = dataclasses.replace(models.TREZOR_ONE, minimum_version=(1, 0, 0))
models.TREZOR_T = dataclasses.replace(models.TREZOR_T, minimum_version=(2, 0, 0))
models.TREZORS = {models.TREZOR_ONE, models.TREZOR_T}
# **** COMMON DEFINITIONS ****

View File

@ -16,7 +16,7 @@
import pytest
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, device, mapping, messages, protobuf
from trezorlib import btc, device, mapping, messages, models, protobuf
from trezorlib.tools import parse_path
from ..emulators import EmulatorWrapper
@ -57,8 +57,8 @@ def emulator(gen, tag):
@for_all(
core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"],
legacy_minimum_version=MINIMUM_FIRMWARE_VERSION["1"],
core_minimum_version=models.TREZOR_T.minimum_version,
legacy_minimum_version=models.TREZOR_ONE.minimum_version,
)
def test_passphrase_works(emulator):
"""Check that passphrase handling in trezorlib works correctly in all versions."""
@ -92,7 +92,7 @@ def test_passphrase_works(emulator):
@for_all(
core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"],
core_minimum_version=models.TREZOR_T.minimum_version,
legacy_minimum_version=(1, 9, 0),
)
def test_init_device(emulator):