diff --git a/core/src/apps/cardano/README.md b/core/src/apps/cardano/README.md index 5c20c50ad..1a9c88f42 100644 --- a/core/src/apps/cardano/README.md +++ b/core/src/apps/cardano/README.md @@ -14,17 +14,17 @@ REVIEWER = Jan Matejek , Tomas Susanka None: diff --git a/core/src/apps/cardano/layout.py b/core/src/apps/cardano/layout.py index 82785c225..b95fb8a8b 100644 --- a/core/src/apps/cardano/layout.py +++ b/core/src/apps/cardano/layout.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal from trezor import messages, ui from trezor.enums import ( @@ -193,7 +193,7 @@ async def show_plutus_tx(ctx: wire.Context) -> None: ctx, "confirm_signing_mode", title="Confirm transaction", - content="Confirming a Plutus transaction - loss of collateral is possible. Check all items carefully.", + content="Confirming a Plutus transaction.", br_code=ButtonRequestType.Other, ) @@ -215,16 +215,24 @@ async def confirm_sending( ctx: wire.Context, ada_amount: int, to: str, - is_change_output: bool, + output_type: Literal["address", "change", "collateral-return"], network_id: int, ) -> None: - subtitle = "Change amount:" if is_change_output else "Confirm sending:" + if output_type == "address": + message = "Confirm sending" + elif output_type == "change": + message = "Change amount" + elif output_type == "collateral-return": + message = "Collateral return" + else: + raise RuntimeError # should be unreachable + await confirm_output( ctx, to, format_coin_amount(ada_amount, network_id), title="Confirm transaction", - subtitle=subtitle, + subtitle=f"{message}:", font_amount=ui.BOLD, width_paginated=17, to_str="\nto\n", @@ -256,6 +264,68 @@ async def confirm_sending_token( ) +async def confirm_datum_hash(ctx: wire.Context, datum_hash: bytes) -> None: + await confirm_properties( + ctx, + "confirm_datum_hash", + title="Confirm transaction", + props=[ + ( + "Datum hash:", + bech32.encode(bech32.HRP_OUTPUT_DATUM_HASH, datum_hash), + ), + ], + br_code=ButtonRequestType.Other, + ) + + +async def confirm_inline_datum( + ctx: wire.Context, first_chunk: bytes, inline_datum_size: int +) -> None: + await _confirm_data_chunk( + ctx, + "confirm_inline_datum", + "Inline datum", + first_chunk, + inline_datum_size, + ) + + +async def confirm_reference_script( + ctx: wire.Context, first_chunk: bytes, reference_script_size: int +) -> None: + await _confirm_data_chunk( + ctx, + "confirm_reference_script", + "Reference script", + first_chunk, + reference_script_size, + ) + + +async def _confirm_data_chunk( + ctx: wire.Context, br_type: str, title: str, first_chunk: bytes, data_size: int +) -> None: + MAX_DISPLAYED_SIZE = 56 + displayed_bytes = first_chunk[:MAX_DISPLAYED_SIZE] + bytes_optional_plural = "byte" if data_size == 1 else "bytes" + props: list[tuple[str, bytes | None]] = [ + ( + f"{title} ({data_size} {bytes_optional_plural}):", + displayed_bytes, + ) + ] + if data_size > MAX_DISPLAYED_SIZE: + props.append(("...", None)) + await confirm_properties( + ctx, + br_type, + title="Confirm transaction", + props=props, + br_code=ButtonRequestType.Other, + ) + + async def show_credentials( ctx: wire.Context, payment_credential: Credential, @@ -350,12 +420,18 @@ async def warn_path(ctx: wire.Context, path: list[int], title: str) -> None: await confirm_path_warning(ctx, address_n_to_str(path), path_type=title) -async def warn_tx_output_contains_tokens(ctx: wire.Context) -> None: +async def warn_tx_output_contains_tokens( + ctx: wire.Context, is_collateral_return: bool = False +) -> None: + if is_collateral_return: + content = "The collateral return\noutput contains tokens." + else: + content = "The following\ntransaction output\ncontains tokens." await confirm_metadata( ctx, "confirm_tokens", title="Confirm transaction", - content="The following\ntransaction output\ncontains tokens.", + content=content, larger_vspace=True, br_code=ButtonRequestType.Other, ) @@ -372,30 +448,12 @@ async def warn_tx_contains_mint(ctx: wire.Context) -> None: ) -async def warn_tx_output_contains_datum_hash( - ctx: wire.Context, datum_hash: bytes -) -> None: - await confirm_properties( - ctx, - "confirm_datum_hash", - title="Confirm transaction", - props=[ - ( - "The following transaction output contains datum hash:", - bech32.encode(bech32.HRP_OUTPUT_DATUM_HASH, datum_hash), - ), - ("\nContinue?", None), - ], - br_code=ButtonRequestType.Other, - ) - - -async def warn_tx_output_no_datum_hash(ctx: wire.Context) -> None: +async def warn_tx_output_no_datum(ctx: wire.Context) -> None: await confirm_metadata( ctx, "confirm_no_datum_hash", title="Confirm transaction", - content="The following transaction output contains a script address, but does not contain a datum hash.", + content="The following transaction output contains a script address, but does not contain a datum.", br_code=ButtonRequestType.Other, ) @@ -420,6 +478,16 @@ async def warn_no_collateral_inputs(ctx: wire.Context) -> None: ) +async def warn_unknown_total_collateral(ctx: wire.Context) -> None: + await confirm_metadata( + ctx, + "confirm_unknown_total_collateral", + title="Warning", + content="Unknown collateral amount, check all items carefully.", + br_code=ButtonRequestType.Other, + ) + + async def confirm_witness_request( ctx: wire.Context, witness_path: list[int], @@ -448,6 +516,7 @@ async def confirm_tx( protocol_magic: int, ttl: int | None, validity_interval_start: int | None, + total_collateral: int | None, is_network_id_verifiable: bool, tx_hash: bytes | None, ) -> None: @@ -455,6 +524,11 @@ async def confirm_tx( ("Transaction fee:", format_coin_amount(fee, network_id)), ] + if total_collateral is not None: + props.append( + ("Total collateral:", format_coin_amount(total_collateral, network_id)) + ) + if is_network_id_verifiable: props.append((f"Network: {protocol_magics.to_ui_string(protocol_magic)}", None)) @@ -778,6 +852,21 @@ async def confirm_collateral_input( ) +async def confirm_reference_input( + ctx: wire.Context, reference_input: messages.CardanoTxReferenceInput +) -> None: + await confirm_properties( + ctx, + "confirm_reference_input", + title="Confirm transaction", + props=[ + ("Reference input ID:", reference_input.prev_hash), + ("Reference input index:", str(reference_input.prev_index)), + ], + br_code=ButtonRequestType.Other, + ) + + async def confirm_required_signer( ctx: wire.Context, required_signer: messages.CardanoTxRequiredSigner ) -> None: diff --git a/core/src/apps/cardano/sign_tx/multisig_signer.py b/core/src/apps/cardano/sign_tx/multisig_signer.py index 690f2b44a..2c402dcac 100644 --- a/core/src/apps/cardano/sign_tx/multisig_signer.py +++ b/core/src/apps/cardano/sign_tx/multisig_signer.py @@ -21,11 +21,11 @@ class MultisigSigner(Signer): def _validate_tx_init(self) -> None: super()._validate_tx_init() - if ( - self.msg.collateral_inputs_count != 0 - or self.msg.required_signers_count != 0 - ): - raise wire.ProcessError("Invalid tx signing request") + self._assert_tx_init_cond(self.msg.collateral_inputs_count == 0) + self._assert_tx_init_cond(self.msg.required_signers_count == 0) + self._assert_tx_init_cond(not self.msg.has_collateral_return) + self._assert_tx_init_cond(self.msg.total_collateral is None) + self._assert_tx_init_cond(self.msg.reference_inputs_count == 0) async def _show_tx_init(self) -> None: await layout.show_multisig_tx(self.ctx) @@ -41,6 +41,7 @@ class MultisigSigner(Signer): self.msg.protocol_magic, self.msg.ttl, self.msg.validity_interval_start, + self.msg.total_collateral, is_network_id_verifiable, tx_hash=None, ) diff --git a/core/src/apps/cardano/sign_tx/ordinary_signer.py b/core/src/apps/cardano/sign_tx/ordinary_signer.py index 9e83e8c00..dae4d380c 100644 --- a/core/src/apps/cardano/sign_tx/ordinary_signer.py +++ b/core/src/apps/cardano/sign_tx/ordinary_signer.py @@ -27,11 +27,11 @@ class OrdinarySigner(Signer): def _validate_tx_init(self) -> None: super()._validate_tx_init() - if ( - self.msg.collateral_inputs_count != 0 - or self.msg.required_signers_count != 0 - ): - raise wire.ProcessError("Invalid tx signing request") + self._assert_tx_init_cond(self.msg.collateral_inputs_count == 0) + self._assert_tx_init_cond(self.msg.required_signers_count == 0) + self._assert_tx_init_cond(not self.msg.has_collateral_return) + self._assert_tx_init_cond(self.msg.total_collateral is None) + self._assert_tx_init_cond(self.msg.reference_inputs_count == 0) async def _confirm_tx(self, tx_hash: bytes) -> None: # super() omitted intentionally @@ -43,6 +43,7 @@ class OrdinarySigner(Signer): self.msg.protocol_magic, self.msg.ttl, self.msg.validity_interval_start, + self.msg.total_collateral, is_network_id_verifiable, tx_hash=None, ) diff --git a/core/src/apps/cardano/sign_tx/plutus_signer.py b/core/src/apps/cardano/sign_tx/plutus_signer.py index 7073073fc..e858965b7 100644 --- a/core/src/apps/cardano/sign_tx/plutus_signer.py +++ b/core/src/apps/cardano/sign_tx/plutus_signer.py @@ -24,12 +24,16 @@ class PlutusSigner(Signer): async def _show_tx_init(self) -> None: await layout.show_plutus_tx(self.ctx) await super()._show_tx_init() + # These items should be present if a Plutus script is to be executed. if self.msg.script_data_hash is None: await layout.warn_no_script_data_hash(self.ctx) if self.msg.collateral_inputs_count == 0: await layout.warn_no_collateral_inputs(self.ctx) + if self.msg.total_collateral is None: + await layout.warn_unknown_total_collateral(self.ctx) + async def _confirm_tx(self, tx_hash: bytes) -> None: # super() omitted intentionally # We display tx hash so that experienced users can compare it to the tx hash @@ -43,10 +47,17 @@ class PlutusSigner(Signer): self.msg.protocol_magic, self.msg.ttl, self.msg.validity_interval_start, + self.msg.total_collateral, is_network_id_verifiable, tx_hash, ) + def _should_show_tx_hash(self) -> bool: + # super() omitted intentionally + # Plutus txs tend to contain a lot of opaque data, some users might + # want to verify only the tx hash. + return True + async def _show_input(self, input: messages.CardanoTxInput) -> None: # super() omitted intentionally # The inputs are not interchangeable (because of datums), so we must show them. diff --git a/core/src/apps/cardano/sign_tx/pool_owner_signer.py b/core/src/apps/cardano/sign_tx/pool_owner_signer.py index c81845994..6586a2e50 100644 --- a/core/src/apps/cardano/sign_tx/pool_owner_signer.py +++ b/core/src/apps/cardano/sign_tx/pool_owner_signer.py @@ -29,21 +29,15 @@ class PoolOwnerSigner(Signer): def _validate_tx_init(self) -> None: super()._validate_tx_init() - if ( - self.msg.certificates_count != 1 - or self.msg.withdrawals_count != 0 - or self.msg.minting_asset_groups_count != 0 - ): - raise wire.ProcessError( - "Stakepool registration transaction cannot contain other certificates, withdrawals or minting" - ) - - if ( - self.msg.script_data_hash is not None - or self.msg.collateral_inputs_count != 0 - or self.msg.required_signers_count != 0 - ): - raise wire.ProcessError("Invalid tx signing request") + self._assert_tx_init_cond(self.msg.certificates_count == 1) + self._assert_tx_init_cond(self.msg.withdrawals_count == 0) + self._assert_tx_init_cond(self.msg.minting_asset_groups_count == 0) + self._assert_tx_init_cond(self.msg.script_data_hash is None) + self._assert_tx_init_cond(self.msg.collateral_inputs_count == 0) + self._assert_tx_init_cond(self.msg.required_signers_count == 0) + self._assert_tx_init_cond(not self.msg.has_collateral_return) + self._assert_tx_init_cond(self.msg.total_collateral is None) + self._assert_tx_init_cond(self.msg.reference_inputs_count == 0) async def _confirm_tx(self, tx_hash: bytes) -> None: # super() omitted intentionally @@ -56,7 +50,12 @@ class PoolOwnerSigner(Signer): def _validate_output(self, output: messages.CardanoTxOutput) -> None: super()._validate_output(output) - if output.address_parameters is not None or output.datum_hash is not None: + if ( + output.address_parameters is not None + or output.datum_hash is not None + or output.inline_datum_size > 0 + or output.reference_script_size > 0 + ): raise wire.ProcessError("Invalid output") def _should_show_output(self, output: messages.CardanoTxOutput) -> bool: diff --git a/core/src/apps/cardano/sign_tx/signer.py b/core/src/apps/cardano/sign_tx/signer.py index 28fb4df0a..dbc348b34 100644 --- a/core/src/apps/cardano/sign_tx/signer.py +++ b/core/src/apps/cardano/sign_tx/signer.py @@ -7,6 +7,7 @@ from trezor.crypto.curve import ed25519 from trezor.enums import ( CardanoAddressType, CardanoCertificateType, + CardanoTxOutputSerializationFormat, CardanoTxWitnessType, ) @@ -22,7 +23,11 @@ from ..helpers import ( ) from ..helpers.account_path_check import AccountPathChecker from ..helpers.credential import Credential, should_show_credentials -from ..helpers.hash_builder_collection import HashBuilderDict, HashBuilderList +from ..helpers.hash_builder_collection import ( + HashBuilderDict, + HashBuilderEmbeddedCBOR, + HashBuilderList, +) from ..helpers.paths import ( CERTIFICATE_PATH_NAME, CHANGE_OUTPUT_PATH_NAME, @@ -61,9 +66,22 @@ TX_BODY_KEY_SCRIPT_DATA_HASH = const(11) TX_BODY_KEY_COLLATERAL_INPUTS = const(13) TX_BODY_KEY_REQUIRED_SIGNERS = const(14) TX_BODY_KEY_NETWORK_ID = const(15) +TX_BODY_KEY_COLLATERAL_RETURN = const(16) +TX_BODY_KEY_TOTAL_COLLATERAL = const(17) +TX_BODY_KEY_REFERENCE_INPUTS = const(18) + +BABBAGE_OUTPUT_KEY_ADDRESS = const(0) +BABBAGE_OUTPUT_KEY_AMOUNT = const(1) +BABBAGE_OUTPUT_KEY_DATUM_OPTION = const(2) +BABBAGE_OUTPUT_KEY_REFERENCE_SCRIPT = const(3) + +DATUM_OPTION_KEY_HASH = const(0) +DATUM_OPTION_KEY_INLINE = const(1) POOL_REGISTRATION_CERTIFICATE_ITEMS_COUNT = 10 +MAX_CHUNK_SIZE = 1024 + class Signer: """ @@ -84,6 +102,11 @@ class Signer: self.msg = msg self.keychain = keychain + # Some data (e.g. output inline datum) are too long to verify manually. + # We track their presence and eventually display the tx hash when + # confirming the tx. + self.has_hidden_data = False + self.account_path_checker = AccountPathChecker() # Inputs, outputs and fee are mandatory, count the number of optional fields present. @@ -99,6 +122,9 @@ class Signer: msg.script_data_hash is not None, msg.collateral_inputs_count > 0, msg.required_signers_count > 0, + msg.has_collateral_return, + msg.total_collateral is not None, + msg.reference_inputs_count > 0, ) ) self.tx_dict: HashBuilderDict[int, Any] = HashBuilderDict( @@ -193,9 +219,27 @@ class Signer: if self.msg.include_network_id: self.tx_dict.add(TX_BODY_KEY_NETWORK_ID, self.msg.network_id) + if self.msg.has_collateral_return: + await self._process_collateral_return() + + if self.msg.total_collateral is not None: + self.tx_dict.add(TX_BODY_KEY_TOTAL_COLLATERAL, self.msg.total_collateral) + + if self.msg.reference_inputs_count > 0: + reference_inputs_list: HashBuilderList[tuple[bytes, int]] = HashBuilderList( + self.msg.reference_inputs_count + ) + with self.tx_dict.add(TX_BODY_KEY_REFERENCE_INPUTS, reference_inputs_list): + await self._process_reference_inputs(reference_inputs_list) + def _validate_tx_init(self) -> None: if self.msg.fee > LOVELACE_MAX_SUPPLY: raise wire.ProcessError("Fee is out of range!") + if ( + self.msg.total_collateral is not None + and self.msg.total_collateral > LOVELACE_MAX_SUPPLY + ): + raise wire.ProcessError("Total collateral is out of range!") validate_network_info(self.msg.network_id, self.msg.protocol_magic) async def _show_tx_init(self) -> None: @@ -206,6 +250,10 @@ class Signer: # Final signing confirmation is handled separately in each signing mode. raise NotImplementedError + def _should_show_tx_hash(self) -> bool: + # By default, we display tx hash only if some data wasn't shown. + return self.has_hidden_data + # inputs async def _process_inputs( @@ -235,43 +283,41 @@ class Signer: output: messages.CardanoTxOutput = await self.ctx.call( messages.CardanoTxItemAck(), messages.CardanoTxOutput ) - self._validate_output(output) - await self._show_output(output) - - output_address = self._get_output_address(output) - - has_datum_hash = output.datum_hash is not None - output_list: HashBuilderList = HashBuilderList(2 + int(has_datum_hash)) - with outputs_list.append(output_list): - output_list.append(output_address) - if output.asset_groups_count == 0: - # output structure is: [address, amount, datum_hash?] - output_list.append(output.amount) - else: - # output structure is: [address, [amount, asset_groups], datum_hash?] - output_value_list: HashBuilderList = HashBuilderList(2) - with output_list.append(output_value_list): - output_value_list.append(output.amount) - asset_groups_dict: HashBuilderDict[ - bytes, HashBuilderDict[bytes, int] - ] = HashBuilderDict( - output.asset_groups_count, - wire.ProcessError("Invalid token bundle in output"), - ) - with output_value_list.append(asset_groups_dict): - await self._process_asset_groups( - asset_groups_dict, - output.asset_groups_count, - self._should_show_output(output), - ) - if has_datum_hash: - output_list.append(output.datum_hash) + await self._process_output(outputs_list, output) total_amount += output.amount if total_amount > LOVELACE_MAX_SUPPLY: raise wire.ProcessError("Total transaction amount is out of range!") + async def _process_output( + self, outputs_list: HashBuilderList, output: messages.CardanoTxOutput + ) -> None: + self._validate_output(output) + should_show = self._should_show_output(output) + if should_show: + await self._show_output_init(output) + + output_items_count = 2 + sum( + ( + output.datum_hash is not None, + output.inline_datum_size > 0, + output.reference_script_size > 0, + ) + ) + if output.format == CardanoTxOutputSerializationFormat.ARRAY_LEGACY: + output_list: HashBuilderList = HashBuilderList(output_items_count) + with outputs_list.append(output_list): + await self._process_legacy_output(output_list, output, should_show) + elif output.format == CardanoTxOutputSerializationFormat.MAP_BABBAGE: + output_dict: HashBuilderDict[int, Any] = HashBuilderDict( + output_items_count, wire.ProcessError("Invalid output") + ) + with outputs_list.append(output_dict): + await self._process_babbage_output(output_dict, output, should_show) + else: + raise RuntimeError # should be unreachable + def _validate_output(self, output: messages.CardanoTxOutput) -> None: if output.address_parameters is not None and output.address is not None: raise wire.ProcessError("Invalid output") @@ -286,25 +332,35 @@ class Signer: else: raise wire.ProcessError("Invalid output") + # datum hash if output.datum_hash is not None: if len(output.datum_hash) != OUTPUT_DATUM_HASH_SIZE: raise wire.ProcessError("Invalid output datum hash") - self.account_path_checker.add_output(output) + # inline datum + if output.inline_datum_size > 0: + if output.format != CardanoTxOutputSerializationFormat.MAP_BABBAGE: + raise wire.ProcessError("Invalid output") - async def _show_output(self, output: messages.CardanoTxOutput) -> None: - if not self._should_show_output(output): - return + # datum hash and inline datum are mutually exclusive + if output.datum_hash is not None and output.inline_datum_size > 0: + raise wire.ProcessError("Invalid output") - if output.datum_hash is not None: - await layout.warn_tx_output_contains_datum_hash(self.ctx, output.datum_hash) + # reference script + if output.reference_script_size > 0: + if output.format != CardanoTxOutputSerializationFormat.MAP_BABBAGE: + raise wire.ProcessError("Invalid output") + self.account_path_checker.add_output(output) + + async def _show_output_init(self, output: messages.CardanoTxOutput) -> None: address_type = self._get_output_address_type(output) if ( output.datum_hash is None + and output.inline_datum_size == 0 and address_type in addresses.ADDRESS_TYPES_PAYMENT_SCRIPT ): - await layout.warn_tx_output_no_datum_hash(self.ctx) + await layout.warn_tx_output_no_datum(self.ctx) if output.asset_groups_count > 0: await layout.warn_tx_output_contains_tokens(self.ctx) @@ -325,7 +381,7 @@ class Signer: self.ctx, output.amount, address, - self._is_change_output(output), + "change" if self._is_change_output(output) else "address", self.msg.network_id, ) @@ -340,24 +396,28 @@ class Signer: def _should_show_output(self, output: messages.CardanoTxOutput) -> bool: """ - Determines whether the output should be shown. Extracted from _show_output because - of readability and because the same decision is made when displaying output tokens. + Determines whether the output should be shown. Extracted from _show_output + because of readability. """ - if output.datum_hash is not None: + if ( + output.datum_hash is not None + or output.inline_datum_size > 0 + or output.reference_script_size > 0 + ): return True address_type = self._get_output_address_type(output) if ( output.datum_hash is None + and output.inline_datum_size == 0 and address_type in addresses.ADDRESS_TYPES_PAYMENT_SCRIPT ): - # Plutus script address without a datum hash is unspendable, we must show a warning. + # Plutus script address without a datum is unspendable, we must show a warning. return True - if output.address_parameters is not None: # change output - if not should_show_credentials(output.address_parameters): - # We don't need to display simple address outputs. - return False + if self._is_simple_change_output(output): + # We don't need to display simple address outputs. + return False return True @@ -365,6 +425,111 @@ class Signer: """Used only to determine what message to show to the user when confirming sending.""" return output.address_parameters is not None + def _is_simple_change_output(self, output: messages.CardanoTxOutput) -> bool: + """Used to determine whether an output is a change output with ordinary credentials.""" + return output.address_parameters is not None and not should_show_credentials( + output.address_parameters + ) + + async def _process_legacy_output( + self, + output_list: HashBuilderList, + output: messages.CardanoTxOutput, + should_show: bool, + ) -> None: + address = self._get_output_address(output) + output_list.append(address) + + if output.asset_groups_count == 0: + # Output structure is: [address, amount, datum_hash?] + output_list.append(output.amount) + else: + # Output structure is: [address, [amount, asset_groups], datum_hash?] + output_value_list: HashBuilderList = HashBuilderList(2) + with output_list.append(output_value_list): + await self._process_output_value(output_value_list, output, should_show) + + if output.datum_hash is not None: + if should_show: + await layout.confirm_datum_hash(self.ctx, output.datum_hash) + output_list.append(output.datum_hash) + + async def _process_babbage_output( + self, + output_dict: HashBuilderDict[int, Any], + output: messages.CardanoTxOutput, + should_show: bool, + ) -> None: + """ + This output format corresponds to the post-Alonzo format in CDDL. + Note that it is to be used also for outputs with no Plutus elements. + """ + address = self._get_output_address(output) + output_dict.add(BABBAGE_OUTPUT_KEY_ADDRESS, address) + + if output.asset_groups_count == 0: + # Only amount is added to the dict. + output_dict.add(BABBAGE_OUTPUT_KEY_AMOUNT, output.amount) + else: + # [amount, asset_groups] is added to the dict. + output_value_list: HashBuilderList = HashBuilderList(2) + with output_dict.add(BABBAGE_OUTPUT_KEY_AMOUNT, output_value_list): + await self._process_output_value(output_value_list, output, should_show) + + if output.datum_hash is not None: + if should_show: + await layout.confirm_datum_hash(self.ctx, output.datum_hash) + output_dict.add( + BABBAGE_OUTPUT_KEY_DATUM_OPTION, + (DATUM_OPTION_KEY_HASH, output.datum_hash), + ) + elif output.inline_datum_size > 0: + inline_datum_list: HashBuilderList = HashBuilderList(2) + with output_dict.add(BABBAGE_OUTPUT_KEY_DATUM_OPTION, inline_datum_list): + inline_datum_list.append(DATUM_OPTION_KEY_INLINE) + inline_datum_cbor: HashBuilderEmbeddedCBOR = HashBuilderEmbeddedCBOR( + output.inline_datum_size + ) + with inline_datum_list.append(inline_datum_cbor): + await self._process_inline_datum( + inline_datum_cbor, output.inline_datum_size, should_show + ) + + if output.reference_script_size > 0: + reference_script_cbor: HashBuilderEmbeddedCBOR = HashBuilderEmbeddedCBOR( + output.reference_script_size + ) + with output_dict.add( + BABBAGE_OUTPUT_KEY_REFERENCE_SCRIPT, reference_script_cbor + ): + await self._process_reference_script( + reference_script_cbor, output.reference_script_size, should_show + ) + + async def _process_output_value( + self, + output_value_list: HashBuilderList, + output: messages.CardanoTxOutput, + should_show_tokens: bool, + ) -> None: + """Should be used only when the output contains tokens.""" + assert output.asset_groups_count > 0 + + output_value_list.append(output.amount) + + asset_groups_dict: HashBuilderDict[ + bytes, HashBuilderDict[bytes, int] + ] = HashBuilderDict( + output.asset_groups_count, + wire.ProcessError("Invalid token bundle in output"), + ) + with output_value_list.append(asset_groups_dict): + await self._process_asset_groups( + asset_groups_dict, + output.asset_groups_count, + should_show_tokens, + ) + # asset groups async def _process_asset_groups( @@ -444,6 +609,62 @@ class Signer: if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH: raise INVALID_TOKEN_BUNDLE + # inline datum + + async def _process_inline_datum( + self, + inline_datum_cbor: HashBuilderEmbeddedCBOR, + inline_datum_size: int, + should_show: bool, + ) -> None: + assert inline_datum_size > 0 + self.has_hidden_data = True + + chunks_count = self._get_chunks_count(inline_datum_size) + for chunk_number in range(chunks_count): + chunk: messages.CardanoTxInlineDatumChunk = await self.ctx.call( + messages.CardanoTxItemAck(), messages.CardanoTxInlineDatumChunk + ) + self._validate_chunk( + chunk.data, + chunk_number, + chunks_count, + wire.ProcessError("Invalid inline datum chunk"), + ) + if chunk_number == 0 and should_show: + await layout.confirm_inline_datum( + self.ctx, chunk.data, inline_datum_size + ) + inline_datum_cbor.add(chunk.data) + + # reference script + + async def _process_reference_script( + self, + reference_script_cbor: HashBuilderEmbeddedCBOR, + reference_script_size: int, + should_show: bool, + ) -> None: + assert reference_script_size > 0 + self.has_hidden_data = True + + chunks_count = self._get_chunks_count(reference_script_size) + for chunk_number in range(chunks_count): + chunk: messages.CardanoTxReferenceScriptChunk = await self.ctx.call( + messages.CardanoTxItemAck(), messages.CardanoTxReferenceScriptChunk + ) + self._validate_chunk( + chunk.data, + chunk_number, + chunks_count, + wire.ProcessError("Invalid reference script chunk"), + ) + if chunk_number == 0 and should_show: + await layout.confirm_reference_script( + self.ctx, chunk.data, reference_script_size + ) + reference_script_cbor.add(chunk.data) + # certificates async def _process_certificates(self, certificates_list: HashBuilderList) -> None: @@ -686,7 +907,7 @@ class Signer: messages.CardanoTxItemAck(), messages.CardanoTxCollateralInput ) self._validate_collateral_input(collateral_input) - await layout.confirm_collateral_input(self.ctx, collateral_input) + await self._show_collateral_input(collateral_input) collateral_inputs_list.append( (collateral_input.prev_hash, collateral_input.prev_index) ) @@ -697,6 +918,12 @@ class Signer: if len(collateral_input.prev_hash) != INPUT_PREV_HASH_SIZE: raise wire.ProcessError("Invalid collateral input") + async def _show_collateral_input( + self, collateral_input: messages.CardanoTxCollateralInput + ) -> None: + if self.msg.total_collateral is None: + await layout.confirm_collateral_input(self.ctx, collateral_input) + # required signers async def _process_required_signers( @@ -735,6 +962,123 @@ class Signer: else: raise INVALID_REQUIRED_SIGNER + # collateral return + + async def _process_collateral_return(self) -> None: + output: messages.CardanoTxOutput = await self.ctx.call( + messages.CardanoTxItemAck(), messages.CardanoTxOutput + ) + self._validate_collateral_return(output) + should_show_init = self._should_show_collateral_return_init(output) + should_show_tokens = self._should_show_collateral_return_tokens(output) + if should_show_init: + await self._show_collateral_return_init(output) + + # Datums and reference scripts are forbidden, see _validate_collateral_return. + output_items_count = 2 + if output.format == CardanoTxOutputSerializationFormat.ARRAY_LEGACY: + output_list: HashBuilderList = HashBuilderList(output_items_count) + with self.tx_dict.add(TX_BODY_KEY_COLLATERAL_RETURN, output_list): + await self._process_legacy_output( + output_list, output, should_show_tokens + ) + elif output.format == CardanoTxOutputSerializationFormat.MAP_BABBAGE: + output_dict: HashBuilderDict[int, Any] = HashBuilderDict( + output_items_count, wire.ProcessError("Invalid collateral return") + ) + with self.tx_dict.add(TX_BODY_KEY_COLLATERAL_RETURN, output_dict): + await self._process_babbage_output( + output_dict, output, should_show_tokens + ) + else: + raise RuntimeError # should be unreachable + + def _validate_collateral_return(self, output: messages.CardanoTxOutput) -> None: + self._validate_output(output) + + address_type = self._get_output_address_type(output) + if address_type not in addresses.ADDRESS_TYPES_PAYMENT_KEY: + raise wire.ProcessError("Invalid collateral return") + + if ( + output.datum_hash is not None + or output.inline_datum_size > 0 + or output.reference_script_size > 0 + ): + raise wire.ProcessError("Invalid collateral return") + + async def _show_collateral_return_init( + self, output: messages.CardanoTxOutput + ) -> None: + # We don't display missing datum warning since datums are forbidden. + + if output.asset_groups_count > 0: + await layout.warn_tx_output_contains_tokens( + self.ctx, is_collateral_return=True + ) + + if output.address_parameters is not None: + address = addresses.derive_human_readable( + self.keychain, + output.address_parameters, + self.msg.protocol_magic, + self.msg.network_id, + ) + await self._show_output_credentials( + output.address_parameters, + ) + else: + assert output.address is not None # _validate_output + address = output.address + + await layout.confirm_sending( + self.ctx, + output.amount, + address, + "collateral-return", + self.msg.network_id, + ) + + def _should_show_collateral_return_init( + self, output: messages.CardanoTxOutput + ) -> bool: + if self.msg.total_collateral is None: + return True + + if self._is_simple_change_output(output): + return False + + return True + + def _should_show_collateral_return_tokens( + self, output: messages.CardanoTxOutput + ) -> bool: + if self._is_simple_change_output(output): + return False + + return True + + # reference inputs + + async def _process_reference_inputs( + self, reference_inputs_list: HashBuilderList[tuple[bytes, int]] + ) -> None: + for _ in range(self.msg.reference_inputs_count): + reference_input: messages.CardanoTxReferenceInput = await self.ctx.call( + messages.CardanoTxItemAck(), messages.CardanoTxReferenceInput + ) + self._validate_reference_input(reference_input) + await layout.confirm_reference_input(self.ctx, reference_input) + reference_inputs_list.append( + (reference_input.prev_hash, reference_input.prev_index) + ) + + def _validate_reference_input( + self, reference_input: messages.CardanoTxReferenceInput + ) -> None: + if len(reference_input.prev_hash) != INPUT_PREV_HASH_SIZE: + raise wire.ProcessError("Invalid reference input") + # witness requests async def _process_witness_requests(self, tx_hash: bytes) -> CardanoTxResponseType: @@ -767,6 +1111,10 @@ class Signer: # helpers + def _assert_tx_init_cond(self, condition: bool) -> None: + if not condition: + raise wire.ProcessError("Invalid tx signing request") + def _is_network_id_verifiable(self) -> bool: """ Checks whether there is at least one element that contains information about @@ -823,6 +1171,22 @@ class Signer: self.msg.network_id, ) + def _get_chunks_count(self, data_size: int) -> int: + assert data_size > 0 + return (data_size - 1) // MAX_CHUNK_SIZE + 1 + + def _validate_chunk( + self, + chunk_data: bytes, + chunk_number: int, + chunks_count: int, + error: wire.ProcessError, + ) -> None: + if chunk_number < chunks_count - 1 and len(chunk_data) != MAX_CHUNK_SIZE: + raise error + if chunk_number == chunks_count - 1 and len(chunk_data) > MAX_CHUNK_SIZE: + raise error + def _get_byron_witness( self, path: list[int], tx_hash: bytes ) -> messages.CardanoTxWitnessResponse: diff --git a/python/src/trezorlib/cardano.py b/python/src/trezorlib/cardano.py index d49dcfc1e..7c37623d9 100644 --- a/python/src/trezorlib/cardano.py +++ b/python/src/trezorlib/cardano.py @@ -26,6 +26,8 @@ from typing import ( Optional, Sequence, Tuple, + Type, + TypeVar, Union, ) @@ -39,6 +41,8 @@ if TYPE_CHECKING: PROTOCOL_MAGICS = {"mainnet": 764824073, "testnet": 1097911063} NETWORK_IDS = {"mainnet": 1, "testnet": 0} +MAX_CHUNK_SIZE = 1024 + REQUIRED_FIELDS_TRANSACTION = ("inputs", "outputs") REQUIRED_FIELDS_INPUT = ("prev_hash", "prev_index") REQUIRED_FIELDS_CERTIFICATE = ("type",) @@ -67,9 +71,18 @@ INVALID_MINT_TOKEN_BUNDLE_ENTRY = "The mint token_bundle entry is invalid" InputWithPath = Tuple[messages.CardanoTxInput, List[int]] CollateralInputWithPath = Tuple[messages.CardanoTxCollateralInput, List[int]] AssetGroupWithTokens = Tuple[messages.CardanoAssetGroup, List[messages.CardanoToken]] -OutputWithAssetGroups = Tuple[messages.CardanoTxOutput, List[AssetGroupWithTokens]] +OutputWithData = Tuple[ + messages.CardanoTxOutput, + List[AssetGroupWithTokens], + List[messages.CardanoTxInlineDatumChunk], + List[messages.CardanoTxReferenceScriptChunk], +] OutputItem = Union[ - messages.CardanoTxOutput, messages.CardanoAssetGroup, messages.CardanoToken + messages.CardanoTxOutput, + messages.CardanoAssetGroup, + messages.CardanoToken, + messages.CardanoTxInlineDatumChunk, + messages.CardanoTxReferenceScriptChunk, ] CertificateItem = Union[ messages.CardanoTxCertificate, @@ -89,6 +102,12 @@ Path = List[int] Witness = Tuple[Path, bytes] AuxiliaryDataSupplement = Dict[str, Union[int, bytes]] SignTxResponse = Dict[str, Union[bytes, List[Witness], AuxiliaryDataSupplement]] +Chunk = TypeVar( + "Chunk", + bound=Union[ + messages.CardanoTxInlineDatumChunk, messages.CardanoTxReferenceScriptChunk + ], +) def parse_optional_bytes(value: Optional[str]) -> Optional[bytes]: @@ -158,7 +177,7 @@ def parse_input(tx_input: dict) -> InputWithPath: ) -def parse_output(output: dict) -> OutputWithAssetGroups: +def parse_output(output: dict) -> OutputWithData: contains_address = "address" in output contains_address_type = "addressType" in output @@ -181,6 +200,20 @@ def parse_output(output: dict) -> OutputWithAssetGroups: datum_hash = parse_optional_bytes(output.get("datum_hash")) + serialization_format = messages.CardanoTxOutputSerializationFormat.ARRAY_LEGACY + if "format" in output: + serialization_format = output["format"] + + inline_datum_size, inline_datum_chunks = _parse_chunkable_data( + parse_optional_bytes(output.get("inline_datum")), + messages.CardanoTxInlineDatumChunk, + ) + + reference_script_size, reference_script_chunks = _parse_chunkable_data( + parse_optional_bytes(output.get("reference_script")), + messages.CardanoTxReferenceScriptChunk, + ) + return ( messages.CardanoTxOutput( address=address, @@ -188,8 +221,13 @@ def parse_output(output: dict) -> OutputWithAssetGroups: amount=int(output["amount"]), asset_groups_count=len(token_bundle), datum_hash=datum_hash, + format=serialization_format, + inline_datum_size=inline_datum_size, + reference_script_size=reference_script_size, ), token_bundle, + inline_datum_chunks, + reference_script_chunks, ) @@ -287,6 +325,23 @@ def _parse_address_parameters( ) +def _parse_chunkable_data( + data: Optional[bytes], chunk_type: Type[Chunk] +) -> Tuple[int, List[Chunk]]: + if data is None: + return 0, [] + data_size = len(data) + data_chunks = [chunk_type(data=chunk) for chunk in _create_data_chunks(data)] + return data_size, data_chunks + + +def _create_data_chunks(data: bytes) -> Iterator[bytes]: + processed_size = 0 + while processed_size < len(data): + yield data[processed_size : (processed_size + MAX_CHUNK_SIZE)] + processed_size += MAX_CHUNK_SIZE + + def parse_native_script(native_script: dict) -> messages.CardanoNativeScript: if "type" not in native_script: raise ValueError("Script is missing some fields") @@ -565,6 +620,16 @@ def parse_required_signer(required_signer: dict) -> messages.CardanoTxRequiredSi ) +def parse_reference_input(reference_input: dict) -> messages.CardanoTxReferenceInput: + if not all(k in reference_input for k in REQUIRED_FIELDS_INPUT): + raise ValueError("The reference input is missing some fields") + + return messages.CardanoTxReferenceInput( + prev_hash=bytes.fromhex(reference_input["prev_hash"]), + prev_index=reference_input["prev_index"], + ) + + def parse_additional_witness_request( additional_witness_request: dict, ) -> Path: @@ -630,20 +695,32 @@ def _get_witness_requests( return [messages.CardanoTxWitnessRequest(path=path) for path in sorted_paths] -def _get_input_items(inputs: List[InputWithPath]) -> Iterator[messages.CardanoTxInput]: +def _get_inputs_items(inputs: List[InputWithPath]) -> Iterator[messages.CardanoTxInput]: for input, _ in inputs: yield input -def _get_output_items(outputs: List[OutputWithAssetGroups]) -> Iterator[OutputItem]: - for output, asset_groups in outputs: - yield output - for asset_group, tokens in asset_groups: - yield asset_group - yield from tokens +def _get_outputs_items(outputs: List[OutputWithData]) -> Iterator[OutputItem]: + for output_with_data in outputs: + yield from _get_output_items(output_with_data) + + +def _get_output_items(output_with_data: OutputWithData) -> Iterator[OutputItem]: + ( + output, + asset_groups, + inline_datum_chunks, + reference_script_chunks, + ) = output_with_data + yield output + for asset_group, tokens in asset_groups: + yield asset_group + yield from tokens + yield from inline_datum_chunks + yield from reference_script_chunks -def _get_certificate_items( +def _get_certificates_items( certificates: Sequence[CertificateWithPoolOwnersAndRelays], ) -> Iterator[CertificateItem]: for certificate, pool_owners_and_relays in certificates: @@ -663,7 +740,7 @@ def _get_mint_items(mint: Sequence[AssetGroupWithTokens]) -> Iterator[MintItem]: yield from tokens -def _get_collateral_input_items( +def _get_collateral_inputs_items( collateral_inputs: Sequence[CollateralInputWithPath], ) -> Iterator[messages.CardanoTxCollateralInput]: for collateral_input, _ in collateral_inputs: @@ -726,7 +803,7 @@ def sign_tx( client: "TrezorClient", signing_mode: messages.CardanoTxSigningMode, inputs: List[InputWithPath], - outputs: List[OutputWithAssetGroups], + outputs: List[OutputWithData], fee: int, ttl: Optional[int], validity_interval_start: Optional[int], @@ -739,6 +816,9 @@ def sign_tx( script_data_hash: Optional[bytes] = None, collateral_inputs: Sequence[CollateralInputWithPath] = (), required_signers: Sequence[messages.CardanoTxRequiredSigner] = (), + collateral_return: Optional[OutputWithData] = None, + total_collateral: Optional[int] = None, + reference_inputs: Sequence[messages.CardanoTxReferenceInput] = (), additional_witness_requests: Sequence[Path] = (), derivation_type: messages.CardanoDerivationType = messages.CardanoDerivationType.ICARUS, include_network_id: bool = False, @@ -772,6 +852,9 @@ def sign_tx( script_data_hash=script_data_hash, collateral_inputs_count=len(collateral_inputs), required_signers_count=len(required_signers), + has_collateral_return=collateral_return is not None, + total_collateral=total_collateral, + reference_inputs_count=len(reference_inputs), witness_requests_count=len(witness_requests), derivation_type=derivation_type, include_network_id=include_network_id, @@ -781,9 +864,9 @@ def sign_tx( raise UNEXPECTED_RESPONSE_ERROR for tx_item in chain( - _get_input_items(inputs), - _get_output_items(outputs), - _get_certificate_items(certificates), + _get_inputs_items(inputs), + _get_outputs_items(outputs), + _get_certificates_items(certificates), withdrawals, ): response = client.call(tx_item) @@ -812,13 +895,24 @@ def sign_tx( for tx_item in chain( _get_mint_items(mint), - _get_collateral_input_items(collateral_inputs), + _get_collateral_inputs_items(collateral_inputs), required_signers, ): response = client.call(tx_item) if not isinstance(response, messages.CardanoTxItemAck): raise UNEXPECTED_RESPONSE_ERROR + if collateral_return is not None: + for tx_item in _get_output_items(collateral_return): + response = client.call(tx_item) + if not isinstance(response, messages.CardanoTxItemAck): + raise UNEXPECTED_RESPONSE_ERROR + + for reference_input in reference_inputs: + response = client.call(reference_input) + if not isinstance(response, messages.CardanoTxItemAck): + raise UNEXPECTED_RESPONSE_ERROR + sign_tx_response["witnesses"] = [] for witness_request in witness_requests: response = client.call(witness_request) diff --git a/python/src/trezorlib/cli/cardano.py b/python/src/trezorlib/cli/cardano.py index b7d01c803..29d764590 100644 --- a/python/src/trezorlib/cli/cardano.py +++ b/python/src/trezorlib/cli/cardano.py @@ -98,6 +98,16 @@ def sign_tx( cardano.parse_required_signer(required_signer) for required_signer in transaction.get("required_signers", ()) ] + collateral_return = ( + cardano.parse_output(transaction["collateral_return"]) + if transaction.get("collateral_return") + else None + ) + total_collateral = transaction.get("total_collateral") + reference_inputs = [ + cardano.parse_reference_input(reference_input) + for reference_input in transaction.get("reference_inputs", ()) + ] additional_witness_requests = [ cardano.parse_additional_witness_request(p) for p in transaction["additional_witness_requests"] @@ -121,6 +131,9 @@ def sign_tx( script_data_hash, collateral_inputs, required_signers, + collateral_return, + total_collateral, + reference_inputs, additional_witness_requests, derivation_type=derivation_type, include_network_id=include_network_id, diff --git a/tests/device_tests/cardano/test_sign_tx.py b/tests/device_tests/cardano/test_sign_tx.py index d71c393d1..498f170be 100644 --- a/tests/device_tests/cardano/test_sign_tx.py +++ b/tests/device_tests/cardano/test_sign_tx.py @@ -53,6 +53,14 @@ def test_cardano_sign_tx(client: Client, parameters, result): required_signers = [ cardano.parse_required_signer(s) for s in parameters["required_signers"] ] + collateral_return = ( + cardano.parse_output(parameters["collateral_return"]) + if parameters["collateral_return"] is not None + else None + ) + reference_inputs = [ + cardano.parse_reference_input(i) for i in parameters["reference_inputs"] + ] additional_witness_requests = [ cardano.parse_additional_witness_request(p) for p in parameters["additional_witness_requests"] @@ -72,8 +80,8 @@ def test_cardano_sign_tx(client: Client, parameters, result): inputs=inputs, outputs=outputs, fee=parameters["fee"], - ttl=parameters.get("ttl"), - validity_interval_start=parameters.get("validity_interval_start"), + ttl=parameters["ttl"], + validity_interval_start=parameters["validity_interval_start"], certificates=certificates, withdrawals=withdrawals, protocol_magic=parameters["protocol_magic"], @@ -83,6 +91,9 @@ def test_cardano_sign_tx(client: Client, parameters, result): script_data_hash=script_data_hash, collateral_inputs=collateral_inputs, required_signers=required_signers, + collateral_return=collateral_return, + total_collateral=parameters["total_collateral"], + reference_inputs=reference_inputs, additional_witness_requests=additional_witness_requests, include_network_id=parameters["include_network_id"], ) @@ -112,6 +123,14 @@ def test_cardano_sign_tx_failed(client: Client, parameters, result): required_signers = [ cardano.parse_required_signer(s) for s in parameters["required_signers"] ] + collateral_return = ( + cardano.parse_output(parameters["collateral_return"]) + if parameters["collateral_return"] is not None + else None + ) + reference_inputs = [ + cardano.parse_reference_input(i) for i in parameters["reference_inputs"] + ] additional_witness_requests = [ cardano.parse_additional_witness_request(p) for p in parameters["additional_witness_requests"] @@ -132,8 +151,8 @@ def test_cardano_sign_tx_failed(client: Client, parameters, result): inputs=inputs, outputs=outputs, fee=parameters["fee"], - ttl=parameters.get("ttl"), - validity_interval_start=parameters.get("validity_interval_start"), + ttl=parameters["ttl"], + validity_interval_start=parameters["validity_interval_start"], certificates=certificates, withdrawals=withdrawals, protocol_magic=parameters["protocol_magic"], @@ -143,6 +162,9 @@ def test_cardano_sign_tx_failed(client: Client, parameters, result): script_data_hash=script_data_hash, collateral_inputs=collateral_inputs, required_signers=required_signers, + collateral_return=collateral_return, + total_collateral=parameters["total_collateral"], + reference_inputs=reference_inputs, additional_witness_requests=additional_witness_requests, include_network_id=parameters["include_network_id"], )