From 6a5836708f393d9addb0e8ed0749fadc4b2e4f34 Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 3 Jan 2025 16:36:40 +0100 Subject: [PATCH] refactor(python): replace usages of @expect --- python/.changelog.d/4464.deprecated.1 | 1 + python/.changelog.d/4464.deprecated.3 | 1 + python/.changelog.d/4464.incompatible | 1 + python/.changelog.d/4464.incompatible.1 | 1 + python/src/trezorlib/benchmark.py | 14 +- python/src/trezorlib/binance.py | 37 +-- python/src/trezorlib/btc.py | 98 ++++--- python/src/trezorlib/cardano.py | 297 ++++++++++------------ python/src/trezorlib/cli/debug.py | 8 +- python/src/trezorlib/cli/device.py | 62 ++--- python/src/trezorlib/cli/fido.py | 12 +- python/src/trezorlib/cli/settings.py | 64 ++--- python/src/trezorlib/cli/solana.py | 6 +- python/src/trezorlib/client.py | 20 +- python/src/trezorlib/debuglink.py | 38 +-- python/src/trezorlib/device.py | 153 ++++++----- python/src/trezorlib/eos.py | 12 +- python/src/trezorlib/ethereum.py | 43 ++-- python/src/trezorlib/fido.py | 51 ++-- python/src/trezorlib/firmware/__init__.py | 9 +- python/src/trezorlib/misc.py | 40 ++- python/src/trezorlib/monero.py | 16 +- python/src/trezorlib/nem.py | 15 +- python/src/trezorlib/protobuf.py | 23 ++ python/src/trezorlib/ripple.py | 16 +- python/src/trezorlib/solana.py | 26 +- python/src/trezorlib/stellar.py | 15 +- python/src/trezorlib/tezos.py | 23 +- 28 files changed, 518 insertions(+), 584 deletions(-) create mode 100644 python/.changelog.d/4464.deprecated.1 create mode 100644 python/.changelog.d/4464.deprecated.3 create mode 100644 python/.changelog.d/4464.incompatible create mode 100644 python/.changelog.d/4464.incompatible.1 diff --git a/python/.changelog.d/4464.deprecated.1 b/python/.changelog.d/4464.deprecated.1 new file mode 100644 index 0000000000..12dda3d95f --- /dev/null +++ b/python/.changelog.d/4464.deprecated.1 @@ -0,0 +1 @@ +String return values are deprecated in functions where the semantic result is a success (specifically those that were returning the message from Trezor's `Success` response). Type annotations are updated to `str | None`, and in a future release those functions will be returning `None` on success, or raise an exception on a failure. diff --git a/python/.changelog.d/4464.deprecated.3 b/python/.changelog.d/4464.deprecated.3 new file mode 100644 index 0000000000..a7dcbca4a0 --- /dev/null +++ b/python/.changelog.d/4464.deprecated.3 @@ -0,0 +1 @@ +Return value of `device.recover()` is deprecated. In the future, this function will return `None`. diff --git a/python/.changelog.d/4464.incompatible b/python/.changelog.d/4464.incompatible new file mode 100644 index 0000000000..2c8d4b814b --- /dev/null +++ b/python/.changelog.d/4464.incompatible @@ -0,0 +1 @@ +Return values in `solana` module were changed from the wrapping protobuf messages to the raw inner values (`str` for address, `bytes` for pubkey / signature). diff --git a/python/.changelog.d/4464.incompatible.1 b/python/.changelog.d/4464.incompatible.1 new file mode 100644 index 0000000000..f96056584f --- /dev/null +++ b/python/.changelog.d/4464.incompatible.1 @@ -0,0 +1 @@ +`trezorctl device` commands whose default result is a success will not print anything to stdout anymore, in line with Unix philosophy. diff --git a/python/src/trezorlib/benchmark.py b/python/src/trezorlib/benchmark.py index f96ef7970e..6587e2a3ab 100644 --- a/python/src/trezorlib/benchmark.py +++ b/python/src/trezorlib/benchmark.py @@ -17,20 +17,18 @@ from typing import TYPE_CHECKING from . import messages -from .tools import expect if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType -@expect(messages.BenchmarkNames) def list_names( client: "TrezorClient", -) -> "MessageType": - return client.call(messages.BenchmarkListNames()) +) -> messages.BenchmarkNames: + return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames) -@expect(messages.BenchmarkResult) -def run(client: "TrezorClient", name: str) -> "MessageType": - return client.call(messages.BenchmarkRun(name=name)) +def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult: + return client.call( + messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult + ) diff --git a/python/src/trezorlib/binance.py b/python/src/trezorlib/binance.py index d2e4b97912..938092a2df 100644 --- a/python/src/trezorlib/binance.py +++ b/python/src/trezorlib/binance.py @@ -18,35 +18,34 @@ from typing import TYPE_CHECKING from . import messages from .protobuf import dict_to_proto -from .tools import expect, session +from .tools import session if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address -@expect(messages.BinanceAddress, field="address", ret_type=str) def get_address( client: "TrezorClient", address_n: "Address", show_display: bool = False, chunkify: bool = False, -) -> "MessageType": +) -> str: return client.call( messages.BinanceGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify - ) - ) + ), + expect=messages.BinanceAddress, + ).address -@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes) def get_public_key( client: "TrezorClient", address_n: "Address", show_display: bool = False -) -> "MessageType": +) -> bytes: return client.call( - messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) - ) + messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display), + expect=messages.BinancePublicKey, + ).public_key @session @@ -60,13 +59,7 @@ def sign_tx( tx_msg["chunkify"] = chunkify envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) - response = client.call(envelope) - - if not isinstance(response, messages.BinanceTxRequest): - raise RuntimeError( - "Invalid response, expected BinanceTxRequest, received " - + type(response).__name__ - ) + client.call(envelope, expect=messages.BinanceTxRequest) if "refid" in msg: msg = dict_to_proto(messages.BinanceCancelMsg, msg) @@ -77,12 +70,4 @@ def sign_tx( else: raise ValueError("can not determine msg type") - response = client.call(msg) - - if not isinstance(response, messages.BinanceSignedTx): - raise RuntimeError( - "Invalid response, expected BinanceSignedTx, received " - + type(response).__name__ - ) - - return response + return client.call(msg, expect=messages.BinanceSignedTx) diff --git a/python/src/trezorlib/btc.py b/python/src/trezorlib/btc.py index a71ead2adc..078f486d9e 100644 --- a/python/src/trezorlib/btc.py +++ b/python/src/trezorlib/btc.py @@ -14,6 +14,8 @@ # You should have received a copy of the License along with this library. # If not, see . +from __future__ import annotations + import warnings from copy import copy from decimal import Decimal @@ -23,11 +25,10 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple from typing_extensions import Protocol, TypedDict from . import exceptions, messages -from .tools import expect, prepare_message_bytes, session +from .tools import _return_success, prepare_message_bytes, session if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address class ScriptSig(TypedDict): @@ -103,7 +104,6 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType: ) -@expect(messages.PublicKey) def get_public_node( client: "TrezorClient", n: "Address", @@ -114,13 +114,12 @@ def get_public_node( ignore_xpub_magic: bool = False, unlock_path: Optional[List[int]] = None, unlock_path_mac: Optional[bytes] = None, -) -> "MessageType": +) -> messages.PublicKey: if unlock_path: - res = client.call( - messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) + client.call( + messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), + expect=messages.UnlockedPathRequest, ) - if not isinstance(res, messages.UnlockedPathRequest): - raise exceptions.TrezorException("Unexpected message") return client.call( messages.GetPublicKey( @@ -130,16 +129,15 @@ def get_public_node( coin_name=coin_name, script_type=script_type, ignore_xpub_magic=ignore_xpub_magic, - ) + ), + expect=messages.PublicKey, ) -@expect(messages.Address, field="address", ret_type=str) -def get_address(*args: Any, **kwargs: Any): - return get_authenticated_address(*args, **kwargs) +def get_address(*args: Any, **kwargs: Any) -> str: + return get_authenticated_address(*args, **kwargs).address -@expect(messages.Address) def get_authenticated_address( client: "TrezorClient", coin_name: str, @@ -151,13 +149,12 @@ def get_authenticated_address( unlock_path: Optional[List[int]] = None, unlock_path_mac: Optional[bytes] = None, chunkify: bool = False, -) -> "MessageType": +) -> messages.Address: if unlock_path: - res = client.call( - messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) + client.call( + messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), + expect=messages.UnlockedPathRequest, ) - if not isinstance(res, messages.UnlockedPathRequest): - raise exceptions.TrezorException("Unexpected message") return client.call( messages.GetAddress( @@ -168,26 +165,27 @@ def get_authenticated_address( script_type=script_type, ignore_xpub_magic=ignore_xpub_magic, chunkify=chunkify, - ) + ), + expect=messages.Address, ) -@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes) def get_ownership_id( client: "TrezorClient", coin_name: str, n: "Address", multisig: Optional[messages.MultisigRedeemScriptType] = None, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, -) -> "MessageType": +) -> bytes: return client.call( messages.GetOwnershipId( address_n=n, coin_name=coin_name, multisig=multisig, script_type=script_type, - ) - ) + ), + expect=messages.OwnershipId, + ).ownership_id def get_ownership_proof( @@ -202,9 +200,7 @@ def get_ownership_proof( preauthorized: bool = False, ) -> Tuple[bytes, bytes]: if preauthorized: - res = client.call(messages.DoPreauthorized()) - if not isinstance(res, messages.PreauthorizedRequest): - raise exceptions.TrezorException("Unexpected message") + client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) res = client.call( messages.GetOwnershipProof( @@ -215,16 +211,13 @@ def get_ownership_proof( user_confirmation=user_confirmation, ownership_ids=ownership_ids, commitment_data=commitment_data, - ) + ), + expect=messages.OwnershipProof, ) - if not isinstance(res, messages.OwnershipProof): - raise exceptions.TrezorException("Unexpected message") - return res.ownership_proof, res.signature -@expect(messages.MessageSignature) def sign_message( client: "TrezorClient", coin_name: str, @@ -233,7 +226,7 @@ def sign_message( script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, no_script_type: bool = False, chunkify: bool = False, -) -> "MessageType": +) -> messages.MessageSignature: return client.call( messages.SignMessage( coin_name=coin_name, @@ -242,7 +235,8 @@ def sign_message( script_type=script_type, no_script_type=no_script_type, chunkify=chunkify, - ) + ), + expect=messages.MessageSignature, ) @@ -255,18 +249,19 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - resp = client.call( + client.call( messages.VerifyMessage( address=address, signature=signature, message=prepare_message_bytes(message), coin_name=coin_name, chunkify=chunkify, - ) + ), + expect=messages.Success, ) + return True except exceptions.TrezorFailure: return False - return isinstance(resp, messages.Success) @session @@ -319,17 +314,14 @@ def sign_tx( setattr(signtx, name, value) if unlock_path: - res = client.call( - messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) + client.call( + messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac), + expect=messages.UnlockedPathRequest, ) - if not isinstance(res, messages.UnlockedPathRequest): - raise exceptions.TrezorException("Unexpected message") elif preauthorized: - res = client.call(messages.DoPreauthorized()) - if not isinstance(res, messages.PreauthorizedRequest): - raise exceptions.TrezorException("Unexpected message") + client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest) - res = client.call(signtx) + res = client.call(signtx, expect=messages.TxRequest) # Prepare structure for signatures signatures: List[Optional[bytes]] = [None] * len(inputs) @@ -357,7 +349,7 @@ def sign_tx( ) R = messages.RequestType - while isinstance(res, messages.TxRequest): + while True: # If there's some part of signed transaction, let's add it if res.serialized: if res.serialized.serialized_tx: @@ -388,7 +380,7 @@ def sign_tx( if res.request_type == R.TXPAYMENTREQ: assert res.details.request_index is not None msg = payment_reqs[res.details.request_index] - res = client.call(msg) + res = client.call(msg, expect=messages.TxRequest) else: msg = messages.TransactionType() if res.request_type == R.TXMETA: @@ -418,10 +410,7 @@ def sign_tx( f"Unknown request type - {res.request_type}." ) - res = client.call(messages.TxAck(tx=msg)) - - if not isinstance(res, messages.TxRequest): - raise exceptions.TrezorException("Unexpected message") + res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest) for i, sig in zip(inputs, signatures): if i.script_type != messages.InputScriptType.EXTERNAL and sig is None: @@ -430,7 +419,6 @@ def sign_tx( return signatures, serialized_tx -@expect(messages.Success, field="message", ret_type=str) def authorize_coinjoin( client: "TrezorClient", coordinator: str, @@ -440,8 +428,8 @@ def authorize_coinjoin( n: "Address", coin_name: str, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, -) -> "MessageType": - return client.call( +) -> str | None: + resp = client.call( messages.AuthorizeCoinJoin( coordinator=coordinator, max_rounds=max_rounds, @@ -450,5 +438,7 @@ def authorize_coinjoin( address_n=n, coin_name=coin_name, script_type=script_type, - ) + ), + expect=messages.Success, ) + return _return_success(resp) diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index 49d2c6463f..4cbc635f1f 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -31,12 +31,11 @@ from typing import ( Union, ) -from . import exceptions, messages, tools -from .tools import expect +from . import messages as m +from . import tools if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType PROTOCOL_MAGICS = { "mainnet": 764824073, @@ -72,35 +71,33 @@ INCOMPLETE_OUTPUT_ERROR_MESSAGE = "The output is missing some fields" INVALID_OUTPUT_TOKEN_BUNDLE_ENTRY = "The output's token_bundle entry is invalid" INVALID_MINT_TOKEN_BUNDLE_ENTRY = "The mint token_bundle entry is invalid" -InputWithPath = Tuple[messages.CardanoTxInput, List[int]] -CollateralInputWithPath = Tuple[messages.CardanoTxCollateralInput, List[int]] -AssetGroupWithTokens = Tuple[messages.CardanoAssetGroup, List[messages.CardanoToken]] +InputWithPath = Tuple[m.CardanoTxInput, List[int]] +CollateralInputWithPath = Tuple[m.CardanoTxCollateralInput, List[int]] +AssetGroupWithTokens = Tuple[m.CardanoAssetGroup, List[m.CardanoToken]] OutputWithData = Tuple[ - messages.CardanoTxOutput, + m.CardanoTxOutput, List[AssetGroupWithTokens], - List[messages.CardanoTxInlineDatumChunk], - List[messages.CardanoTxReferenceScriptChunk], + List[m.CardanoTxInlineDatumChunk], + List[m.CardanoTxReferenceScriptChunk], ] OutputItem = Union[ - messages.CardanoTxOutput, - messages.CardanoAssetGroup, - messages.CardanoToken, - messages.CardanoTxInlineDatumChunk, - messages.CardanoTxReferenceScriptChunk, + m.CardanoTxOutput, + m.CardanoAssetGroup, + m.CardanoToken, + m.CardanoTxInlineDatumChunk, + m.CardanoTxReferenceScriptChunk, ] CertificateItem = Union[ - messages.CardanoTxCertificate, - messages.CardanoPoolOwner, - messages.CardanoPoolRelayParameters, -] -MintItem = Union[ - messages.CardanoTxMint, messages.CardanoAssetGroup, messages.CardanoToken + m.CardanoTxCertificate, + m.CardanoPoolOwner, + m.CardanoPoolRelayParameters, ] +MintItem = Union[m.CardanoTxMint, m.CardanoAssetGroup, m.CardanoToken] PoolOwnersAndRelays = Tuple[ - List[messages.CardanoPoolOwner], List[messages.CardanoPoolRelayParameters] + List[m.CardanoPoolOwner], List[m.CardanoPoolRelayParameters] ] CertificateWithPoolOwnersAndRelays = Tuple[ - messages.CardanoTxCertificate, Optional[PoolOwnersAndRelays] + m.CardanoTxCertificate, Optional[PoolOwnersAndRelays] ] Path = List[int] Witness = Tuple[Path, bytes] @@ -108,9 +105,7 @@ AuxiliaryDataSupplement = Dict[str, Union[int, bytes]] SignTxResponse = Dict[str, Union[bytes, List[Witness], AuxiliaryDataSupplement]] Chunk = TypeVar( "Chunk", - bound=Union[ - messages.CardanoTxInlineDatumChunk, messages.CardanoTxReferenceScriptChunk - ], + bound=Union[m.CardanoTxInlineDatumChunk, m.CardanoTxReferenceScriptChunk], ) @@ -123,7 +118,7 @@ def parse_optional_int(value: Optional[str]) -> Optional[int]: def create_address_parameters( - address_type: messages.CardanoAddressType, + address_type: m.CardanoAddressType, address_n: List[int], address_n_staking: Optional[List[int]] = None, staking_key_hash: Optional[bytes] = None, @@ -132,18 +127,18 @@ def create_address_parameters( certificate_index: Optional[int] = None, script_payment_hash: Optional[bytes] = None, script_staking_hash: Optional[bytes] = None, -) -> messages.CardanoAddressParametersType: +) -> m.CardanoAddressParametersType: certificate_pointer = None if address_type in ( - messages.CardanoAddressType.POINTER, - messages.CardanoAddressType.POINTER_SCRIPT, + m.CardanoAddressType.POINTER, + m.CardanoAddressType.POINTER_SCRIPT, ): certificate_pointer = _create_certificate_pointer( block_index, tx_index, certificate_index ) - return messages.CardanoAddressParametersType( + return m.CardanoAddressParametersType( address_type=address_type, address_n=address_n, address_n_staking=address_n_staking, @@ -158,11 +153,11 @@ def _create_certificate_pointer( block_index: Optional[int], tx_index: Optional[int], certificate_index: Optional[int], -) -> messages.CardanoBlockchainPointerType: +) -> m.CardanoBlockchainPointerType: if block_index is None or tx_index is None or certificate_index is None: raise ValueError("Invalid pointer parameters") - return messages.CardanoBlockchainPointerType( + return m.CardanoBlockchainPointerType( block_index=block_index, tx_index=tx_index, certificate_index=certificate_index ) @@ -173,7 +168,7 @@ def parse_input(tx_input: dict) -> InputWithPath: path = tools.parse_path(tx_input.get("path", "")) return ( - messages.CardanoTxInput( + m.CardanoTxInput( prev_hash=bytes.fromhex(tx_input["prev_hash"]), prev_index=tx_input["prev_index"], ), @@ -204,22 +199,22 @@ def parse_output(output: dict) -> OutputWithData: datum_hash = parse_optional_bytes(output.get("datum_hash")) - serialization_format = messages.CardanoTxOutputSerializationFormat.ARRAY_LEGACY + serialization_format = m.CardanoTxOutputSerializationFormat.ARRAY_LEGACY if "format" in output: serialization_format = output["format"] inline_datum_size, inline_datum_chunks = _parse_chunkable_data( parse_optional_bytes(output.get("inline_datum")), - messages.CardanoTxInlineDatumChunk, + m.CardanoTxInlineDatumChunk, ) reference_script_size, reference_script_chunks = _parse_chunkable_data( parse_optional_bytes(output.get("reference_script")), - messages.CardanoTxReferenceScriptChunk, + m.CardanoTxReferenceScriptChunk, ) return ( - messages.CardanoTxOutput( + m.CardanoTxOutput( address=address, address_parameters=address_parameters, amount=int(output["amount"]), @@ -253,7 +248,7 @@ def _parse_token_bundle( result.append( ( - messages.CardanoAssetGroup( + m.CardanoAssetGroup( policy_id=bytes.fromhex(token_group["policy_id"]), tokens_count=len(tokens), ), @@ -264,7 +259,7 @@ def _parse_token_bundle( return result -def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]: +def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[m.CardanoToken]: error_message: str if is_mint: error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY @@ -288,7 +283,7 @@ def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.Cardan amount = int(token["amount"]) result.append( - messages.CardanoToken( + m.CardanoToken( asset_name_bytes=bytes.fromhex(token["asset_name_bytes"]), amount=amount, mint_amount=mint_amount, @@ -300,7 +295,7 @@ def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.Cardan def _parse_address_parameters( address_parameters: dict, error_message: str -) -> messages.CardanoAddressParametersType: +) -> m.CardanoAddressParametersType: if "addressType" not in address_parameters: raise ValueError(error_message) @@ -317,7 +312,7 @@ def _parse_address_parameters( ) return create_address_parameters( - messages.CardanoAddressType(address_parameters["addressType"]), + m.CardanoAddressType(address_parameters["addressType"]), payment_path, staking_path, staking_key_hash_bytes, @@ -346,7 +341,7 @@ def _create_data_chunks(data: bytes) -> Iterator[bytes]: processed_size += MAX_CHUNK_SIZE -def parse_native_script(native_script: dict) -> messages.CardanoNativeScript: +def parse_native_script(native_script: dict) -> m.CardanoNativeScript: if "type" not in native_script: raise ValueError("Script is missing some fields") @@ -364,7 +359,7 @@ def parse_native_script(native_script: dict) -> messages.CardanoNativeScript: invalid_before = parse_optional_int(native_script.get("invalid_before")) invalid_hereafter = parse_optional_int(native_script.get("invalid_hereafter")) - return messages.CardanoNativeScript( + return m.CardanoNativeScript( type=type, scripts=scripts, key_hash=key_hash, @@ -385,7 +380,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: certificate_type = certificate["type"] - if certificate_type == messages.CardanoCertificateType.STAKE_DELEGATION: + if certificate_type == m.CardanoCertificateType.STAKE_DELEGATION: if "pool" not in certificate: raise CERTIFICATE_MISSING_FIELDS_ERROR @@ -394,7 +389,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: ) return ( - messages.CardanoTxCertificate( + m.CardanoTxCertificate( type=certificate_type, path=path, pool=bytes.fromhex(certificate["pool"]), @@ -404,15 +399,15 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: None, ) elif certificate_type in ( - messages.CardanoCertificateType.STAKE_REGISTRATION, - messages.CardanoCertificateType.STAKE_DEREGISTRATION, + m.CardanoCertificateType.STAKE_REGISTRATION, + m.CardanoCertificateType.STAKE_DEREGISTRATION, ): path, script_hash, key_hash = _parse_credential( certificate, CERTIFICATE_MISSING_FIELDS_ERROR ) return ( - messages.CardanoTxCertificate( + m.CardanoTxCertificate( type=certificate_type, path=path, script_hash=script_hash, @@ -421,8 +416,8 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: None, ) elif certificate_type in ( - messages.CardanoCertificateType.STAKE_REGISTRATION_CONWAY, - messages.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY, + m.CardanoCertificateType.STAKE_REGISTRATION_CONWAY, + m.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY, ): if "deposit" not in certificate: raise CERTIFICATE_MISSING_FIELDS_ERROR @@ -432,7 +427,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: ) return ( - messages.CardanoTxCertificate( + m.CardanoTxCertificate( type=certificate_type, path=path, script_hash=script_hash, @@ -441,7 +436,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: ), None, ) - elif certificate_type == messages.CardanoCertificateType.STAKE_POOL_REGISTRATION: + elif certificate_type == m.CardanoCertificateType.STAKE_POOL_REGISTRATION: pool_parameters = certificate["pool_parameters"] if any( @@ -450,9 +445,9 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: ): raise CERTIFICATE_MISSING_FIELDS_ERROR - pool_metadata: Optional[messages.CardanoPoolMetadataType] + pool_metadata: Optional[m.CardanoPoolMetadataType] if pool_parameters.get("metadata") is not None: - pool_metadata = messages.CardanoPoolMetadataType( + pool_metadata = m.CardanoPoolMetadataType( url=pool_parameters["metadata"]["url"], hash=bytes.fromhex(pool_parameters["metadata"]["hash"]), ) @@ -469,9 +464,9 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: ] return ( - messages.CardanoTxCertificate( + m.CardanoTxCertificate( type=certificate_type, - pool_parameters=messages.CardanoPoolParametersType( + pool_parameters=m.CardanoPoolParametersType( pool_id=bytes.fromhex(pool_parameters["pool_id"]), vrf_key_hash=bytes.fromhex(pool_parameters["vrf_key_hash"]), pledge=int(pool_parameters["pledge"]), @@ -486,7 +481,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: ), (owners, relays), ) - if certificate_type == messages.CardanoCertificateType.VOTE_DELEGATION: + if certificate_type == m.CardanoCertificateType.VOTE_DELEGATION: if "drep" not in certificate: raise CERTIFICATE_MISSING_FIELDS_ERROR @@ -495,13 +490,13 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays: ) return ( - messages.CardanoTxCertificate( + m.CardanoTxCertificate( type=certificate_type, path=path, script_hash=script_hash, key_hash=key_hash, - drep=messages.CardanoDRep( - type=messages.CardanoDRepType(certificate["drep"]["type"]), + drep=m.CardanoDRep( + type=m.CardanoDRepType(certificate["drep"]["type"]), key_hash=parse_optional_bytes(certificate["drep"].get("key_hash")), script_hash=parse_optional_bytes( certificate["drep"].get("script_hash") @@ -527,21 +522,21 @@ def _parse_credential( return path, script_hash, key_hash -def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner: +def _parse_pool_owner(pool_owner: dict) -> m.CardanoPoolOwner: if "staking_key_path" in pool_owner: - return messages.CardanoPoolOwner( + return m.CardanoPoolOwner( staking_key_path=tools.parse_path(pool_owner["staking_key_path"]) ) - return messages.CardanoPoolOwner( + return m.CardanoPoolOwner( staking_key_hash=bytes.fromhex(pool_owner["staking_key_hash"]) ) -def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters: - pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"]) +def _parse_pool_relay(pool_relay: dict) -> m.CardanoPoolRelayParameters: + pool_relay_type = m.CardanoPoolRelayType(pool_relay["type"]) - if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP: + if pool_relay_type == m.CardanoPoolRelayType.SINGLE_HOST_IP: ipv4_address_packed = ( ip_address(pool_relay["ipv4_address"]).packed if "ipv4_address" in pool_relay @@ -553,20 +548,20 @@ def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters: else None ) - return messages.CardanoPoolRelayParameters( + return m.CardanoPoolRelayParameters( type=pool_relay_type, port=int(pool_relay["port"]), ipv4_address=ipv4_address_packed, ipv6_address=ipv6_address_packed, ) - elif pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_NAME: - return messages.CardanoPoolRelayParameters( + elif pool_relay_type == m.CardanoPoolRelayType.SINGLE_HOST_NAME: + return m.CardanoPoolRelayParameters( type=pool_relay_type, port=int(pool_relay["port"]), host_name=pool_relay["host_name"], ) - elif pool_relay_type == messages.CardanoPoolRelayType.MULTIPLE_HOST_NAME: - return messages.CardanoPoolRelayParameters( + elif pool_relay_type == m.CardanoPoolRelayType.MULTIPLE_HOST_NAME: + return m.CardanoPoolRelayParameters( type=pool_relay_type, host_name=pool_relay["host_name"], ) @@ -574,7 +569,7 @@ def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters: raise ValueError("Unknown pool relay type") -def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal: +def parse_withdrawal(withdrawal: dict) -> m.CardanoTxWithdrawal: WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError( "The withdrawal is missing some fields" ) @@ -586,7 +581,7 @@ def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal: withdrawal, WITHDRAWAL_MISSING_FIELDS_ERROR ) - return messages.CardanoTxWithdrawal( + return m.CardanoTxWithdrawal( path=path, amount=int(withdrawal["amount"]), script_hash=script_hash, @@ -596,7 +591,7 @@ def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal: def parse_auxiliary_data( auxiliary_data: Optional[dict], -) -> Optional[messages.CardanoTxAuxiliaryData]: +) -> Optional[m.CardanoTxAuxiliaryData]: if auxiliary_data is None: return None @@ -620,17 +615,17 @@ def parse_auxiliary_data( if not all(k in delegation for k in REQUIRED_FIELDS_CVOTE_DELEGATION): raise AUXILIARY_DATA_MISSING_FIELDS_ERROR delegations.append( - messages.CardanoCVoteRegistrationDelegation( + m.CardanoCVoteRegistrationDelegation( vote_public_key=bytes.fromhex(delegation["vote_public_key"]), weight=int(delegation["weight"]), ) ) voting_purpose = None - if serialization_format == messages.CardanoCVoteRegistrationFormat.CIP36: + if serialization_format == m.CardanoCVoteRegistrationFormat.CIP36: voting_purpose = cvote_registration.get("voting_purpose") - cvote_registration_parameters = messages.CardanoCVoteRegistrationParametersType( + cvote_registration_parameters = m.CardanoCVoteRegistrationParametersType( vote_public_key=parse_optional_bytes( cvote_registration.get("vote_public_key") ), @@ -653,7 +648,7 @@ def parse_auxiliary_data( if hash is None and cvote_registration_parameters is None: raise AUXILIARY_DATA_MISSING_FIELDS_ERROR - return messages.CardanoTxAuxiliaryData( + return m.CardanoTxAuxiliaryData( hash=hash, cvote_registration_parameters=cvote_registration_parameters, ) @@ -673,7 +668,7 @@ def parse_collateral_input(collateral_input: dict) -> CollateralInputWithPath: path = tools.parse_path(collateral_input.get("path", "")) return ( - messages.CardanoTxCollateralInput( + m.CardanoTxCollateralInput( prev_hash=bytes.fromhex(collateral_input["prev_hash"]), prev_index=collateral_input["prev_index"], ), @@ -681,20 +676,20 @@ def parse_collateral_input(collateral_input: dict) -> CollateralInputWithPath: ) -def parse_required_signer(required_signer: dict) -> messages.CardanoTxRequiredSigner: +def parse_required_signer(required_signer: dict) -> m.CardanoTxRequiredSigner: key_hash = parse_optional_bytes(required_signer.get("key_hash")) key_path = tools.parse_path(required_signer.get("key_path", "")) - return messages.CardanoTxRequiredSigner( + return m.CardanoTxRequiredSigner( key_hash=key_hash, key_path=key_path, ) -def parse_reference_input(reference_input: dict) -> messages.CardanoTxReferenceInput: +def parse_reference_input(reference_input: dict) -> m.CardanoTxReferenceInput: if not all(k in reference_input for k in REQUIRED_FIELDS_INPUT): raise ValueError("The reference input is missing some fields") - return messages.CardanoTxReferenceInput( + return m.CardanoTxReferenceInput( prev_hash=bytes.fromhex(reference_input["prev_hash"]), prev_index=reference_input["prev_index"], ) @@ -712,16 +707,16 @@ def parse_additional_witness_request( def _get_witness_requests( inputs: Sequence[InputWithPath], certificates: Sequence[CertificateWithPoolOwnersAndRelays], - withdrawals: Sequence[messages.CardanoTxWithdrawal], + withdrawals: Sequence[m.CardanoTxWithdrawal], collateral_inputs: Sequence[CollateralInputWithPath], - required_signers: Sequence[messages.CardanoTxRequiredSigner], + required_signers: Sequence[m.CardanoTxRequiredSigner], additional_witness_requests: Sequence[Path], - signing_mode: messages.CardanoTxSigningMode, -) -> List[messages.CardanoTxWitnessRequest]: + signing_mode: m.CardanoTxSigningMode, +) -> List[m.CardanoTxWitnessRequest]: paths = set() # don't gather paths from tx elements in MULTISIG_TRANSACTION signing mode - if signing_mode != messages.CardanoTxSigningMode.MULTISIG_TRANSACTION: + if signing_mode != m.CardanoTxSigningMode.MULTISIG_TRANSACTION: for _, path in inputs: if path: paths.add(tuple(path)) @@ -729,18 +724,17 @@ def _get_witness_requests( if ( certificate.type in ( - messages.CardanoCertificateType.STAKE_DEREGISTRATION, - messages.CardanoCertificateType.STAKE_DELEGATION, - messages.CardanoCertificateType.STAKE_REGISTRATION_CONWAY, - messages.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY, - messages.CardanoCertificateType.VOTE_DELEGATION, + m.CardanoCertificateType.STAKE_DEREGISTRATION, + m.CardanoCertificateType.STAKE_DELEGATION, + m.CardanoCertificateType.STAKE_REGISTRATION_CONWAY, + m.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY, + m.CardanoCertificateType.VOTE_DELEGATION, ) and certificate.path ): paths.add(tuple(certificate.path)) elif ( - certificate.type - == messages.CardanoCertificateType.STAKE_POOL_REGISTRATION + certificate.type == m.CardanoCertificateType.STAKE_POOL_REGISTRATION and pool_owners_and_relays is not None ): owners, _ = pool_owners_and_relays @@ -752,7 +746,7 @@ def _get_witness_requests( paths.add(tuple(withdrawal.path)) # gather Plutus-related paths - if signing_mode == messages.CardanoTxSigningMode.PLUTUS_TRANSACTION: + if signing_mode == m.CardanoTxSigningMode.PLUTUS_TRANSACTION: for _, path in collateral_inputs: if path: paths.add(tuple(path)) @@ -765,10 +759,10 @@ def _get_witness_requests( paths.add(tuple(additional_witness_request)) sorted_paths = sorted([list(path) for path in paths]) - return [messages.CardanoTxWitnessRequest(path=path) for path in sorted_paths] + return [m.CardanoTxWitnessRequest(path=path) for path in sorted_paths] -def _get_inputs_items(inputs: List[InputWithPath]) -> Iterator[messages.CardanoTxInput]: +def _get_inputs_items(inputs: List[InputWithPath]) -> Iterator[m.CardanoTxInput]: for input, _ in inputs: yield input @@ -807,7 +801,7 @@ def _get_certificates_items( def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]: if not mint: return - yield messages.CardanoTxMint(asset_groups_count=len(mint)) + yield m.CardanoTxMint(asset_groups_count=len(mint)) for asset_group, tokens in mint: yield asset_group yield from tokens @@ -815,7 +809,7 @@ def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]: def _get_collateral_inputs_items( collateral_inputs: Sequence[CollateralInputWithPath], -) -> Iterator[messages.CardanoTxCollateralInput]: +) -> Iterator[m.CardanoTxCollateralInput]: for collateral_input, _ in collateral_inputs: yield collateral_input @@ -823,88 +817,86 @@ def _get_collateral_inputs_items( # ====== Client functions ====== # -@expect(messages.CardanoAddress, field="address", ret_type=str) def get_address( client: "TrezorClient", - address_parameters: messages.CardanoAddressParametersType, + address_parameters: m.CardanoAddressParametersType, protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], show_display: bool = False, - derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, + derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, chunkify: bool = False, -) -> "MessageType": +) -> str: return client.call( - messages.CardanoGetAddress( + m.CardanoGetAddress( address_parameters=address_parameters, protocol_magic=protocol_magic, network_id=network_id, show_display=show_display, derivation_type=derivation_type, chunkify=chunkify, - ) - ) + ), + expect=m.CardanoAddress, + ).address -@expect(messages.CardanoPublicKey) def get_public_key( client: "TrezorClient", address_n: List[int], - derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, + derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, show_display: bool = False, -) -> "MessageType": +) -> m.CardanoPublicKey: return client.call( - messages.CardanoGetPublicKey( + m.CardanoGetPublicKey( address_n=address_n, derivation_type=derivation_type, show_display=show_display, - ) + ), + expect=m.CardanoPublicKey, ) -@expect(messages.CardanoNativeScriptHash) def get_native_script_hash( client: "TrezorClient", - native_script: messages.CardanoNativeScript, - display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, - derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, -) -> "MessageType": + native_script: m.CardanoNativeScript, + display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE, + derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, +) -> m.CardanoNativeScriptHash: return client.call( - messages.CardanoGetNativeScriptHash( + m.CardanoGetNativeScriptHash( script=native_script, display_format=display_format, derivation_type=derivation_type, - ) + ), + expect=m.CardanoNativeScriptHash, ) def sign_tx( client: "TrezorClient", - signing_mode: messages.CardanoTxSigningMode, + signing_mode: m.CardanoTxSigningMode, inputs: List[InputWithPath], outputs: List[OutputWithData], fee: int, ttl: Optional[int], validity_interval_start: Optional[int], certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (), - withdrawals: Sequence[messages.CardanoTxWithdrawal] = (), + withdrawals: Sequence[m.CardanoTxWithdrawal] = (), protocol_magic: int = PROTOCOL_MAGICS["mainnet"], network_id: int = NETWORK_IDS["mainnet"], - auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None, + auxiliary_data: Optional[m.CardanoTxAuxiliaryData] = None, mint: Sequence[AssetGroupWithTokens] = (), script_data_hash: Optional[bytes] = None, collateral_inputs: Sequence[CollateralInputWithPath] = (), - required_signers: Sequence[messages.CardanoTxRequiredSigner] = (), + required_signers: Sequence[m.CardanoTxRequiredSigner] = (), collateral_return: Optional[OutputWithData] = None, total_collateral: Optional[int] = None, - reference_inputs: Sequence[messages.CardanoTxReferenceInput] = (), + reference_inputs: Sequence[m.CardanoTxReferenceInput] = (), additional_witness_requests: Sequence[Path] = (), - derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, + derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS, include_network_id: bool = False, chunkify: bool = False, tag_cbor_sets: bool = False, ) -> Dict[str, Any]: - UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response") - witness_requests = _get_witness_requests( inputs, certificates, @@ -916,7 +908,7 @@ def sign_tx( ) response = client.call( - messages.CardanoSignTxInit( + m.CardanoSignTxInit( signing_mode=signing_mode, inputs_count=len(inputs), outputs_count=len(outputs), @@ -940,10 +932,9 @@ def sign_tx( include_network_id=include_network_id, chunkify=chunkify, tag_cbor_sets=tag_cbor_sets, - ) + ), + expect=m.CardanoTxItemAck, ) - if not isinstance(response, messages.CardanoTxItemAck): - raise UNEXPECTED_RESPONSE_ERROR for tx_item in chain( _get_inputs_items(inputs), @@ -951,55 +942,41 @@ def sign_tx( _get_certificates_items(certificates), withdrawals, ): - response = client.call(tx_item) - if not isinstance(response, messages.CardanoTxItemAck): - raise UNEXPECTED_RESPONSE_ERROR + response = client.call(tx_item, expect=m.CardanoTxItemAck) sign_tx_response: Dict[str, Any] = {} if auxiliary_data is not None: - auxiliary_data_supplement = client.call(auxiliary_data) - if not isinstance( - auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement - ): - raise UNEXPECTED_RESPONSE_ERROR + auxiliary_data_supplement = client.call( + auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement + ) if ( auxiliary_data_supplement.type - != messages.CardanoTxAuxiliaryDataSupplementType.NONE + != m.CardanoTxAuxiliaryDataSupplementType.NONE ): sign_tx_response["auxiliary_data_supplement"] = ( auxiliary_data_supplement.__dict__ ) - response = client.call(messages.CardanoTxHostAck()) - if not isinstance(response, messages.CardanoTxItemAck): - raise UNEXPECTED_RESPONSE_ERROR + response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck) for tx_item in chain( _get_mint_items(mint), _get_collateral_inputs_items(collateral_inputs), required_signers, ): - response = client.call(tx_item) - if not isinstance(response, messages.CardanoTxItemAck): - raise UNEXPECTED_RESPONSE_ERROR + response = client.call(tx_item, expect=m.CardanoTxItemAck) if collateral_return is not None: for tx_item in _get_output_items(collateral_return): - response = client.call(tx_item) - if not isinstance(response, messages.CardanoTxItemAck): - raise UNEXPECTED_RESPONSE_ERROR + response = client.call(tx_item, expect=m.CardanoTxItemAck) for reference_input in reference_inputs: - response = client.call(reference_input) - if not isinstance(response, messages.CardanoTxItemAck): - raise UNEXPECTED_RESPONSE_ERROR + response = client.call(reference_input, expect=m.CardanoTxItemAck) sign_tx_response["witnesses"] = [] for witness_request in witness_requests: - response = client.call(witness_request) - if not isinstance(response, messages.CardanoTxWitnessResponse): - raise UNEXPECTED_RESPONSE_ERROR + response = client.call(witness_request, expect=m.CardanoTxWitnessResponse) sign_tx_response["witnesses"].append( { "type": response.type, @@ -1009,13 +986,9 @@ def sign_tx( } ) - response = client.call(messages.CardanoTxHostAck()) - if not isinstance(response, messages.CardanoTxBodyHash): - raise UNEXPECTED_RESPONSE_ERROR + response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash) sign_tx_response["tx_hash"] = response.tx_hash - response = client.call(messages.CardanoTxHostAck()) - if not isinstance(response, messages.CardanoSignTxFinished): - raise UNEXPECTED_RESPONSE_ERROR + response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished) return sign_tx_response diff --git a/python/src/trezorlib/cli/debug.py b/python/src/trezorlib/cli/debug.py index 50613a04ee..d9d936c7ab 100644 --- a/python/src/trezorlib/cli/debug.py +++ b/python/src/trezorlib/cli/debug.py @@ -107,16 +107,16 @@ def record_screen_from_connection( @cli.command() @with_client -def prodtest_t1(client: "TrezorClient") -> str: +def prodtest_t1(client: "TrezorClient") -> None: """Perform a prodtest on Model One. Only available on PRODTEST firmware and on T1B1. Formerly named self-test. """ - return debuglink_prodtest_t1(client) + debuglink_prodtest_t1(client) @cli.command() @with_client -def optiga_set_sec_max(client: "TrezorClient") -> str: +def optiga_set_sec_max(client: "TrezorClient") -> None: """Set Optiga's security event counter to maximum.""" - return debuglink_optiga_set_sec_max(client) + debuglink_optiga_set_sec_max(client) diff --git a/python/src/trezorlib/cli/device.py b/python/src/trezorlib/cli/device.py index 07dfcd7524..0803b85a69 100644 --- a/python/src/trezorlib/cli/device.py +++ b/python/src/trezorlib/cli/device.py @@ -29,7 +29,6 @@ from . import ChoiceType, with_client if t.TYPE_CHECKING: from ..client import TrezorClient - from ..protobuf import MessageType from . import TrezorConnection RECOVERY_DEVICE_INPUT_METHOD = { @@ -66,7 +65,7 @@ def cli() -> None: is_flag=True, ) @with_client -def wipe(client: "TrezorClient", bootloader: bool) -> str: +def wipe(client: "TrezorClient", bootloader: bool) -> None: """Reset device to factory defaults and remove all private data.""" if bootloader: if not client.features.bootloader_mode: @@ -87,11 +86,7 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str: else: click.echo("Wiping user data!") - try: - return device.wipe(client) - except exceptions.TrezorFailure as e: - click.echo("Action failed: {} {}".format(*e.args)) - sys.exit(3) + device.wipe(client) @cli.command() @@ -116,7 +111,7 @@ def load( academic: bool, needs_backup: bool, no_backup: bool, -) -> str: +) -> None: """Upload seed and custom configuration to the device. This functionality is only available in debug mode. @@ -136,7 +131,7 @@ def load( label = "ACADEMIC" try: - return debuglink.load_device( + debuglink.load_device( client, mnemonic=list(mnemonic), pin=pin, @@ -184,7 +179,7 @@ def recover( input_method: messages.RecoveryDeviceInputMethod, dry_run: bool, unlock_repeated_backup: bool, -) -> "MessageType": +) -> None: """Start safe recovery workflow.""" if input_method == messages.RecoveryDeviceInputMethod.ScrambledWords: input_callback = ui.mnemonic_words(expand) @@ -201,7 +196,7 @@ def recover( if unlock_repeated_backup: type = messages.RecoveryType.UnlockRepeatedBackup - return device.recover( + device.recover( client, word_count=int(words), passphrase_protection=passphrase_protection, @@ -236,21 +231,13 @@ def setup( no_backup: bool, backup_type: messages.BackupType | None, entropy_check_count: int | None, -) -> str: +) -> None: """Perform device setup and generate new seed.""" if strength: strength = int(strength) BT = messages.BackupType - if backup_type is None: - if client.version >= (2, 7, 1): - # SLIP39 extendable was introduced in 2.7.1 - backup_type = BT.Slip39_Single_Extendable - else: - # this includes both T1 and older trezor-cores - backup_type = BT.Bip39 - if ( backup_type in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) @@ -264,7 +251,7 @@ def setup( "backup type. Traditional BIP39 backup may be generated instead." ) - resp, path_xpubs = device.reset_entropy_check( + path_xpubs = device.setup( client, strength=strength, passphrase_protection=passphrase_protection, @@ -277,13 +264,10 @@ def setup( entropy_check_count=entropy_check_count, ) - if isinstance(resp, messages.Success): + if path_xpubs: click.echo("XPUBs for the generated seed") for path, xpub in path_xpubs: click.echo(f"{format_path(path)}: {xpub}") - return resp.message or "" - else: - raise RuntimeError(f"Received {resp.__class__}") @cli.command() @@ -294,10 +278,9 @@ def backup( client: "TrezorClient", group_threshold: int | None = None, groups: t.Sequence[tuple[int, int]] = (), -) -> str: +) -> None: """Perform device seed backup.""" - - return device.backup(client, group_threshold, groups) + device.backup(client, group_threshold, groups) @cli.command() @@ -305,7 +288,7 @@ def backup( @with_client def sd_protect( client: "TrezorClient", operation: messages.SdProtectOperationType -) -> str: +) -> None: """Secure the device with SD card protection. When SD card protection is enabled, a randomly generated secret is stored @@ -321,12 +304,12 @@ def sd_protect( """ if client.features.model == "1": raise click.ClickException("Trezor One does not support SD card protection.") - return device.sd_protect(client, operation) + device.sd_protect(client, operation) @cli.command() @click.pass_obj -def reboot_to_bootloader(obj: "TrezorConnection") -> str: +def reboot_to_bootloader(obj: "TrezorConnection") -> None: """Reboot device into bootloader mode. Currently only supported on Trezor Model One. @@ -334,21 +317,21 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str: # avoid using @with_client because it closes the session afterwards, # which triggers double prompt on device with obj.client_context() as client: - return device.reboot_to_bootloader(client) + device.reboot_to_bootloader(client) @cli.command() @with_client -def tutorial(client: "TrezorClient") -> str: +def tutorial(client: "TrezorClient") -> None: """Show on-device tutorial.""" - return device.show_device_tutorial(client) + device.show_device_tutorial(client) @cli.command() @with_client -def unlock_bootloader(client: "TrezorClient") -> str: +def unlock_bootloader(client: "TrezorClient") -> None: """Unlocks bootloader. Irreversible.""" - return device.unlock_bootloader(client) + device.unlock_bootloader(client) @cli.command() @@ -360,10 +343,11 @@ def unlock_bootloader(client: "TrezorClient") -> str: help="Dialog expiry in seconds.", ) @with_client -def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str: +def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None: """Show a "Do not disconnect" dialog.""" if enable is False: - return device.set_busy(client, None) + device.set_busy(client, None) + return if expiry is None: raise click.ClickException("Missing option '-e' / '--expiry'.") @@ -373,7 +357,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." ) - return device.set_busy(client, expiry * 1000) + device.set_busy(client, expiry * 1000) PUBKEY_WHITELIST_URL_TEMPLATE = ( diff --git a/python/src/trezorlib/cli/fido.py b/python/src/trezorlib/cli/fido.py index 5983c57249..b51bb74e12 100644 --- a/python/src/trezorlib/cli/fido.py +++ b/python/src/trezorlib/cli/fido.py @@ -80,12 +80,12 @@ def credentials_list(client: "TrezorClient") -> None: @credentials.command(name="add") @click.argument("hex_credential_id") @with_client -def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: +def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None: """Add the credential with the given ID as a resident credential. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. """ - return fido.add_credential(client, bytes.fromhex(hex_credential_id)) + fido.add_credential(client, bytes.fromhex(hex_credential_id)) @credentials.command(name="remove") @@ -93,9 +93,9 @@ def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." ) @with_client -def credentials_remove(client: "TrezorClient", index: int) -> str: +def credentials_remove(client: "TrezorClient", index: int) -> None: """Remove the resident credential at the given index.""" - return fido.remove_credential(client, index) + fido.remove_credential(client, index) # @@ -111,9 +111,9 @@ def counter() -> None: @counter.command(name="set") @click.argument("counter", type=int) @with_client -def counter_set(client: "TrezorClient", counter: int) -> str: +def counter_set(client: "TrezorClient", counter: int) -> None: """Set FIDO/U2F counter value.""" - return fido.set_counter(client, counter) + fido.set_counter(client, counter) @counter.command(name="get-next") diff --git a/python/src/trezorlib/cli/settings.py b/python/src/trezorlib/cli/settings.py index eac93eb796..946f3fbffc 100644 --- a/python/src/trezorlib/cli/settings.py +++ b/python/src/trezorlib/cli/settings.py @@ -181,17 +181,17 @@ def cli() -> None: @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) @with_client -def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: +def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: """Set, change or remove PIN.""" # Remove argument is there for backwards compatibility - return device.change_pin(client, remove=_should_remove(enable, remove)) + device.change_pin(client, remove=_should_remove(enable, remove)) @cli.command() @click.option("-r", "--remove", is_flag=True, hidden=True) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) @with_client -def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: +def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None: """Set or remove the wipe code. The wipe code functions as a "self-destruct PIN". If the wipe code is ever @@ -199,7 +199,7 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s removed and the device will be reset to factory defaults. """ # Remove argument is there for backwards compatibility - return device.change_wipe_code(client, remove=_should_remove(enable, remove)) + device.change_wipe_code(client, remove=_should_remove(enable, remove)) @cli.command() @@ -207,24 +207,24 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.argument("label") @with_client -def label(client: "TrezorClient", label: str) -> str: +def label(client: "TrezorClient", label: str) -> None: """Set new device label.""" - return device.apply_settings(client, label=label) + device.apply_settings(client, label=label) @cli.command() @with_client -def brightness(client: "TrezorClient") -> str: +def brightness(client: "TrezorClient") -> None: """Set display brightness.""" - return device.set_brightness(client) + device.set_brightness(client) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) @with_client -def haptic_feedback(client: "TrezorClient", enable: bool) -> str: +def haptic_feedback(client: "TrezorClient", enable: bool) -> None: """Enable or disable haptic feedback.""" - return device.apply_settings(client, haptic_feedback=enable) + device.apply_settings(client, haptic_feedback=enable) @cli.command() @@ -236,7 +236,7 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str: @with_client def language( client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None -) -> str: +) -> None: """Set new language with translations.""" if remove != (path_or_url is None): raise click.ClickException("Either provide a path or URL or use --remove") @@ -259,27 +259,27 @@ def language( raise click.ClickException( f"Failed to load translations from {path_or_url}" ) from None - return device.change_language( - client, language_data=language_data, show_display=display - ) + device.change_language(client, language_data=language_data, show_display=display) @cli.command() @click.argument("rotation", type=ChoiceType(ROTATION)) @with_client -def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str: +def display_rotation( + client: "TrezorClient", rotation: messages.DisplayRotation +) -> None: """Set display rotation. Configure display rotation for Trezor Model T. The options are north, east, south or west. """ - return device.apply_settings(client, display_rotation=rotation) + device.apply_settings(client, display_rotation=rotation) @cli.command() @click.argument("delay", type=str) @with_client -def auto_lock_delay(client: "TrezorClient", delay: str) -> str: +def auto_lock_delay(client: "TrezorClient", delay: str) -> None: """Set auto-lock delay (in seconds).""" if not client.features.pin_protection: @@ -291,13 +291,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str: seconds = float(value) * units[unit] else: seconds = float(delay) # assume seconds if no unit is specified - return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) + device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) @cli.command() @click.argument("flags") @with_client -def flags(client: "TrezorClient", flags: str) -> str: +def flags(client: "TrezorClient", flags: str) -> None: """Set device flags.""" if flags.lower().startswith("0b"): flags_int = int(flags, 2) @@ -305,7 +305,7 @@ def flags(client: "TrezorClient", flags: str) -> str: flags_int = int(flags, 16) else: flags_int = int(flags) - return device.apply_flags(client, flags=flags_int) + device.apply_flags(client, flags=flags_int) @cli.command() @@ -315,7 +315,7 @@ def flags(client: "TrezorClient", flags: str) -> str: ) @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") @with_client -def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: +def homescreen(client: "TrezorClient", filename: str, quality: int) -> None: """Set new homescreen. To revert to default homescreen, use 'trezorctl set homescreen default' @@ -369,7 +369,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: "Unknown image format requested by the device." ) - return device.apply_settings(client, homescreen=img) + device.apply_settings(client, homescreen=img) @cli.command() @@ -380,7 +380,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: @with_client def safety_checks( client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel -) -> str: +) -> None: """Set safety check level. Set to "strict" to get the full Trezor security (default setting). @@ -392,18 +392,18 @@ def safety_checks( """ if always and level == messages.SafetyCheckLevel.PromptTemporarily: level = messages.SafetyCheckLevel.PromptAlways - return device.apply_settings(client, safety_checks=level) + device.apply_settings(client, safety_checks=level) @cli.command() @click.argument("enable", type=ChoiceType({"on": True, "off": False})) @with_client -def experimental_features(client: "TrezorClient", enable: bool) -> str: +def experimental_features(client: "TrezorClient", enable: bool) -> None: """Enable or disable experimental message types. This is a developer feature. Use with caution. """ - return device.apply_settings(client, experimental_features=enable) + device.apply_settings(client, experimental_features=enable) # @@ -427,13 +427,13 @@ passphrase = cast(AliasedGroup, passphrase_main) @passphrase.command(name="on") @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) @with_client -def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str: +def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None: """Enable passphrase.""" if client.features.passphrase_protection is not True: use_passphrase = True else: use_passphrase = None - return device.apply_settings( + device.apply_settings( client, use_passphrase=use_passphrase, passphrase_always_on_device=force_on_device, @@ -442,9 +442,9 @@ def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> st @passphrase.command(name="off") @with_client -def passphrase_off(client: "TrezorClient") -> str: +def passphrase_off(client: "TrezorClient") -> None: """Disable passphrase.""" - return device.apply_settings(client, use_passphrase=False) + device.apply_settings(client, use_passphrase=False) # Registering the aliases for backwards compatibility @@ -458,9 +458,9 @@ passphrase.aliases = { @passphrase.command(name="hide") @click.argument("hide", type=ChoiceType({"on": True, "off": False})) @with_client -def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str: +def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None: """Enable or disable hiding passphrase coming from host. This is a developer feature. Use with caution. """ - return device.apply_settings(client, hide_passphrase_from_host=hide) + device.apply_settings(client, hide_passphrase_from_host=hide) diff --git a/python/src/trezorlib/cli/solana.py b/python/src/trezorlib/cli/solana.py index 3fe80a5164..590b4f7914 100644 --- a/python/src/trezorlib/cli/solana.py +++ b/python/src/trezorlib/cli/solana.py @@ -26,7 +26,7 @@ def get_public_key( client: "TrezorClient", address: str, show_display: bool, -) -> messages.SolanaPublicKey: +) -> bytes: """Get Solana public key.""" address_n = tools.parse_path(address) return solana.get_public_key(client, address_n, show_display) @@ -42,7 +42,7 @@ def get_address( address: str, show_display: bool, chunkify: bool, -) -> messages.SolanaAddress: +) -> str: """Get Solana address.""" address_n = tools.parse_path(address) return solana.get_address(client, address_n, show_display, chunkify) @@ -58,7 +58,7 @@ def sign_tx( address: str, serialized_tx: str, additional_information_file: Optional[TextIO], -) -> messages.SolanaTxSignature: +) -> bytes: """Sign Solana transaction.""" address_n = tools.parse_path(address) diff --git a/python/src/trezorlib/client.py b/python/src/trezorlib/client.py index fa7992ab0e..4e432bd012 100644 --- a/python/src/trezorlib/client.py +++ b/python/src/trezorlib/client.py @@ -27,7 +27,7 @@ from . import exceptions, mapping, messages, models from .log import DUMP_BYTES from .messages import Capability from .protobuf import MessageType -from .tools import expect, parse_path, session +from .tools import parse_path, session if TYPE_CHECKING: from .transport import Transport @@ -397,12 +397,7 @@ class TrezorClient(Generic[UI]): else: raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) - @expect(messages.Success, field="message", ret_type=str) - def ping( - self, - msg: str, - button_protection: bool = False, - ) -> MessageType: + def ping(self, msg: str, button_protection: bool = False) -> str: # We would like ping to work on any valid TrezorClient instance, but # due to the protection modes, we need to go through self.call, and that will # raise an exception if the firmware is too old. @@ -416,13 +411,18 @@ class TrezorClient(Generic[UI]): # device is PIN-locked. # respond and hope for the best resp = self._callback_button(resp) - return resp + resp = messages.Success.ensure_isinstance(resp) + assert resp.message is not None + return resp.message finally: self.close() - return self.call( - messages.Ping(message=msg, button_protection=button_protection) + resp = self.call( + messages.Ping(message=msg, button_protection=button_protection), + expect=messages.Success, ) + assert resp.message is not None + return resp.message def get_device_id(self) -> Optional[str]: return self.features.device_id diff --git a/python/src/trezorlib/debuglink.py b/python/src/trezorlib/debuglink.py index d1f89db35f..56c1dfa344 100644 --- a/python/src/trezorlib/debuglink.py +++ b/python/src/trezorlib/debuglink.py @@ -44,10 +44,9 @@ from mnemonic import Mnemonic from . import mapping, messages, models, protobuf from .client import TrezorClient -from .exceptions import TrezorFailure +from .exceptions import TrezorFailure, UnexpectedMessageError from .log import DUMP_BYTES from .messages import DebugWaitType -from .tools import expect if TYPE_CHECKING: from typing_extensions import Protocol @@ -775,9 +774,10 @@ class DebugLink: else: self.t1_take_screenshots = False - @expect(messages.DebugLinkMemory, field="memory", ret_type=bytes) - def memory_read(self, address: int, length: int) -> protobuf.MessageType: - return self._call(messages.DebugLinkMemoryRead(address=address, length=length)) + def memory_read(self, address: int, length: int) -> bytes: + return self._call( + messages.DebugLinkMemoryRead(address=address, length=length) + ).memory def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None: self._write( @@ -787,9 +787,11 @@ class DebugLink: def flash_erase(self, sector: int) -> None: self._write(messages.DebugLinkFlashErase(sector=sector)) - @expect(messages.Success) def erase_sd_card(self, format: bool = True) -> messages.Success: - return self._call(messages.DebugLinkEraseSdCard(format=format)) + res = self._call(messages.DebugLinkEraseSdCard(format=format)) + if not isinstance(res, messages.Success): + raise UnexpectedMessageError(messages.Success, res) + return res def snapshot_legacy(self) -> None: """Snapshot the current state of the device.""" @@ -1350,7 +1352,6 @@ class TrezorClientDebugLink(TrezorClient): raise RuntimeError("Unexpected call") -@expect(messages.Success, field="message", ret_type=str) def load_device( client: "TrezorClient", mnemonic: Union[str, Iterable[str]], @@ -1360,7 +1361,7 @@ def load_device( skip_checksum: bool = False, needs_backup: bool = False, no_backup: bool = False, -) -> protobuf.MessageType: +) -> None: if isinstance(mnemonic, str): mnemonic = [mnemonic] @@ -1371,7 +1372,7 @@ def load_device( "Device is initialized already. Call device.wipe() and try again." ) - resp = client.call( + client.call( messages.LoadDevice( mnemonics=mnemonics, pin=pin, @@ -1380,25 +1381,25 @@ def load_device( skip_checksum=skip_checksum, needs_backup=needs_backup, no_backup=no_backup, - ) + ), + expect=messages.Success, ) client.init_device() - return resp # keep the old name for compatibility load_device_by_mnemonic = load_device -@expect(messages.Success, field="message", ret_type=str) -def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType: +def prodtest_t1(client: "TrezorClient") -> None: if client.features.bootloader_mode is not True: raise RuntimeError("Device must be in bootloader mode") - return client.call( + client.call( messages.ProdTestT1( payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" - ) + ), + expect=messages.Success, ) @@ -1450,6 +1451,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool: return debug_client.features.fw_vendor == "EMULATOR" -@expect(messages.Success, field="message", ret_type=str) -def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType: - return client.call(messages.DebugLinkOptigaSetSecMax()) +def optiga_set_sec_max(client: "TrezorClient") -> None: + client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success) diff --git a/python/src/trezorlib/device.py b/python/src/trezorlib/device.py index fcf948f523..c08d485ed0 100644 --- a/python/src/trezorlib/device.py +++ b/python/src/trezorlib/device.py @@ -18,7 +18,6 @@ from __future__ import annotations import hashlib import hmac -import os import random import secrets import time @@ -29,11 +28,16 @@ from slip10 import SLIP10 from . import messages from .exceptions import Cancelled, TrezorException -from .tools import Address, expect, parse_path, session +from .tools import ( + Address, + _deprecation_retval_helper, + _return_success, + parse_path, + session, +) if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType RECOVERY_BACK = "\x08" # backspace character, sent literally @@ -42,7 +46,6 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1) ENTROPY_CHECK_MIN_VERSION = (2, 8, 7) -@expect(messages.Success, field="message", ret_type=str) @session def apply_settings( client: "TrezorClient", @@ -57,7 +60,7 @@ def apply_settings( experimental_features: Optional[bool] = None, hide_passphrase_from_host: Optional[bool] = None, haptic_feedback: Optional[bool] = None, -) -> "MessageType": +) -> str | None: if language is not None: warnings.warn( "language ignored. Use change_language() to set device language.", @@ -76,87 +79,80 @@ def apply_settings( haptic_feedback=haptic_feedback, ) - out = client.call(settings) + out = client.call(settings, expect=messages.Success) client.refresh_features() - return out + return _return_success(out) def _send_language_data( client: "TrezorClient", request: "messages.TranslationDataRequest", language_data: bytes, -) -> "MessageType": - response: MessageType = request +) -> None: + response = request while not isinstance(response, messages.Success): - assert isinstance(response, messages.TranslationDataRequest) + response = messages.TranslationDataRequest.ensure_isinstance(response) data_length = response.data_length data_offset = response.data_offset chunk = language_data[data_offset : data_offset + data_length] response = client.call(messages.TranslationDataAck(data_chunk=chunk)) - return response - -@expect(messages.Success, field="message", ret_type=str) @session def change_language( client: "TrezorClient", language_data: bytes, show_display: bool | None = None, -) -> "MessageType": +) -> str | None: data_length = len(language_data) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) response = client.call(msg) if data_length > 0: - assert isinstance(response, messages.TranslationDataRequest) - response = _send_language_data(client, response, language_data) - assert isinstance(response, messages.Success) + response = messages.TranslationDataRequest.ensure_isinstance(response) + _send_language_data(client, response, language_data) + else: + messages.Success.ensure_isinstance(response) client.refresh_features() # changing the language in features - return response + return _return_success(messages.Success(message="Language changed.")) -@expect(messages.Success, field="message", ret_type=str) @session -def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": - out = client.call(messages.ApplyFlags(flags=flags)) +def apply_flags(client: "TrezorClient", flags: int) -> str | None: + out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success) client.refresh_features() - return out + return _return_success(out) -@expect(messages.Success, field="message", ret_type=str) @session -def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": - ret = client.call(messages.ChangePin(remove=remove)) +def change_pin(client: "TrezorClient", remove: bool = False) -> str | None: + ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success) client.refresh_features() - return ret + return _return_success(ret) -@expect(messages.Success, field="message", ret_type=str) @session -def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": - ret = client.call(messages.ChangeWipeCode(remove=remove)) +def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None: + ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success) client.refresh_features() - return ret + return _return_success(ret) -@expect(messages.Success, field="message", ret_type=str) @session def sd_protect( client: "TrezorClient", operation: messages.SdProtectOperationType -) -> "MessageType": - ret = client.call(messages.SdProtect(operation=operation)) +) -> str | None: + ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success) client.refresh_features() - return ret + return _return_success(ret) -@expect(messages.Success, field="message", ret_type=str) @session -def wipe(client: "TrezorClient") -> "MessageType": - ret = client.call(messages.WipeDevice()) +def wipe(client: "TrezorClient") -> str | None: + ret = client.call(messages.WipeDevice(), expect=messages.Success) if not client.features.bootloader_mode: client.init_device() - return ret + return _return_success(ret) @session @@ -173,7 +169,7 @@ def recover( u2f_counter: Optional[int] = None, *, type: Optional[messages.RecoveryType] = None, -) -> "MessageType": +) -> messages.Success | None: if language is not None: warnings.warn( "language ignored. Use change_language() to set device language.", @@ -235,8 +231,12 @@ def recover( except Cancelled: res = client.call(messages.Cancel()) + # check that the result is a Success + res = messages.Success.ensure_isinstance(res) + # reinitialize the device client.init_device() - return res + + return _deprecation_retval_helper(res) def is_slip39_backup_type(backup_type: messages.BackupType): @@ -279,13 +279,7 @@ def _seed_from_entropy( return seed -@expect(messages.Success, field="message", ret_type=str) -def reset(*args: Any, **kwargs: Any) -> "MessageType": - return reset_entropy_check(*args, **kwargs)[0] - - -@session -def reset_entropy_check( +def reset( client: "TrezorClient", display_random: bool = False, strength: Optional[int] = None, @@ -576,13 +570,12 @@ def _reset_with_entropycheck( return xpubs -@expect(messages.Success, field="message", ret_type=str) @session def backup( client: "TrezorClient", group_threshold: Optional[int] = None, groups: Iterable[tuple[int, int]] = (), -) -> "MessageType": +) -> str | None: ret = client.call( messages.BackupDevice( group_threshold=group_threshold, @@ -590,38 +583,39 @@ def backup( messages.Slip39Group(member_threshold=t, member_count=c) for t, c in groups ], - ) + ), + expect=messages.Success, ) client.refresh_features() - return ret + return _return_success(ret) -@expect(messages.Success, field="message", ret_type=str) -def cancel_authorization(client: "TrezorClient") -> "MessageType": - return client.call(messages.CancelAuthorization()) +def cancel_authorization(client: "TrezorClient") -> str | None: + ret = client.call(messages.CancelAuthorization(), expect=messages.Success) + return _return_success(ret) -@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes) -def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType": - resp = client.call(messages.UnlockPath(address_n=n)) +def unlock_path(client: "TrezorClient", n: "Address") -> bytes: + resp = client.call( + messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest + ) # Cancel the UnlockPath workflow now that we have the authentication code. try: client.call(messages.Cancel()) except Cancelled: - return resp + return resp.mac else: raise TrezorException("Unexpected response in UnlockPath flow") @session -@expect(messages.Success, field="message", ret_type=str) def reboot_to_bootloader( client: "TrezorClient", boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, firmware_header: Optional[bytes] = None, language_data: bytes = b"", -) -> "MessageType": +) -> str | None: response = client.call( messages.RebootToBootloader( boot_command=boot_command, @@ -631,41 +625,42 @@ def reboot_to_bootloader( ) if isinstance(response, messages.TranslationDataRequest): response = _send_language_data(client, response, language_data) - return response + return _return_success(messages.Success(message="")) @session -@expect(messages.Success, field="message", ret_type=str) -def show_device_tutorial(client: "TrezorClient") -> "MessageType": - return client.call(messages.ShowDeviceTutorial()) +def show_device_tutorial(client: "TrezorClient") -> str | None: + ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success) + return _return_success(ret) @session -@expect(messages.Success, field="message", ret_type=str) -def unlock_bootloader(client: "TrezorClient") -> "MessageType": - return client.call(messages.UnlockBootloader()) +def unlock_bootloader(client: "TrezorClient") -> str | None: + ret = client.call(messages.UnlockBootloader(), expect=messages.Success) + return _return_success(ret) -@expect(messages.Success, field="message", ret_type=str) @session -def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType": +def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None: """Sets or clears the busy state of the device. In the busy state the device shows a "Do not disconnect" message instead of the homescreen. Setting `expiry_ms=None` clears the busy state. """ - ret = client.call(messages.SetBusy(expiry_ms=expiry_ms)) + ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success) client.refresh_features() - return ret + return _return_success(ret) -@expect(messages.AuthenticityProof) -def authenticate(client: "TrezorClient", challenge: bytes): - return client.call(messages.AuthenticateDevice(challenge=challenge)) +def authenticate( + client: "TrezorClient", challenge: bytes +) -> messages.AuthenticityProof: + return client.call( + messages.AuthenticateDevice(challenge=challenge), + expect=messages.AuthenticityProof, + ) -@expect(messages.Success, field="message", ret_type=str) -def set_brightness( - client: "TrezorClient", value: Optional[int] = None -) -> "MessageType": - return client.call(messages.SetBrightness(value=value)) +def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None: + ret = client.call(messages.SetBrightness(value=value), expect=messages.Success) + return _return_success(ret) diff --git a/python/src/trezorlib/eos.py b/python/src/trezorlib/eos.py index 1ffaafb4ab..eb491f204c 100644 --- a/python/src/trezorlib/eos.py +++ b/python/src/trezorlib/eos.py @@ -18,11 +18,10 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Tuple from . import exceptions, messages -from .tools import b58decode, expect, session +from .tools import b58decode, session if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address @@ -319,14 +318,13 @@ def parse_transaction_json( # ====== Client functions ====== # -@expect(messages.EosPublicKey) def get_public_key( client: "TrezorClient", n: "Address", show_display: bool = False -) -> "MessageType": - response = client.call( - messages.EosGetPublicKey(address_n=n, show_display=show_display) +) -> messages.EosPublicKey: + return client.call( + messages.EosGetPublicKey(address_n=n, show_display=show_display), + expect=messages.EosPublicKey, ) - return response @session diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 1cf2eeeaed..96ce4d1066 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -18,11 +18,10 @@ import re from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from . import definitions, exceptions, messages -from .tools import expect, prepare_message_bytes, session, unharden +from .tools import prepare_message_bytes, session, unharden if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address @@ -161,30 +160,32 @@ def network_from_address_n( # ====== Client functions ====== # -@expect(messages.EthereumAddress, field="address", ret_type=str) def get_address( client: "TrezorClient", n: "Address", show_display: bool = False, encoded_network: Optional[bytes] = None, chunkify: bool = False, -) -> "MessageType": - return client.call( +) -> str: + resp = client.call( messages.EthereumGetAddress( address_n=n, show_display=show_display, encoded_network=encoded_network, chunkify=chunkify, - ) + ), + expect=messages.EthereumAddress, ) + assert resp.address is not None + return resp.address -@expect(messages.EthereumPublicKey) def get_public_node( client: "TrezorClient", n: "Address", show_display: bool = False -) -> "MessageType": +) -> messages.EthereumPublicKey: return client.call( - messages.EthereumGetPublicKey(address_n=n, show_display=show_display) + messages.EthereumGetPublicKey(address_n=n, show_display=show_display), + expect=messages.EthereumPublicKey, ) @@ -297,25 +298,24 @@ def sign_tx_eip1559( return response.signature_v, response.signature_r, response.signature_s -@expect(messages.EthereumMessageSignature) def sign_message( client: "TrezorClient", n: "Address", message: AnyStr, encoded_network: Optional[bytes] = None, chunkify: bool = False, -) -> "MessageType": +) -> messages.EthereumMessageSignature: return client.call( messages.EthereumSignMessage( address_n=n, message=prepare_message_bytes(message), encoded_network=encoded_network, chunkify=chunkify, - ) + ), + expect=messages.EthereumMessageSignature, ) -@expect(messages.EthereumTypedDataSignature) def sign_typed_data( client: "TrezorClient", n: "Address", @@ -323,7 +323,7 @@ def sign_typed_data( *, metamask_v4_compat: bool = True, definitions: Optional[messages.EthereumDefinitions] = None, -) -> "MessageType": +) -> messages.EthereumTypedDataSignature: data = sanitize_typed_data(data) types = data["types"] @@ -387,7 +387,7 @@ def sign_typed_data( request = messages.EthereumTypedDataValueAck(value=encoded_data) response = client.call(request) - return response + return messages.EthereumTypedDataSignature.ensure_isinstance(response) def verify_message( @@ -398,32 +398,33 @@ def verify_message( chunkify: bool = False, ) -> bool: try: - resp = client.call( + client.call( messages.EthereumVerifyMessage( address=address, signature=signature, message=prepare_message_bytes(message), chunkify=chunkify, - ) + ), + expect=messages.Success, ) + return True except exceptions.TrezorFailure: return False - return isinstance(resp, messages.Success) -@expect(messages.EthereumTypedDataSignature) def sign_typed_data_hash( client: "TrezorClient", n: "Address", domain_hash: bytes, message_hash: Optional[bytes], encoded_network: Optional[bytes] = None, -) -> "MessageType": +) -> messages.EthereumTypedDataSignature: return client.call( messages.EthereumSignTypedHash( address_n=n, domain_separator_hash=domain_hash, message_hash=message_hash, encoded_network=encoded_network, - ) + ), + expect=messages.EthereumTypedDataSignature, ) diff --git a/python/src/trezorlib/fido.py b/python/src/trezorlib/fido.py index 4ed6f22951..a2618b72db 100644 --- a/python/src/trezorlib/fido.py +++ b/python/src/trezorlib/fido.py @@ -14,42 +14,45 @@ # You should have received a copy of the License along with this library. # If not, see . -from typing import TYPE_CHECKING, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence from . import messages -from .tools import expect +from .tools import _return_success if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType -@expect( - messages.WebAuthnCredentials, - field="credentials", - ret_type=List[messages.WebAuthnCredential], -) -def list_credentials(client: "TrezorClient") -> "MessageType": - return client.call(messages.WebAuthnListResidentCredentials()) - - -@expect(messages.Success, field="message", ret_type=str) -def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType": +def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]: return client.call( - messages.WebAuthnAddResidentCredential(credential_id=credential_id) + messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials + ).credentials + + +def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None: + ret = client.call( + messages.WebAuthnAddResidentCredential(credential_id=credential_id), + expect=messages.Success, ) + return _return_success(ret) -@expect(messages.Success, field="message", ret_type=str) -def remove_credential(client: "TrezorClient", index: int) -> "MessageType": - return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) +def remove_credential(client: "TrezorClient", index: int) -> str | None: + ret = client.call( + messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success + ) + return _return_success(ret) -@expect(messages.Success, field="message", ret_type=str) -def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": - return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) +def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None: + ret = client.call( + messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success + ) + return _return_success(ret) -@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) -def get_next_counter(client: "TrezorClient") -> "MessageType": - return client.call(messages.GetNextU2FCounter()) +def get_next_counter(client: "TrezorClient") -> int: + ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter) + return ret.u2f_counter diff --git a/python/src/trezorlib/firmware/__init__.py b/python/src/trezorlib/firmware/__init__.py index 5cc5d8830c..4cfc11dd40 100644 --- a/python/src/trezorlib/firmware/__init__.py +++ b/python/src/trezorlib/firmware/__init__.py @@ -20,7 +20,7 @@ from hashlib import blake2s from typing_extensions import Protocol, TypeGuard from .. import messages -from ..tools import expect, session +from ..tools import session from .core import VendorFirmware from .legacy import LegacyFirmware, LegacyV2Firmware @@ -106,6 +106,7 @@ def update( raise RuntimeError(f"Unexpected message {resp}") -@expect(messages.FirmwareHash, field="hash", ret_type=bytes) -def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]): - return client.call(messages.GetFirmwareHash(challenge=challenge)) +def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]) -> bytes: + return client.call( + messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash + ).hash diff --git a/python/src/trezorlib/misc.py b/python/src/trezorlib/misc.py index 4ed6f5aa81..578c1fa19f 100644 --- a/python/src/trezorlib/misc.py +++ b/python/src/trezorlib/misc.py @@ -17,54 +17,50 @@ from typing import TYPE_CHECKING, Optional from . import messages -from .tools import expect if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address -@expect(messages.Entropy, field="entropy", ret_type=bytes) -def get_entropy(client: "TrezorClient", size: int) -> "MessageType": - return client.call(messages.GetEntropy(size=size)) +def get_entropy(client: "TrezorClient", size: int) -> bytes: + return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy -@expect(messages.SignedIdentity) def sign_identity( client: "TrezorClient", identity: messages.IdentityType, challenge_hidden: bytes, challenge_visual: str, ecdsa_curve_name: Optional[str] = None, -) -> "MessageType": +) -> messages.SignedIdentity: return client.call( messages.SignIdentity( identity=identity, challenge_hidden=challenge_hidden, challenge_visual=challenge_visual, ecdsa_curve_name=ecdsa_curve_name, - ) + ), + expect=messages.SignedIdentity, ) -@expect(messages.ECDHSessionKey) def get_ecdh_session_key( client: "TrezorClient", identity: messages.IdentityType, peer_public_key: bytes, ecdsa_curve_name: Optional[str] = None, -) -> "MessageType": +) -> messages.ECDHSessionKey: return client.call( messages.GetECDHSessionKey( identity=identity, peer_public_key=peer_public_key, ecdsa_curve_name=ecdsa_curve_name, - ) + ), + expect=messages.ECDHSessionKey, ) -@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def encrypt_keyvalue( client: "TrezorClient", n: "Address", @@ -73,7 +69,7 @@ def encrypt_keyvalue( ask_on_encrypt: bool = True, ask_on_decrypt: bool = True, iv: bytes = b"", -) -> "MessageType": +) -> bytes: return client.call( messages.CipherKeyValue( address_n=n, @@ -83,11 +79,11 @@ def encrypt_keyvalue( ask_on_encrypt=ask_on_encrypt, ask_on_decrypt=ask_on_decrypt, iv=iv, - ) - ) + ), + expect=messages.CipheredKeyValue, + ).value -@expect(messages.CipheredKeyValue, field="value", ret_type=bytes) def decrypt_keyvalue( client: "TrezorClient", n: "Address", @@ -96,7 +92,7 @@ def decrypt_keyvalue( ask_on_encrypt: bool = True, ask_on_decrypt: bool = True, iv: bytes = b"", -) -> "MessageType": +) -> bytes: return client.call( messages.CipherKeyValue( address_n=n, @@ -106,10 +102,10 @@ def decrypt_keyvalue( ask_on_encrypt=ask_on_encrypt, ask_on_decrypt=ask_on_decrypt, iv=iv, - ) - ) + ), + expect=messages.CipheredKeyValue, + ).value -@expect(messages.Nonce, field="nonce", ret_type=bytes) -def get_nonce(client: "TrezorClient"): - return client.call(messages.GetNonce()) +def get_nonce(client: "TrezorClient") -> bytes: + return client.call(messages.GetNonce(), expect=messages.Nonce).nonce diff --git a/python/src/trezorlib/monero.py b/python/src/trezorlib/monero.py index 5bce7574e8..b2e3214fb9 100644 --- a/python/src/trezorlib/monero.py +++ b/python/src/trezorlib/monero.py @@ -17,11 +17,9 @@ from typing import TYPE_CHECKING from . import messages -from .tools import expect if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address @@ -31,30 +29,30 @@ if TYPE_CHECKING: # FAKECHAIN = 3 -@expect(messages.MoneroAddress, field="address", ret_type=bytes) def get_address( client: "TrezorClient", n: "Address", show_display: bool = False, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, chunkify: bool = False, -) -> "MessageType": +) -> bytes: return client.call( messages.MoneroGetAddress( address_n=n, show_display=show_display, network_type=network_type, chunkify=chunkify, - ) - ) + ), + expect=messages.MoneroAddress, + ).address -@expect(messages.MoneroWatchKey) def get_watch_key( client: "TrezorClient", n: "Address", network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, -) -> "MessageType": +) -> messages.MoneroWatchKey: return client.call( - messages.MoneroGetWatchKey(address_n=n, network_type=network_type) + messages.MoneroGetWatchKey(address_n=n, network_type=network_type), + expect=messages.MoneroWatchKey, ) diff --git a/python/src/trezorlib/nem.py b/python/src/trezorlib/nem.py index 3a67aec72c..744dc3205f 100644 --- a/python/src/trezorlib/nem.py +++ b/python/src/trezorlib/nem.py @@ -18,11 +18,9 @@ import json from typing import TYPE_CHECKING from . import exceptions, messages -from .tools import expect if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address TYPE_TRANSACTION_TRANSFER = 0x0101 @@ -196,25 +194,24 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig # ====== Client functions ====== # -@expect(messages.NEMAddress, field="address", ret_type=str) def get_address( client: "TrezorClient", n: "Address", network: int, show_display: bool = False, chunkify: bool = False, -) -> "MessageType": +) -> str: return client.call( messages.NEMGetAddress( address_n=n, network=network, show_display=show_display, chunkify=chunkify - ) - ) + ), + expect=messages.NEMAddress, + ).address -@expect(messages.NEMSignedTx) def sign_tx( client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False -) -> "MessageType": +) -> messages.NEMSignedTx: try: msg = create_sign_tx(transaction, chunkify=chunkify) except ValueError as e: @@ -222,4 +219,4 @@ def sign_tx( assert msg.transaction is not None msg.transaction.address_n = n - return client.call(msg) + return client.call(msg, expect=messages.NEMSignedTx) diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index 14f61dff87..5a5315f186 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -35,6 +35,8 @@ from itertools import zip_longest import typing_extensions as tx +from .exceptions import UnexpectedMessageError + if t.TYPE_CHECKING: from IPython.lib.pretty import RepresentationPrinter # noqa: I900 @@ -312,6 +314,27 @@ class MessageType: dump_message(data, self) return len(data.getvalue()) + @classmethod + def ensure_isinstance(cls, msg: t.Any) -> tx.Self: + """Ensure that the received `msg` is an instance of this class. + + If `msg` is not an instance of this class, raise an `UnexpectedMessageError`. + otherwise, return it. This is useful for type-checking like so: + + >>> msg = client.call(SomeMessage()) + >>> if isinstance(msg, Foo): + >>> return msg.foo_attr # attribute of Foo, type-checks OK + >>> else: + >>> msg = Bar.ensure_isinstance(msg) # raises if msg is something else + >>> return msg.bar_attr # attribute of Bar, type-checks OK + + If there is just one expected message, you should use the `expect` parameter of + `Client.call` instead. + """ + if not isinstance(msg, cls): + raise UnexpectedMessageError(cls, msg) + return msg + class LimitedReader: def __init__(self, reader: Reader, limit: int) -> None: diff --git a/python/src/trezorlib/ripple.py b/python/src/trezorlib/ripple.py index 7a953b8fac..00a027c6d9 100644 --- a/python/src/trezorlib/ripple.py +++ b/python/src/trezorlib/ripple.py @@ -18,41 +18,39 @@ from typing import TYPE_CHECKING from . import messages from .protobuf import dict_to_proto -from .tools import dict_from_camelcase, expect +from .tools import dict_from_camelcase if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") -@expect(messages.RippleAddress, field="address", ret_type=str) def get_address( client: "TrezorClient", address_n: "Address", show_display: bool = False, chunkify: bool = False, -) -> "MessageType": +) -> str: return client.call( messages.RippleGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify - ) - ) + ), + expect=messages.RippleAddress, + ).address -@expect(messages.RippleSignedTx) def sign_tx( client: "TrezorClient", address_n: "Address", msg: messages.RippleSignTx, chunkify: bool = False, -) -> "MessageType": +) -> messages.RippleSignedTx: msg.address_n = address_n msg.chunkify = chunkify - return client.call(msg) + return client.call(msg, expect=messages.RippleSignedTx) def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: diff --git a/python/src/trezorlib/solana.py b/python/src/trezorlib/solana.py index be7f2e5fcb..0054e0fd92 100644 --- a/python/src/trezorlib/solana.py +++ b/python/src/trezorlib/solana.py @@ -1,51 +1,49 @@ from typing import TYPE_CHECKING, List, Optional from . import messages -from .tools import expect if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType -@expect(messages.SolanaPublicKey) def get_public_key( client: "TrezorClient", address_n: List[int], show_display: bool, -) -> "MessageType": +) -> bytes: return client.call( - messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display) - ) + messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display), + expect=messages.SolanaPublicKey, + ).public_key -@expect(messages.SolanaAddress) def get_address( client: "TrezorClient", address_n: List[int], show_display: bool, chunkify: bool = False, -) -> "MessageType": +) -> str: return client.call( messages.SolanaGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify, - ) - ) + ), + expect=messages.SolanaAddress, + ).address -@expect(messages.SolanaTxSignature) def sign_tx( client: "TrezorClient", address_n: List[int], serialized_tx: bytes, additional_info: Optional[messages.SolanaTxAdditionalInfo], -) -> "MessageType": +) -> bytes: return client.call( messages.SolanaSignTx( address_n=address_n, serialized_tx=serialized_tx, additional_info=additional_info, - ) - ) + ), + expect=messages.SolanaTxSignature, + ).signature diff --git a/python/src/trezorlib/stellar.py b/python/src/trezorlib/stellar.py index ebf81e4fd0..5bd0a749e4 100644 --- a/python/src/trezorlib/stellar.py +++ b/python/src/trezorlib/stellar.py @@ -18,11 +18,9 @@ from decimal import Decimal from typing import TYPE_CHECKING, List, Tuple, Union from . import exceptions, messages -from .tools import expect if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address StellarMessageType = Union[ @@ -323,18 +321,18 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset: # ====== Client functions ====== # -@expect(messages.StellarAddress, field="address", ret_type=str) def get_address( client: "TrezorClient", address_n: "Address", show_display: bool = False, chunkify: bool = False, -) -> "MessageType": +) -> str: return client.call( messages.StellarGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify - ) - ) + ), + expect=messages.StellarAddress, + ).address def sign_tx( @@ -364,10 +362,7 @@ def sign_tx( "Reached end of operations without a signature." ) from None - if not isinstance(resp, messages.StellarSignedTx): - raise exceptions.TrezorException( - f"Unexpected message: {resp.__class__.__name__}" - ) + resp = messages.StellarSignedTx.ensure_isinstance(resp) if operations: raise exceptions.TrezorException( diff --git a/python/src/trezorlib/tezos.py b/python/src/trezorlib/tezos.py index cff06ed6c8..9319aa1eaa 100644 --- a/python/src/trezorlib/tezos.py +++ b/python/src/trezorlib/tezos.py @@ -17,49 +17,46 @@ from typing import TYPE_CHECKING from . import messages -from .tools import expect if TYPE_CHECKING: from .client import TrezorClient - from .protobuf import MessageType from .tools import Address -@expect(messages.TezosAddress, field="address", ret_type=str) def get_address( client: "TrezorClient", address_n: "Address", show_display: bool = False, chunkify: bool = False, -) -> "MessageType": +) -> str: return client.call( messages.TezosGetAddress( address_n=address_n, show_display=show_display, chunkify=chunkify - ) - ) + ), + expect=messages.TezosAddress, + ).address -@expect(messages.TezosPublicKey, field="public_key", ret_type=str) def get_public_key( client: "TrezorClient", address_n: "Address", show_display: bool = False, chunkify: bool = False, -) -> "MessageType": +) -> str: return client.call( messages.TezosGetPublicKey( address_n=address_n, show_display=show_display, chunkify=chunkify - ) - ) + ), + expect=messages.TezosPublicKey, + ).public_key -@expect(messages.TezosSignedTx) def sign_tx( client: "TrezorClient", address_n: "Address", sign_tx_msg: messages.TezosSignTx, chunkify: bool = False, -) -> "MessageType": +) -> messages.TezosSignedTx: sign_tx_msg.address_n = address_n sign_tx_msg.chunkify = chunkify - return client.call(sign_tx_msg) + return client.call(sign_tx_msg, expect=messages.TezosSignedTx)