1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 20:38:10 +00:00

feat(cardano): streamed transaction signing

This commit is contained in:
gabrielkerekes 2021-06-30 14:13:43 +02:00 committed by matejcik
parent d2a5be4e38
commit b0c8590f00
13 changed files with 1322 additions and 903 deletions

View File

@ -448,6 +448,8 @@ if utils.BITCOIN_ONLY:
import apps.cardano.helpers import apps.cardano.helpers
apps.cardano.helpers.bech32 apps.cardano.helpers.bech32
import apps.cardano.helpers.bech32 import apps.cardano.helpers.bech32
apps.cardano.helpers.hash_builder_collection
import apps.cardano.helpers.hash_builder_collection
apps.cardano.helpers.network_ids apps.cardano.helpers.network_ids
import apps.cardano.helpers.network_ids import apps.cardano.helpers.network_ids
apps.cardano.helpers.paths apps.cardano.helpers.paths

View File

@ -1,6 +1,7 @@
from trezor.crypto import hashlib from trezor.crypto import hashlib
from trezor.crypto.curve import ed25519 from trezor.crypto.curve import ed25519
from trezor.enums import CardanoAddressType from trezor.enums import CardanoAddressType, CardanoTxAuxiliaryDataSupplementType
from trezor.messages import CardanoTxAuxiliaryDataSupplement
from apps.common import cbor from apps.common import cbor
@ -21,10 +22,11 @@ if False:
from trezor.messages import ( from trezor.messages import (
CardanoCatalystRegistrationParametersType, CardanoCatalystRegistrationParametersType,
CardanoTxAuxiliaryDataType, CardanoTxAuxiliaryData,
) )
CatalystRegistrationPayload = dict[int, Union[bytes, int]] CatalystRegistrationPayload = dict[int, Union[bytes, int]]
SignedCatalystRegistrationPayload = tuple[CatalystRegistrationPayload, bytes]
CatalystRegistrationSignature = dict[int, bytes] CatalystRegistrationSignature = dict[int, bytes]
CatalystRegistration = dict[ CatalystRegistration = dict[
int, Union[CatalystRegistrationPayload, CatalystRegistrationSignature] int, Union[CatalystRegistrationPayload, CatalystRegistrationSignature]
@ -40,14 +42,11 @@ METADATA_KEY_CATALYST_REGISTRATION = 61284
METADATA_KEY_CATALYST_REGISTRATION_SIGNATURE = 61285 METADATA_KEY_CATALYST_REGISTRATION_SIGNATURE = 61285
def validate_auxiliary_data(auxiliary_data: CardanoTxAuxiliaryDataType | None) -> None: def validate_auxiliary_data(auxiliary_data: CardanoTxAuxiliaryData) -> None:
if not auxiliary_data:
return
fields_provided = 0 fields_provided = 0
if auxiliary_data.blob: if auxiliary_data.hash:
fields_provided += 1 fields_provided += 1
_validate_auxiliary_data_blob(auxiliary_data.blob) _validate_auxiliary_data_hash(auxiliary_data.hash)
if auxiliary_data.catalyst_registration_parameters: if auxiliary_data.catalyst_registration_parameters:
fields_provided += 1 fields_provided += 1
_validate_catalyst_registration_parameters( _validate_catalyst_registration_parameters(
@ -58,12 +57,8 @@ def validate_auxiliary_data(auxiliary_data: CardanoTxAuxiliaryDataType | None) -
raise INVALID_AUXILIARY_DATA raise INVALID_AUXILIARY_DATA
def _validate_auxiliary_data_blob(auxiliary_data_blob: bytes) -> None: def _validate_auxiliary_data_hash(auxiliary_data_hash: bytes) -> None:
try: if len(auxiliary_data_hash) != AUXILIARY_DATA_HASH_SIZE:
# validation to prevent CBOR injection and invalid CBOR
# we don't validate data format, just that it's a valid CBOR
cbor.decode(auxiliary_data_blob)
except Exception:
raise INVALID_AUXILIARY_DATA raise INVALID_AUXILIARY_DATA
@ -91,27 +86,20 @@ def _validate_catalyst_registration_parameters(
async def show_auxiliary_data( async def show_auxiliary_data(
ctx: wire.Context, ctx: wire.Context,
keychain: seed.Keychain, keychain: seed.Keychain,
auxiliary_data: CardanoTxAuxiliaryDataType | None, auxiliary_data_hash: bytes,
catalyst_registration_parameters: CardanoCatalystRegistrationParametersType | None,
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
) -> None: ) -> None:
if not auxiliary_data: if catalyst_registration_parameters:
return
if auxiliary_data.catalyst_registration_parameters:
await _show_catalyst_registration( await _show_catalyst_registration(
ctx, ctx,
keychain, keychain,
auxiliary_data.catalyst_registration_parameters, catalyst_registration_parameters,
protocol_magic, protocol_magic,
network_id, network_id,
) )
auxiliary_data_bytes = get_auxiliary_data_cbor(
keychain, auxiliary_data, protocol_magic, network_id
)
auxiliary_data_hash = hash_auxiliary_data(bytes(auxiliary_data_bytes))
await show_auxiliary_data_hash(ctx, auxiliary_data_hash) await show_auxiliary_data_hash(ctx, auxiliary_data_hash)
@ -138,37 +126,71 @@ async def _show_catalyst_registration(
) )
def get_auxiliary_data_cbor( def get_auxiliary_data_hash_and_supplement(
keychain: seed.Keychain, keychain: seed.Keychain,
auxiliary_data: CardanoTxAuxiliaryDataType, auxiliary_data: CardanoTxAuxiliaryData,
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
) -> bytes: ) -> tuple[bytes, CardanoTxAuxiliaryDataSupplement]:
if auxiliary_data.blob: if parameters := auxiliary_data.catalyst_registration_parameters:
return auxiliary_data.blob (
elif auxiliary_data.catalyst_registration_parameters: catalyst_registration_payload,
cborized_catalyst_registration = _cborize_catalyst_registration( catalyst_signature,
keychain, ) = _get_signed_catalyst_registration_payload(
auxiliary_data.catalyst_registration_parameters, keychain, parameters, protocol_magic, network_id
protocol_magic,
network_id,
) )
return cbor.encode(_wrap_metadata(cborized_catalyst_registration)) auxiliary_data_hash = _get_catalyst_registration_auxiliary_data_hash(
catalyst_registration_payload, catalyst_signature
)
auxiliary_data_supplement = CardanoTxAuxiliaryDataSupplement(
type=CardanoTxAuxiliaryDataSupplementType.CATALYST_REGISTRATION_SIGNATURE,
auxiliary_data_hash=auxiliary_data_hash,
catalyst_signature=catalyst_signature,
)
return auxiliary_data_hash, auxiliary_data_supplement
else: else:
raise INVALID_AUXILIARY_DATA assert auxiliary_data.hash is not None # validate_auxiliary_data
return auxiliary_data.hash, CardanoTxAuxiliaryDataSupplement(
type=CardanoTxAuxiliaryDataSupplementType.NONE
)
def _get_catalyst_registration_auxiliary_data_hash(
catalyst_registration_payload: CatalystRegistrationPayload,
catalyst_registration_payload_signature: bytes,
) -> bytes:
cborized_catalyst_registration = _cborize_catalyst_registration(
catalyst_registration_payload,
catalyst_registration_payload_signature,
)
return _hash_auxiliary_data(
cbor.encode(_wrap_metadata(cborized_catalyst_registration))
)
def _cborize_catalyst_registration( def _cborize_catalyst_registration(
catalyst_registration_payload: CatalystRegistrationPayload,
catalyst_registration_payload_signature: bytes,
) -> CatalystRegistration:
catalyst_registration_signature = {1: catalyst_registration_payload_signature}
return {
METADATA_KEY_CATALYST_REGISTRATION: catalyst_registration_payload,
METADATA_KEY_CATALYST_REGISTRATION_SIGNATURE: catalyst_registration_signature,
}
def _get_signed_catalyst_registration_payload(
keychain: seed.Keychain, keychain: seed.Keychain,
catalyst_registration_parameters: CardanoCatalystRegistrationParametersType, catalyst_registration_parameters: CardanoCatalystRegistrationParametersType,
protocol_magic: int, protocol_magic: int,
network_id: int, network_id: int,
) -> CatalystRegistration: ) -> SignedCatalystRegistrationPayload:
staking_key = derive_public_key( staking_key = derive_public_key(
keychain, catalyst_registration_parameters.staking_path keychain, catalyst_registration_parameters.staking_path
) )
catalyst_registration_payload: CatalystRegistrationPayload = { payload: CatalystRegistrationPayload = {
1: catalyst_registration_parameters.voting_public_key, 1: catalyst_registration_parameters.voting_public_key,
2: staking_key, 2: staking_key,
3: derive_address_bytes( 3: derive_address_bytes(
@ -180,19 +202,13 @@ def _cborize_catalyst_registration(
4: catalyst_registration_parameters.nonce, 4: catalyst_registration_parameters.nonce,
} }
catalyst_registration_payload_signature = ( signature = _create_catalyst_registration_payload_signature(
_create_catalyst_registration_payload_signature( keychain,
keychain, payload,
catalyst_registration_payload, catalyst_registration_parameters.staking_path,
catalyst_registration_parameters.staking_path,
)
) )
catalyst_registration_signature = {1: catalyst_registration_payload_signature}
return { return payload, signature
METADATA_KEY_CATALYST_REGISTRATION: catalyst_registration_payload,
METADATA_KEY_CATALYST_REGISTRATION_SIGNATURE: catalyst_registration_signature,
}
def _create_catalyst_registration_payload_signature( def _create_catalyst_registration_payload_signature(
@ -228,7 +244,7 @@ def _wrap_metadata(metadata: dict) -> tuple[dict, tuple]:
return metadata, () return metadata, ()
def hash_auxiliary_data(auxiliary_data: bytes) -> bytes: def _hash_auxiliary_data(auxiliary_data: bytes) -> bytes:
return hashlib.blake2b( return hashlib.blake2b(
data=auxiliary_data, outlen=AUXILIARY_DATA_HASH_SIZE data=auxiliary_data, outlen=AUXILIARY_DATA_HASH_SIZE
).digest() ).digest()

View File

@ -1,4 +1,8 @@
from trezor.enums import CardanoCertificateType, CardanoPoolRelayType from trezor.enums import (
CardanoCertificateType,
CardanoPoolRelayType,
CardanoTxSigningMode,
)
from apps.common import cbor from apps.common import cbor
@ -13,10 +17,10 @@ from .helpers.paths import SCHEMA_STAKING_ANY_ACCOUNT
if False: if False:
from trezor.messages import ( from trezor.messages import (
CardanoPoolMetadataType, CardanoPoolMetadataType,
CardanoPoolOwnerType, CardanoPoolOwner,
CardanoPoolParametersType, CardanoPoolParametersType,
CardanoPoolRelayParametersType, CardanoPoolRelayParameters,
CardanoTxCertificateType, CardanoTxCertificate,
) )
from apps.common.cbor import CborSequence from apps.common.cbor import CborSequence
@ -34,8 +38,22 @@ MAX_PORT_NUMBER = 65535
def validate_certificate( def validate_certificate(
certificate: CardanoTxCertificateType, protocol_magic: int, network_id: int certificate: CardanoTxCertificate,
signing_mode: CardanoTxSigningMode,
protocol_magic: int,
network_id: int,
) -> None: ) -> None:
if (
signing_mode == CardanoTxSigningMode.ORDINARY_TRANSACTION
and certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION
):
raise INVALID_CERTIFICATE
elif (
signing_mode == CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER
and certificate.type != CardanoCertificateType.STAKE_POOL_REGISTRATION
):
raise INVALID_CERTIFICATE
if certificate.type in ( if certificate.type in (
CardanoCertificateType.STAKE_DELEGATION, CardanoCertificateType.STAKE_DELEGATION,
CardanoCertificateType.STAKE_REGISTRATION, CardanoCertificateType.STAKE_REGISTRATION,
@ -57,7 +75,7 @@ def validate_certificate(
def cborize_certificate( def cborize_certificate(
keychain: seed.Keychain, certificate: CardanoTxCertificateType keychain: seed.Keychain, certificate: CardanoTxCertificate
) -> CborSequence: ) -> CborSequence:
if certificate.type in ( if certificate.type in (
CardanoCertificateType.STAKE_REGISTRATION, CardanoCertificateType.STAKE_REGISTRATION,
@ -73,35 +91,37 @@ def cborize_certificate(
(0, get_public_key_hash(keychain, certificate.path)), (0, get_public_key_hash(keychain, certificate.path)),
certificate.pool, certificate.pool,
) )
elif certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION:
pool_parameters = certificate.pool_parameters
assert pool_parameters is not None
return (
certificate.type,
pool_parameters.pool_id,
pool_parameters.vrf_key_hash,
pool_parameters.pledge,
pool_parameters.cost,
cbor.Tagged(
30,
(
pool_parameters.margin_numerator,
pool_parameters.margin_denominator,
),
),
# this relies on pool_parameters.reward_account being validated beforehand
# in _validate_pool_parameters
get_address_bytes_unsafe(pool_parameters.reward_account),
_cborize_pool_owners(keychain, pool_parameters.owners),
_cborize_pool_relays(pool_parameters.relays),
_cborize_pool_metadata(pool_parameters.metadata),
)
else: else:
raise INVALID_CERTIFICATE raise INVALID_CERTIFICATE
def cborize_initial_pool_registration_certificate_fields(
certificate: CardanoTxCertificate,
) -> CborSequence:
assert certificate.type == CardanoCertificateType.STAKE_POOL_REGISTRATION
pool_parameters = certificate.pool_parameters
assert pool_parameters is not None
return (
certificate.type,
pool_parameters.pool_id,
pool_parameters.vrf_key_hash,
pool_parameters.pledge,
pool_parameters.cost,
cbor.Tagged(
30,
(
pool_parameters.margin_numerator,
pool_parameters.margin_denominator,
),
),
# this relies on pool_parameters.reward_account being validated beforehand
# in _validate_pool_parameters
get_address_bytes_unsafe(pool_parameters.reward_account),
)
def assert_certificate_cond(condition: bool) -> None: def assert_certificate_cond(condition: bool) -> None:
if not condition: if not condition:
raise INVALID_CERTIFICATE raise INVALID_CERTIFICATE
@ -119,41 +139,27 @@ def _validate_pool_parameters(
assert_certificate_cond( assert_certificate_cond(
pool_parameters.margin_numerator <= pool_parameters.margin_denominator pool_parameters.margin_numerator <= pool_parameters.margin_denominator
) )
assert_certificate_cond(len(pool_parameters.owners) > 0) assert_certificate_cond(pool_parameters.owners_count > 0)
validate_reward_address(pool_parameters.reward_account, protocol_magic, network_id) validate_reward_address(pool_parameters.reward_account, protocol_magic, network_id)
for pool_relay in pool_parameters.relays:
_validate_pool_relay(pool_relay)
_validate_pool_owners(pool_parameters.owners)
if pool_parameters.metadata: if pool_parameters.metadata:
_validate_pool_metadata(pool_parameters.metadata) _validate_pool_metadata(pool_parameters.metadata)
def _validate_pool_owners(owners: list[CardanoPoolOwnerType]) -> None: def validate_pool_owner(owner: CardanoPoolOwner) -> None:
owners_as_path_count = 0 assert_certificate_cond(
for owner in owners: owner.staking_key_hash is not None or owner.staking_key_path is not None
)
if owner.staking_key_hash is not None:
assert_certificate_cond(len(owner.staking_key_hash) == ADDRESS_KEY_HASH_SIZE)
if owner.staking_key_path:
assert_certificate_cond( assert_certificate_cond(
owner.staking_key_hash is not None or owner.staking_key_path is not None SCHEMA_STAKING_ANY_ACCOUNT.match(owner.staking_key_path)
) )
if owner.staking_key_hash is not None:
assert_certificate_cond(
len(owner.staking_key_hash) == ADDRESS_KEY_HASH_SIZE
)
if owner.staking_key_path:
assert_certificate_cond(
SCHEMA_STAKING_ANY_ACCOUNT.match(owner.staking_key_path)
)
if owner.staking_key_path:
owners_as_path_count += 1
assert_certificate_cond(owners_as_path_count == 1)
def _validate_pool_relay(pool_relay: CardanoPoolRelayParametersType) -> None: def validate_pool_relay(pool_relay: CardanoPoolRelayParameters) -> None:
if pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_IP: if pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_IP:
assert_certificate_cond( assert_certificate_cond(
pool_relay.ipv4_address is not None or pool_relay.ipv6_address is not None pool_relay.ipv4_address is not None or pool_relay.ipv6_address is not None
@ -188,20 +194,13 @@ def _validate_pool_metadata(pool_metadata: CardanoPoolMetadataType) -> None:
assert_certificate_cond(all((32 <= ord(c) < 127) for c in pool_metadata.url)) assert_certificate_cond(all((32 <= ord(c) < 127) for c in pool_metadata.url))
def _cborize_pool_owners( def cborize_pool_owner(keychain: seed.Keychain, pool_owner: CardanoPoolOwner) -> bytes:
keychain: seed.Keychain, pool_owners: list[CardanoPoolOwnerType] if pool_owner.staking_key_path:
) -> list[bytes]: return get_public_key_hash(keychain, pool_owner.staking_key_path)
result = [] elif pool_owner.staking_key_hash:
return pool_owner.staking_key_hash
for pool_owner in pool_owners: else:
if pool_owner.staking_key_path: raise ValueError
result.append(get_public_key_hash(keychain, pool_owner.staking_key_path))
elif pool_owner.staking_key_hash:
result.append(pool_owner.staking_key_hash)
else:
raise ValueError
return result
def _cborize_ipv6_address(ipv6_address: bytes | None) -> bytes | None: def _cborize_ipv6_address(ipv6_address: bytes | None) -> bytes | None:
@ -218,41 +217,32 @@ def _cborize_ipv6_address(ipv6_address: bytes | None) -> bytes | None:
return result return result
def _cborize_pool_relays( def cborize_pool_relay(
pool_relays: list[CardanoPoolRelayParametersType], pool_relay: CardanoPoolRelayParameters,
) -> list[CborSequence]: ) -> CborSequence:
result: list[CborSequence] = [] if pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_IP:
return (
for pool_relay in pool_relays: pool_relay.type,
if pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_IP: pool_relay.port,
result.append( pool_relay.ipv4_address,
( _cborize_ipv6_address(pool_relay.ipv6_address),
pool_relay.type, )
pool_relay.port, elif pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_NAME:
pool_relay.ipv4_address, return (
_cborize_ipv6_address(pool_relay.ipv6_address), pool_relay.type,
) pool_relay.port,
) pool_relay.host_name,
elif pool_relay.type == CardanoPoolRelayType.SINGLE_HOST_NAME: )
result.append( elif pool_relay.type == CardanoPoolRelayType.MULTIPLE_HOST_NAME:
( return (
pool_relay.type, pool_relay.type,
pool_relay.port, pool_relay.host_name,
pool_relay.host_name, )
) else:
) raise INVALID_CERTIFICATE
elif pool_relay.type == CardanoPoolRelayType.MULTIPLE_HOST_NAME:
result.append(
(
pool_relay.type,
pool_relay.host_name,
)
)
return result
def _cborize_pool_metadata( def cborize_pool_metadata(
pool_metadata: CardanoPoolMetadataType | None, pool_metadata: CardanoPoolMetadataType | None,
) -> CborSequence | None: ) -> CborSequence | None:
if not pool_metadata: if not pool_metadata:

View File

@ -10,8 +10,8 @@ INVALID_AUXILIARY_DATA = wire.ProcessError("Invalid auxiliary data")
INVALID_STAKE_POOL_REGISTRATION_TX_STRUCTURE = wire.ProcessError( INVALID_STAKE_POOL_REGISTRATION_TX_STRUCTURE = wire.ProcessError(
"Stakepool registration transaction cannot contain other certificates nor withdrawals" "Stakepool registration transaction cannot contain other certificates nor withdrawals"
) )
INVALID_STAKEPOOL_REGISTRATION_TX_INPUTS = wire.ProcessError( INVALID_STAKEPOOL_REGISTRATION_TX_WITNESSES = wire.ProcessError(
"Stakepool registration transaction can contain only external inputs" "Stakepool registration transaction can only contain staking witnesses"
) )
LOVELACE_MAX_SUPPLY = 45_000_000_000 * 1_000_000 LOVELACE_MAX_SUPPLY = 45_000_000_000 * 1_000_000

View File

@ -0,0 +1,98 @@
from apps.common import cbor
if False:
from typing import Any, Generic, TypeVar
from trezor.utils import HashContext
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")
else:
T = 0 # type: ignore
K = 0 # type: ignore
V = 0 # type: ignore
Generic = {T: object, (K, V): object} # type: ignore
class HashBuilderCollection:
def __init__(self, size: int) -> None:
self.size = size
self.remaining = size
self.hash_fn: HashContext | None = None
self.parent: "HashBuilderCollection" | None = None
self.has_unfinished_child = False
def start(self, hash_fn: HashContext) -> "HashBuilderCollection":
self.hash_fn = hash_fn
self.hash_fn.update(self._header_bytes())
return self
def _insert_child(self, child: "HashBuilderCollection") -> None:
child.parent = self
assert self.hash_fn is not None
child.start(self.hash_fn)
self.has_unfinished_child = True
def _do_enter_item(self) -> None:
assert self.hash_fn is not None
assert self.remaining > 0
if self.has_unfinished_child:
raise RuntimeError # can't add item until child is finished
self.remaining -= 1
def _hash_item(self, item: Any) -> None:
assert self.hash_fn is not None
for chunk in cbor.encode_streamed(item):
self.hash_fn.update(chunk)
def _header_bytes(self) -> bytes:
raise NotImplementedError
def finish(self) -> None:
if self.remaining != 0:
raise RuntimeError # not all items were added
if self.parent is not None:
self.parent.has_unfinished_child = False
self.hash_fn = None
self.parent = None
def __enter__(self) -> "HashBuilderCollection":
assert self.hash_fn is not None
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if exc_type is None:
self.finish()
class HashBuilderList(HashBuilderCollection, Generic[T]):
def append(self, item: T) -> T:
self._do_enter_item()
if isinstance(item, HashBuilderCollection):
self._insert_child(item)
else:
self._hash_item(item)
return item
def _header_bytes(self) -> bytes:
return cbor.create_array_header(self.size)
class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
def add(self, key: K, value: V) -> V:
self._do_enter_item()
# enter key, this must not nest
assert not isinstance(key, HashBuilderCollection)
self._hash_item(key)
# enter value, this can nest
if isinstance(value, HashBuilderCollection):
self._insert_child(value)
else:
self._hash_item(value)
return value
def _header_bytes(self) -> bytes:
return cbor.create_map_header(self.size)

View File

@ -26,6 +26,7 @@ CHANGE_OUTPUT_PATH_NAME = "Change output path"
CHANGE_OUTPUT_STAKING_PATH_NAME = "Change output staking path" CHANGE_OUTPUT_STAKING_PATH_NAME = "Change output staking path"
CERTIFICATE_PATH_NAME = "Certificate path" CERTIFICATE_PATH_NAME = "Certificate path"
POOL_OWNER_STAKING_PATH_NAME = "Pool owner staking path" POOL_OWNER_STAKING_PATH_NAME = "Pool owner staking path"
WITNESS_PATH_NAME = "Witness path"
def unharden(item: int) -> int: def unharden(item: int) -> int:

View File

@ -29,12 +29,12 @@ if False:
from trezor import wire from trezor import wire
from trezor.messages import ( from trezor.messages import (
CardanoBlockchainPointerType, CardanoBlockchainPointerType,
CardanoTxCertificateType, CardanoTxCertificate,
CardanoTxWithdrawalType, CardanoTxWithdrawal,
CardanoPoolParametersType, CardanoPoolParametersType,
CardanoPoolOwnerType, CardanoPoolOwner,
CardanoPoolMetadataType, CardanoPoolMetadataType,
CardanoAssetGroupType, CardanoToken,
) )
from trezor.ui.layouts import PropertyType from trezor.ui.layouts import PropertyType
@ -67,11 +67,8 @@ def is_printable_ascii_bytestring(bytestr: bytes) -> bool:
async def confirm_sending( async def confirm_sending(
ctx: wire.Context, ctx: wire.Context,
ada_amount: int, ada_amount: int,
token_bundle: list[CardanoAssetGroupType],
to: str, to: str,
) -> None: ) -> None:
await confirm_sending_token_bundle(ctx, token_bundle)
await confirm_output( await confirm_output(
ctx, ctx,
to, to,
@ -86,27 +83,25 @@ async def confirm_sending(
) )
async def confirm_sending_token_bundle( async def confirm_sending_token(
ctx: wire.Context, token_bundle: list[CardanoAssetGroupType] ctx: wire.Context, policy_id: bytes, token: CardanoToken
) -> None: ) -> None:
for token_group in token_bundle: await confirm_properties(
for token in token_group.tokens: ctx,
await confirm_properties( "confirm_token",
ctx, title="Confirm transaction",
"confirm_token", props=[
title="Confirm transaction", (
props=[ "Asset fingerprint:",
( format_asset_fingerprint(
"Asset fingerprint:", policy_id=policy_id,
format_asset_fingerprint( asset_name_bytes=token.asset_name_bytes,
policy_id=token_group.policy_id, ),
asset_name_bytes=token.asset_name_bytes, ),
), ("Amount sent:", format_amount(token.amount, 0)),
), ],
("Amount sent:", format_amount(token.amount, 0)), br_code=ButtonRequestType.Other,
], )
br_code=ButtonRequestType.Other,
)
async def show_warning_tx_output_contains_tokens(ctx: wire.Context) -> None: async def show_warning_tx_output_contains_tokens(ctx: wire.Context) -> None:
@ -212,7 +207,6 @@ async def show_warning_tx_staking_key_hash(
async def confirm_transaction( async def confirm_transaction(
ctx: wire.Context, ctx: wire.Context,
amount: int,
fee: int, fee: int,
protocol_magic: int, protocol_magic: int,
ttl: int | None, ttl: int | None,
@ -220,12 +214,13 @@ async def confirm_transaction(
is_network_id_verifiable: bool, is_network_id_verifiable: bool,
) -> None: ) -> None:
props: list[PropertyType] = [ props: list[PropertyType] = [
("Transaction amount:", format_coin_amount(amount)),
("Transaction fee:", format_coin_amount(fee)), ("Transaction fee:", format_coin_amount(fee)),
] ]
if is_network_id_verifiable: if is_network_id_verifiable:
props.append(("Network:", protocol_magics.to_ui_string(protocol_magic))) props.append(
("Network: %s" % protocol_magics.to_ui_string(protocol_magic), None)
)
props.append( props.append(
("Valid since: %s" % format_optional_int(validity_interval_start), None) ("Valid since: %s" % format_optional_int(validity_interval_start), None)
@ -243,7 +238,7 @@ async def confirm_transaction(
async def confirm_certificate( async def confirm_certificate(
ctx: wire.Context, certificate: CardanoTxCertificateType ctx: wire.Context, certificate: CardanoTxCertificate
) -> None: ) -> None:
# stake pool registration requires custom confirmation logic not covered # stake pool registration requires custom confirmation logic not covered
# in this call # in this call
@ -270,10 +265,7 @@ async def confirm_certificate(
async def confirm_stake_pool_parameters( async def confirm_stake_pool_parameters(
ctx: wire.Context, ctx: wire.Context, pool_parameters: CardanoPoolParametersType
pool_parameters: CardanoPoolParametersType,
network_id: int,
protocol_magic: int,
) -> None: ) -> None:
margin_percentage = ( margin_percentage = (
100.0 * pool_parameters.margin_numerator / pool_parameters.margin_denominator 100.0 * pool_parameters.margin_numerator / pool_parameters.margin_denominator
@ -302,39 +294,36 @@ async def confirm_stake_pool_parameters(
) )
async def confirm_stake_pool_owners( async def confirm_stake_pool_owner(
ctx: wire.Context, ctx: wire.Context,
keychain: seed.Keychain, keychain: seed.Keychain,
owners: list[CardanoPoolOwnerType], owner: CardanoPoolOwner,
network_id: int, network_id: int,
) -> None: ) -> None:
props: list[tuple[str, str | None]] = [] props: list[tuple[str, str | None]] = []
for index, owner in enumerate(owners, 1): if owner.staking_key_path:
if owner.staking_key_path: props.append(("Pool owner:", address_n_to_str(owner.staking_key_path)))
props.append( props.append(
("Pool owner #%d:" % index, address_n_to_str(owner.staking_key_path)) (
encode_human_readable_address(
pack_reward_address_bytes(
get_public_key_hash(keychain, owner.staking_key_path),
network_id,
)
),
None,
) )
props.append( )
( else:
encode_human_readable_address( assert owner.staking_key_hash is not None # validate_pool_owners
pack_reward_address_bytes( props.append(
get_public_key_hash(keychain, owner.staking_key_path), (
network_id, "Pool owner:",
) encode_human_readable_address(
), pack_reward_address_bytes(owner.staking_key_hash, network_id)
None, ),
)
)
else:
assert owner.staking_key_hash is not None # validate_pool_owners
props.append(
(
"Pool owner #%d:" % index,
encode_human_readable_address(
pack_reward_address_bytes(owner.staking_key_hash, network_id)
),
)
) )
)
await confirm_properties( await confirm_properties(
ctx, ctx,
@ -371,44 +360,29 @@ async def confirm_stake_pool_metadata(
) )
async def confirm_transaction_network_ttl( async def confirm_stake_pool_registration_final(
ctx: wire.Context, ctx: wire.Context,
protocol_magic: int, protocol_magic: int,
ttl: int | None, ttl: int | None,
validity_interval_start: int | None, validity_interval_start: int | None,
) -> None: ) -> None:
await confirm_properties( await confirm_properties(
ctx,
"confirm_pool_network",
title="Confirm transaction",
props=[
("Network:", protocol_magics.to_ui_string(protocol_magic)),
(
"Valid since: %s" % format_optional_int(validity_interval_start),
None,
),
("TTL: %s" % format_optional_int(ttl), None),
],
br_code=ButtonRequestType.Other,
)
async def confirm_stake_pool_registration_final(
ctx: wire.Context,
) -> None:
await confirm_metadata(
ctx, ctx,
"confirm_pool_final", "confirm_pool_final",
title="Confirm transaction", title="Confirm transaction",
content="Confirm signing the stake pool registration as an owner", props=[
hide_continue=True, ("Confirm signing the stake pool registration as an owner.", None),
("Network:", protocol_magics.to_ui_string(protocol_magic)),
("Valid since:", format_optional_int(validity_interval_start)),
("TTL:", format_optional_int(ttl)),
],
hold=True, hold=True,
br_code=ButtonRequestType.Other, br_code=ButtonRequestType.Other,
) )
async def confirm_withdrawal( async def confirm_withdrawal(
ctx: wire.Context, withdrawal: CardanoTxWithdrawalType ctx: wire.Context, withdrawal: CardanoTxWithdrawal
) -> None: ) -> None:
await confirm_properties( await confirm_properties(
ctx, ctx,

File diff suppressed because it is too large Load Diff

View File

@ -309,3 +309,11 @@ def decode(cbor: bytes, offset: int = 0) -> Value:
if r.remaining_count(): if r.remaining_count():
raise ValueError raise ValueError
return res return res
def create_array_header(size: int) -> bytes:
return _header(_CBOR_ARRAY, size)
def create_map_header(size: int) -> bytes:
return _header(_CBOR_MAP, size)

View File

@ -149,7 +149,7 @@ def find_message_handler_module(msg_type: int) -> str:
return "apps.cardano.get_address" return "apps.cardano.get_address"
elif msg_type == MessageType.CardanoGetPublicKey: elif msg_type == MessageType.CardanoGetPublicKey:
return "apps.cardano.get_public_key" return "apps.cardano.get_public_key"
elif msg_type == MessageType.CardanoSignTx: elif msg_type == MessageType.CardanoSignTxInit:
return "apps.cardano.sign_tx" return "apps.cardano.sign_tx"
# tezos # tezos

View File

@ -6,13 +6,48 @@ from apps.common.cbor import (
IndefiniteLengthArray, IndefiniteLengthArray,
OrderedMap, OrderedMap,
Tagged, Tagged,
create_array_header,
create_map_header,
decode, decode,
encode, encode,
encode_chunked, encode_chunked,
encode_streamed, encode_streamed,
) )
class TestCardanoCbor(unittest.TestCase): class TestCardanoCbor(unittest.TestCase):
def test_create_array_header(self):
test_vectors = [
(0, '80'),
(23, '97'),
((2 ** 8) - 1, '98ff'),
((2 ** 16) - 1, '99ffff'),
((2 ** 32) - 1, '9affffffff'),
((2 ** 64) - 1, '9bffffffffffffffff'),
]
for val, header_hex in test_vectors:
header = unhexlify(header_hex)
self.assertEqual(create_array_header(val), header)
with self.assertRaises(NotImplementedError):
create_array_header(2 ** 64)
def test_create_map_header(self):
test_vectors = [
(0, 'a0'),
(23, 'b7'),
((2 ** 8) - 1, 'b8ff'),
((2 ** 16) - 1, 'b9ffff'),
((2 ** 32) - 1, 'baffffffff'),
((2 ** 64) - 1, 'bbffffffffffffffff'),
]
for val, header_hex in test_vectors:
header = unhexlify(header_hex)
self.assertEqual(create_map_header(val), header)
with self.assertRaises(NotImplementedError):
create_map_header(2 ** 64)
def test_cbor_encoding(self): def test_cbor_encoding(self):
test_vectors = [ test_vectors = [
# unsigned integers # unsigned integers

View File

@ -15,11 +15,17 @@
# If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>. # If not, see <https://www.gnu.org/licenses/lgpl-3.0.html>.
from ipaddress import ip_address from ipaddress import ip_address
from typing import List, Optional from itertools import chain
from typing import Dict, Iterator, List, Optional, Tuple, Union
from . import exceptions, messages, tools from . import exceptions, messages, tools
from .tools import expect from .tools import expect
SIGNING_MODE_IDS = {
"ORDINARY_TRANSACTION": messages.CardanoTxSigningMode.ORDINARY_TRANSACTION,
"POOL_REGISTRATION_AS_OWNER": messages.CardanoTxSigningMode.POOL_REGISTRATION_AS_OWNER,
}
PROTOCOL_MAGICS = {"mainnet": 764824073, "testnet": 42} PROTOCOL_MAGICS = {"mainnet": 764824073, "testnet": 42}
NETWORK_IDS = {"mainnet": 1, "testnet": 0} NETWORK_IDS = {"mainnet": 1, "testnet": 0}
@ -57,6 +63,28 @@ ADDRESS_TYPES = (
messages.CardanoAddressType.REWARD, messages.CardanoAddressType.REWARD,
) )
InputWithPath = Tuple[messages.CardanoTxInput, List[int]]
AssetGroupWithTokens = Tuple[messages.CardanoAssetGroup, List[messages.CardanoToken]]
OutputWithAssetGroups = Tuple[messages.CardanoTxOutput, List[AssetGroupWithTokens]]
OutputItem = Union[
messages.CardanoTxOutput, messages.CardanoAssetGroup, messages.CardanoToken
]
CertificateItem = Union[
messages.CardanoTxCertificate,
messages.CardanoPoolOwner,
messages.CardanoPoolRelayParameters,
]
PoolOwnersAndRelays = Tuple[
List[messages.CardanoPoolOwner], List[messages.CardanoPoolRelayParameters]
]
CertificateWithPoolOwnersAndRelays = Tuple[
messages.CardanoTxCertificate, Optional[PoolOwnersAndRelays]
]
Path = List[int]
Witness = Tuple[Path, bytes]
AuxiliaryDataSupplement = Dict[str, Union[int, bytes]]
SignTxResponse = Dict[str, Union[bytes, List[Witness], AuxiliaryDataSupplement]]
def create_address_parameters( def create_address_parameters(
address_type: messages.CardanoAddressType, address_type: messages.CardanoAddressType,
@ -97,18 +125,21 @@ def _create_certificate_pointer(
) )
def parse_input(tx_input) -> messages.CardanoTxInputType: def parse_input(tx_input) -> InputWithPath:
if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT): if not all(k in tx_input for k in REQUIRED_FIELDS_INPUT):
raise ValueError("The input is missing some fields") raise ValueError("The input is missing some fields")
return messages.CardanoTxInputType( path = tools.parse_path(tx_input.get("path"))
address_n=tools.parse_path(tx_input.get("path")), return (
prev_hash=bytes.fromhex(tx_input["prev_hash"]), messages.CardanoTxInput(
prev_index=tx_input["prev_index"], prev_hash=bytes.fromhex(tx_input["prev_hash"]),
prev_index=tx_input["prev_index"],
),
path,
) )
def parse_output(output) -> messages.CardanoTxOutputType: def parse_output(output) -> OutputWithAssetGroups:
contains_address = "address" in output contains_address = "address" in output
contains_address_type = "addressType" in output contains_address_type = "addressType" in output
@ -119,7 +150,7 @@ def parse_output(output) -> messages.CardanoTxOutputType:
address = None address = None
address_parameters = None address_parameters = None
token_bundle = None token_bundle = []
if contains_address: if contains_address:
address = output["address"] address = output["address"]
@ -130,38 +161,46 @@ def parse_output(output) -> messages.CardanoTxOutputType:
if "token_bundle" in output: if "token_bundle" in output:
token_bundle = _parse_token_bundle(output["token_bundle"]) token_bundle = _parse_token_bundle(output["token_bundle"])
return messages.CardanoTxOutputType( return (
address=address, messages.CardanoTxOutput(
address_parameters=address_parameters, address=address,
amount=int(output["amount"]), address_parameters=address_parameters,
token_bundle=token_bundle, amount=int(output["amount"]),
asset_groups_count=len(token_bundle),
),
token_bundle,
) )
def _parse_token_bundle(token_bundle) -> List[messages.CardanoAssetGroupType]: def _parse_token_bundle(token_bundle) -> List[AssetGroupWithTokens]:
result = [] result = []
for token_group in token_bundle: for token_group in token_bundle:
if not all(k in token_group for k in REQUIRED_FIELDS_TOKEN_GROUP): if not all(k in token_group for k in REQUIRED_FIELDS_TOKEN_GROUP):
raise ValueError(INVALID_OUTPUT_TOKEN_BUNDLE_ENTRY) raise ValueError(INVALID_OUTPUT_TOKEN_BUNDLE_ENTRY)
tokens = _parse_tokens(token_group["tokens"])
result.append( result.append(
messages.CardanoAssetGroupType( (
policy_id=bytes.fromhex(token_group["policy_id"]), messages.CardanoAssetGroup(
tokens=_parse_tokens(token_group["tokens"]), policy_id=bytes.fromhex(token_group["policy_id"]),
tokens_count=len(tokens),
),
tokens,
) )
) )
return result return result
def _parse_tokens(tokens) -> List[messages.CardanoTokenType]: def _parse_tokens(tokens) -> List[messages.CardanoToken]:
result = [] result = []
for token in tokens: for token in tokens:
if not all(k in token for k in REQUIRED_FIELDS_TOKEN): if not all(k in token for k in REQUIRED_FIELDS_TOKEN):
raise ValueError(INVALID_OUTPUT_TOKEN_BUNDLE_ENTRY) raise ValueError(INVALID_OUTPUT_TOKEN_BUNDLE_ENTRY)
result.append( result.append(
messages.CardanoTokenType( messages.CardanoToken(
asset_name_bytes=bytes.fromhex(token["asset_name_bytes"]), asset_name_bytes=bytes.fromhex(token["asset_name_bytes"]),
amount=int(token["amount"]), amount=int(token["amount"]),
) )
@ -191,7 +230,7 @@ def _parse_address_parameters(
) )
def parse_certificate(certificate) -> messages.CardanoTxCertificateType: def parse_certificate(certificate) -> CertificateWithPoolOwnersAndRelays:
CERTIFICATE_MISSING_FIELDS_ERROR = ValueError( CERTIFICATE_MISSING_FIELDS_ERROR = ValueError(
"The certificate is missing some fields" "The certificate is missing some fields"
) )
@ -205,10 +244,13 @@ def parse_certificate(certificate) -> messages.CardanoTxCertificateType:
if "pool" not in certificate: if "pool" not in certificate:
raise CERTIFICATE_MISSING_FIELDS_ERROR raise CERTIFICATE_MISSING_FIELDS_ERROR
return messages.CardanoTxCertificateType( return (
type=certificate_type, messages.CardanoTxCertificate(
path=tools.parse_path(certificate["path"]), type=certificate_type,
pool=bytes.fromhex(certificate["pool"]), path=tools.parse_path(certificate["path"]),
pool=bytes.fromhex(certificate["pool"]),
),
None,
) )
elif certificate_type in ( elif certificate_type in (
messages.CardanoCertificateType.STAKE_REGISTRATION, messages.CardanoCertificateType.STAKE_REGISTRATION,
@ -216,9 +258,12 @@ def parse_certificate(certificate) -> messages.CardanoTxCertificateType:
): ):
if "path" not in certificate: if "path" not in certificate:
raise CERTIFICATE_MISSING_FIELDS_ERROR raise CERTIFICATE_MISSING_FIELDS_ERROR
return messages.CardanoTxCertificateType( return (
type=certificate_type, messages.CardanoTxCertificate(
path=tools.parse_path(certificate["path"]), type=certificate_type,
path=tools.parse_path(certificate["path"]),
),
None,
) )
elif certificate_type == messages.CardanoCertificateType.STAKE_POOL_REGISTRATION: elif certificate_type == messages.CardanoCertificateType.STAKE_POOL_REGISTRATION:
pool_parameters = certificate["pool_parameters"] pool_parameters = certificate["pool_parameters"]
@ -237,45 +282,49 @@ def parse_certificate(certificate) -> messages.CardanoTxCertificateType:
else: else:
pool_metadata = None pool_metadata = None
return messages.CardanoTxCertificateType( owners = [
type=certificate_type, _parse_pool_owner(pool_owner)
pool_parameters=messages.CardanoPoolParametersType( for pool_owner in pool_parameters.get("owners", [])
pool_id=bytes.fromhex(pool_parameters["pool_id"]), ]
vrf_key_hash=bytes.fromhex(pool_parameters["vrf_key_hash"]), relays = [
pledge=int(pool_parameters["pledge"]), _parse_pool_relay(pool_relay)
cost=int(pool_parameters["cost"]), for pool_relay in pool_parameters.get("relays", [])
margin_numerator=int(pool_parameters["margin"]["numerator"]), ]
margin_denominator=int(pool_parameters["margin"]["denominator"]),
reward_account=pool_parameters["reward_account"], return (
metadata=pool_metadata, messages.CardanoTxCertificate(
owners=[ type=certificate_type,
_parse_pool_owner(pool_owner) pool_parameters=messages.CardanoPoolParametersType(
for pool_owner in pool_parameters.get("owners", []) pool_id=bytes.fromhex(pool_parameters["pool_id"]),
], vrf_key_hash=bytes.fromhex(pool_parameters["vrf_key_hash"]),
relays=[ pledge=int(pool_parameters["pledge"]),
_parse_pool_relay(pool_relay) cost=int(pool_parameters["cost"]),
for pool_relay in pool_parameters.get("relays", []) margin_numerator=int(pool_parameters["margin"]["numerator"]),
] margin_denominator=int(pool_parameters["margin"]["denominator"]),
if "relays" in pool_parameters reward_account=pool_parameters["reward_account"],
else [], metadata=pool_metadata,
owners_count=len(owners),
relays_count=len(relays),
),
), ),
(owners, relays),
) )
else: else:
raise ValueError("Unknown certificate type") raise ValueError("Unknown certificate type")
def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwnerType: def _parse_pool_owner(pool_owner) -> messages.CardanoPoolOwner:
if "staking_key_path" in pool_owner: if "staking_key_path" in pool_owner:
return messages.CardanoPoolOwnerType( return messages.CardanoPoolOwner(
staking_key_path=tools.parse_path(pool_owner["staking_key_path"]) staking_key_path=tools.parse_path(pool_owner["staking_key_path"])
) )
return messages.CardanoPoolOwnerType( return messages.CardanoPoolOwner(
staking_key_hash=bytes.fromhex(pool_owner["staking_key_hash"]) staking_key_hash=bytes.fromhex(pool_owner["staking_key_hash"])
) )
def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParametersType: def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParameters:
pool_relay_type = int(pool_relay["type"]) pool_relay_type = int(pool_relay["type"])
if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP: if pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_IP:
@ -290,20 +339,20 @@ def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParametersType:
else None else None
) )
return messages.CardanoPoolRelayParametersType( return messages.CardanoPoolRelayParameters(
type=pool_relay_type, type=pool_relay_type,
port=int(pool_relay["port"]), port=int(pool_relay["port"]),
ipv4_address=ipv4_address_packed, ipv4_address=ipv4_address_packed,
ipv6_address=ipv6_address_packed, ipv6_address=ipv6_address_packed,
) )
elif pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_NAME: elif pool_relay_type == messages.CardanoPoolRelayType.SINGLE_HOST_NAME:
return messages.CardanoPoolRelayParametersType( return messages.CardanoPoolRelayParameters(
type=pool_relay_type, type=pool_relay_type,
port=int(pool_relay["port"]), port=int(pool_relay["port"]),
host_name=pool_relay["host_name"], host_name=pool_relay["host_name"],
) )
elif pool_relay_type == messages.CardanoPoolRelayType.MULTIPLE_HOST_NAME: elif pool_relay_type == messages.CardanoPoolRelayType.MULTIPLE_HOST_NAME:
return messages.CardanoPoolRelayParametersType( return messages.CardanoPoolRelayParameters(
type=pool_relay_type, type=pool_relay_type,
host_name=pool_relay["host_name"], host_name=pool_relay["host_name"],
) )
@ -311,18 +360,18 @@ def _parse_pool_relay(pool_relay) -> messages.CardanoPoolRelayParametersType:
raise ValueError("Unknown pool relay type") raise ValueError("Unknown pool relay type")
def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawalType: def parse_withdrawal(withdrawal) -> messages.CardanoTxWithdrawal:
if not all(k in withdrawal for k in REQUIRED_FIELDS_WITHDRAWAL): if not all(k in withdrawal for k in REQUIRED_FIELDS_WITHDRAWAL):
raise ValueError("Withdrawal is missing some fields") raise ValueError("Withdrawal is missing some fields")
path = withdrawal["path"] path = withdrawal["path"]
return messages.CardanoTxWithdrawalType( return messages.CardanoTxWithdrawal(
path=tools.parse_path(path), path=tools.parse_path(path),
amount=int(withdrawal["amount"]), amount=int(withdrawal["amount"]),
) )
def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryDataType: def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryData:
if auxiliary_data is None: if auxiliary_data is None:
return None return None
@ -331,9 +380,9 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryDataType:
) )
# include all provided fields so we can test validation in FW # include all provided fields so we can test validation in FW
blob = None hash = None
if "blob" in auxiliary_data: if "hash" in auxiliary_data:
blob = bytes.fromhex(auxiliary_data["blob"]) hash = bytes.fromhex(auxiliary_data["hash"])
catalyst_registration_parameters = None catalyst_registration_parameters = None
if "catalyst_registration_parameters" in auxiliary_data: if "catalyst_registration_parameters" in auxiliary_data:
@ -356,15 +405,68 @@ def parse_auxiliary_data(auxiliary_data) -> messages.CardanoTxAuxiliaryDataType:
) )
) )
if blob is None and catalyst_registration_parameters is None: if hash is None and catalyst_registration_parameters is None:
raise AUXILIARY_DATA_MISSING_FIELDS_ERROR raise AUXILIARY_DATA_MISSING_FIELDS_ERROR
return messages.CardanoTxAuxiliaryDataType( return messages.CardanoTxAuxiliaryData(
blob=blob, hash=hash,
catalyst_registration_parameters=catalyst_registration_parameters, catalyst_registration_parameters=catalyst_registration_parameters,
) )
def _get_witness_paths(
inputs: List[InputWithPath],
certificates: List[CertificateWithPoolOwnersAndRelays],
withdrawals: List[messages.CardanoTxWithdrawal],
) -> List[Path]:
paths = set()
for _, path in inputs:
if path:
paths.add(tuple(path))
for certificate, pool_owners_and_relays in certificates:
if certificate.type in (
messages.CardanoCertificateType.STAKE_DEREGISTRATION,
messages.CardanoCertificateType.STAKE_DELEGATION,
):
paths.add(tuple(certificate.path))
elif (
certificate.type == messages.CardanoCertificateType.STAKE_POOL_REGISTRATION
and pool_owners_and_relays is not None
):
owners, _ = pool_owners_and_relays
for pool_owner in owners:
if pool_owner.staking_key_path:
paths.add(tuple(pool_owner.staking_key_path))
for withdrawal in withdrawals:
paths.add(tuple(withdrawal.path))
return sorted([list(path) for path in paths])
def _get_input_items(inputs: List[InputWithPath]) -> Iterator[messages.CardanoTxInput]:
for input, _ in inputs:
yield input
def _get_output_items(outputs: List[OutputWithAssetGroups]) -> Iterator[OutputItem]:
for output, asset_groups in outputs:
yield output
for asset_group, tokens in asset_groups:
yield asset_group
yield from tokens
def _get_certificate_items(
certificates: List[CertificateWithPoolOwnersAndRelays],
) -> Iterator[CertificateItem]:
for certificate, pool_owners_and_relays in certificates:
yield certificate
if pool_owners_and_relays is not None:
owners, relays = pool_owners_and_relays
yield from owners
yield from relays
# ====== Client functions ====== # # ====== Client functions ====== #
@ -391,44 +493,94 @@ def get_public_key(client, address_n: List[int]) -> messages.CardanoPublicKey:
return client.call(messages.CardanoGetPublicKey(address_n=address_n)) return client.call(messages.CardanoGetPublicKey(address_n=address_n))
@expect(messages.CardanoSignedTx)
def sign_tx( def sign_tx(
client, client,
inputs: List[messages.CardanoTxInputType], signing_mode: messages.CardanoTxSigningMode,
outputs: List[messages.CardanoTxOutputType], inputs: List[InputWithPath],
outputs: List[OutputWithAssetGroups],
fee: int, fee: int,
ttl: Optional[int], ttl: Optional[int],
validity_interval_start: Optional[int], validity_interval_start: Optional[int],
certificates: List[messages.CardanoTxCertificateType] = (), certificates: List[CertificateWithPoolOwnersAndRelays] = (),
withdrawals: List[messages.CardanoTxWithdrawalType] = (), withdrawals: List[messages.CardanoTxWithdrawal] = (),
protocol_magic: int = PROTOCOL_MAGICS["mainnet"], protocol_magic: int = PROTOCOL_MAGICS["mainnet"],
network_id: int = NETWORK_IDS["mainnet"], network_id: int = NETWORK_IDS["mainnet"],
auxiliary_data: messages.CardanoTxAuxiliaryDataType = None, auxiliary_data: messages.CardanoTxAuxiliaryData = None,
) -> messages.CardanoSignedTx: ) -> SignTxResponse:
UNEXPECTED_RESPONSE_ERROR = exceptions.TrezorException("Unexpected response")
witness_paths = _get_witness_paths(inputs, certificates, withdrawals)
response = client.call( response = client.call(
messages.CardanoSignTx( messages.CardanoSignTxInit(
inputs=inputs, signing_mode=signing_mode,
outputs=outputs, inputs_count=len(inputs),
outputs_count=len(outputs),
fee=fee, fee=fee,
ttl=ttl, ttl=ttl,
validity_interval_start=validity_interval_start, validity_interval_start=validity_interval_start,
certificates=certificates, certificates_count=len(certificates),
withdrawals=withdrawals, withdrawals_count=len(withdrawals),
protocol_magic=protocol_magic, protocol_magic=protocol_magic,
network_id=network_id, network_id=network_id,
auxiliary_data=auxiliary_data, has_auxiliary_data=auxiliary_data is not None,
witness_requests_count=len(witness_paths),
) )
) )
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
result = bytearray() for tx_item in chain(
while isinstance(response, messages.CardanoSignedTxChunk): _get_input_items(inputs),
result.extend(response.signed_tx_chunk) _get_output_items(outputs),
response = client.call(messages.CardanoSignedTxChunkAck()) _get_certificate_items(certificates),
withdrawals,
):
response = client.call(tx_item)
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
if not isinstance(response, messages.CardanoSignedTx): sign_tx_response = {}
raise exceptions.TrezorException("Unexpected response")
if response.serialized_tx is not None: if auxiliary_data is not None:
result.extend(response.serialized_tx) auxiliary_data_supplement = client.call(auxiliary_data)
if not isinstance(
auxiliary_data_supplement, messages.CardanoTxAuxiliaryDataSupplement
):
raise UNEXPECTED_RESPONSE_ERROR
if (
auxiliary_data_supplement.type
!= messages.CardanoTxAuxiliaryDataSupplementType.NONE
):
sign_tx_response[
"auxiliary_data_supplement"
] = auxiliary_data_supplement.__dict__
return messages.CardanoSignedTx(tx_hash=response.tx_hash, serialized_tx=result) response = client.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxItemAck):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"] = []
for path in witness_paths:
response = client.call(messages.CardanoTxWitnessRequest(path=path))
if not isinstance(response, messages.CardanoTxWitnessResponse):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["witnesses"].append(
{
"type": response.type,
"pub_key": response.pub_key,
"signature": response.signature,
"chain_code": response.chain_code,
}
)
response = client.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoTxBodyHash):
raise UNEXPECTED_RESPONSE_ERROR
sign_tx_response["tx_hash"] = response.tx_hash
response = client.call(messages.CardanoTxHostAck())
if not isinstance(response, messages.CardanoSignTxFinished):
raise UNEXPECTED_RESPONSE_ERROR
return sign_tx_response

View File

@ -40,13 +40,19 @@ def cli():
@cli.command() @cli.command()
@click.argument("file", type=click.File("r")) @click.argument("file", type=click.File("r"))
@click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False) @click.option("-f", "--file", "_ignore", is_flag=True, hidden=True, expose_value=False)
@click.option(
"-s",
"--signing-mode",
required=True,
type=ChoiceType({m.name: m for m in messages.CardanoTxSigningMode}),
)
@click.option( @click.option(
"-p", "--protocol-magic", type=int, default=cardano.PROTOCOL_MAGICS["mainnet"] "-p", "--protocol-magic", type=int, default=cardano.PROTOCOL_MAGICS["mainnet"]
) )
@click.option("-N", "--network-id", type=int, default=cardano.NETWORK_IDS["mainnet"]) @click.option("-N", "--network-id", type=int, default=cardano.NETWORK_IDS["mainnet"])
@click.option("-t", "--testnet", is_flag=True) @click.option("-t", "--testnet", is_flag=True)
@with_client @with_client
def sign_tx(client, file, protocol_magic, network_id, testnet): def sign_tx(client, file, signing_mode, protocol_magic, network_id, testnet):
"""Sign Cardano transaction.""" """Sign Cardano transaction."""
transaction = json.load(file) transaction = json.load(file)
@ -69,8 +75,9 @@ def sign_tx(client, file, protocol_magic, network_id, testnet):
] ]
auxiliary_data = cardano.parse_auxiliary_data(transaction.get("auxiliary_data")) auxiliary_data = cardano.parse_auxiliary_data(transaction.get("auxiliary_data"))
signed_transaction = cardano.sign_tx( sign_tx_response = cardano.sign_tx(
client, client,
signing_mode,
inputs, inputs,
outputs, outputs,
fee, fee,
@ -83,10 +90,28 @@ def sign_tx(client, file, protocol_magic, network_id, testnet):
auxiliary_data, auxiliary_data,
) )
return { sign_tx_response["tx_hash"] = sign_tx_response["tx_hash"].hex()
"tx_hash": signed_transaction.tx_hash.hex(), sign_tx_response["witnesses"] = [
"serialized_tx": signed_transaction.serialized_tx.hex(), {
} "type": witness["type"],
"pub_key": witness["pub_key"].hex(),
"signature": witness["signature"].hex(),
"chain_code": witness["chain_code"].hex()
if witness["chain_code"] is not None
else None,
}
for witness in sign_tx_response["witnesses"]
]
auxiliary_data_supplement = sign_tx_response.get("auxiliary_data_supplement")
if auxiliary_data_supplement:
auxiliary_data_supplement["auxiliary_data_hash"] = auxiliary_data_supplement[
"auxiliary_data_hash"
].hex()
catalyst_signature = auxiliary_data_supplement.get("catalyst_signature")
if catalyst_signature:
auxiliary_data_supplement["catalyst_signature"] = catalyst_signature.hex()
sign_tx_response["auxiliary_data_supplement"] = auxiliary_data_supplement
return sign_tx_response
@cli.command() @cli.command()