1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-18 03:10:58 +00:00

refactor(python): replace usages of @expect

This commit is contained in:
matejcik 2025-01-03 16:36:40 +01:00 committed by matejcik
parent 53bdef5bb4
commit 6a5836708f
28 changed files with 518 additions and 584 deletions

View File

@ -0,0 +1 @@
String return values are deprecated in functions where the semantic result is a success (specifically those that were returning the message from Trezor's `Success` response). Type annotations are updated to `str | None`, and in a future release those functions will be returning `None` on success, or raise an exception on a failure.

View File

@ -0,0 +1 @@
Return value of `device.recover()` is deprecated. In the future, this function will return `None`.

View File

@ -0,0 +1 @@
Return values in `solana` module were changed from the wrapping protobuf messages to the raw inner values (`str` for address, `bytes` for pubkey / signature).

View File

@ -0,0 +1 @@
`trezorctl device` commands whose default result is a success will not print anything to stdout anymore, in line with Unix philosophy.

View File

@ -17,20 +17,18 @@
from typing import TYPE_CHECKING
from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
@expect(messages.BenchmarkNames)
def list_names(
client: "TrezorClient",
) -> "MessageType":
return client.call(messages.BenchmarkListNames())
) -> messages.BenchmarkNames:
return client.call(messages.BenchmarkListNames(), expect=messages.BenchmarkNames)
@expect(messages.BenchmarkResult)
def run(client: "TrezorClient", name: str) -> "MessageType":
return client.call(messages.BenchmarkRun(name=name))
def run(client: "TrezorClient", name: str) -> messages.BenchmarkResult:
return client.call(
messages.BenchmarkRun(name=name), expect=messages.BenchmarkResult
)

View File

