mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 06:18:07 +00:00
chore(core): decrease ethereum size by 17250 bytes
This commit is contained in:
parent
0c3423b1c7
commit
26fd0de198
@ -1,16 +1,9 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.messages import EthereumAddress
|
||||
from trezor.ui.layouts import show_address
|
||||
|
||||
from apps.common import paths
|
||||
|
||||
from . import networks
|
||||
from .helpers import address_from_bytes
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumGetAddress
|
||||
from trezor.messages import EthereumGetAddress, EthereumAddress
|
||||
from trezor.wire import Context
|
||||
|
||||
from apps.common.keychain import Keychain
|
||||
@ -20,18 +13,26 @@ if TYPE_CHECKING:
|
||||
async def get_address(
|
||||
ctx: Context, msg: EthereumGetAddress, keychain: Keychain
|
||||
) -> EthereumAddress:
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
from trezor.messages import EthereumAddress
|
||||
from trezor.ui.layouts import show_address
|
||||
from apps.common import paths
|
||||
from . import networks
|
||||
from .helpers import address_from_bytes
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
address_n = msg.address_n # local_cache_attribute
|
||||
|
||||
if len(msg.address_n) > 1: # path has slip44 network identifier
|
||||
network = networks.by_slip44(msg.address_n[1] & 0x7FFF_FFFF)
|
||||
await paths.validate_path(ctx, keychain, address_n)
|
||||
|
||||
node = keychain.derive(address_n)
|
||||
|
||||
if len(address_n) > 1: # path has slip44 network identifier
|
||||
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF)
|
||||
else:
|
||||
network = None
|
||||
address = address_from_bytes(node.ethereum_pubkeyhash(), network)
|
||||
|
||||
if msg.show_display:
|
||||
title = paths.address_n_to_str(msg.address_n)
|
||||
await show_address(ctx, address=address, title=title)
|
||||
title = paths.address_n_to_str(address_n)
|
||||
await show_address(ctx, address, title=title)
|
||||
|
||||
return EthereumAddress(address=address)
|
||||
|
@ -1,15 +1,11 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify
|
||||
|
||||
from trezor.messages import EthereumPublicKey, HDNodeType
|
||||
from trezor.ui.layouts import show_pubkey
|
||||
|
||||
from apps.common import coins, paths
|
||||
from apps.common import paths
|
||||
|
||||
from .keychain import with_keychain_from_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumGetPublicKey
|
||||
from trezor.messages import EthereumGetPublicKey, EthereumPublicKey
|
||||
from trezor.wire import Context
|
||||
|
||||
from apps.common.keychain import Keychain
|
||||
@ -19,6 +15,11 @@ if TYPE_CHECKING:
|
||||
async def get_public_key(
|
||||
ctx: Context, msg: EthereumGetPublicKey, keychain: Keychain
|
||||
) -> EthereumPublicKey:
|
||||
from ubinascii import hexlify
|
||||
from trezor.messages import EthereumPublicKey, HDNodeType
|
||||
from trezor.ui.layouts import show_pubkey
|
||||
from apps.common import coins
|
||||
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
node = keychain.derive(msg.address_n)
|
||||
|
||||
|
@ -1,8 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify, unhexlify
|
||||
|
||||
from trezor import wire
|
||||
from trezor.enums import EthereumDataType
|
||||
from ubinascii import hexlify
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumFieldType
|
||||
@ -24,7 +21,7 @@ def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None)
|
||||
address_hex = hexlify(address_bytes).decode()
|
||||
digest = sha3_256((prefix + address_hex).encode(), keccak=True).digest()
|
||||
|
||||
def maybe_upper(i: int) -> str:
|
||||
def _maybe_upper(i: int) -> str:
|
||||
"""Uppercase i-th letter only if the corresponding nibble has high bit set."""
|
||||
digest_byte = digest[i // 2]
|
||||
hex_letter = address_hex[i]
|
||||
@ -39,10 +36,13 @@ def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None)
|
||||
else:
|
||||
return hex_letter
|
||||
|
||||
return "0x" + "".join(maybe_upper(i) for i in range(len(address_hex)))
|
||||
return "0x" + "".join(_maybe_upper(i) for i in range(len(address_hex)))
|
||||
|
||||
|
||||
def bytes_from_address(address: str) -> bytes:
|
||||
from ubinascii import unhexlify
|
||||
from trezor import wire
|
||||
|
||||
if len(address) == 40:
|
||||
return unhexlify(address)
|
||||
|
||||
@ -59,6 +59,8 @@ def bytes_from_address(address: str) -> bytes:
|
||||
|
||||
def get_type_name(field: EthereumFieldType) -> str:
|
||||
"""Create a string from type definition (like uint256 or bytes16)."""
|
||||
from trezor.enums import EthereumDataType
|
||||
|
||||
data_type = field.data_type
|
||||
size = field.size
|
||||
|
||||
@ -109,12 +111,12 @@ def decode_typed_data(data: bytes, type_name: str) -> str:
|
||||
return str(int.from_bytes(data, "big"))
|
||||
elif type_name.startswith("int"):
|
||||
# Micropython does not implement "signed" arg in int.from_bytes()
|
||||
return str(from_bytes_bigendian_signed(data))
|
||||
return str(_from_bytes_bigendian_signed(data))
|
||||
|
||||
raise ValueError # Unsupported data type for direct field decoding
|
||||
|
||||
|
||||
def from_bytes_bigendian_signed(b: bytes) -> int:
|
||||
def _from_bytes_bigendian_signed(b: bytes) -> int:
|
||||
negative = b[0] & 0x80
|
||||
if negative:
|
||||
neg_b = bytearray(b)
|
||||
|
@ -1,7 +1,5 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import wire
|
||||
|
||||
from apps.common import paths
|
||||
from apps.common.keychain import get_keychain
|
||||
|
||||
@ -10,6 +8,8 @@ from . import CURVE, networks
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable, Iterable, TypeVar
|
||||
|
||||
from trezor.wire import Context
|
||||
|
||||
from trezor.messages import (
|
||||
EthereumGetAddress,
|
||||
EthereumGetPublicKey,
|
||||
@ -64,7 +64,7 @@ def with_keychain_from_path(
|
||||
*patterns: str,
|
||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
|
||||
schemas = _schemas_from_address_n(patterns, msg.address_n)
|
||||
keychain = await get_keychain(ctx, CURVE, schemas)
|
||||
with keychain:
|
||||
@ -97,7 +97,7 @@ def with_keychain_from_chain_id(
|
||||
func: HandlerWithKeychain[MsgInChainId, MsgOut]
|
||||
) -> Handler[MsgInChainId, MsgOut]:
|
||||
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
|
||||
async def wrapper(ctx: wire.Context, msg: MsgInChainId) -> MsgOut:
|
||||
async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut:
|
||||
schemas = _schemas_from_chain_id(msg)
|
||||
keychain = await get_keychain(ctx, CURVE, schemas)
|
||||
with keychain:
|
||||
|
@ -1,29 +1,19 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify
|
||||
|
||||
from trezor import ui
|
||||
from trezor.enums import ButtonRequestType, EthereumDataType
|
||||
from trezor.strings import format_amount, format_plural
|
||||
from trezor.ui.layouts import (
|
||||
confirm_action,
|
||||
confirm_address,
|
||||
confirm_amount,
|
||||
confirm_blob,
|
||||
confirm_output,
|
||||
confirm_text,
|
||||
confirm_total,
|
||||
should_show_more,
|
||||
)
|
||||
from trezor.ui.layouts.altcoin import confirm_total_ethereum
|
||||
from trezor.enums import ButtonRequestType
|
||||
from trezor.strings import format_plural
|
||||
from trezor.ui.layouts import confirm_blob, confirm_text, should_show_more
|
||||
|
||||
from . import networks, tokens
|
||||
from .helpers import address_from_bytes, decode_typed_data, get_type_name
|
||||
from . import networks
|
||||
from .helpers import decode_typed_data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Awaitable, Iterable
|
||||
|
||||
from trezor.messages import EthereumFieldType, EthereumStructMember
|
||||
from trezor.wire import Context
|
||||
from . import tokens
|
||||
|
||||
|
||||
def require_confirm_tx(
|
||||
@ -33,15 +23,18 @@ def require_confirm_tx(
|
||||
chain_id: int,
|
||||
token: tokens.TokenInfo | None = None,
|
||||
) -> Awaitable[None]:
|
||||
from .helpers import address_from_bytes
|
||||
from trezor.ui.layouts import confirm_output
|
||||
|
||||
if to_bytes:
|
||||
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
|
||||
else:
|
||||
to_str = "new contract?"
|
||||
return confirm_output(
|
||||
ctx,
|
||||
address=to_str,
|
||||
amount=format_ethereum_amount(value, token, chain_id),
|
||||
font_amount=ui.BOLD,
|
||||
to_str,
|
||||
format_ethereum_amount(value, token, chain_id),
|
||||
ui.BOLD,
|
||||
color_to=ui.GREY,
|
||||
br_code=ButtonRequestType.SignTx,
|
||||
)
|
||||
@ -55,6 +48,8 @@ def require_confirm_fee(
|
||||
chain_id: int,
|
||||
token: tokens.TokenInfo | None = None,
|
||||
) -> Awaitable[None]:
|
||||
from trezor.ui.layouts.altcoin import confirm_total_ethereum
|
||||
|
||||
return confirm_total_ethereum(
|
||||
ctx,
|
||||
format_ethereum_amount(spending, token, chain_id),
|
||||
@ -72,22 +67,24 @@ async def require_confirm_eip1559_fee(
|
||||
chain_id: int,
|
||||
token: tokens.TokenInfo | None = None,
|
||||
) -> None:
|
||||
from trezor.ui.layouts import confirm_amount, confirm_total
|
||||
|
||||
await confirm_amount(
|
||||
ctx,
|
||||
title="Confirm fee",
|
||||
description="Maximum fee per gas",
|
||||
amount=format_ethereum_amount(max_gas_fee, None, chain_id),
|
||||
"Confirm fee",
|
||||
format_ethereum_amount(max_gas_fee, None, chain_id),
|
||||
"Maximum fee per gas",
|
||||
)
|
||||
await confirm_amount(
|
||||
ctx,
|
||||
title="Confirm fee",
|
||||
description="Priority fee per gas",
|
||||
amount=format_ethereum_amount(max_priority_fee, None, chain_id),
|
||||
"Confirm fee",
|
||||
format_ethereum_amount(max_priority_fee, None, chain_id),
|
||||
"Priority fee per gas",
|
||||
)
|
||||
await confirm_total(
|
||||
ctx,
|
||||
total_amount=format_ethereum_amount(spending, token, chain_id),
|
||||
fee_amount=format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
|
||||
format_ethereum_amount(spending, token, chain_id),
|
||||
format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
|
||||
total_label="Amount sent:\n",
|
||||
fee_label="\nMaximum fee:\n",
|
||||
)
|
||||
@ -96,13 +93,16 @@ async def require_confirm_eip1559_fee(
|
||||
def require_confirm_unknown_token(
|
||||
ctx: Context, address_bytes: bytes
|
||||
) -> Awaitable[None]:
|
||||
from ubinascii import hexlify
|
||||
from trezor.ui.layouts import confirm_address
|
||||
|
||||
contract_address_hex = "0x" + hexlify(address_bytes).decode()
|
||||
return confirm_address(
|
||||
ctx,
|
||||
"Unknown token",
|
||||
contract_address_hex,
|
||||
description="Contract:",
|
||||
br_type="unknown_token",
|
||||
"Contract:",
|
||||
"unknown_token",
|
||||
icon_color=ui.ORANGE,
|
||||
br_code=ButtonRequestType.SignTx,
|
||||
)
|
||||
@ -112,20 +112,22 @@ def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitabl
|
||||
return confirm_blob(
|
||||
ctx,
|
||||
"confirm_data",
|
||||
title="Confirm data",
|
||||
description=f"Size: {data_total} bytes",
|
||||
data=data,
|
||||
"Confirm data",
|
||||
data,
|
||||
f"Size: {data_total} bytes",
|
||||
br_code=ButtonRequestType.SignTx,
|
||||
ask_pagination=True,
|
||||
)
|
||||
|
||||
|
||||
async def confirm_typed_data_final(ctx: Context) -> None:
|
||||
from trezor.ui.layouts import confirm_action
|
||||
|
||||
await confirm_action(
|
||||
ctx,
|
||||
"confirm_typed_data_final",
|
||||
title="Confirm typed data",
|
||||
action="Really sign EIP-712 typed data?",
|
||||
"Confirm typed data",
|
||||
"Really sign EIP-712 typed data?",
|
||||
verb="Hold to confirm",
|
||||
hold=True,
|
||||
)
|
||||
@ -135,9 +137,9 @@ def confirm_empty_typed_message(ctx: Context) -> Awaitable[None]:
|
||||
return confirm_text(
|
||||
ctx,
|
||||
"confirm_empty_typed_message",
|
||||
title="Confirm message",
|
||||
data="",
|
||||
description="No message field",
|
||||
"Confirm message",
|
||||
"",
|
||||
"No message field",
|
||||
)
|
||||
|
||||
|
||||
@ -152,10 +154,10 @@ async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
|
||||
)
|
||||
return await should_show_more(
|
||||
ctx,
|
||||
title="Confirm domain",
|
||||
para=para,
|
||||
button_text="Show full domain",
|
||||
br_type="should_show_domain",
|
||||
"Confirm domain",
|
||||
para,
|
||||
"Show full domain",
|
||||
"should_show_domain",
|
||||
)
|
||||
|
||||
|
||||
@ -176,10 +178,10 @@ async def should_show_struct(
|
||||
)
|
||||
return await should_show_more(
|
||||
ctx,
|
||||
title=title,
|
||||
para=para,
|
||||
button_text=button_text,
|
||||
br_type="should_show_struct",
|
||||
title,
|
||||
para,
|
||||
button_text,
|
||||
"should_show_struct",
|
||||
)
|
||||
|
||||
|
||||
@ -192,10 +194,10 @@ async def should_show_array(
|
||||
para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),)
|
||||
return await should_show_more(
|
||||
ctx,
|
||||
title=limit_str(".".join(parent_objects)),
|
||||
para=para,
|
||||
button_text="Show full array",
|
||||
br_type="should_show_array",
|
||||
limit_str(".".join(parent_objects)),
|
||||
para,
|
||||
"Show full array",
|
||||
"should_show_array",
|
||||
)
|
||||
|
||||
|
||||
@ -207,6 +209,9 @@ async def confirm_typed_value(
|
||||
field: EthereumFieldType,
|
||||
array_index: int | None = None,
|
||||
) -> None:
|
||||
from trezor.enums import EthereumDataType
|
||||
from .helpers import get_type_name
|
||||
|
||||
type_name = get_type_name(field)
|
||||
|
||||
if array_index is not None:
|
||||
@ -222,24 +227,26 @@ async def confirm_typed_value(
|
||||
await confirm_blob(
|
||||
ctx,
|
||||
"confirm_typed_value",
|
||||
title=title,
|
||||
data=data,
|
||||
description=description,
|
||||
title,
|
||||
data,
|
||||
description,
|
||||
ask_pagination=True,
|
||||
)
|
||||
else:
|
||||
await confirm_text(
|
||||
ctx,
|
||||
"confirm_typed_value",
|
||||
title=title,
|
||||
data=data,
|
||||
description=description,
|
||||
title,
|
||||
data,
|
||||
description,
|
||||
)
|
||||
|
||||
|
||||
def format_ethereum_amount(
|
||||
value: int, token: tokens.TokenInfo | None, chain_id: int
|
||||
) -> str:
|
||||
from trezor.strings import format_amount
|
||||
|
||||
if token:
|
||||
suffix = token.symbol
|
||||
decimals = token.decimals
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,10 +1,28 @@
|
||||
# generated from networks.py.mako
|
||||
# (by running `make templates` in `core`)
|
||||
# do not edit manually!
|
||||
from typing import Iterator
|
||||
|
||||
# NOTE: using positional arguments saves 4400 bytes in flash size,
|
||||
# returning tuples instead of classes saved 800 bytes
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from apps.common.paths import HARDENED
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Iterator
|
||||
|
||||
# Removing the necessity to construct object to save space
|
||||
# fmt: off
|
||||
NetworkInfoTuple = tuple[
|
||||
int, # chain_id
|
||||
int, # slip44
|
||||
str, # shortcut
|
||||
str, # name
|
||||
bool # rskip60
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def shortcut_by_chain_id(chain_id: int) -> str:
|
||||
n = by_chain_id(chain_id)
|
||||
@ -13,21 +31,36 @@ def shortcut_by_chain_id(chain_id: int) -> str:
|
||||
|
||||
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
|
||||
for n in _networks_iterator():
|
||||
if n.chain_id == chain_id:
|
||||
return n
|
||||
n_chain_id = n[0]
|
||||
if n_chain_id == chain_id:
|
||||
return NetworkInfo(
|
||||
chain_id=n[0],
|
||||
slip44=n[1],
|
||||
shortcut=n[2],
|
||||
name=n[3],
|
||||
rskip60=n[4],
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def by_slip44(slip44: int) -> "NetworkInfo" | None:
|
||||
for n in _networks_iterator():
|
||||
if n.slip44 == slip44:
|
||||
return n
|
||||
n_slip44 = n[1]
|
||||
if n_slip44 == slip44:
|
||||
return NetworkInfo(
|
||||
chain_id=n[0],
|
||||
slip44=n[1],
|
||||
shortcut=n[2],
|
||||
name=n[3],
|
||||
rskip60=n[4],
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def all_slip44_ids_hardened() -> Iterator[int]:
|
||||
for n in _networks_iterator():
|
||||
yield n.slip44 | HARDENED
|
||||
# n_slip_44 is the second element
|
||||
yield n[1] | HARDENED
|
||||
|
||||
|
||||
class NetworkInfo:
|
||||
@ -42,13 +75,13 @@ class NetworkInfo:
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _networks_iterator() -> Iterator[NetworkInfo]:
|
||||
def _networks_iterator() -> Iterator[NetworkInfoTuple]:
|
||||
% for n in supported_on("trezor2", eth):
|
||||
yield NetworkInfo(
|
||||
chain_id=${n.chain_id},
|
||||
slip44=${n.slip44},
|
||||
shortcut="${n.shortcut}",
|
||||
name="${n.name}",
|
||||
rskip60=${n.rskip60},
|
||||
yield (
|
||||
${n.chain_id}, # chain_id
|
||||
${n.slip44}, # slip44
|
||||
"${n.shortcut}", # shortcut
|
||||
"${n.name}", # name
|
||||
${n.rskip60}, # rskip60
|
||||
)
|
||||
% endfor
|
||||
|
@ -1,25 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.messages import EthereumMessageSignature
|
||||
from trezor.ui.layouts import confirm_signverify
|
||||
from trezor.utils import HashWriter
|
||||
|
||||
from apps.common import paths
|
||||
from apps.common.signverify import decode_message
|
||||
|
||||
from .helpers import address_from_bytes
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumSignMessage
|
||||
from trezor.messages import EthereumSignMessage, EthereumMessageSignature
|
||||
from trezor.wire import Context
|
||||
|
||||
from apps.common.keychain import Keychain
|
||||
|
||||
|
||||
def message_digest(message: bytes) -> bytes:
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.utils import HashWriter
|
||||
|
||||
h = HashWriter(sha3_256(keccak=True))
|
||||
signed_message_header = b"\x19Ethereum Signed Message:\n"
|
||||
h.extend(signed_message_header)
|
||||
@ -32,6 +25,15 @@ def message_digest(message: bytes) -> bytes:
|
||||
async def sign_message(
|
||||
ctx: Context, msg: EthereumSignMessage, keychain: Keychain
|
||||
) -> EthereumMessageSignature:
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.messages import EthereumMessageSignature
|
||||
from trezor.ui.layouts import confirm_signverify
|
||||
|
||||
from apps.common import paths
|
||||
from apps.common.signverify import decode_message
|
||||
|
||||
from .helpers import address_from_bytes
|
||||
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
|
@ -1,29 +1,19 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import wire
|
||||
from trezor.crypto import rlp
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.messages import EthereumTxAck, EthereumTxRequest
|
||||
from trezor.utils import HashWriter
|
||||
from trezor.messages import EthereumTxRequest
|
||||
from trezor.wire import DataError
|
||||
|
||||
from apps.common import paths
|
||||
|
||||
from . import tokens
|
||||
from .helpers import bytes_from_address
|
||||
from .keychain import with_keychain_from_chain_id
|
||||
from .layout import (
|
||||
require_confirm_data,
|
||||
require_confirm_fee,
|
||||
require_confirm_tx,
|
||||
require_confirm_unknown_token,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.keychain import Keychain
|
||||
from trezor.messages import EthereumSignTx
|
||||
from trezor.messages import EthereumSignTx, EthereumTxAck
|
||||
from trezor.wire import Context
|
||||
|
||||
from .keychain import EthereumSignTxAny
|
||||
from . import tokens
|
||||
|
||||
|
||||
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
|
||||
@ -34,9 +24,24 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
|
||||
|
||||
@with_keychain_from_chain_id
|
||||
async def sign_tx(
|
||||
ctx: wire.Context, msg: EthereumSignTx, keychain: Keychain
|
||||
ctx: Context, msg: EthereumSignTx, keychain: Keychain
|
||||
) -> EthereumTxRequest:
|
||||
check(msg)
|
||||
from trezor.utils import HashWriter
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from apps.common import paths
|
||||
from .layout import (
|
||||
require_confirm_data,
|
||||
require_confirm_fee,
|
||||
require_confirm_tx,
|
||||
)
|
||||
|
||||
# check
|
||||
if msg.tx_type not in [1, 6, None]:
|
||||
raise DataError("tx_type out of bounds")
|
||||
if len(msg.gas_price) + len(msg.gas_limit) > 30:
|
||||
raise DataError("Fee overflow")
|
||||
check_common_fields(msg)
|
||||
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
|
||||
# Handle ERC20s
|
||||
@ -61,7 +66,7 @@ async def sign_tx(
|
||||
data += msg.data_initial_chunk
|
||||
data_left = data_total - len(msg.data_initial_chunk)
|
||||
|
||||
total_length = get_total_length(msg, data_total)
|
||||
total_length = _get_total_length(msg, data_total)
|
||||
|
||||
sha = HashWriter(sha3_256(keccak=True))
|
||||
rlp.write_header(sha, total_length, rlp.LIST_HEADER_BYTE)
|
||||
@ -89,14 +94,19 @@ async def sign_tx(
|
||||
rlp.write(sha, 0)
|
||||
|
||||
digest = sha.get_digest()
|
||||
result = sign_digest(msg, keychain, digest)
|
||||
result = _sign_digest(msg, keychain, digest)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def handle_erc20(
|
||||
ctx: wire.Context, msg: EthereumSignTxAny
|
||||
ctx: Context, msg: EthereumSignTxAny
|
||||
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
|
||||
from .layout import require_confirm_unknown_token
|
||||
from . import tokens
|
||||
|
||||
data_initial_chunk = msg.data_initial_chunk # local_cache_attribute
|
||||
|
||||
token = None
|
||||
address_bytes = recipient = bytes_from_address(msg.to)
|
||||
value = int.from_bytes(msg.value, "big")
|
||||
@ -104,13 +114,13 @@ async def handle_erc20(
|
||||
len(msg.to) in (40, 42)
|
||||
and len(msg.value) == 0
|
||||
and msg.data_length == 68
|
||||
and len(msg.data_initial_chunk) == 68
|
||||
and msg.data_initial_chunk[:16]
|
||||
and len(data_initial_chunk) == 68
|
||||
and data_initial_chunk[:16]
|
||||
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
):
|
||||
token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
|
||||
recipient = msg.data_initial_chunk[16:36]
|
||||
value = int.from_bytes(msg.data_initial_chunk[36:68], "big")
|
||||
recipient = data_initial_chunk[16:36]
|
||||
value = int.from_bytes(data_initial_chunk[36:68], "big")
|
||||
|
||||
if token is tokens.UNKNOWN_TOKEN:
|
||||
await require_confirm_unknown_token(ctx, address_bytes)
|
||||
@ -118,7 +128,7 @@ async def handle_erc20(
|
||||
return token, address_bytes, recipient, value
|
||||
|
||||
|
||||
def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
|
||||
def _get_total_length(msg: EthereumSignTx, data_total: int) -> int:
|
||||
length = 0
|
||||
if msg.tx_type is not None:
|
||||
length += rlp.length(msg.tx_type)
|
||||
@ -143,20 +153,20 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
|
||||
return length
|
||||
|
||||
|
||||
async def send_request_chunk(ctx: wire.Context, data_left: int) -> EthereumTxAck:
|
||||
async def send_request_chunk(ctx: Context, data_left: int) -> EthereumTxAck:
|
||||
from trezor.messages import EthereumTxAck
|
||||
|
||||
# TODO: layoutProgress ?
|
||||
req = EthereumTxRequest()
|
||||
if data_left <= 1024:
|
||||
req.data_length = data_left
|
||||
else:
|
||||
req.data_length = 1024
|
||||
|
||||
req.data_length = min(data_left, 1024)
|
||||
return await ctx.call(req, EthereumTxAck)
|
||||
|
||||
|
||||
def sign_digest(
|
||||
def _sign_digest(
|
||||
msg: EthereumSignTx, keychain: Keychain, digest: bytes
|
||||
) -> EthereumTxRequest:
|
||||
from trezor.crypto.curve import secp256k1
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
signature = secp256k1.sign(
|
||||
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
|
||||
@ -175,33 +185,25 @@ def sign_digest(
|
||||
return req
|
||||
|
||||
|
||||
def check(msg: EthereumSignTx) -> None:
|
||||
if msg.tx_type not in [1, 6, None]:
|
||||
raise wire.DataError("tx_type out of bounds")
|
||||
|
||||
if len(msg.gas_price) + len(msg.gas_limit) > 30:
|
||||
raise wire.DataError("Fee overflow")
|
||||
|
||||
check_common_fields(msg)
|
||||
|
||||
|
||||
def check_common_fields(msg: EthereumSignTxAny) -> None:
|
||||
if msg.data_length > 0:
|
||||
data_length = msg.data_length # local_cache_attribute
|
||||
|
||||
if data_length > 0:
|
||||
if not msg.data_initial_chunk:
|
||||
raise wire.DataError("Data length provided, but no initial chunk")
|
||||
raise DataError("Data length provided, but no initial chunk")
|
||||
# Our encoding only supports transactions up to 2^24 bytes. To
|
||||
# prevent exceeding the limit we use a stricter limit on data length.
|
||||
if msg.data_length > 16_000_000:
|
||||
raise wire.DataError("Data length exceeds limit")
|
||||
if len(msg.data_initial_chunk) > msg.data_length:
|
||||
raise wire.DataError("Invalid size of initial chunk")
|
||||
if data_length > 16_000_000:
|
||||
raise DataError("Data length exceeds limit")
|
||||
if len(msg.data_initial_chunk) > data_length:
|
||||
raise DataError("Invalid size of initial chunk")
|
||||
|
||||
if len(msg.to) not in (0, 40, 42):
|
||||
raise wire.DataError("Invalid recipient address")
|
||||
raise DataError("Invalid recipient address")
|
||||
|
||||
if not msg.to and msg.data_length == 0:
|
||||
if not msg.to and data_length == 0:
|
||||
# sending transaction to address 0 (contract creation) without a data field
|
||||
raise wire.DataError("Contract creation without data")
|
||||
raise DataError("Contract creation without data")
|
||||
|
||||
if msg.chain_id == 0:
|
||||
raise wire.DataError("Chain ID out of bounds")
|
||||
raise DataError("Chain ID out of bounds")
|
||||
|
@ -1,28 +1,21 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import wire
|
||||
from trezor.crypto import rlp
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.messages import EthereumAccessList, EthereumTxRequest
|
||||
from trezor.utils import HashWriter
|
||||
|
||||
from apps.common import paths
|
||||
|
||||
from .helpers import bytes_from_address
|
||||
from .keychain import with_keychain_from_chain_id
|
||||
from .layout import (
|
||||
require_confirm_data,
|
||||
require_confirm_eip1559_fee,
|
||||
require_confirm_tx,
|
||||
)
|
||||
from .sign_tx import check_common_fields, handle_erc20, send_request_chunk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumSignTxEIP1559
|
||||
from trezor.messages import (
|
||||
EthereumSignTxEIP1559,
|
||||
EthereumAccessList,
|
||||
EthereumTxRequest,
|
||||
)
|
||||
|
||||
from apps.common.keychain import Keychain
|
||||
from trezor.wire import Context
|
||||
|
||||
|
||||
_TX_TYPE = const(2)
|
||||
|
||||
@ -35,28 +28,30 @@ def access_list_item_length(item: EthereumAccessList) -> int:
|
||||
)
|
||||
|
||||
|
||||
def access_list_length(access_list: list[EthereumAccessList]) -> int:
|
||||
payload_length = sum(access_list_item_length(i) for i in access_list)
|
||||
return rlp.header_length(payload_length) + payload_length
|
||||
|
||||
|
||||
def write_access_list(w: HashWriter, access_list: list[EthereumAccessList]) -> None:
|
||||
payload_length = sum(access_list_item_length(i) for i in access_list)
|
||||
rlp.write_header(w, payload_length, rlp.LIST_HEADER_BYTE)
|
||||
for item in access_list:
|
||||
address_bytes = bytes_from_address(item.address)
|
||||
address_length = rlp.length(address_bytes)
|
||||
keys_length = rlp.length(item.storage_keys)
|
||||
rlp.write_header(w, address_length + keys_length, rlp.LIST_HEADER_BYTE)
|
||||
rlp.write(w, address_bytes)
|
||||
rlp.write(w, item.storage_keys)
|
||||
|
||||
|
||||
@with_keychain_from_chain_id
|
||||
async def sign_tx_eip1559(
|
||||
ctx: wire.Context, msg: EthereumSignTxEIP1559, keychain: Keychain
|
||||
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain
|
||||
) -> EthereumTxRequest:
|
||||
check(msg)
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.utils import HashWriter
|
||||
from trezor import wire
|
||||
from trezor.crypto import rlp # local_cache_global
|
||||
from apps.common import paths
|
||||
from .layout import (
|
||||
require_confirm_data,
|
||||
require_confirm_eip1559_fee,
|
||||
require_confirm_tx,
|
||||
)
|
||||
from .sign_tx import handle_erc20, send_request_chunk, check_common_fields
|
||||
|
||||
gas_limit = msg.gas_limit # local_cache_attribute
|
||||
|
||||
# check
|
||||
if len(msg.max_gas_fee) + len(gas_limit) > 30:
|
||||
raise wire.DataError("Fee overflow")
|
||||
if len(msg.max_priority_fee) + len(gas_limit) > 30:
|
||||
raise wire.DataError("Fee overflow")
|
||||
check_common_fields(msg)
|
||||
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
|
||||
@ -74,7 +69,7 @@ async def sign_tx_eip1559(
|
||||
value,
|
||||
int.from_bytes(msg.max_priority_fee, "big"),
|
||||
int.from_bytes(msg.max_gas_fee, "big"),
|
||||
int.from_bytes(msg.gas_limit, "big"),
|
||||
int.from_bytes(gas_limit, "big"),
|
||||
msg.chain_id,
|
||||
token,
|
||||
)
|
||||
@ -83,7 +78,7 @@ async def sign_tx_eip1559(
|
||||
data += msg.data_initial_chunk
|
||||
data_left = data_total - len(msg.data_initial_chunk)
|
||||
|
||||
total_length = get_total_length(msg, data_total)
|
||||
total_length = _get_total_length(msg, data_total)
|
||||
|
||||
sha = HashWriter(sha3_256(keccak=True))
|
||||
|
||||
@ -96,7 +91,7 @@ async def sign_tx_eip1559(
|
||||
msg.nonce,
|
||||
msg.max_priority_fee,
|
||||
msg.max_gas_fee,
|
||||
msg.gas_limit,
|
||||
gas_limit,
|
||||
address_bytes,
|
||||
msg.value,
|
||||
)
|
||||
@ -114,15 +109,24 @@ async def sign_tx_eip1559(
|
||||
data_left -= len(resp.data_chunk)
|
||||
sha.extend(resp.data_chunk)
|
||||
|
||||
write_access_list(sha, msg.access_list)
|
||||
# write_access_list
|
||||
payload_length = sum(access_list_item_length(i) for i in msg.access_list)
|
||||
rlp.write_header(sha, payload_length, rlp.LIST_HEADER_BYTE)
|
||||
for item in msg.access_list:
|
||||
address_bytes = bytes_from_address(item.address)
|
||||
address_length = rlp.length(address_bytes)
|
||||
keys_length = rlp.length(item.storage_keys)
|
||||
rlp.write_header(sha, address_length + keys_length, rlp.LIST_HEADER_BYTE)
|
||||
rlp.write(sha, address_bytes)
|
||||
rlp.write(sha, item.storage_keys)
|
||||
|
||||
digest = sha.get_digest()
|
||||
result = sign_digest(msg, keychain, digest)
|
||||
result = _sign_digest(msg, keychain, digest)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
|
||||
def _get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
|
||||
length = 0
|
||||
|
||||
fields: tuple[rlp.RLPItem, ...] = (
|
||||
@ -140,14 +144,21 @@ def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
|
||||
length += rlp.header_length(data_total, msg.data_initial_chunk)
|
||||
length += data_total
|
||||
|
||||
length += access_list_length(msg.access_list)
|
||||
# access_list_length
|
||||
payload_length = sum(access_list_item_length(i) for i in msg.access_list)
|
||||
access_list_length = rlp.header_length(payload_length) + payload_length
|
||||
|
||||
length += access_list_length
|
||||
|
||||
return length
|
||||
|
||||
|
||||
def sign_digest(
|
||||
def _sign_digest(
|
||||
msg: EthereumSignTxEIP1559, keychain: Keychain, digest: bytes
|
||||
) -> EthereumTxRequest:
|
||||
from trezor.messages import EthereumTxRequest
|
||||
from trezor.crypto.curve import secp256k1
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
signature = secp256k1.sign(
|
||||
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
|
||||
@ -159,12 +170,3 @@ def sign_digest(
|
||||
req.signature_s = signature[33:]
|
||||
|
||||
return req
|
||||
|
||||
|
||||
def check(msg: EthereumSignTxEIP1559) -> None:
|
||||
if len(msg.max_gas_fee) + len(msg.gas_limit) > 30:
|
||||
raise wire.DataError("Fee overflow")
|
||||
if len(msg.max_priority_fee) + len(msg.gas_limit) > 30:
|
||||
raise wire.DataError("Fee overflow")
|
||||
|
||||
check_common_fields(msg)
|
||||
|
@ -1,38 +1,24 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import wire
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.enums import EthereumDataType
|
||||
from trezor.messages import (
|
||||
EthereumFieldType,
|
||||
EthereumTypedDataSignature,
|
||||
EthereumTypedDataStructAck,
|
||||
EthereumTypedDataStructRequest,
|
||||
EthereumTypedDataValueAck,
|
||||
EthereumTypedDataValueRequest,
|
||||
)
|
||||
from trezor.utils import HashWriter
|
||||
from trezor.wire import DataError
|
||||
|
||||
from apps.common import paths
|
||||
|
||||
from .helpers import address_from_bytes, get_type_name
|
||||
from .helpers import get_type_name
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||
from .layout import (
|
||||
confirm_empty_typed_message,
|
||||
confirm_typed_data_final,
|
||||
confirm_typed_value,
|
||||
should_show_array,
|
||||
should_show_domain,
|
||||
should_show_struct,
|
||||
)
|
||||
from .layout import should_show_struct
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from apps.common.keychain import Keychain
|
||||
from trezor.wire import Context
|
||||
from trezor.utils import HashWriter
|
||||
|
||||
from trezor.messages import EthereumSignTypedData
|
||||
from trezor.messages import (
|
||||
EthereumSignTypedData,
|
||||
EthereumFieldType,
|
||||
EthereumTypedDataSignature,
|
||||
EthereumTypedDataStructAck,
|
||||
)
|
||||
|
||||
|
||||
# Maximum data size we support
|
||||
@ -43,9 +29,14 @@ _MAX_VALUE_BYTE_SIZE = const(1024)
|
||||
async def sign_typed_data(
|
||||
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain
|
||||
) -> EthereumTypedDataSignature:
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from apps.common import paths
|
||||
from .helpers import address_from_bytes
|
||||
from trezor.messages import EthereumTypedDataSignature
|
||||
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
|
||||
data_hash = await generate_typed_data_hash(
|
||||
data_hash = await _generate_typed_data_hash(
|
||||
ctx, msg.primary_type, msg.metamask_v4_compat
|
||||
)
|
||||
|
||||
@ -60,7 +51,7 @@ async def sign_typed_data(
|
||||
)
|
||||
|
||||
|
||||
async def generate_typed_data_hash(
|
||||
async def _generate_typed_data_hash(
|
||||
ctx: Context, primary_type: str, metamask_v4_compat: bool = True
|
||||
) -> bytes:
|
||||
"""
|
||||
@ -69,20 +60,26 @@ async def generate_typed_data_hash(
|
||||
|
||||
metamask_v4_compat - a flag that enables compatibility with MetaMask's signTypedData_v4 method
|
||||
"""
|
||||
from .layout import (
|
||||
confirm_empty_typed_message,
|
||||
confirm_typed_data_final,
|
||||
should_show_domain,
|
||||
)
|
||||
|
||||
typed_data_envelope = TypedDataEnvelope(
|
||||
ctx=ctx,
|
||||
primary_type=primary_type,
|
||||
metamask_v4_compat=metamask_v4_compat,
|
||||
ctx,
|
||||
primary_type,
|
||||
metamask_v4_compat,
|
||||
)
|
||||
await typed_data_envelope.collect_types()
|
||||
|
||||
name, version = await get_name_and_version_for_domain(ctx, typed_data_envelope)
|
||||
name, version = await _get_name_and_version_for_domain(ctx, typed_data_envelope)
|
||||
show_domain = await should_show_domain(ctx, name, version)
|
||||
domain_separator = await typed_data_envelope.hash_struct(
|
||||
primary_type="EIP712Domain",
|
||||
member_path=[0],
|
||||
show_data=show_domain,
|
||||
parent_objects=["EIP712Domain"],
|
||||
"EIP712Domain",
|
||||
[0],
|
||||
show_domain,
|
||||
["EIP712Domain"],
|
||||
)
|
||||
|
||||
# Setting the primary_type to "EIP712Domain" is technically in spec
|
||||
@ -94,16 +91,16 @@ async def generate_typed_data_hash(
|
||||
else:
|
||||
show_message = await should_show_struct(
|
||||
ctx,
|
||||
description=primary_type,
|
||||
data_members=typed_data_envelope.types[primary_type].members,
|
||||
title="Confirm message",
|
||||
button_text="Show full message",
|
||||
primary_type,
|
||||
typed_data_envelope.types[primary_type].members,
|
||||
"Confirm message",
|
||||
"Show full message",
|
||||
)
|
||||
message_hash = await typed_data_envelope.hash_struct(
|
||||
primary_type=primary_type,
|
||||
member_path=[1],
|
||||
show_data=show_message,
|
||||
parent_objects=[primary_type],
|
||||
primary_type,
|
||||
[1],
|
||||
show_message,
|
||||
[primary_type],
|
||||
)
|
||||
|
||||
await confirm_typed_data_final(ctx)
|
||||
@ -112,6 +109,9 @@ async def generate_typed_data_hash(
|
||||
|
||||
|
||||
def get_hash_writer() -> HashWriter:
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.utils import HashWriter
|
||||
|
||||
return HashWriter(sha3_256(keccak=True))
|
||||
|
||||
|
||||
@ -142,6 +142,11 @@ class TypedDataEnvelope:
|
||||
|
||||
async def _collect_types(self, type_name: str) -> None:
|
||||
"""Recursively collect types from the client."""
|
||||
from trezor.messages import (
|
||||
EthereumTypedDataStructRequest,
|
||||
EthereumTypedDataStructAck,
|
||||
)
|
||||
|
||||
req = EthereumTypedDataStructRequest(name=type_name)
|
||||
current_type = await self.ctx.call(req, EthereumTypedDataStructAck)
|
||||
self.types[type_name] = current_type
|
||||
@ -169,11 +174,11 @@ class TypedDataEnvelope:
|
||||
w = get_hash_writer()
|
||||
self.hash_type(w, primary_type)
|
||||
await self.get_and_encode_data(
|
||||
w=w,
|
||||
primary_type=primary_type,
|
||||
member_path=member_path,
|
||||
show_data=show_data,
|
||||
parent_objects=parent_objects,
|
||||
w,
|
||||
primary_type,
|
||||
member_path,
|
||||
show_data,
|
||||
parent_objects,
|
||||
)
|
||||
return w.get_digest()
|
||||
|
||||
@ -242,6 +247,10 @@ class TypedDataEnvelope:
|
||||
i.e. the concatenation of the encoded member values in the order that they appear in the type.
|
||||
Each encoded member value is exactly 32-byte long.
|
||||
"""
|
||||
from .layout import confirm_typed_value, should_show_array
|
||||
|
||||
ctx = self.ctx # local_cache_attribute
|
||||
|
||||
type_members = self.types[primary_type].members
|
||||
member_value_path = member_path + [0]
|
||||
current_parent_objects = parent_objects + [""]
|
||||
@ -258,25 +267,25 @@ class TypedDataEnvelope:
|
||||
|
||||
if show_data:
|
||||
show_struct = await should_show_struct(
|
||||
ctx=self.ctx,
|
||||
description=struct_name,
|
||||
data_members=self.types[struct_name].members,
|
||||
title=".".join(current_parent_objects),
|
||||
ctx,
|
||||
struct_name, # description
|
||||
self.types[struct_name].members, # data_members
|
||||
".".join(current_parent_objects), # title
|
||||
)
|
||||
else:
|
||||
show_struct = False
|
||||
|
||||
res = await self.hash_struct(
|
||||
primary_type=struct_name,
|
||||
member_path=member_value_path,
|
||||
show_data=show_struct,
|
||||
parent_objects=current_parent_objects,
|
||||
struct_name,
|
||||
member_value_path,
|
||||
show_struct,
|
||||
current_parent_objects,
|
||||
)
|
||||
w.extend(res)
|
||||
elif field_type.data_type == EthereumDataType.ARRAY:
|
||||
# Getting the length of the array first, if not fixed
|
||||
if field_type.size is None:
|
||||
array_size = await get_array_size(self.ctx, member_value_path)
|
||||
array_size = await _get_array_size(ctx, member_value_path)
|
||||
else:
|
||||
array_size = field_type.size
|
||||
|
||||
@ -286,10 +295,10 @@ class TypedDataEnvelope:
|
||||
|
||||
if show_data:
|
||||
show_array = await should_show_array(
|
||||
ctx=self.ctx,
|
||||
parent_objects=current_parent_objects,
|
||||
data_type=get_type_name(entry_type),
|
||||
size=array_size,
|
||||
ctx,
|
||||
current_parent_objects,
|
||||
get_type_name(entry_type),
|
||||
array_size,
|
||||
)
|
||||
else:
|
||||
show_array = False
|
||||
@ -309,43 +318,43 @@ class TypedDataEnvelope:
|
||||
# Metamask V4 is using hash_struct() even in this case
|
||||
if self.metamask_v4_compat:
|
||||
res = await self.hash_struct(
|
||||
primary_type=struct_name,
|
||||
member_path=el_member_path,
|
||||
show_data=show_array,
|
||||
parent_objects=current_parent_objects,
|
||||
struct_name,
|
||||
el_member_path,
|
||||
show_array,
|
||||
current_parent_objects,
|
||||
)
|
||||
arr_w.extend(res)
|
||||
else:
|
||||
await self.get_and_encode_data(
|
||||
w=arr_w,
|
||||
primary_type=struct_name,
|
||||
member_path=el_member_path,
|
||||
show_data=show_array,
|
||||
parent_objects=current_parent_objects,
|
||||
arr_w,
|
||||
struct_name,
|
||||
el_member_path,
|
||||
show_array,
|
||||
current_parent_objects,
|
||||
)
|
||||
else:
|
||||
value = await get_value(self.ctx, entry_type, el_member_path)
|
||||
value = await get_value(ctx, entry_type, el_member_path)
|
||||
encode_field(arr_w, entry_type, value)
|
||||
if show_array:
|
||||
await confirm_typed_value(
|
||||
ctx=self.ctx,
|
||||
name=field_name,
|
||||
value=value,
|
||||
parent_objects=parent_objects,
|
||||
field=entry_type,
|
||||
array_index=i,
|
||||
ctx,
|
||||
field_name,
|
||||
value,
|
||||
parent_objects,
|
||||
entry_type,
|
||||
i,
|
||||
)
|
||||
w.extend(arr_w.get_digest())
|
||||
else:
|
||||
value = await get_value(self.ctx, field_type, member_value_path)
|
||||
value = await get_value(ctx, field_type, member_value_path)
|
||||
encode_field(w, field_type, value)
|
||||
if show_data:
|
||||
await confirm_typed_value(
|
||||
ctx=self.ctx,
|
||||
name=field_name,
|
||||
value=value,
|
||||
parent_objects=parent_objects,
|
||||
field=field_type,
|
||||
ctx,
|
||||
field_name,
|
||||
value,
|
||||
parent_objects,
|
||||
field_type,
|
||||
)
|
||||
|
||||
|
||||
@ -370,21 +379,27 @@ def encode_field(
|
||||
encodeData of their contents
|
||||
- Struct values are encoded recursively as hashStruct(value)
|
||||
"""
|
||||
EDT = EthereumDataType # local_cache_global
|
||||
|
||||
data_type = field.data_type
|
||||
|
||||
if data_type == EthereumDataType.BYTES:
|
||||
if data_type == EDT.BYTES:
|
||||
if field.size is None:
|
||||
w.extend(keccak256(value))
|
||||
else:
|
||||
write_rightpad32(w, value)
|
||||
elif data_type == EthereumDataType.STRING:
|
||||
# write_rightpad32
|
||||
assert len(value) <= 32
|
||||
w.extend(value)
|
||||
for _ in range(32 - len(value)):
|
||||
w.append(0x00)
|
||||
elif data_type == EDT.STRING:
|
||||
w.extend(keccak256(value))
|
||||
elif data_type == EthereumDataType.INT:
|
||||
elif data_type == EDT.INT:
|
||||
write_leftpad32(w, value, signed=True)
|
||||
elif data_type in (
|
||||
EthereumDataType.UINT,
|
||||
EthereumDataType.BOOL,
|
||||
EthereumDataType.ADDRESS,
|
||||
EDT.UINT,
|
||||
EDT.BOOL,
|
||||
EDT.ADDRESS,
|
||||
):
|
||||
write_leftpad32(w, value)
|
||||
else:
|
||||
@ -405,93 +420,90 @@ def write_leftpad32(w: HashWriter, value: bytes, signed: bool = False) -> None:
|
||||
w.extend(value)
|
||||
|
||||
|
||||
def write_rightpad32(w: HashWriter, value: bytes) -> None:
|
||||
assert len(value) <= 32
|
||||
|
||||
w.extend(value)
|
||||
for _ in range(32 - len(value)):
|
||||
w.append(0x00)
|
||||
|
||||
|
||||
def validate_value(field: EthereumFieldType, value: bytes) -> None:
|
||||
def _validate_value(field: EthereumFieldType, value: bytes) -> None:
|
||||
"""
|
||||
Make sure the byte data we receive are not corrupted or incorrect.
|
||||
|
||||
Raise wire.DataError if encountering a problem, so clients are notified.
|
||||
Raise DataError if encountering a problem, so clients are notified.
|
||||
"""
|
||||
# Checking if the size corresponds to what is defined in types,
|
||||
# and also setting our maximum supported size in bytes
|
||||
if field.size is not None:
|
||||
if len(value) != field.size:
|
||||
raise wire.DataError("Invalid length")
|
||||
raise DataError("Invalid length")
|
||||
else:
|
||||
if len(value) > _MAX_VALUE_BYTE_SIZE:
|
||||
raise wire.DataError(f"Invalid length, bigger than {_MAX_VALUE_BYTE_SIZE}")
|
||||
raise DataError(f"Invalid length, bigger than {_MAX_VALUE_BYTE_SIZE}")
|
||||
|
||||
# Specific tests for some data types
|
||||
if field.data_type == EthereumDataType.BOOL:
|
||||
if value not in (b"\x00", b"\x01"):
|
||||
raise wire.DataError("Invalid boolean value")
|
||||
raise DataError("Invalid boolean value")
|
||||
elif field.data_type == EthereumDataType.ADDRESS:
|
||||
if len(value) != 20:
|
||||
raise wire.DataError("Invalid address")
|
||||
raise DataError("Invalid address")
|
||||
elif field.data_type == EthereumDataType.STRING:
|
||||
try:
|
||||
value.decode()
|
||||
except UnicodeError:
|
||||
raise wire.DataError("Invalid UTF-8")
|
||||
raise DataError("Invalid UTF-8")
|
||||
|
||||
|
||||
def validate_field_type(field: EthereumFieldType) -> None:
|
||||
"""
|
||||
Make sure the field type is consistent with our expectation.
|
||||
|
||||
Raise wire.DataError if encountering a problem, so clients are notified.
|
||||
Raise DataError if encountering a problem, so clients are notified.
|
||||
"""
|
||||
EDT = EthereumDataType # local_cache_global
|
||||
|
||||
data_type = field.data_type
|
||||
|
||||
# entry_type is only for arrays
|
||||
if data_type == EthereumDataType.ARRAY:
|
||||
if data_type == EDT.ARRAY:
|
||||
if field.entry_type is None:
|
||||
raise wire.DataError("Missing entry_type in array")
|
||||
raise DataError("Missing entry_type in array")
|
||||
# We also need to validate it recursively
|
||||
validate_field_type(field.entry_type)
|
||||
else:
|
||||
if field.entry_type is not None:
|
||||
raise wire.DataError("Unexpected entry_type in nonarray")
|
||||
raise DataError("Unexpected entry_type in nonarray")
|
||||
|
||||
# struct_name is only for structs
|
||||
if data_type == EthereumDataType.STRUCT:
|
||||
if data_type == EDT.STRUCT:
|
||||
if field.struct_name is None:
|
||||
raise wire.DataError("Missing struct_name in struct")
|
||||
raise DataError("Missing struct_name in struct")
|
||||
else:
|
||||
if field.struct_name is not None:
|
||||
raise wire.DataError("Unexpected struct_name in nonstruct")
|
||||
raise DataError("Unexpected struct_name in nonstruct")
|
||||
size = field.size # local_cache_attribute
|
||||
|
||||
# size is special for each type
|
||||
if data_type == EthereumDataType.STRUCT:
|
||||
if field.size is None:
|
||||
raise wire.DataError("Missing size in struct")
|
||||
elif data_type == EthereumDataType.BYTES:
|
||||
if field.size is not None and not 1 <= field.size <= 32:
|
||||
raise wire.DataError("Invalid size in bytes")
|
||||
if data_type == EDT.STRUCT:
|
||||
if size is None:
|
||||
raise DataError("Missing size in struct")
|
||||
elif data_type == EDT.BYTES:
|
||||
if size is not None and not 1 <= size <= 32:
|
||||
raise DataError("Invalid size in bytes")
|
||||
elif data_type in (
|
||||
EthereumDataType.UINT,
|
||||
EthereumDataType.INT,
|
||||
EDT.UINT,
|
||||
EDT.INT,
|
||||
):
|
||||
if field.size is None or not 1 <= field.size <= 32:
|
||||
raise wire.DataError("Invalid size in int/uint")
|
||||
if size is None or not 1 <= size <= 32:
|
||||
raise DataError("Invalid size in int/uint")
|
||||
elif data_type in (
|
||||
EthereumDataType.STRING,
|
||||
EthereumDataType.BOOL,
|
||||
EthereumDataType.ADDRESS,
|
||||
EDT.STRING,
|
||||
EDT.BOOL,
|
||||
EDT.ADDRESS,
|
||||
):
|
||||
if field.size is not None:
|
||||
raise wire.DataError("Unexpected size in str/bool/addr")
|
||||
if size is not None:
|
||||
raise DataError("Unexpected size in str/bool/addr")
|
||||
|
||||
|
||||
async def get_array_size(ctx: Context, member_path: list[int]) -> int:
|
||||
async def _get_array_size(ctx: Context, member_path: list[int]) -> int:
|
||||
"""Get the length of an array at specific `member_path` from the client."""
|
||||
from trezor.messages import EthereumFieldType
|
||||
|
||||
# Field type for getting the array length from client, so we can check the return value
|
||||
ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2)
|
||||
length_value = await get_value(ctx, ARRAY_LENGTH_TYPE, member_path)
|
||||
@ -504,18 +516,20 @@ async def get_value(
|
||||
member_value_path: list[int],
|
||||
) -> bytes:
|
||||
"""Get a single value from the client and perform its validation."""
|
||||
from trezor.messages import EthereumTypedDataValueAck, EthereumTypedDataValueRequest
|
||||
|
||||
req = EthereumTypedDataValueRequest(
|
||||
member_path=member_value_path,
|
||||
)
|
||||
res = await ctx.call(req, EthereumTypedDataValueAck)
|
||||
value = res.value
|
||||
|
||||
validate_value(field=field, value=value)
|
||||
_validate_value(field=field, value=value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
async def get_name_and_version_for_domain(
|
||||
async def _get_name_and_version_for_domain(
|
||||
ctx: Context, typed_data_envelope: TypedDataEnvelope
|
||||
) -> tuple[bytes, bytes]:
|
||||
domain_name = b"unknown"
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -2,6 +2,20 @@
|
||||
# (by running `make templates` in `core`)
|
||||
# do not edit manually!
|
||||
# fmt: off
|
||||
|
||||
# NOTE: returning a tuple instead of `TokenInfo` from the "data" function
|
||||
# saves 5600 bytes of flash size. Implementing the `_token_iterator`
|
||||
# instead of if-tree approach saves another 5600 bytes.
|
||||
|
||||
# NOTE: interestingly, it did not save much flash size to use smaller
|
||||
# parts of the address, for example address length of 10 bytes saves
|
||||
# 1 byte per entry, so 1887 bytes overall (and further decrease does not help).
|
||||
# (The idea was not having to store the whole address, even a smaller part
|
||||
# of it has enough collision-resistance.)
|
||||
# (In the if-tree approach the address length did not have any effect whatsoever.)
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
<%
|
||||
from collections import defaultdict
|
||||
|
||||
@ -22,11 +36,20 @@ UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)
|
||||
|
||||
|
||||
def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo:
|
||||
for addr, symbol, decimal in _token_iterator(chain_id):
|
||||
if address == addr:
|
||||
return TokenInfo(symbol, decimal)
|
||||
return UNKNOWN_TOKEN
|
||||
|
||||
|
||||
def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int]]:
|
||||
% for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items():
|
||||
if chain_id == ${token_chain_id}:
|
||||
% for t in tokens:
|
||||
if address == ${black_repr(t.address_bytes)}:
|
||||
return TokenInfo(${black_repr(t.symbol)}, ${t.decimals}) # ${t.chain} / ${t.name.strip()}
|
||||
yield ( # address, symbol, decimals
|
||||
${black_repr(t.address_bytes)},
|
||||
${black_repr(t.symbol)},
|
||||
${t.decimals},
|
||||
)
|
||||
% endfor
|
||||
% endfor
|
||||
return UNKNOWN_TOKEN
|
||||
|
@ -1,42 +1,42 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import wire
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.messages import Success
|
||||
from trezor.ui.layouts import confirm_signverify, show_success
|
||||
|
||||
from apps.common.signverify import decode_message
|
||||
|
||||
from .helpers import address_from_bytes, bytes_from_address
|
||||
from .sign_message import message_digest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumVerifyMessage
|
||||
from trezor.messages import EthereumVerifyMessage, Success
|
||||
from trezor.wire import Context
|
||||
|
||||
|
||||
async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
|
||||
from trezor.wire import DataError
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.messages import Success
|
||||
from trezor.ui.layouts import confirm_signverify, show_success
|
||||
|
||||
from apps.common.signverify import decode_message
|
||||
|
||||
from .helpers import address_from_bytes, bytes_from_address
|
||||
from .sign_message import message_digest
|
||||
|
||||
digest = message_digest(msg.message)
|
||||
if len(msg.signature) != 65:
|
||||
raise wire.DataError("Invalid signature")
|
||||
raise DataError("Invalid signature")
|
||||
sig = bytearray([msg.signature[64]]) + msg.signature[:64]
|
||||
|
||||
pubkey = secp256k1.verify_recover(sig, digest)
|
||||
|
||||
if not pubkey:
|
||||
raise wire.DataError("Invalid signature")
|
||||
raise DataError("Invalid signature")
|
||||
|
||||
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
|
||||
|
||||
address_bytes = bytes_from_address(msg.address)
|
||||
if address_bytes != pkh:
|
||||
raise wire.DataError("Invalid signature")
|
||||
raise DataError("Invalid signature")
|
||||
|
||||
address = address_from_bytes(address_bytes)
|
||||
|
||||
await confirm_signverify(
|
||||
ctx, "ETH", decode_message(msg.message), address=address, verify=True
|
||||
ctx, "ETH", decode_message(msg.message), address, verify=True
|
||||
)
|
||||
|
||||
await show_success(ctx, "verify_message", "The signature is valid.")
|
||||
|
@ -11,7 +11,7 @@ from trezor.enums import EthereumDataType as EDT
|
||||
if not utils.BITCOIN_ONLY:
|
||||
from apps.ethereum.sign_typed_data import (
|
||||
encode_field,
|
||||
validate_value,
|
||||
_validate_value,
|
||||
validate_field_type,
|
||||
keccak256,
|
||||
TypedDataEnvelope,
|
||||
@ -594,10 +594,10 @@ class TestEthereumSignTypedData(unittest.TestCase):
|
||||
|
||||
for field, valid_values, invalid_values in VECTORS_VALID_INVALID:
|
||||
for valid_value in valid_values:
|
||||
validate_value(field=field, value=valid_value)
|
||||
_validate_value(field=field, value=valid_value)
|
||||
for invalid_value in invalid_values:
|
||||
with self.assertRaises(wire.DataError):
|
||||
validate_value(field=field, value=invalid_value)
|
||||
_validate_value(field=field, value=invalid_value)
|
||||
|
||||
def test_validate_field_type(self):
|
||||
ET = EFT(data_type=EDT.BYTES, size=8)
|
||||
|
Loading…
Reference in New Issue
Block a user