diff --git a/python/src/trezorlib/_internal/emulator.py b/python/src/trezorlib/_internal/emulator.py index 4f6d56f8ed..d171f58200 100644 --- a/python/src/trezorlib/_internal/emulator.py +++ b/python/src/trezorlib/_internal/emulator.py @@ -95,6 +95,15 @@ class Emulator: raise RuntimeError return self._client + @client.setter + def client(self, new_client: TrezorClientDebugLink) -> None: + """Setter for the client property to update _client.""" + if not isinstance(new_client, TrezorClientDebugLink): + raise TypeError( + f"Expected a TrezorClientDebugLink, got {type(new_client).__name__}." + ) + self._client = new_client + def make_args(self) -> List[str]: return [] @@ -112,7 +121,7 @@ class Emulator: start = time.monotonic() try: while True: - if transport._ping(): + if transport.ping(): break if self.process.poll() is not None: raise RuntimeError("Emulator process died") diff --git a/python/src/trezorlib/authentication.py b/python/src/trezorlib/authentication.py index 39e26f569f..2e4a530af5 100644 --- a/python/src/trezorlib/authentication.py +++ b/python/src/trezorlib/authentication.py @@ -7,7 +7,7 @@ import typing as t from importlib import metadata from . import device -from .client import TrezorClient +from .transport.session import Session try: cryptography_version = metadata.version("cryptography") @@ -361,7 +361,7 @@ def verify_authentication_response( def authenticate_device( - client: TrezorClient, + session: Session, challenge: bytes | None = None, *, whitelist: t.Collection[bytes] | None = None, @@ -371,7 +371,7 @@ def authenticate_device( if challenge is None: challenge = secrets.token_bytes(16) - resp = device.authenticate(client, challenge) + resp = device.authenticate(session, challenge) return verify_authentication_response( challenge, diff --git a/python/src/trezorlib/benchmark.py b/python/src/trezorlib/benchmark.py index f96ef7970e..b961dda426 100644 --- a/python/src/trezorlib/benchmark.py +++ b/python/src/trezorlib/benchmark.py @@ -20,17 +20,17 @@ from . import messages from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session @expect(messages.BenchmarkNames) def list_names( - client: "TrezorClient", + session: "Session", ) -> "MessageType": - return client.call(messages.BenchmarkListNames()) + return session.call(messages.BenchmarkListNames()) @expect(messages.BenchmarkResult) -def run(client: "TrezorClient", name: str) -> "MessageType": - return client.call(messages.BenchmarkRun(name=name)) +def run(session: "Session", name: str) -> "MessageType": + return session.call(messages.BenchmarkRun(name=name)) diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index d2e4b97912..afe251a06c 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -18,22 +18,22 @@ from typing import TYPE_CHECKING from . import messages from .protobuf import dict_to_proto -from .tools import expect, session +from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.BinanceAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.BinanceGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -42,16 +42,15 @@ def get_address( @expect(messages.BinancePublicKey, field="public_key", ret_type=bytes) def get_public_key( - client: "TrezorClient", address_n: "Address", show_display: bool = False + session: "Session", address_n: "Address", show_display: bool = False ) -> "MessageType": - return client.call( + return session.call( messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) ) -@session def sign_tx( - client: "TrezorClient", address_n: "Address", tx_json: dict, chunkify: bool = False + session: "Session", address_n: "Address", tx_json: dict, chunkify: bool = False ) -> messages.BinanceSignedTx: msg = tx_json["msgs"][0] tx_msg = tx_json.copy() @@ -60,7 +59,7 @@ def sign_tx( tx_msg["chunkify"] = chunkify envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) - response = client.call(envelope) + response = session.call(envelope) if not isinstance(response, messages.BinanceTxRequest): raise RuntimeError( @@ -77,7 +76,7 @@ def sign_tx( else: raise ValueError("can not determine msg type") - response = client.call(msg) + response = session.call(msg) if not isinstance(response, messages.BinanceSignedTx): raise RuntimeError( diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index a71ead2adc..3ccb1a9595 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -13,7 +13,6 @@ # # You should have received a copy of the License along with this library. # If not, see . - import warnings from copy import copy from decimal import Decimal @@ -23,12 +22,12 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple from typing_extensions import Protocol, TypedDict from . import exceptions, messages -from .tools import expect, prepare_message_bytes, session +from .tools import expect, prepare_message_bytes if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session class ScriptSig(TypedDict): asm: str @@ -105,7 +104,7 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType: @expect(messages.PublicKey) def get_public_node( - client: "TrezorClient", + session: "Session", n: "Address", ecdsa_curve_name: Optional[str] = None, show_display: bool = False, @@ -116,13 +115,13 @@ def get_public_node( unlock_path_mac: Optional[bytes] = None, ) -> "MessageType": if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") - return client.call( + return session.call( messages.GetPublicKey( address_n=n, ecdsa_curve_name=ecdsa_curve_name, @@ -141,7 +140,7 @@ def get_address(*args: Any, **kwargs: Any): @expect(messages.Address) def get_authenticated_address( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", show_display: bool = False, @@ -153,13 +152,13 @@ def get_authenticated_address( chunkify: bool = False, ) -> "MessageType": if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") - return client.call( + return session.call( messages.GetAddress( address_n=n, coin_name=coin_name, @@ -172,15 +171,16 @@ def get_authenticated_address( ) +# TODO this is used by tests only @expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) def get_ownership_id( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> "MessageType": - return client.call( + return session.call( messages.GetOwnershipId( address_n=n, coin_name=coin_name, @@ -190,8 +190,9 @@ def get_ownership_id( ) +# TODO this is used by tests only def get_ownership_proof( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, @@ -202,11 +203,11 @@ def get_ownership_proof( preauthorized: bool = False, ) -> Tuple[bytes, bytes]: if preauthorized: - res = client.call(messages.DoPreauthorized()) + res = session.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): raise exceptions.TrezorException("Unexpected message") - res = client.call( + res = session.call( messages.GetOwnershipProof( address_n=n, coin_name=coin_name, @@ -226,7 +227,7 @@ def get_ownership_proof( @expect(messages.MessageSignature) def sign_message( - client: "TrezorClient", + session: "Session", coin_name: str, n: "Address", message: AnyStr, @@ -234,7 +235,7 @@ def sign_message( no_script_type: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.SignMessage( coin_name=coin_name, address_n=n, @@ -247,7 +248,7 @@ def sign_message( def verify_message( - client: "TrezorClient", + session: "Session", coin_name: str, address: str, signature: bytes, @@ -255,7 +256,7 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - resp = client.call( + resp = session.call( messages.VerifyMessage( address=address, signature=signature, @@ -269,9 +270,9 @@ def verify_message( return isinstance(resp, messages.Success) -@session +# @session def sign_tx( - client: "TrezorClient", + session: "Session", coin_name: str, inputs: Sequence[messages.TxInputType], outputs: Sequence[messages.TxOutputType], @@ -319,17 +320,17 @@ def sign_tx( setattr(signtx, name, value) if unlock_path: - res = client.call( + res = session.call( messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) ) if not isinstance(res, messages.UnlockedPathRequest): raise exceptions.TrezorException("Unexpected message") elif preauthorized: - res = client.call(messages.DoPreauthorized()) + res = session.call(messages.DoPreauthorized()) if not isinstance(res, messages.PreauthorizedRequest): raise exceptions.TrezorException("Unexpected message") - res = client.call(signtx) + res = session.call(signtx) # Prepare structure for signatures signatures: List[Optional[bytes]] = [None] * len(inputs) @@ -388,7 +389,7 @@ def sign_tx( if res.request_type == R.TXPAYMENTREQ: assert res.details.request_index is not None msg = payment_reqs[res.details.request_index] - res = client.call(msg) + res = session.call(msg) else: msg = messages.TransactionType() if res.request_type == R.TXMETA: @@ -418,7 +419,7 @@ def sign_tx( f"Unknown request type - {res.request_type}." ) - res = client.call(messages.TxAck(tx=msg)) + res = session.call(messages.TxAck(tx=msg)) if not isinstance(res, messages.TxRequest): raise exceptions.TrezorException("Unexpected message") @@ -432,7 +433,7 @@ def sign_tx( @expect(messages.Success, field="message", ret_type=str) def authorize_coinjoin( - client: "TrezorClient", + session: "Session", coordinator: str, max_rounds: int, max_coordinator_fee_rate: int, @@ -441,7 +442,7 @@ def authorize_coinjoin( coin_name: str, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, ) -> "MessageType": - return client.call( + return session.call( messages.AuthorizeCoinJoin( coordinator=coordinator, max_rounds=max_rounds, diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 49d2c6463f..f39cfb4222 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -35,8 +35,8 @@ from . import exceptions, messages, tools from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session PROTOCOL_MAGICS = { "mainnet": 764824073, @@ -825,7 +825,7 @@ def _get_collateral_inputs_items( @expect(messages.CardanoAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_parameters: messages.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], @@ -833,7 +833,7 @@ def get_address( derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetAddress( address_parameters=address_parameters, protocol_magic=protocol_magic, @@ -847,12 +847,12 @@ def get_address( @expect(messages.CardanoPublicKey) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, show_display: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type, @@ -863,12 +863,12 @@ def get_public_key( @expect(messages.CardanoNativeScriptHash) def get_native_script_hash( - client: "TrezorClient", + session: "Session", native_script: messages.CardanoNativeScript, display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, ) -> "MessageType": - return client.call( + return session.call( messages.CardanoGetNativeScriptHash( script=native_script, display_format=display_format, @@ -878,7 +878,7 @@ def get_native_script_hash( def sign_tx( - client: "TrezorClient", + session: "Session", signing_mode: messages.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithData], @@ -915,7 +915,7 @@ def sign_tx( signing_mode, ) - response = client.call( + response = session.call( messages.CardanoSignTxInit( signing_mode=signing_mode, inputs_count=len(inputs), @@ -951,14 +951,14 @@ def sign_tx( _get_certificates_items(certificates), withdrawals, ): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: - auxiliary_data_supplement = client.call(auxiliary_data) + auxiliary_data_supplement = session.call(auxiliary_data) if not isinstance( auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement ): @@ -971,7 +971,7 @@ def sign_tx( auxiliary_data_supplement.__dict__ ) - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR @@ -980,24 +980,24 @@ def sign_tx( _get_collateral_inputs_items(collateral_inputs), required_signers, ): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR if collateral_return is not None: for tx_item in _get_output_items(collateral_return): - response = client.call(tx_item) + response = session.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR for reference_input in reference_inputs: - response = client.call(reference_input) + response = session.call(reference_input) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["witnesses"] = [] for witness_request in witness_requests: - response = client.call(witness_request) + response = session.call(witness_request) if not isinstance(response, messages.CardanoTxWitnessResponse): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["witnesses"].append( @@ -1009,12 +1009,12 @@ def sign_tx( } ) - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoTxBodyHash): raise UNEXPECTED_RESPONSE_ERROR sign_tx_response["tx_hash"] = response.tx_hash - response = client.call(messages.CardanoTxHostAck()) + response = session.call(messages.CardanoTxHostAck()) if not isinstance(response, messages.CardanoSignTxFinished): raise UNEXPECTED_RESPONSE_ERROR diff --git a/python/src/trezorlib/cli/__init__.py b/python/src/trezorlib/cli/__init__.py index 6db335a7ad..0b14778ed7 100644 --- a/python/src/trezorlib/cli/__init__.py +++ b/python/src/trezorlib/cli/__init__.py @@ -14,33 +14,42 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import functools +import logging +import os import sys +import typing as t from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional import click -from .. import exceptions, transport -from ..client import TrezorClient -from ..ui import ClickUI, ScriptUI +from .. import exceptions, transport, ui +from ..client import ProtocolVersion, TrezorClient +from ..messages import Capability +from ..transport import Transport +from ..transport.session import Session, SessionV1, SessionV2 +from ..transport.thp.channel_database import get_channel_db -if TYPE_CHECKING: +LOG = logging.getLogger(__name__) + +if t.TYPE_CHECKING: # Needed to enforce a return value from decorators # More details: https://www.python.org/dev/peps/pep-0612/ - from typing import TypeVar from typing_extensions import Concatenate, ParamSpec - from ..transport import Transport - from ..ui import TrezorClientUI - P = ParamSpec("P") - R = TypeVar("R") + R = t.TypeVar("R") + FuncWithSession = t.Callable[Concatenate[Session, P], R] class ChoiceType(click.Choice): - def __init__(self, typemap: Dict[str, Any], case_sensitive: bool = True) -> None: + + def __init__( + self, typemap: t.Dict[str, t.Any], case_sensitive: bool = True + ) -> None: super().__init__(list(typemap.keys())) self.case_sensitive = case_sensitive if case_sensitive: @@ -48,7 +57,7 @@ class ChoiceType(click.Choice): else: self.typemap = {k.lower(): v for k, v in typemap.items()} - def convert(self, value: Any, param: Any, ctx: click.Context) -> Any: + def convert(self, value: t.Any, param: t.Any, ctx: click.Context) -> t.Any: if value in self.typemap.values(): return value value = super().convert(value, param, ctx) @@ -57,11 +66,69 @@ class ChoiceType(click.Choice): return self.typemap[value] +def get_passphrase( + passphrase_on_host: bool, available_on_device: bool +) -> t.Union[str, object]: + if available_on_device and not passphrase_on_host: + return ui.PASSPHRASE_ON_DEVICE + + env_passphrase = os.getenv("PASSPHRASE") + if env_passphrase is not None: + ui.echo("Passphrase required. Using PASSPHRASE environment variable.") + return env_passphrase + + while True: + try: + passphrase = ui.prompt( + "Passphrase required", + hide_input=True, + default="", + show_default=False, + ) + # In case user sees the input on the screen, we do not need confirmation + if not ui.CAN_HANDLE_HIDDEN_INPUT: + return passphrase + second = ui.prompt( + "Confirm your passphrase", + hide_input=True, + default="", + show_default=False, + ) + if passphrase == second: + return passphrase + else: + ui.echo("Passphrase did not match. Please try again.") + except click.Abort: + raise exceptions.Cancelled from None + + +def get_client(transport: Transport) -> TrezorClient: + stored_channels = get_channel_db().load_stored_channels() + stored_transport_paths = [ch.transport_path for ch in stored_channels] + path = transport.get_path() + if path in stored_transport_paths: + stored_channel_with_correct_transport_path = next( + ch for ch in stored_channels if ch.transport_path == path + ) + try: + client = TrezorClient.resume( + transport, stored_channel_with_correct_transport_path + ) + except Exception: + LOG.debug("Failed to resume a channel. Replacing by a new one.") + get_channel_db().remove_channel(path) + client = TrezorClient(transport) + else: + client = TrezorClient(transport) + return client + + class TrezorConnection: + def __init__( self, path: str, - session_id: Optional[bytes], + session_id: bytes | None, passphrase_on_host: bool, script: bool, ) -> None: @@ -70,6 +137,54 @@ class TrezorConnection: self.passphrase_on_host = passphrase_on_host self.script = script + def get_session( + self, + derive_cardano: bool = False, + empty_passphrase: bool = False, + must_resume: bool = False, + ) -> Session: + client = self.get_client() + if must_resume and self.session_id is None: + click.echo("Failed to resume session - no session id provided") + raise RuntimeError("Failed to resume session - no session id provided") + + # Try resume session from id + if self.session_id is not None: + if client.protocol_version is ProtocolVersion.PROTOCOL_V1: + session = SessionV1.resume_from_id( + client=client, session_id=self.session_id + ) + elif client.protocol_version is ProtocolVersion.PROTOCOL_V2: + session = SessionV2(client, self.session_id) + # TODO fix resumption on THP + else: + raise Exception("Unsupported client protocol", client.protocol_version) + if must_resume: + if session.id != self.session_id or session.id is None: + click.echo("Failed to resume session") + RuntimeError("Failed to resume session - no session id provided") + return session + + features = client.protocol.get_features() + + passphrase_enabled = True # TODO what to do here? + + if not passphrase_enabled: + return client.get_session(derive_cardano=derive_cardano) + + if empty_passphrase: + passphrase = "" + else: + available_on_device = Capability.PassphraseEntry in features.capabilities + passphrase = get_passphrase(available_on_device, self.passphrase_on_host) + # TODO handle case when PASSPHRASE_ON_DEVICE is returned from get_passphrase func + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + return session + def get_transport(self) -> "Transport": try: # look for transport without prefix search @@ -82,19 +197,13 @@ class TrezorConnection: # if this fails, we want the exception to bubble up to the caller return transport.get_transport(self.path, prefix_search=True) - def get_ui(self) -> "TrezorClientUI": - if self.script: - # It is alright to return just the class object instead of instance, - # as the ScriptUI class object itself is the implementation of TrezorClientUI - # (ScriptUI is just a set of staticmethods) - return ScriptUI - else: - return ClickUI(passphrase_on_host=self.passphrase_on_host) - def get_client(self) -> TrezorClient: - transport = self.get_transport() - ui = self.get_ui() - return TrezorClient(transport, ui=ui, session_id=self.session_id) + return get_client(self.get_transport()) + + def get_management_session(self) -> Session: + client = self.get_client() + management_session = client.get_management_session() + return management_session @contextmanager def client_context(self): @@ -128,7 +237,57 @@ class TrezorConnection: # other exceptions may cause a traceback -def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[P, R]": +def with_session( + func: "t.Callable[Concatenate[Session, P], R]|None" = None, + *, + empty_passphrase: bool = False, + derive_cardano: bool = False, + management: bool = False, + must_resume: bool = False, +) -> t.Callable[[FuncWithSession], t.Callable[P, R]]: + """Provides a Click command with parameter `session=obj.get_session(...)` or + `session=obj.get_management_session()` based on the parameters provided. + + If default parameters are ok, this decorator can be used without parentheses. + + TODO: handle resumption of sessions and their (potential) closure. + """ + + def decorator( + func: FuncWithSession, + ) -> "t.Callable[P, R]": + + @click.pass_obj + @functools.wraps(func) + def function_with_session( + obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" + ) -> "R": + if management: + session = obj.get_management_session() + else: + session = obj.get_session( + derive_cardano=derive_cardano, + empty_passphrase=empty_passphrase, + must_resume=must_resume, + ) + try: + return func(session, *args, **kwargs) + finally: + pass + # TODO try end session if not resumed + + return function_with_session + + # If the decorator @get_session is used without parentheses + if func and callable(func): + return decorator(func) # type: ignore [Function return type] + + return decorator + + +def with_client( + func: "t.Callable[Concatenate[TrezorClient, P], R]", +) -> "t.Callable[P, R]": """Wrap a Click command in `with obj.client_context() as client`. Sessions are handled transparently. The user is warned when session did not resume @@ -142,23 +301,62 @@ def with_client(func: "Callable[Concatenate[TrezorClient, P], R]") -> "Callable[ obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" ) -> "R": with obj.client_context() as client: - session_was_resumed = obj.session_id == client.session_id - if not session_was_resumed and obj.session_id is not None: - # tried to resume but failed - click.echo("Warning: failed to resume session.", err=True) - + # session_was_resumed = obj.session_id == client.session_id + # if not session_was_resumed and obj.session_id is not None: + # # tried to resume but failed + # click.echo("Warning: failed to resume session.", err=True) + click.echo( + "Warning: resume session detection is not implemented yet!", err=True + ) try: return func(client, *args, **kwargs) finally: - if not session_was_resumed: - try: - client.end_session() - except Exception: - pass + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) + # if not session_was_resumed: + # try: + # client.end_session() + # except Exception: + # pass return trezorctl_command_with_client +# def with_client( +# func: "t.Callable[Concatenate[TrezorClient, P], R]", +# ) -> "t.Callable[P, R]": +# """Wrap a Click command in `with obj.client_context() as client`. + +# Sessions are handled transparently. The user is warned when session did not resume +# cleanly. The session is closed after the command completes - unless the session +# was resumed, in which case it should remain open. +# """ + +# @click.pass_obj +# @functools.wraps(func) +# def trezorctl_command_with_client( +# obj: TrezorConnection, *args: "P.args", **kwargs: "P.kwargs" +# ) -> "R": +# with obj.client_context() as client: +# session_was_resumed = obj.session_id == client.session_id +# if not session_was_resumed and obj.session_id is not None: +# # tried to resume but failed +# click.echo("Warning: failed to resume session.", err=True) + +# try: +# return func(client, *args, **kwargs) +# finally: +# if not session_was_resumed: +# try: +# client.end_session() +# except Exception: +# pass + +# # the return type of @click.pass_obj is improperly specified and pyright doesn't +# # understand that it converts f(obj, *args, **kwargs) to f(*args, **kwargs) +# return trezorctl_command_with_client + + class AliasedGroup(click.Group): """Command group that handles aliases and Click 6.x compatibility. @@ -188,14 +386,14 @@ class AliasedGroup(click.Group): def __init__( self, - aliases: Optional[Dict[str, click.Command]] = None, - *args: Any, - **kwargs: Any, + aliases: t.Dict[str, click.Command] | None = None, + *args: t.Any, + **kwargs: t.Any, ) -> None: super().__init__(*args, **kwargs) self.aliases = aliases or {} - def get_command(self, ctx: click.Context, cmd_name: str) -> Optional[click.Command]: + def get_command(self, ctx: click.Context, cmd_name: str) -> click.Command | None: cmd_name = cmd_name.replace("_", "-") # try to look up the real name cmd = super().get_command(ctx, cmd_name) diff --git a/python/src/trezorlib/cli/benchmark.py b/python/src/trezorlib/cli/benchmark.py index e445089815..7908223881 100644 --- a/python/src/trezorlib/cli/benchmark.py +++ b/python/src/trezorlib/cli/benchmark.py @@ -20,17 +20,15 @@ from typing import TYPE_CHECKING, List, Optional import click from .. import benchmark -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session -def list_names_patern( - client: "TrezorClient", pattern: Optional[str] = None -) -> List[str]: - names = list(benchmark.list_names(client).names) +def list_names_patern(session: "Session", pattern: Optional[str] = None) -> List[str]: + names = list(benchmark.list_names(session).names) if pattern is None: return names return [name for name in names if fnmatch(name, pattern)] @@ -43,10 +41,10 @@ def cli() -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: +@with_session(empty_passphrase=True) +def list_names(session: "Session", pattern: Optional[str] = None) -> None: """List names of all supported benchmarks""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: @@ -56,13 +54,13 @@ def list_names(client: "TrezorClient", pattern: Optional[str] = None) -> None: @cli.command() @click.argument("pattern", required=False) -@with_client -def run(client: "TrezorClient", pattern: Optional[str]) -> None: +@with_session(empty_passphrase=True) +def run(session: "Session", pattern: Optional[str]) -> None: """Run benchmark""" - names = list_names_patern(client, pattern) + names = list_names_patern(session, pattern) if len(names) == 0: click.echo("No benchmark satisfies the pattern.") else: for name in names: - result = benchmark.run(client, name) + result = benchmark.run(session, name) click.echo(f"{name}: {result.value} {result.unit}") diff --git a/python/src/trezorlib/cli/binance.py b/python/src/trezorlib/cli/binance.py index a3139fb271..d8097b3e90 100644 --- a/python/src/trezorlib/cli/binance.py +++ b/python/src/trezorlib/cli/binance.py @@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO import click from .. import binance, tools -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient PATH_HELP = "BIP-32 path to key, e.g. m/44h/714h/0h/0/0" @@ -39,23 +39,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Binance address for specified path.""" address_n = tools.parse_path(address) - return binance.get_address(client, address_n, show_display, chunkify) + return binance.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Binance public key.""" address_n = tools.parse_path(address) - return binance.get_public_key(client, address_n, show_display).hex() + return binance.get_public_key(session, address_n, show_display).hex() @cli.command() @@ -63,13 +63,13 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.BinanceSignedTx": """Sign Binance transaction. Transaction must be provided as a JSON file. """ address_n = tools.parse_path(address) - return binance.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + return binance.sign_tx(session, address_n, json.load(file), chunkify=chunkify) diff --git a/python/src/trezorlib/cli/btc.py b/python/src/trezorlib/cli/btc.py index d6a9867215..77bbe83f81 100644 --- a/python/src/trezorlib/cli/btc.py +++ b/python/src/trezorlib/cli/btc.py @@ -13,6 +13,7 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import base64 import json @@ -22,10 +23,10 @@ import click import construct as c from .. import btc, messages, protobuf, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PURPOSE_BIP44 = 44 PURPOSE_BIP48 = 48 @@ -174,15 +175,15 @@ def cli() -> None: help="Sort pubkeys lexicographically using BIP-67", ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", coin: str, address: str, - script_type: Optional[messages.InputScriptType], + script_type: messages.InputScriptType | None, show_display: bool, multisig_xpub: List[str], - multisig_threshold: Optional[int], + multisig_threshold: int | None, multisig_suffix_length: int, multisig_sort_pubkeys: bool, chunkify: bool, @@ -235,7 +236,7 @@ def get_address( multisig = None return btc.get_address( - client, + session, coin, address_n, show_display, @@ -252,9 +253,9 @@ def get_address( @click.option("-e", "--curve") @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_node( - client: "TrezorClient", + session: "Session", coin: str, address: str, curve: Optional[str], @@ -266,7 +267,7 @@ def get_public_node( if script_type is None: script_type = guess_script_type_from_path(address_n) result = btc.get_public_node( - client, + session, address_n, ecdsa_curve_name=curve, show_display=show_display, @@ -292,7 +293,7 @@ def _append_descriptor_checksum(desc: str) -> str: def _get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, purpose: Optional[int], @@ -326,7 +327,7 @@ def _get_descriptor( n = tools.parse_path(path) pub = btc.get_public_node( - client, + session, n, show_display=show_display, coin_name=coin, @@ -363,9 +364,9 @@ def _get_descriptor( @click.option("-a", "--account-type", type=ChoiceType(ACCOUNT_TYPE_TO_BIP_PURPOSE)) @click.option("-t", "--script-type", type=ChoiceType(INPUT_SCRIPTS)) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_descriptor( - client: "TrezorClient", + session: "Session", coin: Optional[str], account: int, account_type: Optional[int], @@ -375,7 +376,7 @@ def get_descriptor( """Get descriptor of given account.""" try: return _get_descriptor( - client, coin, account, account_type, script_type, show_display + session, coin, account, account_type, script_type, show_display ) except ValueError as e: raise click.ClickException(str(e)) @@ -390,8 +391,8 @@ def get_descriptor( @click.option("-c", "--coin", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) @click.argument("json_file", type=click.File()) -@with_client -def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", json_file: TextIO, chunkify: bool) -> None: """Sign transaction. Transaction data must be provided in a JSON file. See `transaction-format.md` for @@ -416,7 +417,7 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: } _, serialized_tx = btc.sign_tx( - client, + session, coin, inputs, outputs, @@ -447,9 +448,9 @@ def sign_tx(client: "TrezorClient", json_file: TextIO, chunkify: bool) -> None: ) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, message: str, @@ -462,7 +463,7 @@ def sign_message( if script_type is None: script_type = guess_script_type_from_path(address_n) res = btc.sign_message( - client, + session, coin, address_n, message, @@ -483,9 +484,9 @@ def sign_message( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", coin: str, address: str, signature: str, @@ -495,7 +496,7 @@ def verify_message( """Verify message.""" signature_bytes = base64.b64decode(signature) return btc.verify_message( - client, coin, address, signature_bytes, message, chunkify=chunkify + session, coin, address, signature_bytes, message, chunkify=chunkify ) diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index 26d4eab5b9..1e6935d6d9 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, Optional, TextIO import click from .. import cardano, messages, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/1815h/0h/0/0" @@ -62,9 +62,9 @@ def cli() -> None: @click.option("-i", "--include-network-id", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) @click.option("-T", "--tag-cbor-sets", is_flag=True) -@with_client +@with_session(derive_cardano=True) def sign_tx( - client: "TrezorClient", + session: "Session", file: TextIO, signing_mode: messages.CardanoTxSigningMode, protocol_magic: int, @@ -123,9 +123,8 @@ def sign_tx( for p in transaction["additional_witness_requests"] ] - client.init_device(derive_cardano=True) sign_tx_response = cardano.sign_tx( - client, + session, signing_mode, inputs, outputs, @@ -209,9 +208,9 @@ def sign_tx( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_address( - client: "TrezorClient", + session: "Session", address: str, address_type: messages.CardanoAddressType, staking_address: str, @@ -262,9 +261,8 @@ def get_address( script_staking_hash_bytes, ) - client.init_device(derive_cardano=True) return cardano.get_address( - client, + session, address_parameters, protocol_magic, network_id, @@ -283,18 +281,17 @@ def get_address( default=messages.CardanoDerivationType.ICARUS, ) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session(derive_cardano=True) def get_public_key( - client: "TrezorClient", + session: "Session", address: str, derivation_type: messages.CardanoDerivationType, show_display: bool, ) -> messages.CardanoPublicKey: """Get Cardano public key.""" address_n = tools.parse_path(address) - client.init_device(derive_cardano=True) return cardano.get_public_key( - client, address_n, derivation_type=derivation_type, show_display=show_display + session, address_n, derivation_type=derivation_type, show_display=show_display ) @@ -312,9 +309,9 @@ def get_public_key( type=ChoiceType({m.name: m for m in messages.CardanoDerivationType}), default=messages.CardanoDerivationType.ICARUS, ) -@with_client +@with_session(derive_cardano=True) def get_native_script_hash( - client: "TrezorClient", + session: "Session", file: TextIO, display_format: messages.CardanoNativeScriptHashDisplayFormat, derivation_type: messages.CardanoDerivationType, @@ -323,7 +320,6 @@ def get_native_script_hash( native_script_json = json.load(file) native_script = cardano.parse_native_script(native_script_json) - client.init_device(derive_cardano=True) return cardano.get_native_script_hash( - client, native_script, display_format, derivation_type=derivation_type + session, native_script, display_format, derivation_type=derivation_type ) diff --git a/python/src/trezorlib/cli/crypto.py b/python/src/trezorlib/cli/crypto.py index a58b80d4b6..469bc719a4 100644 --- a/python/src/trezorlib/cli/crypto.py +++ b/python/src/trezorlib/cli/crypto.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Tuple import click from .. import misc, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PROMPT_TYPE = ChoiceType( @@ -42,10 +42,10 @@ def cli() -> None: @cli.command() @click.argument("size", type=int) -@with_client -def get_entropy(client: "TrezorClient", size: int) -> str: +@with_session(empty_passphrase=True) +def get_entropy(session: "Session", size: int) -> str: """Get random bytes from device.""" - return misc.get_entropy(client, size).hex() + return misc.get_entropy(session, size).hex() @cli.command() @@ -55,9 +55,9 @@ def get_entropy(client: "TrezorClient", size: int) -> str: ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -75,7 +75,7 @@ def encrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.encrypt_keyvalue( - client, + session, address_n, key, value.encode(), @@ -91,9 +91,9 @@ def encrypt_keyvalue( ) @click.argument("key") @click.argument("value") -@with_client +@with_session(empty_passphrase=True) def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", address: str, key: str, value: str, @@ -112,7 +112,7 @@ def decrypt_keyvalue( ask_on_encrypt, ask_on_decrypt = prompt address_n = tools.parse_path(address) return misc.decrypt_keyvalue( - client, + session, address_n, key, bytes.fromhex(value), diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index 50613a04ee..1670117eb8 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -18,13 +18,12 @@ from typing import TYPE_CHECKING, Union import click -from .. import mapping, messages, protobuf -from ..client import TrezorClient from ..debuglink import TrezorClientDebugLink from ..debuglink import optiga_set_sec_max as debuglink_optiga_set_sec_max from ..debuglink import prodtest_t1 as debuglink_prodtest_t1 from ..debuglink import record_screen -from . import with_client +from ..transport.session import Session +from . import with_session if TYPE_CHECKING: from . import TrezorConnection @@ -35,51 +34,51 @@ def cli() -> None: """Miscellaneous debug features.""" -@cli.command() -@click.argument("message_name_or_type") -@click.argument("hex_data") -@click.pass_obj -def send_bytes( - obj: "TrezorConnection", message_name_or_type: str, hex_data: str -) -> None: - """Send raw bytes to Trezor. +# @cli.command() +# @click.argument("message_name_or_type") +# @click.argument("hex_data") +# @click.pass_obj +# def send_bytes( +# obj: "NewTrezorConnection", message_name_or_type: str, hex_data: str +# ) -> None: +# """Send raw bytes to Trezor. - Message type and message data must be specified separately, due to how message - chunking works on the transport level. Message length is calculated and sent - automatically, and it is currently impossible to explicitly specify invalid length. +# Message type and message data must be specified separately, due to how message +# chunking works on the transport level. Message length is calculated and sent +# automatically, and it is currently impossible to explicitly specify invalid length. - MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, - in which case the value of that enum is used. - """ - if message_name_or_type.isdigit(): - message_type = int(message_name_or_type) - else: - message_type = getattr(messages.MessageType, message_name_or_type) +# MESSAGE_NAME_OR_TYPE can either be a number, or a name from the MessageType enum, +# in which case the value of that enum is used. +# """ +# if message_name_or_type.isdigit(): +# message_type = int(message_name_or_type) +# else: +# message_type = getattr(messages.MessageType, message_name_or_type) - if not isinstance(message_type, int): - raise click.ClickException("Invalid message type.") +# if not isinstance(message_type, int): +# raise click.ClickException("Invalid message type.") - try: - message_data = bytes.fromhex(hex_data) - except Exception as e: - raise click.ClickException("Invalid hex data.") from e +# try: +# message_data = bytes.fromhex(hex_data) +# except Exception as e: +# raise click.ClickException("Invalid hex data.") from e - transport = obj.get_transport() - transport.begin_session() - transport.write(message_type, message_data) +# transport = obj.get_transport() +# transport.deprecated_begin_session() +# transport.write(message_type, message_data) - response_type, response_data = transport.read() - transport.end_session() +# response_type, response_data = transport.read() +# transport.deprecated_end_session() - click.echo(f"Response type: {response_type}") - click.echo(f"Response data: {response_data.hex()}") +# click.echo(f"Response type: {response_type}") +# click.echo(f"Response data: {response_data.hex()}") - try: - msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) - click.echo("Parsed message:") - click.echo(protobuf.format_message(msg)) - except Exception as e: - click.echo(f"Could not parse response: {e}") +# try: +# msg = mapping.DEFAULT_MAPPING.decode(response_type, response_data) +# click.echo("Parsed message:") +# click.echo(protobuf.format_message(msg)) +# except Exception as e: +# click.echo(f"Could not parse response: {e}") @cli.command() @@ -106,17 +105,17 @@ def record_screen_from_connection( @cli.command() -@with_client -def prodtest_t1(client: "TrezorClient") -> str: +@with_session(management=True) +def prodtest_t1(session: "Session") -> str: """Perform a prodtest on Model One. Only available on PRODTEST firmware and on T1B1. Formerly named self-test. """ - return debuglink_prodtest_t1(client) + return debuglink_prodtest_t1(session) @cli.command() -@with_client -def optiga_set_sec_max(client: "TrezorClient") -> str: +@with_session(management=True) +def optiga_set_sec_max(session: "Session") -> str: """Set Optiga's security event counter to maximum.""" - return debuglink_optiga_set_sec_max(client) + return debuglink_optiga_set_sec_max(session) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 52c0bd3961..d53aad1993 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -24,11 +24,11 @@ import click import requests from .. import debuglink, device, exceptions, messages, ui -from . import ChoiceType, with_client +from . import ChoiceType, with_session if t.TYPE_CHECKING: - from ..client import TrezorClient from ..protobuf import MessageType + from ..transport.session import Session from . import TrezorConnection RECOVERY_DEVICE_INPUT_METHOD = { @@ -64,17 +64,18 @@ def cli() -> None: help="Wipe device in bootloader mode. This also erases the firmware.", is_flag=True, ) -@with_client -def wipe(client: "TrezorClient", bootloader: bool) -> str: +@with_session(management=True) +def wipe(session: "Session", bootloader: bool) -> str: """Reset device to factory defaults and remove all private data.""" + features = session.features if bootloader: - if not client.features.bootloader_mode: + if not features.bootloader_mode: click.echo("Please switch your device to bootloader mode.") sys.exit(1) else: click.echo("Wiping user data and firmware!") else: - if client.features.bootloader_mode: + if features.bootloader_mode: click.echo( "Your device is in bootloader mode. This operation would also erase firmware." ) @@ -87,7 +88,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str: click.echo("Wiping user data!") try: - return device.wipe(client) + return device.wipe( + session + ) # TODO decide where the wipe should happen - management or regular session except exceptions.TrezorFailure as e: click.echo("Action failed: {} {}".format(*e.args)) sys.exit(3) @@ -103,9 +106,9 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str: @click.option("-a", "--academic", is_flag=True) @click.option("-b", "--needs-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) -@with_client +@with_session(management=True) def load( - client: "TrezorClient", + session: "Session", mnemonic: t.Sequence[str], pin: str, passphrase_protection: bool, @@ -136,7 +139,7 @@ def load( try: return debuglink.load_device( - client, + session, mnemonic=list(mnemonic), pin=pin, passphrase_protection=passphrase_protection, @@ -171,9 +174,9 @@ def load( ) @click.option("-d", "--dry-run", is_flag=True) @click.option("-b", "--unlock-repeated-backup", is_flag=True) -@with_client +@with_session(management=True) def recover( - client: "TrezorClient", + session: "Session", words: str, expand: bool, pin_protection: bool, @@ -201,7 +204,7 @@ def recover( type = messages.RecoveryType.UnlockRepeatedBackup return device.recover( - client, + session, word_count=int(words), passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -222,9 +225,9 @@ def recover( @click.option("-s", "--skip-backup", is_flag=True) @click.option("-n", "--no-backup", is_flag=True) @click.option("-b", "--backup-type", type=ChoiceType(BACKUP_TYPE)) -@with_client +@with_session(management=True) def setup( - client: "TrezorClient", + session: "Session", strength: int | None, passphrase_protection: bool, pin_protection: bool, @@ -241,7 +244,7 @@ def setup( BT = messages.BackupType if backup_type is None: - if client.version >= (2, 7, 1): + if session.version >= (2, 7, 1): # SLIP39 extendable was introduced in 2.7.1 backup_type = BT.Slip39_Single_Extendable else: @@ -251,10 +254,10 @@ def setup( if ( backup_type in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) - and messages.Capability.Shamir not in client.features.capabilities + and messages.Capability.Shamir not in session.features.capabilities ) or ( backup_type in (BT.Slip39_Advanced, BT.Slip39_Advanced_Extendable) - and messages.Capability.ShamirGroups not in client.features.capabilities + and messages.Capability.ShamirGroups not in session.features.capabilities ): click.echo( "WARNING: Your Trezor device does not indicate support for the requested\n" @@ -262,7 +265,7 @@ def setup( ) return device.reset( - client, + session, strength=strength, passphrase_protection=passphrase_protection, pin_protection=pin_protection, @@ -277,23 +280,21 @@ def setup( @cli.command() @click.option("-t", "--group-threshold", type=int) @click.option("-g", "--group", "groups", type=(int, int), multiple=True, metavar="T N") -@with_client +@with_session(management=True) def backup( - client: "TrezorClient", + session: "Session", group_threshold: int | None = None, groups: t.Sequence[tuple[int, int]] = (), ) -> str: """Perform device seed backup.""" - return device.backup(client, group_threshold, groups) + return device.backup(session, group_threshold, groups) @cli.command() @click.argument("operation", type=ChoiceType(SD_PROTECT_OPERATIONS)) -@with_client -def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType -) -> str: +@with_session(management=True) +def sd_protect(session: "Session", operation: messages.SdProtectOperationType) -> str: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -307,9 +308,9 @@ def sd_protect( off - Remove SD card secret protection. refresh - Replace the current SD card secret with a new one. """ - if client.features.model == "1": + if session.features.model == "1": raise click.ClickException("Trezor One does not support SD card protection.") - return device.sd_protect(client, operation) + return device.sd_protect(session, operation) @cli.command() @@ -319,24 +320,24 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str: Currently only supported on Trezor Model One. """ - # avoid using @with_client because it closes the session afterwards, + # avoid using @with_management_session because it closes the session afterwards, # which triggers double prompt on device with obj.client_context() as client: - return device.reboot_to_bootloader(client) + return device.reboot_to_bootloader(client.get_management_session()) @cli.command() -@with_client -def tutorial(client: "TrezorClient") -> str: +@with_session(management=True) +def tutorial(session: "Session") -> str: """Show on-device tutorial.""" - return device.show_device_tutorial(client) + return device.show_device_tutorial(session) @cli.command() -@with_client -def unlock_bootloader(client: "TrezorClient") -> str: +@with_session(management=True) +def unlock_bootloader(session: "Session") -> str: """Unlocks bootloader. Irreversible.""" - return device.unlock_bootloader(client) + return device.unlock_bootloader(session) @cli.command() @@ -347,11 +348,11 @@ def unlock_bootloader(client: "TrezorClient") -> str: type=int, help="Dialog expiry in seconds.", ) -@with_client -def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str: +@with_session(management=True) +def set_busy(session: "Session", enable: bool | None, expiry: int | None) -> str: """Show a "Do not disconnect" dialog.""" if enable is False: - return device.set_busy(client, None) + return device.set_busy(session, None) if expiry is None: raise click.ClickException("Missing option '-e' / '--expiry'.") @@ -361,7 +362,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." ) - return device.set_busy(client, expiry * 1000) + return device.set_busy(session, expiry * 1000) PUBKEY_WHITELIST_URL_TEMPLATE = ( @@ -381,9 +382,9 @@ PUBKEY_WHITELIST_URL_TEMPLATE = ( is_flag=True, help="Do not check intermediate certificates against the whitelist.", ) -@with_client +@with_session(management=True) def authenticate( - client: "TrezorClient", + session: "Session", hex_challenge: str | None, root: t.BinaryIO | None, raw: bool | None, @@ -408,7 +409,7 @@ def authenticate( challenge = bytes.fromhex(hex_challenge) if raw: - msg = device.authenticate(client, challenge) + msg = device.authenticate(session, challenge) click.echo(f"Challenge: {hex_challenge}") click.echo(f"Signature of challenge: {msg.signature.hex()}") @@ -456,14 +457,14 @@ def authenticate( else: whitelist_json = requests.get( PUBKEY_WHITELIST_URL_TEMPLATE.format( - model=client.model.internal_name.lower() + model=session.model.internal_name.lower() ) ).json() whitelist = [bytes.fromhex(pk) for pk in whitelist_json["ca_pubkeys"]] try: authentication.authenticate_device( - client, challenge, root_pubkey=root_bytes, whitelist=whitelist + session, challenge, root_pubkey=root_bytes, whitelist=whitelist ) except authentication.DeviceNotAuthentic: click.echo("Device is not authentic.") diff --git a/python/src/trezorlib/cli/eos.py b/python/src/trezorlib/cli/eos.py index 84c248c4a4..27d461d8b0 100644 --- a/python/src/trezorlib/cli/eos.py +++ b/python/src/trezorlib/cli/eos.py @@ -20,11 +20,11 @@ from typing import TYPE_CHECKING, TextIO import click from .. import eos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: from .. import messages - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/194h/0h/0/0" @@ -37,11 +37,11 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Eos public key in base58 encoding.""" address_n = tools.parse_path(address) - res = eos.get_public_key(client, address_n, show_display) + res = eos.get_public_key(session, address_n, show_display) return f"WIF: {res.wif_public_key}\nRaw: {res.raw_public_key.hex()}" @@ -50,16 +50,16 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_transaction( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> "messages.EosSignedTx": """Sign EOS transaction.""" tx_json = json.load(file) address_n = tools.parse_path(address) return eos.sign_tx( - client, + session, address_n, tx_json["transaction"], tx_json["chain_id"], diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index 6bbfc0d356..d810d2bf2d 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -26,14 +26,14 @@ import click from .. import _rlp, definitions, ethereum, tools from ..messages import EthereumDefinitions -from . import with_client +from . import with_session if TYPE_CHECKING: import web3 from eth_typing import ChecksumAddress # noqa: I900 from web3.types import Wei - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/60h/0h/0/0" @@ -268,24 +268,24 @@ def cli( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - return ethereum.get_address(client, address_n, show_display, network, chunkify) + return ethereum.get_address(session, address_n, show_display, network, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> dict: +@with_session +def get_public_node(session: "Session", address: str, show_display: bool) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) - result = ethereum.get_public_node(client, address_n, show_display=show_display) + result = ethereum.get_public_node(session, address_n, show_display=show_display) return { "node": { "depth": result.node.depth, @@ -344,9 +344,9 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-C", "--chunkify", is_flag=True) @click.argument("to_address") @click.argument("amount", callback=_amount_to_int) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", chain_id: int, address: str, amount: int, @@ -400,7 +400,7 @@ def sign_tx( encoded_network = DEFINITIONS_SOURCE.get_network(chain_id) address_n = tools.parse_path(address) from_address = ethereum.get_address( - client, address_n, encoded_network=encoded_network + session, address_n, encoded_network=encoded_network ) if token: @@ -446,7 +446,7 @@ def sign_tx( assert max_gas_fee is not None assert max_priority_fee is not None sig = ethereum.sign_tx_eip1559( - client, + session, n=address_n, nonce=nonce, gas_limit=gas_limit, @@ -465,7 +465,7 @@ def sign_tx( gas_price = _get_web3().eth.gas_price assert gas_price is not None sig = ethereum.sign_tx( - client, + session, n=address_n, tx_type=tx_type, nonce=nonce, @@ -526,14 +526,14 @@ def sign_tx( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-C", "--chunkify", is_flag=True) @click.argument("message") -@with_client +@with_session def sign_message( - client: "TrezorClient", address: str, message: str, chunkify: bool + session: "Session", address: str, message: str, chunkify: bool ) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) - ret = ethereum.sign_message(client, address_n, message, network, chunkify=chunkify) + ret = ethereum.sign_message(session, address_n, message, network, chunkify=chunkify) output = { "message": message, "address": ret.address, @@ -550,9 +550,9 @@ def sign_message( help="Be compatible with Metamask's signTypedData_v4 implementation", ) @click.argument("file", type=click.File("r")) -@with_client +@with_session def sign_typed_data( - client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO + session: "Session", address: str, metamask_v4_compat: bool, file: TextIO ) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. @@ -565,7 +565,7 @@ def sign_typed_data( defs = EthereumDefinitions(encoded_network=network) data = json.loads(file.read()) ret = ethereum.sign_typed_data( - client, + session, address_n, data, metamask_v4_compat=metamask_v4_compat, @@ -583,9 +583,9 @@ def sign_typed_data( @click.argument("address") @click.argument("signature") @click.argument("message") -@with_client +@with_session def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: str, message: str, @@ -594,7 +594,7 @@ def verify_message( """Verify message signed with Ethereum address.""" signature_bytes = ethereum.decode_hex(signature) return ethereum.verify_message( - client, address, signature_bytes, message, chunkify=chunkify + session, address, signature_bytes, message, chunkify=chunkify ) @@ -602,9 +602,9 @@ def verify_message( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.argument("domain_hash_hex") @click.argument("message_hash_hex") -@with_client +@with_session def sign_typed_data_hash( - client: "TrezorClient", address: str, domain_hash_hex: str, message_hash_hex: str + session: "Session", address: str, domain_hash_hex: str, message_hash_hex: str ) -> Dict[str, str]: """ Sign hash of typed data (EIP-712) with Ethereum address. @@ -618,7 +618,7 @@ def sign_typed_data_hash( message_hash = ethereum.decode_hex(message_hash_hex) if message_hash_hex else None network = ethereum.network_from_address_n(address_n, DEFINITIONS_SOURCE) ret = ethereum.sign_typed_data_hash( - client, address_n, domain_hash, message_hash, network + session, address_n, domain_hash, message_hash, network ) output = { "domain_hash": domain_hash_hex, diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 5983c57249..024a0bf63f 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING import click from .. import fido -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session ALGORITHM_NAME = {-7: "ES256 (ECDSA w/ SHA-256)", -8: "EdDSA"} @@ -40,10 +40,10 @@ def credentials() -> None: @credentials.command(name="list") -@with_client -def credentials_list(client: "TrezorClient") -> None: +@with_session(empty_passphrase=True) +def credentials_list(session: "Session") -> None: """List all resident credentials on the device.""" - creds = fido.list_credentials(client) + creds = fido.list_credentials(session) for cred in creds: click.echo("") click.echo(f"WebAuthn credential at index {cred.index}:") @@ -79,23 +79,23 @@ def credentials_list(client: "TrezorClient") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") -@with_client -def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: +@with_session(empty_passphrase=True) +def credentials_add(session: "Session", hex_credential_id: str) -> str: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. """ - return fido.add_credential(client, bytes.fromhex(hex_credential_id)) + return fido.add_credential(session, bytes.fromhex(hex_credential_id)) @credentials.command(name="remove") @click.option( "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) -@with_client -def credentials_remove(client: "TrezorClient", index: int) -> str: +@with_session(empty_passphrase=True) +def credentials_remove(session: "Session", index: int) -> str: """Remove the resident credential at the given index.""" - return fido.remove_credential(client, index) + return fido.remove_credential(session, index) # @@ -110,19 +110,19 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) -@with_client -def counter_set(client: "TrezorClient", counter: int) -> str: +@with_session(empty_passphrase=True) +def counter_set(session: "Session", counter: int) -> str: """Set FIDO/U2F counter value.""" - return fido.set_counter(client, counter) + return fido.set_counter(session, counter) @counter.command(name="get-next") -@with_client -def counter_get_next(client: "TrezorClient") -> int: +@with_session(empty_passphrase=True) +def counter_get_next(session: "Session") -> int: """Get-and-increase value of FIDO/U2F counter. FIDO counter value cannot be read directly. On each U2F exchange, the counter value is returned and atomically increased. This command performs the same operation and returns the counter value. """ - return fido.get_next_counter(client) + return fido.get_next_counter(session) diff --git a/python/src/trezorlib/cli/firmware.py b/python/src/trezorlib/cli/firmware.py index 4376a4f283..37a393cb4c 100644 --- a/python/src/trezorlib/cli/firmware.py +++ b/python/src/trezorlib/cli/firmware.py @@ -37,10 +37,11 @@ import requests from .. import device, exceptions, firmware, messages, models from ..firmware import models as fw_models from ..models import TrezorModel -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: from ..client import TrezorClient + from ..transport.session import Session from . import TrezorConnection MODEL_CHOICE = ChoiceType( @@ -74,9 +75,9 @@ def _is_bootloader_onev2(client: "TrezorClient") -> bool: This is the case from bootloader version 1.8.0, and also holds for firmware version 1.8.0 because that installs the appropriate bootloader. """ - f = client.features - version = (f.major_version, f.minor_version, f.patch_version) - bootloader_onev2 = f.major_version == 1 and version >= (1, 8, 0) + features = client.features + version = client.version + bootloader_onev2 = features.major_version == 1 and version >= (1, 8, 0) return bootloader_onev2 @@ -306,25 +307,26 @@ def find_best_firmware_version( If the specified version is not found, prints the closest available version (higher than the specified one, if existing). """ + features = client.features + model = client.model + if bitcoin_only is None: - bitcoin_only = _should_use_bitcoin_only(client.features) + bitcoin_only = _should_use_bitcoin_only(features) def version_str(version: Iterable[int]) -> str: return ".".join(map(str, version)) - f = client.features - - releases = get_all_firmware_releases(client.model, bitcoin_only, beta) + releases = get_all_firmware_releases(model, bitcoin_only, beta) highest_version = releases[0]["version"] if version: want_version = [int(x) for x in version.split(".")] if len(want_version) != 3: click.echo("Please use the 'X.Y.Z' version format.") - if want_version[0] != f.major_version: + if want_version[0] != features.major_version: click.echo( - f"Warning: Trezor {client.model.name} firmware version should be " - f"{f.major_version}.X.Y (requested: {version})" + f"Warning: Trezor {model.name} firmware version should be " + f"{features.major_version}.X.Y (requested: {version})" ) else: want_version = highest_version @@ -359,8 +361,8 @@ def find_best_firmware_version( # to the newer one, in that case update to the minimal # compatible version first # Choosing the version key to compare based on (not) being in BL mode - client_version = [f.major_version, f.minor_version, f.patch_version] - if f.bootloader_mode: + client_version = client.version + if features.bootloader_mode: key_to_compare = "min_bootloader_version" else: key_to_compare = "min_firmware_version" @@ -447,11 +449,11 @@ def extract_embedded_fw( def upload_firmware_into_device( - client: "TrezorClient", + session: "Session", firmware_data: bytes, ) -> None: """Perform the final act of loading the firmware into Trezor.""" - f = client.features + f = session.features try: if f.major_version == 1 and f.firmware_present is not False: # Trezor One does not send ButtonRequest @@ -461,7 +463,7 @@ def upload_firmware_into_device( with click.progressbar( label="Uploading", length=len(firmware_data), show_eta=False ) as bar: - firmware.update(client, firmware_data, bar.update) + firmware.update(session, firmware_data, bar.update) except exceptions.Cancelled: click.echo("Update aborted on device.") except exceptions.TrezorException as e: @@ -654,6 +656,7 @@ def update( against data.trezor.io information, if available. """ with obj.client_context() as client: + management_session = client.get_management_session() if sum(bool(x) for x in (filename, url, version)) > 1: click.echo("You can use only one of: filename, url, version.") sys.exit(1) @@ -709,7 +712,7 @@ def update( if _is_strict_update(client, firmware_data): header_size = _get_firmware_header_size(firmware_data) device.reboot_to_bootloader( - client, + management_session, boot_command=messages.BootCommand.INSTALL_UPGRADE, firmware_header=firmware_data[:header_size], language_data=language_data, @@ -719,7 +722,7 @@ def update( click.echo( "WARNING: Seamless installation not possible, language data will not be uploaded." ) - device.reboot_to_bootloader(client) + device.reboot_to_bootloader(management_session) click.echo("Waiting for bootloader...") while True: @@ -735,13 +738,15 @@ def update( click.echo("Please switch your device to bootloader mode.") sys.exit(1) - upload_firmware_into_device(client=client, firmware_data=firmware_data) + upload_firmware_into_device( + session=client.get_management_session(), firmware_data=firmware_data + ) @cli.command() @click.argument("hex_challenge", required=False) -@with_client -def get_hash(client: "TrezorClient", hex_challenge: Optional[str]) -> str: +@with_session(management=True) +def get_hash(session: "Session", hex_challenge: Optional[str]) -> str: """Get a hash of the installed firmware combined with the optional challenge.""" challenge = bytes.fromhex(hex_challenge) if hex_challenge else None - return firmware.get_hash(client, challenge).hex() + return firmware.get_hash(session, challenge).hex() diff --git a/python/src/trezorlib/cli/monero.py b/python/src/trezorlib/cli/monero.py index 355c562ae3..0441ebc09b 100644 --- a/python/src/trezorlib/cli/monero.py +++ b/python/src/trezorlib/cli/monero.py @@ -19,10 +19,10 @@ from typing import TYPE_CHECKING, Dict import click from .. import messages, monero, tools -from . import ChoiceType, with_client +from . import ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/128h/0h" @@ -42,9 +42,9 @@ def cli() -> None: default=messages.MoneroNetworkType.MAINNET, ) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, network_type: messages.MoneroNetworkType, @@ -52,7 +52,7 @@ def get_address( ) -> bytes: """Get Monero address for specified path.""" address_n = tools.parse_path(address) - return monero.get_address(client, address_n, show_display, network_type, chunkify) + return monero.get_address(session, address_n, show_display, network_type, chunkify) @cli.command() @@ -63,13 +63,13 @@ def get_address( type=ChoiceType({m.name: m for m in messages.MoneroNetworkType}), default=messages.MoneroNetworkType.MAINNET, ) -@with_client +@with_session def get_watch_key( - client: "TrezorClient", address: str, network_type: messages.MoneroNetworkType + session: "Session", address: str, network_type: messages.MoneroNetworkType ) -> Dict[str, str]: """Get Monero watch key for specified path.""" address_n = tools.parse_path(address) - res = monero.get_watch_key(client, address_n, network_type) + res = monero.get_watch_key(session, address_n, network_type) # TODO: could be made required in MoneroWatchKey assert res.address is not None assert res.watch_key is not None diff --git a/python/src/trezorlib/cli/nem.py b/python/src/trezorlib/cli/nem.py index 746ad18723..eac16c2d8c 100644 --- a/python/src/trezorlib/cli/nem.py +++ b/python/src/trezorlib/cli/nem.py @@ -21,10 +21,10 @@ import click import requests from .. import nem, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/134h/0h/0h" @@ -39,9 +39,9 @@ def cli() -> None: @click.option("-N", "--network", type=int, default=0x68) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, network: int, show_display: bool, @@ -49,7 +49,7 @@ def get_address( ) -> str: """Get NEM address for specified path.""" address_n = tools.parse_path(address) - return nem.get_address(client, address_n, network, show_display, chunkify) + return nem.get_address(session, address_n, network, show_display, chunkify) @cli.command() @@ -58,9 +58,9 @@ def get_address( @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-b", "--broadcast", help="NIS to announce transaction to") @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, file: TextIO, broadcast: Optional[str], @@ -71,7 +71,7 @@ def sign_tx( Transaction file is expected in the NIS (RequestPrepareAnnounce) format. """ address_n = tools.parse_path(address) - transaction = nem.sign_tx(client, address_n, json.load(file), chunkify=chunkify) + transaction = nem.sign_tx(session, address_n, json.load(file), chunkify=chunkify) payload = {"data": transaction.data.hex(), "signature": transaction.signature.hex()} diff --git a/python/src/trezorlib/cli/ripple.py b/python/src/trezorlib/cli/ripple.py index e4bcc0b350..634a92028e 100644 --- a/python/src/trezorlib/cli/ripple.py +++ b/python/src/trezorlib/cli/ripple.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO import click from .. import ripple, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/144h/0h/0/0" @@ -37,13 +37,13 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Ripple address""" address_n = tools.parse_path(address) - return ripple.get_address(client, address_n, show_display, chunkify) + return ripple.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -51,13 +51,13 @@ def get_address( @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client -def sign_tx(client: "TrezorClient", address: str, file: TextIO, chunkify: bool) -> None: +@with_session +def sign_tx(session: "Session", address: str, file: TextIO, chunkify: bool) -> None: """Sign Ripple transaction""" address_n = tools.parse_path(address) msg = ripple.create_sign_tx_msg(json.load(file)) - result = ripple.sign_tx(client, address_n, msg, chunkify=chunkify) + result = ripple.sign_tx(session, address_n, msg, chunkify=chunkify) click.echo("Signature:") click.echo(result.signature.hex()) click.echo() diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index eac93eb796..d5e615750d 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -24,10 +24,11 @@ import click import requests from .. import device, messages, toif -from . import AliasedGroup, ChoiceType, with_client +from ..transport.session import Session +from . import AliasedGroup, ChoiceType, with_session if TYPE_CHECKING: - from ..client import TrezorClient + pass try: from PIL import Image @@ -180,18 +181,18 @@ def cli() -> None: @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: +@with_session(management=True) +def pin(session: "Session", enable: Optional[bool], remove: bool) -> str: """Set, change or remove PIN.""" # Remove argument is there for backwards compatibility - return device.change_pin(client, remove=_should_remove(enable, remove)) + return device.change_pin(session, remove=_should_remove(enable, remove)) @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) -@with_client -def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: +@with_session(management=True) +def wipe_code(session: "Session", enable: Optional[bool], remove: bool) -> str: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -199,32 +200,32 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s removed and the device will be reset to factory defaults. """ # Remove argument is there for backwards compatibility - return device.change_wipe_code(client, remove=_should_remove(enable, remove)) + return device.change_wipe_code(session, remove=_should_remove(enable, remove)) @cli.command() # keep the deprecated -l/--label option, make it do nothing @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") -@with_client -def label(client: "TrezorClient", label: str) -> str: +@with_session(management=True) +def label(session: "Session", label: str) -> str: """Set new device label.""" - return device.apply_settings(client, label=label) + return device.apply_settings(session, label=label) @cli.command() -@with_client -def brightness(client: "TrezorClient") -> str: +@with_session(management=True) +def brightness(session: "Session") -> str: """Set display brightness.""" - return device.set_brightness(client) + return device.set_brightness(session) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def haptic_feedback(client: "TrezorClient", enable: bool) -> str: +@with_session(management=True) +def haptic_feedback(session: "Session", enable: bool) -> str: """Enable or disable haptic feedback.""" - return device.apply_settings(client, haptic_feedback=enable) + return device.apply_settings(session, haptic_feedback=enable) @cli.command() @@ -233,9 +234,9 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str: "-r", "--remove", is_flag=True, default=False, help="Switch back to english." ) @click.option("-d/-D", "--display/--no-display", default=None) -@with_client +@with_session(management=True) def language( - client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None + session: "Session", path_or_url: str | None, remove: bool, display: bool | None ) -> str: """Set new language with translations.""" if remove != (path_or_url is None): @@ -260,29 +261,29 @@ def language( f"Failed to load translations from {path_or_url}" ) from None return device.change_language( - client, language_data=language_data, show_display=display + session, language_data=language_data, show_display=display ) @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) -@with_client -def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str: +@with_session(management=True) +def display_rotation(session: "Session", rotation: messages.DisplayRotation) -> str: """Set display rotation. Configure display rotation for Trezor Model T. The options are north, east, south or west. """ - return device.apply_settings(client, display_rotation=rotation) + return device.apply_settings(session, display_rotation=rotation) @cli.command() @click.argument("delay", type=str) -@with_client -def auto_lock_delay(client: "TrezorClient", delay: str) -> str: +@with_session(management=True) +def auto_lock_delay(session: "Session", delay: str) -> str: """Set auto-lock delay (in seconds).""" - if not client.features.pin_protection: + if not session.features.pin_protection: raise click.ClickException("Set up a PIN first") value, unit = delay[:-1], delay[-1:] @@ -291,13 +292,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str: seconds = float(value) * units[unit] else: seconds = float(delay) # assume seconds if no unit is specified - return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) + return device.apply_settings(session, auto_lock_delay_ms=int(seconds * 1000)) @cli.command() @click.argument("flags") -@with_client -def flags(client: "TrezorClient", flags: str) -> str: +@with_session(management=True) +def flags(session: "Session", flags: str) -> str: """Set device flags.""" if flags.lower().startswith("0b"): flags_int = int(flags, 2) @@ -305,7 +306,7 @@ def flags(client: "TrezorClient", flags: str) -> str: flags_int = int(flags, 16) else: flags_int = int(flags) - return device.apply_flags(client, flags=flags_int) + return device.apply_flags(session, flags=flags_int) @cli.command() @@ -314,8 +315,8 @@ def flags(client: "TrezorClient", flags: str) -> str: "-f", "--filename", "_ignore", is_flag=True, hidden=True, expose_value=False ) @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") -@with_client -def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: +@with_session(management=True) +def homescreen(session: "Session", filename: str, quality: int) -> str: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -327,39 +328,39 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: if not path.exists() or not path.is_file(): raise click.ClickException("Cannot open file") - if client.features.model == "1": + if session.features.model == "1": img = image_to_t1(path) else: - if client.features.homescreen_format == messages.HomescreenFormat.Jpeg: + if session.features.homescreen_format == messages.HomescreenFormat.Jpeg: width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 240 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 240 ) img = image_to_jpeg(path, width, height, quality) - elif client.features.homescreen_format == messages.HomescreenFormat.ToiG: - width = client.features.homescreen_width - height = client.features.homescreen_height + elif session.features.homescreen_format == messages.HomescreenFormat.ToiG: + width = session.features.homescreen_width + height = session.features.homescreen_height if width is None or height is None: raise click.ClickException("Device did not report homescreen size.") img = image_to_toif(path, width, height, True) elif ( - client.features.homescreen_format == messages.HomescreenFormat.Toif - or client.features.homescreen_format is None + session.features.homescreen_format == messages.HomescreenFormat.Toif + or session.features.homescreen_format is None ): width = ( - client.features.homescreen_width - if client.features.homescreen_width is not None + session.features.homescreen_width + if session.features.homescreen_width is not None else 144 ) height = ( - client.features.homescreen_height - if client.features.homescreen_height is not None + session.features.homescreen_height + if session.features.homescreen_height is not None else 144 ) img = image_to_toif(path, width, height, False) @@ -369,7 +370,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: "Unknown image format requested by the device." ) - return device.apply_settings(client, homescreen=img) + return device.apply_settings(session, homescreen=img) @cli.command() @@ -377,9 +378,9 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: "--always", is_flag=True, help='Persist the "prompt" setting across Trezor reboots.' ) @click.argument("level", type=ChoiceType(SAFETY_LEVELS)) -@with_client +@with_session(management=True) def safety_checks( - client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel + session: "Session", always: bool, level: messages.SafetyCheckLevel ) -> str: """Set safety check level. @@ -392,18 +393,18 @@ def safety_checks( """ if always and level == messages.SafetyCheckLevel.PromptTemporarily: level = messages.SafetyCheckLevel.PromptAlways - return device.apply_settings(client, safety_checks=level) + return device.apply_settings(session, safety_checks=level) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) -@with_client -def experimental_features(client: "TrezorClient", enable: bool) -> str: +@with_session(management=True) +def experimental_features(session: "Session", enable: bool) -> str: """Enable or disable experimental message types. This is a developer feature. Use with caution. """ - return device.apply_settings(client, experimental_features=enable) + return device.apply_settings(session, experimental_features=enable) # @@ -426,25 +427,25 @@ passphrase = cast(AliasedGroup, passphrase_main) @passphrase.command(name="on") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) -@with_client -def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str: +@with_session(management=True) +def passphrase_on(session: "Session", force_on_device: Optional[bool]) -> str: """Enable passphrase.""" - if client.features.passphrase_protection is not True: + if session.features.passphrase_protection is not True: use_passphrase = True else: use_passphrase = None return device.apply_settings( - client, + session, use_passphrase=use_passphrase, passphrase_always_on_device=force_on_device, ) @passphrase.command(name="off") -@with_client -def passphrase_off(client: "TrezorClient") -> str: +@with_session(management=True) +def passphrase_off(session: "Session") -> str: """Disable passphrase.""" - return device.apply_settings(client, use_passphrase=False) + return device.apply_settings(session, use_passphrase=False) # Registering the aliases for backwards compatibility @@ -457,10 +458,10 @@ passphrase.aliases = { @passphrase.command(name="hide") @click.argument("hide", type=ChoiceType({"on": True, "off": False})) -@with_client -def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str: +@with_session(management=True) +def hide_passphrase_from_host(session: "Session", hide: bool) -> str: """Enable or disable hiding passphrase coming from host. This is a developer feature. Use with caution. """ - return device.apply_settings(client, hide_passphrase_from_host=hide) + return device.apply_settings(session, hide_passphrase_from_host=hide) diff --git a/python/src/trezorlib/cli/solana.py b/python/src/trezorlib/cli/solana.py index 3fe80a5164..8152116b55 100644 --- a/python/src/trezorlib/cli/solana.py +++ b/python/src/trezorlib/cli/solana.py @@ -4,10 +4,10 @@ from typing import TYPE_CHECKING, Optional, TextIO import click from .. import messages, solana, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path to key, e.g. m/44h/501h/0h/0h" DEFAULT_PATH = "m/44h/501h/0h/0h" @@ -21,40 +21,40 @@ def cli() -> None: @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client +@with_session def get_public_key( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, ) -> messages.SolanaPublicKey: """Get Solana public key.""" address_n = tools.parse_path(address) - return solana.get_public_key(client, address_n, show_display) + return solana.get_public_key(session, address_n, show_display) @cli.command() @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", + session: "Session", address: str, show_display: bool, chunkify: bool, ) -> messages.SolanaAddress: """Get Solana address.""" address_n = tools.parse_path(address) - return solana.get_address(client, address_n, show_display, chunkify) + return solana.get_address(session, address_n, show_display, chunkify) @cli.command() @click.argument("serialized_tx", type=str) @click.option("-n", "--address", default=DEFAULT_PATH, help=PATH_HELP) @click.option("-a", "--additional-information-file", type=click.File("r")) -@with_client +@with_session def sign_tx( - client: "TrezorClient", + session: "Session", address: str, serialized_tx: str, additional_information_file: Optional[TextIO], @@ -78,7 +78,7 @@ def sign_tx( ) return solana.sign_tx( - client, + session, address_n, bytes.fromhex(serialized_tx), additional_information, diff --git a/python/src/trezorlib/cli/stellar.py b/python/src/trezorlib/cli/stellar.py index 77ce700ee5..9acb6a57ed 100644 --- a/python/src/trezorlib/cli/stellar.py +++ b/python/src/trezorlib/cli/stellar.py @@ -21,10 +21,10 @@ from typing import TYPE_CHECKING import click from .. import stellar, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session try: from stellar_sdk import ( @@ -52,13 +52,13 @@ def cli() -> None: ) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Stellar public address.""" address_n = tools.parse_path(address) - return stellar.get_address(client, address_n, show_display, chunkify) + return stellar.get_address(session, address_n, show_display, chunkify) @cli.command() @@ -77,9 +77,9 @@ def get_address( help="Network passphrase (blank for public network).", ) @click.argument("b64envelope") -@with_client +@with_session def sign_transaction( - client: "TrezorClient", b64envelope: str, address: str, network_passphrase: str + session: "Session", b64envelope: str, address: str, network_passphrase: str ) -> bytes: """Sign a base64-encoded transaction envelope. @@ -109,6 +109,6 @@ def sign_transaction( address_n = tools.parse_path(address) tx, operations = stellar.from_envelope(envelope) - resp = stellar.sign_tx(client, tx, operations, address_n, network_passphrase) + resp = stellar.sign_tx(session, tx, operations, address_n, network_passphrase) return base64.b64encode(resp.signature) diff --git a/python/src/trezorlib/cli/tezos.py b/python/src/trezorlib/cli/tezos.py index 7dcd1ab9db..e4f0c1a877 100644 --- a/python/src/trezorlib/cli/tezos.py +++ b/python/src/trezorlib/cli/tezos.py @@ -20,10 +20,10 @@ from typing import TYPE_CHECKING, TextIO import click from .. import messages, protobuf, tezos, tools -from . import with_client +from . import with_session if TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session PATH_HELP = "BIP-32 path, e.g. m/44h/1729h/0h" @@ -37,23 +37,23 @@ def cli() -> None: @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def get_address( - client: "TrezorClient", address: str, show_display: bool, chunkify: bool + session: "Session", address: str, show_display: bool, chunkify: bool ) -> str: """Get Tezos address for specified path.""" address_n = tools.parse_path(address) - return tezos.get_address(client, address_n, show_display, chunkify) + return tezos.get_address(session, address_n, show_display, chunkify) @cli.command() @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-d", "--show-display", is_flag=True) -@with_client -def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> str: +@with_session +def get_public_key(session: "Session", address: str, show_display: bool) -> str: """Get Tezos public key.""" address_n = tools.parse_path(address) - return tezos.get_public_key(client, address_n, show_display) + return tezos.get_public_key(session, address_n, show_display) @cli.command() @@ -61,11 +61,11 @@ def get_public_key(client: "TrezorClient", address: str, show_display: bool) -> @click.option("-n", "--address", required=True, help=PATH_HELP) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-C", "--chunkify", is_flag=True) -@with_client +@with_session def sign_tx( - client: "TrezorClient", address: str, file: TextIO, chunkify: bool + session: "Session", address: str, file: TextIO, chunkify: bool ) -> messages.TezosSignedTx: """Sign Tezos transaction.""" address_n = tools.parse_path(address) msg = protobuf.dict_to_proto(messages.TezosSignTx, json.load(file)) - return tezos.sign_tx(client, address_n, msg, chunkify=chunkify) + return tezos.sign_tx(session, address_n, msg, chunkify=chunkify) diff --git a/python/src/trezorlib/cli/trezorctl.py b/python/src/trezorlib/cli/trezorctl.py index 60f8e8d309..b3a885e4c8 100755 --- a/python/src/trezorlib/cli/trezorctl.py +++ b/python/src/trezorlib/cli/trezorctl.py @@ -24,9 +24,12 @@ from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar, ca import click -from .. import __version__, log, messages, protobuf, ui -from ..client import TrezorClient +from .. import __version__, log, messages, protobuf +from ..client import ProtocolVersion, TrezorClient from ..transport import DeviceIsBusy, enumerate_devices +from ..transport.session import Session +from ..transport.thp import channel_database +from ..transport.thp.channel_database import get_channel_db from ..transport.udp import UdpTransport from . import ( AliasedGroup, @@ -50,6 +53,7 @@ from . import ( stellar, tezos, with_client, + with_session, ) F = TypeVar("F", bound=Callable) @@ -193,6 +197,13 @@ def configure_logging(verbose: int) -> None: "--record", help="Record screen changes into a specified directory.", ) +@click.option( + "-n", + "--no-store", + is_flag=True, + help="Do not store channels data between commands.", + default=False, +) @click.version_option(version=__version__) @click.pass_context def cli_main( @@ -204,9 +215,10 @@ def cli_main( script: bool, session_id: Optional[str], record: Optional[str], + no_store: bool, ) -> None: configure_logging(verbose) - + channel_database.set_channel_database(should_not_store=no_store) bytes_session_id: Optional[bytes] = None if session_id is not None: try: @@ -214,6 +226,7 @@ def cli_main( except ValueError: raise click.ClickException(f"Not a valid session id: {session_id}") + # ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) ctx.obj = TrezorConnection(path, bytes_session_id, passphrase_on_host, script) # Optionally record the screen into a specified directory. @@ -285,18 +298,23 @@ def format_device_name(features: messages.Features) -> str: def list_devices(no_resolve: bool) -> Optional[Iterable["Transport"]]: """List connected Trezor devices.""" if no_resolve: - return enumerate_devices() + for d in enumerate_devices(): + print(d.get_path()) + return + + from . import get_client for transport in enumerate_devices(): try: - client = TrezorClient(transport, ui=ui.ClickUI()) + client = get_client(transport) description = format_device_name(client.features) - client.end_session() + if client.protocol_version == ProtocolVersion.PROTOCOL_V2: + get_channel_db().save_channel(client.protocol) except DeviceIsBusy: description = "Device is in use by another process" - except Exception: - description = "Failed to read details" - click.echo(f"{transport} - {description}") + except Exception as e: + description = "Failed to read details " + str(type(e)) + click.echo(f"{transport.get_path()} - {description}") return None @@ -314,15 +332,19 @@ def version() -> str: @cli.command() @click.argument("message") @click.option("-b", "--button-protection", is_flag=True) -@with_client -def ping(client: "TrezorClient", message: str, button_protection: bool) -> str: +@with_session(empty_passphrase=True) +def ping(session: "Session", message: str, button_protection: bool) -> str: """Send ping message.""" - return client.ping(message, button_protection=button_protection) + + # TODO return short-circuit from old client for old Trezors + return session.ping(message, button_protection) @cli.command() @click.pass_obj -def get_session(obj: TrezorConnection) -> str: +def get_session( + obj: TrezorConnection, passphrase: str = "", derive_cardano: bool = False +) -> str: """Get a session ID for subsequent commands. Unlocks Trezor with a passphrase and returns a session ID. Use this session ID with @@ -336,23 +358,44 @@ def get_session(obj: TrezorConnection) -> str: obj.session_id = None with obj.client_context() as client: + if client.features.model == "1" and client.version < (1, 9, 0): raise click.ClickException( "Upgrade your firmware to enable session support." ) - client.ensure_unlocked() - if client.session_id is None: + # client.ensure_unlocked() + session = client.get_session( + passphrase=passphrase, derive_cardano=derive_cardano + ) + if session.id is None: raise click.ClickException("Passphrase not enabled or firmware too old.") else: - return client.session_id.hex() + return session.id.hex() @cli.command() -@with_client -def clear_session(client: "TrezorClient") -> None: +@with_session(must_resume=True, empty_passphrase=True) +def clear_session(session: "Session") -> None: """Clear session (remove cached PIN, passphrase, etc.).""" - return client.clear_session() + if session is None: + click.echo("Cannot clear session as it was not properly resumed.") + return + session.call(messages.LockDevice()) + session.end() + # TODO different behaviour than main, not sure if ok + + +@cli.command() +def delete_channels() -> None: + """ + Delete cached channels. + + Do not use together with the `-n` (`--no-store`) flag, + as the JSON database will not be deleted in that case. + """ + get_channel_db().clear_stored_channels() + click.echo("Deleted stored channels") @cli.command() diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index 2ec853dfd3..d82554dd93 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -13,25 +13,24 @@ # # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations import logging import os -import warnings -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar +import typing as t +from enum import IntEnum -from mnemonic import Mnemonic +from . import mapping, messages, models +from .mapping import ProtobufMapping +from .tools import parse_path +from .transport import Transport, get_transport +from .transport.thp.channel_data import ChannelData +from .transport.thp.protocol_and_channel import ProtocolAndChannel +from .transport.thp.protocol_v1 import ProtocolV1 +from .transport.thp.protocol_v2 import ProtocolV2 -from . import exceptions, mapping, messages, models -from .log import DUMP_BYTES -from .messages import Capability -from .tools import expect, parse_path, session - -if TYPE_CHECKING: - from .protobuf import MessageType - from .transport import Transport - from .ui import TrezorClientUI - -UI = TypeVar("UI", bound="TrezorClientUI") +if t.TYPE_CHECKING: + from .transport.session import Session LOG = logging.getLogger(__name__) @@ -48,8 +47,196 @@ Or visit https://suite.trezor.io/ """.strip() +LOG = logging.getLogger(__name__) + + +class ProtocolVersion(IntEnum): + UNKNOWN = 0x00 + PROTOCOL_V1 = 0x01 # Codec + PROTOCOL_V2 = 0x02 # THP + + +class TrezorClient: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + + _management_session: Session | None = None + _features: messages.Features | None = None + _protocol_version: int + _has_setup_pin: bool = False # Should by used only by conftest + + def __init__( + self, + transport: Transport, + protobuf_mapping: ProtobufMapping | None = None, + protocol: ProtocolAndChannel | None = None, + ) -> None: + self.transport = transport + + if protobuf_mapping is None: + self.mapping = mapping.DEFAULT_MAPPING + else: + self.mapping = protobuf_mapping + if protocol is None: + self.protocol = self._get_protocol() + else: + self.protocol = protocol + self.protocol.mapping = self.mapping + + if isinstance(self.protocol, ProtocolV1): + self._protocol_version = ProtocolVersion.PROTOCOL_V1 + elif isinstance(self.protocol, ProtocolV2): + self._protocol_version = ProtocolVersion.PROTOCOL_V2 + else: + self._protocol_version = ProtocolVersion.UNKNOWN + + @classmethod + def resume( + cls, + transport: Transport, + channel_data: ChannelData, + protobuf_mapping: ProtobufMapping | None = None, + ) -> TrezorClient: + if protobuf_mapping is None: + protobuf_mapping = mapping.DEFAULT_MAPPING + protocol_v1 = ProtocolV1(transport, protobuf_mapping) + if channel_data.protocol_version == 2: + try: + protocol_v1.write(messages.Ping(message="Sanity check - to resume")) + except Exception as e: + print(type(e)) + response = protocol_v1.read() + if ( + isinstance(response, messages.Failure) + and response.code == messages.FailureType.InvalidProtocol + ): + protocol = ProtocolV2(transport, protobuf_mapping, channel_data) + protocol.write(0, messages.Ping()) + response = protocol.read(0) + if not isinstance(response, messages.Success): + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + LOG.debug("Protocol V2 detected - can be resumed") + else: + LOG.debug("Failed to resume ProtocolV2") + raise Exception("Failed to resume ProtocolV2") + else: + protocol = ProtocolV1(transport, protobuf_mapping, channel_data) + return TrezorClient(transport, protobuf_mapping, protocol) + + def get_session( + self, + passphrase: str | object | None = None, + derive_cardano: bool = False, + ) -> Session: + """ + Returns initialized session (with derived seed). + + Will fail if the device is not initialized + """ + from .transport.session import SessionV1, SessionV2 + + if isinstance(self.protocol, ProtocolV1): + if passphrase is None: + passphrase = "" + return SessionV1.new(self, passphrase, derive_cardano) + if isinstance(self.protocol, ProtocolV2): + assert isinstance(passphrase, str) or passphrase is None + return SessionV2.new(self, passphrase, derive_cardano) + raise NotImplementedError # TODO + + def resume_session(self, session: Session): + """ + Note: this function potentially modifies the input session. + """ + from .debuglink import SessionDebugWrapper + from .transport.session import SessionV1, SessionV2 + + if isinstance(session, SessionDebugWrapper): + session = session._session + + if isinstance(session, SessionV2): + return session + elif isinstance(session, SessionV1): + session.init_session() + return session + + else: + raise NotImplementedError + + def get_management_session(self, new_session: bool = False) -> Session: + from .transport.session import SessionV1, SessionV2 + + if not new_session and self._management_session is not None: + return self._management_session + if isinstance(self.protocol, ProtocolV1): + self._management_session = SessionV1.new( + client=self, + passphrase="", + derive_cardano=False, + ) + elif isinstance(self.protocol, ProtocolV2): + self._management_session = SessionV2(client=self, id=b"\x00") + assert self._management_session is not None + return self._management_session + + @property + def features(self) -> messages.Features: + if self._features is None: + self._features = self.protocol.get_features() + assert self._features is not None + return self._features + + @property + def protocol_version(self) -> int: + return self._protocol_version + + @property + def model(self) -> models.TrezorModel: + f = self.features + model = models.by_name(f.model or "1") + + if model is None: + raise RuntimeError( + "Unsupported Trezor model" + f" (internal_model: {f.internal_model}, model: {f.model})" + ) + return model + + @property + def version(self) -> tuple[int, int, int]: + f = self.features + ver = ( + f.major_version, + f.minor_version, + f.patch_version, + ) + return ver + + def refresh_features(self) -> None: + self.protocol.update_features() + self._features = self.protocol.get_features() + + def _get_protocol(self) -> ProtocolAndChannel: + self.transport.open() + + protocol = ProtocolV1(self.transport, mapping.DEFAULT_MAPPING) + + protocol.write(messages.Initialize()) + + response = protocol.read() + self.transport.close() + if isinstance(response, messages.Failure): + if response.code == messages.FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol = ProtocolV2(self.transport, self.mapping) + return protocol + + def get_default_client( - path: Optional[str] = None, ui: Optional["TrezorClientUI"] = None, **kwargs: Any + path: t.Optional[str] = None, + **kwargs: t.Any, ) -> "TrezorClient": """Get a client for a connected Trezor device. @@ -59,434 +246,454 @@ def get_default_client( the value of TREZOR_PATH env variable, or finds first connected Trezor. If no UI is supplied, instantiates the default CLI UI. """ - from .transport import get_transport - from .ui import ClickUI if path is None: path = os.getenv("TREZOR_PATH") transport = get_transport(path, prefix_search=True) - if ui is None: - ui = ClickUI() - - return TrezorClient(transport, ui, **kwargs) - - -class TrezorClient(Generic[UI]): - """Trezor client, a connection to a Trezor device. - - This class allows you to manage connection state, send and receive protobuf - messages, handle user interactions, and perform some generic tasks - (send a cancel message, initialize or clear a session, ping the device). - """ - - model: models.TrezorModel - transport: "Transport" - session_id: Optional[bytes] - ui: UI - features: messages.Features - - def __init__( - self, - transport: "Transport", - ui: UI, - session_id: Optional[bytes] = None, - derive_cardano: Optional[bool] = None, - model: Optional[models.TrezorModel] = None, - _init_device: bool = True, - ) -> None: - """Create a TrezorClient instance. - - You have to provide a `transport`, i.e., a raw connection to the device. You can - use `trezorlib.transport.get_transport` to find one. - - You have to provide a UI implementation for the three kinds of interaction: - - button request (notify the user that their interaction is needed) - - PIN request (on T1, ask the user to input numbers for a PIN matrix) - - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for - details. - - You can supply a `session_id` you might have saved in the previous session. If - you do, the user might not need to enter their passphrase again. - - You can provide Trezor model information. If not provided, it is detected from - the model name reported at initialization time. - - By default, the instance will open a connection to the Trezor device, send an - `Initialize` message, set up the `features` field from the response, and connect - to a session. By specifying `_init_device=False`, this step is skipped. Notably, - this means that `client.features` is unset. Use `client.init_device()` or - `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. - Only use this if you are _sure_ that you know what you are doing. This feature - might be removed at any time. - """ - LOG.info(f"creating client instance for device: {transport.get_path()}") - # Here, self.model could be set to None. Unless _init_device is False, it will - # get correctly reconfigured as part of the init_device flow. - self.model = model # type: ignore ["None" is incompatible with "TrezorModel"] - if self.model: - self.mapping = self.model.default_mapping - else: - self.mapping = mapping.DEFAULT_MAPPING - self.transport = transport - self.ui = ui - self.session_counter = 0 - self.session_id = session_id - if _init_device: - self.init_device(session_id=session_id, derive_cardano=derive_cardano) - - def open(self) -> None: - if self.session_counter == 0: - self.transport.begin_session() - self.session_counter += 1 - - def close(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - # TODO call EndSession here? - self.transport.end_session() - - def cancel(self) -> None: - self._raw_write(messages.Cancel()) - - def call_raw(self, msg: "MessageType") -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - self._raw_write(msg) - return self._raw_read() - - def _raw_write(self, msg: "MessageType") -> None: - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - LOG.debug( - f"sending message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - msg_type, msg_bytes = self.mapping.encode(msg) - LOG.log( - DUMP_BYTES, - f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - self.transport.write(msg_type, msg_bytes) - - def _raw_read(self) -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - msg_type, msg_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", - ) - msg = self.mapping.decode(msg_type, msg_bytes) - LOG.debug( - f"received message: {msg.__class__.__name__}", - extra={"protobuf": msg}, - ) - return msg - - def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": - try: - pin = self.ui.get_pin(msg.type) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if any(d not in "123456789" for d in pin) or not ( - 1 <= len(pin) <= MAX_PIN_LENGTH - ): - self.call_raw(messages.Cancel()) - raise ValueError("Invalid PIN provided") - - resp = self.call_raw(messages.PinMatrixAck(pin=pin)) - if isinstance(resp, messages.Failure) and resp.code in ( - messages.FailureType.PinInvalid, - messages.FailureType.PinCancelled, - messages.FailureType.PinExpected, - ): - raise exceptions.PinException(resp.code, resp.message) - else: - return resp - - def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": - available_on_device = Capability.PassphraseEntry in self.features.capabilities - - def send_passphrase( - passphrase: Optional[str] = None, on_device: Optional[bool] = None - ) -> "MessageType": - msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) - resp = self.call_raw(msg) - if isinstance(resp, messages.Deprecated_PassphraseStateRequest): - self.session_id = resp.state - resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) - return resp - - # short-circuit old style entry - if msg._on_device is True: - return send_passphrase(None, None) - - try: - passphrase = self.ui.get_passphrase(available_on_device=available_on_device) - except exceptions.Cancelled: - self.call_raw(messages.Cancel()) - raise - - if passphrase is PASSPHRASE_ON_DEVICE: - if not available_on_device: - self.call_raw(messages.Cancel()) - raise RuntimeError("Device is not capable of entering passphrase") - else: - return send_passphrase(on_device=True) - - # else process host-entered passphrase - if not isinstance(passphrase, str): - raise RuntimeError("Passphrase must be a str") - passphrase = Mnemonic.normalize_string(passphrase) - if len(passphrase) > MAX_PASSPHRASE_LENGTH: - self.call_raw(messages.Cancel()) - raise ValueError("Passphrase too long") - - return send_passphrase(passphrase, on_device=False) - - def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - # do this raw - send ButtonAck first, notify UI later - self._raw_write(messages.ButtonAck()) - self.ui.button_request(msg) - return self._raw_read() - - @session - def call(self, msg: "MessageType") -> "MessageType": - self.check_firmware_version() - resp = self.call_raw(msg) - while True: - if isinstance(resp, messages.PinMatrixRequest): - resp = self._callback_pin(resp) - elif isinstance(resp, messages.PassphraseRequest): - resp = self._callback_passphrase(resp) - elif isinstance(resp, messages.ButtonRequest): - resp = self._callback_button(resp) - elif isinstance(resp, messages.Failure): - if resp.code == messages.FailureType.ActionCancelled: - raise exceptions.Cancelled - raise exceptions.TrezorFailure(resp) - else: - return resp - - def _refresh_features(self, features: messages.Features) -> None: - """Update internal fields based on passed-in Features message.""" - - if not self.model: - # Trezor Model One bootloader 1.8.0 or older does not send model name - model = models.by_internal_name(features.internal_model) - if model is None: - model = models.by_name(features.model or "1") - if model is None: - raise RuntimeError( - "Unsupported Trezor model" - f" (internal_model: {features.internal_model}, model: {features.model})" - ) - self.model = model - - if features.vendor not in self.model.vendors: - raise RuntimeError("Unsupported device") - - self.features = features - self.version = ( - self.features.major_version, - self.features.minor_version, - self.features.patch_version, - ) - self.check_firmware_version(warn_only=True) - if self.features.session_id is not None: - self.session_id = self.features.session_id - self.features.session_id = None - - @session - def refresh_features(self) -> messages.Features: - """Reload features from the device. - - Should be called after changing settings or performing operations that affect - device state. - """ - resp = self.call_raw(messages.GetFeatures()) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to GetFeatures") - self._refresh_features(resp) - return resp - - @session - def init_device( - self, - *, - session_id: Optional[bytes] = None, - new_session: bool = False, - derive_cardano: Optional[bool] = None, - ) -> Optional[bytes]: - """Initialize the device and return a session ID. - - You can optionally specify a session ID. If the session still exists on the - device, the same session ID will be returned and the session is resumed. - Otherwise a different session ID is returned. - - Specify `new_session=True` to open a fresh session. Since firmware version - 1.9.0/2.3.0, the previous session will remain cached on the device, and can be - resumed by calling `init_device` again with the appropriate session ID. - - If neither `new_session` nor `session_id` is specified, the current session ID - will be reused. If no session ID was cached, a new session ID will be allocated - and returned. - - # Version notes: - - Trezor One older than 1.9.0 does not have session management. Optional arguments - have no effect and the function returns None - - Trezor T older than 2.3.0 does not have session cache. Requesting a new session - will overwrite the old one. In addition, this function will always return None. - A valid session_id can be obtained from the `session_id` attribute, but only - after a passphrase-protected call is performed. You can use the following code: - - >>> client.init_device() - >>> client.ensure_unlocked() - >>> valid_session_id = client.session_id - """ - if new_session: - self.session_id = None - elif session_id is not None: - self.session_id = session_id - - resp = self.call_raw( - messages.Initialize( - session_id=self.session_id, - derive_cardano=derive_cardano, - ) - ) - if isinstance(resp, messages.Failure): - # can happen if `derive_cardano` does not match the current session - raise exceptions.TrezorFailure(resp) - if not isinstance(resp, messages.Features): - raise exceptions.TrezorException("Unexpected response to Initialize") - - if self.session_id is not None and resp.session_id == self.session_id: - LOG.info("Successfully resumed session") - elif session_id is not None: - LOG.info("Failed to resume session") - - # TT < 2.3.0 compatibility: - # _refresh_features will clear out the session_id field. We want this function - # to return its value, so that callers can rely on it being either a valid - # session_id, or None if we can't do that. - # Older TT FW does not report session_id in Features and self.session_id might - # be invalid because TT will not allocate a session_id until a passphrase - # exchange happens. - reported_session_id = resp.session_id - self._refresh_features(resp) - return reported_session_id - - def is_outdated(self) -> bool: - if self.features.bootloader_mode: - return False - return self.version < self.model.minimum_version - - def check_firmware_version(self, warn_only: bool = False) -> None: - if self.is_outdated(): - if warn_only: - warnings.warn("Firmware is out of date", stacklevel=2) - else: - raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) - - @expect(messages.Success, field="message", ret_type=str) - def ping( - self, - msg: str, - button_protection: bool = False, - ) -> "MessageType": - # We would like ping to work on any valid TrezorClient instance, but - # due to the protection modes, we need to go through self.call, and that will - # raise an exception if the firmware is too old. - # So we short-circuit the simplest variant of ping with call_raw. - if not button_protection: - # XXX this should be: `with self:` - try: - self.open() - resp = self.call_raw(messages.Ping(message=msg)) - if isinstance(resp, messages.ButtonRequest): - # device is PIN-locked. - # respond and hope for the best - resp = self._callback_button(resp) - return resp - finally: - self.close() - - return self.call( - messages.Ping(message=msg, button_protection=button_protection) - ) - - def get_device_id(self) -> Optional[str]: - return self.features.device_id - - @session - def lock(self, *, _refresh_features: bool = True) -> None: - """Lock the device. - - If the device does not have a PIN configured, this will do nothing. - Otherwise, a lock screen will be shown and the device will prompt for PIN - before further actions. - - This call does _not_ invalidate passphrase cache. If passphrase is in use, - the device will not prompt for it after unlocking. - - To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate - passphrase cache, use `clear_session()`. - """ - # Private argument _refresh_features can be used internally to avoid - # refreshing in cases where we will refresh soon anyway. This is used - # in TrezorClient.clear_session() - self.call(messages.LockDevice()) - if _refresh_features: - self.refresh_features() - - @session - def ensure_unlocked(self) -> None: - """Ensure the device is unlocked and a passphrase is cached. - - If the device is locked, this will prompt for PIN. If passphrase is enabled - and no passphrase is cached for the current session, the device will also - prompt for passphrase. - - After calling this method, further actions on the device will not prompt for - PIN or passphrase until the device is locked or the session becomes invalid. - """ - from .btc import get_address - - get_address(self, "Testnet", PASSPHRASE_TEST_PATH) - self.refresh_features() - - def end_session(self) -> None: - """Close the current session and clear cached passphrase. - - The session will become invalid until `init_device()` is called again. - If passphrase is enabled, further actions will prompt for it again. - - This is a no-op in bootloader mode, as it does not support session management. - """ - # since: 2.3.4, 1.9.4 - try: - if not self.features.bootloader_mode: - self.call(messages.EndSession()) - except exceptions.TrezorFailure: - # A failure most likely means that the FW version does not support - # the EndSession call. We ignore the failure and clear the local session_id. - # The client-side end result is identical. - pass - self.session_id = None - - @session - def clear_session(self) -> None: - """Lock the device and present a fresh session. - - The current session will be invalidated and a new one will be started. If the - device has PIN enabled, it will become locked. - - Equivalent to calling `lock()`, `end_session()` and `init_device()`. - """ - self.lock(_refresh_features=False) - self.end_session() - self.init_device(new_session=True) + + return TrezorClient(transport, **kwargs) + + +# class TrezorClient(t.Generic[UI]): +# """Trezor client, a connection to a Trezor device. + +# This class allows you to manage connection state, send and receive protobuf +# messages, handle user interactions, and perform some generic tasks +# (send a cancel message, initialize or clear a session, ping the device). +# """ + +# model: models.TrezorModel +# transport: "Transport" +# session_id: t.Optional[bytes] +# ui: UI +# features: messages.Features + +# def __init__( +# self, +# transport: "Transport", +# ui: UI, +# session_id: t.Optional[bytes] = None, +# derive_cardano: t.Optional[bool] = None, +# model: t.Optional[models.TrezorModel] = None, +# _init_device: bool = True, +# ) -> None: +# """Create a TrezorClient instance. + +# You have to provide a `transport`, i.e., a raw connection to the device. You can +# use `trezorlib.transport.get_transport` to find one. + +# You have to provide a UI implementation for the three kinds of interaction: +# - button request (notify the user that their interaction is needed) +# - PIN request (on T1, ask the user to input numbers for a PIN matrix) +# - passphrase request (ask the user to enter a passphrase) See `trezorlib.ui` for +# details. + +# You can supply a `session_id` you might have saved in the previous session. If +# you do, the user might not need to enter their passphrase again. + +# You can provide Trezor model information. If not provided, it is detected from +# the model name reported at initialization time. + +# By default, the instance will open a connection to the Trezor device, send an +# `Initialize` message, set up the `features` field from the response, and connect +# to a session. By specifying `_init_device=False`, this step is skipped. Notably, +# this means that `client.features` is unset. Use `client.init_device()` or +# `client.refresh_features()` to fix that, otherwise A LOT OF THINGS will break. +# Only use this if you are _sure_ that you know what you are doing. This feature +# might be removed at any time. +# """ +# LOG.info(f"creating client instance for device: {transport.get_path()}") +# # Here, self.model could be set to None. Unless _init_device is False, it will +# # get correctly reconfigured as part of the init_device flow. +# self.model = model # type: ignre ["None" is incompatible with "TrezorModel"] +# if self.model: +# self.mapping = self.model.default_mapping +# else: +# self.mapping = mapping.DEFAULT_MAPPING +# self.transport = transport +# self.ui = ui +# self.session_counter = 0 +# self.session_id = session_id +# if _init_device: +# self.init_device(session_id=session_id, derive_cardano=derive_cardano) +# self.resume_session() + +# def open(self) -> None: +# if self.session_counter == 0: +# session_id = self.transport.resume_session(b"") +# if self.session_id != session_id: +# print("Failed to resume session, allocated a new session") +# self.session_id = session_id +# self.transport.deprecated_begin_session() +# self.session_counter += 1 + +# def resume_session(self) -> None: +# new_id = self.transport.resume_session(self.session_id or b"") +# if self.session_id != new_id: +# print("Failed to resume session, allocated a new session") +# self.session_id = new_id + +# def close(self) -> None: +# self.session_counter = max(self.session_counter - 1, 0) +# if self.session_counter == 0: +# # TODO call EndSession here? +# self.transport.deprecated_end_session() + +# def cancel(self) -> None: +# self._raw_write(messages.Cancel()) + +# def call_raw(self, msg: "MessageType") -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 + +# self._raw_write(msg) +# x = self._raw_read() +# return x + +# def _raw_write(self, msg: "MessageType") -> None: +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# LOG.debug( +# f"sending message: {msg.__class__.__name__}", +# extra={"protobuf": msg}, +# ) +# msg_type, msg_bytes = self.mapping.encode(msg) +# LOG.log( +# DUMP_BYTES, +# f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", +# ) +# self.transport.write(msg_type, msg_bytes) + +# def _raw_read(self) -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# msg_type, msg_bytes = self.transport.read() +# print("type/data", msg_type, msg_bytes) +# LOG.log( +# DUMP_BYTES, +# f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", +# ) +# msg = self.mapping.decode(msg_type, msg_bytes) +# LOG.debug( +# f"received message: {msg.__class__.__name__}", +# extra={"protobuf": msg}, +# ) +# return msg + +# def _callback_pin(self, msg: messages.PinMatrixRequest) -> "MessageType": +# try: +# pin = self.ui.get_pin(msg.type) +# except exceptions.Cancelled: +# self.call_raw(messages.Cancel()) +# raise + +# if any(d not in "123456789" for d in pin) or not ( +# 1 <= len(pin) <= MAX_PIN_LENGTH +# ): +# self.call_raw(messages.Cancel()) +# raise ValueError("Invalid PIN provided") + +# resp = self.call_raw(messages.PinMatrixAck(pin=pin)) +# if isinstance(resp, messages.Failure) and resp.code in ( +# messages.FailureType.PinInvalid, +# messages.FailureType.PinCancelled, +# messages.FailureType.PinExpected, +# ): +# raise exceptions.PinException(resp.code, resp.message) +# else: +# return resp + +# def _callback_passphrase(self, msg: messages.PassphraseRequest) -> "MessageType": +# available_on_device = Capability.PassphraseEntry in self.features.capabilities + +# def send_passphrase( +# passphrase: t.Optional[str] = None, on_device: t.Optional[bool] = None +# ) -> "MessageType": +# msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) +# resp = self.call_raw(msg) +# if isinstance(resp, messages.Deprecated_PassphraseStateRequest): +# self.session_id = resp.state +# resp = self.call_raw(messages.Deprecated_PassphraseStateAck()) +# return resp + +# # short-circuit old style entry +# if msg._on_device is True: +# return send_passphrase(None, None) + +# try: +# passphrase = self.ui.get_passphrase(available_on_device=available_on_device) +# except exceptions.Cancelled: +# self.call_raw(messages.Cancel()) +# raise + +# if passphrase is PASSPHRASE_ON_DEVICE: +# if not available_on_device: +# self.call_raw(messages.Cancel()) +# raise RuntimeError("Device is not capable of entering passphrase") +# else: +# return send_passphrase(on_device=True) + +# # else process host-entered passphrase +# if not isinstance(passphrase, str): +# raise RuntimeError("Passphrase must be a str") +# passphrase = Mnemonic.normalize_string(passphrase) +# if len(passphrase) > MAX_PASSPHRASE_LENGTH: +# self.call_raw(messages.Cancel()) +# raise ValueError("Passphrase too long") + +# return send_passphrase(passphrase, on_device=False) + +# def _callback_button(self, msg: messages.ButtonRequest) -> "MessageType": +# __tracebackhide__ = True # for pytest # pylint: disable=W0612 +# # do this raw - send ButtonAck first, notify UI later +# self._raw_write(messages.ButtonAck()) +# self.ui.button_request(msg) +# return self._raw_read() + +# @session +# def call(self, msg: "MessageType") -> "MessageType": +# self.check_firmware_version() +# resp = self.call_raw(msg) +# while True: +# if isinstance(resp, messages.PinMatrixRequest): +# resp = self._callback_pin(resp) +# elif isinstance(resp, messages.PassphraseRequest): +# resp = self._callback_passphrase(resp) +# elif isinstance(resp, messages.ButtonRequest): +# resp = self._callback_button(resp) +# elif isinstance(resp, messages.Failure): +# print("self.call-failure") + +# if resp.code == messages.FailureType.ActionCancelled: +# raise exceptions.Cancelled +# raise exceptions.TrezorFailure(resp) +# else: +# print("self.call-end") +# return resp + +# def _refresh_features(self, features: messages.Features) -> None: +# """Update internal fields based on passed-in Features message.""" + +# if not self.model: +# # Trezor Model One bootloader 1.8.0 or older does not send model name +# model = models.by_internal_name(features.internal_model) +# if model is None: +# model = models.by_name(features.model or "1") +# if model is None: +# raise RuntimeError( +# "Unsupported Trezor model" +# f" (internal_model: {features.internal_model}, model: {features.model})" +# ) +# self.model = model + +# if features.vendor not in self.model.vendors: +# raise RuntimeError("Unsupported device") + +# self.features = features +# self.version = ( +# self.features.major_version, +# self.features.minor_version, +# self.features.patch_version, +# ) +# self.check_firmware_version(warn_only=True) +# if self.features.session_id is not None: +# self.session_id = self.features.session_id +# self.features.session_id = None + +# @session +# def refresh_features(self) -> messages.Features: +# """Reload features from the device. + +# Should be called after changing settings or performing operations that affect +# device state. +# """ +# resp = self.call_raw(messages.GetFeatures()) +# if not isinstance(resp, messages.Features): +# raise exceptions.TrezorException("Unexpected response to GetFeatures") +# self._refresh_features(resp) +# return resp + +# def init_device( +# self, +# *, +# session_id: t.Optional[bytes] = None, +# new_session: bool = False, +# derive_cardano: t.Optional[bool] = None, +# ) -> t.Optional[bytes]: +# """Initialize the device and return a session ID. + +# You can optionally specify a session ID. If the session still exists on the +# device, the same session ID will be returned and the session is resumed. +# Otherwise a different session ID is returned. + +# Specify `new_session=True` to open a fresh session. Since firmware version +# 1.9.0/2.3.0, the previous session will remain cached on the device, and can be +# resumed by calling `init_device` again with the appropriate session ID. + +# If neither `new_session` nor `session_id` is specified, the current session ID +# will be reused. If no session ID was cached, a new session ID will be allocated +# and returned. + +# # Version notes: + +# Trezor One older than 1.9.0 does not have session management. Optional arguments +# have no effect and the function returns None + +# Trezor T older than 2.3.0 does not have session cache. Requesting a new session +# will overwrite the old one. In addition, this function will always return None. +# A valid session_id can be obtained from the `session_id` attribute, but only +# after a passphrase-protected call is performed. You can use the following code: + +# >>> client.init_device() +# >>> client.ensure_unlocked() +# >>> valid_session_id = client.session_id +# """ +# if new_session: +# self.session_id = None +# elif session_id is not None: +# self.session_id = session_id + +# print("before init conn") + +# resp = self.transport.initialize_connection( +# mapping=self.mapping, +# session_id=session_id, +# derive_cardano=derive_cardano, +# ) +# print("here") +# if isinstance(resp, messages.Failure): +# # can happen if `derive_cardano` does not match the current session +# raise exceptions.TrezorFailure(resp) +# if not isinstance(resp, messages.Features): +# raise exceptions.TrezorException("Unexpected response to Initialize") + +# if self.session_id is not None and resp.session_id == self.session_id: +# LOG.info("Successfully resumed session") +# elif session_id is not None: +# LOG.info("Failed to resume session") + +# # TT < 2.3.0 compatibility: +# # _refresh_features will clear out the session_id field. We want this function +# # to return its value, so that callers can rely on it being either a valid +# # session_id, or None if we can't do that. +# # Older TT FW does not report session_id in Features and self.session_id might +# # be invalid because TT will not allocate a session_id until a passphrase +# # exchange happens. +# reported_session_id = resp.session_id +# self._refresh_features(resp) +# print("there:", reported_session_id) +# return reported_session_id + +# def is_outdated(self) -> bool: +# if self.features.bootloader_mode: +# return False +# return self.version < self.model.minimum_version + +# def check_firmware_version(self, warn_only: bool = False) -> None: +# if self.is_outdated(): +# if warn_only: +# warnings.warn("Firmware is out of date", stacklevel=2) +# else: +# raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) + +# @expect(messages.Success, field="message", ret_type=str) +# def ping( +# self, +# msg: str, +# button_protection: bool = False, +# ) -> "MessageType": +# # We would like ping to work on any valid TrezorClient instance, but +# # due to the protection modes, we need to go through self.call, and that will +# # raise an exception if the firmware is too old. +# # So we short-circuit the simplest variant of ping with call_raw. +# if not button_protection: +# # XXX this should be: `with self:` +# try: +# self.open() +# resp = self.call_raw(messages.Ping(message=msg)) +# if isinstance(resp, messages.ButtonRequest): +# # device is PIN-locked. +# # respond and hope for the best +# resp = self._callback_button(resp) +# return resp +# finally: +# self.close() + +# return self.call( +# messages.Ping(message=msg, button_protection=button_protection) +# ) + +# def get_device_id(self) -> t.Optional[str]: +# return self.features.device_id + +# @session +# def lock(self, *, _refresh_features: bool = True) -> None: +# """Lock the device. + +# If the device does not have a PIN configured, this will do nothing. +# Otherwise, a lock screen will be shown and the device will prompt for PIN +# before further actions. + +# This call does _not_ invalidate passphrase cache. If passphrase is in use, +# the device will not prompt for it after unlocking. + +# To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate +# passphrase cache, use `clear_session()`. +# """ +# # Private argument _refresh_features can be used internally to avoid +# # refreshing in cases where we will refresh soon anyway. This is used +# # in TrezorClient.clear_session() +# self.call(messages.LockDevice()) +# if _refresh_features: +# self.refresh_features() + +# @session +# def ensure_unlocked(self) -> None: +# """Ensure the device is unlocked and a passphrase is cached. + +# If the device is locked, this will prompt for PIN. If passphrase is enabled +# and no passphrase is cached for the current session, the device will also +# prompt for passphrase. + +# After calling this method, further actions on the device will not prompt for +# PIN or passphrase until the device is locked or the session becomes invalid. +# """ +# from .btc import get_address + +# get_address(self, "Testnet", PASSPHRASE_TEST_PATH) +# self.refresh_features() + +# def end_session(self) -> None: +# """Close the current session and clear cached passphrase. + +# The session will become invalid until `init_device()` is called again. +# If passphrase is enabled, further actions will prompt for it again. + +# This is a no-op in bootloader mode, as it does not support session management. +# """ +# # since: 2.3.4, 1.9.4 +# print("end session") +# try: +# if not self.features.bootloader_mode: +# self.transport.end_session(self.session_id or b"") +# # self.call(messages.EndSession()) +# except exceptions.TrezorFailure: +# # A failure most likely means that the FW version does not support +# # the EndSession call. We ignore the failure and clear the local session_id. +# # The client-side end result is identical. +# pass +# except ValueError as e: +# print(e) +# print(e.args) +# self.session_id = None + +# @session +# def clear_session(self) -> None: +# """Lock the device and present a fresh session. + +# The current session will be invalidated and a new one will be started. If the +# device has PIN enabled, it will become locked. + +# Equivalent to calling `lock()`, `end_session()` and `init_device()`. +# """ +# self.lock(_refresh_features=False) +# self.end_session() +# self.init_device(new_session=True) diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index 11fac1bc22..707401cf1b 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -21,47 +21,44 @@ import logging import re import textwrap import time +import typing as t from contextlib import contextmanager from copy import deepcopy from datetime import datetime from enum import Enum, IntEnum, auto from itertools import zip_longest from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Generator, - Iterable, - Iterator, - Sequence, - Tuple, - Union, -) from mnemonic import Mnemonic -from . import mapping, messages, models, protobuf -from .client import TrezorClient -from .exceptions import TrezorFailure +from . import btc, mapping, messages, models, protobuf +from .client import ( + MAX_PASSPHRASE_LENGTH, + MAX_PIN_LENGTH, + PASSPHRASE_ON_DEVICE, + TrezorClient, +) +from .exceptions import Cancelled, PinException, TrezorFailure from .log import DUMP_BYTES -from .messages import DebugWaitType -from .tools import expect +from .messages import Capability, DebugWaitType +from .tools import expect, parse_path +from .transport.session import Session, SessionV1 +from .transport.thp.protocol_v1 import ProtocolV1 -if TYPE_CHECKING: +if t.TYPE_CHECKING: from typing_extensions import Protocol from .messages import PinMatrixRequestType from .transport import Transport - ExpectedMessage = Union[ - protobuf.MessageType, type[protobuf.MessageType], "MessageFilter" + ExpectedMessage = t.Union[ + protobuf.MessageType, t.Type[protobuf.MessageType], "MessageFilter" ] - AnyDict = Dict[str, Any] + AnyDict = t.Dict[str, t.Any] class InputFunc(Protocol): + def __call__( self, hold_ms: int | None = None, @@ -70,6 +67,7 @@ if TYPE_CHECKING: EXPECTED_RESPONSES_CONTEXT_LINES = 3 +PASSPHRASE_TEST_PATH = parse_path("44h/1h/0h/0/0") LOG = logging.getLogger(__name__) @@ -104,11 +102,13 @@ class UnstructuredJSONReader: except json.JSONDecodeError: self.dict = {} - def top_level_value(self, key: str) -> Any: + def top_level_value(self, key: str) -> t.Any: return self.dict.get(key) - def find_objects_with_key_and_value(self, key: str, value: Any) -> list[AnyDict]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_objects_with_key_and_value( + self, key: str, value: t.Any + ) -> list["AnyDict"]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if data.get(key) == value: yield data @@ -121,7 +121,7 @@ class UnstructuredJSONReader: return list(recursively_find(self.dict)) def find_unique_object_with_key_and_value( - self, key: str, value: Any + self, key: str, value: t.Any ) -> AnyDict | None: objects = self.find_objects_with_key_and_value(key, value) if not objects: @@ -129,8 +129,10 @@ class UnstructuredJSONReader: assert len(objects) == 1 return objects[0] - def find_values_by_key(self, key: str, only_type: type | None = None) -> list[Any]: - def recursively_find(data: Any) -> Iterator[Any]: + def find_values_by_key( + self, key: str, only_type: type | None = None + ) -> list[t.Any]: + def recursively_find(data: t.Any) -> t.Iterator[t.Any]: if isinstance(data, dict): if key in data: yield data[key] @@ -148,8 +150,8 @@ class UnstructuredJSONReader: return values def find_unique_value_by_key( - self, key: str, default: Any, only_type: type | None = None - ) -> Any: + self, key: str, default: t.Any, only_type: type | None = None + ) -> t.Any: values = self.find_values_by_key(key, only_type=only_type) if not values: return default @@ -160,7 +162,7 @@ class UnstructuredJSONReader: class LayoutContent(UnstructuredJSONReader): """Contains helper functions to extract specific parts of the layout.""" - def __init__(self, json_tokens: Sequence[str]) -> None: + def __init__(self, json_tokens: t.Sequence[str]) -> None: json_str = "".join(json_tokens) super().__init__(json_str) @@ -422,11 +424,13 @@ def _make_input_func( class DebugLink: + def __init__(self, transport: "Transport", auto_interact: bool = True) -> None: self.transport = transport self.allow_interactions = auto_interact self.mapping = mapping.DEFAULT_MAPPING + self.protocol = ProtocolV1(self.transport, self.mapping) # To be set by TrezorClientDebugLink (is not known during creation time) self.model: models.TrezorModel | None = None self.version: tuple[int, int, int] = (0, 0, 0) @@ -479,10 +483,16 @@ class DebugLink: self.screen_text_file = file_path def open(self) -> None: - self.transport.begin_session() + self.transport.open() + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_begin_session() def close(self) -> None: - self.transport.end_session() + pass + # raise NotImplementedError + # TODO is this needed? + # self.transport.deprecated_end_session() def _write(self, msg: protobuf.MessageType) -> None: if self.waiting_for_layout_change: @@ -499,15 +509,10 @@ class DebugLink: DUMP_BYTES, f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", ) - self.transport.write(msg_type, msg_bytes) + self.protocol.write(msg) def _read(self) -> protobuf.MessageType: - ret_type, ret_bytes = self.transport.read() - LOG.log( - DUMP_BYTES, - f"received type {ret_type} ({len(ret_bytes)} bytes): {ret_bytes.hex()}", - ) - msg = self.mapping.decode(ret_type, ret_bytes) + msg = self.protocol.read() # Collapse tokens to make log use less lines. msg_for_log = msg @@ -521,18 +526,27 @@ class DebugLink: ) return msg - def _call(self, msg: protobuf.MessageType) -> Any: + def _call(self, msg: protobuf.MessageType) -> t.Any: self._write(msg) return self._read() - def state(self, wait_type: DebugWaitType | None = None) -> messages.DebugLinkState: + def state( + self, + wait_type: DebugWaitType | None = None, + thp_channel_id: bytes | None = None, + ) -> messages.DebugLinkState: if wait_type is None: wait_type = ( DebugWaitType.CURRENT_LAYOUT if self.has_global_layout else DebugWaitType.IMMEDIATE ) - result = self._call(messages.DebugLinkGetState(wait_layout=wait_type)) + result = self._call( + messages.DebugLinkGetState( + wait_layout=wait_type, + thp_channel_id=thp_channel_id, + ) + ) while not isinstance(result, (messages.Failure, messages.DebugLinkState)): result = self._read() if isinstance(result, messages.Failure): @@ -544,7 +558,7 @@ class DebugLink: def wait_layout(self, wait_for_external_change: bool = False) -> LayoutContent: # Next layout change will be caused by external event - # (e.g. device being auto-locked or as a result of device_handler.run(xxx)) + # (e.g. device being auto-locked or as a result of device_handler.run_with_session(xxx)) # and not by our debug actions/decisions. # Resetting the debug state so we wait for the next layout change # (and do not return the current state). @@ -560,7 +574,7 @@ class DebugLink: return LayoutContent(obj.tokens) @contextmanager - def wait_for_layout_change(self) -> Iterator[LayoutContent]: + def wait_for_layout_change(self) -> t.Iterator[LayoutContent]: # set up a dummy layout content object to be yielded layout_content = LayoutContent( ["DUMMY CONTENT, WAIT UNTIL THE END OF THE BLOCK :("] @@ -622,7 +636,7 @@ class DebugLink: return "".join([str(matrix.index(p) + 1) for p in pin]) - def read_recovery_word(self) -> Tuple[str | None, int | None]: + def read_recovery_word(self) -> t.Tuple[str | None, int | None]: state = self.state() return (state.recovery_fake_word, state.recovery_word_pos) @@ -700,7 +714,7 @@ class DebugLink: def click( self, - click: Tuple[int, int], + click: t.Tuple[int, int], hold_ms: int | None = None, wait: bool | None = None, ) -> LayoutContent: @@ -862,10 +876,10 @@ class DebugUI: self.clear() def clear(self) -> None: - self.pins: Iterator[str] | None = None + self.pins: t.Iterator[str] | None = None self.passphrase = "" - self.input_flow: Union[ - Generator[None, messages.ButtonRequest, None], object, None + self.input_flow: t.Union[ + t.Generator[None, messages.ButtonRequest, None], object, None ] = None def _default_input_flow(self, br: messages.ButtonRequest) -> None: @@ -896,7 +910,7 @@ class DebugUI: raise AssertionError("input flow ended prematurely") else: try: - assert isinstance(self.input_flow, Generator) + assert isinstance(self.input_flow, t.Generator) self.input_flow.send(br) except StopIteration: self.input_flow = self.INPUT_FLOW_DONE @@ -918,12 +932,15 @@ class DebugUI: class MessageFilter: - def __init__(self, message_type: type[protobuf.MessageType], **fields: Any) -> None: + + def __init__( + self, message_type: t.Type[protobuf.MessageType], **fields: t.Any + ) -> None: self.message_type = message_type - self.fields: Dict[str, Any] = {} + self.fields: t.Dict[str, t.Any] = {} self.update_fields(**fields) - def update_fields(self, **fields: Any) -> "MessageFilter": + def update_fields(self, **fields: t.Any) -> "MessageFilter": for name, value in fields.items(): try: self.fields[name] = self.from_message_or_type(value) @@ -971,7 +988,7 @@ class MessageFilter: return True def to_string(self, maxwidth: int = 80) -> str: - fields: list[Tuple[str, str]] = [] + fields: list[t.Tuple[str, str]] = [] for field in self.message_type.FIELDS.values(): if field.name not in self.fields: continue @@ -1001,7 +1018,8 @@ class MessageFilter: class MessageFilterGenerator: - def __getattr__(self, key: str) -> Callable[..., "MessageFilter"]: + + def __getattr__(self, key: str) -> t.Callable[..., "MessageFilter"]: message_type = getattr(messages, key) return MessageFilter(message_type).update_fields @@ -1009,6 +1027,245 @@ class MessageFilterGenerator: message_filters = MessageFilterGenerator() +class SessionDebugWrapper(Session): + def __init__(self, session: Session) -> None: + self._session = session + self.reset_debug_features() + if isinstance(session, SessionDebugWrapper): + raise Exception("Cannot wrap already wrapped session!") + + @property + def protocol_version(self) -> int: + return self.client.protocol_version + + @property + def client(self) -> TrezorClientDebugLink: + assert isinstance(self._session.client, TrezorClientDebugLink) + return self._session.client + + @property + def id(self) -> bytes: + return self._session.id + + def _write(self, msg: t.Any) -> None: + print("writing message:", msg.__class__.__name__) + self._session._write(self._filter_message(msg)) + + def _read(self) -> t.Any: + resp = self._filter_message(self._session._read()) + print("reading message:", resp.__class__.__name__) + if self.actual_responses is not None: + self.actual_responses.append(resp) + return resp + + def set_expected_responses( + self, + expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], + ) -> None: + """Set a sequence of expected responses to session calls. + + Within a given with-block, the list of received responses from device must + match the list of expected responses, otherwise an ``AssertionError`` is raised. + + If an expected response is given a field value other than ``None``, that field value + must exactly match the received field value. If a given field is ``None`` + (or unspecified) in the expected response, the received field value is not + checked. + + Each expected response can also be a tuple ``(bool, message)``. In that case, the + expected response is only evaluated if the first field is ``True``. + This is useful for differentiating sequences between Trezor models: + + >>> trezor_one = session.features.model == "1" + >>> session.set_expected_responses([ + >>> messages.ButtonRequest(code=ConfirmOutput), + >>> (trezor_one, messages.ButtonRequest(code=ConfirmOutput)), + >>> messages.Success(), + >>> ]) + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + # make sure all items are (bool, message) tuples + expected_with_validity = ( + e if isinstance(e, tuple) else (True, e) for e in expected + ) + + # only apply those items that are (True, message) + self.expected_responses = [ + MessageFilter.from_message_or_type(expected) + for valid, expected in expected_with_validity + if valid + ] + self.actual_responses = [] + + def lock(self, *, _refresh_features: bool = True) -> None: + """Lock the device. + + If the device does not have a PIN configured, this will do nothing. + Otherwise, a lock screen will be shown and the device will prompt for PIN + before further actions. + + This call does _not_ invalidate passphrase cache. If passphrase is in use, + the device will not prompt for it after unlocking. + + To invalidate passphrase cache, use `end_session()`. To lock _and_ invalidate + passphrase cache, use `clear_session()`. + """ + # TODO update the documentation above + # Private argument _refresh_features can be used internally to avoid + # refreshing in cases where we will refresh soon anyway. This is used + # in TrezorClient.clear_session() + self.call(messages.LockDevice()) + if _refresh_features: + self.refresh_features() + + def cancel(self) -> None: + self._write(messages.Cancel()) + + def ensure_unlocked(self) -> None: + btc.get_address(self, "Testnet", PASSPHRASE_TEST_PATH) + self.refresh_features() + + def set_filter( + self, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ) -> None: + """Configure a filter function for a specified message type. + + The `callback` must be a function that accepts a protobuf message, and returns + a (possibly modified) protobuf message of the same type. Whenever a message + is sent or received that matches `message_type`, `callback` is invoked on the + message and its result is substituted for the original. + + Useful for test scenarios with an active malicious actor on the wire. + """ + if not self.in_with_statement: + raise RuntimeError("Must be called inside 'with' statement") + + self.filters[message_type] = callback + + def _filter_message(self, msg: protobuf.MessageType) -> protobuf.MessageType: + message_type = msg.__class__ + callback = self.filters.get(message_type) + if callable(callback): + return callback(deepcopy(msg)) + else: + return msg + + def reset_debug_features(self) -> None: + """Prepare the debugging session for a new testcase. + + Clears all debugging state that might have been modified by a testcase. + """ + self.in_with_statement = False + self.expected_responses: list[MessageFilter] | None = None + self.actual_responses: list[protobuf.MessageType] | None = None + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, + ] = {} + self.button_callback = self.client.button_callback + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self._session.passphrase_callback + self.passphrase = self._session.passphrase + + def __enter__(self) -> "SessionDebugWrapper": + # For usage in with/expected_responses + if self.in_with_statement: + raise RuntimeError("Do not nest!") + self.in_with_statement = True + return self + + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + # copy expected/actual responses before clearing them + expected_responses = self.expected_responses + actual_responses = self.actual_responses + + # grab a copy of the inputflow generator to raise an exception through it + if isinstance(self.client.ui, DebugUI): + input_flow = self.client.ui.input_flow + else: + input_flow = None + + self.reset_debug_features() + + if exc_type is None: + # If no other exception was raised, evaluate missed responses + # (raises AssertionError on mismatch) + self._verify_responses(expected_responses, actual_responses) + if isinstance(input_flow, t.Generator): + # Ensure that the input flow is exhausted + try: + input_flow.throw( + AssertionError("input flow continues past end of test") + ) + except StopIteration: + pass + + elif isinstance(input_flow, t.Generator): + # Propagate the exception through the input flow, so that we see in + # traceback where it is stuck. + input_flow.throw(exc_type, value, traceback) + + @classmethod + def _verify_responses( + cls, + expected: list[MessageFilter] | None, + actual: list[protobuf.MessageType] | None, + ) -> None: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + + if expected is None and actual is None: + return + + assert expected is not None + assert actual is not None + + for i, (exp, act) in enumerate(zip_longest(expected, actual)): + if exp is None: + output = cls._expectation_lines(expected, i) + output.append("No more messages were expected, but we got:") + for resp in actual[i:]: + output.append( + textwrap.indent(protobuf.format_message(resp), " ") + ) + raise AssertionError("\n".join(output)) + + if act is None: + output = cls._expectation_lines(expected, i) + output.append("This and the following message was not received.") + raise AssertionError("\n".join(output)) + + if not exp.match(act): + output = cls._expectation_lines(expected, i) + output.append("Actually received:") + output.append(textwrap.indent(protobuf.format_message(act), " ")) + raise AssertionError("\n".join(output)) + + @staticmethod + def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: + start_at = max(current - EXPECTED_RESPONSES_CONTEXT_LINES, 0) + stop_at = min(current + EXPECTED_RESPONSES_CONTEXT_LINES + 1, len(expected)) + output: list[str] = [] + output.append("Expected responses:") + if start_at > 0: + output.append(f" (...{start_at} previous responses omitted)") + for i in range(start_at, stop_at): + exp = expected[i] + prefix = " " if i != current else ">>> " + output.append(textwrap.indent(exp.to_string(), prefix)) + if stop_at < len(expected): + omitted = len(expected) - stop_at + output.append(f" (...{omitted} following responses omitted)") + + output.append("") + return output + + class TrezorClientDebugLink(TrezorClient): # This class implements automatic responses # and other functionality for unit tests @@ -1034,54 +1291,165 @@ class TrezorClientDebugLink(TrezorClient): raise # set transport explicitly so that sync_responses can work + super().__init__(transport) + self.transport = transport + self.ui: DebugUI = DebugUI(self.debug) - self.reset_debug_features() + self.reset_debug_features(new_management_session=True) self.sync_responses() - super().__init__(transport, ui=self.ui) - # So that we can choose right screenshotting logic (T1 vs TT) # and know the supported debug capabilities self.debug.model = self.model self.debug.version = self.version + self.passphrase: str | None = None @property def layout_type(self) -> LayoutType: return self.debug.layout_type - def reset_debug_features(self) -> None: - """Prepare the debugging client for a new testcase. + def get_new_client(self) -> TrezorClientDebugLink: + return TrezorClientDebugLink(self.transport, self.debug.allow_interactions) + + def reset_debug_features(self, new_management_session: bool = False) -> None: + """ + Prepare the debugging client for a new testcase. Clears all debugging state that might have been modified by a testcase. """ self.ui: DebugUI = DebugUI(self.debug) + # self.pin_callback = self.ui.debug_callback_button self.in_with_statement = False self.expected_responses: list[MessageFilter] | None = None self.actual_responses: list[protobuf.MessageType] | None = None - self.filters: dict[ - type[protobuf.MessageType], - Callable[[protobuf.MessageType], protobuf.MessageType] | None, + self.filters: t.Dict[ + t.Type[protobuf.MessageType], + t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ] = {} + if new_management_session: + self._management_session = self.get_management_session(new_session=True) + + @property + def button_callback(self): + + def _callback_button(session: Session, msg: messages.ButtonRequest) -> t.Any: + __tracebackhide__ = True # for pytest # pylint: disable=W0612 + # do this raw - send ButtonAck first, notify UI later + session._write(messages.ButtonAck()) + self.ui.button_request(msg) + return session._read() + + return _callback_button + + @property + def pin_callback(self): + + def _callback_pin(session: Session, msg: messages.PinMatrixRequest) -> t.Any: + try: + pin = self.ui.get_pin(msg.type) + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if any(d not in "123456789" for d in pin) or not ( + 1 <= len(pin) <= MAX_PIN_LENGTH + ): + session.call_raw(messages.Cancel()) + raise ValueError("Invalid PIN provided") + resp = session.call_raw(messages.PinMatrixAck(pin=pin)) + if isinstance(resp, messages.Failure) and resp.code in ( + messages.FailureType.PinInvalid, + messages.FailureType.PinCancelled, + messages.FailureType.PinExpected, + ): + raise PinException(resp.code, resp.message) + else: + return resp + + return _callback_pin + + @property + def passphrase_callback(self): + def _callback_passphrase( + session: Session, msg: messages.PassphraseRequest + ) -> t.Any: + available_on_device = ( + Capability.PassphraseEntry in session.features.capabilities + ) + + def send_passphrase( + passphrase: str | None = None, on_device: bool | None = None + ) -> t.Any: + msg = messages.PassphraseAck(passphrase=passphrase, on_device=on_device) + resp = session.call_raw(msg) + if isinstance(resp, messages.Deprecated_PassphraseStateRequest): + # session.session_id = resp.state + resp = session.call_raw(messages.Deprecated_PassphraseStateAck()) + return resp + + # short-circuit old style entry + if msg._on_device is True: + return send_passphrase(None, None) + + try: + if session.passphrase is None and isinstance(session, SessionV1): + passphrase = self.ui.get_passphrase( + available_on_device=available_on_device + ) + else: + passphrase = session.passphrase + except Cancelled: + session.call_raw(messages.Cancel()) + raise + + if passphrase is PASSPHRASE_ON_DEVICE: + if not available_on_device: + session.call_raw(messages.Cancel()) + raise RuntimeError("Device is not capable of entering passphrase") + else: + return send_passphrase(on_device=True) + + # else process host-entered passphrase + if not isinstance(passphrase, str): + raise RuntimeError("Passphrase must be a str") + passphrase = Mnemonic.normalize_string(passphrase) + if len(passphrase) > MAX_PASSPHRASE_LENGTH: + session.call_raw(messages.Cancel()) + raise ValueError("Passphrase too long") + + return send_passphrase(passphrase, on_device=False) + + return _callback_passphrase def ensure_open(self) -> None: """Only open session if there isn't already an open one.""" - if self.session_counter == 0: - self.open() + # if self.session_counter == 0: + # self.open() + # TODO check if is this needed def open(self) -> None: - super().open() - if self.session_counter == 1: - self.debug.open() + pass + # TODO is this needed? + # self.debug.open() def close(self) -> None: - if self.session_counter == 1: - self.debug.close() - super().close() + pass + # TODO is this needed? + # self.debug.close() + + def get_session( + self, + passphrase: str | object | None = "", + derive_cardano: bool = False, + ) -> Session: + if isinstance(passphrase, str): + passphrase = Mnemonic.normalize_string(passphrase) + return super().get_session(passphrase, derive_cardano) def set_filter( self, - message_type: type[protobuf.MessageType], - callback: Callable[[protobuf.MessageType], protobuf.MessageType] | None, + message_type: t.Type[protobuf.MessageType], + callback: t.Callable[[protobuf.MessageType], protobuf.MessageType] | None, ) -> None: """Configure a filter function for a specified message type. @@ -1106,7 +1474,8 @@ class TrezorClientDebugLink(TrezorClient): return msg def set_input_flow( - self, input_flow: Generator[None, messages.ButtonRequest | None, None] + self, + input_flow: t.Generator[None, messages.ButtonRequest | None, None], ) -> None: """Configure a sequence of input events for the current with-block. @@ -1140,6 +1509,7 @@ class TrezorClientDebugLink(TrezorClient): if not hasattr(input_flow, "send"): raise RuntimeError("input_flow should be a generator function") self.ui.input_flow = input_flow + assert input_flow is not None input_flow.send(None) # start the generator def watch_layout(self, watch: bool = True) -> None: @@ -1162,7 +1532,7 @@ class TrezorClientDebugLink(TrezorClient): self.in_with_statement = True return self - def __exit__(self, exc_type: Any, value: Any, traceback: Any) -> None: + def __exit__(self, exc_type: t.Any, value: t.Any, traceback: t.Any) -> None: __tracebackhide__ = True # for pytest # pylint: disable=W0612 # copy expected/actual responses before clearing them @@ -1175,20 +1545,21 @@ class TrezorClientDebugLink(TrezorClient): else: input_flow = None - self.reset_debug_features() + self.reset_debug_features(new_management_session=False) if exc_type is None: # If no other exception was raised, evaluate missed responses # (raises AssertionError on mismatch) self._verify_responses(expected_responses, actual_responses) - elif isinstance(input_flow, Generator): + elif isinstance(input_flow, t.Generator): # Propagate the exception through the input flow, so that we see in # traceback where it is stuck. input_flow.throw(exc_type, value, traceback) def set_expected_responses( - self, expected: list[Union["ExpectedMessage", Tuple[bool, "ExpectedMessage"]]] + self, + expected: list["ExpectedMessage" | t.Tuple[bool, "ExpectedMessage"]], ) -> None: """Set a sequence of expected responses to client calls. @@ -1227,7 +1598,7 @@ class TrezorClientDebugLink(TrezorClient): ] self.actual_responses = [] - def use_pin_sequence(self, pins: Iterable[str]) -> None: + def use_pin_sequence(self, pins: t.Iterable[str]) -> None: """Respond to PIN prompts from device with the provided PINs. The sequence must be at least as long as the expected number of PIN prompts. """ @@ -1235,6 +1606,7 @@ class TrezorClientDebugLink(TrezorClient): def use_passphrase(self, passphrase: str) -> None: """Respond to passphrase prompts from device with the provided passphrase.""" + self.passphrase = passphrase self.ui.passphrase = Mnemonic.normalize_string(passphrase) def use_mnemonic(self, mnemonic: str) -> None: @@ -1244,15 +1616,14 @@ class TrezorClientDebugLink(TrezorClient): def _raw_read(self) -> protobuf.MessageType: __tracebackhide__ = True # for pytest # pylint: disable=W0612 - - resp = super()._raw_read() + resp = self.get_management_session()._read() resp = self._filter_message(resp) if self.actual_responses is not None: self.actual_responses.append(resp) return resp def _raw_write(self, msg: protobuf.MessageType) -> None: - return super()._raw_write(self._filter_message(msg)) + return self.get_management_session()._write(self._filter_message(msg)) @staticmethod def _expectation_lines(expected: list[MessageFilter], current: int) -> list[str]: @@ -1322,23 +1693,25 @@ class TrezorClientDebugLink(TrezorClient): # Start by canceling whatever is on screen. This will work to cancel T1 PIN # prompt, which is in TINY mode and does not respond to `Ping`. - cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) - self.transport.begin_session() + # TODO REMOVE: cancel_msg = mapping.DEFAULT_MAPPING.encode(messages.Cancel()) + self.transport.open() try: - self.transport.write(*cancel_msg) - + # self.protocol.write(messages.Cancel()) message = "SYNC" + secrets.token_hex(8) - ping_msg = mapping.DEFAULT_MAPPING.encode(messages.Ping(message=message)) - self.transport.write(*ping_msg) + self.get_management_session()._write(messages.Ping(message=message)) resp = None while resp != messages.Success(message=message): - msg_id, msg_bytes = self.transport.read() try: - resp = mapping.DEFAULT_MAPPING.decode(msg_id, msg_bytes) + resp = self.get_management_session()._read() + + raise Exception + except Exception: pass + finally: - self.transport.end_session() + pass # TODO fix + # self.transport.end_session(self.session_id or b"") def mnemonic_callback(self, _) -> str: word, pos = self.debug.read_recovery_word() @@ -1352,8 +1725,8 @@ class TrezorClientDebugLink(TrezorClient): @expect(messages.Success, field="message", ret_type=str) def load_device( - client: "TrezorClient", - mnemonic: Union[str, Iterable[str]], + session: "Session", + mnemonic: str | t.Iterable[str], pin: str | None, passphrase_protection: bool, label: str | None, @@ -1366,12 +1739,12 @@ def load_device( mnemonics = [Mnemonic.normalize_string(m) for m in mnemonic] - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call device.wipe() and try again." ) - resp = client.call( + resp = session.call( messages.LoadDevice( mnemonics=mnemonics, pin=pin, @@ -1382,7 +1755,7 @@ def load_device( no_backup=no_backup, ) ) - client.init_device() + session.refresh_features() return resp @@ -1391,11 +1764,11 @@ load_device_by_mnemonic = load_device @expect(messages.Success, field="message", ret_type=str) -def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: - if client.features.bootloader_mode is not True: +def prodtest_t1(session: "Session") -> protobuf.MessageType: + if session.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") - return client.call( + return session.call( messages.ProdTestT1( payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" ) @@ -1404,8 +1777,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: def record_screen( debug_client: "TrezorClientDebugLink", - directory: Union[str, None], - report_func: Union[Callable[[str], None], None] = None, + directory: str | None, + report_func: t.Callable[[str], None] | None = None, ) -> None: """Record screen changes into a specified directory. @@ -1451,5 +1824,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool: @expect(messages.Success, field="message", ret_type=str) -def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType: - return client.call(messages.DebugLinkOptigaSetSecMax()) +def optiga_set_sec_max(session: "Session") -> protobuf.MessageType: + return session.call(messages.DebugLinkOptigaSetSecMax()) diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index ebd7ca85f5..2542f00dde 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -23,20 +23,19 @@ from typing import TYPE_CHECKING, Callable, Iterable, Optional from . import messages from .exceptions import Cancelled, TrezorException -from .tools import Address, expect, session +from .tools import Address, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session RECOVERY_BACK = "\x08" # backspace character, sent literally @expect(messages.Success, field="message", ret_type=str) -@session def apply_settings( - client: "TrezorClient", + session: "Session", label: Optional[str] = None, language: Optional[str] = None, use_passphrase: Optional[bool] = None, @@ -67,13 +66,13 @@ def apply_settings( haptic_feedback=haptic_feedback, ) - out = client.call(settings) - client.refresh_features() + out = session.call(settings) + session.refresh_features() return out def _send_language_data( - client: "TrezorClient", + session: "Session", request: "messages.TranslationDataRequest", language_data: bytes, ) -> "MessageType": @@ -83,76 +82,70 @@ def _send_language_data( data_length = response.data_length data_offset = response.data_offset chunk = language_data[data_offset : data_offset + data_length] - response = client.call(messages.TranslationDataAck(data_chunk=chunk)) + response = session.call(messages.TranslationDataAck(data_chunk=chunk)) return response @expect(messages.Success, field="message", ret_type=str) -@session def change_language( - client: "TrezorClient", + session: "Session", language_data: bytes, show_display: bool | None = None, ) -> "MessageType": data_length = len(language_data) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) - response = client.call(msg) + response = session.call(msg) if data_length > 0: assert isinstance(response, messages.TranslationDataRequest) - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) assert isinstance(response, messages.Success) - client.refresh_features() # changing the language in features + session.refresh_features() # changing the language in features return response @expect(messages.Success, field="message", ret_type=str) -@session -def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": - out = client.call(messages.ApplyFlags(flags=flags)) - client.refresh_features() +def apply_flags(session: "Session", flags: int) -> "MessageType": + out = session.call(messages.ApplyFlags(flags=flags)) + session.refresh_features() return out @expect(messages.Success, field="message", ret_type=str) -@session -def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": - ret = client.call(messages.ChangePin(remove=remove)) - client.refresh_features() +def change_pin(session: "Session", remove: bool = False) -> "MessageType": + ret = session.call(messages.ChangePin(remove=remove)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session -def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": - ret = client.call(messages.ChangeWipeCode(remove=remove)) - client.refresh_features() +def change_wipe_code(session: "Session", remove: bool = False) -> "MessageType": + ret = session.call(messages.ChangeWipeCode(remove=remove)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session def sd_protect( - client: "TrezorClient", operation: messages.SdProtectOperationType + session: "Session", operation: messages.SdProtectOperationType ) -> "MessageType": - ret = client.call(messages.SdProtect(operation=operation)) - client.refresh_features() + ret = session.call(messages.SdProtect(operation=operation)) + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -@session -def wipe(client: "TrezorClient") -> "MessageType": - ret = client.call(messages.WipeDevice()) - if not client.features.bootloader_mode: - client.init_device() +def wipe(session: "Session") -> "MessageType": + + ret = session.call(messages.WipeDevice()) + # if not session.features.bootloader_mode: + # session.refresh_features() return ret -@session def recover( - client: "TrezorClient", + session: "Session", word_count: int = 24, passphrase_protection: bool = False, pin_protection: bool = True, @@ -188,13 +181,13 @@ def recover( if type is None: type = messages.RecoveryType.NormalRecovery - if client.features.model == "1" and input_callback is None: + if session.features.model == "1" and input_callback is None: raise RuntimeError("Input callback required for Trezor One") if word_count not in (12, 18, 24): raise ValueError("Invalid word count. Use 12/18/24") - if client.features.initialized and type == messages.RecoveryType.NormalRecovery: + if session.features.initialized and type == messages.RecoveryType.NormalRecovery: raise RuntimeError( "Device already initialized. Call device.wipe() and try again." ) @@ -216,24 +209,23 @@ def recover( msg.label = label msg.u2f_counter = u2f_counter - res = client.call(msg) + res = session.call(msg) while isinstance(res, messages.WordRequest): try: assert input_callback is not None inp = input_callback(res.type) - res = client.call(messages.WordAck(word=inp)) + res = session.call(messages.WordAck(word=inp)) except Cancelled: - res = client.call(messages.Cancel()) + res = session.call(messages.Cancel()) - client.init_device() + session.refresh_features() return res @expect(messages.Success, field="message", ret_type=str) -@session def reset( - client: "TrezorClient", + session: "Session", display_random: bool = False, strength: Optional[int] = None, passphrase_protection: bool = False, @@ -257,13 +249,13 @@ def reset( DeprecationWarning, ) - if client.features.initialized: + if session.features.initialized: raise RuntimeError( "Device is initialized already. Call wipe_device() and try again." ) if strength is None: - if client.features.model == "1": + if session.features.model == "1": strength = 256 else: strength = 128 @@ -280,25 +272,24 @@ def reset( backup_type=backup_type, ) - resp = client.call(msg) + resp = session.call(msg) if not isinstance(resp, messages.EntropyRequest): raise RuntimeError("Invalid response, expected EntropyRequest") external_entropy = os.urandom(32) # LOG.debug("Computer generated entropy: " + external_entropy.hex()) - ret = client.call(messages.EntropyAck(entropy=external_entropy)) - client.init_device() + ret = session.call(messages.EntropyAck(entropy=external_entropy)) + session.refresh_features() # TODO is necessary? return ret @expect(messages.Success, field="message", ret_type=str) -@session def backup( - client: "TrezorClient", + session: "Session", group_threshold: Optional[int] = None, groups: Iterable[tuple[int, int]] = (), ) -> "MessageType": - ret = client.call( + ret = session.call( messages.BackupDevice( group_threshold=group_threshold, groups=[ @@ -307,37 +298,36 @@ def backup( ], ) ) - client.refresh_features() + session.refresh_features() return ret @expect(messages.Success, field="message", ret_type=str) -def cancel_authorization(client: "TrezorClient") -> "MessageType": - return client.call(messages.CancelAuthorization()) +def cancel_authorization(session: "Session") -> "MessageType": + return session.call(messages.CancelAuthorization()) @expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes) -def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType": - resp = client.call(messages.UnlockPath(address_n=n)) +def unlock_path(session: "Session", n: "Address") -> "MessageType": + resp = session.call(messages.UnlockPath(address_n=n)) # Cancel the UnlockPath workflow now that we have the authentication code. try: - client.call(messages.Cancel()) + session.call(messages.Cancel()) except Cancelled: return resp else: raise TrezorException("Unexpected response in UnlockPath flow") -@session @expect(messages.Success, field="message", ret_type=str) def reboot_to_bootloader( - client: "TrezorClient", + session: "Session", boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, firmware_header: Optional[bytes] = None, language_data: bytes = b"", ) -> "MessageType": - response = client.call( + response = session.call( messages.RebootToBootloader( boot_command=boot_command, firmware_header=firmware_header, @@ -345,42 +335,37 @@ def reboot_to_bootloader( ) ) if isinstance(response, messages.TranslationDataRequest): - response = _send_language_data(client, response, language_data) + response = _send_language_data(session, response, language_data) return response -@session @expect(messages.Success, field="message", ret_type=str) -def show_device_tutorial(client: "TrezorClient") -> "MessageType": - return client.call(messages.ShowDeviceTutorial()) - - -@session -@expect(messages.Success, field="message", ret_type=str) -def unlock_bootloader(client: "TrezorClient") -> "MessageType": - return client.call(messages.UnlockBootloader()) +def show_device_tutorial(session: "Session") -> "MessageType": + return session.call(messages.ShowDeviceTutorial()) @expect(messages.Success, field="message", ret_type=str) -@session -def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType": +def unlock_bootloader(session: "Session") -> "MessageType": + return session.call(messages.UnlockBootloader()) + + +@expect(messages.Success, field="message", ret_type=str) +def set_busy(session: "Session", expiry_ms: Optional[int]) -> "MessageType": """Sets or clears the busy state of the device. In the busy state the device shows a "Do not disconnect" message instead of the homescreen. Setting `expiry_ms=None` clears the busy state. """ - ret = client.call(messages.SetBusy(expiry_ms=expiry_ms)) - client.refresh_features() + ret = session.call(messages.SetBusy(expiry_ms=expiry_ms)) + session.refresh_features() return ret @expect(messages.AuthenticityProof) -def authenticate(client: "TrezorClient", challenge: bytes): - return client.call(messages.AuthenticateDevice(challenge=challenge)) +def authenticate(session: "Session", challenge: bytes): + return session.call(messages.AuthenticateDevice(challenge=challenge)) @expect(messages.Success, field="message", ret_type=str) -def set_brightness( - client: "TrezorClient", value: Optional[int] = None -) -> "MessageType": - return client.call(messages.SetBrightness(value=value)) +def set_brightness(session: "Session", value: Optional[int] = None) -> "MessageType": + return session.call(messages.SetBrightness(value=value)) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index 1ffaafb4ab..fffe6f0adc 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -18,12 +18,12 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages -from .tools import b58decode, expect, session +from .tools import b58decode, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session def name_to_number(name: str) -> int: @@ -321,17 +321,16 @@ def parse_transaction_json( @expect(messages.EosPublicKey) def get_public_key( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> "MessageType": - response = client.call( + response = session.call( messages.EosGetPublicKey(address_n=n, show_display=show_display) ) return response -@session def sign_tx( - client: "TrezorClient", + session: "Session", address: "Address", transaction: dict, chain_id: str, @@ -347,11 +346,11 @@ def sign_tx( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) try: while isinstance(response, messages.EosTxActionRequest): - response = client.call(actions.pop(0)) + response = session.call(actions.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 1cf2eeeaed..60eaa3366b 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -18,12 +18,12 @@ import re from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import definitions, exceptions, messages -from .tools import expect, prepare_message_bytes, session, unharden +from .tools import expect, prepare_message_bytes, unharden if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session def int_to_big_endian(value: int) -> bytes: @@ -163,13 +163,13 @@ def network_from_address_n( @expect(messages.EthereumAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumGetAddress( address_n=n, show_display=show_display, @@ -181,16 +181,15 @@ def get_address( @expect(messages.EthereumPublicKey) def get_public_node( - client: "TrezorClient", n: "Address", show_display: bool = False + session: "Session", n: "Address", show_display: bool = False ) -> "MessageType": - return client.call( + return session.call( messages.EthereumGetPublicKey(address_n=n, show_display=show_display) ) -@session def sign_tx( - client: "TrezorClient", + session: "Session", n: "Address", nonce: int, gas_price: int, @@ -226,13 +225,13 @@ def sign_tx( data, chunk = data[1024:], data[:1024] msg.data_initial_chunk = chunk - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -247,9 +246,8 @@ def sign_tx( return response.signature_v, response.signature_r, response.signature_s -@session def sign_tx_eip1559( - client: "TrezorClient", + session: "Session", n: "Address", *, nonce: int, @@ -282,13 +280,13 @@ def sign_tx_eip1559( chunkify=chunkify, ) - response = client.call(msg) + response = session.call(msg) assert isinstance(response, messages.EthereumTxRequest) while response.data_length is not None: data_length = response.data_length data, chunk = data[data_length:], data[:data_length] - response = client.call(messages.EthereumTxAck(data_chunk=chunk)) + response = session.call(messages.EthereumTxAck(data_chunk=chunk)) assert isinstance(response, messages.EthereumTxRequest) assert response.signature_v is not None @@ -299,13 +297,13 @@ def sign_tx_eip1559( @expect(messages.EthereumMessageSignature) def sign_message( - client: "TrezorClient", + session: "Session", n: "Address", message: AnyStr, encoded_network: Optional[bytes] = None, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumSignMessage( address_n=n, message=prepare_message_bytes(message), @@ -317,7 +315,7 @@ def sign_message( @expect(messages.EthereumTypedDataSignature) def sign_typed_data( - client: "TrezorClient", + session: "Session", n: "Address", data: Dict[str, Any], *, @@ -333,7 +331,7 @@ def sign_typed_data( metamask_v4_compat=metamask_v4_compat, definitions=definitions, ) - response = client.call(request) + response = session.call(request) # Sending all the types while isinstance(response, messages.EthereumTypedDataStructRequest): @@ -349,7 +347,7 @@ def sign_typed_data( members.append(struct_member) request = messages.EthereumTypedDataStructAck(members=members) - response = client.call(request) + response = session.call(request) # Sending the whole message that should be signed while isinstance(response, messages.EthereumTypedDataValueRequest): @@ -362,7 +360,7 @@ def sign_typed_data( member_typename = data["primaryType"] member_data = data["message"] else: - client.cancel() + # TODO session.cancel() raise exceptions.TrezorException("Root index can only be 0 or 1") # It can be asking for a nested structure (the member path being [X, Y, Z, ...]) @@ -385,20 +383,20 @@ def sign_typed_data( encoded_data = encode_data(member_data, member_typename) request = messages.EthereumTypedDataValueAck(value=encoded_data) - response = client.call(request) + response = session.call(request) return response def verify_message( - client: "TrezorClient", + session: "Session", address: str, signature: bytes, message: AnyStr, chunkify: bool = False, ) -> bool: try: - resp = client.call( + resp = session.call( messages.EthereumVerifyMessage( address=address, signature=signature, @@ -413,13 +411,13 @@ def verify_message( @expect(messages.EthereumTypedDataSignature) def sign_typed_data_hash( - client: "TrezorClient", + session: "Session", n: "Address", domain_hash: bytes, message_hash: Optional[bytes], encoded_network: Optional[bytes] = None, ) -> "MessageType": - return client.call( + return session.call( messages.EthereumSignTypedHash( address_n=n, domain_separator_hash=domain_hash, diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index 4ed6f22951..90064bb238 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -20,8 +20,8 @@ from . import messages from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session @expect( @@ -29,27 +29,27 @@ if TYPE_CHECKING: field="credentials", ret_type=List[messages.WebAuthnCredential], ) -def list_credentials(client: "TrezorClient") -> "MessageType": - return client.call(messages.WebAuthnListResidentCredentials()) +def list_credentials(session: "Session") -> "MessageType": + return session.call(messages.WebAuthnListResidentCredentials()) @expect(messages.Success, field="message", ret_type=str) -def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType": - return client.call( +def add_credential(session: "Session", credential_id: bytes) -> "MessageType": + return session.call( messages.WebAuthnAddResidentCredential(credential_id=credential_id) ) @expect(messages.Success, field="message", ret_type=str) -def remove_credential(client: "TrezorClient", index: int) -> "MessageType": - return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) +def remove_credential(session: "Session", index: int) -> "MessageType": + return session.call(messages.WebAuthnRemoveResidentCredential(index=index)) @expect(messages.Success, field="message", ret_type=str) -def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": - return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) +def set_counter(session: "Session", u2f_counter: int) -> "MessageType": + return session.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) @expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) -def get_next_counter(client: "TrezorClient") -> "MessageType": - return client.call(messages.GetNextU2FCounter()) +def get_next_counter(session: "Session") -> "MessageType": + return session.call(messages.GetNextU2FCounter()) diff --git a/python/src/trezorlib/firmware/__init__.py b/python/src/trezorlib/firmware/__init__.py index 5cc5d8830c..a588b160e1 100644 --- a/python/src/trezorlib/firmware/__init__.py +++ b/python/src/trezorlib/firmware/__init__.py @@ -20,7 +20,7 @@ from hashlib import blake2s from typing_extensions import Protocol, TypeGuard from .. import messages -from ..tools import expect, session +from ..tools import expect from .core import VendorFirmware from .legacy import LegacyFirmware, LegacyV2Firmware @@ -38,7 +38,7 @@ if True: from .vendor import * # noqa: F401, F403 if t.TYPE_CHECKING: - from ..client import TrezorClient + from ..transport.session import Session T = t.TypeVar("T", bound="FirmwareType") @@ -72,20 +72,19 @@ def is_onev2(fw: "FirmwareType") -> TypeGuard[LegacyFirmware]: # ====== Client functions ====== # -@session def update( - client: "TrezorClient", + session: "Session", data: bytes, progress_update: t.Callable[[int], t.Any] = lambda _: None, ): - if client.features.bootloader_mode is False: + if session.features.bootloader_mode is False: raise RuntimeError("Device must be in bootloader mode") - resp = client.call(messages.FirmwareErase(length=len(data))) + resp = session.call(messages.FirmwareErase(length=len(data))) # TREZORv1 method if isinstance(resp, messages.Success): - resp = client.call(messages.FirmwareUpload(payload=data)) + resp = session.call(messages.FirmwareUpload(payload=data)) progress_update(len(data)) if isinstance(resp, messages.Success): return @@ -97,7 +96,7 @@ def update( length = resp.length payload = data[resp.offset : resp.offset + length] digest = blake2s(payload).digest() - resp = client.call(messages.FirmwareUpload(payload=payload, hash=digest)) + resp = session.call(messages.FirmwareUpload(payload=payload, hash=digest)) progress_update(length) if isinstance(resp, messages.Success): @@ -107,5 +106,5 @@ def update( @expect(messages.FirmwareHash, field="hash", ret_type=bytes) -def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]): - return client.call(messages.GetFirmwareHash(challenge=challenge)) +def get_hash(session: "Session", challenge: t.Optional[bytes]): + return session.call(messages.GetFirmwareHash(challenge=challenge)) diff --git a/python/src/trezorlib/mapping.py b/python/src/trezorlib/mapping.py index d50324d586..04b75f0aa5 100644 --- a/python/src/trezorlib/mapping.py +++ b/python/src/trezorlib/mapping.py @@ -17,6 +17,7 @@ from __future__ import annotations import io +import logging from types import ModuleType from typing import Dict, Optional, Tuple, Type, TypeVar @@ -25,6 +26,7 @@ from typing_extensions import Self from . import messages, protobuf T = TypeVar("T") +LOG = logging.getLogger(__name__) class ProtobufMapping: @@ -63,11 +65,21 @@ class ProtobufMapping: wire_type = self.class_to_type_override.get(type(msg), msg.MESSAGE_WIRE_TYPE) if wire_type is None: raise ValueError("Cannot encode class without wire type") - + LOG.debug("encoding wire type %d", wire_type) buf = io.BytesIO() protobuf.dump_message(buf, msg) return wire_type, buf.getvalue() + def encode_without_wire_type(self, msg: protobuf.MessageType) -> bytes: + """Serialize a Python protobuf class. + + Returns the byte representation of the protobuf message. + """ + + buf = io.BytesIO() + protobuf.dump_message(buf, msg) + return buf.getvalue() + def decode(self, msg_wire_type: int, msg_bytes: bytes) -> protobuf.MessageType: """Deserialize a protobuf message into a Python class.""" cls = self.type_to_class[msg_wire_type] @@ -83,7 +95,9 @@ class ProtobufMapping: mapping = cls() message_types = getattr(module, "MessageType") - for entry in message_types: + thp_message_types = getattr(module, "ThpMessageType") + + for entry in (*message_types, *thp_message_types): msg_class = getattr(module, entry.name, None) if msg_class is None: raise ValueError( diff --git a/python/src/trezorlib/messages.py b/python/src/trezorlib/messages.py index b52119311f..86fd70dfd8 100644 --- a/python/src/trezorlib/messages.py +++ b/python/src/trezorlib/messages.py @@ -43,6 +43,8 @@ class FailureType(IntEnum): PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 + ThpUnallocatedSession = 15 + InvalidProtocol = 16 FirmwareError = 99 @@ -400,6 +402,34 @@ class TezosBallotType(IntEnum): Pass = 2 +class ThpMessageType(IntEnum): + ThpCreateNewSession = 1000 + ThpNewSession = 1001 + ThpStartPairingRequest = 1008 + ThpPairingPreparationsFinished = 1009 + ThpCredentialRequest = 1010 + ThpCredentialResponse = 1011 + ThpEndRequest = 1012 + ThpEndResponse = 1013 + ThpCodeEntryCommitment = 1016 + ThpCodeEntryChallenge = 1017 + ThpCodeEntryCpaceHost = 1018 + ThpCodeEntryCpaceTrezor = 1019 + ThpCodeEntryTag = 1020 + ThpCodeEntrySecret = 1021 + ThpQrCodeTag = 1024 + ThpQrCodeSecret = 1025 + ThpNfcUnidirectionalTag = 1032 + ThpNfcUnidirectionalSecret = 1033 + + +class ThpPairingMethod(IntEnum): + NoMethod = 1 + CodeEntry = 2 + QrCode = 3 + NFC_Unidirectional = 4 + + class MessageType(IntEnum): Initialize = 0 Ping = 1 @@ -4100,6 +4130,7 @@ class DebugLinkGetState(protobuf.MessageType): 1: protobuf.Field("wait_word_list", "bool", repeated=False, required=False, default=None), 2: protobuf.Field("wait_word_pos", "bool", repeated=False, required=False, default=None), 3: protobuf.Field("wait_layout", "DebugWaitType", repeated=False, required=False, default=DebugWaitType.IMMEDIATE), + 4: protobuf.Field("thp_channel_id", "bytes", repeated=False, required=False, default=None), } def __init__( @@ -4108,10 +4139,12 @@ class DebugLinkGetState(protobuf.MessageType): wait_word_list: Optional["bool"] = None, wait_word_pos: Optional["bool"] = None, wait_layout: Optional["DebugWaitType"] = DebugWaitType.IMMEDIATE, + thp_channel_id: Optional["bytes"] = None, ) -> None: self.wait_word_list = wait_word_list self.wait_word_pos = wait_word_pos self.wait_layout = wait_layout + self.thp_channel_id = thp_channel_id class DebugLinkState(protobuf.MessageType): @@ -4130,6 +4163,9 @@ class DebugLinkState(protobuf.MessageType): 11: protobuf.Field("reset_word_pos", "uint32", repeated=False, required=False, default=None), 12: protobuf.Field("mnemonic_type", "BackupType", repeated=False, required=False, default=None), 13: protobuf.Field("tokens", "string", repeated=True, required=False, default=None), + 14: protobuf.Field("thp_pairing_code_entry_code", "uint32", repeated=False, required=False, default=None), + 15: protobuf.Field("thp_pairing_code_qr_code", "bytes", repeated=False, required=False, default=None), + 16: protobuf.Field("thp_pairing_code_nfc_unidirectional", "bytes", repeated=False, required=False, default=None), } def __init__( @@ -4148,6 +4184,9 @@ class DebugLinkState(protobuf.MessageType): recovery_word_pos: Optional["int"] = None, reset_word_pos: Optional["int"] = None, mnemonic_type: Optional["BackupType"] = None, + thp_pairing_code_entry_code: Optional["int"] = None, + thp_pairing_code_qr_code: Optional["bytes"] = None, + thp_pairing_code_nfc_unidirectional: Optional["bytes"] = None, ) -> None: self.tokens: Sequence["str"] = tokens if tokens is not None else [] self.layout = layout @@ -4162,6 +4201,9 @@ class DebugLinkState(protobuf.MessageType): self.recovery_word_pos = recovery_word_pos self.reset_word_pos = reset_word_pos self.mnemonic_type = mnemonic_type + self.thp_pairing_code_entry_code = thp_pairing_code_entry_code + self.thp_pairing_code_qr_code = thp_pairing_code_qr_code + self.thp_pairing_code_nfc_unidirectional = thp_pairing_code_nfc_unidirectional class DebugLinkStop(protobuf.MessageType): @@ -7824,6 +7866,280 @@ class TezosManagerTransfer(protobuf.MessageType): self.amount = amount +class ThpDeviceProperties(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("internal_model", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("model_variant", "uint32", repeated=False, required=False, default=None), + 3: protobuf.Field("bootloader_mode", "bool", repeated=False, required=False, default=None), + 4: protobuf.Field("protocol_version", "uint32", repeated=False, required=False, default=None), + 5: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None), + } + + def __init__( + self, + *, + pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None, + internal_model: Optional["str"] = None, + model_variant: Optional["int"] = None, + bootloader_mode: Optional["bool"] = None, + protocol_version: Optional["int"] = None, + ) -> None: + self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else [] + self.internal_model = internal_model + self.model_variant = model_variant + self.bootloader_mode = bootloader_mode + self.protocol_version = protocol_version + + +class ThpHandshakeCompletionReqNoisePayload(protobuf.MessageType): + MESSAGE_WIRE_TYPE = None + FIELDS = { + 1: protobuf.Field("host_pairing_credential", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("pairing_methods", "ThpPairingMethod", repeated=True, required=False, default=None), + } + + def __init__( + self, + *, + pairing_methods: Optional[Sequence["ThpPairingMethod"]] = None, + host_pairing_credential: Optional["bytes"] = None, + ) -> None: + self.pairing_methods: Sequence["ThpPairingMethod"] = pairing_methods if pairing_methods is not None else [] + self.host_pairing_credential = host_pairing_credential + + +class ThpCreateNewSession(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1000 + FIELDS = { + 1: protobuf.Field("passphrase", "string", repeated=False, required=False, default=None), + 2: protobuf.Field("on_device", "bool", repeated=False, required=False, default=None), + 3: protobuf.Field("derive_cardano", "bool", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + passphrase: Optional["str"] = None, + on_device: Optional["bool"] = None, + derive_cardano: Optional["bool"] = None, + ) -> None: + self.passphrase = passphrase + self.on_device = on_device + self.derive_cardano = derive_cardano + + +class ThpNewSession(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1001 + FIELDS = { + 1: protobuf.Field("new_session_id", "uint32", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + new_session_id: Optional["int"] = None, + ) -> None: + self.new_session_id = new_session_id + + +class ThpStartPairingRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1008 + FIELDS = { + 1: protobuf.Field("host_name", "string", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_name: Optional["str"] = None, + ) -> None: + self.host_name = host_name + + +class ThpPairingPreparationsFinished(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1009 + + +class ThpCodeEntryCommitment(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1016 + FIELDS = { + 1: protobuf.Field("commitment", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + commitment: Optional["bytes"] = None, + ) -> None: + self.commitment = commitment + + +class ThpCodeEntryChallenge(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1017 + FIELDS = { + 1: protobuf.Field("challenge", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + challenge: Optional["bytes"] = None, + ) -> None: + self.challenge = challenge + + +class ThpCodeEntryCpaceHost(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1018 + FIELDS = { + 1: protobuf.Field("cpace_host_public_key", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_host_public_key: Optional["bytes"] = None, + ) -> None: + self.cpace_host_public_key = cpace_host_public_key + + +class ThpCodeEntryCpaceTrezor(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1019 + FIELDS = { + 1: protobuf.Field("cpace_trezor_public_key", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + cpace_trezor_public_key: Optional["bytes"] = None, + ) -> None: + self.cpace_trezor_public_key = cpace_trezor_public_key + + +class ThpCodeEntryTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1020 + FIELDS = { + 2: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpCodeEntrySecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1021 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpQrCodeTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1024 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpQrCodeSecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1025 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpNfcUnidirectionalTag(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1032 + FIELDS = { + 1: protobuf.Field("tag", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + tag: Optional["bytes"] = None, + ) -> None: + self.tag = tag + + +class ThpNfcUnidirectionalSecret(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1033 + FIELDS = { + 1: protobuf.Field("secret", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + secret: Optional["bytes"] = None, + ) -> None: + self.secret = secret + + +class ThpCredentialRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1010 + FIELDS = { + 1: protobuf.Field("host_static_pubkey", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + host_static_pubkey: Optional["bytes"] = None, + ) -> None: + self.host_static_pubkey = host_static_pubkey + + +class ThpCredentialResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1011 + FIELDS = { + 1: protobuf.Field("trezor_static_pubkey", "bytes", repeated=False, required=False, default=None), + 2: protobuf.Field("credential", "bytes", repeated=False, required=False, default=None), + } + + def __init__( + self, + *, + trezor_static_pubkey: Optional["bytes"] = None, + credential: Optional["bytes"] = None, + ) -> None: + self.trezor_static_pubkey = trezor_static_pubkey + self.credential = credential + + +class ThpEndRequest(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1012 + + +class ThpEndResponse(protobuf.MessageType): + MESSAGE_WIRE_TYPE = 1013 + + class ThpCredentialMetadata(protobuf.MessageType): MESSAGE_WIRE_TYPE = None FIELDS = { diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index 4ed6f5aa81..d951c52d7c 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -20,25 +20,25 @@ from . import messages from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.Entropy, field="entropy", ret_type=bytes) -def get_entropy(client: "TrezorClient", size: int) -> "MessageType": - return client.call(messages.GetEntropy(size=size)) +def get_entropy(session: "Session", size: int) -> "MessageType": + return session.call(messages.GetEntropy(size=size)) @expect(messages.SignedIdentity) def sign_identity( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, ecdsa_curve_name: Optional[str] = None, ) -> "MessageType": - return client.call( + return session.call( messages.SignIdentity( identity=identity, challenge_hidden=challenge_hidden, @@ -50,12 +50,12 @@ def sign_identity( @expect(messages.ECDHSessionKey) def get_ecdh_session_key( - client: "TrezorClient", + session: "Session", identity: messages.IdentityType, peer_public_key: bytes, ecdsa_curve_name: Optional[str] = None, ) -> "MessageType": - return client.call( + return session.call( messages.GetECDHSessionKey( identity=identity, peer_public_key=peer_public_key, @@ -66,7 +66,7 @@ def get_ecdh_session_key( @expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def encrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -74,7 +74,7 @@ def encrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> "MessageType": - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -89,7 +89,7 @@ def encrypt_keyvalue( @expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def decrypt_keyvalue( - client: "TrezorClient", + session: "Session", n: "Address", key: str, value: bytes, @@ -97,7 +97,7 @@ def decrypt_keyvalue( ask_on_decrypt: bool = True, iv: bytes = b"", ) -> "MessageType": - return client.call( + return session.call( messages.CipherKeyValue( address_n=n, key=key, @@ -111,5 +111,5 @@ def decrypt_keyvalue( @expect(messages.Nonce, field="nonce", ret_type=bytes) -def get_nonce(client: "TrezorClient"): - return client.call(messages.GetNonce()) +def get_nonce(session: "Session"): + return session.call(messages.GetNonce()) diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index 5bce7574e8..5b071626b4 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -20,9 +20,9 @@ from . import messages from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session # MAINNET = 0 @@ -33,13 +33,13 @@ if TYPE_CHECKING: @expect(messages.MoneroAddress, field="address", ret_type=bytes) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", show_display: bool = False, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.MoneroGetAddress( address_n=n, show_display=show_display, @@ -51,10 +51,10 @@ def get_address( @expect(messages.MoneroWatchKey) def get_watch_key( - client: "TrezorClient", + session: "Session", n: "Address", network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, ) -> "MessageType": - return client.call( + return session.call( messages.MoneroGetWatchKey(address_n=n, network_type=network_type) ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 3a67aec72c..6aa087757a 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -21,9 +21,9 @@ from . import exceptions, messages from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_IMPORTANCE_TRANSFER = 0x0801 @@ -198,13 +198,13 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig @expect(messages.NEMAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", n: "Address", network: int, show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.NEMGetAddress( address_n=n, network=network, show_display=show_display, chunkify=chunkify ) @@ -213,7 +213,7 @@ def get_address( @expect(messages.NEMSignedTx) def sign_tx( - client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False + session: "Session", n: "Address", transaction: dict, chunkify: bool = False ) -> "MessageType": try: msg = create_sign_tx(transaction, chunkify=chunkify) @@ -222,4 +222,4 @@ def sign_tx( assert msg.transaction is not None msg.transaction.address_n = n - return client.call(msg) + return session.call(msg) diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 7a953b8fac..f026236c07 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -21,9 +21,9 @@ from .protobuf import dict_to_proto from .tools import dict_from_camelcase, expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") @@ -31,12 +31,12 @@ REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") @expect(messages.RippleAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.RippleGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -45,14 +45,14 @@ def get_address( @expect(messages.RippleSignedTx) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", msg: messages.RippleSignTx, chunkify: bool = False, ) -> "MessageType": msg.address_n = address_n msg.chunkify = chunkify - return client.call(msg) + return session.call(msg) def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: diff --git a/python/src/trezorlib/solana.py b/python/src/trezorlib/solana.py index be7f2e5fcb..1a228b2f95 100644 --- a/python/src/trezorlib/solana.py +++ b/python/src/trezorlib/solana.py @@ -4,29 +4,29 @@ from . import messages from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType + from .transport.session import Session @expect(messages.SolanaPublicKey) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, ) -> "MessageType": - return client.call( + return session.call( messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display) ) @expect(messages.SolanaAddress) def get_address( - client: "TrezorClient", + session: "Session", address_n: List[int], show_display: bool, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.SolanaGetAddress( address_n=address_n, show_display=show_display, @@ -37,12 +37,12 @@ def get_address( @expect(messages.SolanaTxSignature) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: List[int], serialized_tx: bytes, additional_info: Optional[messages.SolanaTxAdditionalInfo], ) -> "MessageType": - return client.call( + return session.call( messages.SolanaSignTx( address_n=address_n, serialized_tx=serialized_tx, diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index ebf81e4fd0..12a75ca5d8 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -21,9 +21,9 @@ from . import exceptions, messages from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session StellarMessageType = Union[ messages.StellarAccountMergeOp, @@ -325,12 +325,12 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: @expect(messages.StellarAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.StellarGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -338,7 +338,7 @@ def get_address( def sign_tx( - client: "TrezorClient", + session: "Session", tx: messages.StellarSignTx, operations: List["StellarMessageType"], address_n: "Address", @@ -354,10 +354,10 @@ def sign_tx( # 3. Receive a StellarTxOpRequest message # 4. Send operations one by one until all operations have been sent. If there are more operations to sign, the device will send a StellarTxOpRequest message # 5. The final message received will be StellarSignedTx which is returned from this method - resp = client.call(tx) + resp = session.call(tx) try: while isinstance(resp, messages.StellarTxOpRequest): - resp = client.call(operations.pop(0)) + resp = session.call(operations.pop(0)) except IndexError: # pop from empty list raise exceptions.TrezorException( diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index cff06ed6c8..b74dc56259 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -20,19 +20,19 @@ from . import messages from .tools import expect if TYPE_CHECKING: - from .client import TrezorClient from .protobuf import MessageType from .tools import Address + from .transport.session import Session @expect(messages.TezosAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.TezosGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -41,12 +41,12 @@ def get_address( @expect(messages.TezosPublicKey, field="public_key", ret_type=str) def get_public_key( - client: "TrezorClient", + session: "Session", address_n: "Address", show_display: bool = False, chunkify: bool = False, ) -> "MessageType": - return client.call( + return session.call( messages.TezosGetPublicKey( address_n=address_n, show_display=show_display, chunkify=chunkify ) @@ -55,11 +55,11 @@ def get_public_key( @expect(messages.TezosSignedTx) def sign_tx( - client: "TrezorClient", + session: "Session", address_n: "Address", sign_tx_msg: messages.TezosSignTx, chunkify: bool = False, ) -> "MessageType": sign_tx_msg.address_n = address_n sign_tx_msg.chunkify = chunkify - return client.call(sign_tx_msg) + return session.call(sign_tx_msg) diff --git a/python/src/trezorlib/tools.py b/python/src/trezorlib/tools.py index 4fd1558ec2..3e9bd1c560 100644 --- a/python/src/trezorlib/tools.py +++ b/python/src/trezorlib/tools.py @@ -40,7 +40,7 @@ if TYPE_CHECKING: # More details: https://www.python.org/dev/peps/pep-0612/ from typing import TypeVar - from typing_extensions import Concatenate, ParamSpec + from typing_extensions import ParamSpec from . import client from .protobuf import MessageType @@ -284,23 +284,6 @@ def expect( return decorator -def session( - f: "Callable[Concatenate[TrezorClient, P], R]", -) -> "Callable[Concatenate[TrezorClient, P], R]": - # Decorator wraps a BaseClient method - # with session activation / deactivation - @functools.wraps(f) - def wrapped_f(client: "TrezorClient", *args: "P.args", **kwargs: "P.kwargs") -> "R": - __tracebackhide__ = True # for pytest # pylint: disable=W0612 - client.open() - try: - return f(client, *args, **kwargs) - finally: - client.close() - - return wrapped_f - - # de-camelcasifier # https://stackoverflow.com/a/1176023/222189 diff --git a/python/src/trezorlib/transport/__init__.py b/python/src/trezorlib/transport/__init__.py index b04876b6b7..45d05150c2 100644 --- a/python/src/trezorlib/transport/__init__.py +++ b/python/src/trezorlib/transport/__init__.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,24 +14,18 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging -from typing import ( - TYPE_CHECKING, - Iterable, - List, - Optional, - Sequence, - Tuple, - Type, - TypeVar, -) +import typing as t from ..exceptions import TrezorException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel - T = TypeVar("T", bound="Transport") + T = t.TypeVar("T", bound="Transport") + LOG = logging.getLogger(__name__) @@ -41,7 +35,7 @@ https://github.com/trezor/trezor-common/blob/master/udev/51-trezor.rules """.strip() -MessagePayload = Tuple[int, bytes] +MessagePayload = t.Tuple[int, bytes] class TransportException(TrezorException): @@ -53,72 +47,54 @@ class DeviceIsBusy(TransportException): class Transport: - """Raw connection to a Trezor device. - - Transport subclass represents a kind of communication link: Trezor Bridge, WebUSB - or USB-HID connection, or UDP socket of listening emulator(s). - It can also enumerate devices available over this communication link, and return - them as instances. - - Transport instance is a thing that: - - can be identified and requested by a string URI-like path - - can open and close sessions, which enclose related operations - - can read and write protobuf messages - - You need to implement a new Transport subclass if you invent a new way to connect - a Trezor device to a computer. - """ - PATH_PREFIX: str - ENABLED = False - def __str__(self) -> str: - return self.get_path() + @classmethod + def enumerate( + cls: t.Type["T"], models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["T"]: + raise NotImplementedError + + @classmethod + def find_by_path(cls: t.Type["T"], path: str, prefix_search: bool = False) -> "T": + for device in cls.enumerate(): + + if device.get_path() == path: + return device + + if prefix_search and device.get_path().startswith(path): + return device + + raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") def get_path(self) -> str: raise NotImplementedError - def begin_session(self) -> None: - raise NotImplementedError - - def end_session(self) -> None: - raise NotImplementedError - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - def find_debug(self: "T") -> "T": raise NotImplementedError - @classmethod - def enumerate( - cls: Type["T"], models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["T"]: + def open(self) -> None: raise NotImplementedError - @classmethod - def find_by_path(cls: Type["T"], path: str, prefix_search: bool = False) -> "T": - for device in cls.enumerate(): - if ( - path is None - or device.get_path() == path - or (prefix_search and device.get_path().startswith(path)) - ): - return device + def close(self) -> None: + raise NotImplementedError - raise TransportException(f"{cls.PATH_PREFIX} device not found: {path}") + def write_chunk(self, chunk: bytes) -> None: + raise NotImplementedError + + def read_chunk(self) -> bytes: + raise NotImplementedError + + CHUNK_SIZE: t.ClassVar[int] -def all_transports() -> Iterable[Type["Transport"]]: +def all_transports() -> t.Iterable[t.Type["Transport"]]: from .bridge import BridgeTransport from .hid import HidTransport from .udp import UdpTransport from .webusb import WebUsbTransport - transports: Tuple[Type["Transport"], ...] = ( + transports: t.Tuple[t.Type["Transport"], ...] = ( BridgeTransport, HidTransport, UdpTransport, @@ -128,9 +104,9 @@ def all_transports() -> Iterable[Type["Transport"]]: def enumerate_devices( - models: Optional[Iterable["TrezorModel"]] = None, -) -> Sequence["Transport"]: - devices: List["Transport"] = [] + models: t.Iterable["TrezorModel"] | None = None, +) -> t.Sequence["Transport"]: + devices: t.List["Transport"] = [] for transport in all_transports(): name = transport.__name__ try: @@ -145,9 +121,7 @@ def enumerate_devices( return devices -def get_transport( - path: Optional[str] = None, prefix_search: bool = False -) -> "Transport": +def get_transport(path: str | None = None, prefix_search: bool = False) -> "Transport": if path is None: try: return next(iter(enumerate_devices())) diff --git a/python/src/trezorlib/transport/bridge.py b/python/src/trezorlib/transport/bridge.py index e0c34a8f70..8d69e5b253 100644 --- a/python/src/trezorlib/transport/bridge.py +++ b/python/src/trezorlib/transport/bridge.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,24 +14,30 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import struct -from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional +import typing as t import requests from ..log import DUMP_PACKETS from . import DeviceIsBusy, MessagePayload, Transport, TransportException -if TYPE_CHECKING: +if t.TYPE_CHECKING: from ..models import TrezorModel LOG = logging.getLogger(__name__) +PROTOCOL_VERSION_1 = 1 +PROTOCOL_VERSION_2 = 2 + TREZORD_HOST = "http://127.0.0.1:21325" TREZORD_ORIGIN_HEADER = {"Origin": "https://python.trezor.io"} TREZORD_VERSION_MODERN = (2, 0, 25) +TREZORD_VERSION_THP_SUPPORT = (2, 0, 31) # TODO add correct value CONNECTION = requests.Session() CONNECTION.headers.update(TREZORD_ORIGIN_HEADER) @@ -45,7 +51,7 @@ class BridgeException(TransportException): super().__init__(f"trezord: {path} failed with code {status}: {message}") -def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: +def call_bridge(path: str, data: str | None = None) -> requests.Response: url = TREZORD_HOST + "/" + path r = CONNECTION.post(url, data=data) if r.status_code != 200: @@ -53,10 +59,54 @@ def call_bridge(path: str, data: Optional[str] = None) -> requests.Response: return r -def is_legacy_bridge() -> bool: +def get_bridge_version() -> t.Tuple[int, ...]: config = call_bridge("configure").json() - version_tuple = tuple(map(int, config["version"].split("."))) - return version_tuple < TREZORD_VERSION_MODERN + return tuple(map(int, config["version"].split("."))) + + +def is_legacy_bridge() -> bool: + return get_bridge_version() < TREZORD_VERSION_MODERN + + +def supports_protocolV2() -> bool: + return get_bridge_version() >= TREZORD_VERSION_THP_SUPPORT + + +def detect_protocol_version(transport: "BridgeTransport") -> int: + from .. import mapping, messages + from ..messages import FailureType + + protocol_version = PROTOCOL_VERSION_1 + request_type, request_data = mapping.DEFAULT_MAPPING.encode(messages.Initialize()) + transport.deprecated_begin_session() + transport.deprecated_write(request_type, request_data) + + response_type, response_data = transport.deprecated_read() + response = mapping.DEFAULT_MAPPING.decode(response_type, response_data) + transport.deprecated_begin_session() + if isinstance(response, messages.Failure): + if response.code == FailureType.InvalidProtocol: + LOG.debug("Protocol V2 detected") + protocol_version = PROTOCOL_VERSION_2 + + return protocol_version + + +def _is_transport_valid(transport: "BridgeTransport") -> bool: + is_valid = ( + supports_protocolV2() + or detect_protocol_version(transport) == PROTOCOL_VERSION_1 + ) + if not is_valid: + LOG.warning("Detected unsupported Bridge transport!") + return is_valid + + +def filter_invalid_bridge_transports( + transports: t.Iterable["BridgeTransport"], +) -> t.Sequence["BridgeTransport"]: + """Filters out invalid bridge transports. Keeps only valid ones.""" + return [t for t in transports if _is_transport_valid(t)] class BridgeHandle: @@ -84,7 +134,7 @@ class BridgeHandleModern(BridgeHandle): class BridgeHandleLegacy(BridgeHandle): def __init__(self, transport: "BridgeTransport") -> None: super().__init__(transport) - self.request: Optional[str] = None + self.request: str | None = None def write_buf(self, buf: bytes) -> None: if self.request is not None: @@ -112,13 +162,12 @@ class BridgeTransport(Transport): ENABLED: bool = True def __init__( - self, device: Dict[str, Any], legacy: bool, debug: bool = False + self, device: t.Dict[str, t.Any], legacy: bool, debug: bool = False ) -> None: if legacy and debug: raise TransportException("Debugging not supported on legacy Bridge") - self.device = device - self.session: Optional[str] = None + self.session: str | None = device["session"] self.debug = debug self.legacy = legacy @@ -135,7 +184,7 @@ class BridgeTransport(Transport): raise TransportException("Debug device not available") return BridgeTransport(self.device, self.legacy, debug=True) - def _call(self, action: str, data: Optional[str] = None) -> requests.Response: + def _call(self, action: str, data: str | None = None) -> requests.Response: session = self.session or "null" uri = action + "/" + str(session) if self.debug: @@ -144,17 +193,20 @@ class BridgeTransport(Transport): @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None - ) -> Iterable["BridgeTransport"]: + cls, _models: t.Iterable["TrezorModel"] | None = None + ) -> t.Iterable["BridgeTransport"]: try: legacy = is_legacy_bridge() - return [ - BridgeTransport(dev, legacy) for dev in call_bridge("enumerate").json() - ] + return filter_invalid_bridge_transports( + [ + BridgeTransport(dev, legacy) + for dev in call_bridge("enumerate").json() + ] + ) except Exception: return [] - def begin_session(self) -> None: + def deprecated_begin_session(self) -> None: try: data = self._call("acquire/" + self.device["path"]) except BridgeException as e: @@ -163,18 +215,32 @@ class BridgeTransport(Transport): raise self.session = data.json()["session"] - def end_session(self) -> None: + def deprecated_end_session(self) -> None: if not self.session: return self._call("release") self.session = None - def write(self, message_type: int, message_data: bytes) -> None: + def deprecated_write(self, message_type: int, message_data: bytes) -> None: header = struct.pack(">HL", message_type, len(message_data)) self.handle.write_buf(header + message_data) - def read(self) -> MessagePayload: + def deprecated_read(self) -> MessagePayload: data = self.handle.read_buf() headerlen = struct.calcsize(">HL") msg_type, datalen = struct.unpack(">HL", data[:headerlen]) return msg_type, data[headerlen : headerlen + datalen] + + def open(self) -> None: + pass + # TODO self.handle.open() + + def close(self) -> None: + pass + # TODO self.handle.close() + + def write_chunk(self, chunk: bytes) -> None: # TODO check if it works :) + self.handle.write_buf(chunk) + + def read_chunk(self) -> bytes: # TODO check if it works :) + return self.handle.read_buf() diff --git a/python/src/trezorlib/transport/hid.py b/python/src/trezorlib/transport/hid.py index 65fa08ccd7..995fd6960c 100644 --- a/python/src/trezorlib/transport/hid.py +++ b/python/src/trezorlib/transport/hid.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,15 +14,16 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import sys import time -from typing import Any, Dict, Iterable, List, Optional +import typing as t from ..log import DUMP_PACKETS from ..models import TREZOR_ONE, TrezorModel -from . import UDEV_RULES_STR, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, Transport, TransportException LOG = logging.getLogger(__name__) @@ -35,23 +36,61 @@ except Exception as e: HID_IMPORTED = False -HidDevice = Dict[str, Any] -HidDeviceHandle = Any +HidDevice = t.Dict[str, t.Any] +HidDeviceHandle = t.Any -class HidHandle: - def __init__( - self, path: bytes, serial: str, probe_hid_version: bool = False - ) -> None: - self.path = path - self.serial = serial +class HidTransport(Transport): + """ + HidTransport implements transport over USB HID interface. + """ + + PATH_PREFIX = "hid" + ENABLED = HID_IMPORTED + + def __init__(self, device: HidDevice, probe_hid_version: bool = False) -> None: + self.device = device + self.device_path = device["path"] + self.device_serial_number = device["serial_number"] self.handle: HidDeviceHandle = None self.hid_version = None if probe_hid_version else 2 + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" + + @classmethod + def enumerate( + cls, models: t.Iterable["TrezorModel"] | None = None, debug: bool = False + ) -> t.Iterable["HidTransport"]: + if models is None: + models = {TREZOR_ONE} + usb_ids = [id for model in models for id in model.usb_ids] + + devices: t.List["HidTransport"] = [] + for dev in hid.enumerate(0, 0): + usb_id = (dev["vendor_id"], dev["product_id"]) + if usb_id not in usb_ids: + continue + if debug: + if not is_debuglink(dev): + continue + else: + if not is_wirelink(dev): + continue + devices.append(HidTransport(dev)) + return devices + + def find_debug(self) -> "HidTransport": + # For v1 protocol, find debug USB interface for the same serial number + for debug in HidTransport.enumerate(debug=True): + if debug.device["serial_number"] == self.device["serial_number"]: + return debug + raise TransportException("Debug HID device not found") + def open(self) -> None: self.handle = hid.device() try: - self.handle.open_path(self.path) + self.handle.open_path(self.device_path) except (IOError, OSError) as e: if sys.platform.startswith("linux"): e.args = e.args + (UDEV_RULES_STR,) @@ -62,11 +101,11 @@ class HidHandle: # and we wouldn't even know. # So we check that the serial matches what we expect. serial = self.handle.get_serial_number_string() - if serial != self.serial: + if serial != self.device_serial_number: self.handle.close() self.handle = None raise TransportException( - f"Unexpected device {serial} on path {self.path.decode()}" + f"Unexpected device {serial} on path {self.device_path.decode()}" ) self.handle.set_nonblocking(True) @@ -77,7 +116,7 @@ class HidHandle: def close(self) -> None: if self.handle is not None: # reload serial, because device.wipe() can reset it - self.serial = self.handle.get_serial_number_string() + self.device_serial_number = self.handle.get_serial_number_string() self.handle.close() self.handle = None @@ -115,53 +154,6 @@ class HidHandle: raise TransportException("Unknown HID version") -class HidTransport(ProtocolBasedTransport): - """ - HidTransport implements transport over USB HID interface. - """ - - PATH_PREFIX = "hid" - ENABLED = HID_IMPORTED - - def __init__(self, device: HidDevice) -> None: - self.device = device - self.handle = HidHandle(device["path"], device["serial_number"]) - - super().__init__(protocol=ProtocolV1(self.handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{self.device['path'].decode()}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, debug: bool = False - ) -> Iterable["HidTransport"]: - if models is None: - models = {TREZOR_ONE} - usb_ids = [id for model in models for id in model.usb_ids] - - devices: List["HidTransport"] = [] - for dev in hid.enumerate(0, 0): - usb_id = (dev["vendor_id"], dev["product_id"]) - if usb_id not in usb_ids: - continue - if debug: - if not is_debuglink(dev): - continue - else: - if not is_wirelink(dev): - continue - devices.append(HidTransport(dev)) - return devices - - def find_debug(self) -> "HidTransport": - # For v1 protocol, find debug USB interface for the same serial number - for debug in HidTransport.enumerate(debug=True): - if debug.device["serial_number"] == self.device["serial_number"]: - return debug - raise TransportException("Debug HID device not found") - - def is_wirelink(dev: HidDevice) -> bool: return dev["usage_page"] == 0xFF00 or dev["interface_number"] == 0 diff --git a/python/src/trezorlib/transport/protocol.py b/python/src/trezorlib/transport/protocol.py deleted file mode 100644 index a5a0ee6be4..0000000000 --- a/python/src/trezorlib/transport/protocol.py +++ /dev/null @@ -1,165 +0,0 @@ -# This file is part of the Trezor project. -# -# Copyright (C) 2012-2022 SatoshiLabs and contributors -# -# This library is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License version 3 -# as published by the Free Software Foundation. -# -# This library is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the License along with this library. -# If not, see . - -import logging -import struct -from typing import Tuple - -from typing_extensions import Protocol as StructuralType - -from . import MessagePayload, Transport - -REPLEN = 64 - -V2_FIRST_CHUNK = 0x01 -V2_NEXT_CHUNK = 0x02 -V2_BEGIN_SESSION = 0x03 -V2_END_SESSION = 0x04 - -LOG = logging.getLogger(__name__) - - -class Handle(StructuralType): - """PEP 544 structural type for Handle functionality. - (called a "Protocol" in the proposed PEP, name which is impractical here) - - Handle is a "physical" layer for a protocol. - It can open/close a connection and read/write bare data in 64-byte chunks. - - Functionally we gain nothing from making this an (abstract) base class for handle - implementations, so this definition is for type hinting purposes only. You can, - but don't have to, inherit from it. - """ - - def open(self) -> None: ... - - def close(self) -> None: ... - - def read_chunk(self) -> bytes: ... - - def write_chunk(self, chunk: bytes) -> None: ... - - -class Protocol: - """Wire protocol that can communicate with a Trezor device, given a Handle. - - A Protocol implements the part of the Transport API that relates to communicating - logical messages over a physical layer. It is a thing that can: - - open and close sessions, - - send and receive protobuf messages, - given the ability to: - - open and close physical connections, - - and send and receive binary chunks. - - For now, the class also handles session counting and opening the underlying Handle. - This will probably be removed in the future. - - We will need a new Protocol class if we change the way a Trezor device encapsulates - its messages. - """ - - def __init__(self, handle: Handle) -> None: - self.handle = handle - self.session_counter = 0 - - # XXX we might be able to remove this now that TrezorClient does session handling - def begin_session(self) -> None: - if self.session_counter == 0: - self.handle.open() - self.session_counter += 1 - - def end_session(self) -> None: - self.session_counter = max(self.session_counter - 1, 0) - if self.session_counter == 0: - self.handle.close() - - def read(self) -> MessagePayload: - raise NotImplementedError - - def write(self, message_type: int, message_data: bytes) -> None: - raise NotImplementedError - - -class ProtocolBasedTransport(Transport): - """Transport that implements its communications through a Protocol. - - Intended as a base class for implementations that proxy their communication - operations to a Protocol. - """ - - def __init__(self, protocol: Protocol) -> None: - self.protocol = protocol - - def write(self, message_type: int, message_data: bytes) -> None: - self.protocol.write(message_type, message_data) - - def read(self) -> MessagePayload: - return self.protocol.read() - - def begin_session(self) -> None: - self.protocol.begin_session() - - def end_session(self) -> None: - self.protocol.end_session() - - -class ProtocolV1(Protocol): - """Protocol version 1. Currently (11/2018) in use on all Trezors. - Does not understand sessions. - """ - - HEADER_LEN = struct.calcsize(">HL") - - def write(self, message_type: int, message_data: bytes) -> None: - header = struct.pack(">HL", message_type, len(message_data)) - buffer = bytearray(b"##" + header + message_data) - - while buffer: - # Report ID, data padded to 63 bytes - chunk = b"?" + buffer[: REPLEN - 1] - chunk = chunk.ljust(REPLEN, b"\x00") - self.handle.write_chunk(chunk) - buffer = buffer[63:] - - def read(self) -> MessagePayload: - buffer = bytearray() - # Read header with first part of message data - msg_type, datalen, first_chunk = self.read_first() - buffer.extend(first_chunk) - - # Read the rest of the message - while len(buffer) < datalen: - buffer.extend(self.read_next()) - - return msg_type, buffer[:datalen] - - def read_first(self) -> Tuple[int, int, bytes]: - chunk = self.handle.read_chunk() - if chunk[:3] != b"?##": - raise RuntimeError("Unexpected magic characters") - try: - msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) - except Exception: - raise RuntimeError("Cannot parse header") - - data = chunk[3 + self.HEADER_LEN :] - return msg_type, datalen, data - - def read_next(self) -> bytes: - chunk = self.handle.read_chunk() - if chunk[:1] != b"?": - raise RuntimeError("Unexpected magic characters") - return chunk[1:] diff --git a/python/src/trezorlib/transport/session.py b/python/src/trezorlib/transport/session.py new file mode 100644 index 0000000000..6b6f4cce2c --- /dev/null +++ b/python/src/trezorlib/transport/session.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import logging +import typing as t + +from .. import exceptions, messages, models +from .thp.protocol_v1 import ProtocolV1 +from .thp.protocol_v2 import ProtocolV2 + +if t.TYPE_CHECKING: + from ..client import TrezorClient + +LOG = logging.getLogger(__name__) + + +class Session: + button_callback: t.Callable[[Session, t.Any], t.Any] | None = None + pin_callback: t.Callable[[Session, t.Any], t.Any] | None = None + passphrase_callback: t.Callable[[Session, t.Any], t.Any] | None = None + + def __init__( + self, client: TrezorClient, id: bytes, passphrase: str | object | None = None + ) -> None: + self.client = client + self._id = id + self.passphrase = passphrase + + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | object | None, derive_cardano: bool + ) -> Session: + raise NotImplementedError + + def call(self, msg: t.Any) -> t.Any: + # TODO self.check_firmware_version() + resp = self.call_raw(msg) + + while True: + if isinstance(resp, messages.PinMatrixRequest): + if self.pin_callback is None: + raise Exception # TODO + resp = self.pin_callback(self, resp) + elif isinstance(resp, messages.PassphraseRequest): + if self.passphrase_callback is None: + raise Exception # TODO + resp = self.passphrase_callback(self, resp) + elif isinstance(resp, messages.ButtonRequest): + if self.button_callback is None: + raise Exception # TODO + resp = self.button_callback(self, resp) + elif isinstance(resp, messages.Failure): + if resp.code == messages.FailureType.ActionCancelled: + raise exceptions.Cancelled + raise exceptions.TrezorFailure(resp) + else: + return resp + + def call_raw(self, msg: t.Any) -> t.Any: + self._write(msg) + return self._read() + + def _write(self, msg: t.Any) -> None: + raise NotImplementedError + + def _read(self) -> t.Any: + raise NotImplementedError + + def refresh_features(self) -> None: + self.client.refresh_features() + + def end(self) -> t.Any: + return self.call(messages.EndSession()) + + def ping(self, message: str, button_protection: bool | None = None) -> str: + resp: messages.Success = self.call( + messages.Ping(message=message, button_protection=button_protection) + ) + return resp.message or "" + + @property + def features(self) -> messages.Features: + return self.client.features + + @property + def model(self) -> models.TrezorModel: + return self.client.model + + @property + def version(self) -> t.Tuple[int, int, int]: + return self.client.version + + @property + def id(self) -> bytes: + return self._id + + @id.setter + def id(self, value: bytes) -> None: + if not isinstance(value, bytes): + raise ValueError("id must be of type bytes") + self._id = value + + +class SessionV1(Session): + derive_cardano: bool | None = False + + @classmethod + def new( + cls, + client: TrezorClient, + passphrase: str | object = "", + derive_cardano: bool = False, + session_id: bytes | None = None, + ) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, id=session_id or b"") + + session._init_callbacks() + session.passphrase = passphrase + session.derive_cardano = derive_cardano + session.init_session(session.derive_cardano) + return session + + @classmethod + def resume_from_id(cls, client: TrezorClient, session_id: bytes) -> SessionV1: + assert isinstance(client.protocol, ProtocolV1) + session = SessionV1(client, session_id) + session.init_session() + return session + + def _init_callbacks(self) -> None: + self.button_callback = self.client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + self.pin_callback = self.client.pin_callback + self.passphrase_callback = self.client.passphrase_callback + + def _write(self, msg: t.Any) -> None: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + self.client.protocol.write(msg) + + def _read(self) -> t.Any: + if t.TYPE_CHECKING: + assert isinstance(self.client.protocol, ProtocolV1) + return self.client.protocol.read() + + def init_session(self, derive_cardano: bool | None = None): + if self.id == b"": + session_id = None + else: + session_id = self.id + resp: messages.Features = self.call_raw( + messages.Initialize(session_id=session_id, derive_cardano=derive_cardano) + ) + if isinstance(self.passphrase, str): + self.passphrase_callback = _send_passphrase + self._id = resp.session_id + + +def _send_passphrase(session: Session, resp: t.Any) -> None: + assert isinstance(session.passphrase, str) + return session.call(messages.PassphraseAck(passphrase=session.passphrase)) + + +def _callback_button(session: Session, msg: t.Any) -> t.Any: + print("Please confirm action on your Trezor device.") # TODO how to handle UI? + return session.call(messages.ButtonAck()) + + +class SessionV2(Session): + + @classmethod + def new( + cls, client: TrezorClient, passphrase: str | None, derive_cardano: bool + ) -> SessionV2: + assert isinstance(client.protocol, ProtocolV2) + session = cls(client, b"\x00") + new_session: messages.ThpNewSession = session.call( + messages.ThpCreateNewSession( + passphrase=passphrase, derive_cardano=derive_cardano + ) + ) + assert new_session.new_session_id is not None + session_id = new_session.new_session_id + session.update_id_and_sid(session_id.to_bytes(1, "big")) + return session + + def __init__(self, client: TrezorClient, id: bytes) -> None: + super().__init__(client, id) + assert isinstance(client.protocol, ProtocolV2) + + self.pin_callback = client.pin_callback + self.button_callback = client.button_callback + if self.button_callback is None: + self.button_callback = _callback_button + self.channel: ProtocolV2 = client.protocol.get_channel() + self.update_id_and_sid(id) + + def _write(self, msg: t.Any) -> None: + LOG.debug("writing message %s", type(msg)) + self.channel.write(self.sid, msg) + + def _read(self) -> t.Any: + msg = self.channel.read(self.sid) + LOG.debug("reading message %s", type(msg)) + return msg + + def update_id_and_sid(self, id: bytes) -> None: + self._id = id + self.sid = int.from_bytes(id, "big") # TODO update to extract only sid diff --git a/python/src/trezorlib/transport/thp/alternating_bit_protocol.py b/python/src/trezorlib/transport/thp/alternating_bit_protocol.py new file mode 100644 index 0000000000..62fb650fab --- /dev/null +++ b/python/src/trezorlib/transport/thp/alternating_bit_protocol.py @@ -0,0 +1,102 @@ +# from storage.cache_thp import ChannelCache +# from trezor import log +# from trezor.wire.thp import ThpError + + +# def is_ack_valid(cache: ChannelCache, ack_bit: int) -> bool: +# """ +# Checks if: +# - an ACK message is expected +# - the received ACK message acknowledges correct sequence number (bit) +# """ +# if not _is_ack_expected(cache): +# return False + +# if not _has_ack_correct_sync_bit(cache, ack_bit): +# return False + +# return True + + +# def _is_ack_expected(cache: ChannelCache) -> bool: +# is_expected: bool = not is_sending_allowed(cache) +# if __debug__ and not is_expected: +# log.debug(__name__, "Received unexpected ACK message") +# return is_expected + + +# def _has_ack_correct_sync_bit(cache: ChannelCache, sync_bit: int) -> bool: +# is_correct: bool = get_send_seq_bit(cache) == sync_bit +# if __debug__ and not is_correct: +# log.debug(__name__, "Received ACK message with wrong ack bit") +# return is_correct + + +# def is_sending_allowed(cache: ChannelCache) -> bool: +# """ +# Checks whether sending a message in the provided channel is allowed. + +# Note: Sending a message in a channel before receipt of ACK message for the previously +# sent message (in the channel) is prohibited, as it can lead to desynchronization. +# """ +# return bool(cache.sync >> 7) + + +# def get_send_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the sequential number (bit) of the next message to be sent +# in the provided channel. +# """ +# return (cache.sync & 0x20) >> 5 + + +# def get_expected_receive_seq_bit(cache: ChannelCache) -> int: +# """ +# Returns the (expected) sequential number (bit) of the next message +# to be received in the provided channel. +# """ +# return (cache.sync & 0x40) >> 6 + + +# def set_sending_allowed(cache: ChannelCache, sending_allowed: bool) -> None: +# """ +# Set the flag whether sending a message in this channel is allowed or not. +# """ +# cache.sync &= 0x7F +# if sending_allowed: +# cache.sync |= 0x80 + + +# def set_expected_receive_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# """ +# Set the expected sequential number (bit) of the next message to be received +# in the provided channel +# """ +# if __debug__: +# log.debug(__name__, "Set sync receive expected seq bit to %d", seq_bit) +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected receive sync bit") + +# # set second bit to "seq_bit" value +# cache.sync &= 0xBF +# if seq_bit: +# cache.sync |= 0x40 + + +# def _set_send_seq_bit(cache: ChannelCache, seq_bit: int) -> None: +# if seq_bit not in (0, 1): +# raise ThpError("Unexpected send seq bit") +# if __debug__: +# log.debug(__name__, "setting sync send seq bit to %d", seq_bit) +# # set third bit to "seq_bit" value +# cache.sync &= 0xDF +# if seq_bit: +# cache.sync |= 0x20 + + +# def set_send_seq_bit_to_opposite(cache: ChannelCache) -> None: +# """ +# Set the sequential bit of the "next message to be send" to the opposite value, +# i.e. 1 -> 0 and 0 -> 1 +# """ +# _set_send_seq_bit(cache=cache, seq_bit=1 - get_send_seq_bit(cache)) diff --git a/python/src/trezorlib/transport/thp/channel_data.py b/python/src/trezorlib/transport/thp/channel_data.py new file mode 100644 index 0000000000..3d70deecaf --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_data.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from binascii import hexlify + + +class ChannelData: + def __init__( + self, + protocol_version: int, + transport_path: str, + channel_id: int, + key_request: bytes, + key_response: bytes, + nonce_request: int, + nonce_response: int, + sync_bit_send: int, + sync_bit_receive: int, + ) -> None: + self.protocol_version: int = protocol_version + self.transport_path: str = transport_path + self.channel_id: int = channel_id + self.key_request: str = hexlify(key_request).decode() + self.key_response: str = hexlify(key_response).decode() + self.nonce_request: int = nonce_request + self.nonce_response: int = nonce_response + self.sync_bit_receive: int = sync_bit_receive + self.sync_bit_send: int = sync_bit_send + + def to_dict(self): + return { + "protocol_version": self.protocol_version, + "transport_path": self.transport_path, + "channel_id": self.channel_id, + "key_request": self.key_request, + "key_response": self.key_response, + "nonce_request": self.nonce_request, + "nonce_response": self.nonce_response, + "sync_bit_send": self.sync_bit_send, + "sync_bit_receive": self.sync_bit_receive, + } diff --git a/python/src/trezorlib/transport/thp/channel_database.py b/python/src/trezorlib/transport/thp/channel_database.py new file mode 100644 index 0000000000..143430069f --- /dev/null +++ b/python/src/trezorlib/transport/thp/channel_database.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import json +import logging +import os +import typing as t + +from ..thp.channel_data import ChannelData +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +db: "ChannelDatabase | None" = None + + +def get_channel_db() -> ChannelDatabase: + if db is None: + set_channel_database(should_not_store=True) + assert db is not None + return db + + +class ChannelDatabase: + + def load_stored_channels(self) -> t.List[ChannelData]: ... + def clear_stored_channels(self) -> None: ... + def read_all_channels(self) -> t.List: ... + def save_all_channels(self, channels: t.List[t.Dict]) -> None: ... + def save_channel(self, new_channel: ProtocolAndChannel): ... + def remove_channel(self, transport_path: str) -> None: ... + + +class DummyChannelDatabase(ChannelDatabase): + + def load_stored_channels(self) -> t.List[ChannelData]: + return [] + + def clear_stored_channels(self) -> None: + pass + + def read_all_channels(self) -> t.List: + return [] + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + return + + def save_channel(self, new_channel: ProtocolAndChannel): + pass + + def remove_channel(self, transport_path: str) -> None: + pass + + +class JsonChannelDatabase(ChannelDatabase): + def __init__(self, data_path: str) -> None: + self.data_path = data_path + super().__init__() + + def load_stored_channels(self) -> t.List[ChannelData]: + dicts = self.read_all_channels() + return [dict_to_channel_data(d) for d in dicts] + + def clear_stored_channels(self) -> None: + LOG.debug("Clearing contents of %s", self.data_path) + with open(self.data_path, "w") as f: + json.dump([], f) + try: + os.remove(self.data_path) + except Exception as e: + LOG.exception("Failed to delete %s (%s)", self.data_path, str(type(e))) + + def read_all_channels(self) -> t.List: + ensure_file_exists(self.data_path) + with open(self.data_path, "r") as f: + return json.load(f) + + def save_all_channels(self, channels: t.List[t.Dict]) -> None: + LOG.debug("saving all channels") + with open(self.data_path, "w") as f: + json.dump(channels, f, indent=4) + + def save_channel(self, new_channel: ProtocolAndChannel): + + LOG.debug("save channel") + channels = self.read_all_channels() + transport_path = new_channel.transport.get_path() + + # If the channel is found in database: replace the old entry by the new + for i, channel in enumerate(channels): + if channel["transport_path"] == transport_path: + LOG.debug("Modified channel entry for %s", transport_path) + channels[i] = new_channel.get_channel_data().to_dict() + self.save_all_channels(channels) + return + + # Channel was not found: add a new channel entry + LOG.debug("Created a new channel entry on path %s", transport_path) + channels.append(new_channel.get_channel_data().to_dict()) + self.save_all_channels(channels) + + def remove_channel(self, transport_path: str) -> None: + LOG.debug( + "Removing channel with path %s from the channel database.", + transport_path, + ) + channels = self.read_all_channels() + remaining_channels = [ + ch for ch in channels if ch["transport_path"] != transport_path + ] + self.save_all_channels(remaining_channels) + + +def dict_to_channel_data(dict: t.Dict) -> ChannelData: + return ChannelData( + protocol_version=dict["protocol_version"], + transport_path=dict["transport_path"], + channel_id=dict["channel_id"], + key_request=bytes.fromhex(dict["key_request"]), + key_response=bytes.fromhex(dict["key_response"]), + nonce_request=dict["nonce_request"], + nonce_response=dict["nonce_response"], + sync_bit_send=dict["sync_bit_send"], + sync_bit_receive=dict["sync_bit_receive"], + ) + + +def ensure_file_exists(file_path: str) -> None: + LOG.debug("checking if file %s exists", file_path) + if not os.path.exists(file_path): + os.makedirs(os.path.dirname(file_path), exist_ok=True) + LOG.debug("File %s does not exist. Creating a new one.", file_path) + with open(file_path, "w") as f: + json.dump([], f) + + +def set_channel_database(should_not_store: bool): + global db + if should_not_store: + db = DummyChannelDatabase() + else: + from platformdirs import user_cache_dir + + APP_NAME = "@trezor" # TODO + DATA_PATH = os.path.join(user_cache_dir(appname=APP_NAME), "channel_data.json") + + db = JsonChannelDatabase(DATA_PATH) diff --git a/python/src/trezorlib/transport/thp/checksum.py b/python/src/trezorlib/transport/thp/checksum.py new file mode 100644 index 0000000000..8e0f32f013 --- /dev/null +++ b/python/src/trezorlib/transport/thp/checksum.py @@ -0,0 +1,19 @@ +import zlib + +CHECKSUM_LENGTH = 4 + + +def compute(data: bytes) -> bytes: + """ + Returns a CRC-32 checksum of the provided `data`. + """ + return zlib.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + + +def is_valid(checksum: bytes, data: bytes) -> bool: + """ + Checks whether the CRC-32 checksum of the `data` is the same + as the checksum provided in `checksum`. + """ + data_checksum = compute(data) + return checksum == data_checksum diff --git a/python/src/trezorlib/transport/thp/control_byte.py b/python/src/trezorlib/transport/thp/control_byte.py new file mode 100644 index 0000000000..ce7f6066f9 --- /dev/null +++ b/python/src/trezorlib/transport/thp/control_byte.py @@ -0,0 +1,59 @@ +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + + +def add_seq_bit_to_ctrl_byte(ctrl_byte: int, seq_bit: int) -> int: + if seq_bit == 0: + return ctrl_byte & 0xEF + if seq_bit == 1: + return ctrl_byte | 0x10 + raise Exception("Unexpected sequence bit") + + +def add_ack_bit_to_ctrl_byte(ctrl_byte: int, ack_bit: int) -> int: + if ack_bit == 0: + return ctrl_byte & 0xF7 + if ack_bit == 1: + return ctrl_byte | 0x08 + raise Exception("Unexpected acknowledgement bit") + + +def get_seq_bit(ctrl_byte: int) -> int: + return (ctrl_byte & 0x10) >> 4 + + +def is_ack(ctrl_byte: int) -> bool: + return ctrl_byte & ACK_MASK == ACK_MESSAGE + + +def is_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & CONTINUATION_PACKET_MASK == CONTINUATION_PACKET + + +def is_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + +def is_handshake_init_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_INIT_REQ + + +def is_handshake_comp_req(ctrl_byte: int) -> bool: + return ctrl_byte & DATA_MASK == HANDSHAKE_COMP_REQ diff --git a/python/src/trezorlib/transport/thp/curve25519.py b/python/src/trezorlib/transport/thp/curve25519.py new file mode 100644 index 0000000000..43127c49e5 --- /dev/null +++ b/python/src/trezorlib/transport/thp/curve25519.py @@ -0,0 +1,116 @@ +from typing import Tuple + +p = 2**255 - 19 +J = 486662 + +c3 = 19681161376707505956807079304988542015446066515923890162744021073123829784752 # sqrt(-1) +c4 = 7237005577332262213973186563042994240829374041602535252466099000494570602493 # (p - 5) // 8 +a24 = 121666 # (J + 2) // 4 + + +def decode_scalar(scalar: bytes) -> int: + # decodeScalar25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + if len(scalar) != 32: + raise ValueError("Invalid length of scalar") + + array = bytearray(scalar) + array[0] &= 248 + array[31] &= 127 + array[31] |= 64 + + return int.from_bytes(array, "little") + + +def decode_coordinate(coordinate: bytes) -> int: + # decodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + if len(coordinate) != 32: + raise ValueError("Invalid length of coordinate") + + array = bytearray(coordinate) + array[-1] &= 0x7F + return int.from_bytes(array, "little") % p + + +def encode_coordinate(coordinate: int) -> bytes: + # encodeUCoordinate from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + return coordinate.to_bytes(32, "little") + + +def get_private_key(secret: bytes) -> bytes: + return decode_scalar(secret).to_bytes(32, "little") + + +def get_public_key(private_key: bytes) -> bytes: + base_point = int.to_bytes(9, 32, "little") + return multiply(private_key, base_point) + + +def multiply(private_scalar: bytes, public_point: bytes): + # X25519 from + # https://datatracker.ietf.org/doc/html/rfc7748#section-5 + + def ladder_operation( + x1: int, x2: int, z2: int, x3: int, z3: int + ) -> Tuple[int, int, int, int]: + # https://hyperelliptic.org/EFD/g1p/auto-montgom-xz.html#ladder-ladd-1987-m-3 + # (x4, z4) = 2 * (x2, z2) + # (x5, z5) = (x2, z2) + (x3, z3) + # where (x1, 1) = (x3, z3) - (x2, z2) + + a = (x2 + z2) % p + aa = (a * a) % p + b = (x2 - z2) % p + bb = (b * b) % p + e = (aa - bb) % p + c = (x3 + z3) % p + d = (x3 - z3) % p + da = (d * a) % p + cb = (c * b) % p + t0 = (da + cb) % p + x5 = (t0 * t0) % p + t1 = (da - cb) % p + t2 = (t1 * t1) % p + z5 = (x1 * t2) % p + x4 = (aa * bb) % p + t3 = (a24 * e) % p + t4 = (bb + t3) % p + z4 = (e * t4) % p + + return x4, z4, x5, z5 + + def conditional_swap(first: int, second: int, condition: int): + # Returns (second, first) if condition is true and (first, second) otherwise + # Must be implemented in a way that it is constant time + true_mask = -condition + false_mask = ~true_mask + return (first & false_mask) | (second & true_mask), (second & false_mask) | ( + first & true_mask + ) + + k = decode_scalar(private_scalar) + u = decode_coordinate(public_point) + + x_1 = u + x_2 = 1 + z_2 = 0 + x_3 = u + z_3 = 1 + swap = 0 + + for i in reversed(range(256)): + bit = (k >> i) & 1 + swap = bit ^ swap + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + swap = bit + x_2, z_2, x_3, z_3 = ladder_operation(x_1, x_2, z_2, x_3, z_3) + + (x_2, x_3) = conditional_swap(x_2, x_3, swap) + (z_2, z_3) = conditional_swap(z_2, z_3, swap) + + x = pow(z_2, p - 2, p) * x_2 % p + return encode_coordinate(x) diff --git a/python/src/trezorlib/transport/thp/message_header.py b/python/src/trezorlib/transport/thp/message_header.py new file mode 100644 index 0000000000..d2ff002d63 --- /dev/null +++ b/python/src/trezorlib/transport/thp/message_header.py @@ -0,0 +1,82 @@ +import struct + +CODEC_V1 = 0x3F +CONTINUATION_PACKET = 0x80 +HANDSHAKE_INIT_REQ = 0x00 +HANDSHAKE_INIT_RES = 0x01 +HANDSHAKE_COMP_REQ = 0x02 +HANDSHAKE_COMP_RES = 0x03 +ENCRYPTED_TRANSPORT = 0x04 + +CONTINUATION_PACKET_MASK = 0x80 +ACK_MASK = 0xF7 +DATA_MASK = 0xE7 + +ACK_MESSAGE = 0x20 +_ERROR = 0x42 +CHANNEL_ALLOCATION_REQ = 0x40 +_CHANNEL_ALLOCATION_RES = 0x41 + +TREZOR_STATE_UNPAIRED = b"\x00" +TREZOR_STATE_PAIRED = b"\x01" + +BROADCAST_CHANNEL_ID = 0xFFFF + + +class MessageHeader: + format_str_init = ">BHH" + format_str_cont = ">BH" + + def __init__(self, ctrl_byte: int, cid: int, length: int) -> None: + self.ctrl_byte = ctrl_byte + self.cid = cid + self.data_length = length + + def to_bytes_init(self) -> bytes: + return struct.pack( + self.format_str_init, self.ctrl_byte, self.cid, self.data_length + ) + + def to_bytes_cont(self) -> bytes: + return struct.pack(self.format_str_cont, CONTINUATION_PACKET, self.cid) + + def pack_to_init_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_init, + buffer, + buffer_offset, + self.ctrl_byte, + self.cid, + self.data_length, + ) + + def pack_to_cont_buffer(self, buffer: bytearray, buffer_offset: int = 0) -> None: + struct.pack_into( + self.format_str_cont, buffer, buffer_offset, CONTINUATION_PACKET, self.cid + ) + + def is_ack(self) -> bool: + return self.ctrl_byte & ACK_MASK == ACK_MESSAGE + + def is_channel_allocation_response(self): + return ( + self.cid == BROADCAST_CHANNEL_ID + and self.ctrl_byte == _CHANNEL_ALLOCATION_RES + ) + + def is_handshake_init_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_INIT_RES + + def is_handshake_comp_response(self) -> bool: + return self.ctrl_byte & DATA_MASK == HANDSHAKE_COMP_RES + + def is_encrypted_transport(self) -> bool: + return self.ctrl_byte & DATA_MASK == ENCRYPTED_TRANSPORT + + @classmethod + def get_error_header(cls, cid: int, length: int): + return cls(_ERROR, cid, length) + + @classmethod + def get_channel_allocation_request_header(cls, length: int): + return cls(CHANNEL_ALLOCATION_REQ, BROADCAST_CHANNEL_ID, length) diff --git a/python/src/trezorlib/transport/thp/protocol_and_channel.py b/python/src/trezorlib/transport/thp/protocol_and_channel.py new file mode 100644 index 0000000000..fa420ac0af --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_and_channel.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import logging + +from ... import messages +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp.channel_data import ChannelData + +LOG = logging.getLogger(__name__) + + +class ProtocolAndChannel: + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.transport = transport + self.mapping = mapping + self.channel_keys = channel_data + + def get_features(self) -> messages.Features: + raise NotImplementedError() + + def get_channel_data(self) -> ChannelData: + raise NotImplementedError + + def update_features(self) -> None: + raise NotImplementedError diff --git a/python/src/trezorlib/transport/thp/protocol_v1.py b/python/src/trezorlib/transport/thp/protocol_v1.py new file mode 100644 index 0000000000..baea7e7401 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v1.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import logging +import struct +import typing as t + +from ... import exceptions, messages +from ...log import DUMP_BYTES +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + + +class ProtocolV1(ProtocolAndChannel): + HEADER_LEN = struct.calcsize(">HL") + _features: messages.Features | None = None + + def get_features(self) -> messages.Features: + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + self.write(messages.GetFeatures()) + resp = self.read() + if not isinstance(resp, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = resp + + def read(self) -> t.Any: + msg_type, msg_bytes = self._read() + LOG.log( + DUMP_BYTES, + f"received type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + msg = self.mapping.decode(msg_type, msg_bytes) + LOG.debug( + f"received message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + self.transport.close() + return msg + + def write(self, msg: t.Any) -> None: + LOG.debug( + f"sending message: {msg.__class__.__name__}", + extra={"protobuf": msg}, + ) + msg_type, msg_bytes = self.mapping.encode(msg) + LOG.log( + DUMP_BYTES, + f"encoded as type {msg_type} ({len(msg_bytes)} bytes): {msg_bytes.hex()}", + ) + self._write(msg_type, msg_bytes) + + def _write(self, message_type: int, message_data: bytes) -> None: + chunk_size = self.transport.CHUNK_SIZE + header = struct.pack(">HL", message_type, len(message_data)) + buffer = bytearray(b"##" + header + message_data) + + while buffer: + # Report ID, data padded to 63 bytes + chunk = b"?" + buffer[: chunk_size - 1] + chunk = chunk.ljust(chunk_size, b"\x00") + self.transport.write_chunk(chunk) + buffer = buffer[63:] + + def _read(self) -> t.Tuple[int, bytes]: + buffer = bytearray() + # Read header with first part of message data + msg_type, datalen, first_chunk = self.read_first() + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < datalen: + buffer.extend(self.read_next()) + + return msg_type, buffer[:datalen] + + def read_first(self) -> t.Tuple[int, int, bytes]: + chunk = self.transport.read_chunk() + if chunk[:3] != b"?##": + raise RuntimeError("Unexpected magic characters") + try: + msg_type, datalen = struct.unpack(">HL", chunk[3 : 3 + self.HEADER_LEN]) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[3 + self.HEADER_LEN :] + return msg_type, datalen, data + + def read_next(self) -> bytes: + chunk = self.transport.read_chunk() + if chunk[:1] != b"?": + raise RuntimeError("Unexpected magic characters") + return chunk[1:] diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py new file mode 100644 index 0000000000..07ff2cadd4 --- /dev/null +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -0,0 +1,404 @@ +from __future__ import annotations + +import hashlib +import hmac +import logging +import os +import typing as t +from binascii import hexlify +from enum import IntEnum + +import click +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from ... import exceptions, messages +from ...mapping import ProtobufMapping +from .. import Transport +from ..thp import checksum, curve25519, thp_io +from ..thp.channel_data import ChannelData +from ..thp.checksum import CHECKSUM_LENGTH +from ..thp.message_header import MessageHeader +from . import control_byte +from .channel_database import ChannelDatabase, get_channel_db +from .protocol_and_channel import ProtocolAndChannel + +LOG = logging.getLogger(__name__) + +MANAGEMENT_SESSION_ID: int = 0 + + +def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: + hash = hashlib.sha256(val_1) + hash.update(val_2) + return hash.digest() + + +def _hkdf(chaining_key: bytes, input: bytes): + temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest() + output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest() + ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256) + ctx_output_2.update(b"\x02") + output_2 = ctx_output_2.digest() + return (output_1, output_2) + + +def _get_iv_from_nonce(nonce: int) -> bytes: + if not nonce <= 0xFFFFFFFFFFFFFFFF: + raise ValueError("Nonce overflow, terminate the channel") + return bytes(4) + nonce.to_bytes(8, "big") + + +class ProtocolV2(ProtocolAndChannel): + channel_id: int + channel_database: ChannelDatabase + key_request: bytes + key_response: bytes + nonce_request: int + nonce_response: int + sync_bit_send: int + sync_bit_receive: int + + _has_valid_channel: bool = False + _features: messages.Features | None = None + + def __init__( + self, + transport: Transport, + mapping: ProtobufMapping, + channel_data: ChannelData | None = None, + ) -> None: + self.channel_database: ChannelDatabase = get_channel_db() + super().__init__(transport, mapping, channel_data) + if channel_data is not None: + self.channel_id = channel_data.channel_id + self.key_request = bytes.fromhex(channel_data.key_request) + self.key_response = bytes.fromhex(channel_data.key_response) + self.nonce_request = channel_data.nonce_request + self.nonce_response = channel_data.nonce_response + self.sync_bit_receive = channel_data.sync_bit_receive + self.sync_bit_send = channel_data.sync_bit_send + self._has_valid_channel = True + + def get_channel(self) -> ProtocolV2: + if not self._has_valid_channel: + self._establish_new_channel() + return self + + def get_channel_data(self) -> ChannelData: + return ChannelData( + protocol_version=2, + transport_path=self.transport.get_path(), + channel_id=self.channel_id, + key_request=self.key_request, + key_response=self.key_response, + nonce_request=self.nonce_request, + nonce_response=self.nonce_response, + sync_bit_receive=self.sync_bit_receive, + sync_bit_send=self.sync_bit_send, + ) + + def read(self, session_id: int) -> t.Any: + sid, msg_type, msg_data = self.read_and_decrypt() + if sid != session_id: + raise Exception("Received messsage on a different session.") + self.channel_database.save_channel(self) + return self.mapping.decode(msg_type, msg_data) + + def write(self, session_id: int, msg: t.Any) -> None: + msg_type, msg_data = self.mapping.encode(msg) + self._encrypt_and_write(session_id, msg_type, msg_data) + self.channel_database.save_channel(self) + + def get_features(self) -> messages.Features: + if not self._has_valid_channel: + self._establish_new_channel() + if self._features is None: + self.update_features() + assert self._features is not None + return self._features + + def update_features(self) -> None: + message = messages.GetFeatures() + message_type, message_data = self.mapping.encode(message) + self.session_id: int = 0 + self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) + _ = self._read_until_valid_crc_check() # TODO check ACK + _, msg_type, msg_data = self.read_and_decrypt() + features = self.mapping.decode(msg_type, msg_data) + if not isinstance(features, messages.Features): + raise exceptions.TrezorException("Unexpected response to GetFeatures") + self._features = features + + def _establish_new_channel(self) -> None: + self.sync_bit_send = 0 + self.sync_bit_receive = 0 + # Send channel allocation request + # Note that [:8] on the following line is required when tests use + # WITH_MOCK_URANDOM. Without [:8] such tests will (almost always) fail. + channel_id_request_nonce = os.urandom(8)[:8] + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + MessageHeader.get_channel_allocation_request_header(12), + channel_id_request_nonce, + ) + + # Read channel allocation response + header, payload = self._read_until_valid_crc_check() + if not self._is_valid_channel_allocation_response( + header, payload, channel_id_request_nonce + ): + # TODO raise exception here, I guess + raise Exception("Invalid channel allocation response.") + + self.channel_id = int.from_bytes(payload[8:10], "big") + self.device_properties = payload[10:] + + # Send handshake init request + ha_init_req_header = MessageHeader(0, self.channel_id, 36) + # Note that [:32] on the following line is required when tests use + # WITH_MOCK_URANDOM. Without [:32] such tests will (almost always) fail. + host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)[:32]) + host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, ha_init_req_header, host_ephemeral_pubkey + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + + # Read handshake init response + header, payload = self._read_until_valid_crc_check() + self._send_ack_0() + + if not header.is_handshake_init_response(): + click.echo( + "Received message is not a valid handshake init response message", + err=True, + ) + + trezor_ephemeral_pubkey = payload[:32] + encrypted_trezor_static_pubkey = payload[32:80] + noise_tag = payload[80:96] + + # TODO check noise tag + LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) + + # Prepare and send handshake completion request + PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" + IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) + h = _sha256_of_two(h, host_ephemeral_pubkey) + h = _sha256_of_two(h, trezor_ephemeral_pubkey) + ck, k = _hkdf( + PROTOCOL_NAME, + curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), + ) + + aes_ctx = AESGCM(k) + try: + trezor_masked_static_pubkey = aes_ctx.decrypt( + IV_1, encrypted_trezor_static_pubkey, h + ) + except Exception as e: + click.echo( + f"Exception of type{type(e)}", err=True + ) # TODO how to handle potential exceptions? Q for Matejcik + h = _sha256_of_two(h, encrypted_trezor_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) + ) + aes_ctx = AESGCM(k) + + tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) + h = _sha256_of_two(h, tag_of_empty_string) + # TODO: search for saved credentials (or possibly not, as we skip pairing phase) + + zeroes_32 = int.to_bytes(0, 32, "little") + temp_host_static_privkey = curve25519.get_private_key(zeroes_32) + temp_host_static_pubkey = curve25519.get_public_key(temp_host_static_privkey) + aes_ctx = AESGCM(k) + encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, temp_host_static_pubkey, h) + h = _sha256_of_two(h, encrypted_host_static_pubkey) + ck, k = _hkdf( + ck, curve25519.multiply(temp_host_static_privkey, trezor_ephemeral_pubkey) + ) + msg_data = self.mapping.encode_without_wire_type( + messages.ThpHandshakeCompletionReqNoisePayload( + pairing_methods=[ + messages.ThpPairingMethod.NoMethod, + ] + ) + ) + + aes_ctx = AESGCM(k) + + encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) + h = _sha256_of_two(h, encrypted_payload) + ha_completion_req_header = MessageHeader( + 0x12, + self.channel_id, + len(encrypted_host_static_pubkey) + + len(encrypted_payload) + + CHECKSUM_LENGTH, + ) + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, + ha_completion_req_header, + encrypted_host_static_pubkey + encrypted_payload, + ) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + + # Read handshake completion response, ignore payload as we do not care about the state + header, _ = self._read_until_valid_crc_check() + if not header.is_handshake_comp_response(): + click.echo( + "Received message is not a valid handshake completion response", + err=True, + ) + self._send_ack_1() + + self.key_request, self.key_response = _hkdf(ck, b"") + self.nonce_request = 0 + self.nonce_response = 1 + + # Send StartPairingReqest message + message = messages.ThpStartPairingRequest() + message_type, message_data = self.mapping.encode(message) + + self._encrypt_and_write(MANAGEMENT_SESSION_ID, message_type, message_data) + + # Read ACK + header, payload = self._read_until_valid_crc_check() + if not header.is_ack() or len(payload) > 0: + click.echo("Received message is not a valid ACK", err=True) + + # Read + _, msg_type, msg_data = self.read_and_decrypt() + maaa = self.mapping.decode(msg_type, msg_data) + + assert isinstance(maaa, messages.ThpEndResponse) + self._has_valid_channel = True + + def _send_ack_0(self): + LOG.debug("sending ack 0") + header = MessageHeader(0x20, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _send_ack_1(self): + LOG.debug("sending ack 1") + header = MessageHeader(0x28, self.channel_id, 4) + thp_io.write_payload_to_wire_and_add_checksum(self.transport, header, b"") + + def _encrypt_and_write( + self, + session_id: int, + message_type: int, + message_data: bytes, + ctrl_byte: int | None = None, + ) -> None: + assert self.key_request is not None + aes_ctx = AESGCM(self.key_request) + + if ctrl_byte is None: + ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send) + self.sync_bit_send = 1 - self.sync_bit_send + + sid = session_id.to_bytes(1, "big") + msg_type = message_type.to_bytes(2, "big") + data = sid + msg_type + message_data + nonce = _get_iv_from_nonce(self.nonce_request) + self.nonce_request += 1 + encrypted_message = aes_ctx.encrypt(nonce, data, b"") + header = MessageHeader( + ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH + ) + + thp_io.write_payload_to_wire_and_add_checksum( + self.transport, header, encrypted_message + ) + + def read_and_decrypt(self) -> t.Tuple[int, int, bytes]: + header, raw_payload = self._read_until_valid_crc_check() + if control_byte.is_ack(header.ctrl_byte): + return self.read_and_decrypt() + if not header.is_encrypted_transport(): + click.echo( + "Trying to decrypt not encrypted message!" + + hexlify(header.to_bytes_init() + raw_payload).decode(), + err=True, + ) + + if not control_byte.is_ack(header.ctrl_byte): + LOG.debug( + "--> Get sequence bit %d %s %s", + control_byte.get_seq_bit(header.ctrl_byte), + "from control byte", + hexlify(header.ctrl_byte.to_bytes(1, "big")).decode(), + ) + if control_byte.get_seq_bit(header.ctrl_byte): + self._send_ack_1() + else: + self._send_ack_0() + aes_ctx = AESGCM(self.key_response) + nonce = _get_iv_from_nonce(self.nonce_response) + self.nonce_response += 1 + + message = aes_ctx.decrypt(nonce, raw_payload, b"") + session_id = message[0] + message_type = message[1:3] + message_data = message[3:] + return ( + session_id, + int.from_bytes(message_type, "big"), + message_data, + ) + + def _read_until_valid_crc_check( + self, + ) -> t.Tuple[MessageHeader, bytes]: + is_valid = False + header, payload, chksum = thp_io.read(self.transport) + while not is_valid: + is_valid = checksum.is_valid(chksum, header.to_bytes_init() + payload) + if not is_valid: + click.echo( + "Received a message with an invalid checksum:" + + hexlify(header.to_bytes_init() + payload + chksum).decode(), + err=True, + ) + header, payload, chksum = thp_io.read(self.transport) + + return header, payload + + def _is_valid_channel_allocation_response( + self, header: MessageHeader, payload: bytes, original_nonce: bytes + ) -> bool: + if not header.is_channel_allocation_response(): + click.echo( + "Received message is not a channel allocation response", err=True + ) + return False + if len(payload) < 10: + click.echo("Invalid channel allocation response payload", err=True) + return False + if payload[:8] != original_nonce: + click.echo( + "Invalid channel allocation response payload (nonce mismatch)", err=True + ) + return False + return True + + class ControlByteType(IntEnum): + CHANNEL_ALLOCATION_RES = 1 + HANDSHAKE_INIT_RES = 2 + HANDSHAKE_COMP_RES = 3 + ACK = 4 + ENCRYPTED_TRANSPORT = 5 diff --git a/python/src/trezorlib/transport/thp/thp_io.py b/python/src/trezorlib/transport/thp/thp_io.py new file mode 100644 index 0000000000..d0237f9e36 --- /dev/null +++ b/python/src/trezorlib/transport/thp/thp_io.py @@ -0,0 +1,93 @@ +import struct +from typing import Tuple + +from .. import Transport +from ..thp import checksum +from .message_header import MessageHeader + +INIT_HEADER_LENGTH = 5 +CONT_HEADER_LENGTH = 3 +MAX_PAYLOAD_LEN = 60000 +MESSAGE_TYPE_LENGTH = 2 + +CONTINUATION_PACKET = 0x80 + + +def write_payload_to_wire_and_add_checksum( + transport: Transport, header: MessageHeader, transport_payload: bytes +): + chksum: bytes = checksum.compute(header.to_bytes_init() + transport_payload) + data = transport_payload + chksum + write_payload_to_wire(transport, header, data) + + +def write_payload_to_wire( + transport: Transport, header: MessageHeader, transport_payload: bytes +): + transport.open() + buffer = bytearray(transport_payload) + chunk = header.to_bytes_init() + buffer[: transport.CHUNK_SIZE - INIT_HEADER_LENGTH] + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + + buffer = buffer[transport.CHUNK_SIZE - INIT_HEADER_LENGTH :] + while buffer: + chunk = ( + header.to_bytes_cont() + buffer[: transport.CHUNK_SIZE - CONT_HEADER_LENGTH] + ) + chunk = chunk.ljust(transport.CHUNK_SIZE, b"\x00") + transport.write_chunk(chunk) + buffer = buffer[transport.CHUNK_SIZE - CONT_HEADER_LENGTH :] + + +def read(transport: Transport) -> Tuple[MessageHeader, bytes, bytes]: + """ + Reads from the given wire transport. + + Returns `Tuple[MessageHeader, bytes, bytes]`: + 1. `header` (`MessageHeader`): Header of the message. + 2. `data` (`bytes`): Contents of the message (if any). + 3. `checksum` (`bytes`): crc32 checksum of the header + data. + + """ + buffer = bytearray() + + # Read header with first part of message data + header, first_chunk = read_first(transport) + buffer.extend(first_chunk) + + # Read the rest of the message + while len(buffer) < header.data_length: + buffer.extend(read_next(transport, header.cid)) + + data_len = header.data_length - checksum.CHECKSUM_LENGTH + msg_data = buffer[:data_len] + chksum = buffer[data_len : data_len + checksum.CHECKSUM_LENGTH] + + return (header, msg_data, chksum) + + +def read_first(transport: Transport) -> Tuple[MessageHeader, bytes]: + chunk = transport.read_chunk() + try: + ctrl_byte, cid, data_length = struct.unpack( + MessageHeader.format_str_init, chunk[:INIT_HEADER_LENGTH] + ) + except Exception: + raise RuntimeError("Cannot parse header") + + data = chunk[INIT_HEADER_LENGTH:] + return MessageHeader(ctrl_byte, cid, data_length), data + + +def read_next(transport: Transport, cid: int) -> bytes: + chunk = transport.read_chunk() + ctrl_byte, read_cid = struct.unpack( + MessageHeader.format_str_cont, chunk[:CONT_HEADER_LENGTH] + ) + if ctrl_byte != CONTINUATION_PACKET: + raise RuntimeError("Continuation packet with incorrect control byte") + if read_cid != cid: + raise RuntimeError("Continuation packet for different channel") + + return chunk[CONT_HEADER_LENGTH:] diff --git a/python/src/trezorlib/transport/udp.py b/python/src/trezorlib/transport/udp.py index 7e4c4614c6..2960df8994 100644 --- a/python/src/trezorlib/transport/udp.py +++ b/python/src/trezorlib/transport/udp.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,14 +14,15 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import logging import socket import time -from typing import TYPE_CHECKING, Iterable, Optional +from typing import TYPE_CHECKING, Iterable, Tuple from ..log import DUMP_PACKETS -from . import TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import Transport, TransportException if TYPE_CHECKING: from ..models import TrezorModel @@ -31,14 +32,18 @@ SOCKET_TIMEOUT = 10 LOG = logging.getLogger(__name__) -class UdpTransport(ProtocolBasedTransport): +class UdpTransport(Transport): DEFAULT_HOST = "127.0.0.1" DEFAULT_PORT = 21324 PATH_PREFIX = "udp" ENABLED: bool = True + CHUNK_SIZE = 64 - def __init__(self, device: Optional[str] = None) -> None: + def __init__( + self, + device: str | None = None, + ) -> None: if not device: host = UdpTransport.DEFAULT_HOST port = UdpTransport.DEFAULT_PORT @@ -46,24 +51,17 @@ class UdpTransport(ProtocolBasedTransport): devparts = device.split(":") host = devparts[0] port = int(devparts[1]) if len(devparts) > 1 else UdpTransport.DEFAULT_PORT - self.device = (host, port) - self.socket: Optional[socket.socket] = None + self.device: Tuple[str, int] = (host, port) - super().__init__(protocol=ProtocolV1(self)) - - def get_path(self) -> str: - return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) - - def find_debug(self) -> "UdpTransport": - host, port = self.device - return UdpTransport(f"{host}:{port + 1}") + self.socket: socket.socket | None = None + super().__init__() @classmethod def _try_path(cls, path: str) -> "UdpTransport": d = cls(path) try: d.open() - if d._ping(): + if d.ping(): return d else: raise TransportException( @@ -77,7 +75,7 @@ class UdpTransport(ProtocolBasedTransport): @classmethod def enumerate( - cls, _models: Optional[Iterable["TrezorModel"]] = None + cls, _models: Iterable["TrezorModel"] | None = None ) -> Iterable["UdpTransport"]: default_path = f"{cls.DEFAULT_HOST}:{cls.DEFAULT_PORT}" try: @@ -99,20 +97,8 @@ class UdpTransport(ProtocolBasedTransport): else: raise TransportException(f"No UDP device at {path}") - def wait_until_ready(self, timeout: float = 10) -> None: - try: - self.open() - start = time.monotonic() - while True: - if self._ping(): - break - elapsed = time.monotonic() - start - if elapsed >= timeout: - raise TransportException("Timed out waiting for connection.") - - time.sleep(0.05) - finally: - self.close() + def get_path(self) -> str: + return "{}:{}:{}".format(self.PATH_PREFIX, *self.device) def open(self) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -124,18 +110,9 @@ class UdpTransport(ProtocolBasedTransport): self.socket.close() self.socket = None - def _ping(self) -> bool: - """Test if the device is listening.""" - assert self.socket is not None - resp = None - try: - self.socket.sendall(b"PINGPING") - resp = self.socket.recv(8) - except Exception: - pass - return resp == b"PONGPONG" - def write_chunk(self, chunk: bytes) -> None: + if self.socket is None: + self.open() assert self.socket is not None if len(chunk) != 64: raise TransportException("Unexpected data length") @@ -143,6 +120,8 @@ class UdpTransport(ProtocolBasedTransport): self.socket.sendall(chunk) def read_chunk(self) -> bytes: + if self.socket is None: + self.open() assert self.socket is not None while True: try: @@ -154,3 +133,33 @@ class UdpTransport(ProtocolBasedTransport): if len(chunk) != 64: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return bytearray(chunk) + + def find_debug(self) -> "UdpTransport": + host, port = self.device + return UdpTransport(f"{host}:{port + 1}") + + def wait_until_ready(self, timeout: float = 10) -> None: + try: + self.open() + start = time.monotonic() + while True: + if self.ping(): + break + elapsed = time.monotonic() - start + if elapsed >= timeout: + raise TransportException("Timed out waiting for connection.") + + time.sleep(0.05) + finally: + self.close() + + def ping(self) -> bool: + """Test if the device is listening.""" + assert self.socket is not None + resp = None + try: + self.socket.sendall(b"PINGPING") + resp = self.socket.recv(8) + except Exception: + pass + return resp == b"PONGPONG" diff --git a/python/src/trezorlib/transport/webusb.py b/python/src/trezorlib/transport/webusb.py index 8e2d08147a..023ed5f245 100644 --- a/python/src/trezorlib/transport/webusb.py +++ b/python/src/trezorlib/transport/webusb.py @@ -1,6 +1,6 @@ # This file is part of the Trezor project. # -# Copyright (C) 2012-2022 SatoshiLabs and contributors +# Copyright (C) 2012-2024 SatoshiLabs and contributors # # This library is free software: you can redistribute it and/or modify # it under the terms of the GNU Lesser General Public License version 3 @@ -14,16 +14,17 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import atexit import logging import sys import time -from typing import Iterable, List, Optional +from typing import Iterable, List from ..log import DUMP_PACKETS from ..models import TREZORS, TrezorModel -from . import UDEV_RULES_STR, DeviceIsBusy, TransportException -from .protocol import ProtocolBasedTransport, ProtocolV1 +from . import UDEV_RULES_STR, DeviceIsBusy, Transport, TransportException LOG = logging.getLogger(__name__) @@ -44,13 +45,69 @@ USB_COMM_TIMEOUT_MS = 300 WEBUSB_CHUNK_SIZE = 64 -class WebUsbHandle: - def __init__(self, device: "usb1.USBDevice", debug: bool = False) -> None: +class WebUsbTransport(Transport): + """ + WebUsbTransport implements transport over WebUSB interface. + """ + + PATH_PREFIX = "webusb" + ENABLED = USB_IMPORTED + context = None + CHUNK_SIZE = 64 + + def __init__( + self, + device: "usb1.USBDevice", + debug: bool = False, + ) -> None: + self.device = device + self.debug = debug + self.interface = DEBUG_INTERFACE if debug else INTERFACE self.endpoint = DEBUG_ENDPOINT if debug else ENDPOINT - self.count = 0 - self.handle: Optional["usb1.USBDeviceHandle"] = None + self.handle: usb1.USBDeviceHandle | None = None + + super().__init__() + + @classmethod + def enumerate( + cls, models: Iterable["TrezorModel"] | None = None, usb_reset: bool = False + ) -> Iterable["WebUsbTransport"]: + if cls.context is None: + cls.context = usb1.USBContext() + cls.context.open() + atexit.register(cls.context.close) + + if models is None: + models = TREZORS + usb_ids = [id for model in models for id in model.usb_ids] + devices: List["WebUsbTransport"] = [] + for dev in cls.context.getDeviceIterator(skip_on_error=True): + usb_id = (dev.getVendorID(), dev.getProductID()) + if usb_id not in usb_ids: + continue + if not is_vendor_class(dev): + continue + if usb_reset: + handle = dev.open() + handle.resetDevice() + handle.close() + continue + try: + # workaround for issue #223: + # on certain combinations of Windows USB drivers and libusb versions, + # Trezor is returned twice (possibly because Windows know it as both + # a HID and a WebUSB device), and one of the returned devices is + # non-functional. + dev.getProduct() + devices.append(WebUsbTransport(dev)) + except usb1.USBErrorNotSupported: + pass + return devices + + def get_path(self) -> str: + return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" def open(self) -> None: self.handle = self.device.open() @@ -64,6 +121,8 @@ class WebUsbHandle: self.handle.claimInterface(self.interface) except usb1.USBErrorAccess as e: raise DeviceIsBusy(self.device) from e + except usb1.USBErrorBusy as e: + raise DeviceIsBusy(self.device) from e def close(self) -> None: if self.handle is not None: @@ -75,6 +134,8 @@ class WebUsbHandle: self.handle = None def write_chunk(self, chunk: bytes) -> None: + if self.handle is None: + self.open() assert self.handle is not None if len(chunk) != WEBUSB_CHUNK_SIZE: raise TransportException(f"Unexpected chunk size: {len(chunk)}") @@ -97,6 +158,8 @@ class WebUsbHandle: return def read_chunk(self) -> bytes: + if self.handle is None: + self.open() assert self.handle is not None endpoint = 0x80 | self.endpoint while True: @@ -117,70 +180,6 @@ class WebUsbHandle: raise TransportException(f"Unexpected chunk size: {len(chunk)}") return chunk - -class WebUsbTransport(ProtocolBasedTransport): - """ - WebUsbTransport implements transport over WebUSB interface. - """ - - PATH_PREFIX = "webusb" - ENABLED = USB_IMPORTED - context = None - - def __init__( - self, - device: "usb1.USBDevice", - handle: Optional[WebUsbHandle] = None, - debug: bool = False, - ) -> None: - if handle is None: - handle = WebUsbHandle(device, debug) - - self.device = device - self.handle = handle - self.debug = debug - - super().__init__(protocol=ProtocolV1(handle)) - - def get_path(self) -> str: - return f"{self.PATH_PREFIX}:{dev_to_str(self.device)}" - - @classmethod - def enumerate( - cls, models: Optional[Iterable["TrezorModel"]] = None, usb_reset: bool = False - ) -> Iterable["WebUsbTransport"]: - if cls.context is None: - cls.context = usb1.USBContext() - cls.context.open() - atexit.register(cls.context.close) - - if models is None: - models = TREZORS - usb_ids = [id for model in models for id in model.usb_ids] - devices: List["WebUsbTransport"] = [] - for dev in cls.context.getDeviceIterator(skip_on_error=True): - usb_id = (dev.getVendorID(), dev.getProductID()) - if usb_id not in usb_ids: - continue - if not is_vendor_class(dev): - continue - try: - # workaround for issue #223: - # on certain combinations of Windows USB drivers and libusb versions, - # Trezor is returned twice (possibly because Windows know it as both - # a HID and a WebUSB device), and one of the returned devices is - # non-functional. - dev.getProduct() - devices.append(WebUsbTransport(dev)) - except usb1.USBErrorNotSupported: - pass - except usb1.USBErrorPipe: - if usb_reset: - handle = dev.open() - handle.resetDevice() - handle.close() - return devices - def find_debug(self) -> "WebUsbTransport": # For v1 protocol, find debug USB interface for the same serial number return WebUsbTransport(self.device, debug=True)