1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-24 00:48:45 +00:00

chore(core): decrease cardano size by 2290 bytes

This commit is contained in:
grdddj 2022-09-20 11:52:54 +02:00 committed by matejcik
parent 26fd0de198
commit 0c8528821f
22 changed files with 1063 additions and 962 deletions

View File

@ -1,14 +1,20 @@
from micropython import const from micropython import const
from typing import Any from typing import TYPE_CHECKING
from trezor import messages, wire
from trezor.crypto import base58 from trezor.crypto import base58
from trezor.enums import CardanoAddressType from trezor.enums import CardanoAddressType
from trezor.wire import ProcessError
from . import byron_addresses, seed from . import byron_addresses
from .helpers import ADDRESS_KEY_HASH_SIZE, SCRIPT_HASH_SIZE, bech32, network_ids from .helpers import bech32
from .helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT from .helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT
from .helpers.utils import get_public_key_hash, variable_length_encode from .helpers.utils import get_public_key_hash
if TYPE_CHECKING:
from typing import Any
from trezor import messages
from .seed import Keychain
ADDRESS_TYPES_SHELLEY = ( ADDRESS_TYPES_SHELLEY = (
CardanoAddressType.BASE, CardanoAddressType.BASE,
@ -44,58 +50,64 @@ _MAX_ADDRESS_BYTES_LENGTH = const(65)
def assert_params_cond(condition: bool) -> None: def assert_params_cond(condition: bool) -> None:
if not condition: if not condition:
raise wire.ProcessError("Invalid address parameters") raise ProcessError("Invalid address parameters")
def validate_address_parameters( def validate_address_parameters(
parameters: messages.CardanoAddressParametersType, parameters: messages.CardanoAddressParametersType,
) -> None: ) -> None:
from . import seed
_validate_address_parameters_structure(parameters) _validate_address_parameters_structure(parameters)
address_type = parameters.address_type # local_cache_attribute
address_n = parameters.address_n # local_cache_attribute
address_n_staking = parameters.address_n_staking # local_cache_attribute
script_payment_hash = parameters.script_payment_hash # local_cache_attribute
is_shelley_path = seed.is_shelley_path # local_cache_attribute
CAT = CardanoAddressType # local_cache_global
if parameters.address_type == CardanoAddressType.BYRON: if address_type == CAT.BYRON:
assert_params_cond(seed.is_byron_path(parameters.address_n)) assert_params_cond(seed.is_byron_path(address_n))
elif parameters.address_type == CardanoAddressType.BASE: elif address_type == CAT.BASE:
assert_params_cond(seed.is_shelley_path(parameters.address_n)) assert_params_cond(is_shelley_path(address_n))
_validate_base_address_staking_info( _validate_base_address_staking_info(
parameters.address_n_staking, parameters.staking_key_hash address_n_staking, parameters.staking_key_hash
) )
elif parameters.address_type == CardanoAddressType.BASE_SCRIPT_KEY: elif address_type == CAT.BASE_SCRIPT_KEY:
_validate_script_hash(parameters.script_payment_hash) _validate_script_hash(script_payment_hash)
_validate_base_address_staking_info( _validate_base_address_staking_info(
parameters.address_n_staking, parameters.staking_key_hash address_n_staking, parameters.staking_key_hash
) )
elif parameters.address_type == CardanoAddressType.BASE_KEY_SCRIPT: elif address_type == CAT.BASE_KEY_SCRIPT:
assert_params_cond(seed.is_shelley_path(parameters.address_n)) assert_params_cond(is_shelley_path(address_n))
_validate_script_hash(parameters.script_staking_hash) _validate_script_hash(parameters.script_staking_hash)
elif parameters.address_type == CardanoAddressType.BASE_SCRIPT_SCRIPT: elif address_type == CAT.BASE_SCRIPT_SCRIPT:
_validate_script_hash(parameters.script_payment_hash) _validate_script_hash(script_payment_hash)
_validate_script_hash(parameters.script_staking_hash) _validate_script_hash(parameters.script_staking_hash)
elif parameters.address_type == CardanoAddressType.POINTER: elif address_type == CAT.POINTER:
assert_params_cond(seed.is_shelley_path(parameters.address_n)) assert_params_cond(is_shelley_path(address_n))
assert_params_cond(parameters.certificate_pointer is not None) assert_params_cond(parameters.certificate_pointer is not None)
elif parameters.address_type == CardanoAddressType.POINTER_SCRIPT: elif address_type == CAT.POINTER_SCRIPT:
_validate_script_hash(parameters.script_payment_hash) _validate_script_hash(script_payment_hash)
assert_params_cond(parameters.certificate_pointer is not None) assert_params_cond(parameters.certificate_pointer is not None)
elif parameters.address_type == CardanoAddressType.ENTERPRISE: elif address_type == CAT.ENTERPRISE:
assert_params_cond(seed.is_shelley_path(parameters.address_n)) assert_params_cond(is_shelley_path(address_n))
elif parameters.address_type == CardanoAddressType.ENTERPRISE_SCRIPT: elif address_type == CAT.ENTERPRISE_SCRIPT:
_validate_script_hash(parameters.script_payment_hash) _validate_script_hash(script_payment_hash)
elif parameters.address_type == CardanoAddressType.REWARD: elif address_type == CAT.REWARD:
assert_params_cond(seed.is_shelley_path(parameters.address_n_staking)) assert_params_cond(is_shelley_path(address_n_staking))
assert_params_cond( assert_params_cond(SCHEMA_STAKING_ANY_ACCOUNT.match(address_n_staking))
SCHEMA_STAKING_ANY_ACCOUNT.match(parameters.address_n_staking)
)
elif parameters.address_type == CardanoAddressType.REWARD_SCRIPT: elif address_type == CAT.REWARD_SCRIPT:
_validate_script_hash(parameters.script_staking_hash) _validate_script_hash(parameters.script_staking_hash)
else: else:
@ -105,75 +117,76 @@ def validate_address_parameters(
def _validate_address_parameters_structure( def _validate_address_parameters_structure(
parameters: messages.CardanoAddressParametersType, parameters: messages.CardanoAddressParametersType,
) -> None: ) -> None:
address_n = parameters.address_n address_n = parameters.address_n # local_cache_attribute
address_n_staking = parameters.address_n_staking address_n_staking = parameters.address_n_staking # local_cache_attribute
staking_key_hash = parameters.staking_key_hash staking_key_hash = parameters.staking_key_hash # local_cache_attribute
certificate_pointer = parameters.certificate_pointer certificate_pointer = parameters.certificate_pointer # local_cache_attribute
script_payment_hash = parameters.script_payment_hash script_payment_hash = parameters.script_payment_hash # local_cache_attribute
script_staking_hash = parameters.script_staking_hash script_staking_hash = parameters.script_staking_hash # local_cache_attribute
CAT = CardanoAddressType # local_cache_global
fields_to_be_empty: dict[CardanoAddressType, tuple[Any, ...]] = { fields_to_be_empty: dict[CAT, tuple[Any, ...]] = {
CardanoAddressType.BASE: ( CAT.BASE: (
certificate_pointer, certificate_pointer,
script_payment_hash, script_payment_hash,
script_staking_hash, script_staking_hash,
), ),
CardanoAddressType.BASE_KEY_SCRIPT: ( CAT.BASE_KEY_SCRIPT: (
address_n_staking, address_n_staking,
certificate_pointer, certificate_pointer,
script_payment_hash, script_payment_hash,
), ),
CardanoAddressType.BASE_SCRIPT_KEY: ( CAT.BASE_SCRIPT_KEY: (
address_n, address_n,
certificate_pointer, certificate_pointer,
script_staking_hash, script_staking_hash,
), ),
CardanoAddressType.BASE_SCRIPT_SCRIPT: ( CAT.BASE_SCRIPT_SCRIPT: (
address_n, address_n,
address_n_staking, address_n_staking,
certificate_pointer, certificate_pointer,
), ),
CardanoAddressType.POINTER: ( CAT.POINTER: (
address_n_staking, address_n_staking,
staking_key_hash, staking_key_hash,
script_payment_hash, script_payment_hash,
script_staking_hash, script_staking_hash,
), ),
CardanoAddressType.POINTER_SCRIPT: ( CAT.POINTER_SCRIPT: (
address_n, address_n,
address_n_staking, address_n_staking,
staking_key_hash, staking_key_hash,
script_staking_hash, script_staking_hash,
), ),
CardanoAddressType.ENTERPRISE: ( CAT.ENTERPRISE: (
address_n_staking, address_n_staking,
staking_key_hash, staking_key_hash,
certificate_pointer, certificate_pointer,
script_payment_hash, script_payment_hash,
script_staking_hash, script_staking_hash,
), ),
CardanoAddressType.ENTERPRISE_SCRIPT: ( CAT.ENTERPRISE_SCRIPT: (
address_n, address_n,
address_n_staking, address_n_staking,
staking_key_hash, staking_key_hash,
certificate_pointer, certificate_pointer,
script_staking_hash, script_staking_hash,
), ),
CardanoAddressType.BYRON: ( CAT.BYRON: (
address_n_staking, address_n_staking,
staking_key_hash, staking_key_hash,
certificate_pointer, certificate_pointer,
script_payment_hash, script_payment_hash,
script_staking_hash, script_staking_hash,
), ),
CardanoAddressType.REWARD: ( CAT.REWARD: (
address_n, address_n,
staking_key_hash, staking_key_hash,
certificate_pointer, certificate_pointer,
script_payment_hash, script_payment_hash,
script_staking_hash, script_staking_hash,
), ),
CardanoAddressType.REWARD_SCRIPT: ( CAT.REWARD_SCRIPT: (
address_n, address_n,
address_n_staking, address_n_staking,
staking_key_hash, staking_key_hash,
@ -190,6 +203,8 @@ def _validate_base_address_staking_info(
staking_path: list[int], staking_path: list[int],
staking_key_hash: bytes | None, staking_key_hash: bytes | None,
) -> None: ) -> None:
from .helpers import ADDRESS_KEY_HASH_SIZE
assert_params_cond(not (staking_key_hash and staking_path)) assert_params_cond(not (staking_key_hash and staking_path))
if staking_key_hash: if staking_key_hash:
@ -197,10 +212,12 @@ def _validate_base_address_staking_info(
elif staking_path: elif staking_path:
assert_params_cond(SCHEMA_STAKING_ANY_ACCOUNT.match(staking_path)) assert_params_cond(SCHEMA_STAKING_ANY_ACCOUNT.match(staking_path))
else: else:
raise wire.ProcessError("Invalid address parameters") raise ProcessError("Invalid address parameters")
def _validate_script_hash(script_hash: bytes | None) -> None: def _validate_script_hash(script_hash: bytes | None) -> None:
from .helpers import SCRIPT_HASH_SIZE
assert_params_cond(script_hash is not None and len(script_hash) == SCRIPT_HASH_SIZE) assert_params_cond(script_hash is not None and len(script_hash) == SCRIPT_HASH_SIZE)
@ -216,7 +233,7 @@ def validate_output_address_parameters(
def assert_cond(condition: bool) -> None: def assert_cond(condition: bool) -> None:
if not condition: if not condition:
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
def _validate_and_get_type(address: str, protocol_magic: int, network_id: int) -> int: def _validate_and_get_type(address: str, protocol_magic: int, network_id: int) -> int:
@ -233,9 +250,24 @@ def _validate_and_get_type(address: str, protocol_magic: int, network_id: int) -
if address_type == CardanoAddressType.BYRON: if address_type == CardanoAddressType.BYRON:
byron_addresses.validate(address_bytes, protocol_magic) byron_addresses.validate(address_bytes, protocol_magic)
elif address_type in ADDRESS_TYPES_SHELLEY: elif address_type in ADDRESS_TYPES_SHELLEY:
_validate_shelley_address(address, address_bytes, network_id) # _validate_shelley_address
# _validate_size
assert_cond(
_MIN_ADDRESS_BYTES_LENGTH <= len(address_bytes) <= _MAX_ADDRESS_BYTES_LENGTH
)
# _validate_bech32_hrp
valid_hrp = _get_bech32_hrp(address_type, network_id)
# get_hrp
bech32_hrp = address.rsplit(bech32.HRP_SEPARATOR, 1)[0]
assert_cond(valid_hrp == bech32_hrp)
# _validate_network_id
if _get_network_id(address_bytes) != network_id:
raise ProcessError("Output address network mismatch")
else: else:
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
return address_type return address_type
@ -262,7 +294,7 @@ def get_bytes_unsafe(address: str) -> bytes:
try: try:
address_bytes = base58.decode(address) address_bytes = base58.decode(address)
except ValueError: except ValueError:
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
return address_bytes return address_bytes
@ -271,32 +303,9 @@ def get_type(address: bytes) -> CardanoAddressType:
return address[0] >> 4 # type: ignore [int-into-enum] return address[0] >> 4 # type: ignore [int-into-enum]
def _validate_shelley_address(
address_str: str, address_bytes: bytes, network_id: int
) -> None:
address_type = get_type(address_bytes)
_validate_size(address_bytes)
_validate_bech32_hrp(address_str, address_type, network_id)
_validate_network_id(address_bytes, network_id)
def _validate_size(address_bytes: bytes) -> None:
assert_cond(
_MIN_ADDRESS_BYTES_LENGTH <= len(address_bytes) <= _MAX_ADDRESS_BYTES_LENGTH
)
def _validate_bech32_hrp(
address_str: str, address_type: CardanoAddressType, network_id: int
) -> None:
valid_hrp = _get_bech32_hrp(address_type, network_id)
bech32_hrp = bech32.get_hrp(address_str)
assert_cond(valid_hrp == bech32_hrp)
def _get_bech32_hrp(address_type: CardanoAddressType, network_id: int) -> str: def _get_bech32_hrp(address_type: CardanoAddressType, network_id: int) -> str:
from .helpers import bech32, network_ids
if address_type == CardanoAddressType.BYRON: if address_type == CardanoAddressType.BYRON:
# Byron address uses base58 encoding # Byron address uses base58 encoding
raise ValueError raise ValueError
@ -313,17 +322,12 @@ def _get_bech32_hrp(address_type: CardanoAddressType, network_id: int) -> str:
return bech32.HRP_TESTNET_ADDRESS return bech32.HRP_TESTNET_ADDRESS
def _validate_network_id(address: bytes, network_id: int) -> None:
if _get_network_id(address) != network_id:
raise wire.ProcessError("Output address network mismatch")
def _get_network_id(address: bytes) -> int: def _get_network_id(address: bytes) -> int:
return address[0] & 0x0F return address[0] & 0x0F
def derive_human_readable( def derive_human_readable(
keychain: seed.Keychain, keychain: Keychain,
parameters: messages.CardanoAddressParametersType, parameters: messages.CardanoAddressParametersType,
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
@ -344,7 +348,7 @@ def encode_human_readable(address_bytes: bytes) -> str:
def derive_bytes( def derive_bytes(
keychain: seed.Keychain, keychain: Keychain,
parameters: messages.CardanoAddressParametersType, parameters: messages.CardanoAddressParametersType,
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
@ -360,11 +364,13 @@ def derive_bytes(
def _derive_shelley_address( def _derive_shelley_address(
keychain: seed.Keychain, keychain: Keychain,
parameters: messages.CardanoAddressParametersType, parameters: messages.CardanoAddressParametersType,
network_id: int, network_id: int,
) -> bytes: ) -> bytes:
header = _create_header(parameters.address_type, network_id) # _create_header
header_int = parameters.address_type << 4 | network_id
header = header_int.to_bytes(1, "little")
payment_part = _get_payment_part(keychain, parameters) payment_part = _get_payment_part(keychain, parameters)
staking_part = _get_staking_part(keychain, parameters) staking_part = _get_staking_part(keychain, parameters)
@ -372,13 +378,8 @@ def _derive_shelley_address(
return header + payment_part + staking_part return header + payment_part + staking_part
def _create_header(address_type: CardanoAddressType, network_id: int) -> bytes:
header: int = address_type << 4 | network_id
return header.to_bytes(1, "little")
def _get_payment_part( def _get_payment_part(
keychain: seed.Keychain, parameters: messages.CardanoAddressParametersType keychain: Keychain, parameters: messages.CardanoAddressParametersType
) -> bytes: ) -> bytes:
if parameters.address_n: if parameters.address_n:
return get_public_key_hash(keychain, parameters.address_n) return get_public_key_hash(keychain, parameters.address_n)
@ -389,8 +390,10 @@ def _get_payment_part(
def _get_staking_part( def _get_staking_part(
keychain: seed.Keychain, parameters: messages.CardanoAddressParametersType keychain: Keychain, parameters: messages.CardanoAddressParametersType
) -> bytes: ) -> bytes:
from .helpers.utils import variable_length_encode
if parameters.staking_key_hash: if parameters.staking_key_hash:
return parameters.staking_key_hash return parameters.staking_key_hash
elif parameters.address_n_staking: elif parameters.address_n_staking:
@ -398,16 +401,11 @@ def _get_staking_part(
elif parameters.script_staking_hash: elif parameters.script_staking_hash:
return parameters.script_staking_hash return parameters.script_staking_hash
elif parameters.certificate_pointer: elif parameters.certificate_pointer:
return _encode_certificate_pointer(parameters.certificate_pointer) # _encode_certificate_pointer
else: pointer = parameters.certificate_pointer
return bytes()
def _encode_certificate_pointer(
pointer: messages.CardanoBlockchainPointerType,
) -> bytes:
block_index_encoded = variable_length_encode(pointer.block_index) block_index_encoded = variable_length_encode(pointer.block_index)
tx_index_encoded = variable_length_encode(pointer.tx_index) tx_index_encoded = variable_length_encode(pointer.tx_index)
certificate_index_encoded = variable_length_encode(pointer.certificate_index) certificate_index_encoded = variable_length_encode(pointer.certificate_index)
return bytes(block_index_encoded + tx_index_encoded + certificate_index_encoded) return bytes(block_index_encoded + tx_index_encoded + certificate_index_encoded)
else:
return bytes()

View File

@ -1,19 +1,12 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import messages, wire
from trezor.crypto import hashlib from trezor.crypto import hashlib
from trezor.crypto.curve import ed25519 from trezor.enums import CardanoAddressType, CardanoGovernanceRegistrationFormat
from trezor.enums import (
CardanoAddressType,
CardanoGovernanceRegistrationFormat,
CardanoTxAuxiliaryDataSupplementType,
)
from apps.common import cbor from apps.common import cbor
from . import addresses, layout from . import addresses, layout
from .helpers import bech32
from .helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT from .helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT
from .helpers.utils import derive_public_key from .helpers.utils import derive_public_key
@ -26,6 +19,9 @@ if TYPE_CHECKING:
int, GovernanceRegistrationPayload | GovernanceRegistrationSignature int, GovernanceRegistrationPayload | GovernanceRegistrationSignature
] ]
from trezor import messages
from trezor.wire import Context
from . import seed from . import seed
_AUXILIARY_DATA_HASH_SIZE = const(32) _AUXILIARY_DATA_HASH_SIZE = const(32)
@ -40,6 +36,8 @@ _DEFAULT_VOTING_PURPOSE = const(0)
def assert_cond(condition: bool) -> None: def assert_cond(condition: bool) -> None:
from trezor import wire
if not condition: if not condition:
raise wire.ProcessError("Invalid auxiliary data") raise wire.ProcessError("Invalid auxiliary data")
@ -48,7 +46,8 @@ def validate(auxiliary_data: messages.CardanoTxAuxiliaryData) -> None:
fields_provided = 0 fields_provided = 0
if auxiliary_data.hash: if auxiliary_data.hash:
fields_provided += 1 fields_provided += 1
_validate_hash(auxiliary_data.hash) # _validate_hash
assert_cond(len(auxiliary_data.hash) == _AUXILIARY_DATA_HASH_SIZE)
if auxiliary_data.governance_registration_parameters: if auxiliary_data.governance_registration_parameters:
fields_provided += 1 fields_provided += 1
_validate_governance_registration_parameters( _validate_governance_registration_parameters(
@ -57,10 +56,6 @@ def validate(auxiliary_data: messages.CardanoTxAuxiliaryData) -> None:
assert_cond(fields_provided == 1) assert_cond(fields_provided == 1)
def _validate_hash(auxiliary_data_hash: bytes) -> None:
assert_cond(len(auxiliary_data_hash) == _AUXILIARY_DATA_HASH_SIZE)
def _validate_governance_registration_parameters( def _validate_governance_registration_parameters(
parameters: messages.CardanoGovernanceRegistrationParametersType, parameters: messages.CardanoGovernanceRegistrationParametersType,
) -> None: ) -> None:
@ -107,7 +102,7 @@ def _get_voting_purpose_to_serialize(
async def show( async def show(
ctx: wire.Context, ctx: Context,
keychain: seed.Keychain, keychain: seed.Keychain,
auxiliary_data_hash: bytes, auxiliary_data_hash: bytes,
parameters: messages.CardanoGovernanceRegistrationParametersType | None, parameters: messages.CardanoGovernanceRegistrationParametersType | None,
@ -130,13 +125,15 @@ async def show(
async def _show_governance_registration( async def _show_governance_registration(
ctx: wire.Context, ctx: Context,
keychain: seed.Keychain, keychain: seed.Keychain,
parameters: messages.CardanoGovernanceRegistrationParametersType, parameters: messages.CardanoGovernanceRegistrationParametersType,
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
should_show_details: bool, should_show_details: bool,
) -> None: ) -> None:
from .helpers import bech32
for delegation in parameters.delegations: for delegation in parameters.delegations:
encoded_public_key = bech32.encode( encoded_public_key = bech32.encode(
bech32.HRP_GOVERNANCE_PUBLIC_KEY, delegation.voting_public_key bech32.HRP_GOVERNANCE_PUBLIC_KEY, delegation.voting_public_key
@ -178,6 +175,9 @@ def get_hash_and_supplement(
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
) -> tuple[bytes, messages.CardanoTxAuxiliaryDataSupplement]: ) -> tuple[bytes, messages.CardanoTxAuxiliaryDataSupplement]:
from trezor.enums import CardanoTxAuxiliaryDataSupplementType
from trezor import messages
if parameters := auxiliary_data.governance_registration_parameters: if parameters := auxiliary_data.governance_registration_parameters:
( (
governance_registration_payload, governance_registration_payload,
@ -205,24 +205,26 @@ def _get_governance_registration_hash(
governance_registration_payload: GovernanceRegistrationPayload, governance_registration_payload: GovernanceRegistrationPayload,
governance_registration_payload_signature: bytes, governance_registration_payload_signature: bytes,
) -> bytes: ) -> bytes:
cborized_governance_registration = _cborize_governance_registration( # _cborize_catalyst_registration
governance_registration_payload,
governance_registration_payload_signature,
)
return _get_hash(cbor.encode(_wrap_metadata(cborized_governance_registration)))
def _cborize_governance_registration(
governance_registration_payload: GovernanceRegistrationPayload,
governance_registration_payload_signature: bytes,
) -> GovernanceRegistration:
governance_registration_signature = {1: governance_registration_payload_signature} governance_registration_signature = {1: governance_registration_payload_signature}
cborized_catalyst_registration = {
return {
_METADATA_KEY_GOVERNANCE_REGISTRATION: governance_registration_payload, _METADATA_KEY_GOVERNANCE_REGISTRATION: governance_registration_payload,
_METADATA_KEY_GOVERNANCE_REGISTRATION_SIGNATURE: governance_registration_signature, _METADATA_KEY_GOVERNANCE_REGISTRATION_SIGNATURE: governance_registration_signature,
} }
# _get_hash
# _wrap_metadata
# A new structure of metadata is used after Cardano Mary era. The metadata
# is wrapped in a tuple and auxiliary_scripts may follow it. Cardano
# tooling uses this new format of "wrapped" metadata even if no
# auxiliary_scripts are included. So we do the same here.
# https://github.com/input-output-hk/cardano-ledger-specs/blob/f7deb22be14d31b535f56edc3ca542c548244c67/shelley-ma/shelley-ma-test/cddl-files/shelley-ma.cddl#L212
metadata = (cborized_catalyst_registration, ())
auxiliary_data = cbor.encode(metadata)
return hashlib.blake2b(
data=auxiliary_data, outlen=_AUXILIARY_DATA_HASH_SIZE
).digest()
def _get_signed_governance_registration_payload( def _get_signed_governance_registration_payload(
keychain: seed.Keychain, keychain: seed.Keychain,
@ -275,6 +277,8 @@ def _create_governance_registration_payload_signature(
governance_registration_payload: GovernanceRegistrationPayload, governance_registration_payload: GovernanceRegistrationPayload,
path: list[int], path: list[int],
) -> bytes: ) -> bytes:
from trezor.crypto.curve import ed25519
node = keychain.derive(path) node = keychain.derive(path)
encoded_governance_registration = cbor.encode( encoded_governance_registration = cbor.encode(
@ -289,21 +293,3 @@ def _create_governance_registration_payload_signature(
return ed25519.sign_ext( return ed25519.sign_ext(
node.private_key(), node.private_key_ext(), governance_registration_hash node.private_key(), node.private_key_ext(), governance_registration_hash
) )
def _wrap_metadata(metadata: dict) -> tuple[dict, tuple]:
"""
A new structure of metadata is used after Cardano Mary era. The metadata
is wrapped in a tuple and auxiliary_scripts may follow it. Cardano
tooling uses this new format of "wrapped" metadata even if no
auxiliary_scripts are included. So we do the same here.
https://github.com/input-output-hk/cardano-ledger-specs/blob/f7deb22be14d31b535f56edc3ca542c548244c67/shelley-ma/shelley-ma-test/cddl-files/shelley-ma.cddl#L212
"""
return metadata, ()
def _get_hash(auxiliary_data: bytes) -> bytes:
return hashlib.blake2b(
data=auxiliary_data, outlen=_AUXILIARY_DATA_HASH_SIZE
).digest()

View File

@ -1,13 +1,12 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import log, wire from trezor.crypto import crc
from trezor.crypto import crc, hashlib from trezor.wire import ProcessError
from apps.common import cbor from apps.common import cbor
from .helpers import protocol_magics from .helpers import protocol_magics
from .helpers.utils import derive_public_key
if TYPE_CHECKING: if TYPE_CHECKING:
from . import seed from . import seed
@ -23,31 +22,28 @@ with base58 encoding and all the nuances of Byron addresses.
""" """
def _encode_raw(address_data_encoded: bytes) -> bytes:
return cbor.encode(
[cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)]
)
def derive(keychain: seed.Keychain, path: list, protocol_magic: int) -> bytes: def derive(keychain: seed.Keychain, path: list, protocol_magic: int) -> bytes:
address_attributes = get_address_attributes(protocol_magic) from .helpers.utils import derive_public_key
address_root = _get_address_root(keychain, path, address_attributes) # get_address_attributes
address_type = 0
address_data = [address_root, address_attributes, address_type]
address_data_encoded = cbor.encode(address_data)
return _encode_raw(address_data_encoded)
def get_address_attributes(protocol_magic: int) -> dict:
# protocol magic is included in Byron addresses only on testnets # protocol magic is included in Byron addresses only on testnets
if protocol_magics.is_mainnet(protocol_magic): if protocol_magics.is_mainnet(protocol_magic):
address_attributes = {} address_attributes = {}
else: else:
address_attributes = {_PROTOCOL_MAGIC_KEY: cbor.encode(protocol_magic)} address_attributes = {_PROTOCOL_MAGIC_KEY: cbor.encode(protocol_magic)}
return address_attributes # _get_address_root
extpubkey = derive_public_key(keychain, path, extended=True)
address_root = _address_hash([0, [0, extpubkey], address_attributes])
address_type = 0
address_data = [address_root, address_attributes, address_type]
address_data_encoded = cbor.encode(address_data)
# _encode_raw
return cbor.encode(
[cbor.Tagged(24, address_data_encoded), crc.crc32(address_data_encoded)]
)
def validate(address: bytes, protocol_magic: int) -> None: def validate(address: bytes, protocol_magic: int) -> None:
@ -55,27 +51,38 @@ def validate(address: bytes, protocol_magic: int) -> None:
_validate_protocol_magic(address_data_encoded, protocol_magic) _validate_protocol_magic(address_data_encoded, protocol_magic)
def _address_hash(data: list) -> bytes:
from trezor.crypto import hashlib
cbor_data = cbor.encode(data)
sha_data_hash = hashlib.sha3_256(cbor_data).digest()
res = hashlib.blake2b(data=sha_data_hash, outlen=28).digest()
return res
def _decode_raw(address: bytes) -> bytes: def _decode_raw(address: bytes) -> bytes:
from trezor import log
try: try:
address_unpacked = cbor.decode(address) address_unpacked = cbor.decode(address)
except ValueError as e: except ValueError as e:
if __debug__: if __debug__:
log.exception(__name__, e) log.exception(__name__, e)
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
if not isinstance(address_unpacked, list) or len(address_unpacked) != 2: if not isinstance(address_unpacked, list) or len(address_unpacked) != 2:
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
address_data_encoded = address_unpacked[0] address_data_encoded = address_unpacked[0]
if not isinstance(address_data_encoded, bytes): if not isinstance(address_data_encoded, bytes):
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
address_crc = address_unpacked[1] address_crc = address_unpacked[1]
if not isinstance(address_crc, int): if not isinstance(address_crc, int):
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
if address_crc != crc.crc32(address_data_encoded): if address_crc != crc.crc32(address_data_encoded):
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
return address_data_encoded return address_data_encoded
@ -88,35 +95,21 @@ def _validate_protocol_magic(address_data_encoded: bytes, protocol_magic: int) -
""" """
address_data = cbor.decode(address_data_encoded) address_data = cbor.decode(address_data_encoded)
if not isinstance(address_data, list) or len(address_data) < 2: if not isinstance(address_data, list) or len(address_data) < 2:
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
attributes = address_data[1] attributes = address_data[1]
if protocol_magics.is_mainnet(protocol_magic): if protocol_magics.is_mainnet(protocol_magic):
if _PROTOCOL_MAGIC_KEY in attributes: if _PROTOCOL_MAGIC_KEY in attributes:
raise wire.ProcessError("Output address network mismatch") raise ProcessError("Output address network mismatch")
else: # testnet else: # testnet
if len(attributes) == 0 or _PROTOCOL_MAGIC_KEY not in attributes: if len(attributes) == 0 or _PROTOCOL_MAGIC_KEY not in attributes:
raise wire.ProcessError("Output address network mismatch") raise ProcessError("Output address network mismatch")
protocol_magic_cbor = attributes[_PROTOCOL_MAGIC_KEY] protocol_magic_cbor = attributes[_PROTOCOL_MAGIC_KEY]
address_protocol_magic = cbor.decode(protocol_magic_cbor) address_protocol_magic = cbor.decode(protocol_magic_cbor)
if not isinstance(address_protocol_magic, int): if not isinstance(address_protocol_magic, int):
raise wire.ProcessError("Invalid address") raise ProcessError("Invalid address")
if address_protocol_magic != protocol_magic: if address_protocol_magic != protocol_magic:
raise wire.ProcessError("Output address network mismatch") raise ProcessError("Output address network mismatch")
def _address_hash(data: list) -> bytes:
cbor_data = cbor.encode(data)
sha_data_hash = hashlib.sha3_256(cbor_data).digest()
res = hashlib.blake2b(data=sha_data_hash, outlen=28).digest()
return res
def _get_address_root(
keychain: seed.Keychain, path: list[int], address_attributes: dict
) -> bytes:
extpubkey = derive_public_key(keychain, path, extended=True)
return _address_hash([0, [0, extpubkey], address_attributes])

View File

@ -1,15 +1,11 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.enums import CardanoCertificateType, CardanoPoolRelayType from trezor.enums import CardanoCertificateType, CardanoPoolRelayType
from trezor.wire import ProcessError
from apps.common import cbor
from . import addresses from . import addresses
from .helpers import ADDRESS_KEY_HASH_SIZE, LOVELACE_MAX_SUPPLY from .helpers.utils import get_public_key_hash
from .helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT
from .helpers.utils import get_public_key_hash, validate_stake_credential
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any
@ -36,27 +32,31 @@ def validate(
network_id: int, network_id: int,
account_path_checker: AccountPathChecker, account_path_checker: AccountPathChecker,
) -> None: ) -> None:
from .helpers.utils import validate_stake_credential
_validate_structure(certificate) _validate_structure(certificate)
CCT = CardanoCertificateType # local_cache_global
if certificate.type in ( if certificate.type in (
CardanoCertificateType.STAKE_DELEGATION, CCT.STAKE_DELEGATION,
CardanoCertificateType.STAKE_REGISTRATION, CCT.STAKE_REGISTRATION,
CardanoCertificateType.STAKE_DEREGISTRATION, CCT.STAKE_DEREGISTRATION,
): ):
validate_stake_credential( validate_stake_credential(
certificate.path, certificate.path,
certificate.script_hash, certificate.script_hash,
certificate.key_hash, certificate.key_hash,
wire.ProcessError("Invalid certificate"), ProcessError("Invalid certificate"),
) )
if certificate.type == CardanoCertificateType.STAKE_DELEGATION: if certificate.type == CCT.STAKE_DELEGATION:
if not certificate.pool or len(certificate.pool) != _POOL_HASH_SIZE: if not certificate.pool or len(certificate.pool) != _POOL_HASH_SIZE:
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION: if certificate.type == CCT.STAKE_POOL_REGISTRATION:
if certificate.pool_parameters is None: if certificate.pool_parameters is None:
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
_validate_pool_parameters( _validate_pool_parameters(
certificate.pool_parameters, protocol_magic, network_id certificate.pool_parameters, protocol_magic, network_id
) )
@ -65,14 +65,15 @@ def validate(
def _validate_structure(certificate: messages.CardanoTxCertificate) -> None: def _validate_structure(certificate: messages.CardanoTxCertificate) -> None:
pool = certificate.pool pool = certificate.pool # local_cache_attribute
pool_parameters = certificate.pool_parameters pool_parameters = certificate.pool_parameters # local_cache_attribute
CCT = CardanoCertificateType # local_cache_global
fields_to_be_empty: dict[CardanoCertificateType, tuple[Any, ...]] = { fields_to_be_empty: dict[CCT, tuple[Any, ...]] = {
CardanoCertificateType.STAKE_REGISTRATION: (pool, pool_parameters), CCT.STAKE_REGISTRATION: (pool, pool_parameters),
CardanoCertificateType.STAKE_DELEGATION: (pool_parameters,), CCT.STAKE_DELEGATION: (pool_parameters,),
CardanoCertificateType.STAKE_DEREGISTRATION: (pool, pool_parameters), CCT.STAKE_DEREGISTRATION: (pool, pool_parameters),
CardanoCertificateType.STAKE_POOL_REGISTRATION: ( CCT.STAKE_POOL_REGISTRATION: (
certificate.path, certificate.path,
certificate.script_hash, certificate.script_hash,
certificate.key_hash, certificate.key_hash,
@ -83,18 +84,20 @@ def _validate_structure(certificate: messages.CardanoTxCertificate) -> None:
if certificate.type not in fields_to_be_empty or any( if certificate.type not in fields_to_be_empty or any(
fields_to_be_empty[certificate.type] fields_to_be_empty[certificate.type]
): ):
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
def cborize( def cborize(
keychain: seed.Keychain, certificate: messages.CardanoTxCertificate keychain: seed.Keychain, certificate: messages.CardanoTxCertificate
) -> CborSequence: ) -> CborSequence:
if certificate.type in ( cert_type = certificate.type # local_cache_attribute
if cert_type in (
CardanoCertificateType.STAKE_REGISTRATION, CardanoCertificateType.STAKE_REGISTRATION,
CardanoCertificateType.STAKE_DEREGISTRATION, CardanoCertificateType.STAKE_DEREGISTRATION,
): ):
return ( return (
certificate.type, cert_type,
cborize_stake_credential( cborize_stake_credential(
keychain, keychain,
certificate.path, certificate.path,
@ -102,9 +105,9 @@ def cborize(
certificate.key_hash, certificate.key_hash,
), ),
) )
elif certificate.type == CardanoCertificateType.STAKE_DELEGATION: elif cert_type == CardanoCertificateType.STAKE_DELEGATION:
return ( return (
certificate.type, cert_type,
cborize_stake_credential( cborize_stake_credential(
keychain, keychain,
certificate.path, certificate.path,
@ -136,6 +139,8 @@ def cborize_stake_credential(
def cborize_pool_registration_init( def cborize_pool_registration_init(
certificate: messages.CardanoTxCertificate, certificate: messages.CardanoTxCertificate,
) -> CborSequence: ) -> CborSequence:
from apps.common import cbor
assert certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION assert certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION
pool_parameters = certificate.pool_parameters pool_parameters = certificate.pool_parameters
@ -162,7 +167,7 @@ def cborize_pool_registration_init(
def assert_cond(condition: bool) -> None: def assert_cond(condition: bool) -> None:
if not condition: if not condition:
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
def _validate_pool_parameters( def _validate_pool_parameters(
@ -170,6 +175,8 @@ def _validate_pool_parameters(
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
) -> None: ) -> None:
from .helpers import LOVELACE_MAX_SUPPLY
assert_cond(len(pool_parameters.pool_id) == _POOL_HASH_SIZE) assert_cond(len(pool_parameters.pool_id) == _POOL_HASH_SIZE)
assert_cond(len(pool_parameters.vrf_key_hash) == _VRF_KEY_HASH_SIZE) assert_cond(len(pool_parameters.vrf_key_hash) == _VRF_KEY_HASH_SIZE)
assert_cond(0 <= pool_parameters.pledge <= LOVELACE_MAX_SUPPLY) assert_cond(0 <= pool_parameters.pledge <= LOVELACE_MAX_SUPPLY)
@ -182,14 +189,20 @@ def _validate_pool_parameters(
addresses.validate_reward_address( addresses.validate_reward_address(
pool_parameters.reward_account, protocol_magic, network_id pool_parameters.reward_account, protocol_magic, network_id
) )
pool_metadata = pool_parameters.metadata # local_cache_attribute
if pool_parameters.metadata: if pool_metadata:
_validate_pool_metadata(pool_parameters.metadata) # _validate_pool_metadata
assert_cond(len(pool_metadata.url) <= _MAX_URL_LENGTH)
assert_cond(len(pool_metadata.hash) == _POOL_METADATA_HASH_SIZE)
assert_cond(all((32 <= ord(c) < 127) for c in pool_metadata.url))
def validate_pool_owner( def validate_pool_owner(
owner: messages.CardanoPoolOwner, account_path_checker: AccountPathChecker owner: messages.CardanoPoolOwner, account_path_checker: AccountPathChecker
) -> None: ) -> None:
from .helpers import ADDRESS_KEY_HASH_SIZE
from .helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT
assert_cond( assert_cond(
owner.staking_key_hash is not None or owner.staking_key_path is not None owner.staking_key_hash is not None or owner.staking_key_path is not None
) )
@ -202,6 +215,9 @@ def validate_pool_owner(
def validate_pool_relay(pool_relay: messages.CardanoPoolRelayParameters) -> None: def validate_pool_relay(pool_relay: messages.CardanoPoolRelayParameters) -> None:
port = pool_relay.port # local_cache_attribute
host_name = pool_relay.host_name # local_cache_attribute
if pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_IP: if pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_IP:
assert_cond( assert_cond(
pool_relay.ipv4_address is not None or pool_relay.ipv6_address is not None pool_relay.ipv4_address is not None or pool_relay.ipv6_address is not None
@ -210,32 +226,16 @@ def validate_pool_relay(pool_relay: messages.CardanoPoolRelayParameters) -> None
assert_cond(len(pool_relay.ipv4_address) == _IPV4_ADDRESS_SIZE) assert_cond(len(pool_relay.ipv4_address) == _IPV4_ADDRESS_SIZE)
if pool_relay.ipv6_address is not None: if pool_relay.ipv6_address is not None:
assert_cond(len(pool_relay.ipv6_address) == _IPV6_ADDRESS_SIZE) assert_cond(len(pool_relay.ipv6_address) == _IPV6_ADDRESS_SIZE)
assert_cond( assert_cond(port is not None and 0 <= port <= _MAX_PORT_NUMBER)
pool_relay.port is not None and 0 <= pool_relay.port <= _MAX_PORT_NUMBER
)
elif pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_NAME: elif pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_NAME:
assert_cond( assert_cond(host_name is not None and len(host_name) <= _MAX_URL_LENGTH)
pool_relay.host_name is not None assert_cond(port is not None and 0 <= port <= _MAX_PORT_NUMBER)
and len(pool_relay.host_name) <= _MAX_URL_LENGTH
)
assert_cond(
pool_relay.port is not None and 0 <= pool_relay.port <= _MAX_PORT_NUMBER
)
elif pool_relay.type == CardanoPoolRelayType.MULTIPLE_HOST_NAME: elif pool_relay.type == CardanoPoolRelayType.MULTIPLE_HOST_NAME:
assert_cond( assert_cond(host_name is not None and len(host_name) <= _MAX_URL_LENGTH)
pool_relay.host_name is not None
and len(pool_relay.host_name) <= _MAX_URL_LENGTH
)
else: else:
raise RuntimeError # should be unreachable raise RuntimeError # should be unreachable
def _validate_pool_metadata(pool_metadata: messages.CardanoPoolMetadataType) -> None:
assert_cond(len(pool_metadata.url) <= _MAX_URL_LENGTH)
assert_cond(len(pool_metadata.hash) == _POOL_METADATA_HASH_SIZE)
assert_cond(all((32 <= ord(c) < 127) for c in pool_metadata.url))
def cborize_pool_owner( def cborize_pool_owner(
keychain: seed.Keychain, pool_owner: messages.CardanoPoolOwner keychain: seed.Keychain, pool_owner: messages.CardanoPoolOwner
) -> bytes: ) -> bytes:
@ -264,22 +264,24 @@ def _cborize_ipv6_address(ipv6_address: bytes | None) -> bytes | None:
def cborize_pool_relay( def cborize_pool_relay(
pool_relay: messages.CardanoPoolRelayParameters, pool_relay: messages.CardanoPoolRelayParameters,
) -> CborSequence: ) -> CborSequence:
if pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_IP: relay_type = pool_relay.type # local_cache_attribute
if relay_type == CardanoPoolRelayType.SINGLE_HOST_IP:
return ( return (
pool_relay.type, relay_type,
pool_relay.port, pool_relay.port,
pool_relay.ipv4_address, pool_relay.ipv4_address,
_cborize_ipv6_address(pool_relay.ipv6_address), _cborize_ipv6_address(pool_relay.ipv6_address),
) )
elif pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_NAME: elif relay_type == CardanoPoolRelayType.SINGLE_HOST_NAME:
return ( return (
pool_relay.type, relay_type,
pool_relay.port, pool_relay.port,
pool_relay.host_name, pool_relay.host_name,
) )
elif pool_relay.type == CardanoPoolRelayType.MULTIPLE_HOST_NAME: elif relay_type == CardanoPoolRelayType.MULTIPLE_HOST_NAME:
return ( return (
pool_relay.type, relay_type,
pool_relay.host_name, pool_relay.host_name,
) )
else: else:

View File

@ -1,21 +1,31 @@
from trezor import log, messages, wire from typing import TYPE_CHECKING
from . import addresses, seed from . import seed
from .helpers.credential import Credential, should_show_credentials
from .helpers.utils import validate_network_info if TYPE_CHECKING:
from .layout import show_cardano_address, show_credentials from trezor.wire import Context
from trezor.messages import CardanoGetAddress, CardanoAddress
@seed.with_keychain @seed.with_keychain
async def get_address( async def get_address(
ctx: wire.Context, msg: messages.CardanoGetAddress, keychain: seed.Keychain ctx: Context, msg: CardanoGetAddress, keychain: seed.Keychain
) -> messages.CardanoAddress: ) -> CardanoAddress:
from trezor.messages import CardanoAddress
from trezor import log, wire
from .helpers.credential import Credential, should_show_credentials
from .helpers.utils import validate_network_info
from .layout import show_cardano_address, show_credentials
from . import addresses
address_parameters = msg.address_parameters # local_cache_attribute
validate_network_info(msg.network_id, msg.protocol_magic) validate_network_info(msg.network_id, msg.protocol_magic)
addresses.validate_address_parameters(msg.address_parameters) addresses.validate_address_parameters(address_parameters)
try: try:
address = addresses.derive_human_readable( address = addresses.derive_human_readable(
keychain, msg.address_parameters, msg.protocol_magic, msg.network_id keychain, address_parameters, msg.protocol_magic, msg.network_id
) )
except ValueError as e: except ValueError as e:
if __debug__: if __debug__:
@ -23,22 +33,13 @@ async def get_address(
raise wire.ProcessError("Deriving address failed") raise wire.ProcessError("Deriving address failed")
if msg.show_display: if msg.show_display:
await _display_address(ctx, msg.address_parameters, address, msg.protocol_magic) # _display_address
return messages.CardanoAddress(address=address)
async def _display_address(
ctx: wire.Context,
address_parameters: messages.CardanoAddressParametersType,
address: str,
protocol_magic: int,
) -> None:
if should_show_credentials(address_parameters): if should_show_credentials(address_parameters):
await show_credentials( await show_credentials(
ctx, ctx,
Credential.payment_credential(address_parameters), Credential.payment_credential(address_parameters),
Credential.stake_credential(address_parameters), Credential.stake_credential(address_parameters),
) )
await show_cardano_address(ctx, address_parameters, address, msg.protocol_magic)
await show_cardano_address(ctx, address_parameters, address, protocol_magic) return CardanoAddress(address=address)

View File

@ -1,13 +1,20 @@
from trezor import messages, wire from typing import TYPE_CHECKING
from trezor.enums import CardanoNativeScriptHashDisplayFormat
from . import layout, native_script, seed from . import seed
if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import CardanoGetNativeScriptHash, CardanoNativeScriptHash
@seed.with_keychain @seed.with_keychain
async def get_native_script_hash( async def get_native_script_hash(
ctx: wire.Context, msg: messages.CardanoGetNativeScriptHash, keychain: seed.Keychain ctx: Context, msg: CardanoGetNativeScriptHash, keychain: seed.Keychain
) -> messages.CardanoNativeScriptHash: ) -> CardanoNativeScriptHash:
from trezor.messages import CardanoNativeScriptHash
from trezor.enums import CardanoNativeScriptHashDisplayFormat
from . import layout, native_script
native_script.validate_native_script(msg.script) native_script.validate_native_script(msg.script)
script_hash = native_script.get_native_script_hash(keychain, msg.script) script_hash = native_script.get_native_script_hash(keychain, msg.script)
@ -16,4 +23,4 @@ async def get_native_script_hash(
await layout.show_native_script(ctx, msg.script) await layout.show_native_script(ctx, msg.script)
await layout.show_script_hash(ctx, script_hash, msg.display_format) await layout.show_script_hash(ctx, script_hash, msg.display_format)
return messages.CardanoNativeScriptHash(script_hash=script_hash) return CardanoNativeScriptHash(script_hash=script_hash)

View File

@ -1,29 +1,34 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify from ubinascii import hexlify
from trezor import log, messages, wire
from trezor.ui.layouts import show_pubkey
from apps.common import paths
from . import seed from . import seed
from .helpers.paths import SCHEMA_MINT, SCHEMA_PUBKEY
from .helpers.utils import derive_public_key if TYPE_CHECKING:
from trezor.wire import Context
from trezor.messages import CardanoGetPublicKey, CardanoPublicKey
@seed.with_keychain @seed.with_keychain
async def get_public_key( async def get_public_key(
ctx: wire.Context, msg: messages.CardanoGetPublicKey, keychain: seed.Keychain ctx: Context, msg: CardanoGetPublicKey, keychain: seed.Keychain
) -> messages.CardanoPublicKey: ) -> CardanoPublicKey:
from trezor import log, wire
from trezor.ui.layouts import show_pubkey
from apps.common import paths
from .helpers.paths import SCHEMA_MINT, SCHEMA_PUBKEY
address_n = msg.address_n # local_cache_attribute
await paths.validate_path( await paths.validate_path(
ctx, ctx,
keychain, keychain,
msg.address_n, address_n,
# path must match the PUBKEY schema # path must match the PUBKEY schema
SCHEMA_PUBKEY.match(msg.address_n) or SCHEMA_MINT.match(msg.address_n), SCHEMA_PUBKEY.match(address_n) or SCHEMA_MINT.match(address_n),
) )
try: try:
key = _get_public_key(keychain, msg.address_n) key = _get_public_key(keychain, address_n)
except ValueError as e: except ValueError as e:
if __debug__: if __debug__:
log.exception(__name__, e) log.exception(__name__, e)
@ -36,14 +41,17 @@ async def get_public_key(
def _get_public_key( def _get_public_key(
keychain: seed.Keychain, derivation_path: list[int] keychain: seed.Keychain, derivation_path: list[int]
) -> messages.CardanoPublicKey: ) -> CardanoPublicKey:
from .helpers.utils import derive_public_key
from trezor.messages import HDNodeType, CardanoPublicKey
node = keychain.derive(derivation_path) node = keychain.derive(derivation_path)
public_key = hexlify(derive_public_key(keychain, derivation_path)).decode() public_key = hexlify(derive_public_key(keychain, derivation_path)).decode()
chain_code = hexlify(node.chain_code()).decode() chain_code = hexlify(node.chain_code()).decode()
xpub_key = public_key + chain_code xpub_key = public_key + chain_code
node_type = messages.HDNodeType( node_type = HDNodeType(
depth=node.depth(), depth=node.depth(),
child_num=node.child_num(), child_num=node.child_num(),
fingerprint=node.fingerprint(), fingerprint=node.fingerprint(),
@ -51,4 +59,4 @@ def _get_public_key(
public_key=derive_public_key(keychain, derivation_path), public_key=derive_public_key(keychain, derivation_path),
) )
return messages.CardanoPublicKey(node=node_type, xpub=xpub_key) return CardanoPublicKey(node=node_type, xpub=xpub_key)

View File

@ -1,8 +1,6 @@
from micropython import const LOVELACE_MAX_SUPPLY = 45_000_000_000 * 1_000_000
INPUT_PREV_HASH_SIZE = 32
LOVELACE_MAX_SUPPLY = const(45_000_000_000 * 1_000_000) ADDRESS_KEY_HASH_SIZE = 28
INPUT_PREV_HASH_SIZE = const(32) SCRIPT_HASH_SIZE = 28
ADDRESS_KEY_HASH_SIZE = const(28) OUTPUT_DATUM_HASH_SIZE = 32
SCRIPT_HASH_SIZE = const(28) SCRIPT_DATA_HASH_SIZE = 32
OUTPUT_DATUM_HASH_SIZE = const(32)
SCRIPT_DATA_HASH_SIZE = const(32)

View File

@ -1,11 +1,8 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire from trezor.wire import ProcessError
from ...common.paths import HARDENED
from .. import seed from .. import seed
from .paths import ACCOUNT_PATH_INDEX, ACCOUNT_PATH_LENGTH
from .utils import to_account_path
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import ( from trezor.messages import (
@ -29,7 +26,9 @@ class AccountPathChecker:
def __init__(self) -> None: def __init__(self) -> None:
self.account_path: object | list[int] = self.UNDEFINED self.account_path: object | list[int] = self.UNDEFINED
def _add(self, path: list[int], error: wire.ProcessError) -> None: def _add(self, path: list[int], error: ProcessError) -> None:
from .utils import to_account_path
# multi-sig and minting paths are always shown and thus don't need to be checked # multi-sig and minting paths are always shown and thus don't need to be checked
if seed.is_multisig_path(path) or seed.is_minting_path(path): if seed.is_multisig_path(path) or seed.is_minting_path(path):
return return
@ -51,10 +50,15 @@ class AccountPathChecker:
from the user. This way the user can be sure that the funds are being moved between the user's from the user. This way the user can be sure that the funds are being moved between the user's
accounts without being bothered by more screens. accounts without being bothered by more screens.
""" """
assert isinstance(self.account_path, list) from ...common.paths import HARDENED
from .paths import ACCOUNT_PATH_INDEX, ACCOUNT_PATH_LENGTH
self_account_path = self.account_path # local_cache_attribute
assert isinstance(self_account_path, list)
is_control_path_byron_or_shelley = seed.is_byron_path( is_control_path_byron_or_shelley = seed.is_byron_path(
self.account_path self_account_path
) or seed.is_shelley_path(self.account_path) ) or seed.is_shelley_path(self_account_path)
is_new_path_byron_or_shelley = seed.is_byron_path( is_new_path_byron_or_shelley = seed.is_byron_path(
account_path account_path
@ -63,9 +67,9 @@ class AccountPathChecker:
return ( return (
is_control_path_byron_or_shelley is_control_path_byron_or_shelley
and is_new_path_byron_or_shelley and is_new_path_byron_or_shelley
and len(self.account_path) == ACCOUNT_PATH_LENGTH and len(self_account_path) == ACCOUNT_PATH_LENGTH
and len(account_path) == ACCOUNT_PATH_LENGTH and len(account_path) == ACCOUNT_PATH_LENGTH
and self.account_path[ACCOUNT_PATH_INDEX] == 0 | HARDENED and self_account_path[ACCOUNT_PATH_INDEX] == 0 | HARDENED
and account_path[ACCOUNT_PATH_INDEX] == 0 | HARDENED and account_path[ACCOUNT_PATH_INDEX] == 0 | HARDENED
) )
@ -76,27 +80,25 @@ class AccountPathChecker:
if not output.address_parameters.address_n: if not output.address_parameters.address_n:
return return
self._add( self._add(output.address_parameters.address_n, ProcessError("Invalid output"))
output.address_parameters.address_n, wire.ProcessError("Invalid output")
)
def add_certificate(self, certificate: CardanoTxCertificate) -> None: def add_certificate(self, certificate: CardanoTxCertificate) -> None:
if not certificate.path: if not certificate.path:
return return
self._add(certificate.path, wire.ProcessError("Invalid certificate")) self._add(certificate.path, ProcessError("Invalid certificate"))
def add_pool_owner(self, pool_owner: CardanoPoolOwner) -> None: def add_pool_owner(self, pool_owner: CardanoPoolOwner) -> None:
if not pool_owner.staking_key_path: if not pool_owner.staking_key_path:
return return
self._add(pool_owner.staking_key_path, wire.ProcessError("Invalid certificate")) self._add(pool_owner.staking_key_path, ProcessError("Invalid certificate"))
def add_withdrawal(self, withdrawal: CardanoTxWithdrawal) -> None: def add_withdrawal(self, withdrawal: CardanoTxWithdrawal) -> None:
if not withdrawal.path: if not withdrawal.path:
return return
self._add(withdrawal.path, wire.ProcessError("Invalid withdrawal")) self._add(withdrawal.path, ProcessError("Invalid withdrawal"))
def add_witness_request(self, witness_request: CardanoTxWitnessRequest) -> None: def add_witness_request(self, witness_request: CardanoTxWitnessRequest) -> None:
self._add(witness_request.path, wire.ProcessError("Invalid witness request")) self._add(witness_request.path, ProcessError("Invalid witness request"))

View File

@ -23,15 +23,11 @@ def encode(hrp: str, data: bytes) -> str:
def decode_unsafe(bech: str) -> bytes: def decode_unsafe(bech: str) -> bytes:
hrp = get_hrp(bech) hrp = bech.rsplit(HRP_SEPARATOR, 1)[0]
return decode(hrp, bech) return _decode(hrp, bech)
def get_hrp(bech: str) -> str: def _decode(hrp: str, bech: str) -> bytes:
return bech.rsplit(HRP_SEPARATOR, 1)[0]
def decode(hrp: str, bech: str) -> bytes:
decoded_hrp, data, spec = bech32.bech32_decode(bech, 130) decoded_hrp, data, spec = bech32.bech32_decode(bech, 130)
if decoded_hrp != hrp: if decoded_hrp != hrp:
raise ValueError raise ValueError

View File

@ -2,10 +2,7 @@ from typing import TYPE_CHECKING
from trezor.enums import CardanoAddressType from trezor.enums import CardanoAddressType
from ...common.paths import address_n_to_str from .paths import SCHEMA_PAYMENT
from . import bech32
from .paths import CHAIN_STAKING_KEY, SCHEMA_PAYMENT, SCHEMA_STAKING
from .utils import to_account_path
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor import messages from trezor import messages
@ -56,7 +53,9 @@ class Credential:
def payment_credential( def payment_credential(
cls, address_params: messages.CardanoAddressParametersType cls, address_params: messages.CardanoAddressParametersType
) -> "Credential": ) -> "Credential":
address_type = address_params.address_type address_type = address_params.address_type # local_cache_attribute
CAT = CardanoAddressType # local_cache_global
credential = cls( credential = cls(
type_name=CREDENTIAL_TYPE_PAYMENT, type_name=CREDENTIAL_TYPE_PAYMENT,
address_type=address_type, address_type=address_type,
@ -67,26 +66,26 @@ class Credential:
) )
if address_type in ( if address_type in (
CardanoAddressType.BASE, CAT.BASE,
CardanoAddressType.BASE_KEY_SCRIPT, CAT.BASE_KEY_SCRIPT,
CardanoAddressType.POINTER, CAT.POINTER,
CardanoAddressType.ENTERPRISE, CAT.ENTERPRISE,
CardanoAddressType.BYRON, CAT.BYRON,
): ):
if not SCHEMA_PAYMENT.match(address_params.address_n): if not SCHEMA_PAYMENT.match(address_params.address_n):
credential.is_unusual_path = True credential.is_unusual_path = True
elif address_type in ( elif address_type in (
CardanoAddressType.BASE_SCRIPT_KEY, CAT.BASE_SCRIPT_KEY,
CardanoAddressType.BASE_SCRIPT_SCRIPT, CAT.BASE_SCRIPT_SCRIPT,
CardanoAddressType.POINTER_SCRIPT, CAT.POINTER_SCRIPT,
CardanoAddressType.ENTERPRISE_SCRIPT, CAT.ENTERPRISE_SCRIPT,
): ):
credential.is_other_warning = True credential.is_other_warning = True
elif address_type in ( elif address_type in (
CardanoAddressType.REWARD, CAT.REWARD,
CardanoAddressType.REWARD_SCRIPT, CAT.REWARD_SCRIPT,
): ):
credential.is_reward = True credential.is_reward = True
@ -99,55 +98,58 @@ class Credential:
def stake_credential( def stake_credential(
cls, address_params: messages.CardanoAddressParametersType cls, address_params: messages.CardanoAddressParametersType
) -> "Credential": ) -> "Credential":
address_type = address_params.address_type from .paths import SCHEMA_STAKING
address_n_staking = address_params.address_n_staking # local_cache_attribute
address_type = address_params.address_type # local_cache_attribute
CAT = CardanoAddressType # local_cache_global
credential = cls( credential = cls(
type_name=CREDENTIAL_TYPE_STAKE, type_name=CREDENTIAL_TYPE_STAKE,
address_type=address_type, address_type=address_type,
path=address_params.address_n_staking, path=address_n_staking,
key_hash=address_params.staking_key_hash, key_hash=address_params.staking_key_hash,
script_hash=address_params.script_staking_hash, script_hash=address_params.script_staking_hash,
pointer=address_params.certificate_pointer, pointer=address_params.certificate_pointer,
) )
if address_type == CardanoAddressType.BASE: if address_type == CAT.BASE:
if address_params.staking_key_hash: if address_params.staking_key_hash:
credential.is_other_warning = True credential.is_other_warning = True
else: else:
if not SCHEMA_STAKING.match(address_params.address_n_staking): if not SCHEMA_STAKING.match(address_n_staking):
credential.is_unusual_path = True credential.is_unusual_path = True
if not _do_base_address_credentials_match( if not _do_base_address_credentials_match(
address_params.address_n, address_params.address_n,
address_params.address_n_staking, address_n_staking,
): ):
credential.is_mismatch = True credential.is_mismatch = True
elif address_type == CardanoAddressType.BASE_SCRIPT_KEY: elif address_type == CAT.BASE_SCRIPT_KEY:
if address_params.address_n_staking and not SCHEMA_STAKING.match( if address_n_staking and not SCHEMA_STAKING.match(address_n_staking):
address_params.address_n_staking
):
credential.is_unusual_path = True credential.is_unusual_path = True
elif address_type in ( elif address_type in (
CardanoAddressType.POINTER, CAT.POINTER,
CardanoAddressType.POINTER_SCRIPT, CAT.POINTER_SCRIPT,
): ):
credential.is_other_warning = True credential.is_other_warning = True
elif address_type == CardanoAddressType.REWARD: elif address_type == CAT.REWARD:
if not SCHEMA_STAKING.match(address_params.address_n_staking): if not SCHEMA_STAKING.match(address_n_staking):
credential.is_unusual_path = True credential.is_unusual_path = True
elif address_type in ( elif address_type in (
CardanoAddressType.BASE_KEY_SCRIPT, CAT.BASE_KEY_SCRIPT,
CardanoAddressType.BASE_SCRIPT_SCRIPT, CAT.BASE_SCRIPT_SCRIPT,
CardanoAddressType.REWARD_SCRIPT, CAT.REWARD_SCRIPT,
): ):
credential.is_other_warning = True credential.is_other_warning = True
elif address_type in ( elif address_type in (
CardanoAddressType.ENTERPRISE, CAT.ENTERPRISE,
CardanoAddressType.ENTERPRISE_SCRIPT, CAT.ENTERPRISE_SCRIPT,
CardanoAddressType.BYRON, CAT.BYRON,
): ):
credential.is_no_staking = True credential.is_no_staking = True
@ -183,6 +185,11 @@ class Credential:
return "" return ""
def format(self) -> list[PropertyType]: def format(self) -> list[PropertyType]:
from ...common.paths import address_n_to_str
from . import bech32
pointer = self.pointer # local_cache_attribute
if self.path: if self.path:
return [(None, address_n_to_str(self.path))] return [(None, address_n_to_str(self.path))]
elif self.key_hash: elif self.key_hash:
@ -194,11 +201,11 @@ class Credential:
return [(None, bech32.encode(hrp, self.key_hash))] return [(None, bech32.encode(hrp, self.key_hash))]
elif self.script_hash: elif self.script_hash:
return [(None, bech32.encode(bech32.HRP_SCRIPT_HASH, self.script_hash))] return [(None, bech32.encode(bech32.HRP_SCRIPT_HASH, self.script_hash))]
elif self.pointer: elif pointer:
return [ return [
(f"Block: {self.pointer.block_index}", None), (f"Block: {pointer.block_index}", None),
(f"Transaction: {self.pointer.tx_index}", None), (f"Transaction: {pointer.tx_index}", None),
(f"Certificate: {self.pointer.certificate_index}", None), (f"Certificate: {pointer.certificate_index}", None),
] ]
else: else:
return [] return []
@ -221,8 +228,8 @@ def _do_base_address_credentials_match(
address_n: list[int], address_n: list[int],
address_n_staking: list[int], address_n_staking: list[int],
) -> bool: ) -> bool:
return address_n_staking == _path_to_staking_path(address_n) from .paths import CHAIN_STAKING_KEY
from .utils import to_account_path
path_to_staking_path = to_account_path(address_n) + [CHAIN_STAKING_KEY, 0]
def _path_to_staking_path(path: list[int]) -> list[int]: return address_n_staking == path_to_staking_path
return to_account_path(path) + [CHAIN_STAKING_KEY, 0]

View File

@ -1,21 +1,13 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto import hashlib from trezor.crypto import hashlib
from apps.common.seed import remove_ed25519_prefix from . import ADDRESS_KEY_HASH_SIZE, bech32
from .paths import ACCOUNT_PATH_INDEX
from . import (
ADDRESS_KEY_HASH_SIZE,
SCRIPT_HASH_SIZE,
bech32,
network_ids,
protocol_magics,
)
from .paths import ACCOUNT_PATH_INDEX, SCHEMA_STAKING_ANY_ACCOUNT, unharden
if TYPE_CHECKING: if TYPE_CHECKING:
from .. import seed from .. import seed
from trezor.wire import ProcessError
def variable_length_encode(number: int) -> bytes: def variable_length_encode(number: int) -> bytes:
@ -40,6 +32,8 @@ def to_account_path(path: list[int]) -> list[int]:
def format_account_number(path: list[int]) -> str: def format_account_number(path: list[int]) -> str:
from .paths import unharden
if len(path) <= ACCOUNT_PATH_INDEX: if len(path) <= ACCOUNT_PATH_INDEX:
raise ValueError("Path is too short.") raise ValueError("Path is too short.")
@ -76,6 +70,8 @@ def get_public_key_hash(keychain: seed.Keychain, path: list[int]) -> bytes:
def derive_public_key( def derive_public_key(
keychain: seed.Keychain, path: list[int], extended: bool = False keychain: seed.Keychain, path: list[int], extended: bool = False
) -> bytes: ) -> bytes:
from apps.common.seed import remove_ed25519_prefix
node = keychain.derive(path) node = keychain.derive(path)
public_key = remove_ed25519_prefix(node.public_key()) public_key = remove_ed25519_prefix(node.public_key())
return public_key if not extended else public_key + node.chain_code() return public_key if not extended else public_key + node.chain_code()
@ -85,8 +81,11 @@ def validate_stake_credential(
path: list[int], path: list[int],
script_hash: bytes | None, script_hash: bytes | None,
key_hash: bytes | None, key_hash: bytes | None,
error: wire.ProcessError, error: ProcessError,
) -> None: ) -> None:
from . import SCRIPT_HASH_SIZE
from .paths import SCHEMA_STAKING_ANY_ACCOUNT
if sum(bool(k) for k in (path, script_hash, key_hash)) != 1: if sum(bool(k) for k in (path, script_hash, key_hash)) != 1:
raise error raise error
@ -104,6 +103,9 @@ def validate_network_info(network_id: int, protocol_magic: int) -> None:
belong to the mainnet or that both belong to a testnet. We don't need to check for belong to the mainnet or that both belong to a testnet. We don't need to check for
consistency between various testnets (at least for now). consistency between various testnets (at least for now).
""" """
from trezor import wire
from . import network_ids, protocol_magics
is_mainnet_network_id = network_ids.is_mainnet(network_id) is_mainnet_network_id = network_ids.is_mainnet(network_id)
is_mainnet_protocol_magic = protocol_magics.is_mainnet(protocol_magic) is_mainnet_protocol_magic = protocol_magics.is_mainnet(protocol_magic)

View File

@ -1,42 +1,39 @@
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING
from trezor import messages, ui from trezor import ui
from trezor.enums import ( from trezor.enums import (
ButtonRequestType, ButtonRequestType,
CardanoAddressType, CardanoAddressType,
CardanoCertificateType, CardanoCertificateType,
CardanoNativeScriptHashDisplayFormat,
CardanoNativeScriptType, CardanoNativeScriptType,
) )
from trezor.strings import format_amount from trezor.strings import format_amount
from trezor.ui.layouts import ( from trezor.ui import layouts
confirm_blob,
confirm_metadata,
confirm_output,
confirm_path_warning,
confirm_properties,
confirm_text,
should_show_more,
show_address,
)
from apps.common.paths import address_n_to_str from apps.common.paths import address_n_to_str
from . import addresses, seed from . import addresses
from .helpers import bech32, network_ids, protocol_magics from .helpers import bech32, protocol_magics
from .helpers.utils import ( from .helpers.utils import (
format_account_number, format_account_number,
format_asset_fingerprint, format_asset_fingerprint,
format_optional_int, format_optional_int,
format_stake_pool_id, format_stake_pool_id,
to_account_path,
) )
if TYPE_CHECKING: confirm_metadata = layouts.confirm_metadata # global_import_cache
from trezor import wire confirm_properties = layouts.confirm_properties # global_import_cache
if TYPE_CHECKING:
from typing import Literal
from trezor.wire import Context
from trezor import messages
from trezor.enums import CardanoNativeScriptHashDisplayFormat
from trezor.ui.layouts import PropertyType from trezor.ui.layouts import PropertyType
from .helpers.credential import Credential from .helpers.credential import Credential
from .seed import Keychain
ADDRESS_TYPE_NAMES = { ADDRESS_TYPE_NAMES = {
@ -69,17 +66,27 @@ CERTIFICATE_TYPE_NAMES = {
CardanoCertificateType.STAKE_POOL_REGISTRATION: "Stakepool registration", CardanoCertificateType.STAKE_POOL_REGISTRATION: "Stakepool registration",
} }
BRT_Other = ButtonRequestType.Other # global_import_cache
def format_coin_amount(amount: int, network_id: int) -> str: def format_coin_amount(amount: int, network_id: int) -> str:
from .helpers import network_ids
currency = "ADA" if network_ids.is_mainnet(network_id) else "tADA" currency = "ADA" if network_ids.is_mainnet(network_id) else "tADA"
return f"{format_amount(amount, 6)} {currency}" return f"{format_amount(amount, 6)} {currency}"
async def show_native_script( async def show_native_script(
ctx: wire.Context, ctx: Context,
script: messages.CardanoNativeScript, script: messages.CardanoNativeScript,
indices: list[int] | None = None, indices: list[int] | None = None,
) -> None: ) -> None:
CNST = CardanoNativeScriptType # local_cache_global
script_type = script.type # local_cache_attribute
key_path = script.key_path # local_cache_attribute
key_hash = script.key_hash # local_cache_attribute
scripts = script.scripts # local_cache_attribute
script_heading = "Script" script_heading = "Script"
if indices is None: if indices is None:
indices = [] indices = []
@ -87,67 +94,68 @@ async def show_native_script(
script_heading += " " + ".".join(str(i) for i in indices) script_heading += " " + ".".join(str(i) for i in indices)
script_type_name_suffix = "" script_type_name_suffix = ""
if script.type == CardanoNativeScriptType.PUB_KEY: if script_type == CNST.PUB_KEY:
if script.key_path: if key_path:
script_type_name_suffix = "path" script_type_name_suffix = "path"
elif script.key_hash: elif key_hash:
script_type_name_suffix = "hash" script_type_name_suffix = "hash"
props: list[PropertyType] = [ props: list[PropertyType] = [
( (
f"{script_heading} - {SCRIPT_TYPE_NAMES[script.type]} {script_type_name_suffix}:", f"{script_heading} - {SCRIPT_TYPE_NAMES[script_type]} {script_type_name_suffix}:",
None, None,
) )
] ]
append = props.append # local_cache_attribute
if script.type == CardanoNativeScriptType.PUB_KEY: if script_type == CNST.PUB_KEY:
assert script.key_hash is not None or script.key_path # validate_script assert key_hash is not None or key_path # validate_script
if script.key_hash: if key_hash:
props.append( append((None, bech32.encode(bech32.HRP_SHARED_KEY_HASH, key_hash)))
(None, bech32.encode(bech32.HRP_SHARED_KEY_HASH, script.key_hash)) elif key_path:
) append((address_n_to_str(key_path), None))
elif script.key_path: elif script_type == CNST.N_OF_K:
props.append((address_n_to_str(script.key_path), None))
elif script.type == CardanoNativeScriptType.N_OF_K:
assert script.required_signatures_count is not None # validate_script assert script.required_signatures_count is not None # validate_script
props.append( append(
( (
f"Requires {script.required_signatures_count} out of {len(script.scripts)} signatures.", f"Requires {script.required_signatures_count} out of {len(scripts)} signatures.",
None, None,
) )
) )
elif script.type == CardanoNativeScriptType.INVALID_BEFORE: elif script_type == CNST.INVALID_BEFORE:
assert script.invalid_before is not None # validate_script assert script.invalid_before is not None # validate_script
props.append((str(script.invalid_before), None)) append((str(script.invalid_before), None))
elif script.type == CardanoNativeScriptType.INVALID_HEREAFTER: elif script_type == CNST.INVALID_HEREAFTER:
assert script.invalid_hereafter is not None # validate_script assert script.invalid_hereafter is not None # validate_script
props.append((str(script.invalid_hereafter), None)) append((str(script.invalid_hereafter), None))
if script.type in ( if script_type in (
CardanoNativeScriptType.ALL, CNST.ALL,
CardanoNativeScriptType.ANY, CNST.ANY,
CardanoNativeScriptType.N_OF_K, CNST.N_OF_K,
): ):
assert script.scripts # validate_script assert scripts # validate_script
props.append((f"Contains {len(script.scripts)} nested scripts.", None)) append((f"Contains {len(scripts)} nested scripts.", None))
await confirm_properties( await confirm_properties(
ctx, ctx,
"verify_script", "verify_script",
title="Verify script", "Verify script",
props=props, props,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
for i, sub_script in enumerate(script.scripts): for i, sub_script in enumerate(scripts):
await show_native_script(ctx, sub_script, indices + [i + 1]) await show_native_script(ctx, sub_script, indices + [i + 1])
async def show_script_hash( async def show_script_hash(
ctx: wire.Context, ctx: Context,
script_hash: bytes, script_hash: bytes,
display_format: CardanoNativeScriptHashDisplayFormat, display_format: CardanoNativeScriptHashDisplayFormat,
) -> None: ) -> None:
from trezor.enums import CardanoNativeScriptHashDisplayFormat
assert display_format in ( assert display_format in (
CardanoNativeScriptHashDisplayFormat.BECH32, CardanoNativeScriptHashDisplayFormat.BECH32,
CardanoNativeScriptHashDisplayFormat.POLICY_ID, CardanoNativeScriptHashDisplayFormat.POLICY_ID,
@ -157,25 +165,23 @@ async def show_script_hash(
await confirm_properties( await confirm_properties(
ctx, ctx,
"verify_script", "verify_script",
title="Verify script", "Verify script",
props=[ (("Script hash:", bech32.encode(bech32.HRP_SCRIPT_HASH, script_hash)),),
("Script hash:", bech32.encode(bech32.HRP_SCRIPT_HASH, script_hash)) br_code=BRT_Other,
],
br_code=ButtonRequestType.Other,
) )
elif display_format == CardanoNativeScriptHashDisplayFormat.POLICY_ID: elif display_format == CardanoNativeScriptHashDisplayFormat.POLICY_ID:
await confirm_blob( await layouts.confirm_blob(
ctx, ctx,
"verify_script", "verify_script",
title="Verify script", "Verify script",
data=script_hash, script_hash,
description="Policy ID:", "Policy ID:",
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def show_tx_init(ctx: wire.Context, title: str) -> bool: async def show_tx_init(ctx: Context, title: str) -> bool:
should_show_details = await should_show_more( should_show_details = await layouts.should_show_more(
ctx, ctx,
"Confirm transaction", "Confirm transaction",
( (
@ -185,7 +191,7 @@ async def show_tx_init(ctx: wire.Context, title: str) -> bool:
), ),
(ui.NORMAL, "Choose level of details:"), (ui.NORMAL, "Choose level of details:"),
), ),
button_text="Show All", "Show All",
icon=ui.ICON_SEND, icon=ui.ICON_SEND,
icon_color=ui.GREEN, icon_color=ui.GREEN,
confirm="Show Simple", confirm="Show Simple",
@ -195,21 +201,21 @@ async def show_tx_init(ctx: wire.Context, title: str) -> bool:
return should_show_details return should_show_details
async def confirm_input(ctx: wire.Context, input: messages.CardanoTxInput) -> None: async def confirm_input(ctx: Context, input: messages.CardanoTxInput) -> None:
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_input", "confirm_input",
title="Confirm transaction", "Confirm transaction",
props=[ (
("Input ID:", input.prev_hash), ("Input ID:", input.prev_hash),
("Input index:", str(input.prev_index)), ("Input index:", str(input.prev_index)),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_sending( async def confirm_sending(
ctx: wire.Context, ctx: Context,
ada_amount: int, ada_amount: int,
to: str, to: str,
output_type: Literal["address", "change", "collateral-return"], output_type: Literal["address", "change", "collateral-return"],
@ -224,30 +230,30 @@ async def confirm_sending(
else: else:
raise RuntimeError # should be unreachable raise RuntimeError # should be unreachable
await confirm_output( await layouts.confirm_output(
ctx, ctx,
to, to,
format_coin_amount(ada_amount, network_id), format_coin_amount(ada_amount, network_id),
title="Confirm transaction", ui.BOLD,
subtitle=f"{message}:", "Confirm transaction",
font_amount=ui.BOLD, f"{message}:",
width_paginated=17, width_paginated=17,
to_str="\nto\n", to_str="\nto\n",
to_paginated=True, to_paginated=True,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_sending_token( async def confirm_sending_token(
ctx: wire.Context, policy_id: bytes, token: messages.CardanoToken ctx: Context, policy_id: bytes, token: messages.CardanoToken
) -> None: ) -> None:
assert token.amount is not None # _validate_token assert token.amount is not None # _validate_token
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_token", "confirm_token",
title="Confirm transaction", "Confirm transaction",
props=[ (
( (
"Asset fingerprint:", "Asset fingerprint:",
format_asset_fingerprint( format_asset_fingerprint(
@ -256,28 +262,28 @@ async def confirm_sending_token(
), ),
), ),
("Amount sent:", format_amount(token.amount, 0)), ("Amount sent:", format_amount(token.amount, 0)),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_datum_hash(ctx: wire.Context, datum_hash: bytes) -> None: async def confirm_datum_hash(ctx: Context, datum_hash: bytes) -> None:
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_datum_hash", "confirm_datum_hash",
title="Confirm transaction", "Confirm transaction",
props=[ (
( (
"Datum hash:", "Datum hash:",
bech32.encode(bech32.HRP_OUTPUT_DATUM_HASH, datum_hash), bech32.encode(bech32.HRP_OUTPUT_DATUM_HASH, datum_hash),
), ),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_inline_datum( async def confirm_inline_datum(
ctx: wire.Context, first_chunk: bytes, inline_datum_size: int ctx: Context, first_chunk: bytes, inline_datum_size: int
) -> None: ) -> None:
await _confirm_data_chunk( await _confirm_data_chunk(
ctx, ctx,
@ -289,7 +295,7 @@ async def confirm_inline_datum(
async def confirm_reference_script( async def confirm_reference_script(
ctx: wire.Context, first_chunk: bytes, reference_script_size: int ctx: Context, first_chunk: bytes, reference_script_size: int
) -> None: ) -> None:
await _confirm_data_chunk( await _confirm_data_chunk(
ctx, ctx,
@ -301,7 +307,7 @@ async def confirm_reference_script(
async def _confirm_data_chunk( async def _confirm_data_chunk(
ctx: wire.Context, br_type: str, title: str, first_chunk: bytes, data_size: int ctx: Context, br_type: str, title: str, first_chunk: bytes, data_size: int
) -> None: ) -> None:
MAX_DISPLAYED_SIZE = 56 MAX_DISPLAYED_SIZE = 56
displayed_bytes = first_chunk[:MAX_DISPLAYED_SIZE] displayed_bytes = first_chunk[:MAX_DISPLAYED_SIZE]
@ -319,12 +325,12 @@ async def _confirm_data_chunk(
br_type, br_type,
title="Confirm transaction", title="Confirm transaction",
props=props, props=props,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def show_credentials( async def show_credentials(
ctx: wire.Context, ctx: Context,
payment_credential: Credential, payment_credential: Credential,
stake_credential: Credential, stake_credential: Credential,
) -> None: ) -> None:
@ -334,7 +340,7 @@ async def show_credentials(
async def show_change_output_credentials( async def show_change_output_credentials(
ctx: wire.Context, ctx: Context,
payment_credential: Credential, payment_credential: Credential,
stake_credential: Credential, stake_credential: Credential,
) -> None: ) -> None:
@ -344,7 +350,7 @@ async def show_change_output_credentials(
async def show_device_owned_output_credentials( async def show_device_owned_output_credentials(
ctx: wire.Context, ctx: Context,
payment_credential: Credential, payment_credential: Credential,
stake_credential: Credential, stake_credential: Credential,
show_both_credentials: bool, show_both_credentials: bool,
@ -356,24 +362,26 @@ async def show_device_owned_output_credentials(
async def _show_credential( async def _show_credential(
ctx: wire.Context, ctx: Context,
credential: Credential, credential: Credential,
intro_text: str, intro_text: str,
is_output: bool, is_output: bool,
) -> None: ) -> None:
if is_output: title = (
title = "Confirm transaction" "Confirm transaction"
else: if is_output
title = f"{ADDRESS_TYPE_NAMES[credential.address_type]} address" else f"{ADDRESS_TYPE_NAMES[credential.address_type]} address"
)
props: list[PropertyType] = [] props: list[PropertyType] = []
append = props.append # local_cache_attribute
# Credential can be empty in case of enterprise address stake credential # Credential can be empty in case of enterprise address stake credential
# and reward address payment credential. In that case we don't want to # and reward address payment credential. In that case we don't want to
# show some of the "props". # show some of the "props".
if credential.is_set(): if credential.is_set():
credential_title = credential.get_title() credential_title = credential.get_title()
props.append( append(
( (
f"{intro_text} {credential.type_name} credential is a {credential_title}:", f"{intro_text} {credential.type_name} credential is a {credential_title}:",
None, None,
@ -382,13 +390,13 @@ async def _show_credential(
props.extend(credential.format()) props.extend(credential.format())
if credential.is_unusual_path: if credential.is_unusual_path:
props.append((None, "Path is unusual.")) append((None, "Path is unusual."))
if credential.is_mismatch: if credential.is_mismatch:
props.append((None, "Credential doesn't match payment credential.")) append((None, "Credential doesn't match payment credential."))
if credential.is_reward: if credential.is_reward:
props.append(("Address is a reward address.", None)) append(("Address is a reward address.", None))
if credential.is_no_staking: if credential.is_no_staking:
props.append( append(
( (
f"{ADDRESS_TYPE_NAMES[credential.address_type]} address - no staking rewards.", f"{ADDRESS_TYPE_NAMES[credential.address_type]} address - no staking rewards.",
None, None,
@ -405,90 +413,93 @@ async def _show_credential(
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_credential", "confirm_credential",
title=title, title,
props=props, props,
icon=icon, icon,
icon_color=icon_color, icon_color,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def warn_path(ctx: wire.Context, path: list[int], title: str) -> None: async def warn_path(ctx: Context, path: list[int], title: str) -> None:
await confirm_path_warning(ctx, address_n_to_str(path), path_type=title) await layouts.confirm_path_warning(ctx, address_n_to_str(path), path_type=title)
async def warn_tx_output_contains_tokens( async def warn_tx_output_contains_tokens(
ctx: wire.Context, is_collateral_return: bool = False ctx: Context, is_collateral_return: bool = False
) -> None: ) -> None:
if is_collateral_return: content = (
content = "The collateral return\noutput contains tokens." "The collateral return\noutput contains tokens."
else: if is_collateral_return
content = "The following\ntransaction output\ncontains tokens." else "The following\ntransaction output\ncontains tokens."
)
await confirm_metadata( await confirm_metadata(
ctx, ctx,
"confirm_tokens", "confirm_tokens",
title="Confirm transaction", "Confirm transaction",
content=content, content,
larger_vspace=True, larger_vspace=True,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def warn_tx_contains_mint(ctx: wire.Context) -> None: async def warn_tx_contains_mint(ctx: Context) -> None:
await confirm_metadata( await confirm_metadata(
ctx, ctx,
"confirm_tokens", "confirm_tokens",
title="Confirm transaction", "Confirm transaction",
content="The transaction contains minting or burning of tokens.", "The transaction contains minting or burning of tokens.",
larger_vspace=True, larger_vspace=True,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def warn_tx_output_no_datum(ctx: wire.Context) -> None: async def warn_tx_output_no_datum(ctx: Context) -> None:
await confirm_metadata( await confirm_metadata(
ctx, ctx,
"confirm_no_datum_hash", "confirm_no_datum_hash",
title="Confirm transaction", "Confirm transaction",
content="The following transaction output contains a script address, but does not contain a datum.", "The following transaction output contains a script address, but does not contain a datum.",
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def warn_no_script_data_hash(ctx: wire.Context) -> None: async def warn_no_script_data_hash(ctx: Context) -> None:
await confirm_metadata( await confirm_metadata(
ctx, ctx,
"confirm_no_script_data_hash", "confirm_no_script_data_hash",
title="Confirm transaction", "Confirm transaction",
content="The transaction contains no script data hash. Plutus script will not be able to run.", "The transaction contains no script data hash. Plutus script will not be able to run.",
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def warn_no_collateral_inputs(ctx: wire.Context) -> None: async def warn_no_collateral_inputs(ctx: Context) -> None:
await confirm_metadata( await confirm_metadata(
ctx, ctx,
"confirm_no_collateral_inputs", "confirm_no_collateral_inputs",
title="Confirm transaction", "Confirm transaction",
content="The transaction contains no collateral inputs. Plutus script will not be able to run.", "The transaction contains no collateral inputs. Plutus script will not be able to run.",
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def warn_unknown_total_collateral(ctx: wire.Context) -> None: async def warn_unknown_total_collateral(ctx: Context) -> None:
await confirm_metadata( await confirm_metadata(
ctx, ctx,
"confirm_unknown_total_collateral", "confirm_unknown_total_collateral",
title="Warning", "Warning",
content="Unknown collateral amount, check all items carefully.", "Unknown collateral amount, check all items carefully.",
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_witness_request( async def confirm_witness_request(
ctx: wire.Context, ctx: Context,
witness_path: list[int], witness_path: list[int],
) -> None: ) -> None:
from . import seed
if seed.is_multisig_path(witness_path): if seed.is_multisig_path(witness_path):
path_title = "multi-sig path" path_title = "multi-sig path"
elif seed.is_minting_path(witness_path): elif seed.is_minting_path(witness_path):
@ -496,18 +507,18 @@ async def confirm_witness_request(
else: else:
path_title = "path" path_title = "path"
await confirm_text( await layouts.confirm_text(
ctx, ctx,
"confirm_total", "confirm_total",
title="Confirm transaction", "Confirm transaction",
data=address_n_to_str(witness_path), address_n_to_str(witness_path),
description=f"Sign transaction with {path_title}:", f"Sign transaction with {path_title}:",
br_code=ButtonRequestType.Other, BRT_Other,
) )
async def confirm_tx( async def confirm_tx(
ctx: wire.Context, ctx: Context,
fee: int, fee: int,
network_id: int, network_id: int,
protocol_magic: int, protocol_magic: int,
@ -520,33 +531,32 @@ async def confirm_tx(
props: list[PropertyType] = [ props: list[PropertyType] = [
("Transaction fee:", format_coin_amount(fee, network_id)), ("Transaction fee:", format_coin_amount(fee, network_id)),
] ]
append = props.append # local_cache_attribute
if total_collateral is not None: if total_collateral is not None:
props.append( append(("Total collateral:", format_coin_amount(total_collateral, network_id)))
("Total collateral:", format_coin_amount(total_collateral, network_id))
)
if is_network_id_verifiable: if is_network_id_verifiable:
props.append((f"Network: {protocol_magics.to_ui_string(protocol_magic)}", None)) append((f"Network: {protocol_magics.to_ui_string(protocol_magic)}", None))
props.append((f"Valid since: {format_optional_int(validity_interval_start)}", None)) append((f"Valid since: {format_optional_int(validity_interval_start)}", None))
props.append((f"TTL: {format_optional_int(ttl)}", None)) append((f"TTL: {format_optional_int(ttl)}", None))
if tx_hash: if tx_hash:
props.append(("Transaction ID:", tx_hash)) append(("Transaction ID:", tx_hash))
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_total", "confirm_total",
title="Confirm transaction", "Confirm transaction",
props=props, props,
hold=True, hold=True,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_certificate( async def confirm_certificate(
ctx: wire.Context, certificate: messages.CardanoTxCertificate ctx: Context, certificate: messages.CardanoTxCertificate
) -> None: ) -> None:
# stake pool registration requires custom confirmation logic not covered # stake pool registration requires custom confirmation logic not covered
# in this call # in this call
@ -566,14 +576,14 @@ async def confirm_certificate(
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_certificate", "confirm_certificate",
title="Confirm transaction", "Confirm transaction",
props=props, props,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_stake_pool_parameters( async def confirm_stake_pool_parameters(
ctx: wire.Context, ctx: Context,
pool_parameters: messages.CardanoPoolParametersType, pool_parameters: messages.CardanoPoolParametersType,
network_id: int, network_id: int,
) -> None: ) -> None:
@ -584,8 +594,8 @@ async def confirm_stake_pool_parameters(
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_pool_registration", "confirm_pool_registration",
title="Confirm transaction", "Confirm transaction",
props=[ (
( (
"Stake pool registration\nPool ID:", "Stake pool registration\nPool ID:",
format_stake_pool_id(pool_parameters.pool_id), format_stake_pool_id(pool_parameters.pool_id),
@ -597,18 +607,20 @@ async def confirm_stake_pool_parameters(
+ f"Margin: {percentage_formatted}%", + f"Margin: {percentage_formatted}%",
None, None,
), ),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_stake_pool_owner( async def confirm_stake_pool_owner(
ctx: wire.Context, ctx: Context,
keychain: seed.Keychain, keychain: Keychain,
owner: messages.CardanoPoolOwner, owner: messages.CardanoPoolOwner,
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
) -> None: ) -> None:
from trezor import messages
props: list[tuple[str, str | None]] = [] props: list[tuple[str, str | None]] = []
if owner.staking_key_path: if owner.staking_key_path:
props.append(("Pool owner:", address_n_to_str(owner.staking_key_path))) props.append(("Pool owner:", address_n_to_str(owner.staking_key_path)))
@ -646,40 +658,40 @@ async def confirm_stake_pool_owner(
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_pool_owners", "confirm_pool_owners",
title="Confirm transaction", "Confirm transaction",
props=props, props,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_stake_pool_metadata( async def confirm_stake_pool_metadata(
ctx: wire.Context, ctx: Context,
metadata: messages.CardanoPoolMetadataType | None, metadata: messages.CardanoPoolMetadataType | None,
) -> None: ) -> None:
if metadata is None: if metadata is None:
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_pool_metadata", "confirm_pool_metadata",
title="Confirm transaction", "Confirm transaction",
props=[("Pool has no metadata (anonymous pool)", None)], (("Pool has no metadata (anonymous pool)", None),),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
return return
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_pool_metadata", "confirm_pool_metadata",
title="Confirm transaction", "Confirm transaction",
props=[ (
("Pool metadata url:", metadata.url), ("Pool metadata url:", metadata.url),
("Pool metadata hash:", metadata.hash), ("Pool metadata hash:", metadata.hash),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_stake_pool_registration_final( async def confirm_stake_pool_registration_final(
ctx: wire.Context, ctx: Context,
protocol_magic: int, protocol_magic: int,
ttl: int | None, ttl: int | None,
validity_interval_start: int | None, validity_interval_start: int | None,
@ -687,20 +699,20 @@ async def confirm_stake_pool_registration_final(
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_pool_final", "confirm_pool_final",
title="Confirm transaction", "Confirm transaction",
props=[ (
("Confirm signing the stake pool registration as an owner.", None), ("Confirm signing the stake pool registration as an owner.", None),
("Network:", protocol_magics.to_ui_string(protocol_magic)), ("Network:", protocol_magics.to_ui_string(protocol_magic)),
("Valid since:", format_optional_int(validity_interval_start)), ("Valid since:", format_optional_int(validity_interval_start)),
("TTL:", format_optional_int(ttl)), ("TTL:", format_optional_int(ttl)),
], ),
hold=True, hold=True,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_withdrawal( async def confirm_withdrawal(
ctx: wire.Context, ctx: Context,
withdrawal: messages.CardanoTxWithdrawal, withdrawal: messages.CardanoTxWithdrawal,
address_bytes: bytes, address_bytes: bytes,
network_id: int, network_id: int,
@ -723,15 +735,17 @@ async def confirm_withdrawal(
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_withdrawal", "confirm_withdrawal",
title="Confirm transaction", "Confirm transaction",
props=props, props,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
def _format_stake_credential( def _format_stake_credential(
path: list[int], script_hash: bytes | None, key_hash: bytes | None path: list[int], script_hash: bytes | None, key_hash: bytes | None
) -> tuple[str, str]: ) -> tuple[str, str]:
from .helpers.utils import to_account_path
if path: if path:
return ( return (
f"for account {format_account_number(path)}:", f"for account {format_account_number(path)}:",
@ -747,7 +761,7 @@ def _format_stake_credential(
async def confirm_governance_registration_delegation( async def confirm_governance_registration_delegation(
ctx: wire.Context, ctx: Context,
public_key: str, public_key: str,
weight: int, weight: int,
) -> None: ) -> None:
@ -768,7 +782,7 @@ async def confirm_governance_registration_delegation(
async def confirm_governance_registration( async def confirm_governance_registration(
ctx: wire.Context, ctx: Context,
public_key: str | None, public_key: str | None,
staking_path: list[int], staking_path: list[int],
reward_address: str, reward_address: str,
@ -805,101 +819,99 @@ async def confirm_governance_registration(
) )
async def show_auxiliary_data_hash( async def show_auxiliary_data_hash(ctx: Context, auxiliary_data_hash: bytes) -> None:
ctx: wire.Context, auxiliary_data_hash: bytes
) -> None:
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_auxiliary_data", "confirm_auxiliary_data",
title="Confirm transaction", "Confirm transaction",
props=[("Auxiliary data hash:", auxiliary_data_hash)], (("Auxiliary data hash:", auxiliary_data_hash),),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_token_minting( async def confirm_token_minting(
ctx: wire.Context, policy_id: bytes, token: messages.CardanoToken ctx: Context, policy_id: bytes, token: messages.CardanoToken
) -> None: ) -> None:
assert token.mint_amount is not None # _validate_token assert token.mint_amount is not None # _validate_token
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_mint", "confirm_mint",
title="Confirm transaction", "Confirm transaction",
props=[ (
( (
"Asset fingerprint:", "Asset fingerprint:",
format_asset_fingerprint( format_asset_fingerprint(
policy_id=policy_id, policy_id,
asset_name_bytes=token.asset_name_bytes, token.asset_name_bytes,
), ),
), ),
( (
"Amount minted:" if token.mint_amount >= 0 else "Amount burned:", "Amount minted:" if token.mint_amount >= 0 else "Amount burned:",
format_amount(token.mint_amount, 0), format_amount(token.mint_amount, 0),
), ),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def warn_tx_network_unverifiable(ctx: wire.Context) -> None: async def warn_tx_network_unverifiable(ctx: Context) -> None:
await confirm_metadata( await confirm_metadata(
ctx, ctx,
"warning_no_outputs", "warning_no_outputs",
title="Warning", "Warning",
content="Transaction has no outputs, network cannot be verified.", "Transaction has no outputs, network cannot be verified.",
larger_vspace=True, larger_vspace=True,
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_script_data_hash(ctx: wire.Context, script_data_hash: bytes) -> None: async def confirm_script_data_hash(ctx: Context, script_data_hash: bytes) -> None:
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_script_data_hash", "confirm_script_data_hash",
title="Confirm transaction", "Confirm transaction",
props=[ (
( (
"Script data hash:", "Script data hash:",
bech32.encode(bech32.HRP_SCRIPT_DATA_HASH, script_data_hash), bech32.encode(bech32.HRP_SCRIPT_DATA_HASH, script_data_hash),
) ),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_collateral_input( async def confirm_collateral_input(
ctx: wire.Context, collateral_input: messages.CardanoTxCollateralInput ctx: Context, collateral_input: messages.CardanoTxCollateralInput
) -> None: ) -> None:
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_collateral_input", "confirm_collateral_input",
title="Confirm transaction", "Confirm transaction",
props=[ (
("Collateral input ID:", collateral_input.prev_hash), ("Collateral input ID:", collateral_input.prev_hash),
("Collateral input index:", str(collateral_input.prev_index)), ("Collateral input index:", str(collateral_input.prev_index)),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_reference_input( async def confirm_reference_input(
ctx: wire.Context, reference_input: messages.CardanoTxReferenceInput ctx: Context, reference_input: messages.CardanoTxReferenceInput
) -> None: ) -> None:
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_reference_input", "confirm_reference_input",
title="Confirm transaction", "Confirm transaction",
props=[ (
("Reference input ID:", reference_input.prev_hash), ("Reference input ID:", reference_input.prev_hash),
("Reference input index:", str(reference_input.prev_index)), ("Reference input index:", str(reference_input.prev_index)),
], ),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def confirm_required_signer( async def confirm_required_signer(
ctx: wire.Context, required_signer: messages.CardanoTxRequiredSigner ctx: Context, required_signer: messages.CardanoTxRequiredSigner
) -> None: ) -> None:
assert ( assert (
required_signer.key_hash is not None or required_signer.key_path required_signer.key_hash is not None or required_signer.key_path
@ -913,18 +925,20 @@ async def confirm_required_signer(
await confirm_properties( await confirm_properties(
ctx, ctx,
"confirm_required_signer", "confirm_required_signer",
title="Confirm transaction", "Confirm transaction",
props=[("Required signer", formatted_signer)], (("Required signer", formatted_signer),),
br_code=ButtonRequestType.Other, br_code=BRT_Other,
) )
async def show_cardano_address( async def show_cardano_address(
ctx: wire.Context, ctx: Context,
address_parameters: messages.CardanoAddressParametersType, address_parameters: messages.CardanoAddressParametersType,
address: str, address: str,
protocol_magic: int, protocol_magic: int,
) -> None: ) -> None:
CAT = CardanoAddressType # local_cache_global
network_name = None network_name = None
if not protocol_magics.is_mainnet(protocol_magic): if not protocol_magics.is_mainnet(protocol_magic):
network_name = protocol_magics.to_ui_string(protocol_magic) network_name = protocol_magics.to_ui_string(protocol_magic)
@ -933,12 +947,12 @@ async def show_cardano_address(
address_extra = None address_extra = None
title_qr = title title_qr = title
if address_parameters.address_type in ( if address_parameters.address_type in (
CardanoAddressType.BYRON, CAT.BYRON,
CardanoAddressType.BASE, CAT.BASE,
CardanoAddressType.BASE_KEY_SCRIPT, CAT.BASE_KEY_SCRIPT,
CardanoAddressType.POINTER, CAT.POINTER,
CardanoAddressType.ENTERPRISE, CAT.ENTERPRISE,
CardanoAddressType.REWARD, CAT.REWARD,
): ):
if address_parameters.address_n: if address_parameters.address_n:
address_extra = address_n_to_str(address_parameters.address_n) address_extra = address_n_to_str(address_parameters.address_n)
@ -947,9 +961,9 @@ async def show_cardano_address(
address_extra = address_n_to_str(address_parameters.address_n_staking) address_extra = address_n_to_str(address_parameters.address_n_staking)
title_qr = address_n_to_str(address_parameters.address_n_staking) title_qr = address_n_to_str(address_parameters.address_n_staking)
await show_address( await layouts.show_address(
ctx, ctx,
address=address, address,
title=title, title=title,
network=network_name, network=network_name,
address_extra=address_extra, address_extra=address_extra,

View File

@ -1,105 +1,111 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import messages, wire
from trezor.crypto import hashlib
from trezor.enums import CardanoNativeScriptType from trezor.enums import CardanoNativeScriptType
from trezor.wire import ProcessError
from apps.common import cbor
from . import seed
from .helpers import ADDRESS_KEY_HASH_SIZE, SCRIPT_HASH_SIZE
from .helpers.paths import SCHEMA_MINT
from .helpers.utils import get_public_key_hash
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any from typing import Any
from trezor import messages
from apps.common.cbor import CborSequence from apps.common.cbor import CborSequence
from . import seed
def validate_native_script(script: messages.CardanoNativeScript | None) -> None: def validate_native_script(script: messages.CardanoNativeScript | None) -> None:
INVALID_NATIVE_SCRIPT = wire.ProcessError("Invalid native script") from .helpers import ADDRESS_KEY_HASH_SIZE
from .helpers.paths import SCHEMA_MINT
from . import seed
INVALID_NATIVE_SCRIPT = ProcessError("Invalid native script")
if not script: if not script:
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
_validate_native_script_structure(script) _validate_native_script_structure(script)
script_type = script.type # local_cache_attribute
key_path = script.key_path # local_cache_attribute
scripts = script.scripts # local_cache_attribute
CNST = CardanoNativeScriptType # local_cache_global
if script.type == CardanoNativeScriptType.PUB_KEY: if script_type == CNST.PUB_KEY:
if script.key_hash and script.key_path: if script.key_hash and key_path:
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
if script.key_hash: if script.key_hash:
if len(script.key_hash) != ADDRESS_KEY_HASH_SIZE: if len(script.key_hash) != ADDRESS_KEY_HASH_SIZE:
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
elif script.key_path: elif key_path:
is_minting = SCHEMA_MINT.match(script.key_path) is_minting = SCHEMA_MINT.match(key_path)
if not seed.is_multisig_path(script.key_path) and not is_minting: if not seed.is_multisig_path(key_path) and not is_minting:
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
else: else:
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
elif script.type == CardanoNativeScriptType.ALL: elif script_type == CNST.ALL:
for sub_script in script.scripts: for sub_script in scripts:
validate_native_script(sub_script) validate_native_script(sub_script)
elif script.type == CardanoNativeScriptType.ANY: elif script_type == CNST.ANY:
for sub_script in script.scripts: for sub_script in scripts:
validate_native_script(sub_script) validate_native_script(sub_script)
elif script.type == CardanoNativeScriptType.N_OF_K: elif script_type == CNST.N_OF_K:
if script.required_signatures_count is None: if script.required_signatures_count is None:
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
if script.required_signatures_count > len(script.scripts): if script.required_signatures_count > len(scripts):
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
for sub_script in script.scripts: for sub_script in scripts:
validate_native_script(sub_script) validate_native_script(sub_script)
elif script.type == CardanoNativeScriptType.INVALID_BEFORE: elif script_type == CNST.INVALID_BEFORE:
if script.invalid_before is None: if script.invalid_before is None:
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
elif script.type == CardanoNativeScriptType.INVALID_HEREAFTER: elif script_type == CNST.INVALID_HEREAFTER:
if script.invalid_hereafter is None: if script.invalid_hereafter is None:
raise INVALID_NATIVE_SCRIPT raise INVALID_NATIVE_SCRIPT
def _validate_native_script_structure(script: messages.CardanoNativeScript) -> None: def _validate_native_script_structure(script: messages.CardanoNativeScript) -> None:
key_hash = script.key_hash key_hash = script.key_hash # local_cache_attribute
key_path = script.key_path key_path = script.key_path # local_cache_attribute
scripts = script.scripts scripts = script.scripts # local_cache_attribute
required_signatures_count = script.required_signatures_count required_signatures_count = (
invalid_before = script.invalid_before script.required_signatures_count
invalid_hereafter = script.invalid_hereafter ) # local_cache_attribute
invalid_before = script.invalid_before # local_cache_attribute
invalid_hereafter = script.invalid_hereafter # local_cache_attribute
CNST = CardanoNativeScriptType # local_cache_global
fields_to_be_empty: dict[CardanoNativeScriptType, tuple[Any, ...]] = { fields_to_be_empty: dict[CNST, tuple[Any, ...]] = {
CardanoNativeScriptType.PUB_KEY: ( CNST.PUB_KEY: (
scripts, scripts,
required_signatures_count, required_signatures_count,
invalid_before, invalid_before,
invalid_hereafter, invalid_hereafter,
), ),
CardanoNativeScriptType.ALL: ( CNST.ALL: (
key_hash, key_hash,
key_path, key_path,
required_signatures_count, required_signatures_count,
invalid_before, invalid_before,
invalid_hereafter, invalid_hereafter,
), ),
CardanoNativeScriptType.ANY: ( CNST.ANY: (
key_hash, key_hash,
key_path, key_path,
required_signatures_count, required_signatures_count,
invalid_before, invalid_before,
invalid_hereafter, invalid_hereafter,
), ),
CardanoNativeScriptType.N_OF_K: ( CNST.N_OF_K: (
key_hash, key_hash,
key_path, key_path,
invalid_before, invalid_before,
invalid_hereafter, invalid_hereafter,
), ),
CardanoNativeScriptType.INVALID_BEFORE: ( CNST.INVALID_BEFORE: (
key_hash, key_hash,
key_path, key_path,
required_signatures_count, required_signatures_count,
invalid_hereafter, invalid_hereafter,
), ),
CardanoNativeScriptType.INVALID_HEREAFTER: ( CNST.INVALID_HEREAFTER: (
key_hash, key_hash,
key_path, key_path,
required_signatures_count, required_signatures_count,
@ -108,12 +114,16 @@ def _validate_native_script_structure(script: messages.CardanoNativeScript) -> N
} }
if script.type not in fields_to_be_empty or any(fields_to_be_empty[script.type]): if script.type not in fields_to_be_empty or any(fields_to_be_empty[script.type]):
raise wire.ProcessError("Invalid native script") raise ProcessError("Invalid native script")
def get_native_script_hash( def get_native_script_hash(
keychain: seed.Keychain, script: messages.CardanoNativeScript keychain: seed.Keychain, script: messages.CardanoNativeScript
) -> bytes: ) -> bytes:
from .helpers import SCRIPT_HASH_SIZE
from trezor.crypto import hashlib
from apps.common import cbor
script_cbor = cbor.encode(cborize_native_script(keychain, script)) script_cbor = cbor.encode(cborize_native_script(keychain, script))
prefixed_script_cbor = b"\00" + script_cbor prefixed_script_cbor = b"\00" + script_cbor
return hashlib.blake2b(data=prefixed_script_cbor, outlen=SCRIPT_HASH_SIZE).digest() return hashlib.blake2b(data=prefixed_script_cbor, outlen=SCRIPT_HASH_SIZE).digest()
@ -122,29 +132,34 @@ def get_native_script_hash(
def cborize_native_script( def cborize_native_script(
keychain: seed.Keychain, script: messages.CardanoNativeScript keychain: seed.Keychain, script: messages.CardanoNativeScript
) -> CborSequence: ) -> CborSequence:
from .helpers.utils import get_public_key_hash
script_type = script.type # local_cache_attribute
CNST = CardanoNativeScriptType # local_cache_global
script_content: CborSequence script_content: CborSequence
if script.type == CardanoNativeScriptType.PUB_KEY: if script_type == CNST.PUB_KEY:
if script.key_hash: if script.key_hash:
script_content = (script.key_hash,) script_content = (script.key_hash,)
elif script.key_path: elif script.key_path:
script_content = (get_public_key_hash(keychain, script.key_path),) script_content = (get_public_key_hash(keychain, script.key_path),)
else: else:
raise wire.ProcessError("Invalid native script") raise ProcessError("Invalid native script")
elif script.type == CardanoNativeScriptType.ALL: elif script_type == CNST.ALL:
script_content = ( script_content = (
tuple( tuple(
cborize_native_script(keychain, sub_script) cborize_native_script(keychain, sub_script)
for sub_script in script.scripts for sub_script in script.scripts
), ),
) )
elif script.type == CardanoNativeScriptType.ANY: elif script_type == CNST.ANY:
script_content = ( script_content = (
tuple( tuple(
cborize_native_script(keychain, sub_script) cborize_native_script(keychain, sub_script)
for sub_script in script.scripts for sub_script in script.scripts
), ),
) )
elif script.type == CardanoNativeScriptType.N_OF_K: elif script_type == CNST.N_OF_K:
script_content = ( script_content = (
script.required_signatures_count, script.required_signatures_count,
tuple( tuple(
@ -152,11 +167,11 @@ def cborize_native_script(
for sub_script in script.scripts for sub_script in script.scripts
), ),
) )
elif script.type == CardanoNativeScriptType.INVALID_BEFORE: elif script_type == CNST.INVALID_BEFORE:
script_content = (script.invalid_before,) script_content = (script.invalid_before,)
elif script.type == CardanoNativeScriptType.INVALID_HEREAFTER: elif script_type == CNST.INVALID_HEREAFTER:
script_content = (script.invalid_hereafter,) script_content = (script.invalid_hereafter,)
else: else:
raise RuntimeError # should be unreachable raise RuntimeError # should be unreachable
return (script.type,) + script_content return (script_type,) + script_content

View File

@ -2,13 +2,12 @@ from typing import TYPE_CHECKING
from storage import cache, device from storage import cache, device
from trezor import wire from trezor import wire
from trezor.crypto import bip32, cardano from trezor.crypto import cardano
from trezor.enums import CardanoDerivationType
from apps.common import mnemonic from apps.common import mnemonic
from apps.common.seed import derive_and_store_roots, get_seed from apps.common.seed import get_seed
from .helpers import paths from .helpers.paths import BYRON_ROOT, MINTING_ROOT, MULTISIG_ROOT, SHELLEY_ROOT
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Awaitable, TypeVar from typing import Callable, Awaitable, TypeVar
@ -17,6 +16,8 @@ if TYPE_CHECKING:
from apps.common.keychain import MsgOut, Handler from apps.common.keychain import MsgOut, Handler
from trezor import messages from trezor import messages
from trezor.enums import CardanoDerivationType
from trezor.crypto import bip32
CardanoMessages = ( CardanoMessages = (
messages.CardanoGetAddress messages.CardanoGetAddress
@ -36,10 +37,10 @@ class Keychain:
""" """
def __init__(self, root: bip32.HDNode) -> None: def __init__(self, root: bip32.HDNode) -> None:
self.byron_root = self._derive_path(root, paths.BYRON_ROOT) self.byron_root = self._derive_path(root, BYRON_ROOT)
self.shelley_root = self._derive_path(root, paths.SHELLEY_ROOT) self.shelley_root = self._derive_path(root, SHELLEY_ROOT)
self.multisig_root = self._derive_path(root, paths.MULTISIG_ROOT) self.multisig_root = self._derive_path(root, MULTISIG_ROOT)
self.minting_root = self._derive_path(root, paths.MINTING_ROOT) self.minting_root = self._derive_path(root, MINTING_ROOT)
root.__del__() root.__del__()
@staticmethod @staticmethod
@ -79,11 +80,11 @@ class Keychain:
# this is true now, so for simplicity we don't branch on path type # this is true now, so for simplicity we don't branch on path type
assert ( assert (
len(paths.BYRON_ROOT) == len(paths.SHELLEY_ROOT) len(BYRON_ROOT) == len(SHELLEY_ROOT)
and len(paths.MULTISIG_ROOT) == len(paths.SHELLEY_ROOT) and len(MULTISIG_ROOT) == len(SHELLEY_ROOT)
and len(paths.MINTING_ROOT) == len(paths.SHELLEY_ROOT) and len(MINTING_ROOT) == len(SHELLEY_ROOT)
) )
suffix = node_path[len(paths.SHELLEY_ROOT) :] suffix = node_path[len(SHELLEY_ROOT) :]
# derive child node from the root # derive child node from the root
return self._derive_path(path_root, suffix) return self._derive_path(path_root, suffix)
@ -94,19 +95,19 @@ class Keychain:
def is_byron_path(path: Bip32Path) -> bool: def is_byron_path(path: Bip32Path) -> bool:
return path[: len(paths.BYRON_ROOT)] == paths.BYRON_ROOT return path[: len(BYRON_ROOT)] == BYRON_ROOT
def is_shelley_path(path: Bip32Path) -> bool: def is_shelley_path(path: Bip32Path) -> bool:
return path[: len(paths.SHELLEY_ROOT)] == paths.SHELLEY_ROOT return path[: len(SHELLEY_ROOT)] == SHELLEY_ROOT
def is_multisig_path(path: Bip32Path) -> bool: def is_multisig_path(path: Bip32Path) -> bool:
return path[: len(paths.MULTISIG_ROOT)] == paths.MULTISIG_ROOT return path[: len(MULTISIG_ROOT)] == MULTISIG_ROOT
def is_minting_path(path: Bip32Path) -> bool: def is_minting_path(path: Bip32Path) -> bool:
return path[: len(paths.MINTING_ROOT)] == paths.MINTING_ROOT return path[: len(MINTING_ROOT)] == MINTING_ROOT
def derive_and_store_secrets(passphrase: str) -> None: def derive_and_store_secrets(passphrase: str) -> None:
@ -135,18 +136,12 @@ def derive_and_store_secrets(passphrase: str) -> None:
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_secret(ctx: wire.Context, cache_entry: int) -> bytes:
secret = cache.get(cache_entry)
if secret is None:
await derive_and_store_roots(ctx)
secret = cache.get(cache_entry)
assert secret is not None
return secret
async def _get_keychain_bip39( async def _get_keychain_bip39(
ctx: wire.Context, derivation_type: CardanoDerivationType ctx: wire.Context, derivation_type: CardanoDerivationType
) -> Keychain: ) -> Keychain:
from apps.common.seed import derive_and_store_roots
from trezor.enums import CardanoDerivationType
if not device.is_initialized(): if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
@ -162,12 +157,18 @@ async def _get_keychain_bip39(
else: else:
cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET
secret = await _get_secret(ctx, cache_entry) # _get_secret
secret = cache.get(cache_entry)
if secret is None:
await derive_and_store_roots(ctx)
secret = cache.get(cache_entry)
assert secret is not None
root = cardano.from_secret(secret) root = cardano.from_secret(secret)
return Keychain(root) return Keychain(root)
async def get_keychain( async def _get_keychain(
ctx: wire.Context, derivation_type: CardanoDerivationType ctx: wire.Context, derivation_type: CardanoDerivationType
) -> Keychain: ) -> Keychain:
if mnemonic.is_bip39(): if mnemonic.is_bip39():
@ -180,7 +181,7 @@ async def get_keychain(
def with_keychain(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: def with_keychain(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain = await get_keychain(ctx, msg.derivation_type) keychain = await _get_keychain(ctx, msg.derivation_type)
return await func(ctx, msg, keychain) return await func(ctx, msg, keychain)
return wrapper return wrapper

View File

@ -1,30 +1,39 @@
from typing import Type from typing import TYPE_CHECKING
from trezor import log, messages, wire
from trezor.enums import CardanoTxSigningMode
from .. import seed from .. import seed
from .signer import Signer
if TYPE_CHECKING:
from typing import Type
from trezor.wire import Context
from trezor.messages import CardanoSignTxFinished, CardanoSignTxInit
@seed.with_keychain @seed.with_keychain
async def sign_tx( async def sign_tx(
ctx: wire.Context, msg: messages.CardanoSignTxInit, keychain: seed.Keychain ctx: Context, msg: CardanoSignTxInit, keychain: seed.Keychain
) -> messages.CardanoSignTxFinished: ) -> CardanoSignTxFinished:
from trezor.messages import CardanoSignTxFinished
from trezor import log, wire
from trezor.enums import CardanoTxSigningMode
from .signer import Signer
signing_mode = msg.signing_mode # local_cache_attribute
signer_type: Type[Signer] signer_type: Type[Signer]
if msg.signing_mode == CardanoTxSigningMode.ORDINARY_TRANSACTION: if signing_mode == CardanoTxSigningMode.ORDINARY_TRANSACTION:
from .ordinary_signer import OrdinarySigner from .ordinary_signer import OrdinarySigner
signer_type = OrdinarySigner signer_type = OrdinarySigner
elif msg.signing_mode == CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER: elif signing_mode == CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER:
from .pool_owner_signer import PoolOwnerSigner from .pool_owner_signer import PoolOwnerSigner
signer_type = PoolOwnerSigner signer_type = PoolOwnerSigner
elif msg.signing_mode == CardanoTxSigningMode.MULTISIG_TRANSACTION: elif signing_mode == CardanoTxSigningMode.MULTISIG_TRANSACTION:
from .multisig_signer import MultisigSigner from .multisig_signer import MultisigSigner
signer_type = MultisigSigner signer_type = MultisigSigner
elif msg.signing_mode == CardanoTxSigningMode.PLUTUS_TRANSACTION: elif signing_mode == CardanoTxSigningMode.PLUTUS_TRANSACTION:
from .plutus_signer import PlutusSigner from .plutus_signer import PlutusSigner
signer_type = PlutusSigner signer_type = PlutusSigner
@ -40,4 +49,4 @@ async def sign_tx(
log.exception(__name__, e) log.exception(__name__, e)
raise wire.ProcessError("Signing failed") raise wire.ProcessError("Signing failed")
return messages.CardanoSignTxFinished() return CardanoSignTxFinished()

View File

@ -1,10 +1,12 @@
from trezor import messages, wire from typing import TYPE_CHECKING
from trezor.enums import CardanoCertificateType
from trezor.wire import ProcessError
from .. import layout, seed
from ..helpers.paths import SCHEMA_MINT
from .signer import Signer from .signer import Signer
if TYPE_CHECKING:
from trezor import messages
class MultisigSigner(Signer): class MultisigSigner(Signer):
""" """
@ -14,23 +16,30 @@ class MultisigSigner(Signer):
SIGNING_MODE_TITLE = "Confirming a multisig transaction." SIGNING_MODE_TITLE = "Confirming a multisig transaction."
def _validate_tx_init(self) -> None: def _validate_tx_init(self) -> None:
msg = self.msg # local_cache_attribute
_assert_tx_init_cond = self._assert_tx_init_cond # local_cache_attribute
super()._validate_tx_init() super()._validate_tx_init()
self._assert_tx_init_cond(self.msg.collateral_inputs_count == 0) _assert_tx_init_cond(msg.collateral_inputs_count == 0)
self._assert_tx_init_cond(not self.msg.has_collateral_return) _assert_tx_init_cond(not msg.has_collateral_return)
self._assert_tx_init_cond(self.msg.total_collateral is None) _assert_tx_init_cond(msg.total_collateral is None)
self._assert_tx_init_cond(self.msg.reference_inputs_count == 0) _assert_tx_init_cond(msg.reference_inputs_count == 0)
async def _confirm_tx(self, tx_hash: bytes) -> None: async def _confirm_tx(self, tx_hash: bytes) -> None:
from .. import layout
msg = self.msg # local_cache_attribute
# super() omitted intentionally # super() omitted intentionally
is_network_id_verifiable = self._is_network_id_verifiable() is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx( await layout.confirm_tx(
self.ctx, self.ctx,
self.msg.fee, msg.fee,
self.msg.network_id, msg.network_id,
self.msg.protocol_magic, msg.protocol_magic,
self.msg.ttl, msg.ttl,
self.msg.validity_interval_start, msg.validity_interval_start,
self.msg.total_collateral, msg.total_collateral,
is_network_id_verifiable, is_network_id_verifiable,
tx_hash=None, tx_hash=None,
) )
@ -38,23 +47,28 @@ class MultisigSigner(Signer):
def _validate_output(self, output: messages.CardanoTxOutput) -> None: def _validate_output(self, output: messages.CardanoTxOutput) -> None:
super()._validate_output(output) super()._validate_output(output)
if output.address_parameters is not None: if output.address_parameters is not None:
raise wire.ProcessError("Invalid output") raise ProcessError("Invalid output")
def _validate_certificate(self, certificate: messages.CardanoTxCertificate) -> None: def _validate_certificate(self, certificate: messages.CardanoTxCertificate) -> None:
from trezor.enums import CardanoCertificateType
super()._validate_certificate(certificate) super()._validate_certificate(certificate)
if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION: if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION:
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
if certificate.path or certificate.key_hash: if certificate.path or certificate.key_hash:
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
def _validate_withdrawal(self, withdrawal: messages.CardanoTxWithdrawal) -> None: def _validate_withdrawal(self, withdrawal: messages.CardanoTxWithdrawal) -> None:
super()._validate_withdrawal(withdrawal) super()._validate_withdrawal(withdrawal)
if withdrawal.path or withdrawal.key_hash: if withdrawal.path or withdrawal.key_hash:
raise wire.ProcessError("Invalid withdrawal") raise ProcessError("Invalid withdrawal")
def _validate_witness_request( def _validate_witness_request(
self, witness_request: messages.CardanoTxWitnessRequest self, witness_request: messages.CardanoTxWitnessRequest
) -> None: ) -> None:
from .. import seed
from ..helpers.paths import SCHEMA_MINT
super()._validate_witness_request(witness_request) super()._validate_witness_request(witness_request)
is_minting = SCHEMA_MINT.match(witness_request.path) is_minting = SCHEMA_MINT.match(witness_request.path)
tx_has_token_minting = self.msg.minting_asset_groups_count > 0 tx_has_token_minting = self.msg.minting_asset_groups_count > 0
@ -63,4 +77,4 @@ class MultisigSigner(Signer):
seed.is_multisig_path(witness_request.path) seed.is_multisig_path(witness_request.path)
or (is_minting and tx_has_token_minting) or (is_minting and tx_has_token_minting)
): ):
raise wire.ProcessError("Invalid witness request") raise ProcessError("Invalid witness request")

View File

@ -1,15 +1,14 @@
from trezor import messages, wire from typing import TYPE_CHECKING
from trezor.enums import CardanoCertificateType
from .. import layout, seed from trezor.wire import ProcessError
from ..helpers.paths import (
SCHEMA_MINT, from .. import layout
SCHEMA_PAYMENT, from ..helpers.paths import SCHEMA_MINT
SCHEMA_STAKING,
WITNESS_PATH_NAME,
)
from .signer import Signer from .signer import Signer
if TYPE_CHECKING:
from trezor import messages
class OrdinarySigner(Signer): class OrdinarySigner(Signer):
""" """
@ -20,42 +19,51 @@ class OrdinarySigner(Signer):
SIGNING_MODE_TITLE = "Confirming a transaction." SIGNING_MODE_TITLE = "Confirming a transaction."
def _validate_tx_init(self) -> None: def _validate_tx_init(self) -> None:
msg = self.msg # local_cache_attribute
_assert_tx_init_cond = self._assert_tx_init_cond # local_cache_attribute
super()._validate_tx_init() super()._validate_tx_init()
self._assert_tx_init_cond(self.msg.collateral_inputs_count == 0) _assert_tx_init_cond(msg.collateral_inputs_count == 0)
self._assert_tx_init_cond(not self.msg.has_collateral_return) _assert_tx_init_cond(not msg.has_collateral_return)
self._assert_tx_init_cond(self.msg.total_collateral is None) _assert_tx_init_cond(msg.total_collateral is None)
self._assert_tx_init_cond(self.msg.reference_inputs_count == 0) _assert_tx_init_cond(msg.reference_inputs_count == 0)
async def _confirm_tx(self, tx_hash: bytes) -> None: async def _confirm_tx(self, tx_hash: bytes) -> None:
msg = self.msg # local_cache_attribute
# super() omitted intentionally # super() omitted intentionally
is_network_id_verifiable = self._is_network_id_verifiable() is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx( await layout.confirm_tx(
self.ctx, self.ctx,
self.msg.fee, msg.fee,
self.msg.network_id, msg.network_id,
self.msg.protocol_magic, msg.protocol_magic,
self.msg.ttl, msg.ttl,
self.msg.validity_interval_start, msg.validity_interval_start,
self.msg.total_collateral, msg.total_collateral,
is_network_id_verifiable, is_network_id_verifiable,
tx_hash=None, tx_hash=None,
) )
def _validate_certificate(self, certificate: messages.CardanoTxCertificate) -> None: def _validate_certificate(self, certificate: messages.CardanoTxCertificate) -> None:
from trezor.enums import CardanoCertificateType
super()._validate_certificate(certificate) super()._validate_certificate(certificate)
if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION: if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION:
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
if certificate.script_hash or certificate.key_hash: if certificate.script_hash or certificate.key_hash:
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
def _validate_withdrawal(self, withdrawal: messages.CardanoTxWithdrawal) -> None: def _validate_withdrawal(self, withdrawal: messages.CardanoTxWithdrawal) -> None:
super()._validate_withdrawal(withdrawal) super()._validate_withdrawal(withdrawal)
if withdrawal.script_hash or withdrawal.key_hash: if withdrawal.script_hash or withdrawal.key_hash:
raise wire.ProcessError("Invalid withdrawal") raise ProcessError("Invalid withdrawal")
def _validate_witness_request( def _validate_witness_request(
self, witness_request: messages.CardanoTxWitnessRequest self, witness_request: messages.CardanoTxWitnessRequest
) -> None: ) -> None:
from .. import seed
super()._validate_witness_request(witness_request) super()._validate_witness_request(witness_request)
is_minting = SCHEMA_MINT.match(witness_request.path) is_minting = SCHEMA_MINT.match(witness_request.path)
tx_has_token_minting = self.msg.minting_asset_groups_count > 0 tx_has_token_minting = self.msg.minting_asset_groups_count > 0
@ -65,9 +73,15 @@ class OrdinarySigner(Signer):
or seed.is_shelley_path(witness_request.path) or seed.is_shelley_path(witness_request.path)
or (is_minting and tx_has_token_minting) or (is_minting and tx_has_token_minting)
): ):
raise wire.ProcessError("Invalid witness request") raise ProcessError("Invalid witness request")
async def _show_witness_request(self, witness_path: list[int]) -> None: async def _show_witness_request(self, witness_path: list[int]) -> None:
from ..helpers.paths import (
SCHEMA_PAYMENT,
SCHEMA_STAKING,
WITNESS_PATH_NAME,
)
# super() omitted intentionally # super() omitted intentionally
# We only allow payment, staking or minting paths. # We only allow payment, staking or minting paths.
# If the path is an unusual payment or staking path, we either fail or show the # If the path is an unusual payment or staking path, we either fail or show the

View File

@ -1,11 +1,13 @@
from trezor import messages, wire from typing import TYPE_CHECKING
from trezor.enums import CardanoCertificateType
from .. import layout, seed from trezor import wire
from ..helpers.credential import Credential, should_show_credentials
from ..helpers.paths import SCHEMA_MINT from .. import layout
from .signer import Signer from .signer import Signer
if TYPE_CHECKING:
from trezor import messages
class PlutusSigner(Signer): class PlutusSigner(Signer):
""" """
@ -28,6 +30,8 @@ class PlutusSigner(Signer):
await layout.warn_unknown_total_collateral(self.ctx) await layout.warn_unknown_total_collateral(self.ctx)
async def _confirm_tx(self, tx_hash: bytes) -> None: async def _confirm_tx(self, tx_hash: bytes) -> None:
msg = self.msg # local_cache_attribute
# super() omitted intentionally # super() omitted intentionally
# We display tx hash so that experienced users can compare it to the tx hash # We display tx hash so that experienced users can compare it to the tx hash
# computed by a trusted device (in case the tx contains many items which are # computed by a trusted device (in case the tx contains many items which are
@ -35,12 +39,12 @@ class PlutusSigner(Signer):
is_network_id_verifiable = self._is_network_id_verifiable() is_network_id_verifiable = self._is_network_id_verifiable()
await layout.confirm_tx( await layout.confirm_tx(
self.ctx, self.ctx,
self.msg.fee, msg.fee,
self.msg.network_id, msg.network_id,
self.msg.protocol_magic, msg.protocol_magic,
self.msg.ttl, msg.ttl,
self.msg.validity_interval_start, msg.validity_interval_start,
self.msg.total_collateral, msg.total_collateral,
is_network_id_verifiable, is_network_id_verifiable,
tx_hash, tx_hash,
) )
@ -53,6 +57,8 @@ class PlutusSigner(Signer):
async def _show_output_credentials( async def _show_output_credentials(
self, address_parameters: messages.CardanoAddressParametersType self, address_parameters: messages.CardanoAddressParametersType
) -> None: ) -> None:
from ..helpers.credential import Credential, should_show_credentials
# In ordinary txs, change outputs with matching payment and staking paths can be # In ordinary txs, change outputs with matching payment and staking paths can be
# hidden, but we need to show them in Plutus txs because of the script # hidden, but we need to show them in Plutus txs because of the script
# evaluation. We at least hide the staking path if it matches the payment path. # evaluation. We at least hide the staking path if it matches the payment path.
@ -80,6 +86,8 @@ class PlutusSigner(Signer):
return False return False
def _validate_certificate(self, certificate: messages.CardanoTxCertificate) -> None: def _validate_certificate(self, certificate: messages.CardanoTxCertificate) -> None:
from trezor.enums import CardanoCertificateType
super()._validate_certificate(certificate) super()._validate_certificate(certificate)
if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION: if certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION:
raise wire.ProcessError("Invalid certificate") raise wire.ProcessError("Invalid certificate")
@ -87,6 +95,9 @@ class PlutusSigner(Signer):
def _validate_witness_request( def _validate_witness_request(
self, witness_request: messages.CardanoTxWitnessRequest self, witness_request: messages.CardanoTxWitnessRequest
) -> None: ) -> None:
from .. import seed
from ..helpers.paths import SCHEMA_MINT
super()._validate_witness_request(witness_request) super()._validate_witness_request(witness_request)
is_minting = SCHEMA_MINT.match(witness_request.path) is_minting = SCHEMA_MINT.match(witness_request.path)

View File

@ -1,10 +1,12 @@
from trezor import messages, wire from typing import TYPE_CHECKING
from trezor.enums import CardanoCertificateType
from trezor.wire import ProcessError
from .. import layout
from ..helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT
from .signer import Signer from .signer import Signer
if TYPE_CHECKING:
from trezor import messages
class PoolOwnerSigner(Signer): class PoolOwnerSigner(Signer):
""" """
@ -22,18 +24,25 @@ class PoolOwnerSigner(Signer):
SIGNING_MODE_TITLE = "Confirming pool registration as owner." SIGNING_MODE_TITLE = "Confirming pool registration as owner."
def _validate_tx_init(self) -> None: def _validate_tx_init(self) -> None:
msg = self.msg # local_cache_attribute
super()._validate_tx_init() super()._validate_tx_init()
self._assert_tx_init_cond(self.msg.certificates_count == 1) for condition in (
self._assert_tx_init_cond(self.msg.withdrawals_count == 0) msg.certificates_count == 1,
self._assert_tx_init_cond(self.msg.minting_asset_groups_count == 0) msg.withdrawals_count == 0,
self._assert_tx_init_cond(self.msg.script_data_hash is None) msg.minting_asset_groups_count == 0,
self._assert_tx_init_cond(self.msg.collateral_inputs_count == 0) msg.script_data_hash is None,
self._assert_tx_init_cond(self.msg.required_signers_count == 0) msg.collateral_inputs_count == 0,
self._assert_tx_init_cond(not self.msg.has_collateral_return) msg.required_signers_count == 0,
self._assert_tx_init_cond(self.msg.total_collateral is None) not msg.has_collateral_return,
self._assert_tx_init_cond(self.msg.reference_inputs_count == 0) msg.total_collateral is None,
msg.reference_inputs_count == 0,
):
self._assert_tx_init_cond(condition)
async def _confirm_tx(self, tx_hash: bytes) -> None: async def _confirm_tx(self, tx_hash: bytes) -> None:
from .. import layout
# super() omitted intentionally # super() omitted intentionally
await layout.confirm_stake_pool_registration_final( await layout.confirm_stake_pool_registration_final(
self.ctx, self.ctx,
@ -50,7 +59,7 @@ class PoolOwnerSigner(Signer):
or output.inline_datum_size > 0 or output.inline_datum_size > 0
or output.reference_script_size > 0 or output.reference_script_size > 0
): ):
raise wire.ProcessError("Invalid output") raise ProcessError("Invalid output")
def _should_show_output(self, output: messages.CardanoTxOutput) -> bool: def _should_show_output(self, output: messages.CardanoTxOutput) -> bool:
# super() omitted intentionally # super() omitted intentionally
@ -58,16 +67,20 @@ class PoolOwnerSigner(Signer):
return False return False
def _validate_certificate(self, certificate: messages.CardanoTxCertificate) -> None: def _validate_certificate(self, certificate: messages.CardanoTxCertificate) -> None:
from trezor.enums import CardanoCertificateType
super()._validate_certificate(certificate) super()._validate_certificate(certificate)
if certificate.type != CardanoCertificateType.STAKE_POOL_REGISTRATION: if certificate.type != CardanoCertificateType.STAKE_POOL_REGISTRATION:
raise wire.ProcessError("Invalid certificate") raise ProcessError("Invalid certificate")
def _validate_witness_request( def _validate_witness_request(
self, witness_request: messages.CardanoTxWitnessRequest self, witness_request: messages.CardanoTxWitnessRequest
) -> None: ) -> None:
from ..helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT
super()._validate_witness_request(witness_request) super()._validate_witness_request(witness_request)
if not SCHEMA_STAKING_ANY_ACCOUNT.match(witness_request.path): if not SCHEMA_STAKING_ANY_ACCOUNT.match(witness_request.path):
raise wire.ProcessError( raise ProcessError(
"Stakepool registration transaction can only contain staking witnesses" "Stakepool registration transaction can only contain staking witnesses"
) )

View File

@ -1,54 +1,34 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import messages, wire from trezor import messages
from trezor.crypto import hashlib
from trezor.crypto.curve import ed25519
from trezor.enums import ( from trezor.enums import (
CardanoAddressType,
CardanoCertificateType, CardanoCertificateType,
CardanoTxOutputSerializationFormat, CardanoTxOutputSerializationFormat,
CardanoTxWitnessType, CardanoTxWitnessType,
) )
from trezor.messages import CardanoTxItemAck, CardanoTxOutput
from trezor.wire import DataError, ProcessError
from apps.common import cbor, safety_checks from apps.common import safety_checks
from .. import addresses, auxiliary_data, certificates, layout, seed from .. import addresses, certificates, layout, seed
from ..helpers import ( from ..helpers import INPUT_PREV_HASH_SIZE, LOVELACE_MAX_SUPPLY
ADDRESS_KEY_HASH_SIZE, from ..helpers.credential import Credential
INPUT_PREV_HASH_SIZE, from ..helpers.hash_builder_collection import HashBuilderDict, HashBuilderList
LOVELACE_MAX_SUPPLY, from ..helpers.paths import SCHEMA_STAKING
OUTPUT_DATUM_HASH_SIZE, from ..helpers.utils import derive_public_key
SCRIPT_DATA_HASH_SIZE,
)
from ..helpers.account_path_check import AccountPathChecker
from ..helpers.credential import Credential, should_show_credentials
from ..helpers.hash_builder_collection import (
HashBuilderDict,
HashBuilderEmbeddedCBOR,
HashBuilderList,
)
from ..helpers.paths import (
CERTIFICATE_PATH_NAME,
CHANGE_OUTPUT_PATH_NAME,
CHANGE_OUTPUT_STAKING_PATH_NAME,
POOL_OWNER_STAKING_PATH_NAME,
SCHEMA_STAKING,
)
from ..helpers.utils import (
derive_public_key,
get_public_key_hash,
validate_network_info,
validate_stake_credential,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Awaitable, ClassVar from typing import Any, Awaitable, ClassVar
from trezor.wire import Context
from trezor.enums import CardanoAddressType
from apps.common.paths import PathSchema from apps.common.paths import PathSchema
from apps.common import cbor
CardanoTxResponseType = ( from ..helpers.hash_builder_collection import HashBuilderEmbeddedCBOR
messages.CardanoTxItemAck | messages.CardanoTxWitnessResponse
) CardanoTxResponseType = CardanoTxItemAck | messages.CardanoTxWitnessResponse
_MINTING_POLICY_ID_LENGTH = const(28) _MINTING_POLICY_ID_LENGTH = const(28)
_MAX_ASSET_NAME_LENGTH = const(32) _MAX_ASSET_NAME_LENGTH = const(32)
@ -96,10 +76,12 @@ class Signer:
def __init__( def __init__(
self, self,
ctx: wire.Context, ctx: Context,
msg: messages.CardanoSignTxInit, msg: messages.CardanoSignTxInit,
keychain: seed.Keychain, keychain: seed.Keychain,
) -> None: ) -> None:
from ..helpers.account_path_check import AccountPathChecker
self.ctx = ctx self.ctx = ctx
self.msg = msg self.msg = msg
self.keychain = keychain self.keychain = keychain
@ -125,12 +107,14 @@ class Signer:
) )
) )
self.tx_dict: HashBuilderDict[int, Any] = HashBuilderDict( self.tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(
tx_dict_items_count, wire.ProcessError("Invalid tx signing request") tx_dict_items_count, ProcessError("Invalid tx signing request")
) )
self.should_show_details = False self.should_show_details = False
async def sign(self) -> None: async def sign(self) -> None:
from trezor.crypto import hashlib
hash_fn = hashlib.blake2b(outlen=32) hash_fn = hashlib.blake2b(outlen=32)
self.tx_dict.start(hash_fn) self.tx_dict.start(hash_fn)
with self.tx_dict: with self.tx_dict:
@ -150,96 +134,95 @@ class Signer:
async def _processs_tx_init(self) -> None: async def _processs_tx_init(self) -> None:
self._validate_tx_init() self._validate_tx_init()
await self._show_tx_init() await self._show_tx_init()
msg = self.msg # local_cache_attribute
add = self.tx_dict.add # local_cache_attribute
HBL = HashBuilderList # local_cache_global
inputs_list: HashBuilderList[tuple[bytes, int]] = HashBuilderList( inputs_list: HashBuilderList[tuple[bytes, int]] = HBL(msg.inputs_count)
self.msg.inputs_count with add(_TX_BODY_KEY_INPUTS, inputs_list):
)
with self.tx_dict.add(_TX_BODY_KEY_INPUTS, inputs_list):
await self._process_inputs(inputs_list) await self._process_inputs(inputs_list)
outputs_list: HashBuilderList = HashBuilderList(self.msg.outputs_count) outputs_list: HashBuilderList = HBL(msg.outputs_count)
with self.tx_dict.add(_TX_BODY_KEY_OUTPUTS, outputs_list): with add(_TX_BODY_KEY_OUTPUTS, outputs_list):
await self._process_outputs(outputs_list) await self._process_outputs(outputs_list)
self.tx_dict.add(_TX_BODY_KEY_FEE, self.msg.fee) add(_TX_BODY_KEY_FEE, msg.fee)
if self.msg.ttl is not None: if msg.ttl is not None:
self.tx_dict.add(_TX_BODY_KEY_TTL, self.msg.ttl) add(_TX_BODY_KEY_TTL, msg.ttl)
if self.msg.certificates_count > 0: if msg.certificates_count > 0:
certificates_list: HashBuilderList = HashBuilderList( certificates_list: HashBuilderList = HBL(msg.certificates_count)
self.msg.certificates_count with add(_TX_BODY_KEY_CERTIFICATES, certificates_list):
)
with self.tx_dict.add(_TX_BODY_KEY_CERTIFICATES, certificates_list):
await self._process_certificates(certificates_list) await self._process_certificates(certificates_list)
if self.msg.withdrawals_count > 0: if msg.withdrawals_count > 0:
withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict( withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict(
self.msg.withdrawals_count, wire.ProcessError("Invalid withdrawal") msg.withdrawals_count, ProcessError("Invalid withdrawal")
) )
with self.tx_dict.add(_TX_BODY_KEY_WITHDRAWALS, withdrawals_dict): with add(_TX_BODY_KEY_WITHDRAWALS, withdrawals_dict):
await self._process_withdrawals(withdrawals_dict) await self._process_withdrawals(withdrawals_dict)
if self.msg.has_auxiliary_data: if msg.has_auxiliary_data:
await self._process_auxiliary_data() await self._process_auxiliary_data()
if self.msg.validity_interval_start is not None: if msg.validity_interval_start is not None:
self.tx_dict.add( add(_TX_BODY_KEY_VALIDITY_INTERVAL_START, msg.validity_interval_start)
_TX_BODY_KEY_VALIDITY_INTERVAL_START, self.msg.validity_interval_start
)
if self.msg.minting_asset_groups_count > 0: if msg.minting_asset_groups_count > 0:
minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict( minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict(
self.msg.minting_asset_groups_count, msg.minting_asset_groups_count,
wire.ProcessError("Invalid mint token bundle"), ProcessError("Invalid mint token bundle"),
) )
with self.tx_dict.add(_TX_BODY_KEY_MINT, minting_dict): with add(_TX_BODY_KEY_MINT, minting_dict):
await self._process_minting(minting_dict) await self._process_minting(minting_dict)
if self.msg.script_data_hash is not None: if msg.script_data_hash is not None:
await self._process_script_data_hash() await self._process_script_data_hash()
if self.msg.collateral_inputs_count > 0: if msg.collateral_inputs_count > 0:
collateral_inputs_list: HashBuilderList[ collateral_inputs_list: HashBuilderList[tuple[bytes, int]] = HBL(
tuple[bytes, int] msg.collateral_inputs_count
] = HashBuilderList(self.msg.collateral_inputs_count) )
with self.tx_dict.add( with add(_TX_BODY_KEY_COLLATERAL_INPUTS, collateral_inputs_list):
_TX_BODY_KEY_COLLATERAL_INPUTS, collateral_inputs_list
):
await self._process_collateral_inputs(collateral_inputs_list) await self._process_collateral_inputs(collateral_inputs_list)
if self.msg.required_signers_count > 0: if msg.required_signers_count > 0:
required_signers_list: HashBuilderList[bytes] = HashBuilderList( required_signers_list: HashBuilderList[bytes] = HBL(
self.msg.required_signers_count msg.required_signers_count
) )
with self.tx_dict.add(_TX_BODY_KEY_REQUIRED_SIGNERS, required_signers_list): with add(_TX_BODY_KEY_REQUIRED_SIGNERS, required_signers_list):
await self._process_required_signers(required_signers_list) await self._process_required_signers(required_signers_list)
if self.msg.include_network_id: if msg.include_network_id:
self.tx_dict.add(_TX_BODY_KEY_NETWORK_ID, self.msg.network_id) add(_TX_BODY_KEY_NETWORK_ID, msg.network_id)
if self.msg.has_collateral_return: if msg.has_collateral_return:
await self._process_collateral_return() await self._process_collateral_return()
if self.msg.total_collateral is not None: if msg.total_collateral is not None:
self.tx_dict.add(_TX_BODY_KEY_TOTAL_COLLATERAL, self.msg.total_collateral) add(_TX_BODY_KEY_TOTAL_COLLATERAL, msg.total_collateral)
if self.msg.reference_inputs_count > 0: if msg.reference_inputs_count > 0:
reference_inputs_list: HashBuilderList[tuple[bytes, int]] = HashBuilderList( reference_inputs_list: HashBuilderList[tuple[bytes, int]] = HBL(
self.msg.reference_inputs_count msg.reference_inputs_count
) )
with self.tx_dict.add(_TX_BODY_KEY_REFERENCE_INPUTS, reference_inputs_list): with add(_TX_BODY_KEY_REFERENCE_INPUTS, reference_inputs_list):
await self._process_reference_inputs(reference_inputs_list) await self._process_reference_inputs(reference_inputs_list)
def _validate_tx_init(self) -> None: def _validate_tx_init(self) -> None:
if self.msg.fee > LOVELACE_MAX_SUPPLY: from ..helpers.utils import validate_network_info
raise wire.ProcessError("Fee is out of range!")
msg = self.msg # local_cache_attribute
if msg.fee > LOVELACE_MAX_SUPPLY:
raise ProcessError("Fee is out of range!")
if ( if (
self.msg.total_collateral is not None msg.total_collateral is not None
and self.msg.total_collateral > LOVELACE_MAX_SUPPLY and msg.total_collateral > LOVELACE_MAX_SUPPLY
): ):
raise wire.ProcessError("Total collateral is out of range!") raise ProcessError("Total collateral is out of range!")
validate_network_info(self.msg.network_id, self.msg.protocol_magic) validate_network_info(msg.network_id, msg.protocol_magic)
async def _show_tx_init(self) -> None: async def _show_tx_init(self) -> None:
self.should_show_details = await layout.show_tx_init( self.should_show_details = await layout.show_tx_init(
@ -260,7 +243,7 @@ class Signer:
) -> None: ) -> None:
for _ in range(self.msg.inputs_count): for _ in range(self.msg.inputs_count):
input: messages.CardanoTxInput = await self.ctx.call( input: messages.CardanoTxInput = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxInput CardanoTxItemAck(), messages.CardanoTxInput
) )
self._validate_input(input) self._validate_input(input)
await self._show_input(input) await self._show_input(input)
@ -268,7 +251,7 @@ class Signer:
def _validate_input(self, input: messages.CardanoTxInput) -> None: def _validate_input(self, input: messages.CardanoTxInput) -> None:
if len(input.prev_hash) != INPUT_PREV_HASH_SIZE: if len(input.prev_hash) != INPUT_PREV_HASH_SIZE:
raise wire.ProcessError("Invalid input") raise ProcessError("Invalid input")
async def _show_input(self, input: messages.CardanoTxInput) -> None: async def _show_input(self, input: messages.CardanoTxInput) -> None:
# We never show the inputs, except for Plutus txs. # We never show the inputs, except for Plutus txs.
@ -279,18 +262,18 @@ class Signer:
async def _process_outputs(self, outputs_list: HashBuilderList) -> None: async def _process_outputs(self, outputs_list: HashBuilderList) -> None:
total_amount = 0 total_amount = 0
for _ in range(self.msg.outputs_count): for _ in range(self.msg.outputs_count):
output: messages.CardanoTxOutput = await self.ctx.call( output: CardanoTxOutput = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxOutput CardanoTxItemAck(), CardanoTxOutput
) )
await self._process_output(outputs_list, output) await self._process_output(outputs_list, output)
total_amount += output.amount total_amount += output.amount
if total_amount > LOVELACE_MAX_SUPPLY: if total_amount > LOVELACE_MAX_SUPPLY:
raise wire.ProcessError("Total transaction amount is out of range!") raise ProcessError("Total transaction amount is out of range!")
async def _process_output( async def _process_output(
self, outputs_list: HashBuilderList, output: messages.CardanoTxOutput self, outputs_list: HashBuilderList, output: CardanoTxOutput
) -> None: ) -> None:
self._validate_output(output) self._validate_output(output)
should_show = self._should_show_output(output) should_show = self._should_show_output(output)
@ -310,49 +293,53 @@ class Signer:
await self._process_legacy_output(output_list, output, should_show) await self._process_legacy_output(output_list, output, should_show)
elif output.format == CardanoTxOutputSerializationFormat.MAP_BABBAGE: elif output.format == CardanoTxOutputSerializationFormat.MAP_BABBAGE:
output_dict: HashBuilderDict[int, Any] = HashBuilderDict( output_dict: HashBuilderDict[int, Any] = HashBuilderDict(
output_items_count, wire.ProcessError("Invalid output") output_items_count, ProcessError("Invalid output")
) )
with outputs_list.append(output_dict): with outputs_list.append(output_dict):
await self._process_babbage_output(output_dict, output, should_show) await self._process_babbage_output(output_dict, output, should_show)
else: else:
raise RuntimeError # should be unreachable raise RuntimeError # should be unreachable
def _validate_output(self, output: messages.CardanoTxOutput) -> None: def _validate_output(self, output: CardanoTxOutput) -> None:
if output.address_parameters is not None and output.address is not None: from ..helpers import OUTPUT_DATUM_HASH_SIZE
raise wire.ProcessError("Invalid output")
if output.address_parameters is not None: address_parameters = output.address_parameters # local_cache_attribute
addresses.validate_output_address_parameters(output.address_parameters)
self._fail_if_strict_and_unusual(output.address_parameters) if address_parameters is not None and output.address is not None:
raise ProcessError("Invalid output")
if address_parameters is not None:
addresses.validate_output_address_parameters(address_parameters)
self._fail_if_strict_and_unusual(address_parameters)
elif output.address is not None: elif output.address is not None:
addresses.validate_output_address( addresses.validate_output_address(
output.address, self.msg.protocol_magic, self.msg.network_id output.address, self.msg.protocol_magic, self.msg.network_id
) )
else: else:
raise wire.ProcessError("Invalid output") raise ProcessError("Invalid output")
# datum hash # datum hash
if output.datum_hash is not None: if output.datum_hash is not None:
if len(output.datum_hash) != OUTPUT_DATUM_HASH_SIZE: if len(output.datum_hash) != OUTPUT_DATUM_HASH_SIZE:
raise wire.ProcessError("Invalid output datum hash") raise ProcessError("Invalid output datum hash")
# inline datum # inline datum
if output.inline_datum_size > 0: if output.inline_datum_size > 0:
if output.format != CardanoTxOutputSerializationFormat.MAP_BABBAGE: if output.format != CardanoTxOutputSerializationFormat.MAP_BABBAGE:
raise wire.ProcessError("Invalid output") raise ProcessError("Invalid output")
# datum hash and inline datum are mutually exclusive # datum hash and inline datum are mutually exclusive
if output.datum_hash is not None and output.inline_datum_size > 0: if output.datum_hash is not None and output.inline_datum_size > 0:
raise wire.ProcessError("Invalid output") raise ProcessError("Invalid output")
# reference script # reference script
if output.reference_script_size > 0: if output.reference_script_size > 0:
if output.format != CardanoTxOutputSerializationFormat.MAP_BABBAGE: if output.format != CardanoTxOutputSerializationFormat.MAP_BABBAGE:
raise wire.ProcessError("Invalid output") raise ProcessError("Invalid output")
self.account_path_checker.add_output(output) self.account_path_checker.add_output(output)
async def _show_output_init(self, output: messages.CardanoTxOutput) -> None: async def _show_output_init(self, output: CardanoTxOutput) -> None:
address_type = self._get_output_address_type(output) address_type = self._get_output_address_type(output)
if ( if (
output.datum_hash is None output.datum_hash is None
@ -393,7 +380,7 @@ class Signer:
Credential.stake_credential(address_parameters), Credential.stake_credential(address_parameters),
) )
def _should_show_output(self, output: messages.CardanoTxOutput) -> bool: def _should_show_output(self, output: CardanoTxOutput) -> bool:
""" """
Determines whether the output should be shown. Extracted from _show_output Determines whether the output should be shown. Extracted from _show_output
because of readability. because of readability.
@ -419,12 +406,14 @@ class Signer:
return True return True
def _is_change_output(self, output: messages.CardanoTxOutput) -> bool: def _is_change_output(self, output: CardanoTxOutput) -> bool:
"""Used only to determine what message to show to the user when confirming sending.""" """Used only to determine what message to show to the user when confirming sending."""
return output.address_parameters is not None return output.address_parameters is not None
def _is_simple_change_output(self, output: messages.CardanoTxOutput) -> bool: def _is_simple_change_output(self, output: CardanoTxOutput) -> bool:
"""Used to determine whether an output is a change output with ordinary credentials.""" """Used to determine whether an output is a change output with ordinary credentials."""
from ..helpers.credential import should_show_credentials
return output.address_parameters is not None and not should_show_credentials( return output.address_parameters is not None and not should_show_credentials(
output.address_parameters output.address_parameters
) )
@ -432,7 +421,7 @@ class Signer:
async def _process_legacy_output( async def _process_legacy_output(
self, self,
output_list: HashBuilderList, output_list: HashBuilderList,
output: messages.CardanoTxOutput, output: CardanoTxOutput,
should_show: bool, should_show: bool,
) -> None: ) -> None:
address = self._get_output_address(output) address = self._get_output_address(output)
@ -457,23 +446,27 @@ class Signer:
async def _process_babbage_output( async def _process_babbage_output(
self, self,
output_dict: HashBuilderDict[int, Any], output_dict: HashBuilderDict[int, Any],
output: messages.CardanoTxOutput, output: CardanoTxOutput,
should_show: bool, should_show: bool,
) -> None: ) -> None:
""" """
This output format corresponds to the post-Alonzo format in CDDL. This output format corresponds to the post-Alonzo format in CDDL.
Note that it is to be used also for outputs with no Plutus elements. Note that it is to be used also for outputs with no Plutus elements.
""" """
from ..helpers.hash_builder_collection import HashBuilderEmbeddedCBOR
add = output_dict.add # local_cache_attribute
address = self._get_output_address(output) address = self._get_output_address(output)
output_dict.add(_BABBAGE_OUTPUT_KEY_ADDRESS, address) add(_BABBAGE_OUTPUT_KEY_ADDRESS, address)
if output.asset_groups_count == 0: if output.asset_groups_count == 0:
# Only amount is added to the dict. # Only amount is added to the dict.
output_dict.add(_BABBAGE_OUTPUT_KEY_AMOUNT, output.amount) add(_BABBAGE_OUTPUT_KEY_AMOUNT, output.amount)
else: else:
# [amount, asset_groups] is added to the dict. # [amount, asset_groups] is added to the dict.
output_value_list: HashBuilderList = HashBuilderList(2) output_value_list: HashBuilderList = HashBuilderList(2)
with output_dict.add(_BABBAGE_OUTPUT_KEY_AMOUNT, output_value_list): with add(_BABBAGE_OUTPUT_KEY_AMOUNT, output_value_list):
await self._process_output_value(output_value_list, output, should_show) await self._process_output_value(output_value_list, output, should_show)
if output.datum_hash is not None: if output.datum_hash is not None:
@ -481,13 +474,13 @@ class Signer:
await self._show_if_showing_details( await self._show_if_showing_details(
layout.confirm_datum_hash(self.ctx, output.datum_hash) layout.confirm_datum_hash(self.ctx, output.datum_hash)
) )
output_dict.add( add(
_BABBAGE_OUTPUT_KEY_DATUM_OPTION, _BABBAGE_OUTPUT_KEY_DATUM_OPTION,
(_DATUM_OPTION_KEY_HASH, output.datum_hash), (_DATUM_OPTION_KEY_HASH, output.datum_hash),
) )
elif output.inline_datum_size > 0: elif output.inline_datum_size > 0:
inline_datum_list: HashBuilderList = HashBuilderList(2) inline_datum_list: HashBuilderList = HashBuilderList(2)
with output_dict.add(_BABBAGE_OUTPUT_KEY_DATUM_OPTION, inline_datum_list): with add(_BABBAGE_OUTPUT_KEY_DATUM_OPTION, inline_datum_list):
inline_datum_list.append(_DATUM_OPTION_KEY_INLINE) inline_datum_list.append(_DATUM_OPTION_KEY_INLINE)
inline_datum_cbor: HashBuilderEmbeddedCBOR = HashBuilderEmbeddedCBOR( inline_datum_cbor: HashBuilderEmbeddedCBOR = HashBuilderEmbeddedCBOR(
output.inline_datum_size output.inline_datum_size
@ -501,9 +494,7 @@ class Signer:
reference_script_cbor: HashBuilderEmbeddedCBOR = HashBuilderEmbeddedCBOR( reference_script_cbor: HashBuilderEmbeddedCBOR = HashBuilderEmbeddedCBOR(
output.reference_script_size output.reference_script_size
) )
with output_dict.add( with add(_BABBAGE_OUTPUT_KEY_REFERENCE_SCRIPT, reference_script_cbor):
_BABBAGE_OUTPUT_KEY_REFERENCE_SCRIPT, reference_script_cbor
):
await self._process_reference_script( await self._process_reference_script(
reference_script_cbor, output.reference_script_size, should_show reference_script_cbor, output.reference_script_size, should_show
) )
@ -511,7 +502,7 @@ class Signer:
async def _process_output_value( async def _process_output_value(
self, self,
output_value_list: HashBuilderList, output_value_list: HashBuilderList,
output: messages.CardanoTxOutput, output: CardanoTxOutput,
should_show_tokens: bool, should_show_tokens: bool,
) -> None: ) -> None:
"""Should be used only when the output contains tokens.""" """Should be used only when the output contains tokens."""
@ -523,7 +514,7 @@ class Signer:
bytes, HashBuilderDict[bytes, int] bytes, HashBuilderDict[bytes, int]
] = HashBuilderDict( ] = HashBuilderDict(
output.asset_groups_count, output.asset_groups_count,
wire.ProcessError("Invalid token bundle in output"), ProcessError("Invalid token bundle in output"),
) )
with output_value_list.append(asset_groups_dict): with output_value_list.append(asset_groups_dict):
await self._process_asset_groups( await self._process_asset_groups(
@ -542,13 +533,13 @@ class Signer:
) -> None: ) -> None:
for _ in range(asset_groups_count): for _ in range(asset_groups_count):
asset_group: messages.CardanoAssetGroup = await self.ctx.call( asset_group: messages.CardanoAssetGroup = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoAssetGroup CardanoTxItemAck(), messages.CardanoAssetGroup
) )
self._validate_asset_group(asset_group) self._validate_asset_group(asset_group)
tokens: HashBuilderDict[bytes, int] = HashBuilderDict( tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
asset_group.tokens_count, asset_group.tokens_count,
wire.ProcessError("Invalid token bundle in output"), ProcessError("Invalid token bundle in output"),
) )
with asset_groups_dict.add(asset_group.policy_id, tokens): with asset_groups_dict.add(asset_group.policy_id, tokens):
await self._process_tokens( await self._process_tokens(
@ -562,9 +553,9 @@ class Signer:
self, asset_group: messages.CardanoAssetGroup, is_mint: bool = False self, asset_group: messages.CardanoAssetGroup, is_mint: bool = False
) -> None: ) -> None:
INVALID_TOKEN_BUNDLE = ( INVALID_TOKEN_BUNDLE = (
wire.ProcessError("Invalid mint token bundle") ProcessError("Invalid mint token bundle")
if is_mint if is_mint
else wire.ProcessError("Invalid token bundle in output") else ProcessError("Invalid token bundle in output")
) )
if len(asset_group.policy_id) != _MINTING_POLICY_ID_LENGTH: if len(asset_group.policy_id) != _MINTING_POLICY_ID_LENGTH:
@ -583,7 +574,7 @@ class Signer:
) -> None: ) -> None:
for _ in range(tokens_count): for _ in range(tokens_count):
token: messages.CardanoToken = await self.ctx.call( token: messages.CardanoToken = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoToken CardanoTxItemAck(), messages.CardanoToken
) )
self._validate_token(token) self._validate_token(token)
if should_show_tokens: if should_show_tokens:
@ -596,9 +587,9 @@ class Signer:
self, token: messages.CardanoToken, is_mint: bool = False self, token: messages.CardanoToken, is_mint: bool = False
) -> None: ) -> None:
INVALID_TOKEN_BUNDLE = ( INVALID_TOKEN_BUNDLE = (
wire.ProcessError("Invalid mint token bundle") ProcessError("Invalid mint token bundle")
if is_mint if is_mint
else wire.ProcessError("Invalid token bundle in output") else ProcessError("Invalid token bundle in output")
) )
if is_mint: if is_mint:
@ -624,13 +615,13 @@ class Signer:
chunks_count = self._get_chunks_count(inline_datum_size) chunks_count = self._get_chunks_count(inline_datum_size)
for chunk_number in range(chunks_count): for chunk_number in range(chunks_count):
chunk: messages.CardanoTxInlineDatumChunk = await self.ctx.call( chunk: messages.CardanoTxInlineDatumChunk = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxInlineDatumChunk CardanoTxItemAck(), messages.CardanoTxInlineDatumChunk
) )
self._validate_chunk( self._validate_chunk(
chunk.data, chunk.data,
chunk_number, chunk_number,
chunks_count, chunks_count,
wire.ProcessError("Invalid inline datum chunk"), ProcessError("Invalid inline datum chunk"),
) )
if chunk_number == 0 and should_show: if chunk_number == 0 and should_show:
await self._show_if_showing_details( await self._show_if_showing_details(
@ -651,13 +642,13 @@ class Signer:
chunks_count = self._get_chunks_count(reference_script_size) chunks_count = self._get_chunks_count(reference_script_size)
for chunk_number in range(chunks_count): for chunk_number in range(chunks_count):
chunk: messages.CardanoTxReferenceScriptChunk = await self.ctx.call( chunk: messages.CardanoTxReferenceScriptChunk = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxReferenceScriptChunk CardanoTxItemAck(), messages.CardanoTxReferenceScriptChunk
) )
self._validate_chunk( self._validate_chunk(
chunk.data, chunk.data,
chunk_number, chunk_number,
chunks_count, chunks_count,
wire.ProcessError("Invalid reference script chunk"), ProcessError("Invalid reference script chunk"),
) )
if chunk_number == 0 and should_show: if chunk_number == 0 and should_show:
await self._show_if_showing_details( await self._show_if_showing_details(
@ -672,7 +663,7 @@ class Signer:
async def _process_certificates(self, certificates_list: HashBuilderList) -> None: async def _process_certificates(self, certificates_list: HashBuilderList) -> None:
for _ in range(self.msg.certificates_count): for _ in range(self.msg.certificates_count):
certificate: messages.CardanoTxCertificate = await self.ctx.call( certificate: messages.CardanoTxCertificate = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxCertificate CardanoTxItemAck(), messages.CardanoTxCertificate
) )
self._validate_certificate(certificate) self._validate_certificate(certificate)
await self._show_certificate(certificate) await self._show_certificate(certificate)
@ -725,6 +716,8 @@ class Signer:
async def _show_certificate( async def _show_certificate(
self, certificate: messages.CardanoTxCertificate self, certificate: messages.CardanoTxCertificate
) -> None: ) -> None:
from ..helpers.paths import CERTIFICATE_PATH_NAME
if certificate.path: if certificate.path:
await self._fail_or_warn_if_invalid_path( await self._fail_or_warn_if_invalid_path(
SCHEMA_STAKING, certificate.path, CERTIFICATE_PATH_NAME SCHEMA_STAKING, certificate.path, CERTIFICATE_PATH_NAME
@ -749,7 +742,7 @@ class Signer:
owners_as_path_count = 0 owners_as_path_count = 0
for _ in range(owners_count): for _ in range(owners_count):
owner: messages.CardanoPoolOwner = await self.ctx.call( owner: messages.CardanoPoolOwner = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoPoolOwner CardanoTxItemAck(), messages.CardanoPoolOwner
) )
certificates.validate_pool_owner(owner, self.account_path_checker) certificates.validate_pool_owner(owner, self.account_path_checker)
await self._show_pool_owner(owner) await self._show_pool_owner(owner)
@ -763,6 +756,8 @@ class Signer:
certificates.assert_cond(owners_as_path_count == 1) certificates.assert_cond(owners_as_path_count == 1)
async def _show_pool_owner(self, owner: messages.CardanoPoolOwner) -> None: async def _show_pool_owner(self, owner: messages.CardanoPoolOwner) -> None:
from ..helpers.paths import POOL_OWNER_STAKING_PATH_NAME
if owner.staking_key_path: if owner.staking_key_path:
await self._fail_or_warn_if_invalid_path( await self._fail_or_warn_if_invalid_path(
SCHEMA_STAKING, owner.staking_key_path, POOL_OWNER_STAKING_PATH_NAME SCHEMA_STAKING, owner.staking_key_path, POOL_OWNER_STAKING_PATH_NAME
@ -781,7 +776,7 @@ class Signer:
) -> None: ) -> None:
for _ in range(relays_count): for _ in range(relays_count):
relay: messages.CardanoPoolRelayParameters = await self.ctx.call( relay: messages.CardanoPoolRelayParameters = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoPoolRelayParameters CardanoTxItemAck(), messages.CardanoPoolRelayParameters
) )
certificates.validate_pool_relay(relay) certificates.validate_pool_relay(relay)
relays_list.append(certificates.cborize_pool_relay(relay)) relays_list.append(certificates.cborize_pool_relay(relay))
@ -793,7 +788,7 @@ class Signer:
) -> None: ) -> None:
for _ in range(self.msg.withdrawals_count): for _ in range(self.msg.withdrawals_count):
withdrawal: messages.CardanoTxWithdrawal = await self.ctx.call( withdrawal: messages.CardanoTxWithdrawal = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxWithdrawal CardanoTxItemAck(), messages.CardanoTxWithdrawal
) )
self._validate_withdrawal(withdrawal) self._validate_withdrawal(withdrawal)
address_bytes = self._derive_withdrawal_address_bytes(withdrawal) address_bytes = self._derive_withdrawal_address_bytes(withdrawal)
@ -805,23 +800,29 @@ class Signer:
withdrawals_dict.add(address_bytes, withdrawal.amount) withdrawals_dict.add(address_bytes, withdrawal.amount)
def _validate_withdrawal(self, withdrawal: messages.CardanoTxWithdrawal) -> None: def _validate_withdrawal(self, withdrawal: messages.CardanoTxWithdrawal) -> None:
from ..helpers.utils import validate_stake_credential
validate_stake_credential( validate_stake_credential(
withdrawal.path, withdrawal.path,
withdrawal.script_hash, withdrawal.script_hash,
withdrawal.key_hash, withdrawal.key_hash,
wire.ProcessError("Invalid withdrawal"), ProcessError("Invalid withdrawal"),
) )
if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY: if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY:
raise wire.ProcessError("Invalid withdrawal") raise ProcessError("Invalid withdrawal")
self.account_path_checker.add_withdrawal(withdrawal) self.account_path_checker.add_withdrawal(withdrawal)
# auxiliary data # auxiliary data
async def _process_auxiliary_data(self) -> None: async def _process_auxiliary_data(self) -> None:
from .. import auxiliary_data
msg = self.msg # local_cache_attribute
data: messages.CardanoTxAuxiliaryData = await self.ctx.call( data: messages.CardanoTxAuxiliaryData = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxAuxiliaryData CardanoTxItemAck(), messages.CardanoTxAuxiliaryData
) )
auxiliary_data.validate(data) auxiliary_data.validate(data)
@ -829,15 +830,15 @@ class Signer:
auxiliary_data_hash, auxiliary_data_hash,
auxiliary_data_supplement, auxiliary_data_supplement,
) = auxiliary_data.get_hash_and_supplement( ) = auxiliary_data.get_hash_and_supplement(
self.keychain, data, self.msg.protocol_magic, self.msg.network_id self.keychain, data, msg.protocol_magic, msg.network_id
) )
await auxiliary_data.show( await auxiliary_data.show(
self.ctx, self.ctx,
self.keychain, self.keychain,
auxiliary_data_hash, auxiliary_data_hash,
data.governance_registration_parameters, data.governance_registration_parameters,
self.msg.protocol_magic, msg.protocol_magic,
self.msg.network_id, msg.network_id,
self.should_show_details, self.should_show_details,
) )
self.tx_dict.add(_TX_BODY_KEY_AUXILIARY_DATA, auxiliary_data_hash) self.tx_dict.add(_TX_BODY_KEY_AUXILIARY_DATA, auxiliary_data_hash)
@ -850,19 +851,19 @@ class Signer:
self, minting_dict: HashBuilderDict[bytes, HashBuilderDict] self, minting_dict: HashBuilderDict[bytes, HashBuilderDict]
) -> None: ) -> None:
token_minting: messages.CardanoTxMint = await self.ctx.call( token_minting: messages.CardanoTxMint = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxMint CardanoTxItemAck(), messages.CardanoTxMint
) )
await layout.warn_tx_contains_mint(self.ctx) await layout.warn_tx_contains_mint(self.ctx)
for _ in range(token_minting.asset_groups_count): for _ in range(token_minting.asset_groups_count):
asset_group: messages.CardanoAssetGroup = await self.ctx.call( asset_group: messages.CardanoAssetGroup = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoAssetGroup CardanoTxItemAck(), messages.CardanoAssetGroup
) )
self._validate_asset_group(asset_group, is_mint=True) self._validate_asset_group(asset_group, is_mint=True)
tokens: HashBuilderDict[bytes, int] = HashBuilderDict( tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
asset_group.tokens_count, wire.ProcessError("Invalid mint token bundle") asset_group.tokens_count, ProcessError("Invalid mint token bundle")
) )
with minting_dict.add(asset_group.policy_id, tokens): with minting_dict.add(asset_group.policy_id, tokens):
await self._process_minting_tokens( await self._process_minting_tokens(
@ -881,7 +882,7 @@ class Signer:
) -> None: ) -> None:
for _ in range(tokens_count): for _ in range(tokens_count):
token: messages.CardanoToken = await self.ctx.call( token: messages.CardanoToken = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoToken CardanoTxItemAck(), messages.CardanoToken
) )
self._validate_token(token, is_mint=True) self._validate_token(token, is_mint=True)
await layout.confirm_token_minting(self.ctx, policy_id, token) await layout.confirm_token_minting(self.ctx, policy_id, token)
@ -900,9 +901,11 @@ class Signer:
self.tx_dict.add(_TX_BODY_KEY_SCRIPT_DATA_HASH, self.msg.script_data_hash) self.tx_dict.add(_TX_BODY_KEY_SCRIPT_DATA_HASH, self.msg.script_data_hash)
def _validate_script_data_hash(self) -> None: def _validate_script_data_hash(self) -> None:
from ..helpers import SCRIPT_DATA_HASH_SIZE
assert self.msg.script_data_hash is not None assert self.msg.script_data_hash is not None
if len(self.msg.script_data_hash) != SCRIPT_DATA_HASH_SIZE: if len(self.msg.script_data_hash) != SCRIPT_DATA_HASH_SIZE:
raise wire.ProcessError("Invalid script data hash") raise ProcessError("Invalid script data hash")
# collateral inputs # collateral inputs
@ -911,7 +914,7 @@ class Signer:
) -> None: ) -> None:
for _ in range(self.msg.collateral_inputs_count): for _ in range(self.msg.collateral_inputs_count):
collateral_input: messages.CardanoTxCollateralInput = await self.ctx.call( collateral_input: messages.CardanoTxCollateralInput = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxCollateralInput CardanoTxItemAck(), messages.CardanoTxCollateralInput
) )
self._validate_collateral_input(collateral_input) self._validate_collateral_input(collateral_input)
await self._show_collateral_input(collateral_input) await self._show_collateral_input(collateral_input)
@ -923,7 +926,7 @@ class Signer:
self, collateral_input: messages.CardanoTxCollateralInput self, collateral_input: messages.CardanoTxCollateralInput
) -> None: ) -> None:
if len(collateral_input.prev_hash) != INPUT_PREV_HASH_SIZE: if len(collateral_input.prev_hash) != INPUT_PREV_HASH_SIZE:
raise wire.ProcessError("Invalid collateral input") raise ProcessError("Invalid collateral input")
async def _show_collateral_input( async def _show_collateral_input(
self, collateral_input: messages.CardanoTxCollateralInput self, collateral_input: messages.CardanoTxCollateralInput
@ -938,9 +941,11 @@ class Signer:
async def _process_required_signers( async def _process_required_signers(
self, required_signers_list: HashBuilderList[bytes] self, required_signers_list: HashBuilderList[bytes]
) -> None: ) -> None:
from ..helpers.utils import get_public_key_hash
for _ in range(self.msg.required_signers_count): for _ in range(self.msg.required_signers_count):
required_signer: messages.CardanoTxRequiredSigner = await self.ctx.call( required_signer: messages.CardanoTxRequiredSigner = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxRequiredSigner CardanoTxItemAck(), messages.CardanoTxRequiredSigner
) )
self._validate_required_signer(required_signer) self._validate_required_signer(required_signer)
await self._show_if_showing_details( await self._show_if_showing_details(
@ -955,19 +960,23 @@ class Signer:
def _validate_required_signer( def _validate_required_signer(
self, required_signer: messages.CardanoTxRequiredSigner self, required_signer: messages.CardanoTxRequiredSigner
) -> None: ) -> None:
INVALID_REQUIRED_SIGNER = wire.ProcessError("Invalid required signer") from ..helpers import ADDRESS_KEY_HASH_SIZE
if required_signer.key_hash and required_signer.key_path: key_path = required_signer.key_path # local_cache_attribute
INVALID_REQUIRED_SIGNER = ProcessError("Invalid required signer")
if required_signer.key_hash and key_path:
raise INVALID_REQUIRED_SIGNER raise INVALID_REQUIRED_SIGNER
if required_signer.key_hash: if required_signer.key_hash:
if len(required_signer.key_hash) != ADDRESS_KEY_HASH_SIZE: if len(required_signer.key_hash) != ADDRESS_KEY_HASH_SIZE:
raise INVALID_REQUIRED_SIGNER raise INVALID_REQUIRED_SIGNER
elif required_signer.key_path: elif key_path:
if not ( if not (
seed.is_shelley_path(required_signer.key_path) seed.is_shelley_path(key_path)
or seed.is_multisig_path(required_signer.key_path) or seed.is_multisig_path(key_path)
or seed.is_minting_path(required_signer.key_path) or seed.is_minting_path(key_path)
): ):
raise INVALID_REQUIRED_SIGNER raise INVALID_REQUIRED_SIGNER
else: else:
@ -976,8 +985,8 @@ class Signer:
# collateral return # collateral return
async def _process_collateral_return(self) -> None: async def _process_collateral_return(self) -> None:
output: messages.CardanoTxOutput = await self.ctx.call( output: CardanoTxOutput = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxOutput CardanoTxItemAck(), CardanoTxOutput
) )
self._validate_collateral_return(output) self._validate_collateral_return(output)
should_show_init = self._should_show_collateral_return_init(output) should_show_init = self._should_show_collateral_return_init(output)
@ -995,7 +1004,7 @@ class Signer:
) )
elif output.format == CardanoTxOutputSerializationFormat.MAP_BABBAGE: elif output.format == CardanoTxOutputSerializationFormat.MAP_BABBAGE:
output_dict: HashBuilderDict[int, Any] = HashBuilderDict( output_dict: HashBuilderDict[int, Any] = HashBuilderDict(
output_items_count, wire.ProcessError("Invalid collateral return") output_items_count, ProcessError("Invalid collateral return")
) )
with self.tx_dict.add(_TX_BODY_KEY_COLLATERAL_RETURN, output_dict): with self.tx_dict.add(_TX_BODY_KEY_COLLATERAL_RETURN, output_dict):
await self._process_babbage_output( await self._process_babbage_output(
@ -1004,23 +1013,21 @@ class Signer:
else: else:
raise RuntimeError # should be unreachable raise RuntimeError # should be unreachable
def _validate_collateral_return(self, output: messages.CardanoTxOutput) -> None: def _validate_collateral_return(self, output: CardanoTxOutput) -> None:
self._validate_output(output) self._validate_output(output)
address_type = self._get_output_address_type(output) address_type = self._get_output_address_type(output)
if address_type not in addresses.ADDRESS_TYPES_PAYMENT_KEY: if address_type not in addresses.ADDRESS_TYPES_PAYMENT_KEY:
raise wire.ProcessError("Invalid collateral return") raise ProcessError("Invalid collateral return")
if ( if (
output.datum_hash is not None output.datum_hash is not None
or output.inline_datum_size > 0 or output.inline_datum_size > 0
or output.reference_script_size > 0 or output.reference_script_size > 0
): ):
raise wire.ProcessError("Invalid collateral return") raise ProcessError("Invalid collateral return")
async def _show_collateral_return_init( async def _show_collateral_return_init(self, output: CardanoTxOutput) -> None:
self, output: messages.CardanoTxOutput
) -> None:
# We don't display missing datum warning since datums are forbidden. # We don't display missing datum warning since datums are forbidden.
if output.asset_groups_count > 0: if output.asset_groups_count > 0:
@ -1050,9 +1057,7 @@ class Signer:
self.msg.network_id, self.msg.network_id,
) )
def _should_show_collateral_return_init( def _should_show_collateral_return_init(self, output: CardanoTxOutput) -> bool:
self, output: messages.CardanoTxOutput
) -> bool:
if self.msg.total_collateral is None: if self.msg.total_collateral is None:
return True return True
@ -1061,9 +1066,7 @@ class Signer:
return True return True
def _should_show_collateral_return_tokens( def _should_show_collateral_return_tokens(self, output: CardanoTxOutput) -> bool:
self, output: messages.CardanoTxOutput
) -> bool:
if self._is_simple_change_output(output): if self._is_simple_change_output(output):
return False return False
@ -1076,7 +1079,7 @@ class Signer:
) -> None: ) -> None:
for _ in range(self.msg.reference_inputs_count): for _ in range(self.msg.reference_inputs_count):
reference_input: messages.CardanoTxReferenceInput = await self.ctx.call( reference_input: messages.CardanoTxReferenceInput = await self.ctx.call(
messages.CardanoTxItemAck(), messages.CardanoTxReferenceInput CardanoTxItemAck(), messages.CardanoTxReferenceInput
) )
self._validate_reference_input(reference_input) self._validate_reference_input(reference_input)
await self._show_if_showing_details( await self._show_if_showing_details(
@ -1090,12 +1093,12 @@ class Signer:
self, reference_input: messages.CardanoTxReferenceInput self, reference_input: messages.CardanoTxReferenceInput
) -> None: ) -> None:
if len(reference_input.prev_hash) != INPUT_PREV_HASH_SIZE: if len(reference_input.prev_hash) != INPUT_PREV_HASH_SIZE:
raise wire.ProcessError("Invalid reference input") raise ProcessError("Invalid reference input")
# witness requests # witness requests
async def _process_witness_requests(self, tx_hash: bytes) -> CardanoTxResponseType: async def _process_witness_requests(self, tx_hash: bytes) -> CardanoTxResponseType:
response: CardanoTxResponseType = messages.CardanoTxItemAck() response: CardanoTxResponseType = CardanoTxItemAck()
for _ in range(self.msg.witness_requests_count): for _ in range(self.msg.witness_requests_count):
witness_request = await self.ctx.call( witness_request = await self.ctx.call(
@ -1126,7 +1129,7 @@ class Signer:
def _assert_tx_init_cond(self, condition: bool) -> None: def _assert_tx_init_cond(self, condition: bool) -> None:
if not condition: if not condition:
raise wire.ProcessError("Invalid tx signing request") raise ProcessError("Invalid tx signing request")
def _is_network_id_verifiable(self) -> bool: def _is_network_id_verifiable(self) -> bool:
""" """
@ -1144,7 +1147,7 @@ class Signer:
or self.msg.withdrawals_count != 0 or self.msg.withdrawals_count != 0
) )
def _get_output_address(self, output: messages.CardanoTxOutput) -> bytes: def _get_output_address(self, output: CardanoTxOutput) -> bytes:
if output.address_parameters: if output.address_parameters:
return addresses.derive_bytes( return addresses.derive_bytes(
self.keychain, self.keychain,
@ -1156,9 +1159,7 @@ class Signer:
assert output.address is not None # _validate_output assert output.address is not None # _validate_output
return addresses.get_bytes_unsafe(output.address) return addresses.get_bytes_unsafe(output.address)
def _get_output_address_type( def _get_output_address_type(self, output: CardanoTxOutput) -> CardanoAddressType:
self, output: messages.CardanoTxOutput
) -> CardanoAddressType:
if output.address_parameters: if output.address_parameters:
return output.address_parameters.address_type return output.address_parameters.address_type
assert output.address is not None # _validate_output assert output.address is not None # _validate_output
@ -1167,6 +1168,8 @@ class Signer:
def _derive_withdrawal_address_bytes( def _derive_withdrawal_address_bytes(
self, withdrawal: messages.CardanoTxWithdrawal self, withdrawal: messages.CardanoTxWithdrawal
) -> bytes: ) -> bytes:
from trezor.enums import CardanoAddressType
reward_address_type = ( reward_address_type = (
CardanoAddressType.REWARD CardanoAddressType.REWARD
if withdrawal.path or withdrawal.key_hash if withdrawal.path or withdrawal.key_hash
@ -1193,7 +1196,7 @@ class Signer:
chunk_data: bytes, chunk_data: bytes,
chunk_number: int, chunk_number: int,
chunks_count: int, chunks_count: int,
error: wire.ProcessError, error: ProcessError,
) -> None: ) -> None:
if chunk_number < chunks_count - 1 and len(chunk_data) != _MAX_CHUNK_SIZE: if chunk_number < chunks_count - 1 and len(chunk_data) != _MAX_CHUNK_SIZE:
raise error raise error
@ -1221,6 +1224,8 @@ class Signer:
) )
def _sign_tx_hash(self, tx_body_hash: bytes, path: list[int]) -> bytes: def _sign_tx_hash(self, tx_body_hash: bytes, path: list[int]) -> bytes:
from trezor.crypto.curve import ed25519
node = self.keychain.derive(path) node = self.keychain.derive(path)
return ed25519.sign_ext( return ed25519.sign_ext(
node.private_key(), node.private_key_ext(), tx_body_hash node.private_key(), node.private_key_ext(), tx_body_hash
@ -1234,21 +1239,26 @@ class Signer:
async def _fail_or_warn_path(self, path: list[int], path_name: str) -> None: async def _fail_or_warn_path(self, path: list[int], path_name: str) -> None:
if safety_checks.is_strict(): if safety_checks.is_strict():
raise wire.DataError(f"Invalid {path_name.lower()}") raise DataError(f"Invalid {path_name.lower()}")
else: else:
await layout.warn_path(self.ctx, path, path_name) await layout.warn_path(self.ctx, path, path_name)
def _fail_if_strict_and_unusual( def _fail_if_strict_and_unusual(
self, address_parameters: messages.CardanoAddressParametersType self, address_parameters: messages.CardanoAddressParametersType
) -> None: ) -> None:
from ..helpers.paths import (
CHANGE_OUTPUT_PATH_NAME,
CHANGE_OUTPUT_STAKING_PATH_NAME,
)
if not safety_checks.is_strict(): if not safety_checks.is_strict():
return return
if Credential.payment_credential(address_parameters).is_unusual_path: if Credential.payment_credential(address_parameters).is_unusual_path:
raise wire.DataError(f"Invalid {CHANGE_OUTPUT_PATH_NAME.lower()}") raise DataError(f"Invalid {CHANGE_OUTPUT_PATH_NAME.lower()}")
if Credential.stake_credential(address_parameters).is_unusual_path: if Credential.stake_credential(address_parameters).is_unusual_path:
raise wire.DataError(f"Invalid {CHANGE_OUTPUT_STAKING_PATH_NAME.lower()}") raise DataError(f"Invalid {CHANGE_OUTPUT_STAKING_PATH_NAME.lower()}")
async def _show_if_showing_details(self, layout_fn: Awaitable) -> None: async def _show_if_showing_details(self, layout_fn: Awaitable) -> None:
if self.should_show_details: if self.should_show_details:

View File

@ -18,7 +18,7 @@ class TestCardanoBech32(unittest.TestCase):
] ]
for expected_human_readable_part, expected_bech in expected_bechs: for expected_human_readable_part, expected_bech in expected_bechs:
decoded = bech32.decode(expected_human_readable_part, expected_bech) decoded = bech32._decode(expected_human_readable_part, expected_bech)
actual_bech = bech32.encode(expected_human_readable_part, decoded) actual_bech = bech32.encode(expected_human_readable_part, decoded)
self.assertEqual(actual_bech, expected_bech) self.assertEqual(actual_bech, expected_bech)