1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-20 23:18:46 +00:00

refator(cardano): validate map key order in HashBuilderDict

This commit is contained in:
David Misiak 2022-03-08 14:48:10 +01:00 committed by matejcik
parent a36fc6cadc
commit fec4fa2257
3 changed files with 50 additions and 54 deletions

View File

@ -4,6 +4,7 @@ from apps.common import cbor
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
from trezor import wire
from trezor.utils import HashContext from trezor.utils import HashContext
T = TypeVar("T") T = TypeVar("T")
@ -43,7 +44,13 @@ class HashBuilderCollection:
self.remaining -= 1 self.remaining -= 1
def _hash_item(self, item: Any) -> None: def _hash_item(self, item: Any) -> bytes:
assert self.hash_fn is not None
encoded_item = cbor.encode(item)
self.hash_fn.update(encoded_item)
return encoded_item
def _hash_item_streamed(self, item: Any) -> None:
assert self.hash_fn is not None assert self.hash_fn is not None
for chunk in cbor.encode_streamed(item): for chunk in cbor.encode_streamed(item):
self.hash_fn.update(chunk) self.hash_fn.update(chunk)
@ -74,7 +81,7 @@ class HashBuilderList(HashBuilderCollection, Generic[T]):
if isinstance(item, HashBuilderCollection): if isinstance(item, HashBuilderCollection):
self._insert_child(item) self._insert_child(item)
else: else:
self._hash_item(item) self._hash_item_streamed(item)
return item return item
@ -83,16 +90,31 @@ class HashBuilderList(HashBuilderCollection, Generic[T]):
class HashBuilderDict(HashBuilderCollection, Generic[K, V]): class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
key_order_error: wire.ProcessError
previous_encoded_key: bytes
def __init__(self, size: int, key_order_error: wire.ProcessError):
super().__init__(size)
self.key_order_error = key_order_error
self.previous_encoded_key = b""
def add(self, key: K, value: V) -> V: def add(self, key: K, value: V) -> V:
self._do_enter_item() self._do_enter_item()
# enter key, this must not nest # enter key, this must not nest
assert not isinstance(key, HashBuilderCollection) assert not isinstance(key, HashBuilderCollection)
self._hash_item(key) encoded_key = self._hash_item(key)
# check key ordering
if not cbor.precedes(self.previous_encoded_key, encoded_key):
raise self.key_order_error
self.previous_encoded_key = encoded_key
# enter value, this can nest # enter value, this can nest
if isinstance(value, HashBuilderCollection): if isinstance(value, HashBuilderCollection):
self._insert_child(value) self._insert_child(value)
else: else:
self._hash_item(value) self._hash_item_streamed(value)
return value return value

View File

