mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-15 00:52:02 +00:00
256 lines
8.5 KiB
Python
256 lines
8.5 KiB
Python
# This file is part of the Trezor project.
|
|
#
|
|
# Copyright (C) 2012-2022 SatoshiLabs and contributors
|
|
#
|
|
# This library is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Lesser General Public License version 3
|
|
# as published by the Free Software Foundation.
|
|
#
|
|
# This library is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Lesser General Public License for more details.
|
|
#
|
|
# You should have received a copy of the License along with this library.
|
|
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import typing as t
|
|
from enum import IntEnum
|
|
|
|
from . import mapping, messages, models
|
|
from .mapping import ProtobufMapping
|
|
from .tools import parse_path
|
|
from .transport import Transport, get_transport
|
|
from .transport.thp.channel_data import ChannelData
|
|
from .transport.thp.protocol_and_channel import ProtocolAndChannel
|
|
from .transport.thp.protocol_v1 import ProtocolV1
|
|
from .transport.thp.protocol_v2 import ProtocolV2
|
|
|
|
if t.TYPE_CHECKING:
|
|
from .transport.session import Session
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
MAX_PASSPHRASE_LENGTH = 50
|
|
MAX_PIN_LENGTH = 50
|
|
|
|
PASSPHRASE_ON_DEVICE = object()
|
|
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
|
|
|
|
OUTDATED_FIRMWARE_ERROR = """
|
|
Your Trezor firmware is out of date. Update it with the following command:
|
|
trezorctl firmware update
|
|
Or visit https://suite.trezor.io/
|
|
""".strip()
|
|
|
|
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
|
|
class ProtocolVersion(IntEnum):
|
|
UNKNOWN = 0x00
|
|
PROTOCOL_V1 = 0x01 # Codec
|
|
PROTOCOL_V2 = 0x02 # THP
|
|
|
|
|
|
class TrezorClient:
|
|
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
|
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
|
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
|
|
|
|
_management_session: Session | None = None
|
|
_features: messages.Features | None = None
|
|
_protocol_version: int
|
|
_has_setup_pin: bool = False # Should by used only by conftest
|
|
|
|
def __init__(
|
|
self,
|
|
transport: Transport,
|
|
protobuf_mapping: ProtobufMapping | None = None,
|
|
protocol: ProtocolAndChannel | None = None,
|
|
) -> None:
|
|
self.transport = transport
|
|
|
|
if protobuf_mapping is None:
|
|
self.mapping = mapping.DEFAULT_MAPPING
|
|
else:
|
|
self.mapping = protobuf_mapping
|
|
if protocol is None:
|
|
self.protocol = self._get_protocol()
|
|
else:
|
|
self.protocol = protocol
|
|
self.protocol.mapping = self.mapping
|
|
|
|
if isinstance(self.protocol, ProtocolV1):
|
|
self._protocol_version = ProtocolVersion.PROTOCOL_V1
|
|
elif isinstance(self.protocol, ProtocolV2):
|
|
self._protocol_version = ProtocolVersion.PROTOCOL_V2
|
|
else:
|
|
self._protocol_version = ProtocolVersion.UNKNOWN
|
|
|
|
@classmethod
|
|
def resume(
|
|
cls,
|
|
transport: Transport,
|
|
channel_data: ChannelData,
|
|
protobuf_mapping: ProtobufMapping | None = None,
|
|
) -> TrezorClient:
|
|
if protobuf_mapping is None:
|
|
protobuf_mapping = mapping.DEFAULT_MAPPING
|
|
protocol_v1 = ProtocolV1(transport, protobuf_mapping)
|
|
if channel_data.protocol_version == 2:
|
|
try:
|
|
protocol_v1.write(messages.Ping(message="Sanity check - to resume"))
|
|
except Exception as e:
|
|
print(type(e))
|
|
response = protocol_v1.read()
|
|
if (
|
|
isinstance(response, messages.Failure)
|
|
and response.code == messages.FailureType.InvalidProtocol
|
|
):
|
|
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
|
|
protocol.write(0, messages.Ping())
|
|
response = protocol.read(0)
|
|
if not isinstance(response, messages.Success):
|
|
LOG.debug("Failed to resume ProtocolV2")
|
|
raise Exception("Failed to resume ProtocolV2")
|
|
LOG.debug("Protocol V2 detected - can be resumed")
|
|
else:
|
|
LOG.debug("Failed to resume ProtocolV2")
|
|
raise Exception("Failed to resume ProtocolV2")
|
|
else:
|
|
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
|
|
return TrezorClient(transport, protobuf_mapping, protocol)
|
|
|
|
def get_session(
|
|
self,
|
|
passphrase: str | object | None = None,
|
|
derive_cardano: bool = False,
|
|
) -> Session:
|
|
"""
|
|
Returns initialized session (with derived seed).
|
|
|
|
Will fail if the device is not initialized
|
|
"""
|
|
from .transport.session import SessionV1, SessionV2
|
|
|
|
if isinstance(self.protocol, ProtocolV1):
|
|
if passphrase is None:
|
|
passphrase = ""
|
|
return SessionV1.new(self, passphrase, derive_cardano)
|
|
if isinstance(self.protocol, ProtocolV2):
|
|
assert isinstance(passphrase, str) or passphrase is None
|
|
return SessionV2.new(self, passphrase, derive_cardano)
|
|
raise NotImplementedError # TODO
|
|
|
|
def resume_session(self, session: Session):
|
|
"""
|
|
Note: this function potentially modifies the input session.
|
|
"""
|
|
from .debuglink import SessionDebugWrapper
|
|
from .transport.session import SessionV1, SessionV2
|
|
|
|
if isinstance(session, SessionDebugWrapper):
|
|
session = session._session
|
|
|
|
if isinstance(session, SessionV2):
|
|
return session
|
|
elif isinstance(session, SessionV1):
|
|
session.init_session()
|
|
return session
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
def get_management_session(self, new_session: bool = False) -> Session:
|
|
from .transport.session import SessionV1, SessionV2
|
|
|
|
if not new_session and self._management_session is not None:
|
|
return self._management_session
|
|
if isinstance(self.protocol, ProtocolV1):
|
|
self._management_session = SessionV1.new(
|
|
client=self,
|
|
passphrase="",
|
|
derive_cardano=False,
|
|
)
|
|
elif isinstance(self.protocol, ProtocolV2):
|
|
self._management_session = SessionV2(client=self, id=b"\x00")
|
|
assert self._management_session is not None
|
|
return self._management_session
|
|
|
|
@property
|
|
def features(self) -> messages.Features:
|
|
if self._features is None:
|
|
self._features = self.protocol.get_features()
|
|
assert self._features is not None
|
|
return self._features
|
|
|
|
@property
|
|
def protocol_version(self) -> int:
|
|
return self._protocol_version
|
|
|
|
@property
|
|
def model(self) -> models.TrezorModel:
|
|
f = self.features
|
|
model = models.by_name(f.model or "1")
|
|
|
|
if model is None:
|
|
raise RuntimeError(
|
|
"Unsupported Trezor model"
|
|
f" (internal_model: {f.internal_model}, model: {f.model})"
|
|
)
|
|
return model
|
|
|
|
@property
|
|
def version(self) -> tuple[int, int, int]:
|
|
f = self.features
|
|
ver = (
|
|
f.major_version,
|
|
f.minor_version,
|
|
f.patch_version,
|
|
)
|
|
return ver
|
|
|
|
def refresh_features(self) -> None:
|
|
self.protocol.update_features()
|
|
self._features = self.protocol.get_features()
|
|
|
|
def _get_protocol(self) -> ProtocolAndChannel:
|
|
self.transport.open()
|
|
|
|
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
|
|
|
|
protocol.write(messages.Initialize())
|
|
|
|
response = protocol.read()
|
|
self.transport.close()
|
|
if isinstance(response, messages.Failure):
|
|
if response.code == messages.FailureType.InvalidProtocol:
|
|
LOG.debug("Protocol V2 detected")
|
|
protocol = ProtocolV2(self.transport, self.mapping)
|
|
return protocol
|
|
|
|
|
|
def get_default_client(
|
|
path: t.Optional[str] = None,
|
|
**kwargs: t.Any,
|
|
) -> "TrezorClient":
|
|
"""Get a client for a connected Trezor device.
|
|
|
|
Returns a TrezorClient instance with minimum fuss.
|
|
|
|
If path is specified, does a prefix-search for the specified device. Otherwise, uses
|
|
the value of TREZOR_PATH env variable, or finds first connected Trezor.
|
|
If no UI is supplied, instantiates the default CLI UI.
|
|
"""
|
|
|
|
if path is None:
|
|
path = os.getenv("TREZOR_PATH")
|
|
|
|
transport = get_transport(path, prefix_search=True)
|
|
|
|
return TrezorClient(transport, **kwargs)
|