1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-20 14:39:22 +00:00

refator(cardano): validate map key order in HashBuilderDict

This commit is contained in:
David Misiak 2022-03-08 14:48:10 +01:00 committed by matejcik
parent a36fc6cadc
commit fec4fa2257
3 changed files with 50 additions and 54 deletions

View File

@ -4,6 +4,7 @@ from apps.common import cbor
if TYPE_CHECKING:
from typing import Any, Generic, TypeVar
from trezor import wire
from trezor.utils import HashContext
T = TypeVar("T")
@ -43,7 +44,13 @@ class HashBuilderCollection:
self.remaining -= 1
def _hash_item(self, item: Any) -> None:
def _hash_item(self, item: Any) -> bytes:
assert self.hash_fn is not None
encoded_item = cbor.encode(item)
self.hash_fn.update(encoded_item)
return encoded_item
def _hash_item_streamed(self, item: Any) -> None:
assert self.hash_fn is not None
for chunk in cbor.encode_streamed(item):
self.hash_fn.update(chunk)
@ -74,7 +81,7 @@ class HashBuilderList(HashBuilderCollection, Generic[T]):
if isinstance(item, HashBuilderCollection):
self._insert_child(item)
else:
self._hash_item(item)
self._hash_item_streamed(item)
return item
@ -83,16 +90,31 @@ class HashBuilderList(HashBuilderCollection, Generic[T]):
class HashBuilderDict(HashBuilderCollection, Generic[K, V]):
key_order_error: wire.ProcessError
previous_encoded_key: bytes
def __init__(self, size: int, key_order_error: wire.ProcessError):
super().__init__(size)
self.key_order_error = key_order_error
self.previous_encoded_key = b""
def add(self, key: K, value: V) -> V:
self._do_enter_item()
# enter key, this must not nest
assert not isinstance(key, HashBuilderCollection)
self._hash_item(key)
encoded_key = self._hash_item(key)
# check key ordering
if not cbor.precedes(self.previous_encoded_key, encoded_key):
raise self.key_order_error
self.previous_encoded_key = encoded_key
# enter value, this can nest
if isinstance(value, HashBuilderCollection):
self._insert_child(value)
else:
self._hash_item(value)
self._hash_item_streamed(value)
return value

View File

