mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-20 14:39:22 +00:00
refator(cardano): validate map key order in HashBuilderDict
This commit is contained in:
parent
a36fc6cadc
commit
fec4fa2257
@ -4,6 +4,7 @@ from apps.common import cbor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any, Generic, TypeVar
|
||||
from trezor import wire
|
||||
from trezor.utils import HashContext
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -43,7 +44,13 @@ class HashBuilderCollection:
|
||||
|
||||
self.remaining -= 1
|
||||
|
||||
def _hash_item(self, item: Any) -> None:
|
||||
def _hash_item(self, item: Any) -> bytes:
|
||||
assert self.hash_fn is not None
|
||||
encoded_item = cbor.encode(item)
|
||||
self.hash_fn.update(encoded_item)
|
||||
return encoded_item
|
||||
|
||||
def _hash_item_streamed(self, item: Any) -> None:
|
||||
assert self.hash_fn is not None
|
||||
for chunk in cbor.encode_streamed(item):
|
||||
self.hash_fn.update(chunk)
|
||||
@ -74,7 +81,7 @@ class HashBuilderList(HashBuilderCollection, Generic[T]):
|
||||
if isinstance(item, HashBuilderCollection):
|
||||
self._insert_child(item)
|
||||
else:
|
||||
self._hash_item(item)
|
||||
self._hash_item_streamed(item)
|
||||
|
||||
return item
|
||||
|
||||
@ -83,16 +90,31 @@ class HashBuilderList(HashBuilderCollection, Generic[T]):
|
||||
|
||||
|
||||
class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
|
||||
key_order_error: wire.ProcessError
|
||||
previous_encoded_key: bytes
|
||||
|
||||
def __init__(self, size: int, key_order_error: wire.ProcessError):
|
||||
super().__init__(size)
|
||||
self.key_order_error = key_order_error
|
||||
self.previous_encoded_key = b""
|
||||
|
||||
def add(self, key: K, value: V) -> V:
|
||||
self._do_enter_item()
|
||||
|
||||
# enter key, this must not nest
|
||||
assert not isinstance(key, HashBuilderCollection)
|
||||
self._hash_item(key)
|
||||
encoded_key = self._hash_item(key)
|
||||
|
||||
# check key ordering
|
||||
if not cbor.precedes(self.previous_encoded_key, encoded_key):
|
||||
raise self.key_order_error
|
||||
self.previous_encoded_key = encoded_key
|
||||
|
||||
# enter value, this can nest
|
||||
if isinstance(value, HashBuilderCollection):
|
||||
self._insert_child(value)
|
||||
else:
|
||||
self._hash_item(value)
|
||||
self._hash_item_streamed(value)
|
||||
|
||||
return value
|
||||
|
||||
|
@ -185,7 +185,9 @@ async def sign_tx(
|
||||
account_path_checker = AccountPathChecker()
|
||||
|
||||
hash_fn = hashlib.blake2b(outlen=32)
|
||||
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(tx_body_map_item_count)
|
||||
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(
|
||||
tx_body_map_item_count, INVALID_TX_SIGNING_REQUEST
|
||||
)
|
||||
tx_dict.start(hash_fn)
|
||||
with tx_dict:
|
||||
await _process_transaction(ctx, msg, keychain, tx_dict, account_path_checker)
|
||||
@ -296,7 +298,7 @@ async def _process_transaction(
|
||||
|
||||
if msg.withdrawals_count > 0:
|
||||
withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict(
|
||||
msg.withdrawals_count
|
||||
msg.withdrawals_count, INVALID_WITHDRAWAL
|
||||
)
|
||||
with tx_dict.add(TX_BODY_KEY_WITHDRAWALS, withdrawals_dict):
|
||||
await _process_withdrawals(
|
||||
@ -324,7 +326,7 @@ async def _process_transaction(
|
||||
|
||||
if msg.minting_asset_groups_count > 0:
|
||||
minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict(
|
||||
msg.minting_asset_groups_count
|
||||
msg.minting_asset_groups_count, INVALID_TOKEN_BUNDLE_MINT
|
||||
)
|
||||
with tx_dict.add(TX_BODY_KEY_MINT, minting_dict):
|
||||
await _process_minting(ctx, minting_dict)
|
||||
@ -468,7 +470,9 @@ async def _process_outputs(
|
||||
output_value_list.append(output.amount)
|
||||
asset_groups_dict: HashBuilderDict[
|
||||
bytes, HashBuilderDict[bytes, int]
|
||||
] = HashBuilderDict(output.asset_groups_count)
|
||||
] = HashBuilderDict(
|
||||
output.asset_groups_count, INVALID_TOKEN_BUNDLE_OUTPUT
|
||||
)
|
||||
with output_value_list.append(asset_groups_dict):
|
||||
await _process_asset_groups(
|
||||
ctx,
|
||||
@ -492,15 +496,15 @@ async def _process_asset_groups(
|
||||
should_show_tokens: bool,
|
||||
) -> None:
|
||||
"""Read, validate and serialize the asset groups of an output."""
|
||||
previous_policy_id: bytes = b""
|
||||
for _ in range(asset_groups_count):
|
||||
asset_group: CardanoAssetGroup = await ctx.call(
|
||||
CardanoTxItemAck(), CardanoAssetGroup
|
||||
)
|
||||
_validate_asset_group(asset_group, previous_policy_id)
|
||||
previous_policy_id = asset_group.policy_id
|
||||
_validate_asset_group(asset_group)
|
||||
|
||||
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count)
|
||||
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
|
||||
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_OUTPUT
|
||||
)
|
||||
with asset_groups_dict.add(asset_group.policy_id, tokens):
|
||||
await _process_tokens(
|
||||
ctx,
|
||||
@ -519,11 +523,9 @@ async def _process_tokens(
|
||||
should_show_tokens: bool,
|
||||
) -> None:
|
||||
"""Read, validate, confirm and serialize the tokens of an asset group."""
|
||||
previous_asset_name_bytes: bytes = b""
|
||||
for _ in range(tokens_count):
|
||||
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
|
||||
_validate_token(token, previous_asset_name_bytes)
|
||||
previous_asset_name_bytes = token.asset_name_bytes
|
||||
_validate_token(token)
|
||||
if should_show_tokens:
|
||||
await confirm_sending_token(ctx, policy_id, token)
|
||||
|
||||
@ -641,24 +643,18 @@ async def _process_withdrawals(
|
||||
if withdrawals_count == 0:
|
||||
return
|
||||
|
||||
previous_reward_address_bytes: bytes = b""
|
||||
for _ in range(withdrawals_count):
|
||||
withdrawal: CardanoTxWithdrawal = await ctx.call(
|
||||
CardanoTxItemAck(), CardanoTxWithdrawal
|
||||
)
|
||||
_validate_withdrawal(
|
||||
keychain,
|
||||
withdrawal,
|
||||
signing_mode,
|
||||
protocol_magic,
|
||||
network_id,
|
||||
account_path_checker,
|
||||
previous_reward_address_bytes,
|
||||
)
|
||||
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
|
||||
keychain, withdrawal, protocol_magic, network_id
|
||||
)
|
||||
previous_reward_address_bytes = reward_address_bytes
|
||||
|
||||
await confirm_withdrawal(ctx, withdrawal, reward_address_bytes, network_id)
|
||||
|
||||
@ -707,15 +703,15 @@ async def _process_minting(
|
||||
|
||||
await show_warning_tx_contains_mint(ctx)
|
||||
|
||||
previous_policy_id: bytes = b""
|
||||
for _ in range(token_minting.asset_groups_count):
|
||||
asset_group: CardanoAssetGroup = await ctx.call(
|
||||
CardanoTxItemAck(), CardanoAssetGroup
|
||||
)
|
||||
_validate_asset_group(asset_group, previous_policy_id, is_mint=True)
|
||||
previous_policy_id = asset_group.policy_id
|
||||
_validate_asset_group(asset_group, is_mint=True)
|
||||
|
||||
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count)
|
||||
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
|
||||
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_MINT
|
||||
)
|
||||
with minting_dict.add(asset_group.policy_id, tokens):
|
||||
await _process_minting_tokens(
|
||||
ctx,
|
||||
@ -732,11 +728,9 @@ async def _process_minting_tokens(
|
||||
tokens_count: int,
|
||||
) -> None:
|
||||
"""Read, validate, confirm and serialize the tokens of an asset group."""
|
||||
previous_asset_name_bytes: bytes = b""
|
||||
for _ in range(tokens_count):
|
||||
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
|
||||
_validate_token(token, previous_asset_name_bytes, is_mint=True)
|
||||
previous_asset_name_bytes = token.asset_name_bytes
|
||||
_validate_token(token, is_mint=True)
|
||||
await confirm_token_minting(ctx, policy_id, token)
|
||||
|
||||
assert token.mint_amount is not None # _validate_token
|
||||
@ -1005,7 +999,7 @@ async def _show_output(
|
||||
|
||||
|
||||
def _validate_asset_group(
|
||||
asset_group: CardanoAssetGroup, previous_policy_id: bytes, is_mint: bool = False
|
||||
asset_group: CardanoAssetGroup, is_mint: bool = False
|
||||
) -> None:
|
||||
INVALID_TOKEN_BUNDLE = (
|
||||
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
|
||||
@ -1015,13 +1009,9 @@ def _validate_asset_group(
|
||||
raise INVALID_TOKEN_BUNDLE
|
||||
if asset_group.tokens_count == 0:
|
||||
raise INVALID_TOKEN_BUNDLE
|
||||
if not cbor.are_canonically_ordered(previous_policy_id, asset_group.policy_id):
|
||||
raise INVALID_TOKEN_BUNDLE
|
||||
|
||||
|
||||
def _validate_token(
|
||||
token: CardanoToken, previous_asset_name_bytes: bytes, is_mint: bool = False
|
||||
) -> None:
|
||||
def _validate_token(token: CardanoToken, is_mint: bool = False) -> None:
|
||||
INVALID_TOKEN_BUNDLE = (
|
||||
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
|
||||
)
|
||||
@ -1035,10 +1025,6 @@ def _validate_token(
|
||||
|
||||
if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH:
|
||||
raise INVALID_TOKEN_BUNDLE
|
||||
if not cbor.are_canonically_ordered(
|
||||
previous_asset_name_bytes, token.asset_name_bytes
|
||||
):
|
||||
raise INVALID_TOKEN_BUNDLE
|
||||
|
||||
|
||||
async def _show_certificate(
|
||||
@ -1065,13 +1051,9 @@ async def _show_certificate(
|
||||
|
||||
|
||||
def _validate_withdrawal(
|
||||
keychain: seed.Keychain,
|
||||
withdrawal: CardanoTxWithdrawal,
|
||||
signing_mode: CardanoTxSigningMode,
|
||||
protocol_magic: int,
|
||||
network_id: int,
|
||||
account_path_checker: AccountPathChecker,
|
||||
previous_reward_address_bytes: bytes,
|
||||
) -> None:
|
||||
validate_stake_credential(
|
||||
withdrawal.path,
|
||||
@ -1084,14 +1066,6 @@ def _validate_withdrawal(
|
||||
if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY:
|
||||
raise INVALID_WITHDRAWAL
|
||||
|
||||
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
|
||||
keychain, withdrawal, protocol_magic, network_id
|
||||
)
|
||||
if not cbor.are_canonically_ordered(
|
||||
previous_reward_address_bytes, reward_address_bytes
|
||||
):
|
||||
raise INVALID_WITHDRAWAL
|
||||
|
||||
account_path_checker.add_withdrawal(withdrawal)
|
||||
|
||||
|
||||
|
@ -320,11 +320,11 @@ def create_map_header(size: int) -> bytes:
|
||||
return _header(_CBOR_MAP, size)
|
||||
|
||||
|
||||
def are_canonically_ordered(previous: Value, current: Value) -> bool:
|
||||
def precedes(prev: bytes, curr: bytes) -> bool:
|
||||
"""
|
||||
Returns True if `previous` is smaller than `current` with regards to
|
||||
Returns True if `prev` is smaller than `curr` with regards to
|
||||
the cbor map key ordering as defined in
|
||||
https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
|
||||
Note that `prev` and `curr` must already be cbor-encoded.
|
||||
"""
|
||||
u, v = encode(previous), encode(current)
|
||||
return len(u) < len(v) or (len(u) == len(v) and u < v)
|
||||
return len(prev) < len(curr) or (len(prev) == len(curr) and prev < curr)
|
||||
|
Loading…
Reference in New Issue
Block a user