1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 15:38:11 +00:00

refactor(core) allow recursion in ethereum sign_typed_data

This commit is contained in:
Jan Hnatek 2024-09-05 11:38:25 +02:00
parent 777ad11bec
commit d98b5fe064

View File

@ -252,8 +252,6 @@ class TypedDataEnvelope:
i.e. the concatenation of the encoded member values in the order that they appear in the type.
Each encoded member value is exactly 32-byte long.
"""
from .layout import confirm_typed_value, should_show_array
type_members = self.types[primary_type].members
member_value_path = member_path + [0]
current_parent_objects = parent_objects + [""]
@ -265,96 +263,165 @@ class TypedDataEnvelope:
# Arrays and structs need special recursive handling
if field_type.data_type == EthereumDataType.STRUCT:
assert field_type.struct_name is not None # validate_field_type
struct_name = field_type.struct_name
current_parent_objects[-1] = field_name
if show_data:
show_struct = await should_show_struct(
struct_name, # description
self.types[struct_name].members, # data_members
".".join(current_parent_objects), # title
)
else:
show_struct = False
res = await self.hash_struct(
struct_name,
await self.encode_struct(
w,
field_type.struct_name,
member_value_path,
show_struct,
show_data,
current_parent_objects,
False,
)
elif field_type.data_type == EthereumDataType.ARRAY:
await self.encode_array(
w,
field_name,
field_type,
member_value_path,
show_data,
parent_objects,
current_parent_objects,
)
w.extend(res)
elif field_type.data_type == EthereumDataType.ARRAY:
# Getting the length of the array first, if not fixed
if field_type.size is None:
array_size = await _get_array_size(member_value_path)
else:
array_size = field_type.size
assert field_type.entry_type is not None # validate_field_type
entry_type = field_type.entry_type
current_parent_objects[-1] = field_name
if show_data:
show_array = await should_show_array(
current_parent_objects,
get_type_name(entry_type),
array_size,
)
else:
show_array = False
arr_w = get_hash_writer()
el_member_path = member_value_path + [0]
for i in range(array_size):
el_member_path[-1] = i
# TODO: we do not support arrays of arrays, check if we should
if entry_type.data_type == EthereumDataType.STRUCT:
assert entry_type.struct_name is not None # validate_field_type
struct_name = entry_type.struct_name
# Metamask V4 implementation has a bug, that causes the
# behavior of structs in array be different from SPEC
# Explanation at https://github.com/MetaMask/eth-sig-util/pull/107
# encode_data() is the way to process structs in arrays, but
# Metamask V4 is using hash_struct() even in this case
if self.metamask_v4_compat:
res = await self.hash_struct(
struct_name,
el_member_path,
show_array,
current_parent_objects,
)
arr_w.extend(res)
else:
await self.get_and_encode_data(
arr_w,
struct_name,
el_member_path,
show_array,
current_parent_objects,
)
else:
value = await get_value(entry_type, el_member_path)
encode_field(arr_w, entry_type, value)
if show_array:
await confirm_typed_value(
field_name,
value,
parent_objects,
entry_type,
i,
)
w.extend(arr_w.get_digest())
else:
value = await get_value(field_type, member_value_path)
encode_field(w, field_type, value)
if show_data:
await confirm_typed_value(
field_name,
value,
parent_objects,
field_type,
)
await self.encode_nonref(
w,
field_name,
field_type,
member_value_path,
show_data,
parent_objects,
)
async def encode_struct(
self,
w: HashWriter,
struct_name: str,
field_value_path: list[int],
show_data: bool,
parent_objects: list[str],
is_array_member: bool,
):
"""Encode a struct field."""
if show_data:
show_struct = await should_show_struct(
struct_name, # description
self.types[struct_name].members, # data_members
".".join(parent_objects), # title
)
else:
show_struct = False
# Metamask V4 implementation has a bug, that causes the
# behavior of structs in array be different from SPEC
# Explanation at https://github.com/MetaMask/eth-sig-util/pull/107
# encode_data() is the way to process structs in arrays, but
# Metamask V4 is using hash_struct() even in this case
if not is_array_member or self.metamask_v4_compat:
res = await self.hash_struct(
struct_name,
field_value_path,
show_struct,
parent_objects,
)
w.extend(res)
else:
await self.get_and_encode_data(
w,
struct_name,
field_value_path,
show_struct,
parent_objects,
)
async def encode_array(
self,
w: HashWriter,
array_name: str,
array_field_type: EthereumFieldType,
array_value_path: list[int],
show_data: bool,
parent_objects: list[str],
current_parent_objects: list[str],
):
"""Encode an array field."""
from .layout import should_show_array
# Get the length of the array first, if not fixed
if array_field_type.size is None:
array_size = await _get_array_size(array_value_path)
else:
array_size = array_field_type.size
assert array_field_type.entry_type is not None # validate_field_type
entry_type = array_field_type.entry_type
current_parent_objects[-1] = array_name
if show_data:
show_array = await should_show_array(
current_parent_objects,
get_type_name(entry_type),
array_size,
)
else:
show_array = False
arr_w = get_hash_writer()
field_member_path = array_value_path + [0]
for i in range(array_size):
field_member_path[-1] = i
# TODO: we do not support arrays of arrays, check if we should
if entry_type.data_type == EthereumDataType.STRUCT:
assert entry_type.struct_name is not None # validate_field_type
await self.encode_struct(
arr_w,
entry_type.struct_name,
field_member_path,
show_array,
current_parent_objects,
is_array_member=True,
)
else:
await self.encode_nonref(
arr_w,
array_name,
entry_type,
field_member_path,
show_array,
parent_objects,
i,
)
w.extend(arr_w.get_digest())
async def encode_nonref(
self,
w: HashWriter,
field_name: str,
field_type: EthereumFieldType,
field_value_path: list[int],
show_data: bool,
parent_objects: list[str],
array_index: int | None = None,
):
"""Encode a non-reference field."""
from .layout import confirm_typed_value
value = await get_value(field_type, field_value_path)
encode_field(w, field_type, value)
if show_data:
await confirm_typed_value(
field_name,
value,
parent_objects,
field_type,
array_index,
)
def encode_field(