1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-15 00:52:02 +00:00
trezor-firmware/python/src/trezorlib/client.py
2025-01-31 14:57:30 +01:00

256 lines
8.5 KiB
Python

# This file is part of the Trezor project.
#
# Copyright (C) 2012-2022 SatoshiLabs and contributors
#
# This library is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3
# as published by the Free Software Foundation.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import logging
import os
import typing as t
from enum import IntEnum
from . import mapping, messages, models
from .mapping import ProtobufMapping
from .tools import parse_path
from .transport import Transport, get_transport
from .transport.thp.channel_data import ChannelData
from .transport.thp.protocol_and_channel import ProtocolAndChannel
from .transport.thp.protocol_v1 import ProtocolV1
from .transport.thp.protocol_v2 import ProtocolV2
if t.TYPE_CHECKING:
from .transport.session import Session
LOG = logging.getLogger(__name__)
MAX_PASSPHRASE_LENGTH = 50
MAX_PIN_LENGTH = 50
PASSPHRASE_ON_DEVICE = object()
PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0")
OUTDATED_FIRMWARE_ERROR = """
Your Trezor firmware is out of date. Update it with the following command:
trezorctl firmware update
Or visit https://suite.trezor.io/
""".strip()
LOG = logging.getLogger(__name__)
class ProtocolVersion(IntEnum):
UNKNOWN = 0x00
PROTOCOL_V1 = 0x01 # Codec
PROTOCOL_V2 = 0x02 # THP
class TrezorClient:
button_callback: t.Callable[[Session, t.Any], t.Any] | None = None
passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None
pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None
_management_session: Session | None = None
_features: messages.Features | None = None
_protocol_version: int
_has_setup_pin: bool = False # Should by used only by conftest
def __init__(
self,
transport: Transport,
protobuf_mapping: ProtobufMapping | None = None,
protocol: ProtocolAndChannel | None = None,
) -> None:
self.transport = transport
if protobuf_mapping is None:
self.mapping = mapping.DEFAULT_MAPPING
else:
self.mapping = protobuf_mapping
if protocol is None:
self.protocol = self._get_protocol()
else:
self.protocol = protocol
self.protocol.mapping = self.mapping
if isinstance(self.protocol, ProtocolV1):
self._protocol_version = ProtocolVersion.PROTOCOL_V1
elif isinstance(self.protocol, ProtocolV2):
self._protocol_version = ProtocolVersion.PROTOCOL_V2
else:
self._protocol_version = ProtocolVersion.UNKNOWN
@classmethod
def resume(
cls,
transport: Transport,
channel_data: ChannelData,
protobuf_mapping: ProtobufMapping | None = None,
) -> TrezorClient:
if protobuf_mapping is None:
protobuf_mapping = mapping.DEFAULT_MAPPING
protocol_v1 = ProtocolV1(transport, protobuf_mapping)
if channel_data.protocol_version == 2:
try:
protocol_v1.write(messages.Ping(message="Sanity check - to resume"))
except Exception as e:
print(type(e))
response = protocol_v1.read()
if (
isinstance(response, messages.Failure)
and response.code == messages.FailureType.InvalidProtocol
):
protocol = ProtocolV2(transport, protobuf_mapping, channel_data)
protocol.write(0, messages.Ping())
response = protocol.read(0)
if not isinstance(response, messages.Success):
LOG.debug("Failed to resume ProtocolV2")
raise Exception("Failed to resume ProtocolV2")
LOG.debug("Protocol V2 detected - can be resumed")
else:
LOG.debug("Failed to resume ProtocolV2")
raise Exception("Failed to resume ProtocolV2")
else:
protocol = ProtocolV1(transport, protobuf_mapping, channel_data)
return TrezorClient(transport, protobuf_mapping, protocol)
def get_session(
self,
passphrase: str | object | None = None,
derive_cardano: bool = False,
) -> Session:
"""
Returns initialized session (with derived seed).
Will fail if the device is not initialized
"""
from .transport.session import SessionV1, SessionV2
if isinstance(self.protocol, ProtocolV1):
if passphrase is None:
passphrase = ""
return SessionV1.new(self, passphrase, derive_cardano)
if isinstance(self.protocol, ProtocolV2):
assert isinstance(passphrase, str) or passphrase is None
return SessionV2.new(self, passphrase, derive_cardano)
raise NotImplementedError # TODO
def resume_session(self, session: Session):
"""
Note: this function potentially modifies the input session.
"""
from .debuglink import SessionDebugWrapper
from .transport.session import SessionV1, SessionV2
if isinstance(session, SessionDebugWrapper):
session = session._session
if isinstance(session, SessionV2):
return session
elif isinstance(session, SessionV1):
session.init_session()
return session
else:
raise NotImplementedError
def get_management_session(self, new_session: bool = False) -> Session:
from .transport.session import SessionV1, SessionV2
if not new_session and self._management_session is not None:
return self._management_session
if isinstance(self.protocol, ProtocolV1):
self._management_session = SessionV1.new(
client=self,
passphrase="",
derive_cardano=False,
)
elif isinstance(self.protocol, ProtocolV2):
self._management_session = SessionV2(client=self, id=b"\x00")
assert self._management_session is not None
return self._management_session
@property
def features(self) -> messages.Features:
if self._features is None:
self._features = self.protocol.get_features()
assert self._features is not None
return self._features
@property
def protocol_version(self) -> int:
return self._protocol_version
@property
def model(self) -> models.TrezorModel:
f = self.features
model = models.by_name(f.model or "1")
if model is None:
raise RuntimeError(
"Unsupported Trezor model"
f" (internal_model: {f.internal_model}, model: {f.model})"
)
return model
@property
def version(self) -> tuple[int, int, int]:
f = self.features
ver = (
f.major_version,
f.minor_version,
f.patch_version,
)
return ver
def refresh_features(self) -> None:
self.protocol.update_features()
self._features = self.protocol.get_features()
def _get_protocol(self) -> ProtocolAndChannel:
self.transport.open()
protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING)
protocol.write(messages.Initialize())
response = protocol.read()
self.transport.close()
if isinstance(response, messages.Failure):
if response.code == messages.FailureType.InvalidProtocol:
LOG.debug("Protocol V2 detected")
protocol = ProtocolV2(self.transport, self.mapping)
return protocol
def get_default_client(
path: t.Optional[str] = None,
**kwargs: t.Any,
) -> "TrezorClient":
"""Get a client for a connected Trezor device.
Returns a TrezorClient instance with minimum fuss.
If path is specified, does a prefix-search for the specified device. Otherwise, uses
the value of TREZOR_PATH env variable, or finds first connected Trezor.
If no UI is supplied, instantiates the default CLI UI.
"""
if path is None:
path = os.getenv("TREZOR_PATH")
transport = get_transport(path, prefix_search=True)
return TrezorClient(transport, **kwargs)