@ -185,7 +185,9 @@ async def sign_tx(
account_path_checker = AccountPathChecker() account_path_checker = AccountPathChecker()
hash_fn = hashlib.blake2b(outlen=32) hash_fn = hashlib.blake2b(outlen=32)
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(tx_body_map_item_count) tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(
tx_body_map_item_count, INVALID_TX_SIGNING_REQUEST
)
tx_dict.start(hash_fn) tx_dict.start(hash_fn)
with tx_dict: with tx_dict:
await _process_transaction(ctx, msg, keychain, tx_dict, account_path_checker) await _process_transaction(ctx, msg, keychain, tx_dict, account_path_checker)
@ -296,7 +298,7 @@ async def _process_transaction(
if msg.withdrawals_count > 0: if msg.withdrawals_count > 0:
withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict( withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict(
msg.withdrawals_count msg.withdrawals_count, INVALID_WITHDRAWAL
) )
with tx_dict.add(TX_BODY_KEY_WITHDRAWALS, withdrawals_dict): with tx_dict.add(TX_BODY_KEY_WITHDRAWALS, withdrawals_dict):
await _process_withdrawals( await _process_withdrawals(
@ -324,7 +326,7 @@ async def _process_transaction(
if msg.minting_asset_groups_count > 0: if msg.minting_asset_groups_count > 0:
minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict( minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict(
msg.minting_asset_groups_count msg.minting_asset_groups_count, INVALID_TOKEN_BUNDLE_MINT
) )
with tx_dict.add(TX_BODY_KEY_MINT, minting_dict): with tx_dict.add(TX_BODY_KEY_MINT, minting_dict):
await _process_minting(ctx, minting_dict) await _process_minting(ctx, minting_dict)
@ -468,7 +470,9 @@ async def _process_outputs(
output_value_list.append(output.amount) output_value_list.append(output.amount)
asset_groups_dict: HashBuilderDict[ asset_groups_dict: HashBuilderDict[
bytes, HashBuilderDict[bytes, int] bytes, HashBuilderDict[bytes, int]
] = HashBuilderDict(output.asset_groups_count) ] = HashBuilderDict(
output.asset_groups_count, INVALID_TOKEN_BUNDLE_OUTPUT
)
with output_value_list.append(asset_groups_dict): with output_value_list.append(asset_groups_dict):
await _process_asset_groups( await _process_asset_groups(
ctx, ctx,
@ -492,15 +496,15 @@ async def _process_asset_groups(
should_show_tokens: bool, should_show_tokens: bool,
) -> None: ) -> None:
"""Read, validate and serialize the asset groups of an output.""" """Read, validate and serialize the asset groups of an output."""
previous_policy_id: bytes = b""
for _ in range(asset_groups_count): for _ in range(asset_groups_count):
asset_group: CardanoAssetGroup = await ctx.call( asset_group: CardanoAssetGroup = await ctx.call(
CardanoTxItemAck(), CardanoAssetGroup CardanoTxItemAck(), CardanoAssetGroup
) )
_validate_asset_group(asset_group, previous_policy_id) _validate_asset_group(asset_group)
previous_policy_id = asset_group.policy_id
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count) tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_OUTPUT
)
with asset_groups_dict.add(asset_group.policy_id, tokens): with asset_groups_dict.add(asset_group.policy_id, tokens):
await _process_tokens( await _process_tokens(
ctx, ctx,
@ -519,11 +523,9 @@ async def _process_tokens(
should_show_tokens: bool, should_show_tokens: bool,
) -> None: ) -> None:
"""Read, validate, confirm and serialize the tokens of an asset group.""" """Read, validate, confirm and serialize the tokens of an asset group."""
previous_asset_name_bytes: bytes = b""
for _ in range(tokens_count): for _ in range(tokens_count):
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken) token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
_validate_token(token, previous_asset_name_bytes) _validate_token(token)
previous_asset_name_bytes = token.asset_name_bytes
if should_show_tokens: if should_show_tokens:
await confirm_sending_token(ctx, policy_id, token) await confirm_sending_token(ctx, policy_id, token)
@ -641,24 +643,18 @@ async def _process_withdrawals(
if withdrawals_count == 0: if withdrawals_count == 0:
return return
previous_reward_address_bytes: bytes = b""
for _ in range(withdrawals_count): for _ in range(withdrawals_count):
withdrawal: CardanoTxWithdrawal = await ctx.call( withdrawal: CardanoTxWithdrawal = await ctx.call(
CardanoTxItemAck(), CardanoTxWithdrawal CardanoTxItemAck(), CardanoTxWithdrawal
) )
_validate_withdrawal( _validate_withdrawal(
keychain,
withdrawal, withdrawal,
signing_mode, signing_mode,
protocol_magic,
network_id,
account_path_checker, account_path_checker,
previous_reward_address_bytes,
) )
reward_address_bytes = _derive_withdrawal_reward_address_bytes( reward_address_bytes = _derive_withdrawal_reward_address_bytes(
keychain, withdrawal, protocol_magic, network_id keychain, withdrawal, protocol_magic, network_id
) )
previous_reward_address_bytes = reward_address_bytes
await confirm_withdrawal(ctx, withdrawal, reward_address_bytes, network_id) await confirm_withdrawal(ctx, withdrawal, reward_address_bytes, network_id)
@ -707,15 +703,15 @@ async def _process_minting(
await show_warning_tx_contains_mint(ctx) await show_warning_tx_contains_mint(ctx)
previous_policy_id: bytes = b""
for _ in range(token_minting.asset_groups_count): for _ in range(token_minting.asset_groups_count):
asset_group: CardanoAssetGroup = await ctx.call( asset_group: CardanoAssetGroup = await ctx.call(
CardanoTxItemAck(), CardanoAssetGroup CardanoTxItemAck(), CardanoAssetGroup
) )
_validate_asset_group(asset_group, previous_policy_id, is_mint=True) _validate_asset_group(asset_group, is_mint=True)
previous_policy_id = asset_group.policy_id
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count) tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_MINT
)
with minting_dict.add(asset_group.policy_id, tokens): with minting_dict.add(asset_group.policy_id, tokens):
await _process_minting_tokens( await _process_minting_tokens(
ctx, ctx,
@ -732,11 +728,9 @@ async def _process_minting_tokens(
tokens_count: int, tokens_count: int,
) -> None: ) -> None:
"""Read, validate, confirm and serialize the tokens of an asset group.""" """Read, validate, confirm and serialize the tokens of an asset group."""
previous_asset_name_bytes: bytes = b""
for _ in range(tokens_count): for _ in range(tokens_count):
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken) token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
_validate_token(token, previous_asset_name_bytes, is_mint=True) _validate_token(token, is_mint=True)
previous_asset_name_bytes = token.asset_name_bytes
await confirm_token_minting(ctx, policy_id, token) await confirm_token_minting(ctx, policy_id, token)
assert token.mint_amount is not None # _validate_token assert token.mint_amount is not None # _validate_token
@ -1005,7 +999,7 @@ async def _show_output(
def _validate_asset_group( def _validate_asset_group(
asset_group: CardanoAssetGroup, previous_policy_id: bytes, is_mint: bool = False asset_group: CardanoAssetGroup, is_mint: bool = False
) -> None: ) -> None:
INVALID_TOKEN_BUNDLE = ( INVALID_TOKEN_BUNDLE = (
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
@ -1015,13 +1009,9 @@ def _validate_asset_group(
raise INVALID_TOKEN_BUNDLE raise INVALID_TOKEN_BUNDLE
if asset_group.tokens_count == 0: if asset_group.tokens_count == 0:
raise INVALID_TOKEN_BUNDLE raise INVALID_TOKEN_BUNDLE
if not cbor.are_canonically_ordered(previous_policy_id, asset_group.policy_id):
raise INVALID_TOKEN_BUNDLE
def _validate_token( def _validate_token(token: CardanoToken, is_mint: bool = False) -> None:
token: CardanoToken, previous_asset_name_bytes: bytes, is_mint: bool = False
) -> None:
INVALID_TOKEN_BUNDLE = ( INVALID_TOKEN_BUNDLE = (
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
) )
@ -1035,10 +1025,6 @@ def _validate_token(
if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH: if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH:
raise INVALID_TOKEN_BUNDLE raise INVALID_TOKEN_BUNDLE
if not cbor.are_canonically_ordered(
previous_asset_name_bytes, token.asset_name_bytes
):
raise INVALID_TOKEN_BUNDLE
async def _show_certificate( async def _show_certificate(
@ -1065,13 +1051,9 @@ async def _show_certificate(
def _validate_withdrawal( def _validate_withdrawal(
keychain: seed.Keychain,
withdrawal: CardanoTxWithdrawal, withdrawal: CardanoTxWithdrawal,
signing_mode: CardanoTxSigningMode, signing_mode: CardanoTxSigningMode,
protocol_magic: int,
network_id: int,
account_path_checker: AccountPathChecker, account_path_checker: AccountPathChecker,
previous_reward_address_bytes: bytes,
) -> None: ) -> None:
validate_stake_credential( validate_stake_credential(
withdrawal.path, withdrawal.path,
@ -1084,14 +1066,6 @@ def _validate_withdrawal(
if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY: if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY:
raise INVALID_WITHDRAWAL raise INVALID_WITHDRAWAL
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
keychain, withdrawal, protocol_magic, network_id
)
if not cbor.are_canonically_ordered(
previous_reward_address_bytes, reward_address_bytes
):
raise INVALID_WITHDRAWAL
account_path_checker.add_withdrawal(withdrawal) account_path_checker.add_withdrawal(withdrawal)

View File

@ -320,11 +320,11 @@ def create_map_header(size: int) -> bytes:
return _header(_CBOR_MAP, size) return _header(_CBOR_MAP, size)
def are_canonically_ordered(previous: Value, current: Value) -> bool: def precedes(prev: bytes, curr: bytes) -> bool:
""" """
Returns True if `previous` is smaller than `current` with regards to Returns True if `prev` is smaller than `curr` with regards to
the cbor map key ordering as defined in the cbor map key ordering as defined in
https://datatracker.ietf.org/doc/html/rfc7049#section-3.9 https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
Note that `prev` and `curr` must already be cbor-encoded.
""" """
u, v = encode(previous), encode(current) return len(prev) < len(curr) or (len(prev) == len(curr) and prev < curr)
return len(u) < len(v) or (len(u) == len(v) and u < v)