diff --git a/python/src/trezorlib/authentication.py b/python/src/trezorlib/authentication.py index 39e26f569f..2e4a530af5 100644 --- a/python/src/trezorlib/authentication.py +++ b/python/src/trezorlib/authentication.py @@ -7,7 +7,7 @@ import typing as t from importlib import metadata from . import device -from .client import TrezorClient +from .transport.session import Session try: cryptography_version = metadata.version("cryptography") @@ -361,7 +361,7 @@ def verify_authentication_response( def authenticate_device( - client: TrezorClient, + session: Session, challenge: bytes | None = None, *, whitelist: t.Collection[bytes] | None = None, @@ -371,7 +371,7 @@ def authenticate_device( if challenge is None: challenge = secrets.token_bytes(16) - resp = device.authenticate(client, challenge) + resp = device.authenticate(session, challenge) return verify_authentication_response( challenge, diff --git a/python/src/trezorlib/benchmark.py b/python/src/trezorlib/benchmark.py index 6587e2a3ab..64218b7aad 100644 --- a/python/src/trezorlib/benchmark.py +++ b/python/src/trezorlib/benchmark.py @@ -19,16 +19,16 @@ from typing import TYPE_CHECKING from . import messages if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session def list_names( - client: "TrezorClient", + session: "Session", ) -> messages.BenchmarkNames: - return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames) + return session.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames) -def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult: - return client.call( +def run(session: "Session", name: str) -> messages.BenchmarkResult: + return session.call( messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult ) diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index 938092a2df..6b35db0446 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -18,20 +18,19 @@ from typing import TYPE_CHECKING from . import messages from .protobuf import dict_to_proto -from .tools import session if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.BinanceGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -40,17 +39,16 @@ def get_address( def get_public_key( - client: "TrezorClient", address_n: "Address", show_display: bool = False + session: "Session", address_n: "Address", show_display: bool = False ) -> bytes: - return client.call( + return session.call( messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display), expect=messages.BinancePublicKey, ).public_key -@session def sign_tx( - client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False + session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False ) -> messages.BinanceSignedTx: msg = tx_json["msgs"][0] tx_msg = tx_json.copy() @@ -59,7 +57,7 @@ def sign_tx( tx_msg["chunkify"] = chunkify envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) - client.call(envelope, expect=messages.BinanceTxRequest) + session.call(envelope, expect=messages.BinanceTxRequest) if "refid" in msg: msg = dict_to_proto(messages.BinanceCancelMsg, msg) @@ -70,4 +68,4 @@ def sign_tx( else: raise ValueError("can not determine msg type") - return client.call(msg, expect=messages.BinanceSignedTx) + return session.call(msg, expect=messages.BinanceSignedTx) diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index 078f486d9e..e3980055fc 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -25,11 +25,11 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple from typing_extensions import Protocol, TypedDict from . import exceptions, messages -from .tools import _return_success, prepare_message_bytes, session +from .tools import _return_success, prepare_message_bytes if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session class ScriptSig(TypedDict): asm: str @@ -105,7 +105,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType: def get_public_node( - client: "TrezorClient", + session: "Session", n: "Address", ecdsa_curve_name: Optional[str] = None, show_display: bool = False, @@ -116,12 +116,12 @@ def get_public_node( unlock_path_mac: Optional[bytes] = None, ) -> messages.PublicKey: if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) - return client.call( + return session.call( messages.GetPublicKey( address_n=n, ecdsa_curve_name=ecdsa_curve_name, @@ -139,7 +139,7 @@ def get_address(*args: Any, **kwargs: Any) -> str: def get_authenticated_address( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", show_display: bool = False, @@ -151,12 +151,12 @@ def get_authenticated_address( chunkify: bool = False, ) -> messages.Address: if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) - return client.call( + return session.call( messages.GetAddress( address_n=n, coin_name=coin_name, @@ -171,13 +171,13 @@ def get_authenticated_address( def get_ownership_id( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> bytes: - return client.call( + return session.call( messages.GetOwnershipId( address_n=n, coin_name=coin_name, @@ -188,8 +188,9 @@ def get_ownership_id( ).ownership_id +# TODO this is used by tests only def get_ownership_proof( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, @@ -200,9 +201,9 @@ def get_ownership_proof( preauthorized: bool = False, ) -> Tuple[bytes, bytes]: if preauthorized: - client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) + session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) - res = client.call( + res = session.call( messages.GetOwnershipProof( address_n=n, coin_name=coin_name, @@ -219,7 +220,7 @@ def get_ownership_proof( def sign_message( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", message: AnyStr, @@ -227,7 +228,7 @@ def sign_message( no_script_type: bool = False, chunkify: bool = False, ) -> messages.MessageSignature: - return client.call( + return session.call( messages.SignMessage( coin_name=coin_name, address_n=n, @@ -241,7 +242,7 @@ def sign_message( def verify_message( - client: "TrezorClient", + session: "Session", coin_name: str, address: str, signature: bytes, @@ -249,7 +250,7 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - client.call( + session.call( messages.VerifyMessage( address=address, signature=signature, @@ -264,9 +265,9 @@ def verify_message( return False -@session +# @session def sign_tx( - client: "TrezorClient", + session: "Session", coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], @@ -314,14 +315,14 @@ def sign_tx( setattr(signtx, name, value) if unlock_path: - client.call( + session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), expect=messages.UnlockedPathRequest, ) elif preauthorized: - client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) + session.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) - res = client.call(signtx, expect=messages.TxRequest) + res = session.call(signtx, expect=messages.TxRequest) # Prepare structure for signatures signatures: List[Optional[bytes]] = [None] * len(inputs) @@ -380,7 +381,7 @@ def sign_tx( if res.request_type == R.TXPAYMENTREQ: assert res.details.request_index is not None msg = payment_reqs[res.details.request_index] - res = client.call(msg, expect=messages.TxRequest) + res = session.call(msg, expect=messages.TxRequest) else: msg = messages.TransactionType() if res.request_type == R.TXMETA: @@ -410,7 +411,7 @@ def sign_tx( f"Unknown request type - {res.request_type}." ) - res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest) + res = session.call(messages.TxAck(tx=msg), expect=messages.TxRequest) for i, sig in zip(inputs, signatures): if i.script_type != messages.InputScriptType.EXTERNAL and sig is None: @@ -420,7 +421,7 @@ def sign_tx( def authorize_coinjoin( - client: "TrezorClient", + session: "Session", coordinator: str, max_rounds: int, max_coordinator_fee_rate: int, @@ -429,7 +430,7 @@ def authorize_coinjoin( coin_name: str, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> str | None: - resp = client.call( + resp = session.call( messages.AuthorizeCoinJoin( coordinator=coordinator, max_rounds=max_rounds, diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 4cbc635f1f..a945cc9b10 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -35,7 +35,7 @@ from . import messages as m from . import tools if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session PROTOCOL_MAGICS = { "mainnet": 764824073, @@ -818,7 +818,7 @@ def _get_collateral_inputs_items( def get_address( - client: "TrezorClient", + session: "Session", address_parameters: m.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], @@ -826,7 +826,7 @@ def get_address( derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, chunkify: bool = False, ) -> str: - return client.call( + return session.call( m.CardanoGetAddress( address_parameters=address_parameters, protocol_magic=protocol_magic, @@ -840,12 +840,12 @@ def get_address( def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, show_display: bool = False, ) -> m.CardanoPublicKey: - return client.call( + return session.call( m.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type, @@ -856,12 +856,12 @@ def get_public_key( def get_native_script_hash( - client: "TrezorClient", + session: "Session", native_script: m.CardanoNativeScript, display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, ) -> m.CardanoNativeScriptHash: - return client.call( + return session.call( m.CardanoGetNativeScriptHash( script=native_script, display_format=display_format, @@ -872,7 +872,7 @@ def get_native_script_hash( def sign_tx( - client: "TrezorClient", + session: "Session", signing_mode: m.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithData], @@ -907,7 +907,7 @@ def sign_tx( signing_mode, ) - response = client.call( + response = session.call( m.CardanoSignTxInit( signing_mode=signing_mode, inputs_count=len(inputs), @@ -942,12 +942,12 @@ def sign_tx( _get_certificates_items(certificates), withdrawals, ): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: - auxiliary_data_supplement = client.call( + auxiliary_data_supplement = session.call( auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement ) if ( @@ -958,25 +958,25 @@ def sign_tx( auxiliary_data_supplement.__dict__ ) - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck) for tx_item in chain( _get_mint_items(mint), _get_collateral_inputs_items(collateral_inputs), required_signers, ): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) if collateral_return is not None: for tx_item in _get_output_items(collateral_return): - response = client.call(tx_item, expect=m.CardanoTxItemAck) + response = session.call(tx_item, expect=m.CardanoTxItemAck) for reference_input in reference_inputs: - response = client.call(reference_input, expect=m.CardanoTxItemAck) + response = session.call(reference_input, expect=m.CardanoTxItemAck) sign_tx_response["witnesses"] = [] for witness_request in witness_requests: - response = client.call(witness_request, expect=m.CardanoTxWitnessResponse) + response = session.call(witness_request, expect=m.CardanoTxWitnessResponse) sign_tx_response["witnesses"].append( { "type": response.type, @@ -986,9 +986,9 @@ def sign_tx( } ) - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash) sign_tx_response["tx_hash"] = response.tx_hash - response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished) + response = session.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished) return sign_tx_response diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 4e432bd012..9d4a9c0f39 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -13,28 +13,24 @@ # # You should have received a copy of the License along with this library. # If not, see . - from __future__ import annotations import logging import os -import warnings -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +import typing as t +from enum import IntEnum -from mnemonic import Mnemonic +from . import mapping, messages, models +from .mapping import ProtobufMapping +from .tools import parse_path +from .transport import Transport, get_transport +from .transport.thp.channel_data import ChannelData +from .transport.thp.protocol_and_channel import ProtocolAndChannel +from .transport.thp.protocol_v1 import ProtocolV1 +from .transport.thp.protocol_v2 import ProtocolV2 -from . import exceptions, mapping, messages, models -from .log import DUMP_BYTES -from .messages import Capability -from .protobuf import MessageType -from .tools import parse_path, session - -if TYPE_CHECKING: - from .transport import Transport - from .ui import TrezorClientUI - -UI = TypeVar("UI", bound="TrezorClientUI") -MT = TypeVar("MT", bound=MessageType) +if t.TYPE_CHECKING: + from .transport.session import Session LOG = logging.getLogger(__name__) @@ -51,8 +47,205 @@ Or visit https://suite.trezor.io/ """.strip() +LOG = logging.getLogger(__name__) + + +class ProtocolVersion(IntEnum): + UNKNOWN = 0x00 + PROTOCOL_V1 = 0x01 # Codec + PROTOCOL_V2 = 0x02 # THP + + +class TrezorClient: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + + _seedless_session: Session | None = None + _features: messages.Features | None = None + _protocol_version: int + _setup_pin: str | None = None # Should by used only by conftest + + def __init__( + self, + transport: Transport, + protobuf_mapping: ProtobufMapping | None = None, + protocol: ProtocolAndChannel | None = None, + ) -> None: + self._is_invalidated: bool = False + self.transport = transport + + if protobuf_mapping is None: + self.mapping = mapping.DEFAULT_MAPPING + else: + self.mapping = protobuf_mapping + if protocol is None: + self.protocol = self._get_protocol() + else: + self.protocol = protocol + self.protocol.mapping = self.mapping + + if isinstance(self.protocol, ProtocolV1): + self._protocol_version = ProtocolVersion.PROTOCOL_V1 + elif isinstance(self.protocol, ProtocolV2): + self._protocol_version = ProtocolVersion.PROTOCOL_V2 + else: + self._protocol_version = ProtocolVersion.UNKNOWN + + @classmethod + def resume( + cls, + transport: Transport, + channel_data: ChannelData, + protobuf_mapping: ProtobufMapping | None = None, + ) -> TrezorClient: + if protobuf_mapping is None: + protobuf_mapping = mapping.DEFAULT_MAPPING + protocol_v1 = ProtocolV1(transport, protobuf_mapping) + if channel_data.protocol_version_major == 2: + try: + protocol_v1.write(messages.Ping(message="Sanity check - to resume")) + except Exception as e: + print(type(e)) + response = protocol_v1.read() + if ( + isinstance(response, messages.Failure) + and response.code == messages.FailureType.InvalidProtocol + ): + protocol = ProtocolV2(transport, protobuf_mapping, channel_data) + protocol.write(0, messages.Ping()) + response = protocol.read(0) + if not isinstance(response, messages.Success): + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + LOG.debug("Protocol V2 detected - can be resumed") + else: + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + else: + protocol = ProtocolV1(transport, protobuf_mapping, channel_data) + return TrezorClient(transport, protobuf_mapping, protocol) + + def get_session( + self, + passphrase: str | object | None = None, + derive_cardano: bool = False, + session_id: int = 0, + ) -> Session: + """ + Returns initialized session (with derived seed). + + Will fail if the device is not initialized + """ + from .transport.session import SessionV1, SessionV2 + + if isinstance(self.protocol, ProtocolV1): + if passphrase is None: + passphrase = "" + return SessionV1.new(self, passphrase, derive_cardano) + if isinstance(self.protocol, ProtocolV2): + assert isinstance(passphrase, str) or passphrase is None + return SessionV2.new(self, passphrase, derive_cardano, session_id) + raise NotImplementedError # TODO + + def resume_session(self, session: Session): + """ + Note: this function potentially modifies the input session. + """ + from .debuglink import SessionDebugWrapper + from .transport.session import SessionV1, SessionV2 + + if isinstance(session, SessionDebugWrapper): + session = session._session + + if isinstance(session, SessionV2): + return session + elif isinstance(session, SessionV1): + session.init_session() + return session + + else: + raise NotImplementedError + + def get_seedless_session(self, new_session: bool = False) -> Session: + from .transport.session import SessionV1, SessionV2 + + if not new_session and self._seedless_session is not None: + return self._seedless_session + if isinstance(self.protocol, ProtocolV1): + self._seedless_session = SessionV1.new( + client=self, + passphrase="", + derive_cardano=False, + ) + elif isinstance(self.protocol, ProtocolV2): + self._seedless_session = SessionV2(client=self, id=b"\x00") + assert self._seedless_session is not None + return self._seedless_session + + def invalidate(self) -> None: + self._is_invalidated = True + + @property + def features(self) -> messages.Features: + if self._features is None: + self._features = self.protocol.get_features() + assert self._features is not None + return self._features + + @property + def protocol_version(self) -> int: + return self._protocol_version + + @property + def model(self) -> models.TrezorModel: + f = self.features + model = models.by_name(f.model or "1") + + if model is None: + raise RuntimeError( + "Unsupported Trezor model" + f" (internal_model: {f.internal_model}, model: {f.model})" + ) + return model + + @property + def version(self) -> tuple[int, int, int]: + f = self.features + ver = ( + f.major_version, + f.minor_version, + f.patch_version, + ) + return ver + + @property + def is_invalidated(self) -> bool: + return self._is_invalidated + + def refresh_features(self) -> None: + self.protocol.update_features() + self._features = self.protocol.get_features() + + def _get_protocol(self) -> ProtocolAndChannel: + self.transport.open() + + protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING) + + protocol.write(messages.Initialize()) + + response = protocol.read() + self.transport.close() + if isinstance(response, messages.Failure): + if response.code == messages.FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol = ProtocolV2(self.transport, self.mapping) + return protocol + + def get_default_client( - path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any + path: t.Optional[str] = None, + **kwargs: t.Any, ) -> "TrezorClient": """Get a client for a connected Trezor device. @@ -62,436 +255,10 @@ def get_default_client( the value of TREZOR_PATH env variable, or finds first connected Trezor. If no UI is supplied, instantiates the default CLI UI. """ - from .transport import get_transport - from .ui import ClickUI if path is None: path = os.getenv("TREZOR_PATH") transport = get_transport(path, prefix_search=True) - if ui is None: - ui = ClickUI() - return TrezorClient(transport, ui, **kwargs) - - -class TrezorClient(Generic[UI]): - """Trezor client, a connection to a Trezor device. - - This class allows you to manage connection state, send and receive protobuf - messages, handle user interactions, and perform some generic tasks - (send a cancel message, initialize or clear a session, ping the device). - """ - - model: models.TrezorModel - transport: "Transport" - session_id: Optional[bytes] - ui: UI - features: messages.Features - - def __init__( - self, - transport: "Transport", - ui: UI, - session_id: Optional[bytes] = None, - derive_cardano: Optional[bool] = None, - model: Optional[models.TrezorModel] = None, - _init_device: bool = True, - ) -> None: - """Create a TrezorClient instance. - - You have to provide a `transport`, i.e., a raw connection to the device. You can - use `trezorlib.transport.get_transport` to find one. - - You have to provide a UI implementation for the three kinds of interaction: - - button request (notify the user that their interaction is needed) - - PIN request (on T1, ask the user to input numbers for a PIN matrix) - - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for - details. - - You can supply a `session_id` you might have saved in the previous session. If - you do, the user might not need to enter their passphrase again. - - You can provide Trezor model information. If not provided, it is detected from - the model name reported at initialization time. - - By default, the instance will open a connection to the Trezor device, send an - `Initialize` message, set up the `features` field from the response, and connect - to a session. By specifying `_init_device=False`, this step is skipped. Notably, - this means that `client.features` is unset. Use `client.init_device()` or - `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. - Only use this if you are _sure_ that you know what you are doing. This feature - might be removed at any time. - """ - LOG.info(f"creating client instance for device: {transport.get_path()}") - # Here, self.model could be set to None. Unless _init_device is False, it will - # get correctly reconfigured as part of the init_device flow. - self.model = model # type: ignore ["None" is incompatible with "TrezorModel"] - if self.model: - self.mapping = self.model.default_mapping - else: - self.mapping = mapping.DEFAULT_MAPPING - self.transport = transport - self.ui = ui - self.session_counter = 0 - self.session_id = session_id - if _init_device: - self.init_device(session_id=session_id, derive_cardano=derive_cardano) - - def open(self) -> None: - if self.session_counter == 0: - self.transport.begin_session() - self.session_counter += 1 - - def close(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - # TODO call EndSession here? - self.transport.end_session() - - def cancel(self) -> None: - self._raw_write(messages.Cancel()) - - def call_raw(self, msg: MessageType) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self._raw_write(msg) - return self._raw_read() - - def _raw_write(self, msg: MessageType) -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - LOG.debug( - f"sending message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - msg_type, msg_bytes = self.mapping.encode(msg) - LOG.log( - DUMP_BYTES, - f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - self.transport.write(msg_type, msg_bytes) - - def _raw_read(self) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - msg_type, msg_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - msg = self.mapping.decode(msg_type, msg_bytes) - LOG.debug( - f"received message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - return msg - - def _callback_pin(self, msg: messages.PinMatrixRequest) -> MessageType: - try: - pin = self.ui.get_pin(msg.type) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - self.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - - resp = self.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise exceptions.PinException(resp.code, resp.message) - else: - return resp - - def _callback_passphrase(self, msg: messages.PassphraseRequest) -> MessageType: - available_on_device = Capability.PassphraseEntry in self.features.capabilities - - def send_passphrase( - passphrase: Optional[str] = None, on_device: Optional[bool] = None - ) -> MessageType: - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = self.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - self.session_id = resp.state - resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - passphrase = self.ui.get_passphrase(available_on_device=available_on_device) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - self.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) - - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - self.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - def _callback_button(self, msg: messages.ButtonRequest) -> MessageType: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - self._raw_write(messages.ButtonAck()) - self.ui.button_request(msg) - return self._raw_read() - - @session - def call(self, msg: MessageType, expect: type[MT] = MessageType) -> MT: - self.check_firmware_version() - resp = self.call_raw(msg) - while True: - if isinstance(resp, messages.PinMatrixRequest): - resp = self._callback_pin(resp) - elif isinstance(resp, messages.PassphraseRequest): - resp = self._callback_passphrase(resp) - elif isinstance(resp, messages.ButtonRequest): - resp = self._callback_button(resp) - elif isinstance(resp, messages.Failure): - if resp.code == messages.FailureType.ActionCancelled: - raise exceptions.Cancelled - raise exceptions.TrezorFailure(resp) - elif not isinstance(resp, expect): - raise exceptions.UnexpectedMessageError(expect, resp) - else: - return resp - - def _refresh_features(self, features: messages.Features) -> None: - """Update internal fields based on passed-in Features message.""" - - if not self.model: - # Trezor Model One bootloader 1.8.0 or older does not send model name - model = models.by_internal_name(features.internal_model) - if model is None: - model = models.by_name(features.model or "1") - if model is None: - raise RuntimeError( - "Unsupported Trezor model" - f" (internal_model: {features.internal_model}, model: {features.model})" - ) - self.model = model - - if features.vendor not in self.model.vendors: - raise RuntimeError("Unsupported device") - - self.features = features - self.version = ( - self.features.major_version, - self.features.minor_version, - self.features.patch_version, - ) - self.check_firmware_version(warn_only=True) - if self.features.session_id is not None: - self.session_id = self.features.session_id - self.features.session_id = None - - @session - def refresh_features(self) -> messages.Features: - """Reload features from the device. - - Should be called after changing settings or performing operations that affect - device state. - """ - resp = self.call_raw(messages.GetFeatures()) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to GetFeatures") - self._refresh_features(resp) - return resp - - @session - def init_device( - self, - *, - session_id: Optional[bytes] = None, - new_session: bool = False, - derive_cardano: Optional[bool] = None, - ) -> Optional[bytes]: - """Initialize the device and return a session ID. - - You can optionally specify a session ID. If the session still exists on the - device, the same session ID will be returned and the session is resumed. - Otherwise a different session ID is returned. - - Specify `new_session=True` to open a fresh session. Since firmware version - 1.9.0/2.3.0, the previous session will remain cached on the device, and can be - resumed by calling `init_device` again with the appropriate session ID. - - If neither `new_session` nor `session_id` is specified, the current session ID - will be reused. If no session ID was cached, a new session ID will be allocated - and returned. - - # Version notes: - - Trezor One older than 1.9.0 does not have session management. Optional arguments - have no effect and the function returns None - - Trezor T older than 2.3.0 does not have session cache. Requesting a new session - will overwrite the old one. In addition, this function will always return None. - A valid session_id can be obtained from the `session_id` attribute, but only - after a passphrase-protected call is performed. You can use the following code: - - >>> client.init_device() - >>> client.ensure_unlocked() - >>> valid_session_id = client.session_id - """ - if new_session: - self.session_id = None - elif session_id is not None: - self.session_id = session_id - - resp = self.call_raw( - messages.Initialize( - session_id=self.session_id, - derive_cardano=derive_cardano, - ) - ) - if isinstance(resp, messages.Failure): - # can happen if `derive_cardano` does not match the current session - raise exceptions.TrezorFailure(resp) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to Initialize") - - if self.session_id is not None and resp.session_id == self.session_id: - LOG.info("Successfully resumed session") - elif session_id is not None: - LOG.info("Failed to resume session") - - # TT < 2.3.0 compatibility: - # _refresh_features will clear out the session_id field. We want this function - # to return its value, so that callers can rely on it being either a valid - # session_id, or None if we can't do that. - # Older TT FW does not report session_id in Features and self.session_id might - # be invalid because TT will not allocate a session_id until a passphrase - # exchange happens. - reported_session_id = resp.session_id - self._refresh_features(resp) - return reported_session_id - - def is_outdated(self) -> bool: - if self.features.bootloader_mode: - return False - return self.version < self.model.minimum_version - - def check_firmware_version(self, warn_only: bool = False) -> None: - if self.is_outdated(): - if warn_only: - warnings.warn("Firmware is out of date", stacklevel=2) - else: - raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) - - def ping(self, msg: str, button_protection: bool = False) -> str: - # We would like ping to work on any valid TrezorClient instance, but - # due to the protection modes, we need to go through self.call, and that will - # raise an exception if the firmware is too old. - # So we short-circuit the simplest variant of ping with call_raw. - if not button_protection: - # XXX this should be: `with self:` - try: - self.open() - resp = self.call_raw(messages.Ping(message=msg)) - if isinstance(resp, messages.ButtonRequest): - # device is PIN-locked. - # respond and hope for the best - resp = self._callback_button(resp) - resp = messages.Success.ensure_isinstance(resp) - assert resp.message is not None - return resp.message - finally: - self.close() - - resp = self.call( - messages.Ping(message=msg, button_protection=button_protection), - expect=messages.Success, - ) - assert resp.message is not None - return resp.message - - def get_device_id(self) -> Optional[str]: - return self.features.device_id - - @session - def lock(self, *, _refresh_features: bool = True) -> None: - """Lock the device. - - If the device does not have a PIN configured, this will do nothing. - Otherwise, a lock screen will be shown and the device will prompt for PIN - before further actions. - - This call does _not_ invalidate passphrase cache. If passphrase is in use, - the device will not prompt for it after unlocking. - - To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate - passphrase cache, use `clear_session()`. - """ - # Private argument _refresh_features can be used internally to avoid - # refreshing in cases where we will refresh soon anyway. This is used - # in TrezorClient.clear_session() - self.call(messages.LockDevice()) - if _refresh_features: - self.refresh_features() - - @session - def ensure_unlocked(self) -> None: - """Ensure the device is unlocked and a passphrase is cached. - - If the device is locked, this will prompt for PIN. If passphrase is enabled - and no passphrase is cached for the current session, the device will also - prompt for passphrase. - - After calling this method, further actions on the device will not prompt for - PIN or passphrase until the device is locked or the session becomes invalid. - """ - from .btc import get_address - - get_address(self, "Testnet", PASSPHRASE_TEST_PATH) - self.refresh_features() - - def end_session(self) -> None: - """Close the current session and clear cached passphrase. - - The session will become invalid until `init_device()` is called again. - If passphrase is enabled, further actions will prompt for it again. - - This is a no-op in bootloader mode, as it does not support session management. - """ - # since: 2.3.4, 1.9.4 - try: - if not self.features.bootloader_mode: - self.call(messages.EndSession()) - except exceptions.TrezorFailure: - # A failure most likely means that the FW version does not support - # the EndSession call. We ignore the failure and clear the local session_id. - # The client-side end result is identical. - pass - self.session_id = None - - @session - def clear_session(self) -> None: - """Lock the device and present a fresh session. - - The current session will be invalidated and a new one will be started. If the - device has PIN enabled, it will become locked. - - Equivalent to calling `lock()`, `end_session()` and `init_device()`. - """ - self.lock(_refresh_features=False) - self.end_session() - self.init_device(new_session=True) + return TrezorClient(transport, **kwargs) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 0a2096993b..0c63c30b65 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -21,55 +21,55 @@ import logging import re import textwrap import time +import typing as t from contextlib import contextmanager from copy import deepcopy from datetime import datetime from enum import Enum, IntEnum, auto from itertools import zip_longest from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Iterable, - Iterator, - Sequence, - Tuple, - Union, -) from mnemonic import Mnemonic -from . import mapping, messages, models, protobuf -from .client import TrezorClient -from .exceptions import TrezorFailure, UnexpectedMessageError +from . import btc, mapping, messages, models, protobuf +from .client import ( + MAX_PASSPHRASE_LENGTH, + MAX_PIN_LENGTH, + PASSPHRASE_ON_DEVICE, + TrezorClient, +) +from .exceptions import Cancelled, PinException, TrezorFailure, UnexpectedMessageError from .log import DUMP_BYTES -from .messages import DebugWaitType +from .messages import Capability, DebugWaitType +from .protobuf import MessageType +from .tools import parse_path +from .transport.session import Session, SessionV1 +from .transport.thp.protocol_v1 import ProtocolV1 -if TYPE_CHECKING: +if t.TYPE_CHECKING: from typing_extensions import Protocol from .messages import PinMatrixRequestType from .transport import Transport - ExpectedMessage = Union[ - protobuf.MessageType, type[protobuf.MessageType], "MessageFilter" + ExpectedMessage = t.Union[ + protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter" ] - AnyDict = Dict[str, Any] + AnyDict = t.Dict[str, t.Any] class InputFunc(Protocol): + def __call__( self, hold_ms: int | None = None, ) -> "None": ... - InputFlowType = Generator[None, messages.ButtonRequest, None] + InputFlowType = t.Generator[None, messages.ButtonRequest, None] EXPECTED_RESPONSES_CONTEXT_LINES = 3 +PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") LOG = logging.getLogger(__name__) @@ -107,11 +107,11 @@ class UnstructuredJSONReader: except json.JSONDecodeError: self.dict = {} - def top_level_value(self, key: str) -> Any: + def top_level_value(self, key: str) -> t.Any: return self.dict.get(key) - def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_objects_with_key_and_value(self, key: str, value: t.Any) -> list[AnyDict]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if data.get(key) == value: yield data @@ -124,7 +124,7 @@ class UnstructuredJSONReader: return list(recursively_find(self.dict)) def find_unique_object_with_key_and_value( - self, key: str, value: Any + self, key: str, value: t.Any ) -> AnyDict | None: objects = self.find_objects_with_key_and_value(key, value) if not objects: @@ -132,8 +132,10 @@ class UnstructuredJSONReader: assert len(objects) == 1 return objects[0] - def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_values_by_key( + self, key: str, only_type: type | None = None + ) -> list[t.Any]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if key in data: yield data[key] @@ -151,8 +153,8 @@ class UnstructuredJSONReader: return values def find_unique_value_by_key( - self, key: str, default: Any, only_type: type | None = None - ) -> Any: + self, key: str, default: t.Any, only_type: type | None = None + ) -> t.Any: values = self.find_values_by_key(key, only_type=only_type) if not values: return default @@ -163,7 +165,7 @@ class UnstructuredJSONReader: class LayoutContent(UnstructuredJSONReader): """Contains helper functions to extract specific parts of the layout.""" - def __init__(self, json_tokens: Sequence[str]) -> None: + def __init__(self, json_tokens: t.Sequence[str]) -> None: json_str = "".join(json_tokens) super().__init__(json_str) @@ -429,6 +431,7 @@ class DebugLink: self.allow_interactions = auto_interact self.mapping = mapping.DEFAULT_MAPPING + self.protocol = ProtocolV1(self.transport, self.mapping) # To be set by TrezorClientDebugLink (is not known during creation time) self.model: models.TrezorModel | None = None self.version: tuple[int, int, int] = (0, 0, 0) @@ -471,10 +474,16 @@ class DebugLink: return LayoutType.from_model(self.model) def open(self) -> None: - self.transport.begin_session() + self.transport.open() + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_begin_session() def close(self) -> None: - self.transport.end_session() + pass + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_end_session() def _write(self, msg: protobuf.MessageType) -> None: if self.waiting_for_layout_change: @@ -491,15 +500,10 @@ class DebugLink: DUMP_BYTES, f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", ) - self.transport.write(msg_type, msg_bytes) + self.protocol.write(msg) def _read(self) -> protobuf.MessageType: - ret_type, ret_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}", - ) - msg = self.mapping.decode(ret_type, ret_bytes) + msg = self.protocol.read() # Collapse tokens to make log use less lines. msg_for_log = msg @@ -513,7 +517,7 @@ class DebugLink: ) return msg - def _call(self, msg: protobuf.MessageType) -> Any: + def _call(self, msg: protobuf.MessageType) -> t.Any: self._write(msg) return self._read() @@ -531,6 +535,25 @@ class DebugLink: raise TrezorFailure(result) return result + def pairing_info( + self, + thp_channel_id: bytes | None = None, + handshake_hash: bytes | None = None, + nfc_secret_host: bytes | None = None, + ) -> messages.DebugLinkPairingInfo: + result = self._call( + messages.DebugLinkGetPairingInfo( + channel_id=thp_channel_id, + handshake_hash=handshake_hash, + nfc_secret_host=nfc_secret_host, + ) + ) + while not isinstance(result, (messages.Failure, messages.DebugLinkPairingInfo)): + result = self._read() + if isinstance(result, messages.Failure): + raise TrezorFailure(result) + return result + def read_layout(self, wait: bool | None = None) -> LayoutContent: """ Force waiting for the layout by setting `wait=True`. Force not waiting by @@ -547,7 +570,7 @@ class DebugLink: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: # Next layout change will be caused by external event - # (e.g. device being auto-locked or as a result of device_handler.run(xxx)) + # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx)) # and not by our debug actions/decisions. # Resetting the debug state so we wait for the next layout change # (and do not return the current state). @@ -562,7 +585,7 @@ class DebugLink: return LayoutContent(obj.tokens) @contextmanager - def wait_for_layout_change(self) -> Iterator[None]: + def wait_for_layout_change(self) -> t.Iterator[None]: # make sure some current layout is up by issuing a dummy GetState self.state() @@ -615,7 +638,7 @@ class DebugLink: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self) -> Tuple[str | None, int | None]: + def read_recovery_word(self) -> t.Tuple[str | None, int | None]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) @@ -671,7 +694,7 @@ class DebugLink: """Send text input to the device. See `_decision` for more details.""" self._decision(messages.DebugLinkDecision(input=word)) - def click(self, click: Tuple[int, int], hold_ms: int | None = None) -> None: + def click(self, click: t.Tuple[int, int], hold_ms: int | None = None) -> None: """Send a click to the device. See `_decision` for more details.""" x, y = click self._decision(messages.DebugLinkDecision(x=x, y=y, hold_ms=hold_ms)) @@ -794,10 +817,10 @@ class DebugUI: self.clear() def clear(self) -> None: - self.pins: Iterator[str] | None = None + self.pins: t.Iterator[str] | None = None self.passphrase = "" - self.input_flow: Union[ - Generator[None, messages.ButtonRequest, None], object, None + self.input_flow: t.Union[ + t.Generator[None, messages.ButtonRequest, None], object, None ] = None def _default_input_flow(self, br: messages.ButtonRequest) -> None: @@ -829,7 +852,7 @@ class DebugUI: raise AssertionError("input flow ended prematurely") else: try: - assert isinstance(self.input_flow, Generator) + assert isinstance(self.input_flow, t.Generator) self.input_flow.send(br) except StopIteration: self.input_flow = self.INPUT_FLOW_DONE @@ -851,12 +874,15 @@ class DebugUI: class MessageFilter: - def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None: + + def __init__( + self, message_type: t.Type[protobuf.MessageType], **fields: t.Any + ) -> None: self.message_type = message_type - self.fields: Dict[str, Any] = {} + self.fields: t.Dict[str, t.Any] = {} self.update_fields(**fields) - def update_fields(self, **fields: Any) -> "MessageFilter": + def update_fields(self, **fields: t.Any) -> "MessageFilter": for name, value in fields.items(): try: self.fields[name] = self.from_message_or_type(value) @@ -904,7 +930,7 @@ class MessageFilter: return True def to_string(self, maxwidth: int = 80) -> str: - fields: list[Tuple[str, str]] = [] + fields: list[t.Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -934,7 +960,7 @@ class MessageFilter: class MessageFilterGenerator: - def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: + def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]: message_type = getattr(messages, key) return MessageFilter(message_type).update_fields @@ -942,6 +968,245 @@ class MessageFilterGenerator: message_filters = MessageFilterGenerator() +class SessionDebugWrapper(Session): + def __init__(self, session: Session) -> None: + self._session = session + self.reset_debug_features() + if isinstance(session, SessionDebugWrapper): + raise Exception("Cannot wrap already wrapped session!") + + @property + def protocol_version(self) -> int: + return self.client.protocol_version + + @property + def client(self) -> TrezorClientDebugLink: + assert isinstance(self._session.client, TrezorClientDebugLink) + return self._session.client + + @property + def id(self) -> bytes: + return self._session.id + + def _write(self, msg: t.Any) -> None: + print("writing message:", msg.__class__.__name__) + self._session._write(self._filter_message(msg)) + + def _read(self) -> t.Any: + resp = self._filter_message(self._session._read()) + print("reading message:", resp.__class__.__name__) + if self.actual_responses is not None: + self.actual_responses.append(resp) + return resp + + def set_expected_responses( + self, + expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], + ) -> None: + """Set a sequence of expected responses to session calls. + + Within a given with-block, the list of received responses from device must + match the list of expected responses, otherwise an ``AssertionError`` is raised. + + If an expected response is given a field value other than ``None``, that field value + must exactly match the received field value. If a given field is ``None`` + (or unspecified) in the expected response, the received field value is not + checked. + + Each expected response can also be a tuple ``(bool, message)``. In that case, the + expected response is only evaluated if the first field is ``True``. + This is useful for differentiating sequences between Trezor models: + + >>> trezor_one = session.features.model == "1" + >>> session.set_expected_responses([ + >>> messages.ButtonRequest(code=ConfirmOutput), + >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), + >>> messages.Success(), + >>> ]) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + # make sure all items are (bool, message) tuples + expected_with_validity = ( + e if isinstance(e, tuple) else (True, e) for e in expected + ) + + # only apply those items that are (True, message) + self.expected_responses = [ + MessageFilter.from_message_or_type(expected) + for valid, expected in expected_with_validity + if valid + ] + self.actual_responses = [] + + def lock(self, *, _refresh_features: bool = True) -> None: + """Lock the device. + + If the device does not have a PIN configured, this will do nothing. + Otherwise, a lock screen will be shown and the device will prompt for PIN + before further actions. + + This call does _not_ invalidate passphrase cache. If passphrase is in use, + the device will not prompt for it after unlocking. + + To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate + passphrase cache, use `clear_session()`. + """ + # TODO update the documentation above + # Private argument _refresh_features can be used internally to avoid + # refreshing in cases where we will refresh soon anyway. This is used + # in TrezorClient.clear_session() + self.call(messages.LockDevice()) + if _refresh_features: + self.refresh_features() + + def cancel(self) -> None: + self._write(messages.Cancel()) + + def ensure_unlocked(self) -> None: + btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH) + self.refresh_features() + + def set_filter( + self, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ) -> None: + """Configure a filter function for a specified message type. + + The `callback` must be a function that accepts a protobuf message, and returns + a (possibly modified) protobuf message of the same type. Whenever a message + is sent or received that matches `message_type`, `callback` is invoked on the + message and its result is substituted for the original. + + Useful for test scenarios with an active malicious actor on the wire. + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + self.filters[message_type] = callback + + def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: + message_type = msg.__class__ + callback = self.filters.get(message_type) + if callable(callback): + return callback(deepcopy(msg)) + else: + return msg + + def reset_debug_features(self) -> None: + """Prepare the debugging session for a new testcase. + + Clears all debugging state that might have been modified by a testcase. + """ + self.in_with_statement = False + self.expected_responses: list[MessageFilter] | None = None + self.actual_responses: list[protobuf.MessageType] | None = None + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ] = {} + self.button_callback = self.client.button_callback + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self._session.passphrase_callback + self.passphrase = self._session.passphrase + + def __enter__(self) -> "SessionDebugWrapper": + # For usage in with/expected_responses + if self.in_with_statement: + raise RuntimeError("Do not nest!") + self.in_with_statement = True + return self + + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + # copy expected/actual responses before clearing them + expected_responses = self.expected_responses + actual_responses = self.actual_responses + + # grab a copy of the inputflow generator to raise an exception through it + if isinstance(self.client.ui, DebugUI): + input_flow = self.client.ui.input_flow + else: + input_flow = None + + self.reset_debug_features() + + if exc_type is None: + # If no other exception was raised, evaluate missed responses + # (raises AssertionError on mismatch) + self._verify_responses(expected_responses, actual_responses) + if isinstance(input_flow, t.Generator): + # Ensure that the input flow is exhausted + try: + input_flow.throw( + AssertionError("input flow continues past end of test") + ) + except StopIteration: + pass + + elif isinstance(input_flow, t.Generator): + # Propagate the exception through the input flow, so that we see in + # traceback where it is stuck. + input_flow.throw(exc_type, value, traceback) + + @classmethod + def _verify_responses( + cls, + expected: list[MessageFilter] | None, + actual: list[protobuf.MessageType] | None, + ) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + if expected is None and actual is None: + return + + assert expected is not None + assert actual is not None + + for i, (exp, act) in enumerate(zip_longest(expected, actual)): + if exp is None: + output = cls._expectation_lines(expected, i) + output.append("No more messages were expected, but we got:") + for resp in actual[i:]: + output.append( + textwrap.indent(protobuf.format_message(resp), " ") + ) + raise AssertionError("\n".join(output)) + + if act is None: + output = cls._expectation_lines(expected, i) + output.append("This and the following message was not received.") + raise AssertionError("\n".join(output)) + + if not exp.match(act): + output = cls._expectation_lines(expected, i) + output.append("Actually received:") + output.append(textwrap.indent(protobuf.format_message(act), " ")) + raise AssertionError("\n".join(output)) + + @staticmethod + def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: + start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) + stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) + output: list[str] = [] + output.append("Expected responses:") + if start_at > 0: + output.append(f" (...{start_at} previous responses omitted)") + for i in range(start_at, stop_at): + exp = expected[i] + prefix = " " if i != current else ">>> " + output.append(textwrap.indent(exp.to_string(), prefix)) + if stop_at < len(expected): + omitted = len(expected) - stop_at + output.append(f" (...{omitted} following responses omitted)") + + output.append("") + return output + + class TrezorClientDebugLink(TrezorClient): # This class implements automatic responses # and other functionality for unit tests @@ -967,23 +1232,34 @@ class TrezorClientDebugLink(TrezorClient): raise # set transport explicitly so that sync_responses can work - self.transport = transport + super().__init__(transport) - self.reset_debug_features() + self.transport = transport + self.ui: DebugUI = DebugUI(self.debug) + + self.reset_debug_features(new_seedless_session=True) self.sync_responses() - super().__init__(transport, ui=self.ui) # So that we can choose right screenshotting logic (T1 vs TT) # and know the supported debug capabilities self.debug.model = self.model self.debug.version = self.version + self.passphrase: str | None = None @property def layout_type(self) -> LayoutType: return self.debug.layout_type - def reset_debug_features(self) -> None: - """Prepare the debugging client for a new testcase. + def get_new_client(self) -> TrezorClientDebugLink: + new_client = TrezorClientDebugLink( + self.transport, self.debug.allow_interactions + ) + new_client.debug.screenshot_recording_dir = self.debug.screenshot_recording_dir + return new_client + + def reset_debug_features(self, new_seedless_session: bool = False) -> None: + """ + Prepare the debugging client for a new testcase. Clears all debugging state that might have been modified by a testcase. """ @@ -991,30 +1267,139 @@ class TrezorClientDebugLink(TrezorClient): self.in_with_statement = False self.expected_responses: list[MessageFilter] | None = None self.actual_responses: list[protobuf.MessageType] | None = None - self.filters: dict[ - type[protobuf.MessageType], - Callable[[protobuf.MessageType], protobuf.MessageType] | None, + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} + if new_seedless_session: + self._seedless_session = self.get_seedless_session(new_session=True) + + @property + def button_callback(self): + + def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + session._write(messages.ButtonAck()) + self.ui.button_request(msg) + return session._read() + + return _callback_button + + @property + def pin_callback(self): + + def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any: + try: + pin = self.ui.get_pin(msg.type) + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if any(d not in "123456789" for d in pin) or not ( + 1 <= len(pin) <= MAX_PIN_LENGTH + ): + session.call_raw(messages.Cancel()) + raise ValueError("Invalid PIN provided") + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp + + return _callback_pin + + @property + def passphrase_callback(self): + def _callback_passphrase( + session: Session, msg: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) + + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> MessageType: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + # session.session_id = resp.state + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp + + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + if session.passphrase is None and isinstance(session, SessionV1): + passphrase = self.ui.get_passphrase( + available_on_device=available_on_device + ) + else: + passphrase = session.passphrase + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: + session.call_raw(messages.Cancel()) + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) + + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") + + return send_passphrase(passphrase, on_device=False) + + return _callback_passphrase def ensure_open(self) -> None: """Only open session if there isn't already an open one.""" - if self.session_counter == 0: - self.open() + # if self.session_counter == 0: + # self.open() + # TODO check if is this needed def open(self) -> None: - super().open() - if self.session_counter == 1: - self.debug.open() + pass + # TODO is this needed? + # self.debug.open() def close(self) -> None: - if self.session_counter == 1: - self.debug.close() - super().close() + pass + # TODO is this needed? + # self.debug.close() + + def lock(self) -> None: + s = SessionDebugWrapper(self.get_seedless_session()) + s.lock() + + def get_session( + self, + passphrase: str | object | None = "", + derive_cardano: bool = False, + session_id: int = 0, + ) -> Session: + if isinstance(passphrase, str): + passphrase = Mnemonic.normalize_string(passphrase) + return super().get_session(passphrase, derive_cardano, session_id) def set_filter( self, - message_type: type[protobuf.MessageType], - callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ) -> None: """Configure a filter function for a specified message type. @@ -1039,7 +1424,7 @@ class TrezorClientDebugLink(TrezorClient): return msg def set_input_flow( - self, input_flow: InputFlowType | Callable[[], InputFlowType] + self, input_flow: InputFlowType | t.Callable[[], InputFlowType] ) -> None: """Configure a sequence of input events for the current with-block. @@ -1095,7 +1480,7 @@ class TrezorClientDebugLink(TrezorClient): self.in_with_statement = True return self - def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # copy expected/actual responses before clearing them @@ -1108,21 +1493,23 @@ class TrezorClientDebugLink(TrezorClient): else: input_flow = None - self.reset_debug_features() + self.reset_debug_features(new_seedless_session=False) if exc_type is None: # If no other exception was raised, evaluate missed responses # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, Generator): + elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. input_flow.throw(exc_type, value, traceback) def set_expected_responses( self, - expected: Sequence[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]], + expected: t.Sequence[ + t.Union["ExpectedMessage", t.Tuple[bool, "ExpectedMessage"]] + ], ) -> None: """Set a sequence of expected responses to client calls. @@ -1161,7 +1548,7 @@ class TrezorClientDebugLink(TrezorClient): ] self.actual_responses = [] - def use_pin_sequence(self, pins: Iterable[str]) -> None: + def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. """ @@ -1169,6 +1556,7 @@ class TrezorClientDebugLink(TrezorClient): def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" + self.passphrase = passphrase self.ui.passphrase = Mnemonic.normalize_string(passphrase) def use_mnemonic(self, mnemonic: str) -> None: @@ -1178,15 +1566,14 @@ class TrezorClientDebugLink(TrezorClient): def _raw_read(self) -> protobuf.MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - resp = super()._raw_read() + resp = self.get_seedless_session()._read() resp = self._filter_message(resp) if self.actual_responses is not None: self.actual_responses.append(resp) return resp def _raw_write(self, msg: protobuf.MessageType) -> None: - return super()._raw_write(self._filter_message(msg)) + return self.get_seedless_session()._write(self._filter_message(msg)) @staticmethod def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: @@ -1256,23 +1643,25 @@ class TrezorClientDebugLink(TrezorClient): # Start by canceling whatever is on screen. This will work to cancel T1 PIN # prompt, which is in TINY mode and does not respond to `Ping`. - cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) - self.transport.begin_session() + # TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) + self.transport.open() try: - self.transport.write(*cancel_msg) - + # self.protocol.write(messages.Cancel()) message = "SYNC" + secrets.token_hex(8) - ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) - self.transport.write(*ping_msg) + self.get_seedless_session()._write(messages.Ping(message=message)) resp = None while resp != messages.Success(message=message): - msg_id, msg_bytes = self.transport.read() try: - resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) + resp = self.get_seedless_session()._read() + + raise Exception + except Exception: pass + finally: - self.transport.end_session() + pass # TODO fix + # self.transport.end_session(self.session_id or b"") def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() @@ -1285,8 +1674,8 @@ class TrezorClientDebugLink(TrezorClient): def load_device( - client: "TrezorClient", - mnemonic: Union[str, Iterable[str]], + session: "Session", + mnemonic: str | t.Iterable[str], pin: str | None, passphrase_protection: bool, label: str | None, @@ -1299,12 +1688,12 @@ def load_device( mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call device.wipe() and try again." ) - client.call( + session.call( messages.LoadDevice( mnemonics=mnemonics, pin=pin, @@ -1316,18 +1705,18 @@ def load_device( ), expect=messages.Success, ) - client.init_device() + session.refresh_features() # keep the old name for compatibility load_device_by_mnemonic = load_device -def prodtest_t1(client: "TrezorClient") -> None: - if client.features.bootloader_mode is not True: +def prodtest_t1(session: "Session") -> None: + if session.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") - client.call( + session.call( messages.ProdTestT1( payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" ), @@ -1337,8 +1726,8 @@ def prodtest_t1(client: "TrezorClient") -> None: def record_screen( debug_client: "TrezorClientDebugLink", - directory: Union[str, None], - report_func: Union[Callable[[str], None], None] = None, + directory: str | None, + report_func: t.Callable[[str], None] | None = None, ) -> None: """Record screen changes into a specified directory. @@ -1383,5 +1772,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool: return debug_client.features.fw_vendor == "EMULATOR" -def optiga_set_sec_max(client: "TrezorClient") -> None: - client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success) +def optiga_set_sec_max(session: "Session") -> None: + session.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success) diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index c08d485ed0..a3b24c247d 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -28,16 +28,10 @@ from slip10 import SLIP10 from . import messages from .exceptions import Cancelled, TrezorException -from .tools import ( - Address, - _deprecation_retval_helper, - _return_success, - parse_path, - session, -) +from .tools import Address, _deprecation_retval_helper, _return_success, parse_path if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session RECOVERY_BACK = "\x08" # backspace character, sent literally @@ -46,9 +40,8 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1) ENTROPY_CHECK_MIN_VERSION = (2, 8, 7) -@session def apply_settings( - client: "TrezorClient", + session: "Session", label: Optional[str] = None, language: Optional[str] = None, use_passphrase: Optional[bool] = None, @@ -79,13 +72,13 @@ def apply_settings( haptic_feedback=haptic_feedback, ) - out = client.call(settings, expect=messages.Success) - client.refresh_features() + out = session.call(settings, expect=messages.Success) + session.refresh_features() return _return_success(out) def _send_language_data( - client: "TrezorClient", + session: "Session", request: "messages.TranslationDataRequest", language_data: bytes, ) -> None: @@ -95,69 +88,63 @@ def _send_language_data( data_length = response.data_length data_offset = response.data_offset chunk = language_data[data_offset : data_offset + data_length] - response = client.call(messages.TranslationDataAck(data_chunk=chunk)) + response = session.call(messages.TranslationDataAck(data_chunk=chunk)) -@session def change_language( - client: "TrezorClient", + session: "Session", language_data: bytes, show_display: bool | None = None, ) -> str | None: data_length = len(language_data) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) - response = client.call(msg) + response = session.call(msg) if data_length > 0: response = messages.TranslationDataRequest.ensure_isinstance(response) - _send_language_data(client, response, language_data) + _send_language_data(session, response, language_data) else: messages.Success.ensure_isinstance(response) - client.refresh_features() # changing the language in features + session.refresh_features() # changing the language in features return _return_success(messages.Success(message="Language changed.")) -@session -def apply_flags(client: "TrezorClient", flags: int) -> str | None: - out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success) - client.refresh_features() +def apply_flags(session: "Session", flags: int) -> str | None: + out = session.call(messages.ApplyFlags(flags=flags), expect=messages.Success) + session.refresh_features() return _return_success(out) -@session -def change_pin(client: "TrezorClient", remove: bool = False) -> str | None: - ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success) - client.refresh_features() +def change_pin(session: "Session", remove: bool = False) -> str | None: + ret = session.call(messages.ChangePin(remove=remove), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session -def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None: - ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success) - client.refresh_features() +def change_wipe_code(session: "Session", remove: bool = False) -> str | None: + ret = session.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType + session: "Session", operation: messages.SdProtectOperationType ) -> str | None: - ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success) - client.refresh_features() + ret = session.call(messages.SdProtect(operation=operation), expect=messages.Success) + session.refresh_features() return _return_success(ret) -@session -def wipe(client: "TrezorClient") -> str | None: - ret = client.call(messages.WipeDevice(), expect=messages.Success) - if not client.features.bootloader_mode: - client.init_device() +def wipe(session: "Session") -> str | None: + ret = session.call(messages.WipeDevice(), expect=messages.Success) + session.invalidate() + # if not session.features.bootloader_mode: + # session.refresh_features() return _return_success(ret) -@session def recover( - client: "TrezorClient", + session: "Session", word_count: int = 24, passphrase_protection: bool = False, pin_protection: bool = True, @@ -193,13 +180,13 @@ def recover( if type is None: type = messages.RecoveryType.NormalRecovery - if client.features.model == "1" and input_callback is None: + if session.features.model == "1" and input_callback is None: raise RuntimeError("Input callback required for Trezor One") if word_count not in (12, 18, 24): raise ValueError("Invalid word count. Use 12/18/24") - if client.features.initialized and type == messages.RecoveryType.NormalRecovery: + if session.features.initialized and type == messages.RecoveryType.NormalRecovery: raise RuntimeError( "Device already initialized. Call device.wipe() and try again." ) @@ -221,20 +208,20 @@ def recover( msg.label = label msg.u2f_counter = u2f_counter - res = client.call(msg) + res = session.call(msg) while isinstance(res, messages.WordRequest): try: assert input_callback is not None inp = input_callback(res.type) - res = client.call(messages.WordAck(word=inp)) + res = session.call(messages.WordAck(word=inp)) except Cancelled: - res = client.call(messages.Cancel()) + res = session.call(messages.Cancel()) # check that the result is a Success res = messages.Success.ensure_isinstance(res) # reinitialize the device - client.init_device() + session.refresh_features() return _deprecation_retval_helper(res) @@ -280,7 +267,7 @@ def _seed_from_entropy( def reset( - client: "TrezorClient", + session: "Session", display_random: bool = False, strength: Optional[int] = None, passphrase_protection: bool = False, @@ -313,7 +300,7 @@ def reset( ) setup( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -331,9 +318,8 @@ def _get_external_entropy() -> bytes: return secrets.token_bytes(32) -@session def setup( - client: "TrezorClient", + session: "Session", *, strength: Optional[int] = None, passphrase_protection: bool = True, @@ -388,19 +374,19 @@ def setup( check. """ - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call wipe_device() and try again." ) if strength is None: - if client.features.model == "1": + if session.features.model == "1": strength = 256 else: strength = 128 if backup_type is None: - if client.version < SLIP39_EXTENDABLE_MIN_VERSION: + if session.version < SLIP39_EXTENDABLE_MIN_VERSION: # includes Trezor One 1.x.x backup_type = messages.BackupType.Bip39 else: @@ -411,7 +397,7 @@ def setup( paths = [parse_path("m/84h/0h/0h"), parse_path("m/44h/60h/0h")] if entropy_check_count is None: - if client.version < ENTROPY_CHECK_MIN_VERSION: + if session.version < ENTROPY_CHECK_MIN_VERSION: # includes Trezor One 1.x.x entropy_check_count = 0 else: @@ -431,18 +417,18 @@ def setup( ) if entropy_check_count > 0: xpubs = _reset_with_entropycheck( - client, msg, entropy_check_count, paths, _get_entropy + session, msg, entropy_check_count, paths, _get_entropy ) else: - _reset_no_entropycheck(client, msg, _get_entropy) + _reset_no_entropycheck(session, msg, _get_entropy) xpubs = [] - client.init_device() + session.refresh_features() return xpubs def _reset_no_entropycheck( - client: "TrezorClient", + session: "Session", msg: messages.ResetDevice, get_entropy: Callable[[], bytes], ) -> None: @@ -454,12 +440,12 @@ def _reset_no_entropycheck( << Success """ assert msg.entropy_check is False - client.call(msg, expect=messages.EntropyRequest) - client.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success) + session.call(msg, expect=messages.EntropyRequest) + session.call(messages.EntropyAck(entropy=get_entropy()), expect=messages.Success) def _reset_with_entropycheck( - client: "TrezorClient", + session: "Session", reset_msg: messages.ResetDevice, entropy_check_count: int, paths: Iterable[Address], @@ -495,7 +481,7 @@ def _reset_with_entropycheck( def get_xpubs() -> list[tuple[Address, str]]: xpubs = [] for path in paths: - resp = client.call( + resp = session.call( messages.GetPublicKey(address_n=path), expect=messages.PublicKey ) xpubs.append((path, resp.xpub)) @@ -524,13 +510,13 @@ def _reset_with_entropycheck( raise TrezorException("Invalid XPUB in entropy check") xpubs = [] - resp = client.call(reset_msg, expect=messages.EntropyRequest) + resp = session.call(reset_msg, expect=messages.EntropyRequest) entropy_commitment = resp.entropy_commitment while True: # provide external entropy for this round external_entropy = get_entropy() - client.call( + session.call( messages.EntropyAck(entropy=external_entropy), expect=messages.EntropyCheckReady, ) @@ -540,7 +526,7 @@ def _reset_with_entropycheck( if entropy_check_count <= 0: # last round, wait for a Success and exit the loop - client.call( + session.call( messages.EntropyCheckContinue(finish=True), expect=messages.Success, ) @@ -549,7 +535,7 @@ def _reset_with_entropycheck( entropy_check_count -= 1 # Next round starts. - resp = client.call( + resp = session.call( messages.EntropyCheckContinue(finish=False), expect=messages.EntropyRequest, ) @@ -570,13 +556,12 @@ def _reset_with_entropycheck( return xpubs -@session def backup( - client: "TrezorClient", + session: "Session", group_threshold: Optional[int] = None, groups: Iterable[tuple[int, int]] = (), ) -> str | None: - ret = client.call( + ret = session.call( messages.BackupDevice( group_threshold=group_threshold, groups=[ @@ -586,37 +571,36 @@ def backup( ), expect=messages.Success, ) - client.refresh_features() + session.refresh_features() return _return_success(ret) -def cancel_authorization(client: "TrezorClient") -> str | None: - ret = client.call(messages.CancelAuthorization(), expect=messages.Success) +def cancel_authorization(session: "Session") -> str | None: + ret = session.call(messages.CancelAuthorization(), expect=messages.Success) return _return_success(ret) -def unlock_path(client: "TrezorClient", n: "Address") -> bytes: - resp = client.call( +def unlock_path(session: "Session", n: "Address") -> bytes: + resp = session.call( messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest ) # Cancel the UnlockPath workflow now that we have the authentication code. try: - client.call(messages.Cancel()) + session.call(messages.Cancel()) except Cancelled: return resp.mac else: raise TrezorException("Unexpected response in UnlockPath flow") -@session def reboot_to_bootloader( - client: "TrezorClient", + session: "Session", boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, firmware_header: Optional[bytes] = None, language_data: bytes = b"", ) -> str | None: - response = client.call( + response = session.call( messages.RebootToBootloader( boot_command=boot_command, firmware_header=firmware_header, @@ -624,43 +608,38 @@ def reboot_to_bootloader( ) ) if isinstance(response, messages.TranslationDataRequest): - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) return _return_success(messages.Success(message="")) -@session -def show_device_tutorial(client: "TrezorClient") -> str | None: - ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success) +def show_device_tutorial(session: "Session") -> str | None: + ret = session.call(messages.ShowDeviceTutorial(), expect=messages.Success) return _return_success(ret) -@session -def unlock_bootloader(client: "TrezorClient") -> str | None: - ret = client.call(messages.UnlockBootloader(), expect=messages.Success) +def unlock_bootloader(session: "Session") -> str | None: + ret = session.call(messages.UnlockBootloader(), expect=messages.Success) return _return_success(ret) -@session -def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None: +def set_busy(session: "Session", expiry_ms: Optional[int]) -> str | None: """Sets or clears the busy state of the device. In the busy state the device shows a "Do not disconnect" message instead of the homescreen. Setting `expiry_ms=None` clears the busy state. """ - ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success) - client.refresh_features() + ret = session.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success) + session.refresh_features() return _return_success(ret) -def authenticate( - client: "TrezorClient", challenge: bytes -) -> messages.AuthenticityProof: - return client.call( +def authenticate(session: "Session", challenge: bytes) -> messages.AuthenticityProof: + return session.call( messages.AuthenticateDevice(challenge=challenge), expect=messages.AuthenticityProof, ) -def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None: - ret = client.call(messages.SetBrightness(value=value), expect=messages.Success) +def set_brightness(session: "Session", value: Optional[int] = None) -> str | None: + ret = session.call(messages.SetBrightness(value=value), expect=messages.Success) return _return_success(ret) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index eb491f204c..990adf3855 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -18,11 +18,11 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages -from .tools import b58decode, session +from .tools import b58decode if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def name_to_number(name: str) -> int: @@ -319,17 +319,16 @@ def parse_transaction_json( def get_public_key( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> messages.EosPublicKey: - return client.call( + return session.call( messages.EosGetPublicKey(address_n=n, show_display=show_display), expect=messages.EosPublicKey, ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", address: "Address", transaction: dict, chain_id: str, @@ -345,11 +344,11 @@ def sign_tx( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) try: while isinstance(response, messages.EosTxActionRequest): - response = client.call(actions.pop(0)) + response = session.call(actions.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 96ce4d1066..77b071f6b7 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -18,11 +18,11 @@ import re from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import definitions, exceptions, messages -from .tools import prepare_message_bytes, session, unharden +from .tools import prepare_message_bytes, unharden if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def int_to_big_endian(value: int) -> bytes: @@ -161,13 +161,13 @@ def network_from_address_n( def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> str: - resp = client.call( + resp = session.call( messages.EthereumGetAddress( address_n=n, show_display=show_display, @@ -181,17 +181,16 @@ def get_address( def get_public_node( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> messages.EthereumPublicKey: - return client.call( + return session.call( messages.EthereumGetPublicKey(address_n=n, show_display=show_display), expect=messages.EthereumPublicKey, ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", n: "Address", nonce: int, gas_price: int, @@ -227,13 +226,13 @@ def sign_tx( data, chunk = data[1024:], data[:1024] msg.data_initial_chunk = chunk - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -248,9 +247,8 @@ def sign_tx( return response.signature_v, response.signature_r, response.signature_s -@session def sign_tx_eip1559( - client: "TrezorClient", + session: "Session", n: "Address", *, nonce: int, @@ -283,13 +281,13 @@ def sign_tx_eip1559( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -299,13 +297,13 @@ def sign_tx_eip1559( def sign_message( - client: "TrezorClient", + session: "Session", n: "Address", message: AnyStr, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> messages.EthereumMessageSignature: - return client.call( + return session.call( messages.EthereumSignMessage( address_n=n, message=prepare_message_bytes(message), @@ -317,7 +315,7 @@ def sign_message( def sign_typed_data( - client: "TrezorClient", + session: "Session", n: "Address", data: Dict[str, Any], *, @@ -333,7 +331,7 @@ def sign_typed_data( metamask_v4_compat=metamask_v4_compat, definitions=definitions, ) - response = client.call(request) + response = session.call(request) # Sending all the types while isinstance(response, messages.EthereumTypedDataStructRequest): @@ -349,7 +347,7 @@ def sign_typed_data( members.append(struct_member) request = messages.EthereumTypedDataStructAck(members=members) - response = client.call(request) + response = session.call(request) # Sending the whole message that should be signed while isinstance(response, messages.EthereumTypedDataValueRequest): @@ -362,7 +360,7 @@ def sign_typed_data( member_typename = data["primaryType"] member_data = data["message"] else: - client.cancel() + # TODO session.cancel() raise exceptions.TrezorException("Root index can only be 0 or 1") # It can be asking for a nested structure (the member path being [X, Y, Z, ...]) @@ -385,20 +383,20 @@ def sign_typed_data( encoded_data = encode_data(member_data, member_typename) request = messages.EthereumTypedDataValueAck(value=encoded_data) - response = client.call(request) + response = session.call(request) return messages.EthereumTypedDataSignature.ensure_isinstance(response) def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: bytes, message: AnyStr, chunkify: bool = False, ) -> bool: try: - client.call( + session.call( messages.EthereumVerifyMessage( address=address, signature=signature, @@ -413,13 +411,13 @@ def verify_message( def sign_typed_data_hash( - client: "TrezorClient", + session: "Session", n: "Address", domain_hash: bytes, message_hash: Optional[bytes], encoded_network: Optional[bytes] = None, ) -> messages.EthereumTypedDataSignature: - return client.call( + return session.call( messages.EthereumSignTypedHash( address_n=n, domain_separator_hash=domain_hash, diff --git a/python/src/trezorlib/exceptions.py b/python/src/trezorlib/exceptions.py index 99f0048dd3..44d25d7088 100644 --- a/python/src/trezorlib/exceptions.py +++ b/python/src/trezorlib/exceptions.py @@ -65,3 +65,7 @@ class UnexpectedMessageError(TrezorException): self.expected = expected self.actual = actual super().__init__(f"Expected {expected.__name__} but Trezor sent {actual}") + + +class DeviceLockedException(TrezorException): + pass diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index a2618b72db..aaa3b084bf 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -22,37 +22,37 @@ from . import messages from .tools import _return_success if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session -def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]: - return client.call( +def list_credentials(session: "Session") -> Sequence[messages.WebAuthnCredential]: + return session.call( messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials ).credentials -def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None: - ret = client.call( +def add_credential(session: "Session", credential_id: bytes) -> str | None: + ret = session.call( messages.WebAuthnAddResidentCredential(credential_id=credential_id), expect=messages.Success, ) return _return_success(ret) -def remove_credential(client: "TrezorClient", index: int) -> str | None: - ret = client.call( +def remove_credential(session: "Session", index: int) -> str | None: + ret = session.call( messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success ) return _return_success(ret) -def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None: - ret = client.call( +def set_counter(session: "Session", u2f_counter: int) -> str | None: + ret = session.call( messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success ) return _return_success(ret) -def get_next_counter(client: "TrezorClient") -> int: - ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter) +def get_next_counter(session: "Session") -> int: + ret = session.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter) return ret.u2f_counter diff --git a/python/src/trezorlib/firmware/__init__.py b/python/src/trezorlib/firmware/__init__.py index 4cfc11dd40..56168306bb 100644 --- a/python/src/trezorlib/firmware/__init__.py +++ b/python/src/trezorlib/firmware/__init__.py @@ -20,7 +20,6 @@ from hashlib import blake2s from typing_extensions import Protocol, TypeGuard from .. import messages -from ..tools import session from .core import VendorFirmware from .legacy import LegacyFirmware, LegacyV2Firmware @@ -38,7 +37,7 @@ if True: from .vendor import * # noqa: F401, F403 if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session T = t.TypeVar("T", bound="FirmwareType") @@ -72,20 +71,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]: # ====== Client functions ====== # -@session def update( - client: "TrezorClient", + session: "Session", data: bytes, progress_update: t.Callable[[int], t.Any] = lambda _: None, ): - if client.features.bootloader_mode is False: + if session.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") - resp = client.call(messages.FirmwareErase(length=len(data))) + resp = session.call(messages.FirmwareErase(length=len(data))) # TREZORv1 method if isinstance(resp, messages.Success): - resp = client.call(messages.FirmwareUpload(payload=data)) + resp = session.call(messages.FirmwareUpload(payload=data)) progress_update(len(data)) if isinstance(resp, messages.Success): return @@ -97,7 +95,7 @@ def update( length = resp.length payload = data[resp.offset : resp.offset + length] digest = blake2s(payload).digest() - resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) + resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest)) progress_update(length) if isinstance(resp, messages.Success): @@ -106,7 +104,7 @@ def update( raise RuntimeError(f"Unexpected message {resp}") -def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]) -> bytes: - return client.call( +def get_hash(session: "Session", challenge: t.Optional[bytes]) -> bytes: + return session.call( messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash ).hash diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index d50324d586..04b75f0aa5 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -17,6 +17,7 @@ from __future__ import annotations import io +import logging from types import ModuleType from typing import Dict, Optional, Tuple, Type, TypeVar @@ -25,6 +26,7 @@ from typing_extensions import Self from . import messages, protobuf T = TypeVar("T") +LOG = logging.getLogger(__name__) class ProtobufMapping: @@ -63,11 +65,21 @@ class ProtobufMapping: wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) if wire_type is None: raise ValueError("Cannot encode class without wire type") - + LOG.debug("encoding wire type %d", wire_type) buf = io.BytesIO() protobuf.dump_message(buf, msg) return wire_type, buf.getvalue() + def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes: + """Serialize a Python protobuf class. + + Returns the byte representation of the protobuf message. + """ + + buf = io.BytesIO() + protobuf.dump_message(buf, msg) + return buf.getvalue() + def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: """Deserialize a protobuf message into a Python class.""" cls = self.type_to_class[msg_wire_type] @@ -83,7 +95,9 @@ class ProtobufMapping: mapping = cls() message_types = getattr(module, "MessageType") - for entry in message_types: + thp_message_types = getattr(module, "ThpMessageType") + + for entry in (*message_types, *thp_message_types): msg_class = getattr(module, entry.name, None) if msg_class is None: raise ValueError( diff --git a/python/src/trezorlib/messages.py b/python/src/trezorlib/messages.py index 024c3ae696..b163c34b87 100644 --- a/python/src/trezorlib/messages.py +++ b/python/src/trezorlib/messages.py @@ -43,6 +43,10 @@ class FailureType(IntEnum): PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 + ThpUnallocatedSession = 15 + InvalidProtocol = 16 + BufferError = 17 + DeviceIsBusy = 18 FirmwareError = 99 @@ -400,6 +404,34 @@ class TezosBallotType(IntEnum): Pass = 2 +class ThpMessageType(IntEnum): + ThpCreateNewSession = 1000 + ThpPairingRequest = 1006 + ThpPairingRequestApproved = 1007 + ThpSelectMethod = 1008 + ThpPairingPreparationsFinished = 1009 + ThpCredentialRequest = 1010 + ThpCredentialResponse = 1011 + ThpEndRequest = 1012 + ThpEndResponse = 1013 + ThpCodeEntryCommitment = 1016 + ThpCodeEntryChallenge = 1017 + ThpCodeEntryCpaceTrezor = 1018 + ThpCodeEntryCpaceHostTag = 1019 + ThpCodeEntrySecret = 1020 + ThpQrCodeTag = 1024 + ThpQrCodeSecret = 1025 + ThpNfcTagHost = 1032 + ThpNfcTagTrezor = 1033 + + +class ThpPairingMethod(IntEnum): + SkipPairing = 1 + CodeEntry = 2 + QrCode = 3 + NFC = 4 + + class MessageType(IntEnum): Initialize = 0 Ping = 1 @@ -500,6 +532,8 @@ class MessageType(IntEnum): DebugLinkWatchLayout = 9006 DebugLinkResetDebugEvents = 9007 DebugLinkOptigaSetSecMax = 9008 + DebugLinkGetPairingInfo = 9009 + DebugLinkPairingInfo = 9010 EthereumGetPublicKey = 450 EthereumPublicKey = 451 EthereumGetAddress = 56 @@ -4203,6 +4237,52 @@ class DebugLinkState(protobuf.MessageType): self.mnemonic_type = mnemonic_type +class DebugLinkGetPairingInfo(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 9009 + FIELDS = { + 1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None), + 3: protobuf.Field("nfc_secret_host", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + channel_id: Optional["bytes"] = None, + handshake_hash: Optional["bytes"] = None, + nfc_secret_host: Optional["bytes"] = None, + ) -> None: + self.channel_id = channel_id + self.handshake_hash = handshake_hash + self.nfc_secret_host = nfc_secret_host + + +class DebugLinkPairingInfo(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 9010 + FIELDS = { + 1: protobuf.Field("channel_id", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("handshake_hash", "bytes", repeated=False, required=False, default=None), + 3: protobuf.Field("code_entry_code", "uint32", repeated=False, required=False, default=None), + 4: protobuf.Field("code_qr_code", "bytes", repeated=False, required=False, default=None), + 5: protobuf.Field("nfc_secret_trezor", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + channel_id: Optional["bytes"] = None, + handshake_hash: Optional["bytes"] = None, + code_entry_code: Optional["int"] = None, + code_qr_code: Optional["bytes"] = None, + nfc_secret_trezor: Optional["bytes"] = None, + ) -> None: + self.channel_id = channel_id + self.handshake_hash = handshake_hash + self.code_entry_code = code_entry_code + self.code_qr_code = code_qr_code + self.nfc_secret_trezor = nfc_secret_trezor + + class DebugLinkStop(protobuf.MessageType): MESSAGE_WIRE_TYPE = 103 @@ -7863,8 +7943,68 @@ class TezosManagerTransfer(protobuf.MessageType): self.amount = amount -class ThpCredentialMetadata(protobuf.MessageType): +class ThpDeviceProperties(protobuf.MessageType): MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None), + 3: protobuf.Field("protocol_version_major", "uint32", repeated=False, required=False, default=None), + 4: protobuf.Field("protocol_version_minor", "uint32", repeated=False, required=False, default=None), + 5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None), + } + + def __init__( + self, + *, + pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None, + internal_model: Optional["str"] = None, + model_variant: Optional["int"] = None, + protocol_version_major: Optional["int"] = None, + protocol_version_minor: Optional["int"] = None, + ) -> None: + self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else [] + self.internal_model = internal_model + self.model_variant = model_variant + self.protocol_version_major = protocol_version_major + self.protocol_version_minor = protocol_version_minor + + +class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_pairing_credential: Optional["bytes"] = None, + ) -> None: + self.host_pairing_credential = host_pairing_credential + + +class ThpCreateNewSession(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1000 + FIELDS = { + 1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None), + 3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + passphrase: Optional["str"] = None, + on_device: Optional["bool"] = None, + derive_cardano: Optional["bool"] = None, + ) -> None: + self.passphrase = passphrase + self.on_device = on_device + self.derive_cardano = derive_cardano + + +class ThpPairingRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1006 FIELDS = { 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), } @@ -7877,6 +8017,216 @@ class ThpCredentialMetadata(protobuf.MessageType): self.host_name = host_name +class ThpPairingRequestApproved(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1007 + + +class ThpSelectMethod(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1008 + FIELDS = { + 1: protobuf.Field("selected_pairing_method", "ThpPairingMethod", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + selected_pairing_method: Optional["ThpPairingMethod"] = None, + ) -> None: + self.selected_pairing_method = selected_pairing_method + + +class ThpPairingPreparationsFinished(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1009 + + +class ThpCodeEntryCommitment(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1016 + FIELDS = { + 1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + commitment: Optional["bytes"] = None, + ) -> None: + self.commitment = commitment + + +class ThpCodeEntryChallenge(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1017 + FIELDS = { + 1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + challenge: Optional["bytes"] = None, + ) -> None: + self.challenge = challenge + + +class ThpCodeEntryCpaceTrezor(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1018 + FIELDS = { + 1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_trezor_public_key: Optional["bytes"] = None, + ) -> None: + self.cpace_trezor_public_key = cpace_trezor_public_key + + +class ThpCodeEntryCpaceHostTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1019 + FIELDS = { + 1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_host_public_key: Optional["bytes"] = None, + tag: Optional["bytes"] = None, + ) -> None: + self.cpace_host_public_key = cpace_host_public_key + self.tag = tag + + +class ThpCodeEntrySecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1020 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpQrCodeTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1024 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpQrCodeSecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1025 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpNfcTagHost(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1032 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpNfcTagTrezor(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1033 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpCredentialRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1010 + FIELDS = { + 1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_static_pubkey: Optional["bytes"] = None, + autoconnect: Optional["bool"] = None, + ) -> None: + self.host_static_pubkey = host_static_pubkey + self.autoconnect = autoconnect + + +class ThpCredentialResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1011 + FIELDS = { + 1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + trezor_static_pubkey: Optional["bytes"] = None, + credential: Optional["bytes"] = None, + ) -> None: + self.trezor_static_pubkey = trezor_static_pubkey + self.credential = credential + + +class ThpEndRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1012 + + +class ThpEndResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1013 + + +class ThpCredentialMetadata(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("autoconnect", "bool", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_name: Optional["str"] = None, + autoconnect: Optional["bool"] = None, + ) -> None: + self.host_name = host_name + self.autoconnect = autoconnect + + class ThpPairingCredential(protobuf.MessageType): MESSAGE_WIRE_TYPE = None FIELDS = { diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index 578c1fa19f..eeaea26872 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -19,22 +19,22 @@ from typing import TYPE_CHECKING, Optional from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session -def get_entropy(client: "TrezorClient", size: int) -> bytes: - return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy +def get_entropy(session: "Session", size: int) -> bytes: + return session.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy def sign_identity( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, ecdsa_curve_name: Optional[str] = None, ) -> messages.SignedIdentity: - return client.call( + return session.call( messages.SignIdentity( identity=identity, challenge_hidden=challenge_hidden, @@ -46,12 +46,12 @@ def sign_identity( def get_ecdh_session_key( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, peer_public_key: bytes, ecdsa_curve_name: Optional[str] = None, ) -> messages.ECDHSessionKey: - return client.call( + return session.call( messages.GetECDHSessionKey( identity=identity, peer_public_key=peer_public_key, @@ -62,7 +62,7 @@ def get_ecdh_session_key( def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -70,7 +70,7 @@ def encrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> bytes: - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -85,7 +85,7 @@ def encrypt_keyvalue( def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -93,7 +93,7 @@ def decrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> bytes: - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -107,5 +107,5 @@ def decrypt_keyvalue( ).value -def get_nonce(client: "TrezorClient") -> bytes: - return client.call(messages.GetNonce(), expect=messages.Nonce).nonce +def get_nonce(session: "Session") -> bytes: + return session.call(messages.GetNonce(), expect=messages.Nonce).nonce diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index b2e3214fb9..9e32346156 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -19,8 +19,8 @@ from typing import TYPE_CHECKING from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session # MAINNET = 0 @@ -30,13 +30,13 @@ if TYPE_CHECKING: def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, chunkify: bool = False, ) -> bytes: - return client.call( + return session.call( messages.MoneroGetAddress( address_n=n, show_display=show_display, @@ -48,11 +48,11 @@ def get_address( def get_watch_key( - client: "TrezorClient", + session: "Session", n: "Address", network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, ) -> messages.MoneroWatchKey: - return client.call( + return session.call( messages.MoneroGetWatchKey(address_n=n, network_type=network_type), expect=messages.MoneroWatchKey, ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 744dc3205f..357de145ad 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -20,8 +20,8 @@ from typing import TYPE_CHECKING from . import exceptions, messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 @@ -195,13 +195,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig def get_address( - client: "TrezorClient", + session: "Session", n: "Address", network: int, show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.NEMGetAddress( address_n=n, network=network, show_display=show_display, chunkify=chunkify ), @@ -210,7 +210,7 @@ def get_address( def sign_tx( - client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False + session: "Session", n: "Address", transaction: dict, chunkify: bool = False ) -> messages.NEMSignedTx: try: msg = create_sign_tx(transaction, chunkify=chunkify) @@ -219,4 +219,4 @@ def sign_tx( assert msg.transaction is not None msg.transaction.address_n = n - return client.call(msg, expect=messages.NEMSignedTx) + return session.call(msg, expect=messages.NEMSignedTx) diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 00a027c6d9..e5e0f524cc 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -21,20 +21,20 @@ from .protobuf import dict_to_proto from .tools import dict_from_camelcase if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.RippleGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -43,14 +43,14 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", msg: messages.RippleSignTx, chunkify: bool = False, ) -> messages.RippleSignedTx: msg.address_n = address_n msg.chunkify = chunkify - return client.call(msg, expect=messages.RippleSignedTx) + return session.call(msg, expect=messages.RippleSignedTx) def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: diff --git a/python/src/trezorlib/solana.py b/python/src/trezorlib/solana.py index 0054e0fd92..3d0ee75549 100644 --- a/python/src/trezorlib/solana.py +++ b/python/src/trezorlib/solana.py @@ -3,27 +3,27 @@ from typing import TYPE_CHECKING, List, Optional from . import messages if TYPE_CHECKING: - from .client import TrezorClient + from .transport.session import Session def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, ) -> bytes: - return client.call( + return session.call( messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display), expect=messages.SolanaPublicKey, ).public_key def get_address( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.SolanaGetAddress( address_n=address_n, show_display=show_display, @@ -34,12 +34,12 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: List[int], serialized_tx: bytes, additional_info: Optional[messages.SolanaTxAdditionalInfo], ) -> bytes: - return client.call( + return session.call( messages.SolanaSignTx( address_n=address_n, serialized_tx=serialized_tx, diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index 5bd0a749e4..843a2e0c39 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -20,8 +20,8 @@ from typing import TYPE_CHECKING, List, Tuple, Union from . import exceptions, messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session StellarMessageType = Union[ messages.StellarAccountMergeOp, @@ -322,12 +322,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.StellarGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -336,7 +336,7 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", tx: messages.StellarSignTx, operations: List["StellarMessageType"], address_n: "Address", @@ -352,10 +352,10 @@ def sign_tx( # 3. Receive a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 5. The final message received will be StellarSignedTx which is returned from this method - resp = client.call(tx) + resp = session.call(tx) try: while isinstance(resp, messages.StellarTxOpRequest): - resp = client.call(operations.pop(0)) + resp = session.call(operations.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index 9319aa1eaa..06bcafe759 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -19,17 +19,17 @@ from typing import TYPE_CHECKING from . import messages if TYPE_CHECKING: - from .client import TrezorClient from .tools import Address + from .transport.session import Session def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.TezosGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -38,12 +38,12 @@ def get_address( def get_public_key( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> str: - return client.call( + return session.call( messages.TezosGetPublicKey( address_n=address_n, show_display=show_display, chunkify=chunkify ), @@ -52,11 +52,11 @@ def get_public_key( def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", sign_tx_msg: messages.TezosSignTx, chunkify: bool = False, ) -> messages.TezosSignedTx: sign_tx_msg.address_n = address_n sign_tx_msg.chunkify = chunkify - return client.call(sign_tx_msg, expect=messages.TezosSignedTx) + return session.call(sign_tx_msg, expect=messages.TezosSignedTx) diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 6ba8c64dba..f753e68a33 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -45,7 +45,7 @@ if TYPE_CHECKING: # More details: https://www.python.org/dev/peps/pep-0612/ from typing import TypeVar - from typing_extensions import Concatenate, ParamSpec + from typing_extensions import ParamSpec from . import client from .messages import Success @@ -389,23 +389,6 @@ def _return_success(msg: "Success") -> str | None: return _deprecation_retval_helper(msg.message, stacklevel=1) -def session( - f: "Callable[Concatenate[TrezorClient, P], R]", -) -> "Callable[Concatenate[TrezorClient, P], R]": - # Decorator wraps a BaseClient method - # with session activation / deactivation - @functools.wraps(f) - def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - client.open() - try: - return f(client, *args, **kwargs) - finally: - client.close() - - return wrapped_f - - # de-camelcasifier # https://stackoverflow.com/a/1176023/222189