mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 06:18:07 +00:00
chore(core): decrease ethereum size by 17250 bytes
This commit is contained in:
parent
0c3423b1c7
commit
26fd0de198
@ -1,16 +1,9 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor.messages import EthereumAddress
|
|
||||||
from trezor.ui.layouts import show_address
|
|
||||||
|
|
||||||
from apps.common import paths
|
|
||||||
|
|
||||||
from . import networks
|
|
||||||
from .helpers import address_from_bytes
|
|
||||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezor.messages import EthereumGetAddress
|
from trezor.messages import EthereumGetAddress, EthereumAddress
|
||||||
from trezor.wire import Context
|
from trezor.wire import Context
|
||||||
|
|
||||||
from apps.common.keychain import Keychain
|
from apps.common.keychain import Keychain
|
||||||
@ -20,18 +13,26 @@ if TYPE_CHECKING:
|
|||||||
async def get_address(
|
async def get_address(
|
||||||
ctx: Context, msg: EthereumGetAddress, keychain: Keychain
|
ctx: Context, msg: EthereumGetAddress, keychain: Keychain
|
||||||
) -> EthereumAddress:
|
) -> EthereumAddress:
|
||||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
from trezor.messages import EthereumAddress
|
||||||
|
from trezor.ui.layouts import show_address
|
||||||
|
from apps.common import paths
|
||||||
|
from . import networks
|
||||||
|
from .helpers import address_from_bytes
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
address_n = msg.address_n # local_cache_attribute
|
||||||
|
|
||||||
if len(msg.address_n) > 1: # path has slip44 network identifier
|
await paths.validate_path(ctx, keychain, address_n)
|
||||||
network = networks.by_slip44(msg.address_n[1] & 0x7FFF_FFFF)
|
|
||||||
|
node = keychain.derive(address_n)
|
||||||
|
|
||||||
|
if len(address_n) > 1: # path has slip44 network identifier
|
||||||
|
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF)
|
||||||
else:
|
else:
|
||||||
network = None
|
network = None
|
||||||
address = address_from_bytes(node.ethereum_pubkeyhash(), network)
|
address = address_from_bytes(node.ethereum_pubkeyhash(), network)
|
||||||
|
|
||||||
if msg.show_display:
|
if msg.show_display:
|
||||||
title = paths.address_n_to_str(msg.address_n)
|
title = paths.address_n_to_str(address_n)
|
||||||
await show_address(ctx, address=address, title=title)
|
await show_address(ctx, address, title=title)
|
||||||
|
|
||||||
return EthereumAddress(address=address)
|
return EthereumAddress(address=address)
|
||||||
|
@ -1,15 +1,11 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from ubinascii import hexlify
|
|
||||||
|
|
||||||
from trezor.messages import EthereumPublicKey, HDNodeType
|
from apps.common import paths
|
||||||
from trezor.ui.layouts import show_pubkey
|
|
||||||
|
|
||||||
from apps.common import coins, paths
|
|
||||||
|
|
||||||
from .keychain import with_keychain_from_path
|
from .keychain import with_keychain_from_path
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezor.messages import EthereumGetPublicKey
|
from trezor.messages import EthereumGetPublicKey, EthereumPublicKey
|
||||||
from trezor.wire import Context
|
from trezor.wire import Context
|
||||||
|
|
||||||
from apps.common.keychain import Keychain
|
from apps.common.keychain import Keychain
|
||||||
@ -19,6 +15,11 @@ if TYPE_CHECKING:
|
|||||||
async def get_public_key(
|
async def get_public_key(
|
||||||
ctx: Context, msg: EthereumGetPublicKey, keychain: Keychain
|
ctx: Context, msg: EthereumGetPublicKey, keychain: Keychain
|
||||||
) -> EthereumPublicKey:
|
) -> EthereumPublicKey:
|
||||||
|
from ubinascii import hexlify
|
||||||
|
from trezor.messages import EthereumPublicKey, HDNodeType
|
||||||
|
from trezor.ui.layouts import show_pubkey
|
||||||
|
from apps.common import coins
|
||||||
|
|
||||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
|
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from ubinascii import hexlify, unhexlify
|
from ubinascii import hexlify
|
||||||
|
|
||||||
from trezor import wire
|
|
||||||
from trezor.enums import EthereumDataType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezor.messages import EthereumFieldType
|
from trezor.messages import EthereumFieldType
|
||||||
@ -24,7 +21,7 @@ def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None)
|
|||||||
address_hex = hexlify(address_bytes).decode()
|
address_hex = hexlify(address_bytes).decode()
|
||||||
digest = sha3_256((prefix + address_hex).encode(), keccak=True).digest()
|
digest = sha3_256((prefix + address_hex).encode(), keccak=True).digest()
|
||||||
|
|
||||||
def maybe_upper(i: int) -> str:
|
def _maybe_upper(i: int) -> str:
|
||||||
"""Uppercase i-th letter only if the corresponding nibble has high bit set."""
|
"""Uppercase i-th letter only if the corresponding nibble has high bit set."""
|
||||||
digest_byte = digest[i // 2]
|
digest_byte = digest[i // 2]
|
||||||
hex_letter = address_hex[i]
|
hex_letter = address_hex[i]
|
||||||
@ -39,10 +36,13 @@ def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None)
|
|||||||
else:
|
else:
|
||||||
return hex_letter
|
return hex_letter
|
||||||
|
|
||||||
return "0x" + "".join(maybe_upper(i) for i in range(len(address_hex)))
|
return "0x" + "".join(_maybe_upper(i) for i in range(len(address_hex)))
|
||||||
|
|
||||||
|
|
||||||
def bytes_from_address(address: str) -> bytes:
|
def bytes_from_address(address: str) -> bytes:
|
||||||
|
from ubinascii import unhexlify
|
||||||
|
from trezor import wire
|
||||||
|
|
||||||
if len(address) == 40:
|
if len(address) == 40:
|
||||||
return unhexlify(address)
|
return unhexlify(address)
|
||||||
|
|
||||||
@ -59,6 +59,8 @@ def bytes_from_address(address: str) -> bytes:
|
|||||||
|
|
||||||
def get_type_name(field: EthereumFieldType) -> str:
|
def get_type_name(field: EthereumFieldType) -> str:
|
||||||
"""Create a string from type definition (like uint256 or bytes16)."""
|
"""Create a string from type definition (like uint256 or bytes16)."""
|
||||||
|
from trezor.enums import EthereumDataType
|
||||||
|
|
||||||
data_type = field.data_type
|
data_type = field.data_type
|
||||||
size = field.size
|
size = field.size
|
||||||
|
|
||||||
@ -109,12 +111,12 @@ def decode_typed_data(data: bytes, type_name: str) -> str:
|
|||||||
return str(int.from_bytes(data, "big"))
|
return str(int.from_bytes(data, "big"))
|
||||||
elif type_name.startswith("int"):
|
elif type_name.startswith("int"):
|
||||||
# Micropython does not implement "signed" arg in int.from_bytes()
|
# Micropython does not implement "signed" arg in int.from_bytes()
|
||||||
return str(from_bytes_bigendian_signed(data))
|
return str(_from_bytes_bigendian_signed(data))
|
||||||
|
|
||||||
raise ValueError # Unsupported data type for direct field decoding
|
raise ValueError # Unsupported data type for direct field decoding
|
||||||
|
|
||||||
|
|
||||||
def from_bytes_bigendian_signed(b: bytes) -> int:
|
def _from_bytes_bigendian_signed(b: bytes) -> int:
|
||||||
negative = b[0] & 0x80
|
negative = b[0] & 0x80
|
||||||
if negative:
|
if negative:
|
||||||
neg_b = bytearray(b)
|
neg_b = bytearray(b)
|
||||||
|
@ -1,7 +1,5 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import wire
|
|
||||||
|
|
||||||
from apps.common import paths
|
from apps.common import paths
|
||||||
from apps.common.keychain import get_keychain
|
from apps.common.keychain import get_keychain
|
||||||
|
|
||||||
@ -10,6 +8,8 @@ from . import CURVE, networks
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Callable, Iterable, TypeVar
|
from typing import Callable, Iterable, TypeVar
|
||||||
|
|
||||||
|
from trezor.wire import Context
|
||||||
|
|
||||||
from trezor.messages import (
|
from trezor.messages import (
|
||||||
EthereumGetAddress,
|
EthereumGetAddress,
|
||||||
EthereumGetPublicKey,
|
EthereumGetPublicKey,
|
||||||
@ -64,7 +64,7 @@ def with_keychain_from_path(
|
|||||||
*patterns: str,
|
*patterns: str,
|
||||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
|
||||||
schemas = _schemas_from_address_n(patterns, msg.address_n)
|
schemas = _schemas_from_address_n(patterns, msg.address_n)
|
||||||
keychain = await get_keychain(ctx, CURVE, schemas)
|
keychain = await get_keychain(ctx, CURVE, schemas)
|
||||||
with keychain:
|
with keychain:
|
||||||
@ -97,7 +97,7 @@ def with_keychain_from_chain_id(
|
|||||||
func: HandlerWithKeychain[MsgInChainId, MsgOut]
|
func: HandlerWithKeychain[MsgInChainId, MsgOut]
|
||||||
) -> Handler[MsgInChainId, MsgOut]:
|
) -> Handler[MsgInChainId, MsgOut]:
|
||||||
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
|
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
|
||||||
async def wrapper(ctx: wire.Context, msg: MsgInChainId) -> MsgOut:
|
async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut:
|
||||||
schemas = _schemas_from_chain_id(msg)
|
schemas = _schemas_from_chain_id(msg)
|
||||||
keychain = await get_keychain(ctx, CURVE, schemas)
|
keychain = await get_keychain(ctx, CURVE, schemas)
|
||||||
with keychain:
|
with keychain:
|
||||||
|
@ -1,29 +1,19 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
from ubinascii import hexlify
|
|
||||||
|
|
||||||
from trezor import ui
|
from trezor import ui
|
||||||
from trezor.enums import ButtonRequestType, EthereumDataType
|
from trezor.enums import ButtonRequestType
|
||||||
from trezor.strings import format_amount, format_plural
|
from trezor.strings import format_plural
|
||||||
from trezor.ui.layouts import (
|
from trezor.ui.layouts import confirm_blob, confirm_text, should_show_more
|
||||||
confirm_action,
|
|
||||||
confirm_address,
|
|
||||||
confirm_amount,
|
|
||||||
confirm_blob,
|
|
||||||
confirm_output,
|
|
||||||
confirm_text,
|
|
||||||
confirm_total,
|
|
||||||
should_show_more,
|
|
||||||
)
|
|
||||||
from trezor.ui.layouts.altcoin import confirm_total_ethereum
|
|
||||||
|
|
||||||
from . import networks, tokens
|
from . import networks
|
||||||
from .helpers import address_from_bytes, decode_typed_data, get_type_name
|
from .helpers import decode_typed_data
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from typing import Awaitable, Iterable
|
from typing import Awaitable, Iterable
|
||||||
|
|
||||||
from trezor.messages import EthereumFieldType, EthereumStructMember
|
from trezor.messages import EthereumFieldType, EthereumStructMember
|
||||||
from trezor.wire import Context
|
from trezor.wire import Context
|
||||||
|
from . import tokens
|
||||||
|
|
||||||
|
|
||||||
def require_confirm_tx(
|
def require_confirm_tx(
|
||||||
@ -33,15 +23,18 @@ def require_confirm_tx(
|
|||||||
chain_id: int,
|
chain_id: int,
|
||||||
token: tokens.TokenInfo | None = None,
|
token: tokens.TokenInfo | None = None,
|
||||||
) -> Awaitable[None]:
|
) -> Awaitable[None]:
|
||||||
|
from .helpers import address_from_bytes
|
||||||
|
from trezor.ui.layouts import confirm_output
|
||||||
|
|
||||||
if to_bytes:
|
if to_bytes:
|
||||||
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
|
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
|
||||||
else:
|
else:
|
||||||
to_str = "new contract?"
|
to_str = "new contract?"
|
||||||
return confirm_output(
|
return confirm_output(
|
||||||
ctx,
|
ctx,
|
||||||
address=to_str,
|
to_str,
|
||||||
amount=format_ethereum_amount(value, token, chain_id),
|
format_ethereum_amount(value, token, chain_id),
|
||||||
font_amount=ui.BOLD,
|
ui.BOLD,
|
||||||
color_to=ui.GREY,
|
color_to=ui.GREY,
|
||||||
br_code=ButtonRequestType.SignTx,
|
br_code=ButtonRequestType.SignTx,
|
||||||
)
|
)
|
||||||
@ -55,6 +48,8 @@ def require_confirm_fee(
|
|||||||
chain_id: int,
|
chain_id: int,
|
||||||
token: tokens.TokenInfo | None = None,
|
token: tokens.TokenInfo | None = None,
|
||||||
) -> Awaitable[None]:
|
) -> Awaitable[None]:
|
||||||
|
from trezor.ui.layouts.altcoin import confirm_total_ethereum
|
||||||
|
|
||||||
return confirm_total_ethereum(
|
return confirm_total_ethereum(
|
||||||
ctx,
|
ctx,
|
||||||
format_ethereum_amount(spending, token, chain_id),
|
format_ethereum_amount(spending, token, chain_id),
|
||||||
@ -72,22 +67,24 @@ async def require_confirm_eip1559_fee(
|
|||||||
chain_id: int,
|
chain_id: int,
|
||||||
token: tokens.TokenInfo | None = None,
|
token: tokens.TokenInfo | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
from trezor.ui.layouts import confirm_amount, confirm_total
|
||||||
|
|
||||||
await confirm_amount(
|
await confirm_amount(
|
||||||
ctx,
|
ctx,
|
||||||
title="Confirm fee",
|
"Confirm fee",
|
||||||
description="Maximum fee per gas",
|
format_ethereum_amount(max_gas_fee, None, chain_id),
|
||||||
amount=format_ethereum_amount(max_gas_fee, None, chain_id),
|
"Maximum fee per gas",
|
||||||
)
|
)
|
||||||
await confirm_amount(
|
await confirm_amount(
|
||||||
ctx,
|
ctx,
|
||||||
title="Confirm fee",
|
"Confirm fee",
|
||||||
description="Priority fee per gas",
|
format_ethereum_amount(max_priority_fee, None, chain_id),
|
||||||
amount=format_ethereum_amount(max_priority_fee, None, chain_id),
|
"Priority fee per gas",
|
||||||
)
|
)
|
||||||
await confirm_total(
|
await confirm_total(
|
||||||
ctx,
|
ctx,
|
||||||
total_amount=format_ethereum_amount(spending, token, chain_id),
|
format_ethereum_amount(spending, token, chain_id),
|
||||||
fee_amount=format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
|
format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
|
||||||
total_label="Amount sent:\n",
|
total_label="Amount sent:\n",
|
||||||
fee_label="\nMaximum fee:\n",
|
fee_label="\nMaximum fee:\n",
|
||||||
)
|
)
|
||||||
@ -96,13 +93,16 @@ async def require_confirm_eip1559_fee(
|
|||||||
def require_confirm_unknown_token(
|
def require_confirm_unknown_token(
|
||||||
ctx: Context, address_bytes: bytes
|
ctx: Context, address_bytes: bytes
|
||||||
) -> Awaitable[None]:
|
) -> Awaitable[None]:
|
||||||
|
from ubinascii import hexlify
|
||||||
|
from trezor.ui.layouts import confirm_address
|
||||||
|
|
||||||
contract_address_hex = "0x" + hexlify(address_bytes).decode()
|
contract_address_hex = "0x" + hexlify(address_bytes).decode()
|
||||||
return confirm_address(
|
return confirm_address(
|
||||||
ctx,
|
ctx,
|
||||||
"Unknown token",
|
"Unknown token",
|
||||||
contract_address_hex,
|
contract_address_hex,
|
||||||
description="Contract:",
|
"Contract:",
|
||||||
br_type="unknown_token",
|
"unknown_token",
|
||||||
icon_color=ui.ORANGE,
|
icon_color=ui.ORANGE,
|
||||||
br_code=ButtonRequestType.SignTx,
|
br_code=ButtonRequestType.SignTx,
|
||||||
)
|
)
|
||||||
@ -112,20 +112,22 @@ def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitabl
|
|||||||
return confirm_blob(
|
return confirm_blob(
|
||||||
ctx,
|
ctx,
|
||||||
"confirm_data",
|
"confirm_data",
|
||||||
title="Confirm data",
|
"Confirm data",
|
||||||
description=f"Size: {data_total} bytes",
|
data,
|
||||||
data=data,
|
f"Size: {data_total} bytes",
|
||||||
br_code=ButtonRequestType.SignTx,
|
br_code=ButtonRequestType.SignTx,
|
||||||
ask_pagination=True,
|
ask_pagination=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def confirm_typed_data_final(ctx: Context) -> None:
|
async def confirm_typed_data_final(ctx: Context) -> None:
|
||||||
|
from trezor.ui.layouts import confirm_action
|
||||||
|
|
||||||
await confirm_action(
|
await confirm_action(
|
||||||
ctx,
|
ctx,
|
||||||
"confirm_typed_data_final",
|
"confirm_typed_data_final",
|
||||||
title="Confirm typed data",
|
"Confirm typed data",
|
||||||
action="Really sign EIP-712 typed data?",
|
"Really sign EIP-712 typed data?",
|
||||||
verb="Hold to confirm",
|
verb="Hold to confirm",
|
||||||
hold=True,
|
hold=True,
|
||||||
)
|
)
|
||||||
@ -135,9 +137,9 @@ def confirm_empty_typed_message(ctx: Context) -> Awaitable[None]:
|
|||||||
return confirm_text(
|
return confirm_text(
|
||||||
ctx,
|
ctx,
|
||||||
"confirm_empty_typed_message",
|
"confirm_empty_typed_message",
|
||||||
title="Confirm message",
|
"Confirm message",
|
||||||
data="",
|
"",
|
||||||
description="No message field",
|
"No message field",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -152,10 +154,10 @@ async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
|
|||||||
)
|
)
|
||||||
return await should_show_more(
|
return await should_show_more(
|
||||||
ctx,
|
ctx,
|
||||||
title="Confirm domain",
|
"Confirm domain",
|
||||||
para=para,
|
para,
|
||||||
button_text="Show full domain",
|
"Show full domain",
|
||||||
br_type="should_show_domain",
|
"should_show_domain",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -176,10 +178,10 @@ async def should_show_struct(
|
|||||||
)
|
)
|
||||||
return await should_show_more(
|
return await should_show_more(
|
||||||
ctx,
|
ctx,
|
||||||
title=title,
|
title,
|
||||||
para=para,
|
para,
|
||||||
button_text=button_text,
|
button_text,
|
||||||
br_type="should_show_struct",
|
"should_show_struct",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -192,10 +194,10 @@ async def should_show_array(
|
|||||||
para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),)
|
para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),)
|
||||||
return await should_show_more(
|
return await should_show_more(
|
||||||
ctx,
|
ctx,
|
||||||
title=limit_str(".".join(parent_objects)),
|
limit_str(".".join(parent_objects)),
|
||||||
para=para,
|
para,
|
||||||
button_text="Show full array",
|
"Show full array",
|
||||||
br_type="should_show_array",
|
"should_show_array",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -207,6 +209,9 @@ async def confirm_typed_value(
|
|||||||
field: EthereumFieldType,
|
field: EthereumFieldType,
|
||||||
array_index: int | None = None,
|
array_index: int | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
from trezor.enums import EthereumDataType
|
||||||
|
from .helpers import get_type_name
|
||||||
|
|
||||||
type_name = get_type_name(field)
|
type_name = get_type_name(field)
|
||||||
|
|
||||||
if array_index is not None:
|
if array_index is not None:
|
||||||
@ -222,24 +227,26 @@ async def confirm_typed_value(
|
|||||||
await confirm_blob(
|
await confirm_blob(
|
||||||
ctx,
|
ctx,
|
||||||
"confirm_typed_value",
|
"confirm_typed_value",
|
||||||
title=title,
|
title,
|
||||||
data=data,
|
data,
|
||||||
description=description,
|
description,
|
||||||
ask_pagination=True,
|
ask_pagination=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await confirm_text(
|
await confirm_text(
|
||||||
ctx,
|
ctx,
|
||||||
"confirm_typed_value",
|
"confirm_typed_value",
|
||||||
title=title,
|
title,
|
||||||
data=data,
|
data,
|
||||||
description=description,
|
description,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def format_ethereum_amount(
|
def format_ethereum_amount(
|
||||||
value: int, token: tokens.TokenInfo | None, chain_id: int
|
value: int, token: tokens.TokenInfo | None, chain_id: int
|
||||||
) -> str:
|
) -> str:
|
||||||
|
from trezor.strings import format_amount
|
||||||
|
|
||||||
if token:
|
if token:
|
||||||
suffix = token.symbol
|
suffix = token.symbol
|
||||||
decimals = token.decimals
|
decimals = token.decimals
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,10 +1,28 @@
|
|||||||
# generated from networks.py.mako
|
# generated from networks.py.mako
|
||||||
# (by running `make templates` in `core`)
|
# (by running `make templates` in `core`)
|
||||||
# do not edit manually!
|
# do not edit manually!
|
||||||
from typing import Iterator
|
|
||||||
|
# NOTE: using positional arguments saves 4400 bytes in flash size,
|
||||||
|
# returning tuples instead of classes saved 800 bytes
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from apps.common.paths import HARDENED
|
from apps.common.paths import HARDENED
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
|
# Removing the necessity to construct object to save space
|
||||||
|
# fmt: off
|
||||||
|
NetworkInfoTuple = tuple[
|
||||||
|
int, # chain_id
|
||||||
|
int, # slip44
|
||||||
|
str, # shortcut
|
||||||
|
str, # name
|
||||||
|
bool # rskip60
|
||||||
|
]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def shortcut_by_chain_id(chain_id: int) -> str:
|
def shortcut_by_chain_id(chain_id: int) -> str:
|
||||||
n = by_chain_id(chain_id)
|
n = by_chain_id(chain_id)
|
||||||
@ -13,21 +31,36 @@ def shortcut_by_chain_id(chain_id: int) -> str:
|
|||||||
|
|
||||||
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
|
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
|
||||||
for n in _networks_iterator():
|
for n in _networks_iterator():
|
||||||
if n.chain_id == chain_id:
|
n_chain_id = n[0]
|
||||||
return n
|
if n_chain_id == chain_id:
|
||||||
|
return NetworkInfo(
|
||||||
|
chain_id=n[0],
|
||||||
|
slip44=n[1],
|
||||||
|
shortcut=n[2],
|
||||||
|
name=n[3],
|
||||||
|
rskip60=n[4],
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def by_slip44(slip44: int) -> "NetworkInfo" | None:
|
def by_slip44(slip44: int) -> "NetworkInfo" | None:
|
||||||
for n in _networks_iterator():
|
for n in _networks_iterator():
|
||||||
if n.slip44 == slip44:
|
n_slip44 = n[1]
|
||||||
return n
|
if n_slip44 == slip44:
|
||||||
|
return NetworkInfo(
|
||||||
|
chain_id=n[0],
|
||||||
|
slip44=n[1],
|
||||||
|
shortcut=n[2],
|
||||||
|
name=n[3],
|
||||||
|
rskip60=n[4],
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def all_slip44_ids_hardened() -> Iterator[int]:
|
def all_slip44_ids_hardened() -> Iterator[int]:
|
||||||
for n in _networks_iterator():
|
for n in _networks_iterator():
|
||||||
yield n.slip44 | HARDENED
|
# n_slip_44 is the second element
|
||||||
|
yield n[1] | HARDENED
|
||||||
|
|
||||||
|
|
||||||
class NetworkInfo:
|
class NetworkInfo:
|
||||||
@ -42,13 +75,13 @@ class NetworkInfo:
|
|||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
def _networks_iterator() -> Iterator[NetworkInfo]:
|
def _networks_iterator() -> Iterator[NetworkInfoTuple]:
|
||||||
% for n in supported_on("trezor2", eth):
|
% for n in supported_on("trezor2", eth):
|
||||||
yield NetworkInfo(
|
yield (
|
||||||
chain_id=${n.chain_id},
|
${n.chain_id}, # chain_id
|
||||||
slip44=${n.slip44},
|
${n.slip44}, # slip44
|
||||||
shortcut="${n.shortcut}",
|
"${n.shortcut}", # shortcut
|
||||||
name="${n.name}",
|
"${n.name}", # name
|
||||||
rskip60=${n.rskip60},
|
${n.rskip60}, # rskip60
|
||||||
)
|
)
|
||||||
% endfor
|
% endfor
|
||||||
|
@ -1,25 +1,18 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor.crypto.curve import secp256k1
|
|
||||||
from trezor.crypto.hashlib import sha3_256
|
|
||||||
from trezor.messages import EthereumMessageSignature
|
|
||||||
from trezor.ui.layouts import confirm_signverify
|
|
||||||
from trezor.utils import HashWriter
|
|
||||||
|
|
||||||
from apps.common import paths
|
|
||||||
from apps.common.signverify import decode_message
|
|
||||||
|
|
||||||
from .helpers import address_from_bytes
|
|
||||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezor.messages import EthereumSignMessage
|
from trezor.messages import EthereumSignMessage, EthereumMessageSignature
|
||||||
from trezor.wire import Context
|
from trezor.wire import Context
|
||||||
|
|
||||||
from apps.common.keychain import Keychain
|
from apps.common.keychain import Keychain
|
||||||
|
|
||||||
|
|
||||||
def message_digest(message: bytes) -> bytes:
|
def message_digest(message: bytes) -> bytes:
|
||||||
|
from trezor.crypto.hashlib import sha3_256
|
||||||
|
from trezor.utils import HashWriter
|
||||||
|
|
||||||
h = HashWriter(sha3_256(keccak=True))
|
h = HashWriter(sha3_256(keccak=True))
|
||||||
signed_message_header = b"\x19Ethereum Signed Message:\n"
|
signed_message_header = b"\x19Ethereum Signed Message:\n"
|
||||||
h.extend(signed_message_header)
|
h.extend(signed_message_header)
|
||||||
@ -32,6 +25,15 @@ def message_digest(message: bytes) -> bytes:
|
|||||||
async def sign_message(
|
async def sign_message(
|
||||||
ctx: Context, msg: EthereumSignMessage, keychain: Keychain
|
ctx: Context, msg: EthereumSignMessage, keychain: Keychain
|
||||||
) -> EthereumMessageSignature:
|
) -> EthereumMessageSignature:
|
||||||
|
from trezor.crypto.curve import secp256k1
|
||||||
|
from trezor.messages import EthereumMessageSignature
|
||||||
|
from trezor.ui.layouts import confirm_signverify
|
||||||
|
|
||||||
|
from apps.common import paths
|
||||||
|
from apps.common.signverify import decode_message
|
||||||
|
|
||||||
|
from .helpers import address_from_bytes
|
||||||
|
|
||||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
|
@ -1,29 +1,19 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import wire
|
|
||||||
from trezor.crypto import rlp
|
from trezor.crypto import rlp
|
||||||
from trezor.crypto.curve import secp256k1
|
from trezor.messages import EthereumTxRequest
|
||||||
from trezor.crypto.hashlib import sha3_256
|
from trezor.wire import DataError
|
||||||
from trezor.messages import EthereumTxAck, EthereumTxRequest
|
|
||||||
from trezor.utils import HashWriter
|
|
||||||
|
|
||||||
from apps.common import paths
|
|
||||||
|
|
||||||
from . import tokens
|
|
||||||
from .helpers import bytes_from_address
|
from .helpers import bytes_from_address
|
||||||
from .keychain import with_keychain_from_chain_id
|
from .keychain import with_keychain_from_chain_id
|
||||||
from .layout import (
|
|
||||||
require_confirm_data,
|
|
||||||
require_confirm_fee,
|
|
||||||
require_confirm_tx,
|
|
||||||
require_confirm_unknown_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from apps.common.keychain import Keychain
|
from apps.common.keychain import Keychain
|
||||||
from trezor.messages import EthereumSignTx
|
from trezor.messages import EthereumSignTx, EthereumTxAck
|
||||||
|
from trezor.wire import Context
|
||||||
|
|
||||||
from .keychain import EthereumSignTxAny
|
from .keychain import EthereumSignTxAny
|
||||||
|
from . import tokens
|
||||||
|
|
||||||
|
|
||||||
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
|
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
|
||||||
@ -34,9 +24,24 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
|
|||||||
|
|
||||||
@with_keychain_from_chain_id
|
@with_keychain_from_chain_id
|
||||||
async def sign_tx(
|
async def sign_tx(
|
||||||
ctx: wire.Context, msg: EthereumSignTx, keychain: Keychain
|
ctx: Context, msg: EthereumSignTx, keychain: Keychain
|
||||||
) -> EthereumTxRequest:
|
) -> EthereumTxRequest:
|
||||||
check(msg)
|
from trezor.utils import HashWriter
|
||||||
|
from trezor.crypto.hashlib import sha3_256
|
||||||
|
from apps.common import paths
|
||||||
|
from .layout import (
|
||||||
|
require_confirm_data,
|
||||||
|
require_confirm_fee,
|
||||||
|
require_confirm_tx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check
|
||||||
|
if msg.tx_type not in [1, 6, None]:
|
||||||
|
raise DataError("tx_type out of bounds")
|
||||||
|
if len(msg.gas_price) + len(msg.gas_limit) > 30:
|
||||||
|
raise DataError("Fee overflow")
|
||||||
|
check_common_fields(msg)
|
||||||
|
|
||||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||||
|
|
||||||
# Handle ERC20s
|
# Handle ERC20s
|
||||||
@ -61,7 +66,7 @@ async def sign_tx(
|
|||||||
data += msg.data_initial_chunk
|
data += msg.data_initial_chunk
|
||||||
data_left = data_total - len(msg.data_initial_chunk)
|
data_left = data_total - len(msg.data_initial_chunk)
|
||||||
|
|
||||||
total_length = get_total_length(msg, data_total)
|
total_length = _get_total_length(msg, data_total)
|
||||||
|
|
||||||
sha = HashWriter(sha3_256(keccak=True))
|
sha = HashWriter(sha3_256(keccak=True))
|
||||||
rlp.write_header(sha, total_length, rlp.LIST_HEADER_BYTE)
|
rlp.write_header(sha, total_length, rlp.LIST_HEADER_BYTE)
|
||||||
@ -89,14 +94,19 @@ async def sign_tx(
|
|||||||
rlp.write(sha, 0)
|
rlp.write(sha, 0)
|
||||||
|
|
||||||
digest = sha.get_digest()
|
digest = sha.get_digest()
|
||||||
result = sign_digest(msg, keychain, digest)
|
result = _sign_digest(msg, keychain, digest)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def handle_erc20(
|
async def handle_erc20(
|
||||||
ctx: wire.Context, msg: EthereumSignTxAny
|
ctx: Context, msg: EthereumSignTxAny
|
||||||
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
|
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
|
||||||
|
from .layout import require_confirm_unknown_token
|
||||||
|
from . import tokens
|
||||||
|
|
||||||
|
data_initial_chunk = msg.data_initial_chunk # local_cache_attribute
|
||||||
|
|
||||||
token = None
|
token = None
|
||||||
address_bytes = recipient = bytes_from_address(msg.to)
|
address_bytes = recipient = bytes_from_address(msg.to)
|
||||||
value = int.from_bytes(msg.value, "big")
|
value = int.from_bytes(msg.value, "big")
|
||||||
@ -104,13 +114,13 @@ async def handle_erc20(
|
|||||||
len(msg.to) in (40, 42)
|
len(msg.to) in (40, 42)
|
||||||
and len(msg.value) == 0
|
and len(msg.value) == 0
|
||||||
and msg.data_length == 68
|
and msg.data_length == 68
|
||||||
and len(msg.data_initial_chunk) == 68
|
and len(data_initial_chunk) == 68
|
||||||
and msg.data_initial_chunk[:16]
|
and data_initial_chunk[:16]
|
||||||
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||||
):
|
):
|
||||||
token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
|
token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
|
||||||
recipient = msg.data_initial_chunk[16:36]
|
recipient = data_initial_chunk[16:36]
|
||||||
value = int.from_bytes(msg.data_initial_chunk[36:68], "big")
|
value = int.from_bytes(data_initial_chunk[36:68], "big")
|
||||||
|
|
||||||
if token is tokens.UNKNOWN_TOKEN:
|
if token is tokens.UNKNOWN_TOKEN:
|
||||||
await require_confirm_unknown_token(ctx, address_bytes)
|
await require_confirm_unknown_token(ctx, address_bytes)
|
||||||
@ -118,7 +128,7 @@ async def handle_erc20(
|
|||||||
return token, address_bytes, recipient, value
|
return token, address_bytes, recipient, value
|
||||||
|
|
||||||
|
|
||||||
def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
|
def _get_total_length(msg: EthereumSignTx, data_total: int) -> int:
|
||||||
length = 0
|
length = 0
|
||||||
if msg.tx_type is not None:
|
if msg.tx_type is not None:
|
||||||
length += rlp.length(msg.tx_type)
|
length += rlp.length(msg.tx_type)
|
||||||
@ -143,20 +153,20 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
|
|||||||
return length
|
return length
|
||||||
|
|
||||||
|
|
||||||
async def send_request_chunk(ctx: wire.Context, data_left: int) -> EthereumTxAck:
|
async def send_request_chunk(ctx: Context, data_left: int) -> EthereumTxAck:
|
||||||
|
from trezor.messages import EthereumTxAck
|
||||||
|
|
||||||
# TODO: layoutProgress ?
|
# TODO: layoutProgress ?
|
||||||
req = EthereumTxRequest()
|
req = EthereumTxRequest()
|
||||||
if data_left <= 1024:
|
req.data_length = min(data_left, 1024)
|
||||||
req.data_length = data_left
|
|
||||||
else:
|
|
||||||
req.data_length = 1024
|
|
||||||
|
|
||||||
return await ctx.call(req, EthereumTxAck)
|
return await ctx.call(req, EthereumTxAck)
|
||||||
|
|
||||||
|
|
||||||
def sign_digest(
|
def _sign_digest(
|
||||||
msg: EthereumSignTx, keychain: Keychain, digest: bytes
|
msg: EthereumSignTx, keychain: Keychain, digest: bytes
|
||||||
) -> EthereumTxRequest:
|
) -> EthereumTxRequest:
|
||||||
|
from trezor.crypto.curve import secp256k1
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
signature = secp256k1.sign(
|
signature = secp256k1.sign(
|
||||||
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
|
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
|
||||||
@ -175,33 +185,25 @@ def sign_digest(
|
|||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
||||||
def check(msg: EthereumSignTx) -> None:
|
|
||||||
if msg.tx_type not in [1, 6, None]:
|
|
||||||
raise wire.DataError("tx_type out of bounds")
|
|
||||||
|
|
||||||
if len(msg.gas_price) + len(msg.gas_limit) > 30:
|
|
||||||
raise wire.DataError("Fee overflow")
|
|
||||||
|
|
||||||
check_common_fields(msg)
|
|
||||||
|
|
||||||
|
|
||||||
def check_common_fields(msg: EthereumSignTxAny) -> None:
|
def check_common_fields(msg: EthereumSignTxAny) -> None:
|
||||||
if msg.data_length > 0:
|
data_length = msg.data_length # local_cache_attribute
|
||||||
|
|
||||||
|
if data_length > 0:
|
||||||
if not msg.data_initial_chunk:
|
if not msg.data_initial_chunk:
|
||||||
raise wire.DataError("Data length provided, but no initial chunk")
|
raise DataError("Data length provided, but no initial chunk")
|
||||||
# Our encoding only supports transactions up to 2^24 bytes. To
|
# Our encoding only supports transactions up to 2^24 bytes. To
|
||||||
# prevent exceeding the limit we use a stricter limit on data length.
|
# prevent exceeding the limit we use a stricter limit on data length.
|
||||||
if msg.data_length > 16_000_000:
|
if data_length > 16_000_000:
|
||||||
raise wire.DataError("Data length exceeds limit")
|
raise DataError("Data length exceeds limit")
|
||||||
if len(msg.data_initial_chunk) > msg.data_length:
|
if len(msg.data_initial_chunk) > data_length:
|
||||||
raise wire.DataError("Invalid size of initial chunk")
|
raise DataError("Invalid size of initial chunk")
|
||||||
|
|
||||||
if len(msg.to) not in (0, 40, 42):
|
if len(msg.to) not in (0, 40, 42):
|
||||||
raise wire.DataError("Invalid recipient address")
|
raise DataError("Invalid recipient address")
|
||||||
|
|
||||||
if not msg.to and msg.data_length == 0:
|
if not msg.to and data_length == 0:
|
||||||
# sending transaction to address 0 (contract creation) without a data field
|
# sending transaction to address 0 (contract creation) without a data field
|
||||||
raise wire.DataError("Contract creation without data")
|
raise DataError("Contract creation without data")
|
||||||
|
|
||||||
if msg.chain_id == 0:
|
if msg.chain_id == 0:
|
||||||
raise wire.DataError("Chain ID out of bounds")
|
raise DataError("Chain ID out of bounds")
|
||||||
|
@ -1,28 +1,21 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import wire
|
|
||||||
from trezor.crypto import rlp
|
from trezor.crypto import rlp
|
||||||
from trezor.crypto.curve import secp256k1
|
|
||||||
from trezor.crypto.hashlib import sha3_256
|
|
||||||
from trezor.messages import EthereumAccessList, EthereumTxRequest
|
|
||||||
from trezor.utils import HashWriter
|
|
||||||
|
|
||||||
from apps.common import paths
|
|
||||||
|
|
||||||
from .helpers import bytes_from_address
|
from .helpers import bytes_from_address
|
||||||
from .keychain import with_keychain_from_chain_id
|
from .keychain import with_keychain_from_chain_id
|
||||||
from .layout import (
|
|
||||||
require_confirm_data,
|
|
||||||
require_confirm_eip1559_fee,
|
|
||||||
require_confirm_tx,
|
|
||||||
)
|
|
||||||
from .sign_tx import check_common_fields, handle_erc20, send_request_chunk
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezor.messages import EthereumSignTxEIP1559
|
from trezor.messages import (
|
||||||
|
EthereumSignTxEIP1559,
|
||||||
|
EthereumAccessList,
|
||||||
|
EthereumTxRequest,
|
||||||
|
)
|
||||||
|
|
||||||
from apps.common.keychain import Keychain
|
from apps.common.keychain import Keychain
|
||||||
|
from trezor.wire import Context
|
||||||
|
|
||||||
|
|
||||||
_TX_TYPE = const(2)
|
_TX_TYPE = const(2)
|
||||||
|
|
||||||
@ -35,28 +28,30 @@ def access_list_item_length(item: EthereumAccessList) -> int:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def access_list_length(access_list: list[EthereumAccessList]) -> int:
|
|
||||||
payload_length = sum(access_list_item_length(i) for i in access_list)
|
|
||||||
return rlp.header_length(payload_length) + payload_length
|
|
||||||
|
|
||||||
|
|
||||||
def write_access_list(w: HashWriter, access_list: list[EthereumAccessList]) -> None:
|
|
||||||
payload_length = sum(access_list_item_length(i) for i in access_list)
|
|
||||||
rlp.write_header(w, payload_length, rlp.LIST_HEADER_BYTE)
|
|
||||||
for item in access_list:
|
|
||||||
address_bytes = bytes_from_address(item.address)
|
|
||||||
address_length = rlp.length(address_bytes)
|
|
||||||
keys_length = rlp.length(item.storage_keys)
|
|
||||||
rlp.write_header(w, address_length + keys_length, rlp.LIST_HEADER_BYTE)
|
|
||||||
rlp.write(w, address_bytes)
|
|
||||||
rlp.write(w, item.storage_keys)
|
|
||||||
|
|
||||||
|
|
||||||
@with_keychain_from_chain_id
|
@with_keychain_from_chain_id
|
||||||
async def sign_tx_eip1559(
|
async def sign_tx_eip1559(
|
||||||
ctx: wire.Context, msg: EthereumSignTxEIP1559, keychain: Keychain
|
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain
|
||||||
) -> EthereumTxRequest:
|
) -> EthereumTxRequest:
|
||||||
check(msg)
|
from trezor.crypto.hashlib import sha3_256
|
||||||
|
from trezor.utils import HashWriter
|
||||||
|
from trezor import wire
|
||||||
|
from trezor.crypto import rlp # local_cache_global
|
||||||
|
from apps.common import paths
|
||||||
|
from .layout import (
|
||||||
|
require_confirm_data,
|
||||||
|
require_confirm_eip1559_fee,
|
||||||
|
require_confirm_tx,
|
||||||
|
)
|
||||||
|
from .sign_tx import handle_erc20, send_request_chunk, check_common_fields
|
||||||
|
|
||||||
|
gas_limit = msg.gas_limit # local_cache_attribute
|
||||||
|
|
||||||
|
# check
|
||||||
|
if len(msg.max_gas_fee) + len(gas_limit) > 30:
|
||||||
|
raise wire.DataError("Fee overflow")
|
||||||
|
if len(msg.max_priority_fee) + len(gas_limit) > 30:
|
||||||
|
raise wire.DataError("Fee overflow")
|
||||||
|
check_common_fields(msg)
|
||||||
|
|
||||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||||
|
|
||||||
@ -74,7 +69,7 @@ async def sign_tx_eip1559(
|
|||||||
value,
|
value,
|
||||||
int.from_bytes(msg.max_priority_fee, "big"),
|
int.from_bytes(msg.max_priority_fee, "big"),
|
||||||
int.from_bytes(msg.max_gas_fee, "big"),
|
int.from_bytes(msg.max_gas_fee, "big"),
|
||||||
int.from_bytes(msg.gas_limit, "big"),
|
int.from_bytes(gas_limit, "big"),
|
||||||
msg.chain_id,
|
msg.chain_id,
|
||||||
token,
|
token,
|
||||||
)
|
)
|
||||||
@ -83,7 +78,7 @@ async def sign_tx_eip1559(
|
|||||||
data += msg.data_initial_chunk
|
data += msg.data_initial_chunk
|
||||||
data_left = data_total - len(msg.data_initial_chunk)
|
data_left = data_total - len(msg.data_initial_chunk)
|
||||||
|
|
||||||
total_length = get_total_length(msg, data_total)
|
total_length = _get_total_length(msg, data_total)
|
||||||
|
|
||||||
sha = HashWriter(sha3_256(keccak=True))
|
sha = HashWriter(sha3_256(keccak=True))
|
||||||
|
|
||||||
@ -96,7 +91,7 @@ async def sign_tx_eip1559(
|
|||||||
msg.nonce,
|
msg.nonce,
|
||||||
msg.max_priority_fee,
|
msg.max_priority_fee,
|
||||||
msg.max_gas_fee,
|
msg.max_gas_fee,
|
||||||
msg.gas_limit,
|
gas_limit,
|
||||||
address_bytes,
|
address_bytes,
|
||||||
msg.value,
|
msg.value,
|
||||||
)
|
)
|
||||||
@ -114,15 +109,24 @@ async def sign_tx_eip1559(
|
|||||||
data_left -= len(resp.data_chunk)
|
data_left -= len(resp.data_chunk)
|
||||||
sha.extend(resp.data_chunk)
|
sha.extend(resp.data_chunk)
|
||||||
|
|
||||||
write_access_list(sha, msg.access_list)
|
# write_access_list
|
||||||
|
payload_length = sum(access_list_item_length(i) for i in msg.access_list)
|
||||||
|
rlp.write_header(sha, payload_length, rlp.LIST_HEADER_BYTE)
|
||||||
|
for item in msg.access_list:
|
||||||
|
address_bytes = bytes_from_address(item.address)
|
||||||
|
address_length = rlp.length(address_bytes)
|
||||||
|
keys_length = rlp.length(item.storage_keys)
|
||||||
|
rlp.write_header(sha, address_length + keys_length, rlp.LIST_HEADER_BYTE)
|
||||||
|
rlp.write(sha, address_bytes)
|
||||||
|
rlp.write(sha, item.storage_keys)
|
||||||
|
|
||||||
digest = sha.get_digest()
|
digest = sha.get_digest()
|
||||||
result = sign_digest(msg, keychain, digest)
|
result = _sign_digest(msg, keychain, digest)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
|
def _get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
|
||||||
length = 0
|
length = 0
|
||||||
|
|
||||||
fields: tuple[rlp.RLPItem, ...] = (
|
fields: tuple[rlp.RLPItem, ...] = (
|
||||||
@ -140,14 +144,21 @@ def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
|
|||||||
length += rlp.header_length(data_total, msg.data_initial_chunk)
|
length += rlp.header_length(data_total, msg.data_initial_chunk)
|
||||||
length += data_total
|
length += data_total
|
||||||
|
|
||||||
length += access_list_length(msg.access_list)
|
# access_list_length
|
||||||
|
payload_length = sum(access_list_item_length(i) for i in msg.access_list)
|
||||||
|
access_list_length = rlp.header_length(payload_length) + payload_length
|
||||||
|
|
||||||
|
length += access_list_length
|
||||||
|
|
||||||
return length
|
return length
|
||||||
|
|
||||||
|
|
||||||
def sign_digest(
|
def _sign_digest(
|
||||||
msg: EthereumSignTxEIP1559, keychain: Keychain, digest: bytes
|
msg: EthereumSignTxEIP1559, keychain: Keychain, digest: bytes
|
||||||
) -> EthereumTxRequest:
|
) -> EthereumTxRequest:
|
||||||
|
from trezor.messages import EthereumTxRequest
|
||||||
|
from trezor.crypto.curve import secp256k1
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
signature = secp256k1.sign(
|
signature = secp256k1.sign(
|
||||||
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
|
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
|
||||||
@ -159,12 +170,3 @@ def sign_digest(
|
|||||||
req.signature_s = signature[33:]
|
req.signature_s = signature[33:]
|
||||||
|
|
||||||
return req
|
return req
|
||||||
|
|
||||||
|
|
||||||
def check(msg: EthereumSignTxEIP1559) -> None:
|
|
||||||
if len(msg.max_gas_fee) + len(msg.gas_limit) > 30:
|
|
||||||
raise wire.DataError("Fee overflow")
|
|
||||||
if len(msg.max_priority_fee) + len(msg.gas_limit) > 30:
|
|
||||||
raise wire.DataError("Fee overflow")
|
|
||||||
|
|
||||||
check_common_fields(msg)
|
|
||||||
|
@ -1,38 +1,24 @@
|
|||||||
from micropython import const
|
from micropython import const
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import wire
|
|
||||||
from trezor.crypto.curve import secp256k1
|
|
||||||
from trezor.crypto.hashlib import sha3_256
|
|
||||||
from trezor.enums import EthereumDataType
|
from trezor.enums import EthereumDataType
|
||||||
from trezor.messages import (
|
from trezor.wire import DataError
|
||||||
EthereumFieldType,
|
|
||||||
EthereumTypedDataSignature,
|
|
||||||
EthereumTypedDataStructAck,
|
|
||||||
EthereumTypedDataStructRequest,
|
|
||||||
EthereumTypedDataValueAck,
|
|
||||||
EthereumTypedDataValueRequest,
|
|
||||||
)
|
|
||||||
from trezor.utils import HashWriter
|
|
||||||
|
|
||||||
from apps.common import paths
|
from .helpers import get_type_name
|
||||||
|
|
||||||
from .helpers import address_from_bytes, get_type_name
|
|
||||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||||
from .layout import (
|
from .layout import should_show_struct
|
||||||
confirm_empty_typed_message,
|
|
||||||
confirm_typed_data_final,
|
|
||||||
confirm_typed_value,
|
|
||||||
should_show_array,
|
|
||||||
should_show_domain,
|
|
||||||
should_show_struct,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from apps.common.keychain import Keychain
|
from apps.common.keychain import Keychain
|
||||||
from trezor.wire import Context
|
from trezor.wire import Context
|
||||||
|
from trezor.utils import HashWriter
|
||||||
|
|
||||||
from trezor.messages import EthereumSignTypedData
|
from trezor.messages import (
|
||||||
|
EthereumSignTypedData,
|
||||||
|
EthereumFieldType,
|
||||||
|
EthereumTypedDataSignature,
|
||||||
|
EthereumTypedDataStructAck,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Maximum data size we support
|
# Maximum data size we support
|
||||||
@ -43,9 +29,14 @@ _MAX_VALUE_BYTE_SIZE = const(1024)
|
|||||||
async def sign_typed_data(
|
async def sign_typed_data(
|
||||||
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain
|
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain
|
||||||
) -> EthereumTypedDataSignature:
|
) -> EthereumTypedDataSignature:
|
||||||
|
from trezor.crypto.curve import secp256k1
|
||||||
|
from apps.common import paths
|
||||||
|
from .helpers import address_from_bytes
|
||||||
|
from trezor.messages import EthereumTypedDataSignature
|
||||||
|
|
||||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||||
|
|
||||||
data_hash = await generate_typed_data_hash(
|
data_hash = await _generate_typed_data_hash(
|
||||||
ctx, msg.primary_type, msg.metamask_v4_compat
|
ctx, msg.primary_type, msg.metamask_v4_compat
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,7 +51,7 @@ async def sign_typed_data(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_typed_data_hash(
|
async def _generate_typed_data_hash(
|
||||||
ctx: Context, primary_type: str, metamask_v4_compat: bool = True
|
ctx: Context, primary_type: str, metamask_v4_compat: bool = True
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""
|
"""
|
||||||
@ -69,20 +60,26 @@ async def generate_typed_data_hash(
|
|||||||
|
|
||||||
metamask_v4_compat - a flag that enables compatibility with MetaMask's signTypedData_v4 method
|
metamask_v4_compat - a flag that enables compatibility with MetaMask's signTypedData_v4 method
|
||||||
"""
|
"""
|
||||||
|
from .layout import (
|
||||||
|
confirm_empty_typed_message,
|
||||||
|
confirm_typed_data_final,
|
||||||
|
should_show_domain,
|
||||||
|
)
|
||||||
|
|
||||||
typed_data_envelope = TypedDataEnvelope(
|
typed_data_envelope = TypedDataEnvelope(
|
||||||
ctx=ctx,
|
ctx,
|
||||||
primary_type=primary_type,
|
primary_type,
|
||||||
metamask_v4_compat=metamask_v4_compat,
|
metamask_v4_compat,
|
||||||
)
|
)
|
||||||
await typed_data_envelope.collect_types()
|
await typed_data_envelope.collect_types()
|
||||||
|
|
||||||
name, version = await get_name_and_version_for_domain(ctx, typed_data_envelope)
|
name, version = await _get_name_and_version_for_domain(ctx, typed_data_envelope)
|
||||||
show_domain = await should_show_domain(ctx, name, version)
|
show_domain = await should_show_domain(ctx, name, version)
|
||||||
domain_separator = await typed_data_envelope.hash_struct(
|
domain_separator = await typed_data_envelope.hash_struct(
|
||||||
primary_type="EIP712Domain",
|
"EIP712Domain",
|
||||||
member_path=[0],
|
[0],
|
||||||
show_data=show_domain,
|
show_domain,
|
||||||
parent_objects=["EIP712Domain"],
|
["EIP712Domain"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setting the primary_type to "EIP712Domain" is technically in spec
|
# Setting the primary_type to "EIP712Domain" is technically in spec
|
||||||
@ -94,16 +91,16 @@ async def generate_typed_data_hash(
|
|||||||
else:
|
else:
|
||||||
show_message = await should_show_struct(
|
show_message = await should_show_struct(
|
||||||
ctx,
|
ctx,
|
||||||
description=primary_type,
|
primary_type,
|
||||||
data_members=typed_data_envelope.types[primary_type].members,
|
typed_data_envelope.types[primary_type].members,
|
||||||
title="Confirm message",
|
"Confirm message",
|
||||||
button_text="Show full message",
|
"Show full message",
|
||||||
)
|
)
|
||||||
message_hash = await typed_data_envelope.hash_struct(
|
message_hash = await typed_data_envelope.hash_struct(
|
||||||
primary_type=primary_type,
|
primary_type,
|
||||||
member_path=[1],
|
[1],
|
||||||
show_data=show_message,
|
show_message,
|
||||||
parent_objects=[primary_type],
|
[primary_type],
|
||||||
)
|
)
|
||||||
|
|
||||||
await confirm_typed_data_final(ctx)
|
await confirm_typed_data_final(ctx)
|
||||||
@ -112,6 +109,9 @@ async def generate_typed_data_hash(
|
|||||||
|
|
||||||
|
|
||||||
def get_hash_writer() -> HashWriter:
|
def get_hash_writer() -> HashWriter:
|
||||||
|
from trezor.crypto.hashlib import sha3_256
|
||||||
|
from trezor.utils import HashWriter
|
||||||
|
|
||||||
return HashWriter(sha3_256(keccak=True))
|
return HashWriter(sha3_256(keccak=True))
|
||||||
|
|
||||||
|
|
||||||
@ -142,6 +142,11 @@ class TypedDataEnvelope:
|
|||||||
|
|
||||||
async def _collect_types(self, type_name: str) -> None:
|
async def _collect_types(self, type_name: str) -> None:
|
||||||
"""Recursively collect types from the client."""
|
"""Recursively collect types from the client."""
|
||||||
|
from trezor.messages import (
|
||||||
|
EthereumTypedDataStructRequest,
|
||||||
|
EthereumTypedDataStructAck,
|
||||||
|
)
|
||||||
|
|
||||||
req = EthereumTypedDataStructRequest(name=type_name)
|
req = EthereumTypedDataStructRequest(name=type_name)
|
||||||
current_type = await self.ctx.call(req, EthereumTypedDataStructAck)
|
current_type = await self.ctx.call(req, EthereumTypedDataStructAck)
|
||||||
self.types[type_name] = current_type
|
self.types[type_name] = current_type
|
||||||
@ -169,11 +174,11 @@ class TypedDataEnvelope:
|
|||||||
w = get_hash_writer()
|
w = get_hash_writer()
|
||||||
self.hash_type(w, primary_type)
|
self.hash_type(w, primary_type)
|
||||||
await self.get_and_encode_data(
|
await self.get_and_encode_data(
|
||||||
w=w,
|
w,
|
||||||
primary_type=primary_type,
|
primary_type,
|
||||||
member_path=member_path,
|
member_path,
|
||||||
show_data=show_data,
|
show_data,
|
||||||
parent_objects=parent_objects,
|
parent_objects,
|
||||||
)
|
)
|
||||||
return w.get_digest()
|
return w.get_digest()
|
||||||
|
|
||||||
@ -242,6 +247,10 @@ class TypedDataEnvelope:
|
|||||||
i.e. the concatenation of the encoded member values in the order that they appear in the type.
|
i.e. the concatenation of the encoded member values in the order that they appear in the type.
|
||||||
Each encoded member value is exactly 32-byte long.
|
Each encoded member value is exactly 32-byte long.
|
||||||
"""
|
"""
|
||||||
|
from .layout import confirm_typed_value, should_show_array
|
||||||
|
|
||||||
|
ctx = self.ctx # local_cache_attribute
|
||||||
|
|
||||||
type_members = self.types[primary_type].members
|
type_members = self.types[primary_type].members
|
||||||
member_value_path = member_path + [0]
|
member_value_path = member_path + [0]
|
||||||
current_parent_objects = parent_objects + [""]
|
current_parent_objects = parent_objects + [""]
|
||||||
@ -258,25 +267,25 @@ class TypedDataEnvelope:
|
|||||||
|
|
||||||
if show_data:
|
if show_data:
|
||||||
show_struct = await should_show_struct(
|
show_struct = await should_show_struct(
|
||||||
ctx=self.ctx,
|
ctx,
|
||||||
description=struct_name,
|
struct_name, # description
|
||||||
data_members=self.types[struct_name].members,
|
self.types[struct_name].members, # data_members
|
||||||
title=".".join(current_parent_objects),
|
".".join(current_parent_objects), # title
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
show_struct = False
|
show_struct = False
|
||||||
|
|
||||||
res = await self.hash_struct(
|
res = await self.hash_struct(
|
||||||
primary_type=struct_name,
|
struct_name,
|
||||||
member_path=member_value_path,
|
member_value_path,
|
||||||
show_data=show_struct,
|
show_struct,
|
||||||
parent_objects=current_parent_objects,
|
current_parent_objects,
|
||||||
)
|
)
|
||||||
w.extend(res)
|
w.extend(res)
|
||||||
elif field_type.data_type == EthereumDataType.ARRAY:
|
elif field_type.data_type == EthereumDataType.ARRAY:
|
||||||
# Getting the length of the array first, if not fixed
|
# Getting the length of the array first, if not fixed
|
||||||
if field_type.size is None:
|
if field_type.size is None:
|
||||||
array_size = await get_array_size(self.ctx, member_value_path)
|
array_size = await _get_array_size(ctx, member_value_path)
|
||||||
else:
|
else:
|
||||||
array_size = field_type.size
|
array_size = field_type.size
|
||||||
|
|
||||||
@ -286,10 +295,10 @@ class TypedDataEnvelope:
|
|||||||
|
|
||||||
if show_data:
|
if show_data:
|
||||||
show_array = await should_show_array(
|
show_array = await should_show_array(
|
||||||
ctx=self.ctx,
|
ctx,
|
||||||
parent_objects=current_parent_objects,
|
current_parent_objects,
|
||||||
data_type=get_type_name(entry_type),
|
get_type_name(entry_type),
|
||||||
size=array_size,
|
array_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
show_array = False
|
show_array = False
|
||||||
@ -309,43 +318,43 @@ class TypedDataEnvelope:
|
|||||||
# Metamask V4 is using hash_struct() even in this case
|
# Metamask V4 is using hash_struct() even in this case
|
||||||
if self.metamask_v4_compat:
|
if self.metamask_v4_compat:
|
||||||
res = await self.hash_struct(
|
res = await self.hash_struct(
|
||||||
primary_type=struct_name,
|
struct_name,
|
||||||
member_path=el_member_path,
|
el_member_path,
|
||||||
show_data=show_array,
|
show_array,
|
||||||
parent_objects=current_parent_objects,
|
current_parent_objects,
|
||||||
)
|
)
|
||||||
arr_w.extend(res)
|
arr_w.extend(res)
|
||||||
else:
|
else:
|
||||||
await self.get_and_encode_data(
|
await self.get_and_encode_data(
|
||||||
w=arr_w,
|
arr_w,
|
||||||
primary_type=struct_name,
|
struct_name,
|
||||||
member_path=el_member_path,
|
el_member_path,
|
||||||
show_data=show_array,
|
show_array,
|
||||||
parent_objects=current_parent_objects,
|
current_parent_objects,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
value = await get_value(self.ctx, entry_type, el_member_path)
|
value = await get_value(ctx, entry_type, el_member_path)
|
||||||
encode_field(arr_w, entry_type, value)
|
encode_field(arr_w, entry_type, value)
|
||||||
if show_array:
|
if show_array:
|
||||||
await confirm_typed_value(
|
await confirm_typed_value(
|
||||||
ctx=self.ctx,
|
ctx,
|
||||||
name=field_name,
|
field_name,
|
||||||
value=value,
|
value,
|
||||||
parent_objects=parent_objects,
|
parent_objects,
|
||||||
field=entry_type,
|
entry_type,
|
||||||
array_index=i,
|
i,
|
||||||
)
|
)
|
||||||
w.extend(arr_w.get_digest())
|
w.extend(arr_w.get_digest())
|
||||||
else:
|
else:
|
||||||
value = await get_value(self.ctx, field_type, member_value_path)
|
value = await get_value(ctx, field_type, member_value_path)
|
||||||
encode_field(w, field_type, value)
|
encode_field(w, field_type, value)
|
||||||
if show_data:
|
if show_data:
|
||||||
await confirm_typed_value(
|
await confirm_typed_value(
|
||||||
ctx=self.ctx,
|
ctx,
|
||||||
name=field_name,
|
field_name,
|
||||||
value=value,
|
value,
|
||||||
parent_objects=parent_objects,
|
parent_objects,
|
||||||
field=field_type,
|
field_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -370,21 +379,27 @@ def encode_field(
|
|||||||
encodeData of their contents
|
encodeData of their contents
|
||||||
- Struct values are encoded recursively as hashStruct(value)
|
- Struct values are encoded recursively as hashStruct(value)
|
||||||
"""
|
"""
|
||||||
|
EDT = EthereumDataType # local_cache_global
|
||||||
|
|
||||||
data_type = field.data_type
|
data_type = field.data_type
|
||||||
|
|
||||||
if data_type == EthereumDataType.BYTES:
|
if data_type == EDT.BYTES:
|
||||||
if field.size is None:
|
if field.size is None:
|
||||||
w.extend(keccak256(value))
|
w.extend(keccak256(value))
|
||||||
else:
|
else:
|
||||||
write_rightpad32(w, value)
|
# write_rightpad32
|
||||||
elif data_type == EthereumDataType.STRING:
|
assert len(value) <= 32
|
||||||
|
w.extend(value)
|
||||||
|
for _ in range(32 - len(value)):
|
||||||
|
w.append(0x00)
|
||||||
|
elif data_type == EDT.STRING:
|
||||||
w.extend(keccak256(value))
|
w.extend(keccak256(value))
|
||||||
elif data_type == EthereumDataType.INT:
|
elif data_type == EDT.INT:
|
||||||
write_leftpad32(w, value, signed=True)
|
write_leftpad32(w, value, signed=True)
|
||||||
elif data_type in (
|
elif data_type in (
|
||||||
EthereumDataType.UINT,
|
EDT.UINT,
|
||||||
EthereumDataType.BOOL,
|
EDT.BOOL,
|
||||||
EthereumDataType.ADDRESS,
|
EDT.ADDRESS,
|
||||||
):
|
):
|
||||||
write_leftpad32(w, value)
|
write_leftpad32(w, value)
|
||||||
else:
|
else:
|
||||||
@ -405,93 +420,90 @@ def write_leftpad32(w: HashWriter, value: bytes, signed: bool = False) -> None:
|
|||||||
w.extend(value)
|
w.extend(value)
|
||||||
|
|
||||||
|
|
||||||
def write_rightpad32(w: HashWriter, value: bytes) -> None:
|
def _validate_value(field: EthereumFieldType, value: bytes) -> None:
|
||||||
assert len(value) <= 32
|
|
||||||
|
|
||||||
w.extend(value)
|
|
||||||
for _ in range(32 - len(value)):
|
|
||||||
w.append(0x00)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_value(field: EthereumFieldType, value: bytes) -> None:
|
|
||||||
"""
|
"""
|
||||||
Make sure the byte data we receive are not corrupted or incorrect.
|
Make sure the byte data we receive are not corrupted or incorrect.
|
||||||
|
|
||||||
Raise wire.DataError if encountering a problem, so clients are notified.
|
Raise DataError if encountering a problem, so clients are notified.
|
||||||
"""
|
"""
|
||||||
# Checking if the size corresponds to what is defined in types,
|
# Checking if the size corresponds to what is defined in types,
|
||||||
# and also setting our maximum supported size in bytes
|
# and also setting our maximum supported size in bytes
|
||||||
if field.size is not None:
|
if field.size is not None:
|
||||||
if len(value) != field.size:
|
if len(value) != field.size:
|
||||||
raise wire.DataError("Invalid length")
|
raise DataError("Invalid length")
|
||||||
else:
|
else:
|
||||||
if len(value) > _MAX_VALUE_BYTE_SIZE:
|
if len(value) > _MAX_VALUE_BYTE_SIZE:
|
||||||
raise wire.DataError(f"Invalid length, bigger than {_MAX_VALUE_BYTE_SIZE}")
|
raise DataError(f"Invalid length, bigger than {_MAX_VALUE_BYTE_SIZE}")
|
||||||
|
|
||||||
# Specific tests for some data types
|
# Specific tests for some data types
|
||||||
if field.data_type == EthereumDataType.BOOL:
|
if field.data_type == EthereumDataType.BOOL:
|
||||||
if value not in (b"\x00", b"\x01"):
|
if value not in (b"\x00", b"\x01"):
|
||||||
raise wire.DataError("Invalid boolean value")
|
raise DataError("Invalid boolean value")
|
||||||
elif field.data_type == EthereumDataType.ADDRESS:
|
elif field.data_type == EthereumDataType.ADDRESS:
|
||||||
if len(value) != 20:
|
if len(value) != 20:
|
||||||
raise wire.DataError("Invalid address")
|
raise DataError("Invalid address")
|
||||||
elif field.data_type == EthereumDataType.STRING:
|
elif field.data_type == EthereumDataType.STRING:
|
||||||
try:
|
try:
|
||||||
value.decode()
|
value.decode()
|
||||||
except UnicodeError:
|
except UnicodeError:
|
||||||
raise wire.DataError("Invalid UTF-8")
|
raise DataError("Invalid UTF-8")
|
||||||
|
|
||||||
|
|
||||||
def validate_field_type(field: EthereumFieldType) -> None:
|
def validate_field_type(field: EthereumFieldType) -> None:
|
||||||
"""
|
"""
|
||||||
Make sure the field type is consistent with our expectation.
|
Make sure the field type is consistent with our expectation.
|
||||||
|
|
||||||
Raise wire.DataError if encountering a problem, so clients are notified.
|
Raise DataError if encountering a problem, so clients are notified.
|
||||||
"""
|
"""
|
||||||
|
EDT = EthereumDataType # local_cache_global
|
||||||
|
|
||||||
data_type = field.data_type
|
data_type = field.data_type
|
||||||
|
|
||||||
# entry_type is only for arrays
|
# entry_type is only for arrays
|
||||||
if data_type == EthereumDataType.ARRAY:
|
if data_type == EDT.ARRAY:
|
||||||
if field.entry_type is None:
|
if field.entry_type is None:
|
||||||
raise wire.DataError("Missing entry_type in array")
|
raise DataError("Missing entry_type in array")
|
||||||
# We also need to validate it recursively
|
# We also need to validate it recursively
|
||||||
validate_field_type(field.entry_type)
|
validate_field_type(field.entry_type)
|
||||||
else:
|
else:
|
||||||
if field.entry_type is not None:
|
if field.entry_type is not None:
|
||||||
raise wire.DataError("Unexpected entry_type in nonarray")
|
raise DataError("Unexpected entry_type in nonarray")
|
||||||
|
|
||||||
# struct_name is only for structs
|
# struct_name is only for structs
|
||||||
if data_type == EthereumDataType.STRUCT:
|
if data_type == EDT.STRUCT:
|
||||||
if field.struct_name is None:
|
if field.struct_name is None:
|
||||||
raise wire.DataError("Missing struct_name in struct")
|
raise DataError("Missing struct_name in struct")
|
||||||
else:
|
else:
|
||||||
if field.struct_name is not None:
|
if field.struct_name is not None:
|
||||||
raise wire.DataError("Unexpected struct_name in nonstruct")
|
raise DataError("Unexpected struct_name in nonstruct")
|
||||||
|
size = field.size # local_cache_attribute
|
||||||
|
|
||||||
# size is special for each type
|
# size is special for each type
|
||||||
if data_type == EthereumDataType.STRUCT:
|
if data_type == EDT.STRUCT:
|
||||||
if field.size is None:
|
if size is None:
|
||||||
raise wire.DataError("Missing size in struct")
|
raise DataError("Missing size in struct")
|
||||||
elif data_type == EthereumDataType.BYTES:
|
elif data_type == EDT.BYTES:
|
||||||
if field.size is not None and not 1 <= field.size <= 32:
|
if size is not None and not 1 <= size <= 32:
|
||||||
raise wire.DataError("Invalid size in bytes")
|
raise DataError("Invalid size in bytes")
|
||||||
elif data_type in (
|
elif data_type in (
|
||||||
EthereumDataType.UINT,
|
EDT.UINT,
|
||||||
EthereumDataType.INT,
|
EDT.INT,
|
||||||
):
|
):
|
||||||
if field.size is None or not 1 <= field.size <= 32:
|
if size is None or not 1 <= size <= 32:
|
||||||
raise wire.DataError("Invalid size in int/uint")
|
raise DataError("Invalid size in int/uint")
|
||||||
elif data_type in (
|
elif data_type in (
|
||||||
EthereumDataType.STRING,
|
EDT.STRING,
|
||||||
EthereumDataType.BOOL,
|
EDT.BOOL,
|
||||||
EthereumDataType.ADDRESS,
|
EDT.ADDRESS,
|
||||||
):
|
):
|
||||||
if field.size is not None:
|
if size is not None:
|
||||||
raise wire.DataError("Unexpected size in str/bool/addr")
|
raise DataError("Unexpected size in str/bool/addr")
|
||||||
|
|
||||||
|
|
||||||
async def get_array_size(ctx: Context, member_path: list[int]) -> int:
|
async def _get_array_size(ctx: Context, member_path: list[int]) -> int:
|
||||||
"""Get the length of an array at specific `member_path` from the client."""
|
"""Get the length of an array at specific `member_path` from the client."""
|
||||||
|
from trezor.messages import EthereumFieldType
|
||||||
|
|
||||||
# Field type for getting the array length from client, so we can check the return value
|
# Field type for getting the array length from client, so we can check the return value
|
||||||
ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2)
|
ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2)
|
||||||
length_value = await get_value(ctx, ARRAY_LENGTH_TYPE, member_path)
|
length_value = await get_value(ctx, ARRAY_LENGTH_TYPE, member_path)
|
||||||
@ -504,18 +516,20 @@ async def get_value(
|
|||||||
member_value_path: list[int],
|
member_value_path: list[int],
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
"""Get a single value from the client and perform its validation."""
|
"""Get a single value from the client and perform its validation."""
|
||||||
|
from trezor.messages import EthereumTypedDataValueAck, EthereumTypedDataValueRequest
|
||||||
|
|
||||||
req = EthereumTypedDataValueRequest(
|
req = EthereumTypedDataValueRequest(
|
||||||
member_path=member_value_path,
|
member_path=member_value_path,
|
||||||
)
|
)
|
||||||
res = await ctx.call(req, EthereumTypedDataValueAck)
|
res = await ctx.call(req, EthereumTypedDataValueAck)
|
||||||
value = res.value
|
value = res.value
|
||||||
|
|
||||||
validate_value(field=field, value=value)
|
_validate_value(field=field, value=value)
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
async def get_name_and_version_for_domain(
|
async def _get_name_and_version_for_domain(
|
||||||
ctx: Context, typed_data_envelope: TypedDataEnvelope
|
ctx: Context, typed_data_envelope: TypedDataEnvelope
|
||||||
) -> tuple[bytes, bytes]:
|
) -> tuple[bytes, bytes]:
|
||||||
domain_name = b"unknown"
|
domain_name = b"unknown"
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -2,6 +2,20 @@
|
|||||||
# (by running `make templates` in `core`)
|
# (by running `make templates` in `core`)
|
||||||
# do not edit manually!
|
# do not edit manually!
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
|
||||||
|
# NOTE: returning a tuple instead of `TokenInfo` from the "data" function
|
||||||
|
# saves 5600 bytes of flash size. Implementing the `_token_iterator`
|
||||||
|
# instead of if-tree approach saves another 5600 bytes.
|
||||||
|
|
||||||
|
# NOTE: interestingly, it did not save much flash size to use smaller
|
||||||
|
# parts of the address, for example address length of 10 bytes saves
|
||||||
|
# 1 byte per entry, so 1887 bytes overall (and further decrease does not help).
|
||||||
|
# (The idea was not having to store the whole address, even a smaller part
|
||||||
|
# of it has enough collision-resistance.)
|
||||||
|
# (In the if-tree approach the address length did not have any effect whatsoever.)
|
||||||
|
|
||||||
|
from typing import Iterator
|
||||||
|
|
||||||
<%
|
<%
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
@ -22,11 +36,20 @@ UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)
|
|||||||
|
|
||||||
|
|
||||||
def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo:
|
def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo:
|
||||||
|
for addr, symbol, decimal in _token_iterator(chain_id):
|
||||||
|
if address == addr:
|
||||||
|
return TokenInfo(symbol, decimal)
|
||||||
|
return UNKNOWN_TOKEN
|
||||||
|
|
||||||
|
|
||||||
|
def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int]]:
|
||||||
% for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items():
|
% for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items():
|
||||||
if chain_id == ${token_chain_id}:
|
if chain_id == ${token_chain_id}:
|
||||||
% for t in tokens:
|
% for t in tokens:
|
||||||
if address == ${black_repr(t.address_bytes)}:
|
yield ( # address, symbol, decimals
|
||||||
return TokenInfo(${black_repr(t.symbol)}, ${t.decimals}) # ${t.chain} / ${t.name.strip()}
|
${black_repr(t.address_bytes)},
|
||||||
|
${black_repr(t.symbol)},
|
||||||
|
${t.decimals},
|
||||||
|
)
|
||||||
% endfor
|
% endfor
|
||||||
% endfor
|
% endfor
|
||||||
return UNKNOWN_TOKEN
|
|
||||||
|
@ -1,42 +1,42 @@
|
|||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from trezor import wire
|
|
||||||
from trezor.crypto.curve import secp256k1
|
|
||||||
from trezor.crypto.hashlib import sha3_256
|
|
||||||
from trezor.messages import Success
|
|
||||||
from trezor.ui.layouts import confirm_signverify, show_success
|
|
||||||
|
|
||||||
from apps.common.signverify import decode_message
|
|
||||||
|
|
||||||
from .helpers import address_from_bytes, bytes_from_address
|
|
||||||
from .sign_message import message_digest
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from trezor.messages import EthereumVerifyMessage
|
from trezor.messages import EthereumVerifyMessage, Success
|
||||||
from trezor.wire import Context
|
from trezor.wire import Context
|
||||||
|
|
||||||
|
|
||||||
async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
|
async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
|
||||||
|
from trezor.wire import DataError
|
||||||
|
from trezor.crypto.curve import secp256k1
|
||||||
|
from trezor.crypto.hashlib import sha3_256
|
||||||
|
from trezor.messages import Success
|
||||||
|
from trezor.ui.layouts import confirm_signverify, show_success
|
||||||
|
|
||||||
|
from apps.common.signverify import decode_message
|
||||||
|
|
||||||
|
from .helpers import address_from_bytes, bytes_from_address
|
||||||
|
from .sign_message import message_digest
|
||||||
|
|
||||||
digest = message_digest(msg.message)
|
digest = message_digest(msg.message)
|
||||||
if len(msg.signature) != 65:
|
if len(msg.signature) != 65:
|
||||||
raise wire.DataError("Invalid signature")
|
raise DataError("Invalid signature")
|
||||||
sig = bytearray([msg.signature[64]]) + msg.signature[:64]
|
sig = bytearray([msg.signature[64]]) + msg.signature[:64]
|
||||||
|
|
||||||
pubkey = secp256k1.verify_recover(sig, digest)
|
pubkey = secp256k1.verify_recover(sig, digest)
|
||||||
|
|
||||||
if not pubkey:
|
if not pubkey:
|
||||||
raise wire.DataError("Invalid signature")
|
raise DataError("Invalid signature")
|
||||||
|
|
||||||
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
|
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
|
||||||
|
|
||||||
address_bytes = bytes_from_address(msg.address)
|
address_bytes = bytes_from_address(msg.address)
|
||||||
if address_bytes != pkh:
|
if address_bytes != pkh:
|
||||||
raise wire.DataError("Invalid signature")
|
raise DataError("Invalid signature")
|
||||||
|
|
||||||
address = address_from_bytes(address_bytes)
|
address = address_from_bytes(address_bytes)
|
||||||
|
|
||||||
await confirm_signverify(
|
await confirm_signverify(
|
||||||
ctx, "ETH", decode_message(msg.message), address=address, verify=True
|
ctx, "ETH", decode_message(msg.message), address, verify=True
|
||||||
)
|
)
|
||||||
|
|
||||||
await show_success(ctx, "verify_message", "The signature is valid.")
|
await show_success(ctx, "verify_message", "The signature is valid.")
|
||||||
|
@ -11,7 +11,7 @@ from trezor.enums import EthereumDataType as EDT
|
|||||||
if not utils.BITCOIN_ONLY:
|
if not utils.BITCOIN_ONLY:
|
||||||
from apps.ethereum.sign_typed_data import (
|
from apps.ethereum.sign_typed_data import (
|
||||||
encode_field,
|
encode_field,
|
||||||
validate_value,
|
_validate_value,
|
||||||
validate_field_type,
|
validate_field_type,
|
||||||
keccak256,
|
keccak256,
|
||||||
TypedDataEnvelope,
|
TypedDataEnvelope,
|
||||||
@ -594,10 +594,10 @@ class TestEthereumSignTypedData(unittest.TestCase):
|
|||||||
|
|
||||||
for field, valid_values, invalid_values in VECTORS_VALID_INVALID:
|
for field, valid_values, invalid_values in VECTORS_VALID_INVALID:
|
||||||
for valid_value in valid_values:
|
for valid_value in valid_values:
|
||||||
validate_value(field=field, value=valid_value)
|
_validate_value(field=field, value=valid_value)
|
||||||
for invalid_value in invalid_values:
|
for invalid_value in invalid_values:
|
||||||
with self.assertRaises(wire.DataError):
|
with self.assertRaises(wire.DataError):
|
||||||
validate_value(field=field, value=invalid_value)
|
_validate_value(field=field, value=invalid_value)
|
||||||
|
|
||||||
def test_validate_field_type(self):
|
def test_validate_field_type(self):
|
||||||
ET = EFT(data_type=EDT.BYTES, size=8)
|
ET = EFT(data_type=EDT.BYTES, size=8)
|
||||||
|
Loading…
Reference in New Issue
Block a user