@ -185,7 +185,9 @@ async def sign_tx(
account_path_checker = AccountPathChecker()
hash_fn = hashlib.blake2b(outlen=32)
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(tx_body_map_item_count)
tx_dict: HashBuilderDict[int, Any] = HashBuilderDict(
tx_body_map_item_count, INVALID_TX_SIGNING_REQUEST
)
tx_dict.start(hash_fn)
with tx_dict:
await _process_transaction(ctx, msg, keychain, tx_dict, account_path_checker)
@ -296,7 +298,7 @@ async def _process_transaction(
if msg.withdrawals_count > 0:
withdrawals_dict: HashBuilderDict[bytes, int] = HashBuilderDict(
msg.withdrawals_count
msg.withdrawals_count, INVALID_WITHDRAWAL
)
with tx_dict.add(TX_BODY_KEY_WITHDRAWALS, withdrawals_dict):
await _process_withdrawals(
@ -324,7 +326,7 @@ async def _process_transaction(
if msg.minting_asset_groups_count > 0:
minting_dict: HashBuilderDict[bytes, HashBuilderDict] = HashBuilderDict(
msg.minting_asset_groups_count
msg.minting_asset_groups_count, INVALID_TOKEN_BUNDLE_MINT
)
with tx_dict.add(TX_BODY_KEY_MINT, minting_dict):
await _process_minting(ctx, minting_dict)
@ -468,7 +470,9 @@ async def _process_outputs(
output_value_list.append(output.amount)
asset_groups_dict: HashBuilderDict[
bytes, HashBuilderDict[bytes, int]
] = HashBuilderDict(output.asset_groups_count)
] = HashBuilderDict(
output.asset_groups_count, INVALID_TOKEN_BUNDLE_OUTPUT
)
with output_value_list.append(asset_groups_dict):
await _process_asset_groups(
ctx,
@ -492,15 +496,15 @@ async def _process_asset_groups(
should_show_tokens: bool,
) -> None:
"""Read, validate and serialize the asset groups of an output."""
previous_policy_id: bytes = b""
for _ in range(asset_groups_count):
asset_group: CardanoAssetGroup = await ctx.call(
CardanoTxItemAck(), CardanoAssetGroup
)
_validate_asset_group(asset_group, previous_policy_id)
previous_policy_id = asset_group.policy_id
_validate_asset_group(asset_group)
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count)
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_OUTPUT
)
with asset_groups_dict.add(asset_group.policy_id, tokens):
await _process_tokens(
ctx,
@ -519,11 +523,9 @@ async def _process_tokens(
should_show_tokens: bool,
) -> None:
"""Read, validate, confirm and serialize the tokens of an asset group."""
previous_asset_name_bytes: bytes = b""
for _ in range(tokens_count):
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
_validate_token(token, previous_asset_name_bytes)
previous_asset_name_bytes = token.asset_name_bytes
_validate_token(token)
if should_show_tokens:
await confirm_sending_token(ctx, policy_id, token)
@ -641,24 +643,18 @@ async def _process_withdrawals(
if withdrawals_count == 0:
return
previous_reward_address_bytes: bytes = b""
for _ in range(withdrawals_count):
withdrawal: CardanoTxWithdrawal = await ctx.call(
CardanoTxItemAck(), CardanoTxWithdrawal
)
_validate_withdrawal(
keychain,
withdrawal,
signing_mode,
protocol_magic,
network_id,
account_path_checker,
previous_reward_address_bytes,
)
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
keychain, withdrawal, protocol_magic, network_id
)
previous_reward_address_bytes = reward_address_bytes
await confirm_withdrawal(ctx, withdrawal, reward_address_bytes, network_id)
@ -707,15 +703,15 @@ async def _process_minting(
await show_warning_tx_contains_mint(ctx)
previous_policy_id: bytes = b""
for _ in range(token_minting.asset_groups_count):
asset_group: CardanoAssetGroup = await ctx.call(
CardanoTxItemAck(), CardanoAssetGroup
)
_validate_asset_group(asset_group, previous_policy_id, is_mint=True)
previous_policy_id = asset_group.policy_id
_validate_asset_group(asset_group, is_mint=True)
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(asset_group.tokens_count)
tokens: HashBuilderDict[bytes, int] = HashBuilderDict(
asset_group.tokens_count, INVALID_TOKEN_BUNDLE_MINT
)
with minting_dict.add(asset_group.policy_id, tokens):
await _process_minting_tokens(
ctx,
@ -732,11 +728,9 @@ async def _process_minting_tokens(
tokens_count: int,
) -> None:
"""Read, validate, confirm and serialize the tokens of an asset group."""
previous_asset_name_bytes: bytes = b""
for _ in range(tokens_count):
token: CardanoToken = await ctx.call(CardanoTxItemAck(), CardanoToken)
_validate_token(token, previous_asset_name_bytes, is_mint=True)
previous_asset_name_bytes = token.asset_name_bytes
_validate_token(token, is_mint=True)
await confirm_token_minting(ctx, policy_id, token)
assert token.mint_amount is not None # _validate_token
@ -1005,7 +999,7 @@ async def _show_output(
def _validate_asset_group(
asset_group: CardanoAssetGroup, previous_policy_id: bytes, is_mint: bool = False
asset_group: CardanoAssetGroup, is_mint: bool = False
) -> None:
INVALID_TOKEN_BUNDLE = (
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
@ -1015,13 +1009,9 @@ def _validate_asset_group(
raise INVALID_TOKEN_BUNDLE
if asset_group.tokens_count == 0:
raise INVALID_TOKEN_BUNDLE
if not cbor.are_canonically_ordered(previous_policy_id, asset_group.policy_id):
raise INVALID_TOKEN_BUNDLE
def _validate_token(
token: CardanoToken, previous_asset_name_bytes: bytes, is_mint: bool = False
) -> None:
def _validate_token(token: CardanoToken, is_mint: bool = False) -> None:
INVALID_TOKEN_BUNDLE = (
INVALID_TOKEN_BUNDLE_MINT if is_mint else INVALID_TOKEN_BUNDLE_OUTPUT
)
@ -1035,10 +1025,6 @@ def _validate_token(
if len(token.asset_name_bytes) > MAX_ASSET_NAME_LENGTH:
raise INVALID_TOKEN_BUNDLE
if not cbor.are_canonically_ordered(
previous_asset_name_bytes, token.asset_name_bytes
):
raise INVALID_TOKEN_BUNDLE
async def _show_certificate(
@ -1065,13 +1051,9 @@ async def _show_certificate(
def _validate_withdrawal(
keychain: seed.Keychain,
withdrawal: CardanoTxWithdrawal,
signing_mode: CardanoTxSigningMode,
protocol_magic: int,
network_id: int,
account_path_checker: AccountPathChecker,
previous_reward_address_bytes: bytes,
) -> None:
validate_stake_credential(
withdrawal.path,
@ -1084,14 +1066,6 @@ def _validate_withdrawal(
if not 0 <= withdrawal.amount < LOVELACE_MAX_SUPPLY:
raise INVALID_WITHDRAWAL
reward_address_bytes = _derive_withdrawal_reward_address_bytes(
keychain, withdrawal, protocol_magic, network_id
)
if not cbor.are_canonically_ordered(
previous_reward_address_bytes, reward_address_bytes
):
raise INVALID_WITHDRAWAL
account_path_checker.add_withdrawal(withdrawal)

View File

@ -320,11 +320,11 @@ def create_map_header(size: int) -> bytes:
return _header(_CBOR_MAP, size)
def are_canonically_ordered(previous: Value, current: Value) -> bool:
def precedes(prev: bytes, curr: bytes) -> bool:
"""
Returns True if `previous` is smaller than `current` with regards to
Returns True if `prev` is smaller than `curr` with regards to
the cbor map key ordering as defined in
https://datatracker.ietf.org/doc/html/rfc7049#section-3.9
Note that `prev` and `curr` must already be cbor-encoded.
"""
u, v = encode(previous), encode(current)
return len(u) < len(v) or (len(u) == len(v) and u < v)
return len(prev) < len(curr) or (len(prev) == len(curr) and prev < curr)