mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 16:00:57 +00:00
feat(python): introduce Trezor models
This keeps information about vendors and USB IDs in one place, and allows us to extend with model-specific information later. By default, this should be backwards-compatible -- TrezorClient can optionally accept model information, and if not, it will try to guess based on Features. It is possible to specify which models to look for in transport enumeration. Bridge and UDP transports ignore the parameter, because they can't know what model is on the other side. supersedes #1448 and #1449
This commit is contained in:
parent
38fca4a83d
commit
a4bcc95deb
1
python/.changelog.d/1967.changed
Normal file
1
python/.changelog.d/1967.changed
Normal file
@ -0,0 +1 @@
|
||||
Introduce Trezor models as an abstraction over USB IDs, vendor strings, and possibly protobuf mappings.
|
@ -15,10 +15,3 @@
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
__version__ = "0.13.0"
|
||||
|
||||
# fmt: off
|
||||
MINIMUM_FIRMWARE_VERSION = {
|
||||
"1": (1, 8, 0),
|
||||
"T": (2, 1, 0),
|
||||
}
|
||||
# fmt: on
|
||||
|
@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from mnemonic import Mnemonic
|
||||
|
||||
from . import MINIMUM_FIRMWARE_VERSION, exceptions, mapping, messages
|
||||
from . import exceptions, mapping, messages, models
|
||||
from .log import DUMP_BYTES
|
||||
from .messages import Capability
|
||||
from .tools import expect, parse_path, session
|
||||
@ -33,7 +33,6 @@ if TYPE_CHECKING:
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
VENDORS = ("bitcointrezor.com", "trezor.io")
|
||||
MAX_PASSPHRASE_LENGTH = 50
|
||||
MAX_PIN_LENGTH = 50
|
||||
|
||||
@ -85,6 +84,7 @@ class TrezorClient:
|
||||
ui: "TrezorClientUI",
|
||||
session_id: Optional[bytes] = None,
|
||||
derive_cardano: Optional[bool] = None,
|
||||
model: Optional[models.TrezorModel] = None,
|
||||
_init_device: bool = True,
|
||||
) -> None:
|
||||
"""Create a TrezorClient instance.
|
||||
@ -101,6 +101,9 @@ class TrezorClient:
|
||||
You can supply a `session_id` you might have saved in the previous session. If
|
||||
you do, the user might not need to enter their passphrase again.
|
||||
|
||||
You can provide Trezor model information. If not provided, it is detected from
|
||||
the model name reported at initialization time.
|
||||
|
||||
By default, the instance will open a connection to the Trezor device, send an
|
||||
`Initialize` message, set up the `features` field from the response, and connect
|
||||
to a session. By specifying `_init_device=False`, this step is skipped. Notably,
|
||||
@ -110,6 +113,10 @@ class TrezorClient:
|
||||
might be removed at any time.
|
||||
"""
|
||||
LOG.info(f"creating client instance for device: {transport.get_path()}")
|
||||
self.model = model
|
||||
if self.model:
|
||||
self.mapping = self.model.default_mapping
|
||||
else:
|
||||
self.mapping = mapping.DEFAULT_MAPPING
|
||||
self.transport = transport
|
||||
self.ui = ui
|
||||
@ -254,7 +261,14 @@ class TrezorClient:
|
||||
|
||||
def _refresh_features(self, features: messages.Features) -> None:
|
||||
"""Update internal fields based on passed-in Features message."""
|
||||
if features.vendor not in VENDORS:
|
||||
|
||||
if not self.model:
|
||||
# Trezor Model One bootloader 1.8.0 or older does not send model name
|
||||
self.model = models.by_name(features.model or "1")
|
||||
if self.model is None:
|
||||
raise RuntimeError("Unsupported Trezor model")
|
||||
|
||||
if features.vendor not in self.model.vendors:
|
||||
raise RuntimeError("Unsupported device")
|
||||
|
||||
self.features = features
|
||||
@ -353,9 +367,9 @@ class TrezorClient:
|
||||
def is_outdated(self) -> bool:
|
||||
if self.features.bootloader_mode:
|
||||
return False
|
||||
model = self.features.model or "1"
|
||||
required_version = MINIMUM_FIRMWARE_VERSION[model]
|
||||
return self.version < required_version
|
||||
|
||||
assert self.model is not None # should happen in _refresh_features
|
||||
return self.version < self.model.minimum_version
|
||||
|
||||
def check_firmware_version(self, warn_only: bool = False) -> None:
|
||||
if self.is_outdated():
|
||||
|
43
python/src/trezorlib/models.py
Normal file
43
python/src/trezorlib/models.py
Normal file
@ -0,0 +1,43 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Collection, Optional, Tuple
|
||||
|
||||
from . import mapping
|
||||
|
||||
UsbId = Tuple[int, int]
|
||||
|
||||
VENDORS = ("bitcointrezor.com", "trezor.io")
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class TrezorModel:
|
||||
name: str
|
||||
minimum_version: Tuple[int, int, int]
|
||||
vendors: Collection[str]
|
||||
usb_ids: Collection[UsbId]
|
||||
default_mapping: mapping.ProtobufMapping
|
||||
|
||||
|
||||
TREZOR_ONE = TrezorModel(
|
||||
name="1",
|
||||
minimum_version=(1, 8, 0),
|
||||
vendors=VENDORS,
|
||||
usb_ids=((0x534C, 0x0001),),
|
||||
default_mapping=mapping.DEFAULT_MAPPING,
|
||||
)
|
||||
|
||||
TREZOR_T = TrezorModel(
|
||||
name="T",
|
||||
minimum_version=(2, 1, 0),
|
||||
vendors=VENDORS,
|
||||
usb_ids=((0x1209, 0x53C1), (0x1209, 0x53C0)),
|
||||
default_mapping=mapping.DEFAULT_MAPPING,
|
||||
)
|
||||
|
||||
TREZORS = {TREZOR_ONE, TREZOR_T}
|
||||
|
||||
|
||||
def by_name(name: str) -> Optional[TrezorModel]:
|
||||
for model in TREZORS:
|
||||
if model.name == name:
|
||||
return model
|
||||
return None
|
@ -29,17 +29,12 @@ from typing import (
|
||||
from ..exceptions import TrezorException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
T = TypeVar("T", bound="Transport")
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
# USB vendor/product IDs for Trezors
|
||||
DEV_TREZOR1 = (0x534C, 0x0001)
|
||||
DEV_TREZOR2 = (0x1209, 0x53C1)
|
||||
DEV_TREZOR2_BL = (0x1209, 0x53C0)
|
||||
|
||||
TREZORS = {DEV_TREZOR1, DEV_TREZOR2, DEV_TREZOR2_BL}
|
||||
|
||||
UDEV_RULES_STR = """
|
||||
Do you have udev rules installed?
|
||||
https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules
|
||||
@ -95,7 +90,9 @@ class Transport:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls: Type["T"]) -> Iterable["T"]:
|
||||
def enumerate(
|
||||
cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["T"]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@ -126,12 +123,14 @@ def all_transports() -> Iterable[Type["Transport"]]:
|
||||
return set(t for t in transports if t.ENABLED)
|
||||
|
||||
|
||||
def enumerate_devices() -> Sequence["Transport"]:
|
||||
def enumerate_devices(
|
||||
models: Optional[Iterable["TrezorModel"]] = None,
|
||||
) -> Sequence["Transport"]:
|
||||
devices: List["Transport"] = []
|
||||
for transport in all_transports():
|
||||
name = transport.__name__
|
||||
try:
|
||||
found = list(transport.enumerate())
|
||||
found = list(transport.enumerate(models))
|
||||
LOG.info(f"Enumerating {name}: found {len(found)} devices")
|
||||
devices.extend(found)
|
||||
except NotImplementedError:
|
||||
|
@ -16,13 +16,16 @@
|
||||
|
||||
import logging
|
||||
import struct
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import MessagePayload, Transport, TransportException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
TREZORD_HOST = "http://127.0.0.1:21325"
|
||||
@ -135,7 +138,9 @@ class BridgeTransport(Transport):
|
||||
return call_bridge(uri, data=data)
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls) -> Iterable["BridgeTransport"]:
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["BridgeTransport"]:
|
||||
try:
|
||||
legacy = is_legacy_bridge()
|
||||
return [
|
||||
|
@ -17,10 +17,11 @@
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List
|
||||
from typing import Any, Dict, Iterable, List, Optional
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import DEV_TREZOR1, UDEV_RULES_STR, TransportException
|
||||
from ..models import TREZOR_ONE, TrezorModel
|
||||
from . import UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -132,11 +133,17 @@ class HidTransport(ProtocolBasedTransport):
|
||||
return f"{self.PATH_PREFIX}:{self.device['path'].decode()}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls, debug: bool = False) -> Iterable["HidTransport"]:
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False
|
||||
) -> Iterable["HidTransport"]:
|
||||
if models is None:
|
||||
models = {TREZOR_ONE}
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
|
||||
devices: List["HidTransport"] = []
|
||||
for dev in hid.enumerate(0, 0):
|
||||
usb_id = (dev["vendor_id"], dev["product_id"])
|
||||
if usb_id != DEV_TREZOR1:
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if debug:
|
||||
if not is_debuglink(dev):
|
||||
|
@ -17,12 +17,15 @@
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
from typing import Iterable, Optional
|
||||
from typing import TYPE_CHECKING, Iterable, Optional
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..models import TrezorModel
|
||||
|
||||
SOCKET_TIMEOUT = 10
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -70,7 +73,9 @@ class UdpTransport(ProtocolBasedTransport):
|
||||
d.close()
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls) -> Iterable["UdpTransport"]:
|
||||
def enumerate(
|
||||
cls, _models: Optional[Iterable["TrezorModel"]] = None
|
||||
) -> Iterable["UdpTransport"]:
|
||||
default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}"
|
||||
try:
|
||||
return [cls._try_path(default_path)]
|
||||
|
@ -21,7 +21,8 @@ import time
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
from ..log import DUMP_PACKETS
|
||||
from . import TREZORS, UDEV_RULES_STR, TransportException
|
||||
from ..models import TREZORS, TrezorModel
|
||||
from . import UDEV_RULES_STR, TransportException
|
||||
from .protocol import ProtocolBasedTransport, ProtocolV1
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@ -114,15 +115,21 @@ class WebUsbTransport(ProtocolBasedTransport):
|
||||
return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}"
|
||||
|
||||
@classmethod
|
||||
def enumerate(cls, usb_reset: bool = False) -> Iterable["WebUsbTransport"]:
|
||||
def enumerate(
|
||||
cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False
|
||||
) -> Iterable["WebUsbTransport"]:
|
||||
if cls.context is None:
|
||||
cls.context = usb1.USBContext()
|
||||
cls.context.open()
|
||||
atexit.register(cls.context.close) # type: ignore [Param spec "_P@register" has no bound value]
|
||||
|
||||
if models is None:
|
||||
models = TREZORS
|
||||
usb_ids = [id for model in models for id in model.usb_ids]
|
||||
devices: List["WebUsbTransport"] = []
|
||||
for dev in cls.context.getDeviceIterator(skip_on_error=True):
|
||||
usb_id = (dev.getVendorID(), dev.getProductID())
|
||||
if usb_id not in TREZORS:
|
||||
if usb_id not in usb_ids:
|
||||
continue
|
||||
if not is_vendor_class(dev):
|
||||
continue
|
||||
|
@ -14,9 +14,11 @@
|
||||
# You should have received a copy of the License along with this library.
|
||||
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
||||
|
||||
import dataclasses
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, debuglink, device, exceptions, fido
|
||||
from trezorlib import btc, debuglink, device, exceptions, fido, models
|
||||
from trezorlib.messages import BackupType
|
||||
from trezorlib.tools import H_
|
||||
|
||||
@ -26,9 +28,9 @@ from ..device_handler import BackgroundDeviceHandler
|
||||
from ..emulators import ALL_TAGS, EmulatorWrapper
|
||||
from . import for_all, for_tags
|
||||
|
||||
MINIMUM_FIRMWARE_VERSION["1"] = (1, 0, 0)
|
||||
MINIMUM_FIRMWARE_VERSION["T"] = (2, 0, 0)
|
||||
|
||||
models.TREZOR_ONE = dataclasses.replace(models.TREZOR_ONE, minimum_version=(1, 0, 0))
|
||||
models.TREZOR_T = dataclasses.replace(models.TREZOR_T, minimum_version=(2, 0, 0))
|
||||
models.TREZORS = {models.TREZOR_ONE, models.TREZOR_T}
|
||||
|
||||
# **** COMMON DEFINITIONS ****
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import MINIMUM_FIRMWARE_VERSION, btc, device, mapping, messages, protobuf
|
||||
from trezorlib import btc, device, mapping, messages, models, protobuf
|
||||
from trezorlib.tools import parse_path
|
||||
|
||||
from ..emulators import EmulatorWrapper
|
||||
@ -57,8 +57,8 @@ def emulator(gen, tag):
|
||||
|
||||
|
||||
@for_all(
|
||||
core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"],
|
||||
legacy_minimum_version=MINIMUM_FIRMWARE_VERSION["1"],
|
||||
core_minimum_version=models.TREZOR_T.minimum_version,
|
||||
legacy_minimum_version=models.TREZOR_ONE.minimum_version,
|
||||
)
|
||||
def test_passphrase_works(emulator):
|
||||
"""Check that passphrase handling in trezorlib works correctly in all versions."""
|
||||
@ -92,7 +92,7 @@ def test_passphrase_works(emulator):
|
||||
|
||||
|
||||
@for_all(
|
||||
core_minimum_version=MINIMUM_FIRMWARE_VERSION["T"],
|
||||
core_minimum_version=models.TREZOR_T.minimum_version,
|
||||
legacy_minimum_version=(1, 9, 0),
|
||||
)
|
||||
def test_init_device(emulator):
|
||||
|
Loading…
Reference in New Issue
Block a user