diff --git a/core/src/apps/cardano/helpers/hash_builder_collection.py b/core/src/apps/cardano/helpers/hash_builder_collection.py index a30122aca..33998cd72 100644 --- a/core/src/apps/cardano/helpers/hash_builder_collection.py +++ b/core/src/apps/cardano/helpers/hash_builder_collection.py @@ -4,6 +4,7 @@ from apps.common import cbor if TYPE_CHECKING: from typing import Any, Generic, TypeVar + from trezor import wire from trezor.utils import HashContext T = TypeVar("T") @@ -43,7 +44,13 @@ class HashBuilderCollection: self.remaining -= 1 - def _hash_item(self, item: Any) -> None: + def _hash_item(self, item: Any) -> bytes: + assert self.hash_fn is not None + encoded_item = cbor.encode(item) + self.hash_fn.update(encoded_item) + return encoded_item + + def _hash_item_streamed(self, item: Any) -> None: assert self.hash_fn is not None for chunk in cbor.encode_streamed(item): self.hash_fn.update(chunk) @@ -74,7 +81,7 @@ class HashBuilderList(HashBuilderCollection, Generic[T]): if isinstance(item, HashBuilderCollection): self._insert_child(item) else: - self._hash_item(item) + self._hash_item_streamed(item) return item @@ -83,16 +90,31 @@ class HashBuilderList(HashBuilderCollection, Generic[T]): class HashBuilderDict(HashBuilderCollection, Generic[K, V]): + key_order_error: wire.ProcessError + previous_encoded_key: bytes + + def __init__(self, size: int, key_order_error: wire.ProcessError): + super().__init__(size) + self.key_order_error = key_order_error + self.previous_encoded_key = b"" + def add(self, key: K, value: V) -> V: self._do_enter_item() + # enter key, this must not nest assert not isinstance(key, HashBuilderCollection) - self._hash_item(key) + encoded_key = self._hash_item(key) + + # check key ordering + if not cbor.precedes(self.previous_encoded_key, encoded_key): + raise self.key_order_error + self.previous_encoded_key = encoded_key + # enter value, this can nest if isinstance(value, HashBuilderCollection): self._insert_child(value) else: - self._hash_item(value) + self._hash_item_streamed(value) return value diff --git a/core/src/apps/cardano/sign_tx.py b/core/src/apps/cardano/sign_tx.py index 643eaa9da..c3ae8324c 100644 --- a/core/src/apps/cardano/sign_tx.py +++ b/core/src/apps/cardano/sign_tx.py @@ -185,7 +185,9 @@ async def sign_tx( account_path_checker = AccountPathChecker() hash_fn = hashlib.blake2b(outlen=32) - tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(tx_body_map_item_count) + tx_dict: HashBuilderDict[int, Any] = HashBuilderDict( + tx_body_map_item_count, INVALID_TX_SIGNING_REQUEST + ) tx_dict.start(hash_fn) with tx_dict: await _process_transaction(ctx, msg, keychain, tx_dict, account_path_checker) @@ -296,7 +298,7 @@ async def _process_transaction( if msg.withdrawals_count > 0: withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict( - msg.withdrawals_count + msg.withdrawals_count, INVALID_WITHDRAWAL ) with tx_dict.add(TX_BODY_KEY_WITHDRAWALS, withdrawals_dict): await _process_withdrawals( @@ -324,7 +326,7 @@ async def _process_transaction( if msg.minting_asset_groups_count > 0: minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict( - msg.minting_asset_groups_count + msg.minting_asset_groups_count, INVALID_TOKEN_BUNDLE_MINT ) with tx_dict.add(TX_BODY_KEY_MINT, minting_dict): await _process_minting(ctx, minting_dict) @@ -468,7 +470,9 @@ async def _process_outputs( output_value_list.append(output.amount) asset_groups_dict: HashBuilderDict[ bytes, HashBuilderDict[bytes, int] - ] = HashBuilderDict(output.asset_groups_count) + ] = HashBuilderDict( + output.asset_groups_count, INVALID_TOKEN_BUNDLE_OUTPUT + ) with output_value_list.append(asset_groups_dict): await _process_asset_groups( ctx, @@ -492,15 +496,15 @@ async def _process_asset_groups( should_show_tokens: bool, ) -> None: """Read, validate and serialize the asset groups of an output.""" - previous_policy_id: bytes = b"" for _ in range(asset_groups_count): asset_group: CardanoAssetGroup = await ctx.call( CardanoTxItemAck(), CardanoAssetGroup ) - _validate_asset_group(asset_group, previous_policy_id) - previous_policy_id = asset_group.policy_id + _validate_asset_group(asset_group) - tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count) + tokens: HashBuilderDict[bytes, int] = HashBuilderDict( + asset_group.tokens_count, INVALID_TOKEN_BUNDLE_OUTPUT + ) with asset_groups_dict.add(asset_group.policy_id, tokens): await _process_tokens( ctx, @@ -519,11 +523,9 @@ async def _process_tokens( should_show_tokens: bool, ) -> None: """Read, validate, confirm and serialize the tokens of an asset group.""" - previous_asset_name_bytes: bytes = b"" for _ in range(tokens_count): token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken) - _validate_token(token, previous_asset_name_bytes) - previous_asset_name_bytes = token.asset_name_bytes + _validate_token(token) if should_show_tokens: await confirm_sending_token(ctx, policy_id, token) @@ -641,24 +643,18 @@ async def _process_withdrawals( if withdrawals_count == 0: return - previous_reward_address_bytes: bytes = b"" for _ in range(withdrawals_count): withdrawal: CardanoTxWithdrawal = await ctx.call( CardanoTxItemAck(), CardanoTxWithdrawal ) _validate_withdrawal( - keychain, withdrawal, signing_mode, - protocol_magic, - network_id, account_path_checker, - previous_reward_address_bytes, ) reward_address_bytes = _derive_withdrawal_reward_address_bytes( keychain, withdrawal, protocol_magic, network_id ) - previous_reward_address_bytes = reward_address_bytes await confirm_withdrawal(ctx, withdrawal, reward_address_bytes, network_id) @@ -707,15 +703,15 @@ async def _process_minting( await show_warning_tx_contains_mint(ctx) - previous_policy_id: bytes = b"" for _ in range(token_minting.asset_groups_count): asset_group: CardanoAssetGroup = await ctx.call( CardanoTxItemAck(), CardanoAssetGroup ) - _validate_asset_group(asset_group, previous_policy_id, is_mint=True) - previous_policy_id = asset_group.policy_id + _validate_asset_group(asset_group, is_mint=True) - tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count) + tokens: HashBuilderDict[bytes, int] = HashBuilderDict( + asset_group.tokens_count, INVALID_TOKEN_BUNDLE_MINT + ) with minting_dict.add(asset_group.policy_id, tokens): await _process_minting_tokens( ctx, @@ -732,11 +728,9 @@ async def _process_minting_tokens( tokens_count: int, ) -> None: """Read, validate, confirm and serialize the tokens of an asset group.""" - previous_asset_name_bytes: bytes = b"" for _ in range(tokens_count): token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken) - _validate_token(token, previous_asset_name_bytes, is_mint=True) - previous_asset_name_bytes = token.asset_name_bytes + _validate_token(token, is_mint=True) await confirm_token_minting(ctx, policy_id, token) assert token.mint_amount is not None # _validate_token @@ -1005,7 +999,7 @@ async def _show_output( def _validate_asset_group( - asset_group: CardanoAssetGroup, previous_policy_id: bytes, is_mint: bool = False + asset_group: CardanoAssetGroup, is_mint: bool = False ) -> None: INVALID_TOKEN_BUNDLE = ( INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT @@ -1015,13 +1009,9 @@ def _validate_asset_group( raise INVALID_TOKEN_BUNDLE if asset_group.tokens_count == 0: raise INVALID_TOKEN_BUNDLE - if not cbor.are_canonically_ordered(previous_policy_id, asset_group.policy_id): - raise INVALID_TOKEN_BUNDLE -def _validate_token( - token: CardanoToken, previous_asset_name_bytes: bytes, is_mint: bool = False -) -> None: +def _validate_token(token: CardanoToken, is_mint: bool = False) -> None: INVALID_TOKEN_BUNDLE = ( INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT ) @@ -1035,10 +1025,6 @@ def _validate_token( if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH: raise INVALID_TOKEN_BUNDLE - if not cbor.are_canonically_ordered( - previous_asset_name_bytes, token.asset_name_bytes - ): - raise INVALID_TOKEN_BUNDLE async def _show_certificate( @@ -1065,13 +1051,9 @@ async def _show_certificate( def _validate_withdrawal( - keychain: seed.Keychain, withdrawal: CardanoTxWithdrawal, signing_mode: CardanoTxSigningMode, - protocol_magic: int, - network_id: int, account_path_checker: AccountPathChecker, - previous_reward_address_bytes: bytes, ) -> None: validate_stake_credential( withdrawal.path, @@ -1084,14 +1066,6 @@ def _validate_withdrawal( if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY: raise INVALID_WITHDRAWAL - reward_address_bytes = _derive_withdrawal_reward_address_bytes( - keychain, withdrawal, protocol_magic, network_id - ) - if not cbor.are_canonically_ordered( - previous_reward_address_bytes, reward_address_bytes - ): - raise INVALID_WITHDRAWAL - account_path_checker.add_withdrawal(withdrawal) diff --git a/core/src/apps/common/cbor.py b/core/src/apps/common/cbor.py index 9ea69f6f3..042b9e593 100644 --- a/core/src/apps/common/cbor.py +++ b/core/src/apps/common/cbor.py @@ -320,11 +320,11 @@ def create_map_header(size: int) -> bytes: return _header(_CBOR_MAP, size) -def are_canonically_ordered(previous: Value, current: Value) -> bool: +def precedes(prev: bytes, curr: bytes) -> bool: """ - Returns True if `previous` is smaller than `current` with regards to + Returns True if `prev` is smaller than `curr` with regards to the cbor map key ordering as defined in https://datatracker.ietf.org/doc/html/rfc7049#section-3.9 + Note that `prev` and `curr` must already be cbor-encoded. """ - u, v = encode(previous), encode(current) - return len(u) < len(v) or (len(u) == len(v) and u < v) + return len(prev) < len(curr) or (len(prev) == len(curr) and prev < curr)