@ -18,35 +18,34 @@ from typing import TYPE_CHECKING
from . import messages
from .protobuf import dict_to_proto
from .tools import expect, session
from .tools import session
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
@expect(messages.BinanceAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
) -> str:
return client.call(
messages.BinanceGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
)
),
expect=messages.BinanceAddress,
).address
@expect(messages.BinancePublicKey, field="public_key", ret_type=bytes)
def get_public_key(
client: "TrezorClient", address_n: "Address", show_display: bool = False
) -> "MessageType":
) -> bytes:
return client.call(
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display)
)
messages.BinanceGetPublicKey(address_n=address_n, show_display=show_display),
expect=messages.BinancePublicKey,
).public_key
@session
@ -60,13 +59,7 @@ def sign_tx(
tx_msg["chunkify"] = chunkify
envelope = dict_to_proto(messages.BinanceSignTx, tx_msg)
response = client.call(envelope)
if not isinstance(response, messages.BinanceTxRequest):
raise RuntimeError(
"Invalid response, expected BinanceTxRequest, received "
+ type(response).__name__
)
client.call(envelope, expect=messages.BinanceTxRequest)
if "refid" in msg:
msg = dict_to_proto(messages.BinanceCancelMsg, msg)
@ -77,12 +70,4 @@ def sign_tx(
else:
raise ValueError("can not determine msg type")
response = client.call(msg)
if not isinstance(response, messages.BinanceSignedTx):
raise RuntimeError(
"Invalid response, expected BinanceSignedTx, received "
+ type(response).__name__
)
return response
return client.call(msg, expect=messages.BinanceSignedTx)

View File

@ -14,6 +14,8 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from __future__ import annotations
import warnings
from copy import copy
from decimal import Decimal
@ -23,11 +25,10 @@ from typing import TYPE_CHECKING, Any, AnyStr, List, Optional, Sequence, Tuple
from typing_extensions import Protocol, TypedDict
from . import exceptions, messages
from .tools import expect, prepare_message_bytes, session
from .tools import _return_success, prepare_message_bytes, session
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
class ScriptSig(TypedDict):
@ -103,7 +104,6 @@ def from_json(json_dict: "Transaction") -> messages.TransactionType:
)
@expect(messages.PublicKey)
def get_public_node(
client: "TrezorClient",
n: "Address",
@ -114,13 +114,12 @@ def get_public_node(
ignore_xpub_magic: bool = False,
unlock_path: Optional[List[int]] = None,
unlock_path_mac: Optional[bytes] = None,
) -> "MessageType":
) -> messages.PublicKey:
if unlock_path:
res = client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call(
messages.GetPublicKey(
@ -130,16 +129,15 @@ def get_public_node(
coin_name=coin_name,
script_type=script_type,
ignore_xpub_magic=ignore_xpub_magic,
)
),
expect=messages.PublicKey,
)
@expect(messages.Address, field="address", ret_type=str)
def get_address(*args: Any, **kwargs: Any):
return get_authenticated_address(*args, **kwargs)
def get_address(*args: Any, **kwargs: Any) -> str:
return get_authenticated_address(*args, **kwargs).address
@expect(messages.Address)
def get_authenticated_address(
client: "TrezorClient",
coin_name: str,
@ -151,13 +149,12 @@ def get_authenticated_address(
unlock_path: Optional[List[int]] = None,
unlock_path_mac: Optional[bytes] = None,
chunkify: bool = False,
) -> "MessageType":
) -> messages.Address:
if unlock_path:
res = client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
return client.call(
messages.GetAddress(
@ -168,26 +165,27 @@ def get_authenticated_address(
script_type=script_type,
ignore_xpub_magic=ignore_xpub_magic,
chunkify=chunkify,
)
),
expect=messages.Address,
)
@expect(messages.OwnershipId, field="ownership_id", ret_type=bytes)
def get_ownership_id(
client: "TrezorClient",
coin_name: str,
n: "Address",
multisig: Optional[messages.MultisigRedeemScriptType] = None,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
) -> bytes:
return client.call(
messages.GetOwnershipId(
address_n=n,
coin_name=coin_name,
multisig=multisig,
script_type=script_type,
)
)
),
expect=messages.OwnershipId,
).ownership_id
def get_ownership_proof(
@ -202,9 +200,7 @@ def get_ownership_proof(
preauthorized: bool = False,
) -> Tuple[bytes, bytes]:
if preauthorized:
res = client.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
res = client.call(
messages.GetOwnershipProof(
@ -215,16 +211,13 @@ def get_ownership_proof(
user_confirmation=user_confirmation,
ownership_ids=ownership_ids,
commitment_data=commitment_data,
),
expect=messages.OwnershipProof,
)
)
if not isinstance(res, messages.OwnershipProof):
raise exceptions.TrezorException("Unexpected message")
return res.ownership_proof, res.signature
@expect(messages.MessageSignature)
def sign_message(
client: "TrezorClient",
coin_name: str,
@ -233,7 +226,7 @@ def sign_message(
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
no_script_type: bool = False,
chunkify: bool = False,
) -> "MessageType":
) -> messages.MessageSignature:
return client.call(
messages.SignMessage(
coin_name=coin_name,
@ -242,7 +235,8 @@ def sign_message(
script_type=script_type,
no_script_type=no_script_type,
chunkify=chunkify,
)
),
expect=messages.MessageSignature,
)
@ -255,18 +249,19 @@ def verify_message(
chunkify: bool = False,
) -> bool:
try:
resp = client.call(
client.call(
messages.VerifyMessage(
address=address,
signature=signature,
message=prepare_message_bytes(message),
coin_name=coin_name,
chunkify=chunkify,
),
expect=messages.Success,
)
)
return True
except exceptions.TrezorFailure:
return False
return isinstance(resp, messages.Success)
@session
@ -319,17 +314,14 @@ def sign_tx(
setattr(signtx, name, value)
if unlock_path:
res = client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac)
client.call(
messages.UnlockPath(address_n=unlock_path, mac=unlock_path_mac),
expect=messages.UnlockedPathRequest,
)
if not isinstance(res, messages.UnlockedPathRequest):
raise exceptions.TrezorException("Unexpected message")
elif preauthorized:
res = client.call(messages.DoPreauthorized())
if not isinstance(res, messages.PreauthorizedRequest):
raise exceptions.TrezorException("Unexpected message")
client.call(messages.DoPreauthorized(), expect=messages.PreauthorizedRequest)
res = client.call(signtx)
res = client.call(signtx, expect=messages.TxRequest)
# Prepare structure for signatures
signatures: List[Optional[bytes]] = [None] * len(inputs)
@ -357,7 +349,7 @@ def sign_tx(
)
R = messages.RequestType
while isinstance(res, messages.TxRequest):
while True:
# If there's some part of signed transaction, let's add it
if res.serialized:
if res.serialized.serialized_tx:
@ -388,7 +380,7 @@ def sign_tx(
if res.request_type == R.TXPAYMENTREQ:
assert res.details.request_index is not None
msg = payment_reqs[res.details.request_index]
res = client.call(msg)
res = client.call(msg, expect=messages.TxRequest)
else:
msg = messages.TransactionType()
if res.request_type == R.TXMETA:
@ -418,10 +410,7 @@ def sign_tx(
f"Unknown request type - {res.request_type}."
)
res = client.call(messages.TxAck(tx=msg))
if not isinstance(res, messages.TxRequest):
raise exceptions.TrezorException("Unexpected message")
res = client.call(messages.TxAck(tx=msg), expect=messages.TxRequest)
for i, sig in zip(inputs, signatures):
if i.script_type != messages.InputScriptType.EXTERNAL and sig is None:
@ -430,7 +419,6 @@ def sign_tx(
return signatures, serialized_tx
@expect(messages.Success, field="message", ret_type=str)
def authorize_coinjoin(
client: "TrezorClient",
coordinator: str,
@ -440,8 +428,8 @@ def authorize_coinjoin(
n: "Address",
coin_name: str,
script_type: messages.InputScriptType = messages.InputScriptType.SPENDADDRESS,
) -> "MessageType":
return client.call(
) -> str | None:
resp = client.call(
messages.AuthorizeCoinJoin(
coordinator=coordinator,
max_rounds=max_rounds,
@ -450,5 +438,7 @@ def authorize_coinjoin(
address_n=n,
coin_name=coin_name,
script_type=script_type,
),
expect=messages.Success,
)
)
return _return_success(resp)

View File

@ -31,12 +31,11 @@ from typing import (
Union,
)
from . import exceptions, messages, tools
from .tools import expect
from . import messages as m
from . import tools
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
PROTOCOL_MAGICS = {
"mainnet": 764824073,
@ -72,35 +71,33 @@ INCOMPLETE_OUTPUT_ERROR_MESSAGE = "The output is missing some fields"
INVALID_OUTPUT_TOKEN_BUNDLE_ENTRY = "The output's token_bundle entry is invalid"
INVALID_MINT_TOKEN_BUNDLE_ENTRY = "The mint token_bundle entry is invalid"
InputWithPath = Tuple[messages.CardanoTxInput, List[int]]
CollateralInputWithPath = Tuple[messages.CardanoTxCollateralInput, List[int]]
AssetGroupWithTokens = Tuple[messages.CardanoAssetGroup, List[messages.CardanoToken]]
InputWithPath = Tuple[m.CardanoTxInput, List[int]]
CollateralInputWithPath = Tuple[m.CardanoTxCollateralInput, List[int]]
AssetGroupWithTokens = Tuple[m.CardanoAssetGroup, List[m.CardanoToken]]
OutputWithData = Tuple[
messages.CardanoTxOutput,
m.CardanoTxOutput,
List[AssetGroupWithTokens],
List[messages.CardanoTxInlineDatumChunk],
List[messages.CardanoTxReferenceScriptChunk],
List[m.CardanoTxInlineDatumChunk],
List[m.CardanoTxReferenceScriptChunk],
]
OutputItem = Union[
messages.CardanoTxOutput,
messages.CardanoAssetGroup,
messages.CardanoToken,
messages.CardanoTxInlineDatumChunk,
messages.CardanoTxReferenceScriptChunk,
m.CardanoTxOutput,
m.CardanoAssetGroup,
m.CardanoToken,
m.CardanoTxInlineDatumChunk,
m.CardanoTxReferenceScriptChunk,
]
CertificateItem = Union[
messages.CardanoTxCertificate,
messages.CardanoPoolOwner,
messages.CardanoPoolRelayParameters,
]
MintItem = Union[
messages.CardanoTxMint, messages.CardanoAssetGroup, messages.CardanoToken
m.CardanoTxCertificate,
m.CardanoPoolOwner,
m.CardanoPoolRelayParameters,
]
MintItem = Union[m.CardanoTxMint, m.CardanoAssetGroup, m.CardanoToken]
PoolOwnersAndRelays = Tuple[
List[messages.CardanoPoolOwner], List[messages.CardanoPoolRelayParameters]
List[m.CardanoPoolOwner], List[m.CardanoPoolRelayParameters]
]
CertificateWithPoolOwnersAndRelays = Tuple[
messages.CardanoTxCertificate, Optional[PoolOwnersAndRelays]
m.CardanoTxCertificate, Optional[PoolOwnersAndRelays]
]
Path = List[int]
Witness = Tuple[Path, bytes]
@ -108,9 +105,7 @@ AuxiliaryDataSupplement = Dict[str, Union[int, bytes]]
SignTxResponse = Dict[str, Union[bytes, List[Witness], AuxiliaryDataSupplement]]
Chunk = TypeVar(
"Chunk",
bound=Union[
messages.CardanoTxInlineDatumChunk, messages.CardanoTxReferenceScriptChunk
],
bound=Union[m.CardanoTxInlineDatumChunk, m.CardanoTxReferenceScriptChunk],
)
@ -123,7 +118,7 @@ def parse_optional_int(value: Optional[str]) -> Optional[int]:
def create_address_parameters(
address_type: messages.CardanoAddressType,
address_type: m.CardanoAddressType,
address_n: List[int],
address_n_staking: Optional[List[int]] = None,
staking_key_hash: Optional[bytes] = None,
@ -132,18 +127,18 @@ def create_address_parameters(
certificate_index: Optional[int] = None,
script_payment_hash: Optional[bytes] = None,
script_staking_hash: Optional[bytes] = None,
) -> messages.CardanoAddressParametersType:
) -> m.CardanoAddressParametersType:
certificate_pointer = None
if address_type in (
messages.CardanoAddressType.POINTER,
messages.CardanoAddressType.POINTER_SCRIPT,
m.CardanoAddressType.POINTER,
m.CardanoAddressType.POINTER_SCRIPT,
):
certificate_pointer = _create_certificate_pointer(
block_index, tx_index, certificate_index
)
return messages.CardanoAddressParametersType(
return m.CardanoAddressParametersType(
address_type=address_type,
address_n=address_n,
address_n_staking=address_n_staking,
@ -158,11 +153,11 @@ def _create_certificate_pointer(
block_index: Optional[int],
tx_index: Optional[int],
certificate_index: Optional[int],
) -> messages.CardanoBlockchainPointerType:
) -> m.CardanoBlockchainPointerType:
if block_index is None or tx_index is None or certificate_index is None:
raise ValueError("Invalid pointer parameters")
return messages.CardanoBlockchainPointerType(
return m.CardanoBlockchainPointerType(
block_index=block_index, tx_index=tx_index, certificate_index=certificate_index
)
@ -173,7 +168,7 @@ def parse_input(tx_input: dict) -> InputWithPath:
path = tools.parse_path(tx_input.get("path", ""))
return (
messages.CardanoTxInput(
m.CardanoTxInput(
prev_hash=bytes.fromhex(tx_input["prev_hash"]),
prev_index=tx_input["prev_index"],
),
@ -204,22 +199,22 @@ def parse_output(output: dict) -> OutputWithData:
datum_hash = parse_optional_bytes(output.get("datum_hash"))
serialization_format = messages.CardanoTxOutputSerializationFormat.ARRAY_LEGACY
serialization_format = m.CardanoTxOutputSerializationFormat.ARRAY_LEGACY
if "format" in output:
serialization_format = output["format"]
inline_datum_size, inline_datum_chunks = _parse_chunkable_data(
parse_optional_bytes(output.get("inline_datum")),
messages.CardanoTxInlineDatumChunk,
m.CardanoTxInlineDatumChunk,
)
reference_script_size, reference_script_chunks = _parse_chunkable_data(
parse_optional_bytes(output.get("reference_script")),
messages.CardanoTxReferenceScriptChunk,
m.CardanoTxReferenceScriptChunk,
)
return (
messages.CardanoTxOutput(
m.CardanoTxOutput(
address=address,
address_parameters=address_parameters,
amount=int(output["amount"]),
@ -253,7 +248,7 @@ def _parse_token_bundle(
result.append(
(
messages.CardanoAssetGroup(
m.CardanoAssetGroup(
policy_id=bytes.fromhex(token_group["policy_id"]),
tokens_count=len(tokens),
),
@ -264,7 +259,7 @@ def _parse_token_bundle(
return result
def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.CardanoToken]:
def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[m.CardanoToken]:
error_message: str
if is_mint:
error_message = INVALID_MINT_TOKEN_BUNDLE_ENTRY
@ -288,7 +283,7 @@ def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.Cardan
amount = int(token["amount"])
result.append(
messages.CardanoToken(
m.CardanoToken(
asset_name_bytes=bytes.fromhex(token["asset_name_bytes"]),
amount=amount,
mint_amount=mint_amount,
@ -300,7 +295,7 @@ def _parse_tokens(tokens: Iterable[dict], is_mint: bool) -> List[messages.Cardan
def _parse_address_parameters(
address_parameters: dict, error_message: str
) -> messages.CardanoAddressParametersType:
) -> m.CardanoAddressParametersType:
if "addressType" not in address_parameters:
raise ValueError(error_message)
@ -317,7 +312,7 @@ def _parse_address_parameters(
)
return create_address_parameters(
messages.CardanoAddressType(address_parameters["addressType"]),
m.CardanoAddressType(address_parameters["addressType"]),
payment_path,
staking_path,
staking_key_hash_bytes,
@ -346,7 +341,7 @@ def _create_data_chunks(data: bytes) -> Iterator[bytes]:
processed_size += MAX_CHUNK_SIZE
def parse_native_script(native_script: dict) -> messages.CardanoNativeScript:
def parse_native_script(native_script: dict) -> m.CardanoNativeScript:
if "type" not in native_script:
raise ValueError("Script is missing some fields")
@ -364,7 +359,7 @@ def parse_native_script(native_script: dict) -> messages.CardanoNativeScript:
invalid_before = parse_optional_int(native_script.get("invalid_before"))
invalid_hereafter = parse_optional_int(native_script.get("invalid_hereafter"))
return messages.CardanoNativeScript(
return m.CardanoNativeScript(
type=type,
scripts=scripts,
key_hash=key_hash,
@ -385,7 +380,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
certificate_type = certificate["type"]
if certificate_type == messages.CardanoCertificateType.STAKE_DELEGATION:
if certificate_type == m.CardanoCertificateType.STAKE_DELEGATION:
if "pool" not in certificate:
raise CERTIFICATE_MISSING_FIELDS_ERROR
@ -394,7 +389,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
)
return (
messages.CardanoTxCertificate(
m.CardanoTxCertificate(
type=certificate_type,
path=path,
pool=bytes.fromhex(certificate["pool"]),
@ -404,15 +399,15 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
None,
)
elif certificate_type in (
messages.CardanoCertificateType.STAKE_REGISTRATION,
messages.CardanoCertificateType.STAKE_DEREGISTRATION,
m.CardanoCertificateType.STAKE_REGISTRATION,
m.CardanoCertificateType.STAKE_DEREGISTRATION,
):
path, script_hash, key_hash = _parse_credential(
certificate, CERTIFICATE_MISSING_FIELDS_ERROR
)
return (
messages.CardanoTxCertificate(
m.CardanoTxCertificate(
type=certificate_type,
path=path,
script_hash=script_hash,
@ -421,8 +416,8 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
None,
)
elif certificate_type in (
messages.CardanoCertificateType.STAKE_REGISTRATION_CONWAY,
messages.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY,
m.CardanoCertificateType.STAKE_REGISTRATION_CONWAY,
m.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY,
):
if "deposit" not in certificate:
raise CERTIFICATE_MISSING_FIELDS_ERROR
@ -432,7 +427,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
)
return (
messages.CardanoTxCertificate(
m.CardanoTxCertificate(
type=certificate_type,
path=path,
script_hash=script_hash,
@ -441,7 +436,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
),
None,
)
elif certificate_type == messages.CardanoCertificateType.STAKE_POOL_REGISTRATION:
elif certificate_type == m.CardanoCertificateType.STAKE_POOL_REGISTRATION:
pool_parameters = certificate["pool_parameters"]
if any(
@ -450,9 +445,9 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
):
raise CERTIFICATE_MISSING_FIELDS_ERROR
pool_metadata: Optional[messages.CardanoPoolMetadataType]
pool_metadata: Optional[m.CardanoPoolMetadataType]
if pool_parameters.get("metadata") is not None:
pool_metadata = messages.CardanoPoolMetadataType(
pool_metadata = m.CardanoPoolMetadataType(
url=pool_parameters["metadata"]["url"],
hash=bytes.fromhex(pool_parameters["metadata"]["hash"]),
)
@ -469,9 +464,9 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
]
return (
messages.CardanoTxCertificate(
m.CardanoTxCertificate(
type=certificate_type,
pool_parameters=messages.CardanoPoolParametersType(
pool_parameters=m.CardanoPoolParametersType(
pool_id=bytes.fromhex(pool_parameters["pool_id"]),
vrf_key_hash=bytes.fromhex(pool_parameters["vrf_key_hash"]),
pledge=int(pool_parameters["pledge"]),
@ -486,7 +481,7 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
),
(owners, relays),
)
if certificate_type == messages.CardanoCertificateType.VOTE_DELEGATION:
if certificate_type == m.CardanoCertificateType.VOTE_DELEGATION:
if "drep" not in certificate:
raise CERTIFICATE_MISSING_FIELDS_ERROR
@ -495,13 +490,13 @@ def parse_certificate(certificate: dict) -> CertificateWithPoolOwnersAndRelays:
)
return (
messages.CardanoTxCertificate(
m.CardanoTxCertificate(
type=certificate_type,
path=path,
script_hash=script_hash,
key_hash=key_hash,
drep=messages.CardanoDRep(
type=messages.CardanoDRepType(certificate["drep"]["type"]),
drep=m.CardanoDRep(
type=m.CardanoDRepType(certificate["drep"]["type"]),
key_hash=parse_optional_bytes(certificate["drep"].get("key_hash")),
script_hash=parse_optional_bytes(
certificate["drep"].get("script_hash")
@ -527,21 +522,21 @@ def _parse_credential(
return path, script_hash, key_hash
def _parse_pool_owner(pool_owner: dict) -> messages.CardanoPoolOwner:
def _parse_pool_owner(pool_owner: dict) -> m.CardanoPoolOwner:
if "staking_key_path" in pool_owner:
return messages.CardanoPoolOwner(
return m.CardanoPoolOwner(
staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
)
return messages.CardanoPoolOwner(
return m.CardanoPoolOwner(
staking_key_hash=bytes.fromhex(pool_owner["staking_key_hash"])
)
def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
pool_relay_type = messages.CardanoPoolRelayType(pool_relay["type"])
def _parse_pool_relay(pool_relay: dict) -> m.CardanoPoolRelayParameters:
pool_relay_type = m.CardanoPoolRelayType(pool_relay["type"])
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP:
if pool_relay_type == m.CardanoPoolRelayType.SINGLE_HOST_IP:
ipv4_address_packed = (
ip_address(pool_relay["ipv4_address"]).packed
if "ipv4_address" in pool_relay
@ -553,20 +548,20 @@ def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
else None
)
return messages.CardanoPoolRelayParameters(
return m.CardanoPoolRelayParameters(
type=pool_relay_type,
port=int(pool_relay["port"]),
ipv4_address=ipv4_address_packed,
ipv6_address=ipv6_address_packed,
)
elif pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_NAME:
return messages.CardanoPoolRelayParameters(
elif pool_relay_type == m.CardanoPoolRelayType.SINGLE_HOST_NAME:
return m.CardanoPoolRelayParameters(
type=pool_relay_type,
port=int(pool_relay["port"]),
host_name=pool_relay["host_name"],
)
elif pool_relay_type == messages.CardanoPoolRelayType.MULTIPLE_HOST_NAME:
return messages.CardanoPoolRelayParameters(
elif pool_relay_type == m.CardanoPoolRelayType.MULTIPLE_HOST_NAME:
return m.CardanoPoolRelayParameters(
type=pool_relay_type,
host_name=pool_relay["host_name"],
)
@ -574,7 +569,7 @@ def _parse_pool_relay(pool_relay: dict) -> messages.CardanoPoolRelayParameters:
raise ValueError("Unknown pool relay type")
def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
def parse_withdrawal(withdrawal: dict) -> m.CardanoTxWithdrawal:
WITHDRAWAL_MISSING_FIELDS_ERROR = ValueError(
"The withdrawal is missing some fields"
)
@ -586,7 +581,7 @@ def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
withdrawal, WITHDRAWAL_MISSING_FIELDS_ERROR
)
return messages.CardanoTxWithdrawal(
return m.CardanoTxWithdrawal(
path=path,
amount=int(withdrawal["amount"]),
script_hash=script_hash,
@ -596,7 +591,7 @@ def parse_withdrawal(withdrawal: dict) -> messages.CardanoTxWithdrawal:
def parse_auxiliary_data(
auxiliary_data: Optional[dict],
) -> Optional[messages.CardanoTxAuxiliaryData]:
) -> Optional[m.CardanoTxAuxiliaryData]:
if auxiliary_data is None:
return None
@ -620,17 +615,17 @@ def parse_auxiliary_data(
if not all(k in delegation for k in REQUIRED_FIELDS_CVOTE_DELEGATION):
raise AUXILIARY_DATA_MISSING_FIELDS_ERROR
delegations.append(
messages.CardanoCVoteRegistrationDelegation(
m.CardanoCVoteRegistrationDelegation(
vote_public_key=bytes.fromhex(delegation["vote_public_key"]),
weight=int(delegation["weight"]),
)
)
voting_purpose = None
if serialization_format == messages.CardanoCVoteRegistrationFormat.CIP36:
if serialization_format == m.CardanoCVoteRegistrationFormat.CIP36:
voting_purpose = cvote_registration.get("voting_purpose")
cvote_registration_parameters = messages.CardanoCVoteRegistrationParametersType(
cvote_registration_parameters = m.CardanoCVoteRegistrationParametersType(
vote_public_key=parse_optional_bytes(
cvote_registration.get("vote_public_key")
),
@ -653,7 +648,7 @@ def parse_auxiliary_data(
if hash is None and cvote_registration_parameters is None:
raise AUXILIARY_DATA_MISSING_FIELDS_ERROR
return messages.CardanoTxAuxiliaryData(
return m.CardanoTxAuxiliaryData(
hash=hash,
cvote_registration_parameters=cvote_registration_parameters,
)
@ -673,7 +668,7 @@ def parse_collateral_input(collateral_input: dict) -> CollateralInputWithPath:
path = tools.parse_path(collateral_input.get("path", ""))
return (
messages.CardanoTxCollateralInput(
m.CardanoTxCollateralInput(
prev_hash=bytes.fromhex(collateral_input["prev_hash"]),
prev_index=collateral_input["prev_index"],
),
@ -681,20 +676,20 @@ def parse_collateral_input(collateral_input: dict) -> CollateralInputWithPath:
)
def parse_required_signer(required_signer: dict) -> messages.CardanoTxRequiredSigner:
def parse_required_signer(required_signer: dict) -> m.CardanoTxRequiredSigner:
key_hash = parse_optional_bytes(required_signer.get("key_hash"))
key_path = tools.parse_path(required_signer.get("key_path", ""))
return messages.CardanoTxRequiredSigner(
return m.CardanoTxRequiredSigner(
key_hash=key_hash,
key_path=key_path,
)
def parse_reference_input(reference_input: dict) -> messages.CardanoTxReferenceInput:
def parse_reference_input(reference_input: dict) -> m.CardanoTxReferenceInput:
if not all(k in reference_input for k in REQUIRED_FIELDS_INPUT):
raise ValueError("The reference input is missing some fields")
return messages.CardanoTxReferenceInput(
return m.CardanoTxReferenceInput(
prev_hash=bytes.fromhex(reference_input["prev_hash"]),
prev_index=reference_input["prev_index"],
)
@ -712,16 +707,16 @@ def parse_additional_witness_request(
def _get_witness_requests(
inputs: Sequence[InputWithPath],
certificates: Sequence[CertificateWithPoolOwnersAndRelays],
withdrawals: Sequence[messages.CardanoTxWithdrawal],
withdrawals: Sequence[m.CardanoTxWithdrawal],
collateral_inputs: Sequence[CollateralInputWithPath],
required_signers: Sequence[messages.CardanoTxRequiredSigner],
required_signers: Sequence[m.CardanoTxRequiredSigner],
additional_witness_requests: Sequence[Path],
signing_mode: messages.CardanoTxSigningMode,
) -> List[messages.CardanoTxWitnessRequest]:
signing_mode: m.CardanoTxSigningMode,
) -> List[m.CardanoTxWitnessRequest]:
paths = set()
# don't gather paths from tx elements in MULTISIG_TRANSACTION signing mode
if signing_mode != messages.CardanoTxSigningMode.MULTISIG_TRANSACTION:
if signing_mode != m.CardanoTxSigningMode.MULTISIG_TRANSACTION:
for _, path in inputs:
if path:
paths.add(tuple(path))
@ -729,18 +724,17 @@ def _get_witness_requests(
if (
certificate.type
in (
messages.CardanoCertificateType.STAKE_DEREGISTRATION,
messages.CardanoCertificateType.STAKE_DELEGATION,
messages.CardanoCertificateType.STAKE_REGISTRATION_CONWAY,
messages.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY,
messages.CardanoCertificateType.VOTE_DELEGATION,
m.CardanoCertificateType.STAKE_DEREGISTRATION,
m.CardanoCertificateType.STAKE_DELEGATION,
m.CardanoCertificateType.STAKE_REGISTRATION_CONWAY,
m.CardanoCertificateType.STAKE_DEREGISTRATION_CONWAY,
m.CardanoCertificateType.VOTE_DELEGATION,
)
and certificate.path
):
paths.add(tuple(certificate.path))
elif (
certificate.type
== messages.CardanoCertificateType.STAKE_POOL_REGISTRATION
certificate.type == m.CardanoCertificateType.STAKE_POOL_REGISTRATION
and pool_owners_and_relays is not None
):
owners, _ = pool_owners_and_relays
@ -752,7 +746,7 @@ def _get_witness_requests(
paths.add(tuple(withdrawal.path))
# gather Plutus-related paths
if signing_mode == messages.CardanoTxSigningMode.PLUTUS_TRANSACTION:
if signing_mode == m.CardanoTxSigningMode.PLUTUS_TRANSACTION:
for _, path in collateral_inputs:
if path:
paths.add(tuple(path))
@ -765,10 +759,10 @@ def _get_witness_requests(
paths.add(tuple(additional_witness_request))
sorted_paths = sorted([list(path) for path in paths])
return [messages.CardanoTxWitnessRequest(path=path) for path in sorted_paths]
return [m.CardanoTxWitnessRequest(path=path) for path in sorted_paths]
def _get_inputs_items(inputs: List[InputWithPath]) -> Iterator[messages.CardanoTxInput]:
def _get_inputs_items(inputs: List[InputWithPath]) -> Iterator[m.CardanoTxInput]:
for input, _ in inputs:
yield input
@ -807,7 +801,7 @@ def _get_certificates_items(
def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]:
if not mint:
return
yield messages.CardanoTxMint(asset_groups_count=len(mint))
yield m.CardanoTxMint(asset_groups_count=len(mint))
for asset_group, tokens in mint:
yield asset_group
yield from tokens
@ -815,7 +809,7 @@ def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]:
def _get_collateral_inputs_items(
collateral_inputs: Sequence[CollateralInputWithPath],
) -> Iterator[messages.CardanoTxCollateralInput]:
) -> Iterator[m.CardanoTxCollateralInput]:
for collateral_input, _ in collateral_inputs:
yield collateral_input
@ -823,88 +817,86 @@ def _get_collateral_inputs_items(
# ====== Client functions ====== #
@expect(messages.CardanoAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
address_parameters: messages.CardanoAddressParametersType,
address_parameters: m.CardanoAddressParametersType,
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"],
show_display: bool = False,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
chunkify: bool = False,
) -> "MessageType":
) -> str:
return client.call(
messages.CardanoGetAddress(
m.CardanoGetAddress(
address_parameters=address_parameters,
protocol_magic=protocol_magic,
network_id=network_id,
show_display=show_display,
derivation_type=derivation_type,
chunkify=chunkify,
)
)
),
expect=m.CardanoAddress,
).address
@expect(messages.CardanoPublicKey)
def get_public_key(
client: "TrezorClient",
address_n: List[int],
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
show_display: bool = False,
) -> "MessageType":
) -> m.CardanoPublicKey:
return client.call(
messages.CardanoGetPublicKey(
m.CardanoGetPublicKey(
address_n=address_n,
derivation_type=derivation_type,
show_display=show_display,
)
),
expect=m.CardanoPublicKey,
)
@expect(messages.CardanoNativeScriptHash)
def get_native_script_hash(
client: "TrezorClient",
native_script: messages.CardanoNativeScript,
display_format: messages.CardanoNativeScriptHashDisplayFormat = messages.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
) -> "MessageType":
native_script: m.CardanoNativeScript,
display_format: m.CardanoNativeScriptHashDisplayFormat = m.CardanoNativeScriptHashDisplayFormat.HIDE,
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
) -> m.CardanoNativeScriptHash:
return client.call(
messages.CardanoGetNativeScriptHash(
m.CardanoGetNativeScriptHash(
script=native_script,
display_format=display_format,
derivation_type=derivation_type,
)
),
expect=m.CardanoNativeScriptHash,
)
def sign_tx(
client: "TrezorClient",
signing_mode: messages.CardanoTxSigningMode,
signing_mode: m.CardanoTxSigningMode,
inputs: List[InputWithPath],
outputs: List[OutputWithData],
fee: int,
ttl: Optional[int],
validity_interval_start: Optional[int],
certificates: Sequence[CertificateWithPoolOwnersAndRelays] = (),
withdrawals: Sequence[messages.CardanoTxWithdrawal] = (),
withdrawals: Sequence[m.CardanoTxWithdrawal] = (),
protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"],
auxiliary_data: Optional[messages.CardanoTxAuxiliaryData] = None,
auxiliary_data: Optional[m.CardanoTxAuxiliaryData] = None,
mint: Sequence[AssetGroupWithTokens] = (),
script_data_hash: Optional[bytes] = None,
collateral_inputs: Sequence[CollateralInputWithPath] = (),
required_signers: Sequence[messages.CardanoTxRequiredSigner] = (),
required_signers: Sequence[m.CardanoTxRequiredSigner] = (),
collateral_return: Optional[OutputWithData] = None,
total_collateral: Optional[int] = None,
reference_inputs: Sequence[messages.CardanoTxReferenceInput] = (),
reference_inputs: Sequence[m.CardanoTxReferenceInput] = (),
additional_witness_requests: Sequence[Path] = (),
derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS,
derivation_type: m.CardanoDerivationType = m.CardanoDerivationType.ICARUS,
include_network_id: bool = False,
chunkify: bool = False,
tag_cbor_sets: bool = False,
) -> Dict[str, Any]:
UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response")
witness_requests = _get_witness_requests(
inputs,
certificates,
@ -916,7 +908,7 @@ def sign_tx(
)
response = client.call(
messages.CardanoSignTxInit(
m.CardanoSignTxInit(
signing_mode=signing_mode,
inputs_count=len(inputs),
outputs_count=len(outputs),
@ -940,10 +932,9 @@ def sign_tx(
include_network_id=include_network_id,
chunkify=chunkify,
tag_cbor_sets=tag_cbor_sets,
),
expect=m.CardanoTxItemAck,
)
)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
for tx_item in chain(
_get_inputs_items(inputs),
@ -951,55 +942,41 @@ def sign_tx(
_get_certificates_items(certificates),
withdrawals,
):
response = client.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
response = client.call(tx_item, expect=m.CardanoTxItemAck)
sign_tx_response: Dict[str, Any] = {}
if auxiliary_data is not None:
auxiliary_data_supplement = client.call(auxiliary_data)
if not isinstance(
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
):
raise UNEXPECTED_RESPONSE_ERROR
auxiliary_data_supplement = client.call(
auxiliary_data, expect=m.CardanoTxAuxiliaryDataSupplement
)
if (
auxiliary_data_supplement.type
!= messages.CardanoTxAuxiliaryDataSupplementType.NONE
!= m.CardanoTxAuxiliaryDataSupplementType.NONE
):
sign_tx_response["auxiliary_data_supplement"] = (
auxiliary_data_supplement.__dict__
)
response = client.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxItemAck)
for tx_item in chain(
_get_mint_items(mint),
_get_collateral_inputs_items(collateral_inputs),
required_signers,
):
response = client.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
response = client.call(tx_item, expect=m.CardanoTxItemAck)
if collateral_return is not None:
for tx_item in _get_output_items(collateral_return):
response = client.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
response = client.call(tx_item, expect=m.CardanoTxItemAck)
for reference_input in reference_inputs:
response = client.call(reference_input)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
response = client.call(reference_input, expect=m.CardanoTxItemAck)
sign_tx_response["witnesses"] = []
for witness_request in witness_requests:
response = client.call(witness_request)
if not isinstance(response, messages.CardanoTxWitnessResponse):
raise UNEXPECTED_RESPONSE_ERROR
response = client.call(witness_request, expect=m.CardanoTxWitnessResponse)
sign_tx_response["witnesses"].append(
{
"type": response.type,
@ -1009,13 +986,9 @@ def sign_tx(
}
)
response = client.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxBodyHash):
raise UNEXPECTED_RESPONSE_ERROR
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoTxBodyHash)
sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoSignTxFinished):
raise UNEXPECTED_RESPONSE_ERROR
response = client.call(m.CardanoTxHostAck(), expect=m.CardanoSignTxFinished)
return sign_tx_response

View File

@ -107,16 +107,16 @@ def record_screen_from_connection(
@cli.command()
@with_client
def prodtest_t1(client: "TrezorClient") -> str:
def prodtest_t1(client: "TrezorClient") -> None:
"""Perform a prodtest on Model One.
Only available on PRODTEST firmware and on T1B1. Formerly named self-test.
"""
return debuglink_prodtest_t1(client)
debuglink_prodtest_t1(client)
@cli.command()
@with_client
def optiga_set_sec_max(client: "TrezorClient") -> str:
def optiga_set_sec_max(client: "TrezorClient") -> None:
"""Set Optiga's security event counter to maximum."""
return debuglink_optiga_set_sec_max(client)
debuglink_optiga_set_sec_max(client)

View File

@ -29,7 +29,6 @@ from . import ChoiceType, with_client
if t.TYPE_CHECKING:
from ..client import TrezorClient
from ..protobuf import MessageType
from . import TrezorConnection
RECOVERY_DEVICE_INPUT_METHOD = {
@ -66,7 +65,7 @@ def cli() -> None:
is_flag=True,
)
@with_client
def wipe(client: "TrezorClient", bootloader: bool) -> str:
def wipe(client: "TrezorClient", bootloader: bool) -> None:
"""Reset device to factory defaults and remove all private data."""
if bootloader:
if not client.features.bootloader_mode:
@ -87,11 +86,7 @@ def wipe(client: "TrezorClient", bootloader: bool) -> str:
else:
click.echo("Wiping user data!")
try:
return device.wipe(client)
except exceptions.TrezorFailure as e:
click.echo("Action failed: {} {}".format(*e.args))
sys.exit(3)
device.wipe(client)
@cli.command()
@ -116,7 +111,7 @@ def load(
academic: bool,
needs_backup: bool,
no_backup: bool,
) -> str:
) -> None:
"""Upload seed and custom configuration to the device.
This functionality is only available in debug mode.
@ -136,7 +131,7 @@ def load(
label = "ACADEMIC"
try:
return debuglink.load_device(
debuglink.load_device(
client,
mnemonic=list(mnemonic),
pin=pin,
@ -184,7 +179,7 @@ def recover(
input_method: messages.RecoveryDeviceInputMethod,
dry_run: bool,
unlock_repeated_backup: bool,
) -> "MessageType":
) -> None:
"""Start safe recovery workflow."""
if input_method == messages.RecoveryDeviceInputMethod.ScrambledWords:
input_callback = ui.mnemonic_words(expand)
@ -201,7 +196,7 @@ def recover(
if unlock_repeated_backup:
type = messages.RecoveryType.UnlockRepeatedBackup
return device.recover(
device.recover(
client,
word_count=int(words),
passphrase_protection=passphrase_protection,
@ -236,21 +231,13 @@ def setup(
no_backup: bool,
backup_type: messages.BackupType | None,
entropy_check_count: int | None,
) -> str:
) -> None:
"""Perform device setup and generate new seed."""
if strength:
strength = int(strength)
BT = messages.BackupType
if backup_type is None:
if client.version >= (2, 7, 1):
# SLIP39 extendable was introduced in 2.7.1
backup_type = BT.Slip39_Single_Extendable
else:
# this includes both T1 and older trezor-cores
backup_type = BT.Bip39
if (
backup_type
in (BT.Slip39_Single_Extendable, BT.Slip39_Basic, BT.Slip39_Basic_Extendable)
@ -264,7 +251,7 @@ def setup(
"backup type. Traditional BIP39 backup may be generated instead."
)
resp, path_xpubs = device.reset_entropy_check(
path_xpubs = device.setup(
client,
strength=strength,
passphrase_protection=passphrase_protection,
@ -277,13 +264,10 @@ def setup(
entropy_check_count=entropy_check_count,
)
if isinstance(resp, messages.Success):
if path_xpubs:
click.echo("XPUBs for the generated seed")
for path, xpub in path_xpubs:
click.echo(f"{format_path(path)}: {xpub}")
return resp.message or ""
else:
raise RuntimeError(f"Received {resp.__class__}")
@cli.command()
@ -294,10 +278,9 @@ def backup(
client: "TrezorClient",
group_threshold: int | None = None,
groups: t.Sequence[tuple[int, int]] = (),
) -> str:
) -> None:
"""Perform device seed backup."""
return device.backup(client, group_threshold, groups)
device.backup(client, group_threshold, groups)
@cli.command()
@ -305,7 +288,7 @@ def backup(
@with_client
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> str:
) -> None:
"""Secure the device with SD card protection.
When SD card protection is enabled, a randomly generated secret is stored
@ -321,12 +304,12 @@ def sd_protect(
"""
if client.features.model == "1":
raise click.ClickException("Trezor One does not support SD card protection.")
return device.sd_protect(client, operation)
device.sd_protect(client, operation)
@cli.command()
@click.pass_obj
def reboot_to_bootloader(obj: "TrezorConnection") -> str:
def reboot_to_bootloader(obj: "TrezorConnection") -> None:
"""Reboot device into bootloader mode.
Currently only supported on Trezor Model One.
@ -334,21 +317,21 @@ def reboot_to_bootloader(obj: "TrezorConnection") -> str:
# avoid using @with_client because it closes the session afterwards,
# which triggers double prompt on device
with obj.client_context() as client:
return device.reboot_to_bootloader(client)
device.reboot_to_bootloader(client)
@cli.command()
@with_client
def tutorial(client: "TrezorClient") -> str:
def tutorial(client: "TrezorClient") -> None:
"""Show on-device tutorial."""
return device.show_device_tutorial(client)
device.show_device_tutorial(client)
@cli.command()
@with_client
def unlock_bootloader(client: "TrezorClient") -> str:
def unlock_bootloader(client: "TrezorClient") -> None:
"""Unlocks bootloader. Irreversible."""
return device.unlock_bootloader(client)
device.unlock_bootloader(client)
@cli.command()
@ -360,10 +343,11 @@ def unlock_bootloader(client: "TrezorClient") -> str:
help="Dialog expiry in seconds.",
)
@with_client
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> str:
def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) -> None:
"""Show a "Do not disconnect" dialog."""
if enable is False:
return device.set_busy(client, None)
device.set_busy(client, None)
return
if expiry is None:
raise click.ClickException("Missing option '-e' / '--expiry'.")
@ -373,7 +357,7 @@ def set_busy(client: "TrezorClient", enable: bool | None, expiry: int | None) ->
f"Invalid value for '-e' / '--expiry': '{expiry}' is not a positive integer."
)
return device.set_busy(client, expiry * 1000)
device.set_busy(client, expiry * 1000)
PUBKEY_WHITELIST_URL_TEMPLATE = (

View File

@ -80,12 +80,12 @@ def credentials_list(client: "TrezorClient") -> None:
@credentials.command(name="add")
@click.argument("hex_credential_id")
@with_client
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
def credentials_add(client: "TrezorClient", hex_credential_id: str) -> None:
"""Add the credential with the given ID as a resident credential.
HEX_CREDENTIAL_ID is the credential ID as a hexadecimal string.
"""
return fido.add_credential(client, bytes.fromhex(hex_credential_id))
fido.add_credential(client, bytes.fromhex(hex_credential_id))
@credentials.command(name="remove")
@ -93,9 +93,9 @@ def credentials_add(client: "TrezorClient", hex_credential_id: str) -> str:
"-i", "--index", required=True, type=click.IntRange(0, 99), help="Credential index."
)
@with_client
def credentials_remove(client: "TrezorClient", index: int) -> str:
def credentials_remove(client: "TrezorClient", index: int) -> None:
"""Remove the resident credential at the given index."""
return fido.remove_credential(client, index)
fido.remove_credential(client, index)
#
@ -111,9 +111,9 @@ def counter() -> None:
@counter.command(name="set")
@click.argument("counter", type=int)
@with_client
def counter_set(client: "TrezorClient", counter: int) -> str:
def counter_set(client: "TrezorClient", counter: int) -> None:
"""Set FIDO/U2F counter value."""
return fido.set_counter(client, counter)
fido.set_counter(client, counter)
@counter.command(name="get-next")

View File

@ -181,17 +181,17 @@ def cli() -> None:
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
def pin(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
"""Set, change or remove PIN."""
# Remove argument is there for backwards compatibility
return device.change_pin(client, remove=_should_remove(enable, remove))
device.change_pin(client, remove=_should_remove(enable, remove))
@cli.command()
@click.option("-r", "--remove", is_flag=True, hidden=True)
@click.argument("enable", type=ChoiceType({"on": True, "off": False}), required=False)
@with_client
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> str:
def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> None:
"""Set or remove the wipe code.
The wipe code functions as a "self-destruct PIN". If the wipe code is ever
@ -199,7 +199,7 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
removed and the device will be reset to factory defaults.
"""
# Remove argument is there for backwards compatibility
return device.change_wipe_code(client, remove=_should_remove(enable, remove))
device.change_wipe_code(client, remove=_should_remove(enable, remove))
@cli.command()
@ -207,24 +207,24 @@ def wipe_code(client: "TrezorClient", enable: Optional[bool], remove: bool) -> s
@click.option("-l", "--label", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.argument("label")
@with_client
def label(client: "TrezorClient", label: str) -> str:
def label(client: "TrezorClient", label: str) -> None:
"""Set new device label."""
return device.apply_settings(client, label=label)
device.apply_settings(client, label=label)
@cli.command()
@with_client
def brightness(client: "TrezorClient") -> str:
def brightness(client: "TrezorClient") -> None:
"""Set display brightness."""
return device.set_brightness(client)
device.set_brightness(client)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
def haptic_feedback(client: "TrezorClient", enable: bool) -> None:
"""Enable or disable haptic feedback."""
return device.apply_settings(client, haptic_feedback=enable)
device.apply_settings(client, haptic_feedback=enable)
@cli.command()
@ -236,7 +236,7 @@ def haptic_feedback(client: "TrezorClient", enable: bool) -> str:
@with_client
def language(
client: "TrezorClient", path_or_url: str | None, remove: bool, display: bool | None
) -> str:
) -> None:
"""Set new language with translations."""
if remove != (path_or_url is None):
raise click.ClickException("Either provide a path or URL or use --remove")
@ -259,27 +259,27 @@ def language(
raise click.ClickException(
f"Failed to load translations from {path_or_url}"
) from None
return device.change_language(
client, language_data=language_data, show_display=display
)
device.change_language(client, language_data=language_data, show_display=display)
@cli.command()
@click.argument("rotation", type=ChoiceType(ROTATION))
@with_client
def display_rotation(client: "TrezorClient", rotation: messages.DisplayRotation) -> str:
def display_rotation(
client: "TrezorClient", rotation: messages.DisplayRotation
) -> None:
"""Set display rotation.
Configure display rotation for Trezor Model T. The options are
north, east, south or west.
"""
return device.apply_settings(client, display_rotation=rotation)
device.apply_settings(client, display_rotation=rotation)
@cli.command()
@click.argument("delay", type=str)
@with_client
def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
def auto_lock_delay(client: "TrezorClient", delay: str) -> None:
"""Set auto-lock delay (in seconds)."""
if not client.features.pin_protection:
@ -291,13 +291,13 @@ def auto_lock_delay(client: "TrezorClient", delay: str) -> str:
seconds = float(value) * units[unit]
else:
seconds = float(delay) # assume seconds if no unit is specified
return device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
device.apply_settings(client, auto_lock_delay_ms=int(seconds * 1000))
@cli.command()
@click.argument("flags")
@with_client
def flags(client: "TrezorClient", flags: str) -> str:
def flags(client: "TrezorClient", flags: str) -> None:
"""Set device flags."""
if flags.lower().startswith("0b"):
flags_int = int(flags, 2)
@ -305,7 +305,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
flags_int = int(flags, 16)
else:
flags_int = int(flags)
return device.apply_flags(client, flags=flags_int)
device.apply_flags(client, flags=flags_int)
@cli.command()
@ -315,7 +315,7 @@ def flags(client: "TrezorClient", flags: str) -> str:
)
@click.option("-q", "--quality", type=int, default=90, help="JPEG quality (0-100)")
@with_client
def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
def homescreen(client: "TrezorClient", filename: str, quality: int) -> None:
"""Set new homescreen.
To revert to default homescreen, use 'trezorctl set homescreen default'
@ -369,7 +369,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
"Unknown image format requested by the device."
)
return device.apply_settings(client, homescreen=img)
device.apply_settings(client, homescreen=img)
@cli.command()
@ -380,7 +380,7 @@ def homescreen(client: "TrezorClient", filename: str, quality: int) -> str:
@with_client
def safety_checks(
client: "TrezorClient", always: bool, level: messages.SafetyCheckLevel
) -> str:
) -> None:
"""Set safety check level.
Set to "strict" to get the full Trezor security (default setting).
@ -392,18 +392,18 @@ def safety_checks(
"""
if always and level == messages.SafetyCheckLevel.PromptTemporarily:
level = messages.SafetyCheckLevel.PromptAlways
return device.apply_settings(client, safety_checks=level)
device.apply_settings(client, safety_checks=level)
@cli.command()
@click.argument("enable", type=ChoiceType({"on": True, "off": False}))
@with_client
def experimental_features(client: "TrezorClient", enable: bool) -> str:
def experimental_features(client: "TrezorClient", enable: bool) -> None:
"""Enable or disable experimental message types.
This is a developer feature. Use with caution.
"""
return device.apply_settings(client, experimental_features=enable)
device.apply_settings(client, experimental_features=enable)
#
@ -427,13 +427,13 @@ passphrase = cast(AliasedGroup, passphrase_main)
@passphrase.command(name="on")
@click.option("-f/-F", "--force-on-device/--no-force-on-device", default=None)
@with_client
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> str:
def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> None:
"""Enable passphrase."""
if client.features.passphrase_protection is not True:
use_passphrase = True
else:
use_passphrase = None
return device.apply_settings(
device.apply_settings(
client,
use_passphrase=use_passphrase,
passphrase_always_on_device=force_on_device,
@ -442,9 +442,9 @@ def passphrase_on(client: "TrezorClient", force_on_device: Optional[bool]) -> st
@passphrase.command(name="off")
@with_client
def passphrase_off(client: "TrezorClient") -> str:
def passphrase_off(client: "TrezorClient") -> None:
"""Disable passphrase."""
return device.apply_settings(client, use_passphrase=False)
device.apply_settings(client, use_passphrase=False)
# Registering the aliases for backwards compatibility
@ -458,9 +458,9 @@ passphrase.aliases = {
@passphrase.command(name="hide")
@click.argument("hide", type=ChoiceType({"on": True, "off": False}))
@with_client
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> str:
def hide_passphrase_from_host(client: "TrezorClient", hide: bool) -> None:
"""Enable or disable hiding passphrase coming from host.
This is a developer feature. Use with caution.
"""
return device.apply_settings(client, hide_passphrase_from_host=hide)
device.apply_settings(client, hide_passphrase_from_host=hide)

View File

@ -26,7 +26,7 @@ def get_public_key(
client: "TrezorClient",
address: str,
show_display: bool,
) -> messages.SolanaPublicKey:
) -> bytes:
"""Get Solana public key."""
address_n = tools.parse_path(address)
return solana.get_public_key(client, address_n, show_display)
@ -42,7 +42,7 @@ def get_address(
address: str,
show_display: bool,
chunkify: bool,
) -> messages.SolanaAddress:
) -> str:
"""Get Solana address."""
address_n = tools.parse_path(address)
return solana.get_address(client, address_n, show_display, chunkify)
@ -58,7 +58,7 @@ def sign_tx(
address: str,
serialized_tx: str,
additional_information_file: Optional[TextIO],
) -> messages.SolanaTxSignature:
) -> bytes:
"""Sign Solana transaction."""
address_n = tools.parse_path(address)

View File

@ -27,7 +27,7 @@ from . import exceptions, mapping, messages, models
from .log import DUMP_BYTES
from .messages import Capability
from .protobuf import MessageType
from .tools import expect, parse_path, session
from .tools import parse_path, session
if TYPE_CHECKING:
from .transport import Transport
@ -397,12 +397,7 @@ class TrezorClient(Generic[UI]):
else:
raise exceptions.OutdatedFirmwareError(OUTDATED_FIRMWARE_ERROR)
@expect(messages.Success, field="message", ret_type=str)
def ping(
self,
msg: str,
button_protection: bool = False,
) -> MessageType:
def ping(self, msg: str, button_protection: bool = False) -> str:
# We would like ping to work on any valid TrezorClient instance, but
# due to the protection modes, we need to go through self.call, and that will
# raise an exception if the firmware is too old.
@ -416,13 +411,18 @@ class TrezorClient(Generic[UI]):
# device is PIN-locked.
# respond and hope for the best
resp = self._callback_button(resp)
return resp
resp = messages.Success.ensure_isinstance(resp)
assert resp.message is not None
return resp.message
finally:
self.close()
return self.call(
messages.Ping(message=msg, button_protection=button_protection)
resp = self.call(
messages.Ping(message=msg, button_protection=button_protection),
expect=messages.Success,
)
assert resp.message is not None
return resp.message
def get_device_id(self) -> Optional[str]:
return self.features.device_id

View File

@ -44,10 +44,9 @@ from mnemonic import Mnemonic
from . import mapping, messages, models, protobuf
from .client import TrezorClient
from .exceptions import TrezorFailure
from .exceptions import TrezorFailure, UnexpectedMessageError
from .log import DUMP_BYTES
from .messages import DebugWaitType
from .tools import expect
if TYPE_CHECKING:
from typing_extensions import Protocol
@ -775,9 +774,10 @@ class DebugLink:
else:
self.t1_take_screenshots = False
@expect(messages.DebugLinkMemory, field="memory", ret_type=bytes)
def memory_read(self, address: int, length: int) -> protobuf.MessageType:
return self._call(messages.DebugLinkMemoryRead(address=address, length=length))
def memory_read(self, address: int, length: int) -> bytes:
return self._call(
messages.DebugLinkMemoryRead(address=address, length=length)
).memory
def memory_write(self, address: int, memory: bytes, flash: bool = False) -> None:
self._write(
@ -787,9 +787,11 @@ class DebugLink:
def flash_erase(self, sector: int) -> None:
self._write(messages.DebugLinkFlashErase(sector=sector))
@expect(messages.Success)
def erase_sd_card(self, format: bool = True) -> messages.Success:
return self._call(messages.DebugLinkEraseSdCard(format=format))
res = self._call(messages.DebugLinkEraseSdCard(format=format))
if not isinstance(res, messages.Success):
raise UnexpectedMessageError(messages.Success, res)
return res
def snapshot_legacy(self) -> None:
"""Snapshot the current state of the device."""
@ -1350,7 +1352,6 @@ class TrezorClientDebugLink(TrezorClient):
raise RuntimeError("Unexpected call")
@expect(messages.Success, field="message", ret_type=str)
def load_device(
client: "TrezorClient",
mnemonic: Union[str, Iterable[str]],
@ -1360,7 +1361,7 @@ def load_device(
skip_checksum: bool = False,
needs_backup: bool = False,
no_backup: bool = False,
) -> protobuf.MessageType:
) -> None:
if isinstance(mnemonic, str):
mnemonic = [mnemonic]
@ -1371,7 +1372,7 @@ def load_device(
"Device is initialized already. Call device.wipe() and try again."
)
resp = client.call(
client.call(
messages.LoadDevice(
mnemonics=mnemonics,
pin=pin,
@ -1380,25 +1381,25 @@ def load_device(
skip_checksum=skip_checksum,
needs_backup=needs_backup,
no_backup=no_backup,
)
),
expect=messages.Success,
)
client.init_device()
return resp
# keep the old name for compatibility
load_device_by_mnemonic = load_device
@expect(messages.Success, field="message", ret_type=str)
def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
def prodtest_t1(client: "TrezorClient") -> None:
if client.features.bootloader_mode is not True:
raise RuntimeError("Device must be in bootloader mode")
return client.call(
client.call(
messages.ProdTestT1(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
)
),
expect=messages.Success,
)
@ -1450,6 +1451,5 @@ def _is_emulator(debug_client: "TrezorClientDebugLink") -> bool:
return debug_client.features.fw_vendor == "EMULATOR"
@expect(messages.Success, field="message", ret_type=str)
def optiga_set_sec_max(client: "TrezorClient") -> protobuf.MessageType:
return client.call(messages.DebugLinkOptigaSetSecMax())
def optiga_set_sec_max(client: "TrezorClient") -> None:
client.call(messages.DebugLinkOptigaSetSecMax(), expect=messages.Success)

View File

@ -18,7 +18,6 @@ from __future__ import annotations
import hashlib
import hmac
import os
import random
import secrets
import time
@ -29,11 +28,16 @@ from slip10 import SLIP10
from . import messages
from .exceptions import Cancelled, TrezorException
from .tools import Address, expect, parse_path, session
from .tools import (
Address,
_deprecation_retval_helper,
_return_success,
parse_path,
session,
)
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
RECOVERY_BACK = "\x08" # backspace character, sent literally
@ -42,7 +46,6 @@ SLIP39_EXTENDABLE_MIN_VERSION = (2, 7, 1)
ENTROPY_CHECK_MIN_VERSION = (2, 8, 7)
@expect(messages.Success, field="message", ret_type=str)
@session
def apply_settings(
client: "TrezorClient",
@ -57,7 +60,7 @@ def apply_settings(
experimental_features: Optional[bool] = None,
hide_passphrase_from_host: Optional[bool] = None,
haptic_feedback: Optional[bool] = None,
) -> "MessageType":
) -> str | None:
if language is not None:
warnings.warn(
"language ignored. Use change_language() to set device language.",
@ -76,87 +79,80 @@ def apply_settings(
haptic_feedback=haptic_feedback,
)
out = client.call(settings)
out = client.call(settings, expect=messages.Success)
client.refresh_features()
return out
return _return_success(out)
def _send_language_data(
client: "TrezorClient",
request: "messages.TranslationDataRequest",
language_data: bytes,
) -> "MessageType":
response: MessageType = request
) -> None:
response = request
while not isinstance(response, messages.Success):
assert isinstance(response, messages.TranslationDataRequest)
response = messages.TranslationDataRequest.ensure_isinstance(response)
data_length = response.data_length
data_offset = response.data_offset
chunk = language_data[data_offset : data_offset + data_length]
response = client.call(messages.TranslationDataAck(data_chunk=chunk))
return response
@expect(messages.Success, field="message", ret_type=str)
@session
def change_language(
client: "TrezorClient",
language_data: bytes,
show_display: bool | None = None,
) -> "MessageType":
) -> str | None:
data_length = len(language_data)
msg = messages.ChangeLanguage(data_length=data_length, show_display=show_display)
response = client.call(msg)
if data_length > 0:
assert isinstance(response, messages.TranslationDataRequest)
response = _send_language_data(client, response, language_data)
assert isinstance(response, messages.Success)
response = messages.TranslationDataRequest.ensure_isinstance(response)
_send_language_data(client, response, language_data)
else:
messages.Success.ensure_isinstance(response)
client.refresh_features() # changing the language in features
return response
return _return_success(messages.Success(message="Language changed."))
@expect(messages.Success, field="message", ret_type=str)
@session
def apply_flags(client: "TrezorClient", flags: int) -> "MessageType":
out = client.call(messages.ApplyFlags(flags=flags))
def apply_flags(client: "TrezorClient", flags: int) -> str | None:
out = client.call(messages.ApplyFlags(flags=flags), expect=messages.Success)
client.refresh_features()
return out
return _return_success(out)
@expect(messages.Success, field="message", ret_type=str)
@session
def change_pin(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangePin(remove=remove))
def change_pin(client: "TrezorClient", remove: bool = False) -> str | None:
ret = client.call(messages.ChangePin(remove=remove), expect=messages.Success)
client.refresh_features()
return ret
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
@session
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> "MessageType":
ret = client.call(messages.ChangeWipeCode(remove=remove))
def change_wipe_code(client: "TrezorClient", remove: bool = False) -> str | None:
ret = client.call(messages.ChangeWipeCode(remove=remove), expect=messages.Success)
client.refresh_features()
return ret
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
@session
def sd_protect(
client: "TrezorClient", operation: messages.SdProtectOperationType
) -> "MessageType":
ret = client.call(messages.SdProtect(operation=operation))
) -> str | None:
ret = client.call(messages.SdProtect(operation=operation), expect=messages.Success)
client.refresh_features()
return ret
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
@session
def wipe(client: "TrezorClient") -> "MessageType":
ret = client.call(messages.WipeDevice())
def wipe(client: "TrezorClient") -> str | None:
ret = client.call(messages.WipeDevice(), expect=messages.Success)
if not client.features.bootloader_mode:
client.init_device()
return ret
return _return_success(ret)
@session
@ -173,7 +169,7 @@ def recover(
u2f_counter: Optional[int] = None,
*,
type: Optional[messages.RecoveryType] = None,
) -> "MessageType":
) -> messages.Success | None:
if language is not None:
warnings.warn(
"language ignored. Use change_language() to set device language.",
@ -235,8 +231,12 @@ def recover(
except Cancelled:
res = client.call(messages.Cancel())
# check that the result is a Success
res = messages.Success.ensure_isinstance(res)
# reinitialize the device
client.init_device()
return res
return _deprecation_retval_helper(res)
def is_slip39_backup_type(backup_type: messages.BackupType):
@ -279,13 +279,7 @@ def _seed_from_entropy(
return seed
@expect(messages.Success, field="message", ret_type=str)
def reset(*args: Any, **kwargs: Any) -> "MessageType":
return reset_entropy_check(*args, **kwargs)[0]
@session
def reset_entropy_check(
def reset(
client: "TrezorClient",
display_random: bool = False,
strength: Optional[int] = None,
@ -576,13 +570,12 @@ def _reset_with_entropycheck(
return xpubs
@expect(messages.Success, field="message", ret_type=str)
@session
def backup(
client: "TrezorClient",
group_threshold: Optional[int] = None,
groups: Iterable[tuple[int, int]] = (),
) -> "MessageType":
) -> str | None:
ret = client.call(
messages.BackupDevice(
group_threshold=group_threshold,
@ -590,38 +583,39 @@ def backup(
messages.Slip39Group(member_threshold=t, member_count=c)
for t, c in groups
],
)
),
expect=messages.Success,
)
client.refresh_features()
return ret
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
def cancel_authorization(client: "TrezorClient") -> "MessageType":
return client.call(messages.CancelAuthorization())
def cancel_authorization(client: "TrezorClient") -> str | None:
ret = client.call(messages.CancelAuthorization(), expect=messages.Success)
return _return_success(ret)
@expect(messages.UnlockedPathRequest, field="mac", ret_type=bytes)
def unlock_path(client: "TrezorClient", n: "Address") -> "MessageType":
resp = client.call(messages.UnlockPath(address_n=n))
def unlock_path(client: "TrezorClient", n: "Address") -> bytes:
resp = client.call(
messages.UnlockPath(address_n=n), expect=messages.UnlockedPathRequest
)
# Cancel the UnlockPath workflow now that we have the authentication code.
try:
client.call(messages.Cancel())
except Cancelled:
return resp
return resp.mac
else:
raise TrezorException("Unexpected response in UnlockPath flow")
@session
@expect(messages.Success, field="message", ret_type=str)
def reboot_to_bootloader(
client: "TrezorClient",
boot_command: messages.BootCommand = messages.BootCommand.STOP_AND_WAIT,
firmware_header: Optional[bytes] = None,
language_data: bytes = b"",
) -> "MessageType":
) -> str | None:
response = client.call(
messages.RebootToBootloader(
boot_command=boot_command,
@ -631,41 +625,42 @@ def reboot_to_bootloader(
)
if isinstance(response, messages.TranslationDataRequest):
response = _send_language_data(client, response, language_data)
return response
return _return_success(messages.Success(message=""))
@session
@expect(messages.Success, field="message", ret_type=str)
def show_device_tutorial(client: "TrezorClient") -> "MessageType":
return client.call(messages.ShowDeviceTutorial())
def show_device_tutorial(client: "TrezorClient") -> str | None:
ret = client.call(messages.ShowDeviceTutorial(), expect=messages.Success)
return _return_success(ret)
@session
@expect(messages.Success, field="message", ret_type=str)
def unlock_bootloader(client: "TrezorClient") -> "MessageType":
return client.call(messages.UnlockBootloader())
def unlock_bootloader(client: "TrezorClient") -> str | None:
ret = client.call(messages.UnlockBootloader(), expect=messages.Success)
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
@session
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> "MessageType":
def set_busy(client: "TrezorClient", expiry_ms: Optional[int]) -> str | None:
"""Sets or clears the busy state of the device.
In the busy state the device shows a "Do not disconnect" message instead of the homescreen.
Setting `expiry_ms=None` clears the busy state.
"""
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms))
ret = client.call(messages.SetBusy(expiry_ms=expiry_ms), expect=messages.Success)
client.refresh_features()
return ret
return _return_success(ret)
@expect(messages.AuthenticityProof)
def authenticate(client: "TrezorClient", challenge: bytes):
return client.call(messages.AuthenticateDevice(challenge=challenge))
def authenticate(
client: "TrezorClient", challenge: bytes
) -> messages.AuthenticityProof:
return client.call(
messages.AuthenticateDevice(challenge=challenge),
expect=messages.AuthenticityProof,
)
@expect(messages.Success, field="message", ret_type=str)
def set_brightness(
client: "TrezorClient", value: Optional[int] = None
) -> "MessageType":
return client.call(messages.SetBrightness(value=value))
def set_brightness(client: "TrezorClient", value: Optional[int] = None) -> str | None:
ret = client.call(messages.SetBrightness(value=value), expect=messages.Success)
return _return_success(ret)

View File

@ -18,11 +18,10 @@ from datetime import datetime
from typing import TYPE_CHECKING, List, Tuple
from . import exceptions, messages
from .tools import b58decode, expect, session
from .tools import b58decode, session
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
@ -319,14 +318,13 @@ def parse_transaction_json(
# ====== Client functions ====== #
@expect(messages.EosPublicKey)
def get_public_key(
client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType":
response = client.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display)
) -> messages.EosPublicKey:
return client.call(
messages.EosGetPublicKey(address_n=n, show_display=show_display),
expect=messages.EosPublicKey,
)
return response
@session

View File

@ -18,11 +18,10 @@ import re
from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple
from . import definitions, exceptions, messages
from .tools import expect, prepare_message_bytes, session, unharden
from .tools import prepare_message_bytes, session, unharden
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
@ -161,30 +160,32 @@ def network_from_address_n(
# ====== Client functions ====== #
@expect(messages.EthereumAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
n: "Address",
show_display: bool = False,
encoded_network: Optional[bytes] = None,
chunkify: bool = False,
) -> "MessageType":
return client.call(
) -> str:
resp = client.call(
messages.EthereumGetAddress(
address_n=n,
show_display=show_display,
encoded_network=encoded_network,
chunkify=chunkify,
),
expect=messages.EthereumAddress,
)
)
assert resp.address is not None
return resp.address
@expect(messages.EthereumPublicKey)
def get_public_node(
client: "TrezorClient", n: "Address", show_display: bool = False
) -> "MessageType":
) -> messages.EthereumPublicKey:
return client.call(
messages.EthereumGetPublicKey(address_n=n, show_display=show_display)
messages.EthereumGetPublicKey(address_n=n, show_display=show_display),
expect=messages.EthereumPublicKey,
)
@ -297,25 +298,24 @@ def sign_tx_eip1559(
return response.signature_v, response.signature_r, response.signature_s
@expect(messages.EthereumMessageSignature)
def sign_message(
client: "TrezorClient",
n: "Address",
message: AnyStr,
encoded_network: Optional[bytes] = None,
chunkify: bool = False,
) -> "MessageType":
) -> messages.EthereumMessageSignature:
return client.call(
messages.EthereumSignMessage(
address_n=n,
message=prepare_message_bytes(message),
encoded_network=encoded_network,
chunkify=chunkify,
)
),
expect=messages.EthereumMessageSignature,
)
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data(
client: "TrezorClient",
n: "Address",
@ -323,7 +323,7 @@ def sign_typed_data(
*,
metamask_v4_compat: bool = True,
definitions: Optional[messages.EthereumDefinitions] = None,
) -> "MessageType":
) -> messages.EthereumTypedDataSignature:
data = sanitize_typed_data(data)
types = data["types"]
@ -387,7 +387,7 @@ def sign_typed_data(
request = messages.EthereumTypedDataValueAck(value=encoded_data)
response = client.call(request)
return response
return messages.EthereumTypedDataSignature.ensure_isinstance(response)
def verify_message(
@ -398,32 +398,33 @@ def verify_message(
chunkify: bool = False,
) -> bool:
try:
resp = client.call(
client.call(
messages.EthereumVerifyMessage(
address=address,
signature=signature,
message=prepare_message_bytes(message),
chunkify=chunkify,
),
expect=messages.Success,
)
)
return True
except exceptions.TrezorFailure:
return False
return isinstance(resp, messages.Success)
@expect(messages.EthereumTypedDataSignature)
def sign_typed_data_hash(
client: "TrezorClient",
n: "Address",
domain_hash: bytes,
message_hash: Optional[bytes],
encoded_network: Optional[bytes] = None,
) -> "MessageType":
) -> messages.EthereumTypedDataSignature:
return client.call(
messages.EthereumSignTypedHash(
address_n=n,
domain_separator_hash=domain_hash,
message_hash=message_hash,
encoded_network=encoded_network,
)
),
expect=messages.EthereumTypedDataSignature,
)

View File

@ -14,42 +14,45 @@
# You should have received a copy of the License along with this library.
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from typing import TYPE_CHECKING, List
from __future__ import annotations
from typing import TYPE_CHECKING, Sequence
from . import messages
from .tools import expect
from .tools import _return_success
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
@expect(
messages.WebAuthnCredentials,
field="credentials",
ret_type=List[messages.WebAuthnCredential],
)
def list_credentials(client: "TrezorClient") -> "MessageType":
return client.call(messages.WebAuthnListResidentCredentials())
@expect(messages.Success, field="message", ret_type=str)
def add_credential(client: "TrezorClient", credential_id: bytes) -> "MessageType":
def list_credentials(client: "TrezorClient") -> Sequence[messages.WebAuthnCredential]:
return client.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id)
messages.WebAuthnListResidentCredentials(), expect=messages.WebAuthnCredentials
).credentials
def add_credential(client: "TrezorClient", credential_id: bytes) -> str | None:
ret = client.call(
messages.WebAuthnAddResidentCredential(credential_id=credential_id),
expect=messages.Success,
)
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
def remove_credential(client: "TrezorClient", index: int) -> "MessageType":
return client.call(messages.WebAuthnRemoveResidentCredential(index=index))
def remove_credential(client: "TrezorClient", index: int) -> str | None:
ret = client.call(
messages.WebAuthnRemoveResidentCredential(index=index), expect=messages.Success
)
return _return_success(ret)
@expect(messages.Success, field="message", ret_type=str)
def set_counter(client: "TrezorClient", u2f_counter: int) -> "MessageType":
return client.call(messages.SetU2FCounter(u2f_counter=u2f_counter))
def set_counter(client: "TrezorClient", u2f_counter: int) -> str | None:
ret = client.call(
messages.SetU2FCounter(u2f_counter=u2f_counter), expect=messages.Success
)
return _return_success(ret)
@expect(messages.NextU2FCounter, field="u2f_counter", ret_type=int)
def get_next_counter(client: "TrezorClient") -> "MessageType":
return client.call(messages.GetNextU2FCounter())
def get_next_counter(client: "TrezorClient") -> int:
ret = client.call(messages.GetNextU2FCounter(), expect=messages.NextU2FCounter)
return ret.u2f_counter

View File

@ -20,7 +20,7 @@ from hashlib import blake2s
from typing_extensions import Protocol, TypeGuard
from .. import messages
from ..tools import expect, session
from ..tools import session
from .core import VendorFirmware
from .legacy import LegacyFirmware, LegacyV2Firmware
@ -106,6 +106,7 @@ def update(
raise RuntimeError(f"Unexpected message {resp}")
@expect(messages.FirmwareHash, field="hash", ret_type=bytes)
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]):
return client.call(messages.GetFirmwareHash(challenge=challenge))
def get_hash(client: "TrezorClient", challenge: t.Optional[bytes]) -> bytes:
return client.call(
messages.GetFirmwareHash(challenge=challenge), expect=messages.FirmwareHash
).hash

View File

@ -17,54 +17,50 @@
from typing import TYPE_CHECKING, Optional
from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
@expect(messages.Entropy, field="entropy", ret_type=bytes)
def get_entropy(client: "TrezorClient", size: int) -> "MessageType":
return client.call(messages.GetEntropy(size=size))
def get_entropy(client: "TrezorClient", size: int) -> bytes:
return client.call(messages.GetEntropy(size=size), expect=messages.Entropy).entropy
@expect(messages.SignedIdentity)
def sign_identity(
client: "TrezorClient",
identity: messages.IdentityType,
challenge_hidden: bytes,
challenge_visual: str,
ecdsa_curve_name: Optional[str] = None,
) -> "MessageType":
) -> messages.SignedIdentity:
return client.call(
messages.SignIdentity(
identity=identity,
challenge_hidden=challenge_hidden,
challenge_visual=challenge_visual,
ecdsa_curve_name=ecdsa_curve_name,
)
),
expect=messages.SignedIdentity,
)
@expect(messages.ECDHSessionKey)
def get_ecdh_session_key(
client: "TrezorClient",
identity: messages.IdentityType,
peer_public_key: bytes,
ecdsa_curve_name: Optional[str] = None,
) -> "MessageType":
) -> messages.ECDHSessionKey:
return client.call(
messages.GetECDHSessionKey(
identity=identity,
peer_public_key=peer_public_key,
ecdsa_curve_name=ecdsa_curve_name,
)
),
expect=messages.ECDHSessionKey,
)
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def encrypt_keyvalue(
client: "TrezorClient",
n: "Address",
@ -73,7 +69,7 @@ def encrypt_keyvalue(
ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> "MessageType":
) -> bytes:
return client.call(
messages.CipherKeyValue(
address_n=n,
@ -83,11 +79,11 @@ def encrypt_keyvalue(
ask_on_encrypt=ask_on_encrypt,
ask_on_decrypt=ask_on_decrypt,
iv=iv,
)
)
),
expect=messages.CipheredKeyValue,
).value
@expect(messages.CipheredKeyValue, field="value", ret_type=bytes)
def decrypt_keyvalue(
client: "TrezorClient",
n: "Address",
@ -96,7 +92,7 @@ def decrypt_keyvalue(
ask_on_encrypt: bool = True,
ask_on_decrypt: bool = True,
iv: bytes = b"",
) -> "MessageType":
) -> bytes:
return client.call(
messages.CipherKeyValue(
address_n=n,
@ -106,10 +102,10 @@ def decrypt_keyvalue(
ask_on_encrypt=ask_on_encrypt,
ask_on_decrypt=ask_on_decrypt,
iv=iv,
)
)
),
expect=messages.CipheredKeyValue,
).value
@expect(messages.Nonce, field="nonce", ret_type=bytes)
def get_nonce(client: "TrezorClient"):
return client.call(messages.GetNonce())
def get_nonce(client: "TrezorClient") -> bytes:
return client.call(messages.GetNonce(), expect=messages.Nonce).nonce

View File

@ -17,11 +17,9 @@
from typing import TYPE_CHECKING
from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
@ -31,30 +29,30 @@ if TYPE_CHECKING:
# FAKECHAIN = 3
@expect(messages.MoneroAddress, field="address", ret_type=bytes)
def get_address(
client: "TrezorClient",
n: "Address",
show_display: bool = False,
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
chunkify: bool = False,
) -> "MessageType":
) -> bytes:
return client.call(
messages.MoneroGetAddress(
address_n=n,
show_display=show_display,
network_type=network_type,
chunkify=chunkify,
)
)
),
expect=messages.MoneroAddress,
).address
@expect(messages.MoneroWatchKey)
def get_watch_key(
client: "TrezorClient",
n: "Address",
network_type: messages.MoneroNetworkType = messages.MoneroNetworkType.MAINNET,
) -> "MessageType":
) -> messages.MoneroWatchKey:
return client.call(
messages.MoneroGetWatchKey(address_n=n, network_type=network_type)
messages.MoneroGetWatchKey(address_n=n, network_type=network_type),
expect=messages.MoneroWatchKey,
)

View File

@ -18,11 +18,9 @@ import json
from typing import TYPE_CHECKING
from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
TYPE_TRANSACTION_TRANSFER = 0x0101
@ -196,25 +194,24 @@ def create_sign_tx(transaction: dict, chunkify: bool = False) -> messages.NEMSig
# ====== Client functions ====== #
@expect(messages.NEMAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
n: "Address",
network: int,
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
) -> str:
return client.call(
messages.NEMGetAddress(
address_n=n, network=network, show_display=show_display, chunkify=chunkify
)
)
),
expect=messages.NEMAddress,
).address
@expect(messages.NEMSignedTx)
def sign_tx(
client: "TrezorClient", n: "Address", transaction: dict, chunkify: bool = False
) -> "MessageType":
) -> messages.NEMSignedTx:
try:
msg = create_sign_tx(transaction, chunkify=chunkify)
except ValueError as e:
@ -222,4 +219,4 @@ def sign_tx(
assert msg.transaction is not None
msg.transaction.address_n = n
return client.call(msg)
return client.call(msg, expect=messages.NEMSignedTx)

View File

@ -35,6 +35,8 @@ from itertools import zip_longest
import typing_extensions as tx
from .exceptions import UnexpectedMessageError
if t.TYPE_CHECKING:
from IPython.lib.pretty import RepresentationPrinter # noqa: I900
@ -312,6 +314,27 @@ class MessageType:
dump_message(data, self)
return len(data.getvalue())
@classmethod
def ensure_isinstance(cls, msg: t.Any) -> tx.Self:
"""Ensure that the received `msg` is an instance of this class.
If `msg` is not an instance of this class, raise an `UnexpectedMessageError`.
otherwise, return it. This is useful for type-checking like so:
>>> msg = client.call(SomeMessage())
>>> if isinstance(msg, Foo):
>>> return msg.foo_attr # attribute of Foo, type-checks OK
>>> else:
>>> msg = Bar.ensure_isinstance(msg) # raises if msg is something else
>>> return msg.bar_attr # attribute of Bar, type-checks OK
If there is just one expected message, you should use the `expect` parameter of
`Client.call` instead.
"""
if not isinstance(msg, cls):
raise UnexpectedMessageError(cls, msg)
return msg
class LimitedReader:
def __init__(self, reader: Reader, limit: int) -> None:

View File

@ -18,41 +18,39 @@ from typing import TYPE_CHECKING
from . import messages
from .protobuf import dict_to_proto
from .tools import dict_from_camelcase, expect
from .tools import dict_from_camelcase
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
REQUIRED_FIELDS = ("Fee", "Sequence", "TransactionType", "Payment")
REQUIRED_PAYMENT_FIELDS = ("Amount", "Destination")
@expect(messages.RippleAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
) -> str:
return client.call(
messages.RippleGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
)
),
expect=messages.RippleAddress,
).address
@expect(messages.RippleSignedTx)
def sign_tx(
client: "TrezorClient",
address_n: "Address",
msg: messages.RippleSignTx,
chunkify: bool = False,
) -> "MessageType":
) -> messages.RippleSignedTx:
msg.address_n = address_n
msg.chunkify = chunkify
return client.call(msg)
return client.call(msg, expect=messages.RippleSignedTx)
def create_sign_tx_msg(transaction: dict) -> messages.RippleSignTx:

View File

@ -1,51 +1,49 @@
from typing import TYPE_CHECKING, List, Optional
from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
@expect(messages.SolanaPublicKey)
def get_public_key(
client: "TrezorClient",
address_n: List[int],
show_display: bool,
) -> "MessageType":
) -> bytes:
return client.call(
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display)
)
messages.SolanaGetPublicKey(address_n=address_n, show_display=show_display),
expect=messages.SolanaPublicKey,
).public_key
@expect(messages.SolanaAddress)
def get_address(
client: "TrezorClient",
address_n: List[int],
show_display: bool,
chunkify: bool = False,
) -> "MessageType":
) -> str:
return client.call(
messages.SolanaGetAddress(
address_n=address_n,
show_display=show_display,
chunkify=chunkify,
)
)
),
expect=messages.SolanaAddress,
).address
@expect(messages.SolanaTxSignature)
def sign_tx(
client: "TrezorClient",
address_n: List[int],
serialized_tx: bytes,
additional_info: Optional[messages.SolanaTxAdditionalInfo],
) -> "MessageType":
) -> bytes:
return client.call(
messages.SolanaSignTx(
address_n=address_n,
serialized_tx=serialized_tx,
additional_info=additional_info,
)
)
),
expect=messages.SolanaTxSignature,
).signature

View File

@ -18,11 +18,9 @@ from decimal import Decimal
from typing import TYPE_CHECKING, List, Tuple, Union
from . import exceptions, messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
StellarMessageType = Union[
@ -323,18 +321,18 @@ def _read_asset(asset: "Asset") -> messages.StellarAsset:
# ====== Client functions ====== #
@expect(messages.StellarAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
) -> str:
return client.call(
messages.StellarGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
)
),
expect=messages.StellarAddress,
).address
def sign_tx(
@ -364,10 +362,7 @@ def sign_tx(
"Reached end of operations without a signature."
) from None
if not isinstance(resp, messages.StellarSignedTx):
raise exceptions.TrezorException(
f"Unexpected message: {resp.__class__.__name__}"
)
resp = messages.StellarSignedTx.ensure_isinstance(resp)
if operations:
raise exceptions.TrezorException(

View File

@ -17,49 +17,46 @@
from typing import TYPE_CHECKING
from . import messages
from .tools import expect
if TYPE_CHECKING:
from .client import TrezorClient
from .protobuf import MessageType
from .tools import Address
@expect(messages.TezosAddress, field="address", ret_type=str)
def get_address(
client: "TrezorClient",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
) -> str:
return client.call(
messages.TezosGetAddress(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
)
),
expect=messages.TezosAddress,
).address
@expect(messages.TezosPublicKey, field="public_key", ret_type=str)
def get_public_key(
client: "TrezorClient",
address_n: "Address",
show_display: bool = False,
chunkify: bool = False,
) -> "MessageType":
) -> str:
return client.call(
messages.TezosGetPublicKey(
address_n=address_n, show_display=show_display, chunkify=chunkify
)
)
),
expect=messages.TezosPublicKey,
).public_key
@expect(messages.TezosSignedTx)
def sign_tx(
client: "TrezorClient",
address_n: "Address",
sign_tx_msg: messages.TezosSignTx,
chunkify: bool = False,
) -> "MessageType":
) -> messages.TezosSignedTx:
sign_tx_msg.address_n = address_n
sign_tx_msg.chunkify = chunkify
return client.call(sign_tx_msg)
return client.call(sign_tx_msg, expect=messages.TezosSignedTx)