mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-06-20 23:18:46 +00:00
refator(cardano): validate map key order in HashBuilderDict
This commit is contained in:
parent
a36fc6cadc
commit
fec4fa2257
@ -4,6 +4,7 @@ from apps.common import cbor
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
from trezor import wire
|
||||||
from trezor.utils import HashContext
|
from trezor.utils import HashContext
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@ -43,7 +44,13 @@ class HashBuilderCollection:
|
|||||||
|
|
||||||
self.remaining -= 1
|
self.remaining -= 1
|
||||||
|
|
||||||
def _hash_item(self, item: Any) -> None:
|
def _hash_item(self, item: Any) -> bytes:
|
||||||
|
assert self.hash_fn is not None
|
||||||
|
encoded_item = cbor.encode(item)
|
||||||
|
self.hash_fn.update(encoded_item)
|
||||||
|
return encoded_item
|
||||||
|
|
||||||
|
def _hash_item_streamed(self, item: Any) -> None:
|
||||||
assert self.hash_fn is not None
|
assert self.hash_fn is not None
|
||||||
for chunk in cbor.encode_streamed(item):
|
for chunk in cbor.encode_streamed(item):
|
||||||
self.hash_fn.update(chunk)
|
self.hash_fn.update(chunk)
|
||||||
@ -74,7 +81,7 @@ class HashBuilderList(HashBuilderCollection, Generic[T]):
|
|||||||
if isinstance(item, HashBuilderCollection):
|
if isinstance(item, HashBuilderCollection):
|
||||||
self._insert_child(item)
|
self._insert_child(item)
|
||||||
else:
|
else:
|
||||||
self._hash_item(item)
|
self._hash_item_streamed(item)
|
||||||
|
|
||||||
return item
|
return item
|
||||||
|
|
||||||
@ -83,16 +90,31 @@ class HashBuilderList(HashBuilderCollection, Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
|
class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
|
||||||
|
key_order_error: wire.ProcessError
|
||||||
|
previous_encoded_key: bytes
|
||||||
|
|
||||||
|
def __init__(self, size: int, key_order_error: wire.ProcessError):
|
||||||
|
super().__init__(size)
|
||||||
|
self.key_order_error = key_order_error
|
||||||
|
self.previous_encoded_key = b""
|
||||||
|
|
||||||
def add(self, key: K, value: V) -> V:
|
def add(self, key: K, value: V) -> V:
|
||||||
self._do_enter_item()
|
self._do_enter_item()
|
||||||
|
|
||||||
# enter key, this must not nest
|
# enter key, this must not nest
|
||||||
assert not isinstance(key, HashBuilderCollection)
|
assert not isinstance(key, HashBuilderCollection)
|
||||||
self._hash_item(key)
|
encoded_key = self._hash_item(key)
|
||||||
|
|
||||||
|
# check key ordering
|
||||||
|
if not cbor.precedes(self.previous_encoded_key, encoded_key):
|
||||||
|
raise self.key_order_error
|
||||||
|
self.previous_encoded_key = encoded_key
|
||||||
|
|
||||||
# enter value, this can nest
|
# enter value, this can nest
|
||||||
if isinstance(value, HashBuilderCollection):
|
if isinstance(value, HashBuilderCollection):
|
||||||
self._insert_child(value)
|
self._insert_child(value)
|
||||||
else:
|
else:
|
||||||
self._hash_item(value)
|
self._hash_item_streamed(value)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@ -185,7 +185,9 @@ async def sign_tx(
|
|||||||
account_path_checker = AccountPathChecker()
|
account_path_checker = AccountPathChecker()
|
||||||
|
|
||||||
hash_fn = hashlib.blake2b(outlen=32)
|
hash_fn = hashlib.blake2b(outlen=32)
|
||||||
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(tx_body_map_item_count)
|
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(
|
||||||
|
tx_body_map_item_count, INVALID_TX_SIGNING_REQUEST
|
||||||
|
)
|
||||||
tx_dict.start(hash_fn)
|
tx_dict.start(hash_fn)
|
||||||
with tx_dict:
|
with tx_dict:
|
||||||
await _process_transaction(ctx, msg, keychain, tx_dict, account_path_checker)
|
await _process_transaction(ctx, msg, keychain, tx_dict, account_path_checker)
|
||||||
@ -296,7 +298,7 @@ async def _process_transaction(
|
|||||||
|
|
||||||
if msg.withdrawals_count > 0:
|
if msg.withdrawals_count > 0:
|
||||||
withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict(
|
withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict(
|
||||||
msg.withdrawals_count
|
msg.withdrawals_count, INVALID_WITHDRAWAL
|
||||||
)
|
)
|
||||||
with tx_dict.add(TX_BODY_KEY_WITHDRAWALS, withdrawals_dict):
|
with tx_dict.add(TX_BODY_KEY_WITHDRAWALS, withdrawals_dict):
|
||||||
await _process_withdrawals(
|
await _process_withdrawals(
|
||||||
@ -324,7 +326,7 @@ async def _process_transaction(
|
|||||||
|
|
||||||
if msg.minting_asset_groups_count > 0:
|
if msg.minting_asset_groups_count > 0:
|
||||||
minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict(
|
minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict(
|
||||||
msg.minting_asset_groups_count
|
msg.minting_asset_groups_count, INVALID_TOKEN_BUNDLE_MINT
|
||||||
)
|
)
|
||||||
with tx_dict.add(TX_BODY_KEY_MINT, minting_dict):
|
with tx_dict.add(TX_BODY_KEY_MINT, minting_dict):
|
||||||
await _process_minting(ctx, minting_dict)
|
await _process_minting(ctx, minting_dict)
|
||||||
@ -468,7 +470,9 @@ async def _process_outputs(
|
|||||||
output_value_list.append(output.amount)
|
output_value_list.append(output.amount)
|
||||||
asset_groups_dict: HashBuilderDict[
|
asset_groups_dict: HashBuilderDict[
|
||||||
bytes, HashBuilderDict[bytes, int]
|
bytes, HashBuilderDict[bytes, int]
|
||||||
] = HashBuilderDict(output.asset_groups_count)
|
] = HashBuilderDict(
|
||||||
|
output.asset_groups_count, INVALID_TOKEN_BUNDLE_OUTPUT
|
||||||
|
)
|
||||||
with output_value_list.append(asset_groups_dict):
|
with output_value_list.append(asset_groups_dict):
|
||||||
await _process_asset_groups(
|
await _process_asset_groups(
|
||||||
ctx,
|
ctx,
|
||||||
@ -492,15 +496,15 @@ async def _process_asset_groups(
|
|||||||
should_show_tokens: bool,
|
should_show_tokens: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Read, validate and serialize the asset groups of an output."""
|
"""Read, validate and serialize the asset groups of an output."""
|
||||||
previous_policy_id: bytes = b""
|
|
||||||
for _ in range(asset_groups_count):
|
for _ in range(asset_groups_count):
|
||||||
asset_group: CardanoAssetGroup = await ctx.call(
|
asset_group: CardanoAssetGroup = await ctx.call(
|
||||||
CardanoTxItemAck(), CardanoAssetGroup
|
CardanoTxItemAck(), CardanoAssetGroup
|
||||||
)
|
)
|
||||||
_validate_asset_group(asset_group, previous_policy_id)
|
_validate_asset_group(asset_group)
|
||||||
previous_policy_id = asset_group.policy_id
|
|
||||||
|
|
||||||
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count)
|
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
|
||||||
|
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_OUTPUT
|
||||||
|
)
|
||||||
with asset_groups_dict.add(asset_group.policy_id, tokens):
|
with asset_groups_dict.add(asset_group.policy_id, tokens):
|
||||||
await _process_tokens(
|
await _process_tokens(
|
||||||
ctx,
|
ctx,
|
||||||
@ -519,11 +523,9 @@ async def _process_tokens(
|
|||||||
should_show_tokens: bool,
|
should_show_tokens: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Read, validate, confirm and serialize the tokens of an asset group."""
|
"""Read, validate, confirm and serialize the tokens of an asset group."""
|
||||||
previous_asset_name_bytes: bytes = b""
|
|
||||||
for _ in range(tokens_count):
|
for _ in range(tokens_count):
|
||||||
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
|
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
|
||||||
_validate_token(token, previous_asset_name_bytes)
|
_validate_token(token)
|
||||||
previous_asset_name_bytes = token.asset_name_bytes
|
|
||||||
if should_show_tokens:
|
if should_show_tokens:
|
||||||
await confirm_sending_token(ctx, policy_id, token)
|
await confirm_sending_token(ctx, policy_id, token)
|
||||||
|
|
||||||
@ -641,24 +643,18 @@ async def _process_withdrawals(
|
|||||||
if withdrawals_count == 0:
|
if withdrawals_count == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
previous_reward_address_bytes: bytes = b""
|
|
||||||
for _ in range(withdrawals_count):
|
for _ in range(withdrawals_count):
|
||||||
withdrawal: CardanoTxWithdrawal = await ctx.call(
|
withdrawal: CardanoTxWithdrawal = await ctx.call(
|
||||||
CardanoTxItemAck(), CardanoTxWithdrawal
|
CardanoTxItemAck(), CardanoTxWithdrawal
|
||||||
)
|
)
|
||||||
_validate_withdrawal(
|
_validate_withdrawal(
|
||||||
keychain,
|
|
||||||
withdrawal,
|
withdrawal,
|
||||||
signing_mode,
|
signing_mode,
|
||||||
protocol_magic,
|
|
||||||
network_id,
|
|
||||||
account_path_checker,
|
account_path_checker,
|
||||||
previous_reward_address_bytes,
|
|
||||||
)
|
)
|
||||||
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
|
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
|
||||||
keychain, withdrawal, protocol_magic, network_id
|
keychain, withdrawal, protocol_magic, network_id
|
||||||
)
|
)
|
||||||
previous_reward_address_bytes = reward_address_bytes
|
|
||||||
|
|
||||||
await confirm_withdrawal(ctx, withdrawal, reward_address_bytes, network_id)
|
await confirm_withdrawal(ctx, withdrawal, reward_address_bytes, network_id)
|
||||||
|
|
||||||
@ -707,15 +703,15 @@ async def _process_minting(
|
|||||||
|
|
||||||
await show_warning_tx_contains_mint(ctx)
|
await show_warning_tx_contains_mint(ctx)
|
||||||
|
|
||||||
previous_policy_id: bytes = b""
|
|
||||||
for _ in range(token_minting.asset_groups_count):
|
for _ in range(token_minting.asset_groups_count):
|
||||||
asset_group: CardanoAssetGroup = await ctx.call(
|
asset_group: CardanoAssetGroup = await ctx.call(
|
||||||
CardanoTxItemAck(), CardanoAssetGroup
|
CardanoTxItemAck(), CardanoAssetGroup
|
||||||
)
|
)
|
||||||
_validate_asset_group(asset_group, previous_policy_id, is_mint=True)
|
_validate_asset_group(asset_group, is_mint=True)
|
||||||
previous_policy_id = asset_group.policy_id
|
|
||||||
|
|
||||||
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count)
|
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
|
||||||
|
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_MINT
|
||||||
|
)
|
||||||
with minting_dict.add(asset_group.policy_id, tokens):
|
with minting_dict.add(asset_group.policy_id, tokens):
|
||||||
await _process_minting_tokens(
|
await _process_minting_tokens(
|
||||||
ctx,
|
ctx,
|
||||||
@ -732,11 +728,9 @@ async def _process_minting_tokens(
|
|||||||
tokens_count: int,
|
tokens_count: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Read, validate, confirm and serialize the tokens of an asset group."""
|
"""Read, validate, confirm and serialize the tokens of an asset group."""
|
||||||
previous_asset_name_bytes: bytes = b""
|
|
||||||
for _ in range(tokens_count):
|
for _ in range(tokens_count):
|
||||||
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
|
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
|
||||||
_validate_token(token, previous_asset_name_bytes, is_mint=True)
|
_validate_token(token, is_mint=True)
|
||||||
previous_asset_name_bytes = token.asset_name_bytes
|
|
||||||
await confirm_token_minting(ctx, policy_id, token)
|
await confirm_token_minting(ctx, policy_id, token)
|
||||||
|
|
||||||
assert token.mint_amount is not None # _validate_token
|
assert token.mint_amount is not None # _validate_token
|
||||||
@ -1005,7 +999,7 @@ async def _show_output(
|
|||||||
|
|
||||||
|
|
||||||
def _validate_asset_group(
|
def _validate_asset_group(
|
||||||
asset_group: CardanoAssetGroup, previous_policy_id: bytes, is_mint: bool = False
|
asset_group: CardanoAssetGroup, is_mint: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
INVALID_TOKEN_BUNDLE = (
|
INVALID_TOKEN_BUNDLE = (
|
||||||
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
|
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
|
||||||
@ -1015,13 +1009,9 @@ def _validate_asset_group(
|
|||||||
raise INVALID_TOKEN_BUNDLE
|
raise INVALID_TOKEN_BUNDLE
|
||||||
if asset_group.tokens_count == 0:
|
if asset_group.tokens_count == 0:
|
||||||
raise INVALID_TOKEN_BUNDLE
|
raise INVALID_TOKEN_BUNDLE
|
||||||
if not cbor.are_canonically_ordered(previous_policy_id, asset_group.policy_id):
|
|
||||||
raise INVALID_TOKEN_BUNDLE
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_token(
|
def _validate_token(token: CardanoToken, is_mint: bool = False) -> None:
|
||||||
token: CardanoToken, previous_asset_name_bytes: bytes, is_mint: bool = False
|
|
||||||
) -> None:
|
|
||||||
INVALID_TOKEN_BUNDLE = (
|
INVALID_TOKEN_BUNDLE = (
|
||||||
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
|
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
|
||||||
)
|
)
|
||||||
@ -1035,10 +1025,6 @@ def _validate_token(
|
|||||||
|
|
||||||
if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH:
|
if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH:
|
||||||
raise INVALID_TOKEN_BUNDLE
|
raise INVALID_TOKEN_BUNDLE
|
||||||
if not cbor.are_canonically_ordered(
|
|
||||||
previous_asset_name_bytes, token.asset_name_bytes
|
|
||||||
):
|
|
||||||
raise INVALID_TOKEN_BUNDLE
|
|
||||||
|
|
||||||
|
|
||||||
async def _show_certificate(
|
async def _show_certificate(
|
||||||
@ -1065,13 +1051,9 @@ async def _show_certificate(
|
|||||||
|
|
||||||
|
|
||||||
def _validate_withdrawal(
|
def _validate_withdrawal(
|
||||||
keychain: seed.Keychain,
|
|
||||||
withdrawal: CardanoTxWithdrawal,
|
withdrawal: CardanoTxWithdrawal,
|
||||||
signing_mode: CardanoTxSigningMode,
|
signing_mode: CardanoTxSigningMode,
|
||||||
protocol_magic: int,
|
|
||||||
network_id: int,
|
|
||||||
account_path_checker: AccountPathChecker,
|
account_path_checker: AccountPathChecker,
|
||||||
previous_reward_address_bytes: bytes,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
validate_stake_credential(
|
validate_stake_credential(
|
||||||
withdrawal.path,
|
withdrawal.path,
|
||||||
@ -1084,14 +1066,6 @@ def _validate_withdrawal(
|
|||||||
if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY:
|
if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY:
|
||||||
raise INVALID_WITHDRAWAL
|
raise INVALID_WITHDRAWAL
|
||||||
|
|
||||||
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
|
|
||||||
keychain, withdrawal, protocol_magic, network_id
|
|
||||||
)
|
|
||||||
if not cbor.are_canonically_ordered(
|
|
||||||
previous_reward_address_bytes, reward_address_bytes
|
|
||||||
):
|
|
||||||
raise INVALID_WITHDRAWAL
|
|
||||||
|
|
||||||
account_path_checker.add_withdrawal(withdrawal)
|
account_path_checker.add_withdrawal(withdrawal)
|
||||||
|
|
||||||
|
|
||||||
|
@ -320,11 +320,11 @@ def create_map_header(size: int) -> bytes:
|
|||||||
return _header(_CBOR_MAP, size)
|
return _header(_CBOR_MAP, size)
|
||||||
|
|
||||||
|
|
||||||
def are_canonically_ordered(previous: Value, current: Value) -> bool:
|
def precedes(prev: bytes, curr: bytes) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if `previous` is smaller than `current` with regards to
|
Returns True if `prev` is smaller than `curr` with regards to
|
||||||
the cbor map key ordering as defined in
|
the cbor map key ordering as defined in
|
||||||
https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
|
https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
|
||||||
|
Note that `prev` and `curr` must already be cbor-encoded.
|
||||||
"""
|
"""
|
||||||
u, v = encode(previous), encode(current)
|
return len(prev) < len(curr) or (len(prev) == len(curr) and prev < curr)
|
||||||
return len(u) < len(v) or (len(u) == len(v) and u < v)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user