# This file is part of the Trezor project. # # Copyright (C) 2012-2022 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 # as published by the Free Software Foundation. # # This library is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Lesser General Public License for more details. # # You should have received a copy of the License along with this library. # If not, see . from __future__ import annotations import logging import os import typing as t from enum import IntEnum from . import mapping, messages, models from .mapping import ProtobufMapping from .tools import parse_path from .transport import Transport, get_transport from .transport.thp.channel_data import ChannelData from .transport.thp.protocol_and_channel import ProtocolAndChannel from .transport.thp.protocol_v1 import ProtocolV1 from .transport.thp.protocol_v2 import ProtocolV2 if t.TYPE_CHECKING: from .transport.session import Session LOG = logging.getLogger(__name__) MAX_PASSPHRASE_LENGTH = 50 MAX_PIN_LENGTH = 50 PASSPHRASE_ON_DEVICE = object() PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") OUTDATED_FIRMWARE_ERROR = """ Your Trezor firmware is out of date. Update it with the following command: trezorctl firmware update Or visit https://suite.trezor.io/ """.strip() LOG = logging.getLogger(__name__) class ProtocolVersion(IntEnum): UNKNOWN = 0x00 PROTOCOL_V1 = 0x01 # Codec PROTOCOL_V2 = 0x02 # THP class TrezorClient: button_callback: t.Callable[[Session, t.Any], t.Any] | None = None passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None _management_session: Session | None = None _features: messages.Features | None = None _protocol_version: int _has_setup_pin: bool = False # Should by used only by conftest def __init__( self, transport: Transport, protobuf_mapping: ProtobufMapping | None = None, protocol: ProtocolAndChannel | None = None, ) -> None: self.transport = transport if protobuf_mapping is None: self.mapping = mapping.DEFAULT_MAPPING else: self.mapping = protobuf_mapping if protocol is None: self.protocol = self._get_protocol() else: self.protocol = protocol self.protocol.mapping = self.mapping if isinstance(self.protocol, ProtocolV1): self._protocol_version = ProtocolVersion.PROTOCOL_V1 elif isinstance(self.protocol, ProtocolV2): self._protocol_version = ProtocolVersion.PROTOCOL_V2 else: self._protocol_version = ProtocolVersion.UNKNOWN @classmethod def resume( cls, transport: Transport, channel_data: ChannelData, protobuf_mapping: ProtobufMapping | None = None, ) -> TrezorClient: if protobuf_mapping is None: protobuf_mapping = mapping.DEFAULT_MAPPING protocol_v1 = ProtocolV1(transport, protobuf_mapping) if channel_data.protocol_version == 2: try: protocol_v1.write(messages.Ping(message="Sanity check - to resume")) except Exception as e: print(type(e)) response = protocol_v1.read() if ( isinstance(response, messages.Failure) and response.code == messages.FailureType.InvalidProtocol ): protocol = ProtocolV2(transport, protobuf_mapping, channel_data) protocol.write(0, messages.Ping()) response = protocol.read(0) if not isinstance(response, messages.Success): LOG.debug("Failed to resume ProtocolV2") raise Exception("Failed to resume ProtocolV2") LOG.debug("Protocol V2 detected - can be resumed") else: LOG.debug("Failed to resume ProtocolV2") raise Exception("Failed to resume ProtocolV2") else: protocol = ProtocolV1(transport, protobuf_mapping, channel_data) return TrezorClient(transport, protobuf_mapping, protocol) def get_session( self, passphrase: str | object | None = None, derive_cardano: bool = False, ) -> Session: """ Returns initialized session (with derived seed). Will fail if the device is not initialized """ from .transport.session import SessionV1, SessionV2 if isinstance(self.protocol, ProtocolV1): if passphrase is None: passphrase = "" return SessionV1.new(self, passphrase, derive_cardano) if isinstance(self.protocol, ProtocolV2): assert isinstance(passphrase, str) or passphrase is None return SessionV2.new(self, passphrase, derive_cardano) raise NotImplementedError # TODO def resume_session(self, session: Session): """ Note: this function potentially modifies the input session. """ from .debuglink import SessionDebugWrapper from .transport.session import SessionV1, SessionV2 if isinstance(session, SessionDebugWrapper): session = session._session if isinstance(session, SessionV2): return session elif isinstance(session, SessionV1): session.init_session() return session else: raise NotImplementedError def get_management_session(self, new_session: bool = False) -> Session: from .transport.session import SessionV1, SessionV2 if not new_session and self._management_session is not None: return self._management_session if isinstance(self.protocol, ProtocolV1): self._management_session = SessionV1.new( client=self, passphrase="", derive_cardano=False, ) elif isinstance(self.protocol, ProtocolV2): self._management_session = SessionV2(client=self, id=b"\x00") assert self._management_session is not None return self._management_session @property def features(self) -> messages.Features: if self._features is None: self._features = self.protocol.get_features() assert self._features is not None return self._features @property def protocol_version(self) -> int: return self._protocol_version @property def model(self) -> models.TrezorModel: f = self.features model = models.by_name(f.model or "1") if model is None: raise RuntimeError( "Unsupported Trezor model" f" (internal_model: {f.internal_model}, model: {f.model})" ) return model @property def version(self) -> tuple[int, int, int]: f = self.features ver = ( f.major_version, f.minor_version, f.patch_version, ) return ver def refresh_features(self) -> None: self.protocol.update_features() self._features = self.protocol.get_features() def _get_protocol(self) -> ProtocolAndChannel: self.transport.open() protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING) protocol.write(messages.Initialize()) response = protocol.read() self.transport.close() if isinstance(response, messages.Failure): if response.code == messages.FailureType.InvalidProtocol: LOG.debug("Protocol V2 detected") protocol = ProtocolV2(self.transport, self.mapping) return protocol def get_default_client( path: t.Optional[str] = None, **kwargs: t.Any, ) -> "TrezorClient": """Get a client for a connected Trezor device. Returns a TrezorClient instance with minimum fuss. If path is specified, does a prefix-search for the specified device. Otherwise, uses the value of TREZOR_PATH env variable, or finds first connected Trezor. If no UI is supplied, instantiates the default CLI UI. """ if path is None: path = os.getenv("TREZOR_PATH") transport = get_transport(path, prefix_search=True) return TrezorClient(transport, **kwargs)