1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-29 16:51:30 +00:00

refactor(python): replace usages of @expect

This commit is contained in:
matejcik 2025-01-03 16:36:40 +01:00 committed by matejcik
parent 53bdef5bb4
commit 6a5836708f
28 changed files with 518 additions and 584 deletions

View File

@ -0,0 +1 @@
String return values are deprecated in functions where the semantic result is a success (specifically those that were returning the message from Trezor's `Success` response). Type annotations are updated to `str | None`, and in a future release those functions will be returning `None` on success, or raise an exception on a failure.

View File

@ -0,0 +1 @@
Return value of `device.recover()` is deprecated. In the future, this function will return `None`.

View File

@ -0,0 +1 @@
Return values in `solana` module were changed from the wrapping protobuf messages to the raw inner values (`str` for address, `bytes` for pubkey / signature).

View File

@ -0,0 +1 @@
`trezorctl device` commands whose default result is a success will not print anything to stdout anymore, in line with Unix philosophy.

View File

@ -17,20 +17,18 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from . import messages from . import messages
from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
@expect(messages.BenchmarkNames)
def list_names( def list_names(
client: "TrezorClient", client: "TrezorClient",
) -> "MessageType": ) -> messages.BenchmarkNames:
return client.call(messages.BenchmarkListNames()) return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
@expect(messages.BenchmarkResult) def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult:
def run(client: "TrezorClient", name: str) -> "MessageType": return client.call(
return client.call(messages.BenchmarkRun(name=name)) messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
)

View File

@ -18,35 +18,34 @@ from typing import TYPE_CHECKING
from . import messages from . import messages
from .protobuf import dict_to_proto from .protobuf import dict_to_proto
from .tools import expect, session from .tools import session
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
@expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( return client.call(
messages.BinanceGetAddress( messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) ),
) expect=messages.BinanceAddress,
).address
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key( def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType": ) -> bytes:
return client.call( return client.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display) messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
) expect=messages.BinancePublicKey,
).public_key
@session @session
@ -60,13 +59,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg) envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
response = client.call(envelope) client.call(envelope, expect=messages.BinanceTxRequest)
if not isinstance(response, messages.BinanceTxRequest):
raise RuntimeError(
"Invalid response, expected BinanceTxRequest, received "
+ type(response).__name__
)
if "refid" in msg: if "refid" in msg:
msg = dict_to_proto(messages.BinanceCancelMsg, msg) msg = dict_to_proto(messages.BinanceCancelMsg, msg)
@ -77,12 +70,4 @@ def sign_tx(
else: else:
raise ValueError("can not determine msg type") raise ValueError("can not determine msg type")
response = client.call(msg) return client.call(msg, expect=messages.BinanceSignedTx)
if not isinstance(response, messages.BinanceSignedTx):
raise RuntimeError(
"Invalid response, expected BinanceSignedTx, received "
+ type(response).__name__
)
return response

View File

@ -14,6 +14,8 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import warnings import warnings
from copy import copy from copy import copy
from decimal import Decimal from decimal import Decimal
@ -23,11 +25,10 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict from typing_extensions import Protocol, TypedDict
from . import exceptions, messages from . import exceptions, messages
from .tools import expect, prepare_message_bytes, session from .tools import _return_success, prepare_message_bytes, session
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
class ScriptSig(TypedDict): class ScriptSig(TypedDict):
@ -103,7 +104,6 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
) )
@expect(messages.PublicKey)
def get_public_node( def get_public_node(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
@ -114,13 +114,12 @@ def get_public_node(
ignore_xpub_magic: bool = False, ignore_xpub_magic: bool = False,
unlock_path: Optional[List[int]] = None, unlock_path: Optional[List[int]] = None,
unlock_path_mac: Optional[bytes] = None, unlock_path_mac: Optional[bytes] = None,
) -> "MessageType": ) -> messages.PublicKey:
if unlock_path: if unlock_path:
res = client.call( client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
) )
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call( return client.call(
messages.GetPublicKey( messages.GetPublicKey(
@ -130,16 +129,15 @@ def get_public_node(
coin_name=coin_name, coin_name=coin_name,
script_type=script_type, script_type=script_type,
ignore_xpub_magic=ignore_xpub_magic, ignore_xpub_magic=ignore_xpub_magic,
) ),
expect=messages.PublicKey,
) )
@expect(messages.Address, field="address", ret_type=str) def get_address(*args: Any, **kwargs: Any) -> str:
def get_address(*args: Any, **kwargs: Any): return get_authenticated_address(*args, **kwargs).address
return get_authenticated_address(*args, **kwargs)
@expect(messages.Address)
def get_authenticated_address( def get_authenticated_address(
client: "TrezorClient", client: "TrezorClient",
coin_name: str, coin_name: str,
@ -151,13 +149,12 @@ def get_authenticated_address(
unlock_path: Optional[List[int]] = None, unlock_path: Optional[List[int]] = None,
unlock_path_mac: Optional[bytes] = None, unlock_path_mac: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> messages.Address:
if unlock_path: if unlock_path:
res = client.call( client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
) )
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call( return client.call(
messages.GetAddress( messages.GetAddress(
@ -168,26 +165,27 @@ def get_authenticated_address(
script_type=script_type, script_type=script_type,
ignore_xpub_magic=ignore_xpub_magic, ignore_xpub_magic=ignore_xpub_magic,
chunkify=chunkify, chunkify=chunkify,
) ),
expect=messages.Address,
) )
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id( def get_ownership_id(
client: "TrezorClient", client: "TrezorClient",
coin_name: str, coin_name: str,
n: "Address", n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None, multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType": ) -> bytes:
return client.call( return client.call(
messages.GetOwnershipId( messages.GetOwnershipId(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
multisig=multisig, multisig=multisig,
script_type=script_type, script_type=script_type,
) ),
) expect=messages.OwnershipId,
).ownership_id
def get_ownership_proof( def get_ownership_proof(
@ -202,9 +200,7 @@ def get_ownership_proof(
preauthorized: bool = False, preauthorized: bool = False,
) -> Tuple[bytes, bytes]: ) -> Tuple[bytes, bytes]:
if preauthorized: if preauthorized:
res = client.call(messages.DoPreauthorized()) client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call( res = client.call(
messages.GetOwnershipProof( messages.GetOwnershipProof(
@ -215,16 +211,13 @@ def get_ownership_proof(
user_confirmation=user_confirmation, user_confirmation=user_confirmation,
ownership_ids=ownership_ids, ownership_ids=ownership_ids,
commitment_data=commitment_data, commitment_data=commitment_data,
) ),
expect=messages.OwnershipProof,
) )
if not isinstance(res, messages.OwnershipProof):
raise exceptions.TrezorException("Unexpected message")
return res.ownership_proof, res.signature return res.ownership_proof, res.signature
@expect(messages.MessageSignature)
def sign_message( def sign_message(
client: "TrezorClient", client: "TrezorClient",
coin_name: str, coin_name: str,
@ -233,7 +226,7 @@ def sign_message(
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
no_script_type: bool = False, no_script_type: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> messages.MessageSignature:
return client.call( return client.call(
messages.SignMessage( messages.SignMessage(
coin_name=coin_name, coin_name=coin_name,
@ -242,7 +235,8 @@ def sign_message(
script_type=script_type, script_type=script_type,
no_script_type=no_script_type, no_script_type=no_script_type,
chunkify=chunkify, chunkify=chunkify,
) ),
expect=messages.MessageSignature,
) )
@ -255,18 +249,19 @@ def verify_message(
chunkify: bool = False, chunkify: bool = False,
) -> bool: ) -> bool:
try: try:
resp = client.call( client.call(
messages.VerifyMessage( messages.VerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
message=prepare_message_bytes(message), message=prepare_message_bytes(message),
coin_name=coin_name, coin_name=coin_name,
chunkify=chunkify, chunkify=chunkify,
) ),
expect=messages.Success,
) )
return True
except exceptions.TrezorFailure: except exceptions.TrezorFailure:
return False return False
return isinstance(resp, messages.Success)
@session @session
@ -319,17 +314,14 @@ def sign_tx(
setattr(signtx, name, value) setattr(signtx, name, value)
if unlock_path: if unlock_path:
res = client.call( client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac) messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
) )
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
elif preauthorized: elif preauthorized:
res = client.call(messages.DoPreauthorized()) client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call(signtx) res = client.call(signtx, expect=messages.TxRequest)
# Prepare structure for signatures # Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs) signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -357,7 +349,7 @@ def sign_tx(
) )
R = messages.RequestType R = messages.RequestType
while isinstance(res, messages.TxRequest): while True:
# If there's some part of signed transaction, let's add it # If there's some part of signed transaction, let's add it
if res.serialized: if res.serialized:
if res.serialized.serialized_tx: if res.serialized.serialized_tx:
@ -388,7 +380,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ: if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index] msg = payment_reqs[res.details.request_index]
res = client.call(msg) res = client.call(msg, expect=messages.TxRequest)
else: else:
msg = messages.TransactionType() msg = messages.TransactionType()
if res.request_type == R.TXMETA: if res.request_type == R.TXMETA:
@ -418,10 +410,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}." f"Unknown request type - {res.request_type}."
) )
res = client.call(messages.TxAck(tx=msg)) res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
if not isinstance(res, messages.TxRequest):
raise exceptions.TrezorException("Unexpected message")
for i, sig in zip(inputs, signatures): for i, sig in zip(inputs, signatures):
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None: if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
@ -430,7 +419,6 @@ def sign_tx(
return signatures, serialized_tx return signatures, serialized_tx
@expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin( def authorize_coinjoin(
client: "TrezorClient", client: "TrezorClient",
coordinator: str, coordinator: str,
@ -440,8 +428,8 @@ def authorize_coinjoin(
n: "Address", n: "Address",
coin_name: str, coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS, script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType": ) -> str | None:
return client.call( resp = client.call(
messages.AuthorizeCoinJoin( messages.AuthorizeCoinJoin(
coordinator=coordinator, coordinator=coordinator,
max_rounds=max_rounds, max_rounds=max_rounds,
@ -450,5 +438,7 @@ def authorize_coinjoin(
address_n=n, address_n=n,
coin_name=coin_name, coin_name=coin_name,
script_type=script_type, script_type=script_type,
) ),
expect=messages.Success,
) )
return _return_success(resp)

View File

@ -31,12 +31,11 @@ from typing import (
Union, Union,
) )
from . import exceptions, messages, tools from . import messages as m
from .tools import expect from . import tools
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
PROTOCOL_MAGICS = { PROTOCOL_MAGICS = {
"mainnet": 764824073, "mainnet": 764824073,
@ -72,35 +71,33 @@ INCOMPLETE_OUTPUT_ERROR_MESSAGE = "The output is missing some fields"
INVALID_OUTPUT_TOKEN_BUNDLE_ENTRY = "The output's token_bundle entry is invalid" INVALID_OUTPUT_TOKEN_BUNDLE_ENTRY = "The output's token_bundle entry is invalid"
INVALID_MINT_TOKEN_BUNDLE_ENTRY = "The mint token_bundle entry is invalid" INVALID_MINT_TOKEN_BUNDLE_ENTRY = "The mint token_bundle entry is invalid"
InputWithPath = Tuple[messages.CardanoTxInput, List[int]] InputWithPath = Tuple[m.CardanoTxInput, List[int]]
CollateralInputWithPath = Tuple[messages.CardanoTxCollateralInput, List[int]] CollateralInputWithPath = Tuple[m.CardanoTxCollateralInput, List[int]]
AssetGroupWithTokens = Tuple[messages.CardanoAssetGroup, List[messages.CardanoToken]] AssetGroupWithTokens = Tuple[m.CardanoAssetGroup, List[m.CardanoToken]]
OutputWithData = Tuple[ OutputWithData = Tuple[
messages.CardanoTxOutput, m.CardanoTxOutput,
List[AssetGroupWithTokens], List[AssetGroupWithTokens],
List[messages.CardanoTxInlineDatumChunk], List[m.CardanoTxInlineDatumChunk],
List[messages.CardanoTxReferenceScriptChunk], List[m.CardanoTxReferenceScriptChunk],
] ]
OutputItem = Union[ OutputItem = Union[
messages.CardanoTxOutput, m.CardanoTxOutput,
messages.CardanoAssetGroup, m.CardanoAssetGroup,
messages.CardanoToken, m.CardanoToken,
messages.CardanoTxInlineDatumChunk, m.CardanoTxInlineDatumChunk,
messages.CardanoTxReferenceScriptChunk, m.CardanoTxReferenceScriptChunk,
] ]
CertificateItem = Union[ CertificateItem = Union[
messages.CardanoTxCertificate, m.CardanoTxCertificate,
messages.CardanoPoolOwner, m.CardanoPoolOwner,
messages.CardanoPoolRelayParameters, m.CardanoPoolRelayParameters,
]
MintItem = Union[
messages.CardanoTxMint, messages.CardanoAssetGroup, messages.CardanoToken
] ]
MintItem = Union[m.CardanoTxMint, m.CardanoAssetGroup, m.CardanoToken]
PoolOwnersAndRelays = Tuple[ PoolOwnersAndRelays = Tuple[
List[messages.CardanoPoolOwner], List[messages.CardanoPoolRelayParameters] List[m.CardanoPoolOwner], List[m.CardanoPoolRelayParameters]
] ]
CertificateWithPoolOwnersAndRelays = Tuple[ CertificateWithPoolOwnersAndRelays = Tuple[
messages.CardanoTxCertificate, Optional[PoolOwnersAndRelays] m.CardanoTxCertificate, Optional[PoolOwnersAndRelays]
] ]
Path = List[int] Path = List[int]
Witness = Tuple[Path, bytes] Witness = Tuple[Path, bytes]
@ -108,9 +105,7 @@ AuxiliaryDataSupplement = Dict[str, Union[int, bytes]]
SignTxResponse = Dict[str, Union[bytes, List[Witness], AuxiliaryDataSupplement]] SignTxResponse = Dict[str, Union[bytes, List[Witness], AuxiliaryDataSupplement]]
Chunk = TypeVar( Chunk = TypeVar(
"Chunk", "Chunk",
bound=Union[ bound=Union[m.CardanoTxInlineDatumChunk, m.CardanoTxReferenceScriptChunk],
messages.CardanoTxInlineDatumChunk, messages.CardanoTxReferenceScriptChunk
],
) )
@ -123,7 +118,7 @@ def parse_optional_int(value: Optional[str]) -> Optional[int]:
def create_address_parameters( def create_address_parameters(
address_type: messages.CardanoAddressType, address_type: m.CardanoAddressType,
address_n: List[int], address_n: List[int],
address_n_staking: Optional[List[int]] = None, address_n_staking: Optional[List[int]] = None,
staking_key_hash: Optional[bytes] = None, staking_key_hash: Optional[bytes] = None,
@ -132,18 +127,18 @@ def create_address_parameters(
certificate_index: Optional[int] = None, certificate_index: Optional[int] = None,
script_payment_hash: Optional[bytes] = None, script_payment_hash: Optional[bytes] = None,
script_staking_hash: Optional[bytes] = None, script_staking_hash: Optional[bytes] = None,
) -> messages.CardanoAddressParametersType: ) -> m.CardanoAddressParametersType:
certificate_pointer = None certificate_pointer = None
if address_type in ( if address_type in (
messages.CardanoAddressType.POINTER, m.CardanoAddressType.POINTER,
messages.CardanoAddressType.POINTER_SCRIPT, m.CardanoAddressType.POINTER_SCRIPT,
): ):
certificate_pointer = _create_certificate_pointer( certificate_pointer = _create_certificate_pointer(
block_index, tx_index, certificate_index block_index, tx_index, certificate_index
) )
return messages.CardanoAddressParametersType( return m.CardanoAddressParametersType(
address_type=address_type, address_type=address_type,
address_n=address_n, address_n=address_n,
address_n_staking=address_n_staking, address_n_staking=address_n_staking,
@ -158,11 +153,11 @@ def _create_certificate_pointer(
block_index: Optional[int], block_index: Optional[int],
tx_index: Optional[int], tx_index: Optional[int],
certificate_index: Optional[int], certificate_index: Optional[int],
) -> messages.CardanoBlockchainPointerType: ) -> m.CardanoBlockchainPointerType:
if block_index is None or tx_index is None or certificate_index is None: if block_index is None or tx_index is None or certificate_index is None:
raise ValueError("Invalid pointer parameters") raise ValueError("Invalid pointer parameters")
return messages.CardanoBlockchainPointerType( return m.CardanoBlockchainPointerType(
block_index=block_index, tx_index=tx_index, certificate_index=certificate_index block_index=block_index, tx_index=tx_index, certificate_index=certificate_index
) )
@ -173,7 +168,7 @@ def parse_input(tx_input: dict) -> InputWithPath:
path = tools.parse_path(tx_input.get("path", "")) path = tools.parse_path(tx_input.get("path", ""))
return ( return (
messages.CardanoTxInput( m.CardanoTxInput(
prev_hash=bytes.fromhex(tx_input["prev_hash"]), prev_hash=bytes.fromhex(tx_input["prev_hash"]),
prev_index=tx_input["prev_index"], prev_index=tx_input["prev_index"],
), ),
@ -204,22 +199,22 @@ def parse_output(output: dict) -> OutputWithData:
datum_hash = parse_optional_bytes(output.get("datum_hash")) datum_hash = parse_optional_bytes(output.get("datum_hash"))
serialization_format = messages.CardanoTxOutputSerializationFormat.ARRAY_LEGACY serialization_format = m.CardanoTxOutputSerializationFormat.ARRAY_LEGACY
if "format" in output: if "format" in output:
serialization_format = output["format"] serialization_format = output["format"]
inline_datum_size, inline_datum_chunks = _parse_chunkable_data( inline_datum_size, inline_datum_chunks = _parse_chunkable_data(
parse_optional_bytes(output.get("inline_datum")), parse_optional_bytes(output.get("inline_datum")),
messages.CardanoTxInlineDatumChunk, m.CardanoTxInlineDatumChunk,
) )
reference_script_size, reference_script_chunks = _parse_chunkable_data( reference_script_size, reference_script_chunks = _parse_chunkable_data(
parse_optional_bytes(output.get("reference_script")), parse_optional_bytes(output.get("reference_script")),
messages.CardanoTxReferenceScriptChunk, m.CardanoTxReferenceScriptChunk,
) )
return ( return (
messages.CardanoTxOutput( m.CardanoTxOutput(
address=address, address=address,
address_parameters=address_parameters, address_parameters=address_parameters,
amount=int(output["amount"]), amount=int(output["amount"]),
@ -253,7 +248,7 @@ def _parse_token_bundle(
result.append( result.append(
( (
messages.CardanoAssetGroup( m.CardanoAssetGroup(
policy_id=bytes.fromhex(token_group["policy_id"]), policy_id=bytes.fromhex(token_group["policy_id"]),
tokens_count=len(tokens), tokens_count=len(tokens),
), ),
@ -264,7 +259,7 @@ def _parse_token_bundle(
return result return result
def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]: def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[m.CardanoToken]:
error_message: str error_message: str
if is_mint: if is_mint:
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
@ -288,7 +283,7 @@ def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.Cardan
amount = int(token["amount"]) amount = int(token["amount"])
result.append( result.append(
messages.CardanoToken( m.CardanoToken(
asset_name_bytes=bytes.fromhex(token["asset_name_bytes"]), asset_name_bytes=bytes.fromhex(token["asset_name_bytes"]),
amount=amount, amount=amount,
mint_amount=mint_amount, mint_amount=mint_amount,
@ -300,7 +295,7 @@ def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.Cardan
def _parse_address_parameters( def _parse_address_parameters(
address_parameters: dict, error_message: str address_parameters: dict, error_message: str
) -> messages.CardanoAddressParametersType: ) -> m.CardanoAddressParametersType:
if "addressType" not in address_parameters: if "addressType" not in address_parameters:
raise ValueError(error_message) raise ValueError(error_message)
@ -317,7 +312,7 @@ def _parse_address_parameters(
) )
return create_address_parameters( return create_address_parameters(
messages.CardanoAddressType(address_parameters["addressType"]), m.CardanoAddressType(address_parameters["addressType"]),
payment_path, payment_path,
staking_path, staking_path,
staking_key_hash_bytes, staking_key_hash_bytes,
@ -346,7 +341,7 @@ def _create_data_chunks(data: bytes) -> Iterator[bytes]:
processed_size += MAX_CHUNK_SIZE processed_size += MAX_CHUNK_SIZE
def parse_native_script(native_script: dict) -> messages.CardanoNativeScript: def parse_native_script(native_script: dict) -> m.CardanoNativeScript:
if "type" not in native_script: if "type" not in native_script:
raise ValueError("Script is missing some fields") raise ValueError("Script is missing some fields")
@ -364,7 +359,7 @@ def parse_native_script(native_script: dict) -> messages.CardanoNativeScript:
invalid_before = parse_optional_int(native_script.get("invalid_before")) invalid_before = parse_optional_int(native_script.get("invalid_before"))
invalid_hereafter = parse_optional_int(native_script.get("invalid_hereafter")) invalid_hereafter = parse_optional_int(native_script.get("invalid_hereafter"))
return messages.CardanoNativeScript( return m.CardanoNativeScript(
type=type, type=type,
scripts=scripts, scripts=scripts,
key_hash=key_hash, key_hash=key_hash,
@ -385,7 +380,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
certificate_type = certificate["type"] certificate_type = certificate["type"]
if certificate_type == messages.CardanoCertificateType.STAKE_DELEGATION: if certificate_type == m.CardanoCertificateType.STAKE_DELEGATION:
if "pool" not in certificate: if "pool" not in certificate:
raise CERTIFICATE_MISSING_FIELDS_ERROR raise CERTIFICATE_MISSING_FIELDS_ERROR
@ -394,7 +389,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
) )
return ( return (
messages.CardanoTxCertificate( m.CardanoTxCertificate(
type=certificate_type, type=certificate_type,
path=path, path=path,
pool=bytes.fromhex(certificate["pool"]), pool=bytes.fromhex(certificate["pool"]),
@ -404,15 +399,15 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
None, None,
) )
elif certificate_type in ( elif certificate_type in (
messages.CardanoCertificateType.STAKE_REGISTRATION, m.CardanoCertificateType.STAKE_REGISTRATION,
messages.CardanoCertificateType.STAKE_DEREGISTRATION, m.CardanoCertificateType.STAKE_DEREGISTRATION,
): ):
path, script_hash, key_hash = _parse_credential( path, script_hash, key_hash = _parse_credential(
certificate, CERTIFICATE_MISSING_FIELDS_ERROR certificate, CERTIFICATE_MISSING_FIELDS_ERROR
) )
return ( return (
messages.CardanoTxCertificate( m.CardanoTxCertificate(
type=certificate_type, type=certificate_type,
path=path, path=path,
script_hash=script_hash, script_hash=script_hash,
@ -421,8 +416,8 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
None, None,
) )
elif certificate_type in ( elif certificate_type in (
messages.CardanoCertificateType.STAKE_REGISTRATION_CONWAY, m.CardanoCertificateType.STAKE_REGISTRATION_CONWAY,
messages.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY, m.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY,
): ):
if "deposit" not in certificate: if "deposit" not in certificate:
raise CERTIFICATE_MISSING_FIELDS_ERROR raise CERTIFICATE_MISSING_FIELDS_ERROR
@ -432,7 +427,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
) )
return ( return (
messages.CardanoTxCertificate( m.CardanoTxCertificate(
type=certificate_type, type=certificate_type,
path=path, path=path,
script_hash=script_hash, script_hash=script_hash,
@ -441,7 +436,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
), ),
None, None,
) )
elif certificate_type == messages.CardanoCertificateType.STAKE_POOL_REGISTRATION: elif certificate_type == m.CardanoCertificateType.STAKE_POOL_REGISTRATION:
pool_parameters = certificate["pool_parameters"] pool_parameters = certificate["pool_parameters"]
if any( if any(
@ -450,9 +445,9 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
): ):
raise CERTIFICATE_MISSING_FIELDS_ERROR raise CERTIFICATE_MISSING_FIELDS_ERROR
pool_metadata: Optional[messages.CardanoPoolMetadataType] pool_metadata: Optional[m.CardanoPoolMetadataType]
if pool_parameters.get("metadata") is not None: if pool_parameters.get("metadata") is not None:
pool_metadata = messages.CardanoPoolMetadataType( pool_metadata = m.CardanoPoolMetadataType(
url=pool_parameters["metadata"]["url"], url=pool_parameters["metadata"]["url"],
hash=bytes.fromhex(pool_parameters["metadata"]["hash"]), hash=bytes.fromhex(pool_parameters["metadata"]["hash"]),
) )
@ -469,9 +464,9 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
] ]
return ( return (
messages.CardanoTxCertificate( m.CardanoTxCertificate(
type=certificate_type, type=certificate_type,
pool_parameters=messages.CardanoPoolParametersType( pool_parameters=m.CardanoPoolParametersType(
pool_id=bytes.fromhex(pool_parameters["pool_id"]), pool_id=bytes.fromhex(pool_parameters["pool_id"]),
vrf_key_hash=bytes.fromhex(pool_parameters["vrf_key_hash"]), vrf_key_hash=bytes.fromhex(pool_parameters["vrf_key_hash"]),
pledge=int(pool_parameters["pledge"]), pledge=int(pool_parameters["pledge"]),
@ -486,7 +481,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
), ),
(owners, relays), (owners, relays),
) )
if certificate_type == messages.CardanoCertificateType.VOTE_DELEGATION: if certificate_type == m.CardanoCertificateType.VOTE_DELEGATION:
if "drep" not in certificate: if "drep" not in certificate:
raise CERTIFICATE_MISSING_FIELDS_ERROR raise CERTIFICATE_MISSING_FIELDS_ERROR
@ -495,13 +490,13 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
) )
return ( return (
messages.CardanoTxCertificate( m.CardanoTxCertificate(
type=certificate_type, type=certificate_type,
path=path, path=path,
script_hash=script_hash, script_hash=script_hash,
key_hash=key_hash, key_hash=key_hash,
drep=messages.CardanoDRep( drep=m.CardanoDRep(
type=messages.CardanoDRepType(certificate["drep"]["type"]), type=m.CardanoDRepType(certificate["drep"]["type"]),
key_hash=parse_optional_bytes(certificate["drep"].get("key_hash")), key_hash=parse_optional_bytes(certificate["drep"].get("key_hash")),
script_hash=parse_optional_bytes( script_hash=parse_optional_bytes(
certificate["drep"].get("script_hash") certificate["drep"].get("script_hash")
@ -527,21 +522,21 @@ def _parse_credential(
return path, script_hash, key_hash return path, script_hash, key_hash
def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner: def _parse_pool_owner(pool_owner: dict) -> m.CardanoPoolOwner:
if "staking_key_path" in pool_owner: if "staking_key_path" in pool_owner:
return messages.CardanoPoolOwner( return m.CardanoPoolOwner(
staking_key_path=tools.parse_path(pool_owner["staking_key_path"]) staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
) )
return messages.CardanoPoolOwner( return m.CardanoPoolOwner(
staking_key_hash=bytes.fromhex(pool_owner["staking_key_hash"]) staking_key_hash=bytes.fromhex(pool_owner["staking_key_hash"])
) )
def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters: def _parse_pool_relay(pool_relay: dict) -> m.CardanoPoolRelayParameters:
pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"]) pool_relay_type = m.CardanoPoolRelayType(pool_relay["type"])
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP: if pool_relay_type == m.CardanoPoolRelayType.SINGLE_HOST_IP:
ipv4_address_packed = ( ipv4_address_packed = (
ip_address(pool_relay["ipv4_address"]).packed ip_address(pool_relay["ipv4_address"]).packed
if "ipv4_address" in pool_relay if "ipv4_address" in pool_relay
@ -553,20 +548,20 @@ def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
else None else None
) )
return messages.CardanoPoolRelayParameters( return m.CardanoPoolRelayParameters(
type=pool_relay_type, type=pool_relay_type,
port=int(pool_relay["port"]), port=int(pool_relay["port"]),
ipv4_address=ipv4_address_packed, ipv4_address=ipv4_address_packed,
ipv6_address=ipv6_address_packed, ipv6_address=ipv6_address_packed,
) )
elif pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_NAME: elif pool_relay_type == m.CardanoPoolRelayType.SINGLE_HOST_NAME:
return messages.CardanoPoolRelayParameters( return m.CardanoPoolRelayParameters(
type=pool_relay_type, type=pool_relay_type,
port=int(pool_relay["port"]), port=int(pool_relay["port"]),
host_name=pool_relay["host_name"], host_name=pool_relay["host_name"],
) )
elif pool_relay_type == messages.CardanoPoolRelayType.MULTIPLE_HOST_NAME: elif pool_relay_type == m.CardanoPoolRelayType.MULTIPLE_HOST_NAME:
return messages.CardanoPoolRelayParameters( return m.CardanoPoolRelayParameters(
type=pool_relay_type, type=pool_relay_type,
host_name=pool_relay["host_name"], host_name=pool_relay["host_name"],
) )
@ -574,7 +569,7 @@ def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
raise ValueError("Unknown pool relay type") raise ValueError("Unknown pool relay type")
def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal: def parse_withdrawal(withdrawal: dict) -> m.CardanoTxWithdrawal:
WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError( WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError(
"The withdrawal is missing some fields" "The withdrawal is missing some fields"
) )
@ -586,7 +581,7 @@ def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
withdrawal, WITHDRAWAL_MISSING_FIELDS_ERROR withdrawal, WITHDRAWAL_MISSING_FIELDS_ERROR
) )
return messages.CardanoTxWithdrawal( return m.CardanoTxWithdrawal(
path=path, path=path,
amount=int(withdrawal["amount"]), amount=int(withdrawal["amount"]),
script_hash=script_hash, script_hash=script_hash,
@ -596,7 +591,7 @@ def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
def parse_auxiliary_data( def parse_auxiliary_data(
auxiliary_data: Optional[dict], auxiliary_data: Optional[dict],
) -> Optional[messages.CardanoTxAuxiliaryData]: ) -> Optional[m.CardanoTxAuxiliaryData]:
if auxiliary_data is None: if auxiliary_data is None:
return None return None
@ -620,17 +615,17 @@ def parse_auxiliary_data(
if not all(k in delegation for k in REQUIRED_FIELDS_CVOTE_DELEGATION): if not all(k in delegation for k in REQUIRED_FIELDS_CVOTE_DELEGATION):
raise AUXILIARY_DATA_MISSING_FIELDS_ERROR raise AUXILIARY_DATA_MISSING_FIELDS_ERROR
delegations.append( delegations.append(
messages.CardanoCVoteRegistrationDelegation( m.CardanoCVoteRegistrationDelegation(
vote_public_key=bytes.fromhex(delegation["vote_public_key"]), vote_public_key=bytes.fromhex(delegation["vote_public_key"]),
weight=int(delegation["weight"]), weight=int(delegation["weight"]),
) )
) )
voting_purpose = None voting_purpose = None
if serialization_format == messages.CardanoCVoteRegistrationFormat.CIP36: if serialization_format == m.CardanoCVoteRegistrationFormat.CIP36:
voting_purpose = cvote_registration.get("voting_purpose") voting_purpose = cvote_registration.get("voting_purpose")
cvote_registration_parameters = messages.CardanoCVoteRegistrationParametersType( cvote_registration_parameters = m.CardanoCVoteRegistrationParametersType(
vote_public_key=parse_optional_bytes( vote_public_key=parse_optional_bytes(
cvote_registration.get("vote_public_key") cvote_registration.get("vote_public_key")
), ),
@ -653,7 +648,7 @@ def parse_auxiliary_data(
if hash is None and cvote_registration_parameters is None: if hash is None and cvote_registration_parameters is None:
raise AUXILIARY_DATA_MISSING_FIELDS_ERROR raise AUXILIARY_DATA_MISSING_FIELDS_ERROR
return messages.CardanoTxAuxiliaryData( return m.CardanoTxAuxiliaryData(
hash=hash, hash=hash,
cvote_registration_parameters=cvote_registration_parameters, cvote_registration_parameters=cvote_registration_parameters,
) )
@ -673,7 +668,7 @@ def parse_collateral_input(collateral_input: dict) -> CollateralInputWithPath:
path = tools.parse_path(collateral_input.get("path", "")) path = tools.parse_path(collateral_input.get("path", ""))
return ( return (
messages.CardanoTxCollateralInput( m.CardanoTxCollateralInput(
prev_hash=bytes.fromhex(collateral_input["prev_hash"]), prev_hash=bytes.fromhex(collateral_input["prev_hash"]),
prev_index=collateral_input["prev_index"], prev_index=collateral_input["prev_index"],
), ),
@ -681,20 +676,20 @@ def parse_collateral_input(collateral_input: dict) -> CollateralInputWithPath:
) )
def parse_required_signer(required_signer: dict) -> messages.CardanoTxRequiredSigner: def parse_required_signer(required_signer: dict) -> m.CardanoTxRequiredSigner:
key_hash = parse_optional_bytes(required_signer.get("key_hash")) key_hash = parse_optional_bytes(required_signer.get("key_hash"))
key_path = tools.parse_path(required_signer.get("key_path", "")) key_path = tools.parse_path(required_signer.get("key_path", ""))
return messages.CardanoTxRequiredSigner( return m.CardanoTxRequiredSigner(
key_hash=key_hash, key_hash=key_hash,
key_path=key_path, key_path=key_path,
) )
def parse_reference_input(reference_input: dict) -> messages.CardanoTxReferenceInput: def parse_reference_input(reference_input: dict) -> m.CardanoTxReferenceInput:
if not all(k in reference_input for k in REQUIRED_FIELDS_INPUT): if not all(k in reference_input for k in REQUIRED_FIELDS_INPUT):
raise ValueError("The reference input is missing some fields") raise ValueError("The reference input is missing some fields")
return messages.CardanoTxReferenceInput( return m.CardanoTxReferenceInput(
prev_hash=bytes.fromhex(reference_input["prev_hash"]), prev_hash=bytes.fromhex(reference_input["prev_hash"]),
prev_index=reference_input["prev_index"], prev_index=reference_input["prev_index"],
) )
@ -712,16 +707,16 @@ def parse_additional_witness_request(
def _get_witness_requests( def _get_witness_requests(
inputs: Sequence[InputWithPath], inputs: Sequence[InputWithPath],
certificates: Sequence[CertificateWithPoolOwnersAndRelays], certificates: Sequence[CertificateWithPoolOwnersAndRelays],
withdrawals: Sequence[messages.CardanoTxWithdrawal], withdrawals: Sequence[m.CardanoTxWithdrawal],
collateral_inputs: Sequence[CollateralInputWithPath], collateral_inputs: Sequence[CollateralInputWithPath],
required_signers: Sequence[messages.CardanoTxRequiredSigner], required_signers: Sequence[m.CardanoTxRequiredSigner],
additional_witness_requests: Sequence[Path], additional_witness_requests: Sequence[Path],
signing_mode: messages.CardanoTxSigningMode, signing_mode: m.CardanoTxSigningMode,
) -> List[messages.CardanoTxWitnessRequest]: ) -> List[m.CardanoTxWitnessRequest]:
paths = set() paths = set()
# don't gather paths from tx elements in MULTISIG_TRANSACTION signing mode # don't gather paths from tx elements in MULTISIG_TRANSACTION signing mode
if signing_mode != messages.CardanoTxSigningMode.MULTISIG_TRANSACTION: if signing_mode != m.CardanoTxSigningMode.MULTISIG_TRANSACTION:
for _, path in inputs: for _, path in inputs:
if path: if path:
paths.add(tuple(path)) paths.add(tuple(path))
@ -729,18 +724,17 @@ def _get_witness_requests(
if ( if (
certificate.type certificate.type
in ( in (
messages.CardanoCertificateType.STAKE_DEREGISTRATION, m.CardanoCertificateType.STAKE_DEREGISTRATION,
messages.CardanoCertificateType.STAKE_DELEGATION, m.CardanoCertificateType.STAKE_DELEGATION,
messages.CardanoCertificateType.STAKE_REGISTRATION_CONWAY, m.CardanoCertificateType.STAKE_REGISTRATION_CONWAY,
messages.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY, m.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY,
messages.CardanoCertificateType.VOTE_DELEGATION, m.CardanoCertificateType.VOTE_DELEGATION,
) )
and certificate.path and certificate.path
): ):
paths.add(tuple(certificate.path)) paths.add(tuple(certificate.path))
elif ( elif (
certificate.type certificate.type == m.CardanoCertificateType.STAKE_POOL_REGISTRATION
== messages.CardanoCertificateType.STAKE_POOL_REGISTRATION
and pool_owners_and_relays is not None and pool_owners_and_relays is not None
): ):
owners, _ = pool_owners_and_relays owners, _ = pool_owners_and_relays
@ -752,7 +746,7 @@ def _get_witness_requests(
paths.add(tuple(withdrawal.path)) paths.add(tuple(withdrawal.path))
# gather Plutus-related paths # gather Plutus-related paths
if signing_mode == messages.CardanoTxSigningMode.PLUTUS_TRANSACTION: if signing_mode == m.CardanoTxSigningMode.PLUTUS_TRANSACTION:
for _, path in collateral_inputs: for _, path in collateral_inputs:
if path: if path:
paths.add(tuple(path)) paths.add(tuple(path))
@ -765,10 +759,10 @@ def _get_witness_requests(
paths.add(tuple(additional_witness_request)) paths.add(tuple(additional_witness_request))
sorted_paths = sorted([list(path) for path in paths]) sorted_paths = sorted([list(path) for path in paths])
return [messages.CardanoTxWitnessRequest(path=path) for path in sorted_paths] return [m.CardanoTxWitnessRequest(path=path) for path in sorted_paths]
def _get_inputs_items(inputs: List[InputWithPath]) -> Iterator[messages.CardanoTxInput]: def _get_inputs_items(inputs: List[InputWithPath]) -> Iterator[m.CardanoTxInput]:
for input, _ in inputs: for input, _ in inputs:
yield input yield input
@ -807,7 +801,7 @@ def _get_certificates_items(
def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]: def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]:
if not mint: if not mint:
return return
yield messages.CardanoTxMint(asset_groups_count=len(mint)) yield m.CardanoTxMint(asset_groups_count=len(mint))
for asset_group, tokens in mint: for asset_group, tokens in mint:
yield asset_group yield asset_group
yield from tokens yield from tokens
@ -815,7 +809,7 @@ def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]:
def _get_collateral_inputs_items( def _get_collateral_inputs_items(
collateral_inputs: Sequence[CollateralInputWithPath], collateral_inputs: Sequence[CollateralInputWithPath],
) -> Iterator[messages.CardanoTxCollateralInput]: ) -> Iterator[m.CardanoTxCollateralInput]:
for collateral_input, _ in collateral_inputs: for collateral_input, _ in collateral_inputs:
yield collateral_input yield collateral_input
@ -823,88 +817,86 @@ def _get_collateral_inputs_items(
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
address_parameters: messages.CardanoAddressParametersType, address_parameters: m.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"], protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"], network_id: int = NETWORK_IDS["mainnet"],
show_display: bool = False, show_display: bool = False,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( return client.call(
messages.CardanoGetAddress( m.CardanoGetAddress(
address_parameters=address_parameters, address_parameters=address_parameters,
protocol_magic=protocol_magic, protocol_magic=protocol_magic,
network_id=network_id, network_id=network_id,
show_display=show_display, show_display=show_display,
derivation_type=derivation_type, derivation_type=derivation_type,
chunkify=chunkify, chunkify=chunkify,
) ),
) expect=m.CardanoAddress,
).address
@expect(messages.CardanoPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", client: "TrezorClient",
address_n: List[int], address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
show_display: bool = False, show_display: bool = False,
) -> "MessageType": ) -> m.CardanoPublicKey:
return client.call( return client.call(
messages.CardanoGetPublicKey( m.CardanoGetPublicKey(
address_n=address_n, address_n=address_n,
derivation_type=derivation_type, derivation_type=derivation_type,
show_display=show_display, show_display=show_display,
) ),
expect=m.CardanoPublicKey,
) )
@expect(messages.CardanoNativeScriptHash)
def get_native_script_hash( def get_native_script_hash(
client: "TrezorClient", client: "TrezorClient",
native_script: messages.CardanoNativeScript, native_script: m.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE, display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
) -> "MessageType": ) -> m.CardanoNativeScriptHash:
return client.call( return client.call(
messages.CardanoGetNativeScriptHash( m.CardanoGetNativeScriptHash(
script=native_script, script=native_script,
display_format=display_format, display_format=display_format,
derivation_type=derivation_type, derivation_type=derivation_type,
) ),
expect=m.CardanoNativeScriptHash,
) )
def sign_tx( def sign_tx(
client: "TrezorClient", client: "TrezorClient",
signing_mode: messages.CardanoTxSigningMode, signing_mode: m.CardanoTxSigningMode,
inputs: List[InputWithPath], inputs: List[InputWithPath],
outputs: List[OutputWithData], outputs: List[OutputWithData],
fee: int, fee: int,
ttl: Optional[int], ttl: Optional[int],
validity_interval_start: Optional[int], validity_interval_start: Optional[int],
certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (), certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (),
withdrawals: Sequence[messages.CardanoTxWithdrawal] = (), withdrawals: Sequence[m.CardanoTxWithdrawal] = (),
protocol_magic: int = PROTOCOL_MAGICS["mainnet"], protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"], network_id: int = NETWORK_IDS["mainnet"],
auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None, auxiliary_data: Optional[m.CardanoTxAuxiliaryData] = None,
mint: Sequence[AssetGroupWithTokens] = (), mint: Sequence[AssetGroupWithTokens] = (),
script_data_hash: Optional[bytes] = None, script_data_hash: Optional[bytes] = None,
collateral_inputs: Sequence[CollateralInputWithPath] = (), collateral_inputs: Sequence[CollateralInputWithPath] = (),
required_signers: Sequence[messages.CardanoTxRequiredSigner] = (), required_signers: Sequence[m.CardanoTxRequiredSigner] = (),
collateral_return: Optional[OutputWithData] = None, collateral_return: Optional[OutputWithData] = None,
total_collateral: Optional[int] = None, total_collateral: Optional[int] = None,
reference_inputs: Sequence[messages.CardanoTxReferenceInput] = (), reference_inputs: Sequence[m.CardanoTxReferenceInput] = (),
additional_witness_requests: Sequence[Path] = (), additional_witness_requests: Sequence[Path] = (),
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
include_network_id: bool = False, include_network_id: bool = False,
chunkify: bool = False, chunkify: bool = False,
tag_cbor_sets: bool = False, tag_cbor_sets: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response")
witness_requests = _get_witness_requests( witness_requests = _get_witness_requests(
inputs, inputs,
certificates, certificates,
@ -916,7 +908,7 @@ def sign_tx(
) )
response = client.call( response = client.call(
messages.CardanoSignTxInit( m.CardanoSignTxInit(
signing_mode=signing_mode, signing_mode=signing_mode,
inputs_count=len(inputs), inputs_count=len(inputs),
outputs_count=len(outputs), outputs_count=len(outputs),
@ -940,10 +932,9 @@ def sign_tx(
include_network_id=include_network_id, include_network_id=include_network_id,
chunkify=chunkify, chunkify=chunkify,
tag_cbor_sets=tag_cbor_sets, tag_cbor_sets=tag_cbor_sets,
) ),
expect=m.CardanoTxItemAck,
) )
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
for tx_item in chain( for tx_item in chain(
_get_inputs_items(inputs), _get_inputs_items(inputs),
@ -951,55 +942,41 @@ def sign_tx(
_get_certificates_items(certificates), _get_certificates_items(certificates),
withdrawals, withdrawals,
): ):
response = client.call(tx_item) response = client.call(tx_item, expect=m.CardanoTxItemAck)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response: Dict[str, Any] = {} sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None: if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data) auxiliary_data_supplement = client.call(
if not isinstance( auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement )
):
raise UNEXPECTED_RESPONSE_ERROR
if ( if (
auxiliary_data_supplement.type auxiliary_data_supplement.type
!= messages.CardanoTxAuxiliaryDataSupplementType.NONE != m.CardanoTxAuxiliaryDataSupplementType.NONE
): ):
sign_tx_response["auxiliary_data_supplement"] = ( sign_tx_response["auxiliary_data_supplement"] = (
auxiliary_data_supplement.__dict__ auxiliary_data_supplement.__dict__
) )
response = client.call(messages.CardanoTxHostAck()) response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
for tx_item in chain( for tx_item in chain(
_get_mint_items(mint), _get_mint_items(mint),
_get_collateral_inputs_items(collateral_inputs), _get_collateral_inputs_items(collateral_inputs),
required_signers, required_signers,
): ):
response = client.call(tx_item) response = client.call(tx_item, expect=m.CardanoTxItemAck)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
if collateral_return is not None: if collateral_return is not None:
for tx_item in _get_output_items(collateral_return): for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item) response = client.call(tx_item, expect=m.CardanoTxItemAck)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
for reference_input in reference_inputs: for reference_input in reference_inputs:
response = client.call(reference_input) response = client.call(reference_input, expect=m.CardanoTxItemAck)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"] = [] sign_tx_response["witnesses"] = []
for witness_request in witness_requests: for witness_request in witness_requests:
response = client.call(witness_request) response = client.call(witness_request, expect=m.CardanoTxWitnessResponse)
if not isinstance(response, messages.CardanoTxWitnessResponse):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"].append( sign_tx_response["witnesses"].append(
{ {
"type": response.type, "type": response.type,
@ -1009,13 +986,9 @@ def sign_tx(
} }
) )
response = client.call(messages.CardanoTxHostAck()) response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
if not isinstance(response, messages.CardanoTxBodyHash):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["tx_hash"] = response.tx_hash sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(messages.CardanoTxHostAck()) response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
if not isinstance(response, messages.CardanoSignTxFinished):
raise UNEXPECTED_RESPONSE_ERROR
return sign_tx_response return sign_tx_response

View File

@ -107,16 +107,16 @@ def record_screen_from_connection(
@cli.command() @cli.command()
@with_client @with_client
def prodtest_t1(client: "TrezorClient") -> str: def prodtest_t1(client: "TrezorClient") -> None:
"""Perform a prodtest on Model One. """Perform a prodtest on Model One.
Only available on PRODTEST firmware and on T1B1. Formerly named self-test. Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
""" """
return debuglink_prodtest_t1(client) debuglink_prodtest_t1(client)
@cli.command() @cli.command()
@with_client @with_client
def optiga_set_sec_max(client: "TrezorClient") -> str: def optiga_set_sec_max(client: "TrezorClient") -> None:
"""Set Optiga's security event counter to maximum.""" """Set Optiga's security event counter to maximum."""
return debuglink_optiga_set_sec_max(client) debuglink_optiga_set_sec_max(client)

View File

@ -29,7 +29,6 @@ from . import ChoiceType, with_client
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from ..client import TrezorClient from ..client import TrezorClient
from ..protobuf import MessageType
from . import TrezorConnection from . import TrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = { RECOVERY_DEVICE_INPUT_METHOD = {
@ -66,7 +65,7 @@ def cli() -> None:
is_flag=True, is_flag=True,
) )
@with_client @with_client
def wipe(client: "TrezorClient", bootloader: bool) -> str: def wipe(client: "TrezorClient", bootloader: bool) -> None:
"""Reset device to factory defaults and remove all private data.""" """Reset device to factory defaults and remove all private data."""
if bootloader: if bootloader:
if not client.features.bootloader_mode: if not client.features.bootloader_mode:
@ -87,11 +86,7 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
else: else:
click.echo("Wiping user data!") click.echo("Wiping user data!")
try: device.wipe(client)
return device.wipe(client)
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
@cli.command() @cli.command()
@ -116,7 +111,7 @@ def load(
academic: bool, academic: bool,
needs_backup: bool, needs_backup: bool,
no_backup: bool, no_backup: bool,
) -> str: ) -> None:
"""Upload seed and custom configuration to the device. """Upload seed and custom configuration to the device.
This functionality is only available in debug mode. This functionality is only available in debug mode.
@ -136,7 +131,7 @@ def load(
label = "ACADEMIC" label = "ACADEMIC"
try: try:
return debuglink.load_device( debuglink.load_device(
client, client,
mnemonic=list(mnemonic), mnemonic=list(mnemonic),
pin=pin, pin=pin,
@ -184,7 +179,7 @@ def recover(
input_method: messages.RecoveryDeviceInputMethod, input_method: messages.RecoveryDeviceInputMethod,
dry_run: bool, dry_run: bool,
unlock_repeated_backup: bool, unlock_repeated_backup: bool,
) -> "MessageType": ) -> None:
"""Start safe recovery workflow.""" """Start safe recovery workflow."""
if input_method == messages.RecoveryDeviceInputMethod.ScrambledWords: if input_method == messages.RecoveryDeviceInputMethod.ScrambledWords:
input_callback = ui.mnemonic_words(expand) input_callback = ui.mnemonic_words(expand)
@ -201,7 +196,7 @@ def recover(
if unlock_repeated_backup: if unlock_repeated_backup:
type = messages.RecoveryType.UnlockRepeatedBackup type = messages.RecoveryType.UnlockRepeatedBackup
return device.recover( device.recover(
client, client,
word_count=int(words), word_count=int(words),
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
@ -236,21 +231,13 @@ def setup(
no_backup: bool, no_backup: bool,
backup_type: messages.BackupType | None, backup_type: messages.BackupType | None,
entropy_check_count: int | None, entropy_check_count: int | None,
) -> str: ) -> None:
"""Perform device setup and generate new seed.""" """Perform device setup and generate new seed."""
if strength: if strength:
strength = int(strength) strength = int(strength)
BT = messages.BackupType BT = messages.BackupType
if backup_type is None:
if client.version >= (2, 7, 1):
# SLIP39 extendable was introduced in 2.7.1
backup_type = BT.Slip39_Single_Extendable
else:
# this includes both T1 and older trezor-cores
backup_type = BT.Bip39
if ( if (
backup_type backup_type
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable) in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
@ -264,7 +251,7 @@ def setup(
"backup type. Traditional BIP39 backup may be generated instead." "backup type. Traditional BIP39 backup may be generated instead."
) )
resp, path_xpubs = device.reset_entropy_check( path_xpubs = device.setup(
client, client,
strength=strength, strength=strength,
passphrase_protection=passphrase_protection, passphrase_protection=passphrase_protection,
@ -277,13 +264,10 @@ def setup(
entropy_check_count=entropy_check_count, entropy_check_count=entropy_check_count,
) )
if isinstance(resp, messages.Success): if path_xpubs:
click.echo("XPUBs for the generated seed") click.echo("XPUBs for the generated seed")
for path, xpub in path_xpubs: for path, xpub in path_xpubs:
click.echo(f"{format_path(path)}: {xpub}") click.echo(f"{format_path(path)}: {xpub}")
return resp.message or ""
else:
raise RuntimeError(f"Received {resp.__class__}")
@cli.command() @cli.command()
@ -294,10 +278,9 @@ def backup(
client: "TrezorClient", client: "TrezorClient",
group_threshold: int | None = None, group_threshold: int | None = None,
groups: t.Sequence[tuple[int, int]] = (), groups: t.Sequence[tuple[int, int]] = (),
) -> str: ) -> None:
"""Perform device seed backup.""" """Perform device seed backup."""
device.backup(client, group_threshold, groups)
return device.backup(client, group_threshold, groups)
@cli.command() @cli.command()
@ -305,7 +288,7 @@ def backup(
@with_client @with_client
def sd_protect( def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType client: "TrezorClient", operation: messages.SdProtectOperationType
) -> str: ) -> None:
"""Secure the device with SD card protection. """Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored When SD card protection is enabled, a randomly generated secret is stored
@ -321,12 +304,12 @@ def sd_protect(
""" """
if client.features.model == "1": if client.features.model == "1":
raise click.ClickException("Trezor One does not support SD card protection.") raise click.ClickException("Trezor One does not support SD card protection.")
return device.sd_protect(client, operation) device.sd_protect(client, operation)
@cli.command() @cli.command()
@click.pass_obj @click.pass_obj
def reboot_to_bootloader(obj: "TrezorConnection") -> str: def reboot_to_bootloader(obj: "TrezorConnection") -> None:
"""Reboot device into bootloader mode. """Reboot device into bootloader mode.
Currently only supported on Trezor Model One. Currently only supported on Trezor Model One.
@ -334,21 +317,21 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str:
# avoid using @with_client because it closes the session afterwards, # avoid using @with_client because it closes the session afterwards,
# which triggers double prompt on device # which triggers double prompt on device
with obj.client_context() as client: with obj.client_context() as client:
return device.reboot_to_bootloader(client) device.reboot_to_bootloader(client)
@cli.command() @cli.command()
@with_client @with_client
def tutorial(client: "TrezorClient") -> str: def tutorial(client: "TrezorClient") -> None:
"""Show on-device tutorial.""" """Show on-device tutorial."""
return device.show_device_tutorial(client) device.show_device_tutorial(client)
@cli.command() @cli.command()
@with_client @with_client
def unlock_bootloader(client: "TrezorClient") -> str: def unlock_bootloader(client: "TrezorClient") -> None:
"""Unlocks bootloader. Irreversible.""" """Unlocks bootloader. Irreversible."""
return device.unlock_bootloader(client) device.unlock_bootloader(client)
@cli.command() @cli.command()
@ -360,10 +343,11 @@ def unlock_bootloader(client: "TrezorClient") -> str:
help="Dialog expiry in seconds.", help="Dialog expiry in seconds.",
) )
@with_client @with_client
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str: def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None:
"""Show a "Do not disconnect" dialog.""" """Show a "Do not disconnect" dialog."""
if enable is False: if enable is False:
return device.set_busy(client, None) device.set_busy(client, None)
return
if expiry is None: if expiry is None:
raise click.ClickException("Missing option '-e' / '--expiry'.") raise click.ClickException("Missing option '-e' / '--expiry'.")
@ -373,7 +357,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer." f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
) )
return device.set_busy(client, expiry * 1000) device.set_busy(client, expiry * 1000)
PUBKEY_WHITELIST_URL_TEMPLATE = ( PUBKEY_WHITELIST_URL_TEMPLATE = (

View File

@ -80,12 +80,12 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add") @credentials.command(name="add")
@click.argument("hex_credential_id") @click.argument("hex_credential_id")
@with_client @with_client
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str: def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None:
"""Add the credential with the given ID as a resident credential. """Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string. HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
""" """
return fido.add_credential(client, bytes.fromhex(hex_credential_id)) fido.add_credential(client, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove") @credentials.command(name="remove")
@ -93,9 +93,9 @@ def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index." "-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
) )
@with_client @with_client
def credentials_remove(client: "TrezorClient", index: int) -> str: def credentials_remove(client: "TrezorClient", index: int) -> None:
"""Remove the resident credential at the given index.""" """Remove the resident credential at the given index."""
return fido.remove_credential(client, index) fido.remove_credential(client, index)
# #
@ -111,9 +111,9 @@ def counter() -> None:
@counter.command(name="set") @counter.command(name="set")
@click.argument("counter", type=int) @click.argument("counter", type=int)
@with_client @with_client
def counter_set(client: "TrezorClient", counter: int) -> str: def counter_set(client: "TrezorClient", counter: int) -> None:
"""Set FIDO/U2F counter value.""" """Set FIDO/U2F counter value."""
return fido.set_counter(client, counter) fido.set_counter(client, counter)
@counter.command(name="get-next") @counter.command(name="get-next")

View File

@ -181,17 +181,17 @@ def cli() -> None:
@click.option("-r", "--remove", is_flag=True, hidden=True) @click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client @with_client
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
"""Set, change or remove PIN.""" """Set, change or remove PIN."""
# Remove argument is there for backwards compatibility # Remove argument is there for backwards compatibility
return device.change_pin(client, remove=_should_remove(enable, remove)) device.change_pin(client, remove=_should_remove(enable, remove))
@cli.command() @cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True) @click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False) @click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client @with_client
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str: def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
"""Set or remove the wipe code. """Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -199,7 +199,7 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
removed and the device will be reset to factory defaults. removed and the device will be reset to factory defaults.
""" """
# Remove argument is there for backwards compatibility # Remove argument is there for backwards compatibility
return device.change_wipe_code(client, remove=_should_remove(enable, remove)) device.change_wipe_code(client, remove=_should_remove(enable, remove))
@cli.command() @cli.command()
@ -207,24 +207,24 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label") @click.argument("label")
@with_client @with_client
def label(client: "TrezorClient", label: str) -> str: def label(client: "TrezorClient", label: str) -> None:
"""Set new device label.""" """Set new device label."""
return device.apply_settings(client, label=label) device.apply_settings(client, label=label)
@cli.command() @cli.command()
@with_client @with_client
def brightness(client: "TrezorClient") -> str: def brightness(client: "TrezorClient") -> None:
"""Set display brightness.""" """Set display brightness."""
return device.set_brightness(client) device.set_brightness(client)
@cli.command() @cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False})) @click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client @with_client
def haptic_feedback(client: "TrezorClient", enable: bool) -> str: def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
"""Enable or disable haptic feedback.""" """Enable or disable haptic feedback."""
return device.apply_settings(client, haptic_feedback=enable) device.apply_settings(client, haptic_feedback=enable)
@cli.command() @cli.command()
@ -236,7 +236,7 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
@with_client @with_client
def language( def language(
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
) -> str: ) -> None:
"""Set new language with translations.""" """Set new language with translations."""
if remove != (path_or_url is None): if remove != (path_or_url is None):
raise click.ClickException("Either provide a path or URL or use --remove") raise click.ClickException("Either provide a path or URL or use --remove")
@ -259,27 +259,27 @@ def language(
raise click.ClickException( raise click.ClickException(
f"Failed to load translations from {path_or_url}" f"Failed to load translations from {path_or_url}"
) from None ) from None
return device.change_language( device.change_language(client, language_data=language_data, show_display=display)
client, language_data=language_data, show_display=display
)
@cli.command() @cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION)) @click.argument("rotation", type=ChoiceType(ROTATION))
@with_client @with_client
def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str: def display_rotation(
client: "TrezorClient", rotation: messages.DisplayRotation
) -> None:
"""Set display rotation. """Set display rotation.
Configure display rotation for Trezor Model T. The options are Configure display rotation for Trezor Model T. The options are
north, east, south or west. north, east, south or west.
""" """
return device.apply_settings(client, display_rotation=rotation) device.apply_settings(client, display_rotation=rotation)
@cli.command() @cli.command()
@click.argument("delay", type=str) @click.argument("delay", type=str)
@with_client @with_client
def auto_lock_delay(client: "TrezorClient", delay: str) -> str: def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
"""Set auto-lock delay (in seconds).""" """Set auto-lock delay (in seconds)."""
if not client.features.pin_protection: if not client.features.pin_protection:
@ -291,13 +291,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
seconds = float(value) * units[unit] seconds = float(value) * units[unit]
else: else:
seconds = float(delay) # assume seconds if no unit is specified seconds = float(delay) # assume seconds if no unit is specified
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000)) device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
@cli.command() @cli.command()
@click.argument("flags") @click.argument("flags")
@with_client @with_client
def flags(client: "TrezorClient", flags: str) -> str: def flags(client: "TrezorClient", flags: str) -> None:
"""Set device flags.""" """Set device flags."""
if flags.lower().startswith("0b"): if flags.lower().startswith("0b"):
flags_int = int(flags, 2) flags_int = int(flags, 2)
@ -305,7 +305,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
flags_int = int(flags, 16) flags_int = int(flags, 16)
else: else:
flags_int = int(flags) flags_int = int(flags)
return device.apply_flags(client, flags=flags_int) device.apply_flags(client, flags=flags_int)
@cli.command() @cli.command()
@ -315,7 +315,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
) )
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)") @click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
@with_client @with_client
def homescreen(client: "TrezorClient", filename: str, quality: int) -> str: def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
"""Set new homescreen. """Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default' To revert to default homescreen, use 'trezorctl set homescreen default'
@ -369,7 +369,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
"Unknown image format requested by the device." "Unknown image format requested by the device."
) )
return device.apply_settings(client, homescreen=img) device.apply_settings(client, homescreen=img)
@cli.command() @cli.command()
@ -380,7 +380,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
@with_client @with_client
def safety_checks( def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
) -> str: ) -> None:
"""Set safety check level. """Set safety check level.
Set to "strict" to get the full Trezor security (default setting). Set to "strict" to get the full Trezor security (default setting).
@ -392,18 +392,18 @@ def safety_checks(
""" """
if always and level == messages.SafetyCheckLevel.PromptTemporarily: if always and level == messages.SafetyCheckLevel.PromptTemporarily:
level = messages.SafetyCheckLevel.PromptAlways level = messages.SafetyCheckLevel.PromptAlways
return device.apply_settings(client, safety_checks=level) device.apply_settings(client, safety_checks=level)
@cli.command() @cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False})) @click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client @with_client
def experimental_features(client: "TrezorClient", enable: bool) -> str: def experimental_features(client: "TrezorClient", enable: bool) -> None:
"""Enable or disable experimental message types. """Enable or disable experimental message types.
This is a developer feature. Use with caution. This is a developer feature. Use with caution.
""" """
return device.apply_settings(client, experimental_features=enable) device.apply_settings(client, experimental_features=enable)
# #
@ -427,13 +427,13 @@ passphrase = cast(AliasedGroup, passphrase_main)
@passphrase.command(name="on") @passphrase.command(name="on")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None) @click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client @with_client
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str: def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None:
"""Enable passphrase.""" """Enable passphrase."""
if client.features.passphrase_protection is not True: if client.features.passphrase_protection is not True:
use_passphrase = True use_passphrase = True
else: else:
use_passphrase = None use_passphrase = None
return device.apply_settings( device.apply_settings(
client, client,
use_passphrase=use_passphrase, use_passphrase=use_passphrase,
passphrase_always_on_device=force_on_device, passphrase_always_on_device=force_on_device,
@ -442,9 +442,9 @@ def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> st
@passphrase.command(name="off") @passphrase.command(name="off")
@with_client @with_client
def passphrase_off(client: "TrezorClient") -> str: def passphrase_off(client: "TrezorClient") -> None:
"""Disable passphrase.""" """Disable passphrase."""
return device.apply_settings(client, use_passphrase=False) device.apply_settings(client, use_passphrase=False)
# Registering the aliases for backwards compatibility # Registering the aliases for backwards compatibility
@ -458,9 +458,9 @@ passphrase.aliases = {
@passphrase.command(name="hide") @passphrase.command(name="hide")
@click.argument("hide", type=ChoiceType({"on": True, "off": False})) @click.argument("hide", type=ChoiceType({"on": True, "off": False}))
@with_client @with_client
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str: def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None:
"""Enable or disable hiding passphrase coming from host. """Enable or disable hiding passphrase coming from host.
This is a developer feature. Use with caution. This is a developer feature. Use with caution.
""" """
return device.apply_settings(client, hide_passphrase_from_host=hide) device.apply_settings(client, hide_passphrase_from_host=hide)

View File

@ -26,7 +26,7 @@ def get_public_key(
client: "TrezorClient", client: "TrezorClient",
address: str, address: str,
show_display: bool, show_display: bool,
) -> messages.SolanaPublicKey: ) -> bytes:
"""Get Solana public key.""" """Get Solana public key."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display) return solana.get_public_key(client, address_n, show_display)
@ -42,7 +42,7 @@ def get_address(
address: str, address: str,
show_display: bool, show_display: bool,
chunkify: bool, chunkify: bool,
) -> messages.SolanaAddress: ) -> str:
"""Get Solana address.""" """Get Solana address."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify) return solana.get_address(client, address_n, show_display, chunkify)
@ -58,7 +58,7 @@ def sign_tx(
address: str, address: str,
serialized_tx: str, serialized_tx: str,
additional_information_file: Optional[TextIO], additional_information_file: Optional[TextIO],
) -> messages.SolanaTxSignature: ) -> bytes:
"""Sign Solana transaction.""" """Sign Solana transaction."""
address_n = tools.parse_path(address) address_n = tools.parse_path(address)

View File

@ -27,7 +27,7 @@ from . import exceptions, mapping, messages, models
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import Capability from .messages import Capability
from .protobuf import MessageType from .protobuf import MessageType
from .tools import expect, parse_path, session from .tools import parse_path, session
if TYPE_CHECKING: if TYPE_CHECKING:
from .transport import Transport from .transport import Transport
@ -397,12 +397,7 @@ class TrezorClient(Generic[UI]):
else: else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR) raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
@expect(messages.Success, field="message", ret_type=str) def ping(self, msg: str, button_protection: bool = False) -> str:
def ping(
self,
msg: str,
button_protection: bool = False,
) -> MessageType:
# We would like ping to work on any valid TrezorClient instance, but # We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will # due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old. # raise an exception if the firmware is too old.
@ -416,13 +411,18 @@ class TrezorClient(Generic[UI]):
# device is PIN-locked. # device is PIN-locked.
# respond and hope for the best # respond and hope for the best
resp = self._callback_button(resp) resp = self._callback_button(resp)
return resp resp = messages.Success.ensure_isinstance(resp)
assert resp.message is not None
return resp.message
finally: finally:
self.close() self.close()
return self.call( resp = self.call(
messages.Ping(message=msg, button_protection=button_protection) messages.Ping(message=msg, button_protection=button_protection),
expect=messages.Success,
) )
assert resp.message is not None
return resp.message
def get_device_id(self) -> Optional[str]: def get_device_id(self) -> Optional[str]:
return self.features.device_id return self.features.device_id

View File

@ -44,10 +44,9 @@ from mnemonic import Mnemonic
from . import mapping, messages, models, protobuf from . import mapping, messages, models, protobuf
from .client import TrezorClient from .client import TrezorClient
from .exceptions import TrezorFailure from .exceptions import TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES from .log import DUMP_BYTES
from .messages import DebugWaitType from .messages import DebugWaitType
from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import Protocol from typing_extensions import Protocol
@ -775,9 +774,10 @@ class DebugLink:
else: else:
self.t1_take_screenshots = False self.t1_take_screenshots = False
@expect(messages.DebugLinkMemory, field="memory", ret_type=bytes) def memory_read(self, address: int, length: int) -> bytes:
def memory_read(self, address: int, length: int) -> protobuf.MessageType: return self._call(
return self._call(messages.DebugLinkMemoryRead(address=address, length=length)) messages.DebugLinkMemoryRead(address=address, length=length)
).memory
def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None: def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None:
self._write( self._write(
@ -787,9 +787,11 @@ class DebugLink:
def flash_erase(self, sector: int) -> None: def flash_erase(self, sector: int) -> None:
self._write(messages.DebugLinkFlashErase(sector=sector)) self._write(messages.DebugLinkFlashErase(sector=sector))
@expect(messages.Success)
def erase_sd_card(self, format: bool = True) -> messages.Success: def erase_sd_card(self, format: bool = True) -> messages.Success:
return self._call(messages.DebugLinkEraseSdCard(format=format)) res = self._call(messages.DebugLinkEraseSdCard(format=format))
if not isinstance(res, messages.Success):
raise UnexpectedMessageError(messages.Success, res)
return res
def snapshot_legacy(self) -> None: def snapshot_legacy(self) -> None:
"""Snapshot the current state of the device.""" """Snapshot the current state of the device."""
@ -1350,7 +1352,6 @@ class TrezorClientDebugLink(TrezorClient):
raise RuntimeError("Unexpected call") raise RuntimeError("Unexpected call")
@expect(messages.Success, field="message", ret_type=str)
def load_device( def load_device(
client: "TrezorClient", client: "TrezorClient",
mnemonic: Union[str, Iterable[str]], mnemonic: Union[str, Iterable[str]],
@ -1360,7 +1361,7 @@ def load_device(
skip_checksum: bool = False, skip_checksum: bool = False,
needs_backup: bool = False, needs_backup: bool = False,
no_backup: bool = False, no_backup: bool = False,
) -> protobuf.MessageType: ) -> None:
if isinstance(mnemonic, str): if isinstance(mnemonic, str):
mnemonic = [mnemonic] mnemonic = [mnemonic]
@ -1371,7 +1372,7 @@ def load_device(
"Device is initialized already. Call device.wipe() and try again." "Device is initialized already. Call device.wipe() and try again."
) )
resp = client.call( client.call(
messages.LoadDevice( messages.LoadDevice(
mnemonics=mnemonics, mnemonics=mnemonics,
pin=pin, pin=pin,
@ -1380,25 +1381,25 @@ def load_device(
skip_checksum=skip_checksum, skip_checksum=skip_checksum,
needs_backup=needs_backup, needs_backup=needs_backup,
no_backup=no_backup, no_backup=no_backup,
) ),
expect=messages.Success,
) )
client.init_device() client.init_device()
return resp
# keep the old name for compatibility # keep the old name for compatibility
load_device_by_mnemonic = load_device load_device_by_mnemonic = load_device
@expect(messages.Success, field="message", ret_type=str) def prodtest_t1(client: "TrezorClient") -> None:
def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
if client.features.bootloader_mode is not True: if client.features.bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode") raise RuntimeError("Device must be in bootloader mode")
return client.call( client.call(
messages.ProdTestT1( messages.ProdTestT1(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC" payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
) ),
expect=messages.Success,
) )
@ -1450,6 +1451,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
return debug_client.features.fw_vendor == "EMULATOR" return debug_client.features.fw_vendor == "EMULATOR"
@expect(messages.Success, field="message", ret_type=str) def optiga_set_sec_max(client: "TrezorClient") -> None:
def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType: client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)
return client.call(messages.DebugLinkOptigaSetSecMax())

View File

@ -18,7 +18,6 @@ from __future__ import annotations
import hashlib import hashlib
import hmac import hmac
import os
import random import random
import secrets import secrets
import time import time
@ -29,11 +28,16 @@ from slip10 import SLIP10
from . import messages from . import messages
from .exceptions import Cancelled, TrezorException from .exceptions import Cancelled, TrezorException
from .tools import Address, expect, parse_path, session from .tools import (
Address,
_deprecation_retval_helper,
_return_success,
parse_path,
session,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
RECOVERY_BACK = "\x08" # backspace character, sent literally RECOVERY_BACK = "\x08" # backspace character, sent literally
@ -42,7 +46,6 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1)
ENTROPY_CHECK_MIN_VERSION = (2, 8, 7) ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
@expect(messages.Success, field="message", ret_type=str)
@session @session
def apply_settings( def apply_settings(
client: "TrezorClient", client: "TrezorClient",
@ -57,7 +60,7 @@ def apply_settings(
experimental_features: Optional[bool] = None, experimental_features: Optional[bool] = None,
hide_passphrase_from_host: Optional[bool] = None, hide_passphrase_from_host: Optional[bool] = None,
haptic_feedback: Optional[bool] = None, haptic_feedback: Optional[bool] = None,
) -> "MessageType": ) -> str | None:
if language is not None: if language is not None:
warnings.warn( warnings.warn(
"language ignored. Use change_language() to set device language.", "language ignored. Use change_language() to set device language.",
@ -76,87 +79,80 @@ def apply_settings(
haptic_feedback=haptic_feedback, haptic_feedback=haptic_feedback,
) )
out = client.call(settings) out = client.call(settings, expect=messages.Success)
client.refresh_features() client.refresh_features()
return out return _return_success(out)
def _send_language_data( def _send_language_data(
client: "TrezorClient", client: "TrezorClient",
request: "messages.TranslationDataRequest", request: "messages.TranslationDataRequest",
language_data: bytes, language_data: bytes,
) -> "MessageType": ) -> None:
response: MessageType = request response = request
while not isinstance(response, messages.Success): while not isinstance(response, messages.Success):
assert isinstance(response, messages.TranslationDataRequest) response = messages.TranslationDataRequest.ensure_isinstance(response)
data_length = response.data_length data_length = response.data_length
data_offset = response.data_offset data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length] chunk = language_data[data_offset : data_offset + data_length]
response = client.call(messages.TranslationDataAck(data_chunk=chunk)) response = client.call(messages.TranslationDataAck(data_chunk=chunk))
return response
@expect(messages.Success, field="message", ret_type=str)
@session @session
def change_language( def change_language(
client: "TrezorClient", client: "TrezorClient",
language_data: bytes, language_data: bytes,
show_display: bool | None = None, show_display: bool | None = None,
) -> "MessageType": ) -> str | None:
data_length = len(language_data) data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display) msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
response = client.call(msg) response = client.call(msg)
if data_length > 0: if data_length > 0:
assert isinstance(response, messages.TranslationDataRequest) response = messages.TranslationDataRequest.ensure_isinstance(response)
response = _send_language_data(client, response, language_data) _send_language_data(client, response, language_data)
assert isinstance(response, messages.Success) else:
messages.Success.ensure_isinstance(response)
client.refresh_features() # changing the language in features client.refresh_features() # changing the language in features
return response return _return_success(messages.Success(message="Language changed."))
@expect(messages.Success, field="message", ret_type=str)
@session @session
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType": def apply_flags(client: "TrezorClient", flags: int) -> str | None:
out = client.call(messages.ApplyFlags(flags=flags)) out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
client.refresh_features() client.refresh_features()
return out return _return_success(out)
@expect(messages.Success, field="message", ret_type=str)
@session @session
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType": def change_pin(client: "TrezorClient", remove: bool = False) -> str | None:
ret = client.call(messages.ChangePin(remove=remove)) ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success)
client.refresh_features() client.refresh_features()
return ret return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
@session @session
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType": def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None:
ret = client.call(messages.ChangeWipeCode(remove=remove)) ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
client.refresh_features() client.refresh_features()
return ret return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
@session @session
def sd_protect( def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType client: "TrezorClient", operation: messages.SdProtectOperationType
) -> "MessageType": ) -> str | None:
ret = client.call(messages.SdProtect(operation=operation)) ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success)
client.refresh_features() client.refresh_features()
return ret return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
@session @session
def wipe(client: "TrezorClient") -> "MessageType": def wipe(client: "TrezorClient") -> str | None:
ret = client.call(messages.WipeDevice()) ret = client.call(messages.WipeDevice(), expect=messages.Success)
if not client.features.bootloader_mode: if not client.features.bootloader_mode:
client.init_device() client.init_device()
return ret return _return_success(ret)
@session @session
@ -173,7 +169,7 @@ def recover(
u2f_counter: Optional[int] = None, u2f_counter: Optional[int] = None,
*, *,
type: Optional[messages.RecoveryType] = None, type: Optional[messages.RecoveryType] = None,
) -> "MessageType": ) -> messages.Success | None:
if language is not None: if language is not None:
warnings.warn( warnings.warn(
"language ignored. Use change_language() to set device language.", "language ignored. Use change_language() to set device language.",
@ -235,8 +231,12 @@ def recover(
except Cancelled: except Cancelled:
res = client.call(messages.Cancel()) res = client.call(messages.Cancel())
# check that the result is a Success
res = messages.Success.ensure_isinstance(res)
# reinitialize the device
client.init_device() client.init_device()
return res
return _deprecation_retval_helper(res)
def is_slip39_backup_type(backup_type: messages.BackupType): def is_slip39_backup_type(backup_type: messages.BackupType):
@ -279,13 +279,7 @@ def _seed_from_entropy(
return seed return seed
@expect(messages.Success, field="message", ret_type=str) def reset(
def reset(*args: Any, **kwargs: Any) -> "MessageType":
return reset_entropy_check(*args, **kwargs)[0]
@session
def reset_entropy_check(
client: "TrezorClient", client: "TrezorClient",
display_random: bool = False, display_random: bool = False,
strength: Optional[int] = None, strength: Optional[int] = None,
@ -576,13 +570,12 @@ def _reset_with_entropycheck(
return xpubs return xpubs
@expect(messages.Success, field="message", ret_type=str)
@session @session
def backup( def backup(
client: "TrezorClient", client: "TrezorClient",
group_threshold: Optional[int] = None, group_threshold: Optional[int] = None,
groups: Iterable[tuple[int, int]] = (), groups: Iterable[tuple[int, int]] = (),
) -> "MessageType": ) -> str | None:
ret = client.call( ret = client.call(
messages.BackupDevice( messages.BackupDevice(
group_threshold=group_threshold, group_threshold=group_threshold,
@ -590,38 +583,39 @@ def backup(
messages.Slip39Group(member_threshold=t, member_count=c) messages.Slip39Group(member_threshold=t, member_count=c)
for t, c in groups for t, c in groups
], ],
) ),
expect=messages.Success,
) )
client.refresh_features() client.refresh_features()
return ret return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str) def cancel_authorization(client: "TrezorClient") -> str | None:
def cancel_authorization(client: "TrezorClient") -> "MessageType": ret = client.call(messages.CancelAuthorization(), expect=messages.Success)
return client.call(messages.CancelAuthorization()) return _return_success(ret)
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes) def unlock_path(client: "TrezorClient", n: "Address") -> bytes:
def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType": resp = client.call(
resp = client.call(messages.UnlockPath(address_n=n)) messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
)
# Cancel the UnlockPath workflow now that we have the authentication code. # Cancel the UnlockPath workflow now that we have the authentication code.
try: try:
client.call(messages.Cancel()) client.call(messages.Cancel())
except Cancelled: except Cancelled:
return resp return resp.mac
else: else:
raise TrezorException("Unexpected response in UnlockPath flow") raise TrezorException("Unexpected response in UnlockPath flow")
@session @session
@expect(messages.Success, field="message", ret_type=str)
def reboot_to_bootloader( def reboot_to_bootloader(
client: "TrezorClient", client: "TrezorClient",
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT, boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
firmware_header: Optional[bytes] = None, firmware_header: Optional[bytes] = None,
language_data: bytes = b"", language_data: bytes = b"",
) -> "MessageType": ) -> str | None:
response = client.call( response = client.call(
messages.RebootToBootloader( messages.RebootToBootloader(
boot_command=boot_command, boot_command=boot_command,
@ -631,41 +625,42 @@ def reboot_to_bootloader(
) )
if isinstance(response, messages.TranslationDataRequest): if isinstance(response, messages.TranslationDataRequest):
response = _send_language_data(client, response, language_data) response = _send_language_data(client, response, language_data)
return response return _return_success(messages.Success(message=""))
@session @session
@expect(messages.Success, field="message", ret_type=str) def show_device_tutorial(client: "TrezorClient") -> str | None:
def show_device_tutorial(client: "TrezorClient") -> "MessageType": ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success)
return client.call(messages.ShowDeviceTutorial()) return _return_success(ret)
@session @session
@expect(messages.Success, field="message", ret_type=str) def unlock_bootloader(client: "TrezorClient") -> str | None:
def unlock_bootloader(client: "TrezorClient") -> "MessageType": ret = client.call(messages.UnlockBootloader(), expect=messages.Success)
return client.call(messages.UnlockBootloader()) return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
@session @session
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType": def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None:
"""Sets or clears the busy state of the device. """Sets or clears the busy state of the device.
In the busy state the device shows a "Do not disconnect" message instead of the homescreen. In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
Setting `expiry_ms=None` clears the busy state. Setting `expiry_ms=None` clears the busy state.
""" """
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms)) ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
client.refresh_features() client.refresh_features()
return ret return _return_success(ret)
@expect(messages.AuthenticityProof) def authenticate(
def authenticate(client: "TrezorClient", challenge: bytes): client: "TrezorClient", challenge: bytes
return client.call(messages.AuthenticateDevice(challenge=challenge)) ) -> messages.AuthenticityProof:
return client.call(
messages.AuthenticateDevice(challenge=challenge),
expect=messages.AuthenticityProof,
)
@expect(messages.Success, field="message", ret_type=str) def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None:
def set_brightness( ret = client.call(messages.SetBrightness(value=value), expect=messages.Success)
client: "TrezorClient", value: Optional[int] = None return _return_success(ret)
) -> "MessageType":
return client.call(messages.SetBrightness(value=value))

View File

@ -18,11 +18,10 @@ from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages from . import exceptions, messages
from .tools import b58decode, expect, session from .tools import b58decode, session
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
@ -319,14 +318,13 @@ def parse_transaction_json(
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.EosPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType": ) -> messages.EosPublicKey:
response = client.call( return client.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display) messages.EosGetPublicKey(address_n=n, show_display=show_display),
expect=messages.EosPublicKey,
) )
return response
@session @session

View File

@ -18,11 +18,10 @@ import re
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import definitions, exceptions, messages from . import definitions, exceptions, messages
from .tools import expect, prepare_message_bytes, session, unharden from .tools import prepare_message_bytes, session, unharden
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
@ -161,30 +160,32 @@ def network_from_address_n(
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.EthereumAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( resp = client.call(
messages.EthereumGetAddress( messages.EthereumGetAddress(
address_n=n, address_n=n,
show_display=show_display, show_display=show_display,
encoded_network=encoded_network, encoded_network=encoded_network,
chunkify=chunkify, chunkify=chunkify,
) ),
expect=messages.EthereumAddress,
) )
assert resp.address is not None
return resp.address
@expect(messages.EthereumPublicKey)
def get_public_node( def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType": ) -> messages.EthereumPublicKey:
return client.call( return client.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display) messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
expect=messages.EthereumPublicKey,
) )
@ -297,25 +298,24 @@ def sign_tx_eip1559(
return response.signature_v, response.signature_r, response.signature_s return response.signature_v, response.signature_r, response.signature_s
@expect(messages.EthereumMessageSignature)
def sign_message( def sign_message(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
message: AnyStr, message: AnyStr,
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> messages.EthereumMessageSignature:
return client.call( return client.call(
messages.EthereumSignMessage( messages.EthereumSignMessage(
address_n=n, address_n=n,
message=prepare_message_bytes(message), message=prepare_message_bytes(message),
encoded_network=encoded_network, encoded_network=encoded_network,
chunkify=chunkify, chunkify=chunkify,
) ),
expect=messages.EthereumMessageSignature,
) )
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data( def sign_typed_data(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
@ -323,7 +323,7 @@ def sign_typed_data(
*, *,
metamask_v4_compat: bool = True, metamask_v4_compat: bool = True,
definitions: Optional[messages.EthereumDefinitions] = None, definitions: Optional[messages.EthereumDefinitions] = None,
) -> "MessageType": ) -> messages.EthereumTypedDataSignature:
data = sanitize_typed_data(data) data = sanitize_typed_data(data)
types = data["types"] types = data["types"]
@ -387,7 +387,7 @@ def sign_typed_data(
request = messages.EthereumTypedDataValueAck(value=encoded_data) request = messages.EthereumTypedDataValueAck(value=encoded_data)
response = client.call(request) response = client.call(request)
return response return messages.EthereumTypedDataSignature.ensure_isinstance(response)
def verify_message( def verify_message(
@ -398,32 +398,33 @@ def verify_message(
chunkify: bool = False, chunkify: bool = False,
) -> bool: ) -> bool:
try: try:
resp = client.call( client.call(
messages.EthereumVerifyMessage( messages.EthereumVerifyMessage(
address=address, address=address,
signature=signature, signature=signature,
message=prepare_message_bytes(message), message=prepare_message_bytes(message),
chunkify=chunkify, chunkify=chunkify,
) ),
expect=messages.Success,
) )
return True
except exceptions.TrezorFailure: except exceptions.TrezorFailure:
return False return False
return isinstance(resp, messages.Success)
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data_hash( def sign_typed_data_hash(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
domain_hash: bytes, domain_hash: bytes,
message_hash: Optional[bytes], message_hash: Optional[bytes],
encoded_network: Optional[bytes] = None, encoded_network: Optional[bytes] = None,
) -> "MessageType": ) -> messages.EthereumTypedDataSignature:
return client.call( return client.call(
messages.EthereumSignTypedHash( messages.EthereumSignTypedHash(
address_n=n, address_n=n,
domain_separator_hash=domain_hash, domain_separator_hash=domain_hash,
message_hash=message_hash, message_hash=message_hash,
encoded_network=encoded_network, encoded_network=encoded_network,
) ),
expect=messages.EthereumTypedDataSignature,
) )

View File

@ -14,42 +14,45 @@
# You should have received a copy of the License along with this library. # You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, List from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
from . import messages from . import messages
from .tools import expect from .tools import _return_success
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
@expect( def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]:
messages.WebAuthnCredentials,
field="credentials",
ret_type=List[messages.WebAuthnCredential],
)
def list_credentials(client: "TrezorClient") -> "MessageType":
return client.call(messages.WebAuthnListResidentCredentials())
@expect(messages.Success, field="message", ret_type=str)
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
return client.call( return client.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id) messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
).credentials
def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None:
ret = client.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id),
expect=messages.Success,
) )
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str) def remove_credential(client: "TrezorClient", index: int) -> str | None:
def remove_credential(client: "TrezorClient", index: int) -> "MessageType": ret = client.call(
return client.call(messages.WebAuthnRemoveResidentCredential(index=index)) messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
)
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str) def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None:
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType": ret = client.call(
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter)) messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
)
return _return_success(ret)
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int) def get_next_counter(client: "TrezorClient") -> int:
def get_next_counter(client: "TrezorClient") -> "MessageType": ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
return client.call(messages.GetNextU2FCounter()) return ret.u2f_counter

View File

@ -20,7 +20,7 @@ from hashlib import blake2s
from typing_extensions import Protocol, TypeGuard from typing_extensions import Protocol, TypeGuard
from .. import messages from .. import messages
from ..tools import expect, session from ..tools import session
from .core import VendorFirmware from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware from .legacy import LegacyFirmware, LegacyV2Firmware
@ -106,6 +106,7 @@ def update(
raise RuntimeError(f"Unexpected message {resp}") raise RuntimeError(f"Unexpected message {resp}")
@expect(messages.FirmwareHash, field="hash", ret_type=bytes) def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]) -> bytes:
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]): return client.call(
return client.call(messages.GetFirmwareHash(challenge=challenge)) messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
).hash

View File

@ -17,54 +17,50 @@
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from . import messages from . import messages
from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
@expect(messages.Entropy, field="entropy", ret_type=bytes) def get_entropy(client: "TrezorClient", size: int) -> bytes:
def get_entropy(client: "TrezorClient", size: int) -> "MessageType": return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
return client.call(messages.GetEntropy(size=size))
@expect(messages.SignedIdentity)
def sign_identity( def sign_identity(
client: "TrezorClient", client: "TrezorClient",
identity: messages.IdentityType, identity: messages.IdentityType,
challenge_hidden: bytes, challenge_hidden: bytes,
challenge_visual: str, challenge_visual: str,
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
) -> "MessageType": ) -> messages.SignedIdentity:
return client.call( return client.call(
messages.SignIdentity( messages.SignIdentity(
identity=identity, identity=identity,
challenge_hidden=challenge_hidden, challenge_hidden=challenge_hidden,
challenge_visual=challenge_visual, challenge_visual=challenge_visual,
ecdsa_curve_name=ecdsa_curve_name, ecdsa_curve_name=ecdsa_curve_name,
) ),
expect=messages.SignedIdentity,
) )
@expect(messages.ECDHSessionKey)
def get_ecdh_session_key( def get_ecdh_session_key(
client: "TrezorClient", client: "TrezorClient",
identity: messages.IdentityType, identity: messages.IdentityType,
peer_public_key: bytes, peer_public_key: bytes,
ecdsa_curve_name: Optional[str] = None, ecdsa_curve_name: Optional[str] = None,
) -> "MessageType": ) -> messages.ECDHSessionKey:
return client.call( return client.call(
messages.GetECDHSessionKey( messages.GetECDHSessionKey(
identity=identity, identity=identity,
peer_public_key=peer_public_key, peer_public_key=peer_public_key,
ecdsa_curve_name=ecdsa_curve_name, ecdsa_curve_name=ecdsa_curve_name,
) ),
expect=messages.ECDHSessionKey,
) )
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def encrypt_keyvalue( def encrypt_keyvalue(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
@ -73,7 +69,7 @@ def encrypt_keyvalue(
ask_on_encrypt: bool = True, ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> "MessageType": ) -> bytes:
return client.call( return client.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
@ -83,11 +79,11 @@ def encrypt_keyvalue(
ask_on_encrypt=ask_on_encrypt, ask_on_encrypt=ask_on_encrypt,
ask_on_decrypt=ask_on_decrypt, ask_on_decrypt=ask_on_decrypt,
iv=iv, iv=iv,
) ),
) expect=messages.CipheredKeyValue,
).value
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def decrypt_keyvalue( def decrypt_keyvalue(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
@ -96,7 +92,7 @@ def decrypt_keyvalue(
ask_on_encrypt: bool = True, ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True, ask_on_decrypt: bool = True,
iv: bytes = b"", iv: bytes = b"",
) -> "MessageType": ) -> bytes:
return client.call( return client.call(
messages.CipherKeyValue( messages.CipherKeyValue(
address_n=n, address_n=n,
@ -106,10 +102,10 @@ def decrypt_keyvalue(
ask_on_encrypt=ask_on_encrypt, ask_on_encrypt=ask_on_encrypt,
ask_on_decrypt=ask_on_decrypt, ask_on_decrypt=ask_on_decrypt,
iv=iv, iv=iv,
) ),
) expect=messages.CipheredKeyValue,
).value
@expect(messages.Nonce, field="nonce", ret_type=bytes) def get_nonce(client: "TrezorClient") -> bytes:
def get_nonce(client: "TrezorClient"): return client.call(messages.GetNonce(), expect=messages.Nonce).nonce
return client.call(messages.GetNonce())

View File

@ -17,11 +17,9 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from . import messages from . import messages
from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
@ -31,30 +29,30 @@ if TYPE_CHECKING:
# FAKECHAIN = 3 # FAKECHAIN = 3
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
show_display: bool = False, show_display: bool = False,
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> bytes:
return client.call( return client.call(
messages.MoneroGetAddress( messages.MoneroGetAddress(
address_n=n, address_n=n,
show_display=show_display, show_display=show_display,
network_type=network_type, network_type=network_type,
chunkify=chunkify, chunkify=chunkify,
) ),
) expect=messages.MoneroAddress,
).address
@expect(messages.MoneroWatchKey)
def get_watch_key( def get_watch_key(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET, network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
) -> "MessageType": ) -> messages.MoneroWatchKey:
return client.call( return client.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type) messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
expect=messages.MoneroWatchKey,
) )

View File

@ -18,11 +18,9 @@ import json
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from . import exceptions, messages from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
TYPE_TRANSACTION_TRANSFER = 0x0101 TYPE_TRANSACTION_TRANSFER = 0x0101
@ -196,25 +194,24 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.NEMAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
n: "Address", n: "Address",
network: int, network: int,
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( return client.call(
messages.NEMGetAddress( messages.NEMGetAddress(
address_n=n, network=network, show_display=show_display, chunkify=chunkify address_n=n, network=network, show_display=show_display, chunkify=chunkify
) ),
) expect=messages.NEMAddress,
).address
@expect(messages.NEMSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
) -> "MessageType": ) -> messages.NEMSignedTx:
try: try:
msg = create_sign_tx(transaction, chunkify=chunkify) msg = create_sign_tx(transaction, chunkify=chunkify)
except ValueError as e: except ValueError as e:
@ -222,4 +219,4 @@ def sign_tx(
assert msg.transaction is not None assert msg.transaction is not None
msg.transaction.address_n = n msg.transaction.address_n = n
return client.call(msg) return client.call(msg, expect=messages.NEMSignedTx)

View File

@ -35,6 +35,8 @@ from itertools import zip_longest
import typing_extensions as tx import typing_extensions as tx
from .exceptions import UnexpectedMessageError
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from IPython.lib.pretty import RepresentationPrinter # noqa: I900 from IPython.lib.pretty import RepresentationPrinter # noqa: I900
@ -312,6 +314,27 @@ class MessageType:
dump_message(data, self) dump_message(data, self)
return len(data.getvalue()) return len(data.getvalue())
@classmethod
def ensure_isinstance(cls, msg: t.Any) -> tx.Self:
"""Ensure that the received `msg` is an instance of this class.
If `msg` is not an instance of this class, raise an `UnexpectedMessageError`.
otherwise, return it. This is useful for type-checking like so:
>>> msg = client.call(SomeMessage())
>>> if isinstance(msg, Foo):
>>> return msg.foo_attr # attribute of Foo, type-checks OK
>>> else:
>>> msg = Bar.ensure_isinstance(msg) # raises if msg is something else
>>> return msg.bar_attr # attribute of Bar, type-checks OK
If there is just one expected message, you should use the `expect` parameter of
`Client.call` instead.
"""
if not isinstance(msg, cls):
raise UnexpectedMessageError(cls, msg)
return msg
class LimitedReader: class LimitedReader:
def __init__(self, reader: Reader, limit: int) -> None: def __init__(self, reader: Reader, limit: int) -> None:

View File

@ -18,41 +18,39 @@ from typing import TYPE_CHECKING
from . import messages from . import messages
from .protobuf import dict_to_proto from .protobuf import dict_to_proto
from .tools import dict_from_camelcase, expect from .tools import dict_from_camelcase
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment") REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination") REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@expect(messages.RippleAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( return client.call(
messages.RippleGetAddress( messages.RippleGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) ),
) expect=messages.RippleAddress,
).address
@expect(messages.RippleSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", client: "TrezorClient",
address_n: "Address", address_n: "Address",
msg: messages.RippleSignTx, msg: messages.RippleSignTx,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> messages.RippleSignedTx:
msg.address_n = address_n msg.address_n = address_n
msg.chunkify = chunkify msg.chunkify = chunkify
return client.call(msg) return client.call(msg, expect=messages.RippleSignedTx)
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx: def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:

View File

@ -1,51 +1,49 @@
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
from . import messages from . import messages
from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
@expect(messages.SolanaPublicKey)
def get_public_key( def get_public_key(
client: "TrezorClient", client: "TrezorClient",
address_n: List[int], address_n: List[int],
show_display: bool, show_display: bool,
) -> "MessageType": ) -> bytes:
return client.call( return client.call(
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display) messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
) expect=messages.SolanaPublicKey,
).public_key
@expect(messages.SolanaAddress)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
address_n: List[int], address_n: List[int],
show_display: bool, show_display: bool,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( return client.call(
messages.SolanaGetAddress( messages.SolanaGetAddress(
address_n=address_n, address_n=address_n,
show_display=show_display, show_display=show_display,
chunkify=chunkify, chunkify=chunkify,
) ),
) expect=messages.SolanaAddress,
).address
@expect(messages.SolanaTxSignature)
def sign_tx( def sign_tx(
client: "TrezorClient", client: "TrezorClient",
address_n: List[int], address_n: List[int],
serialized_tx: bytes, serialized_tx: bytes,
additional_info: Optional[messages.SolanaTxAdditionalInfo], additional_info: Optional[messages.SolanaTxAdditionalInfo],
) -> "MessageType": ) -> bytes:
return client.call( return client.call(
messages.SolanaSignTx( messages.SolanaSignTx(
address_n=address_n, address_n=address_n,
serialized_tx=serialized_tx, serialized_tx=serialized_tx,
additional_info=additional_info, additional_info=additional_info,
) ),
) expect=messages.SolanaTxSignature,
).signature

View File

@ -18,11 +18,9 @@ from decimal import Decimal
from typing import TYPE_CHECKING, List, Tuple, Union from typing import TYPE_CHECKING, List, Tuple, Union
from . import exceptions, messages from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
StellarMessageType = Union[ StellarMessageType = Union[
@ -323,18 +321,18 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
# ====== Client functions ====== # # ====== Client functions ====== #
@expect(messages.StellarAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( return client.call(
messages.StellarGetAddress( messages.StellarGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) ),
) expect=messages.StellarAddress,
).address
def sign_tx( def sign_tx(
@ -364,10 +362,7 @@ def sign_tx(
"Reached end of operations without a signature." "Reached end of operations without a signature."
) from None ) from None
if not isinstance(resp, messages.StellarSignedTx): resp = messages.StellarSignedTx.ensure_isinstance(resp)
raise exceptions.TrezorException(
f"Unexpected message: {resp.__class__.__name__}"
)
if operations: if operations:
raise exceptions.TrezorException( raise exceptions.TrezorException(

View File

@ -17,49 +17,46 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from . import messages from . import messages
from .tools import expect
if TYPE_CHECKING: if TYPE_CHECKING:
from .client import TrezorClient from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address from .tools import Address
@expect(messages.TezosAddress, field="address", ret_type=str)
def get_address( def get_address(
client: "TrezorClient", client: "TrezorClient",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( return client.call(
messages.TezosGetAddress( messages.TezosGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) ),
) expect=messages.TezosAddress,
).address
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
def get_public_key( def get_public_key(
client: "TrezorClient", client: "TrezorClient",
address_n: "Address", address_n: "Address",
show_display: bool = False, show_display: bool = False,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> str:
return client.call( return client.call(
messages.TezosGetPublicKey( messages.TezosGetPublicKey(
address_n=address_n, show_display=show_display, chunkify=chunkify address_n=address_n, show_display=show_display, chunkify=chunkify
) ),
) expect=messages.TezosPublicKey,
).public_key
@expect(messages.TezosSignedTx)
def sign_tx( def sign_tx(
client: "TrezorClient", client: "TrezorClient",
address_n: "Address", address_n: "Address",
sign_tx_msg: messages.TezosSignTx, sign_tx_msg: messages.TezosSignTx,
chunkify: bool = False, chunkify: bool = False,
) -> "MessageType": ) -> messages.TezosSignedTx:
sign_tx_msg.address_n = address_n sign_tx_msg.address_n = address_n
sign_tx_msg.chunkify = chunkify sign_tx_msg.chunkify = chunkify
return client.call(sign_tx_msg) return client.call(sign_tx_msg, expect=messages.TezosSignedTx)