1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 06:18:07 +00:00

chore(core): decrease ethereum size by 17250 bytes

This commit is contained in:
grdddj 2022-09-19 13:11:08 +02:00 committed by matejcik
parent 0c3423b1c7
commit 26fd0de198
15 changed files with 11882 additions and 6084 deletions

View File

@ -1,16 +1,9 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address
from apps.common import paths
from . import networks
from .helpers import address_from_bytes
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumGetAddress from trezor.messages import EthereumGetAddress, EthereumAddress
from trezor.wire import Context from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@ -20,18 +13,26 @@ if TYPE_CHECKING:
async def get_address( async def get_address(
ctx: Context, msg: EthereumGetAddress, keychain: Keychain ctx: Context, msg: EthereumGetAddress, keychain: Keychain
) -> EthereumAddress: ) -> EthereumAddress:
await paths.validate_path(ctx, keychain, msg.address_n) from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address
from apps.common import paths
from . import networks
from .helpers import address_from_bytes
node = keychain.derive(msg.address_n) address_n = msg.address_n # local_cache_attribute
if len(msg.address_n) > 1: # path has slip44 network identifier await paths.validate_path(ctx, keychain, address_n)
network = networks.by_slip44(msg.address_n[1] & 0x7FFF_FFFF)
node = keychain.derive(address_n)
if len(address_n) > 1: # path has slip44 network identifier
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF)
else: else:
network = None network = None
address = address_from_bytes(node.ethereum_pubkeyhash(), network) address = address_from_bytes(node.ethereum_pubkeyhash(), network)
if msg.show_display: if msg.show_display:
title = paths.address_n_to_str(msg.address_n) title = paths.address_n_to_str(address_n)
await show_address(ctx, address=address, title=title) await show_address(ctx, address, title=title)
return EthereumAddress(address=address) return EthereumAddress(address=address)

View File

@ -1,15 +1,11 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor.messages import EthereumPublicKey, HDNodeType from apps.common import paths
from trezor.ui.layouts import show_pubkey
from apps.common import coins, paths
from .keychain import with_keychain_from_path from .keychain import with_keychain_from_path
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumGetPublicKey from trezor.messages import EthereumGetPublicKey, EthereumPublicKey
from trezor.wire import Context from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
@ -19,6 +15,11 @@ if TYPE_CHECKING:
async def get_public_key( async def get_public_key(
ctx: Context, msg: EthereumGetPublicKey, keychain: Keychain ctx: Context, msg: EthereumGetPublicKey, keychain: Keychain
) -> EthereumPublicKey: ) -> EthereumPublicKey:
from ubinascii import hexlify
from trezor.messages import EthereumPublicKey, HDNodeType
from trezor.ui.layouts import show_pubkey
from apps.common import coins
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

View File

@ -1,8 +1,5 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify, unhexlify from ubinascii import hexlify
from trezor import wire
from trezor.enums import EthereumDataType
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumFieldType from trezor.messages import EthereumFieldType
@ -24,7 +21,7 @@ def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None)
address_hex = hexlify(address_bytes).decode() address_hex = hexlify(address_bytes).decode()
digest = sha3_256((prefix + address_hex).encode(), keccak=True).digest() digest = sha3_256((prefix + address_hex).encode(), keccak=True).digest()
def maybe_upper(i: int) -> str: def _maybe_upper(i: int) -> str:
"""Uppercase i-th letter only if the corresponding nibble has high bit set.""" """Uppercase i-th letter only if the corresponding nibble has high bit set."""
digest_byte = digest[i // 2] digest_byte = digest[i // 2]
hex_letter = address_hex[i] hex_letter = address_hex[i]
@ -39,10 +36,13 @@ def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None)
else: else:
return hex_letter return hex_letter
return "0x" + "".join(maybe_upper(i) for i in range(len(address_hex))) return "0x" + "".join(_maybe_upper(i) for i in range(len(address_hex)))
def bytes_from_address(address: str) -> bytes: def bytes_from_address(address: str) -> bytes:
from ubinascii import unhexlify
from trezor import wire
if len(address) == 40: if len(address) == 40:
return unhexlify(address) return unhexlify(address)
@ -59,6 +59,8 @@ def bytes_from_address(address: str) -> bytes:
def get_type_name(field: EthereumFieldType) -> str: def get_type_name(field: EthereumFieldType) -> str:
"""Create a string from type definition (like uint256 or bytes16).""" """Create a string from type definition (like uint256 or bytes16)."""
from trezor.enums import EthereumDataType
data_type = field.data_type data_type = field.data_type
size = field.size size = field.size
@ -109,12 +111,12 @@ def decode_typed_data(data: bytes, type_name: str) -> str:
return str(int.from_bytes(data, "big")) return str(int.from_bytes(data, "big"))
elif type_name.startswith("int"): elif type_name.startswith("int"):
# Micropython does not implement "signed" arg in int.from_bytes() # Micropython does not implement "signed" arg in int.from_bytes()
return str(from_bytes_bigendian_signed(data)) return str(_from_bytes_bigendian_signed(data))
raise ValueError # Unsupported data type for direct field decoding raise ValueError # Unsupported data type for direct field decoding
def from_bytes_bigendian_signed(b: bytes) -> int: def _from_bytes_bigendian_signed(b: bytes) -> int:
negative = b[0] & 0x80 negative = b[0] & 0x80
if negative: if negative:
neg_b = bytearray(b) neg_b = bytearray(b)

View File

@ -1,7 +1,5 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from apps.common import paths from apps.common import paths
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
@ -10,6 +8,8 @@ from . import CURVE, networks
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, TypeVar from typing import Callable, Iterable, TypeVar
from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
EthereumGetAddress, EthereumGetAddress,
EthereumGetPublicKey, EthereumGetPublicKey,
@ -64,7 +64,7 @@ def with_keychain_from_path(
*patterns: str, *patterns: str,
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: ) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
schemas = _schemas_from_address_n(patterns, msg.address_n) schemas = _schemas_from_address_n(patterns, msg.address_n)
keychain = await get_keychain(ctx, CURVE, schemas) keychain = await get_keychain(ctx, CURVE, schemas)
with keychain: with keychain:
@ -97,7 +97,7 @@ def with_keychain_from_chain_id(
func: HandlerWithKeychain[MsgInChainId, MsgOut] func: HandlerWithKeychain[MsgInChainId, MsgOut]
) -> Handler[MsgInChainId, MsgOut]: ) -> Handler[MsgInChainId, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed # this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: wire.Context, msg: MsgInChainId) -> MsgOut: async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut:
schemas = _schemas_from_chain_id(msg) schemas = _schemas_from_chain_id(msg)
keychain = await get_keychain(ctx, CURVE, schemas) keychain = await get_keychain(ctx, CURVE, schemas)
with keychain: with keychain:

View File

@ -1,29 +1,19 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import ui from trezor import ui
from trezor.enums import ButtonRequestType, EthereumDataType from trezor.enums import ButtonRequestType
from trezor.strings import format_amount, format_plural from trezor.strings import format_plural
from trezor.ui.layouts import ( from trezor.ui.layouts import confirm_blob, confirm_text, should_show_more
confirm_action,
confirm_address,
confirm_amount,
confirm_blob,
confirm_output,
confirm_text,
confirm_total,
should_show_more,
)
from trezor.ui.layouts.altcoin import confirm_total_ethereum
from . import networks, tokens from . import networks
from .helpers import address_from_bytes, decode_typed_data, get_type_name from .helpers import decode_typed_data
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Awaitable, Iterable from typing import Awaitable, Iterable
from trezor.messages import EthereumFieldType, EthereumStructMember from trezor.messages import EthereumFieldType, EthereumStructMember
from trezor.wire import Context from trezor.wire import Context
from . import tokens
def require_confirm_tx( def require_confirm_tx(
@ -33,15 +23,18 @@ def require_confirm_tx(
chain_id: int, chain_id: int,
token: tokens.TokenInfo | None = None, token: tokens.TokenInfo | None = None,
) -> Awaitable[None]: ) -> Awaitable[None]:
from .helpers import address_from_bytes
from trezor.ui.layouts import confirm_output
if to_bytes: if to_bytes:
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id)) to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
else: else:
to_str = "new contract?" to_str = "new contract?"
return confirm_output( return confirm_output(
ctx, ctx,
address=to_str, to_str,
amount=format_ethereum_amount(value, token, chain_id), format_ethereum_amount(value, token, chain_id),
font_amount=ui.BOLD, ui.BOLD,
color_to=ui.GREY, color_to=ui.GREY,
br_code=ButtonRequestType.SignTx, br_code=ButtonRequestType.SignTx,
) )
@ -55,6 +48,8 @@ def require_confirm_fee(
chain_id: int, chain_id: int,
token: tokens.TokenInfo | None = None, token: tokens.TokenInfo | None = None,
) -> Awaitable[None]: ) -> Awaitable[None]:
from trezor.ui.layouts.altcoin import confirm_total_ethereum
return confirm_total_ethereum( return confirm_total_ethereum(
ctx, ctx,
format_ethereum_amount(spending, token, chain_id), format_ethereum_amount(spending, token, chain_id),
@ -72,22 +67,24 @@ async def require_confirm_eip1559_fee(
chain_id: int, chain_id: int,
token: tokens.TokenInfo | None = None, token: tokens.TokenInfo | None = None,
) -> None: ) -> None:
from trezor.ui.layouts import confirm_amount, confirm_total
await confirm_amount( await confirm_amount(
ctx, ctx,
title="Confirm fee", "Confirm fee",
description="Maximum fee per gas", format_ethereum_amount(max_gas_fee, None, chain_id),
amount=format_ethereum_amount(max_gas_fee, None, chain_id), "Maximum fee per gas",
) )
await confirm_amount( await confirm_amount(
ctx, ctx,
title="Confirm fee", "Confirm fee",
description="Priority fee per gas", format_ethereum_amount(max_priority_fee, None, chain_id),
amount=format_ethereum_amount(max_priority_fee, None, chain_id), "Priority fee per gas",
) )
await confirm_total( await confirm_total(
ctx, ctx,
total_amount=format_ethereum_amount(spending, token, chain_id), format_ethereum_amount(spending, token, chain_id),
fee_amount=format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id), format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
total_label="Amount sent:\n", total_label="Amount sent:\n",
fee_label="\nMaximum fee:\n", fee_label="\nMaximum fee:\n",
) )
@ -96,13 +93,16 @@ async def require_confirm_eip1559_fee(
def require_confirm_unknown_token( def require_confirm_unknown_token(
ctx: Context, address_bytes: bytes ctx: Context, address_bytes: bytes
) -> Awaitable[None]: ) -> Awaitable[None]:
from ubinascii import hexlify
from trezor.ui.layouts import confirm_address
contract_address_hex = "0x" + hexlify(address_bytes).decode() contract_address_hex = "0x" + hexlify(address_bytes).decode()
return confirm_address( return confirm_address(
ctx, ctx,
"Unknown token", "Unknown token",
contract_address_hex, contract_address_hex,
description="Contract:", "Contract:",
br_type="unknown_token", "unknown_token",
icon_color=ui.ORANGE, icon_color=ui.ORANGE,
br_code=ButtonRequestType.SignTx, br_code=ButtonRequestType.SignTx,
) )
@ -112,20 +112,22 @@ def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitabl
return confirm_blob( return confirm_blob(
ctx, ctx,
"confirm_data", "confirm_data",
title="Confirm data", "Confirm data",
description=f"Size: {data_total} bytes", data,
data=data, f"Size: {data_total} bytes",
br_code=ButtonRequestType.SignTx, br_code=ButtonRequestType.SignTx,
ask_pagination=True, ask_pagination=True,
) )
async def confirm_typed_data_final(ctx: Context) -> None: async def confirm_typed_data_final(ctx: Context) -> None:
from trezor.ui.layouts import confirm_action
await confirm_action( await confirm_action(
ctx, ctx,
"confirm_typed_data_final", "confirm_typed_data_final",
title="Confirm typed data", "Confirm typed data",
action="Really sign EIP-712 typed data?", "Really sign EIP-712 typed data?",
verb="Hold to confirm", verb="Hold to confirm",
hold=True, hold=True,
) )
@ -135,9 +137,9 @@ def confirm_empty_typed_message(ctx: Context) -> Awaitable[None]:
return confirm_text( return confirm_text(
ctx, ctx,
"confirm_empty_typed_message", "confirm_empty_typed_message",
title="Confirm message", "Confirm message",
data="", "",
description="No message field", "No message field",
) )
@ -152,10 +154,10 @@ async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
) )
return await should_show_more( return await should_show_more(
ctx, ctx,
title="Confirm domain", "Confirm domain",
para=para, para,
button_text="Show full domain", "Show full domain",
br_type="should_show_domain", "should_show_domain",
) )
@ -176,10 +178,10 @@ async def should_show_struct(
) )
return await should_show_more( return await should_show_more(
ctx, ctx,
title=title, title,
para=para, para,
button_text=button_text, button_text,
br_type="should_show_struct", "should_show_struct",
) )
@ -192,10 +194,10 @@ async def should_show_array(
para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),) para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),)
return await should_show_more( return await should_show_more(
ctx, ctx,
title=limit_str(".".join(parent_objects)), limit_str(".".join(parent_objects)),
para=para, para,
button_text="Show full array", "Show full array",
br_type="should_show_array", "should_show_array",
) )
@ -207,6 +209,9 @@ async def confirm_typed_value(
field: EthereumFieldType, field: EthereumFieldType,
array_index: int | None = None, array_index: int | None = None,
) -> None: ) -> None:
from trezor.enums import EthereumDataType
from .helpers import get_type_name
type_name = get_type_name(field) type_name = get_type_name(field)
if array_index is not None: if array_index is not None:
@ -222,24 +227,26 @@ async def confirm_typed_value(
await confirm_blob( await confirm_blob(
ctx, ctx,
"confirm_typed_value", "confirm_typed_value",
title=title, title,
data=data, data,
description=description, description,
ask_pagination=True, ask_pagination=True,
) )
else: else:
await confirm_text( await confirm_text(
ctx, ctx,
"confirm_typed_value", "confirm_typed_value",
title=title, title,
data=data, data,
description=description, description,
) )
def format_ethereum_amount( def format_ethereum_amount(
value: int, token: tokens.TokenInfo | None, chain_id: int value: int, token: tokens.TokenInfo | None, chain_id: int
) -> str: ) -> str:
from trezor.strings import format_amount
if token: if token:
suffix = token.symbol suffix = token.symbol
decimals = token.decimals decimals = token.decimals

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,28 @@
# generated from networks.py.mako # generated from networks.py.mako
# (by running `make templates` in `core`) # (by running `make templates` in `core`)
# do not edit manually! # do not edit manually!
from typing import Iterator
# NOTE: using positional arguments saves 4400 bytes in flash size,
# returning tuples instead of classes saved 800 bytes
from typing import TYPE_CHECKING
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
if TYPE_CHECKING:
from typing import Iterator
# Removing the necessity to construct object to save space
# fmt: off
NetworkInfoTuple = tuple[
int, # chain_id
int, # slip44
str, # shortcut
str, # name
bool # rskip60
]
# fmt: on
def shortcut_by_chain_id(chain_id: int) -> str: def shortcut_by_chain_id(chain_id: int) -> str:
n = by_chain_id(chain_id) n = by_chain_id(chain_id)
@ -13,21 +31,36 @@ def shortcut_by_chain_id(chain_id: int) -> str:
def by_chain_id(chain_id: int) -> "NetworkInfo" | None: def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
for n in _networks_iterator(): for n in _networks_iterator():
if n.chain_id == chain_id: n_chain_id = n[0]
return n if n_chain_id == chain_id:
return NetworkInfo(
chain_id=n[0],
slip44=n[1],
shortcut=n[2],
name=n[3],
rskip60=n[4],
)
return None return None
def by_slip44(slip44: int) -> "NetworkInfo" | None: def by_slip44(slip44: int) -> "NetworkInfo" | None:
for n in _networks_iterator(): for n in _networks_iterator():
if n.slip44 == slip44: n_slip44 = n[1]
return n if n_slip44 == slip44:
return NetworkInfo(
chain_id=n[0],
slip44=n[1],
shortcut=n[2],
name=n[3],
rskip60=n[4],
)
return None return None
def all_slip44_ids_hardened() -> Iterator[int]: def all_slip44_ids_hardened() -> Iterator[int]:
for n in _networks_iterator(): for n in _networks_iterator():
yield n.slip44 | HARDENED # n_slip_44 is the second element
yield n[1] | HARDENED
class NetworkInfo: class NetworkInfo:
@ -42,13 +75,13 @@ class NetworkInfo:
# fmt: off # fmt: off
def _networks_iterator() -> Iterator[NetworkInfo]: def _networks_iterator() -> Iterator[NetworkInfoTuple]:
% for n in supported_on("trezor2", eth): % for n in supported_on("trezor2", eth):
yield NetworkInfo( yield (
chain_id=${n.chain_id}, ${n.chain_id}, # chain_id
slip44=${n.slip44}, ${n.slip44}, # slip44
shortcut="${n.shortcut}", "${n.shortcut}", # shortcut
name="${n.name}", "${n.name}", # name
rskip60=${n.rskip60}, ${n.rskip60}, # rskip60
) )
% endfor % endfor

View File

@ -1,25 +1,18 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import EthereumMessageSignature
from trezor.ui.layouts import confirm_signverify
from trezor.utils import HashWriter
from apps.common import paths
from apps.common.signverify import decode_message
from .helpers import address_from_bytes
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumSignMessage from trezor.messages import EthereumSignMessage, EthereumMessageSignature
from trezor.wire import Context from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
def message_digest(message: bytes) -> bytes: def message_digest(message: bytes) -> bytes:
from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter
h = HashWriter(sha3_256(keccak=True)) h = HashWriter(sha3_256(keccak=True))
signed_message_header = b"\x19Ethereum Signed Message:\n" signed_message_header = b"\x19Ethereum Signed Message:\n"
h.extend(signed_message_header) h.extend(signed_message_header)
@ -32,6 +25,15 @@ def message_digest(message: bytes) -> bytes:
async def sign_message( async def sign_message(
ctx: Context, msg: EthereumSignMessage, keychain: Keychain ctx: Context, msg: EthereumSignMessage, keychain: Keychain
) -> EthereumMessageSignature: ) -> EthereumMessageSignature:
from trezor.crypto.curve import secp256k1
from trezor.messages import EthereumMessageSignature
from trezor.ui.layouts import confirm_signverify
from apps.common import paths
from apps.common.signverify import decode_message
from .helpers import address_from_bytes
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

View File

@ -1,29 +1,19 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto import rlp from trezor.crypto import rlp
from trezor.crypto.curve import secp256k1 from trezor.messages import EthereumTxRequest
from trezor.crypto.hashlib import sha3_256 from trezor.wire import DataError
from trezor.messages import EthereumTxAck, EthereumTxRequest
from trezor.utils import HashWriter
from apps.common import paths
from . import tokens
from .helpers import bytes_from_address from .helpers import bytes_from_address
from .keychain import with_keychain_from_chain_id from .keychain import with_keychain_from_chain_id
from .layout import (
require_confirm_data,
require_confirm_fee,
require_confirm_tx,
require_confirm_unknown_token,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.messages import EthereumSignTx from trezor.messages import EthereumSignTx, EthereumTxAck
from trezor.wire import Context
from .keychain import EthereumSignTxAny from .keychain import EthereumSignTxAny
from . import tokens
# Maximum chain_id which returns the full signature_v (which must fit into an uint32). # Maximum chain_id which returns the full signature_v (which must fit into an uint32).
@ -34,9 +24,24 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id @with_keychain_from_chain_id
async def sign_tx( async def sign_tx(
ctx: wire.Context, msg: EthereumSignTx, keychain: Keychain ctx: Context, msg: EthereumSignTx, keychain: Keychain
) -> EthereumTxRequest: ) -> EthereumTxRequest:
check(msg) from trezor.utils import HashWriter
from trezor.crypto.hashlib import sha3_256
from apps.common import paths
from .layout import (
require_confirm_data,
require_confirm_fee,
require_confirm_tx,
)
# check
if msg.tx_type not in [1, 6, None]:
raise DataError("tx_type out of bounds")
if len(msg.gas_price) + len(msg.gas_limit) > 30:
raise DataError("Fee overflow")
check_common_fields(msg)
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s # Handle ERC20s
@ -61,7 +66,7 @@ async def sign_tx(
data += msg.data_initial_chunk data += msg.data_initial_chunk
data_left = data_total - len(msg.data_initial_chunk) data_left = data_total - len(msg.data_initial_chunk)
total_length = get_total_length(msg, data_total) total_length = _get_total_length(msg, data_total)
sha = HashWriter(sha3_256(keccak=True)) sha = HashWriter(sha3_256(keccak=True))
rlp.write_header(sha, total_length, rlp.LIST_HEADER_BYTE) rlp.write_header(sha, total_length, rlp.LIST_HEADER_BYTE)
@ -89,14 +94,19 @@ async def sign_tx(
rlp.write(sha, 0) rlp.write(sha, 0)
digest = sha.get_digest() digest = sha.get_digest()
result = sign_digest(msg, keychain, digest) result = _sign_digest(msg, keychain, digest)
return result return result
async def handle_erc20( async def handle_erc20(
ctx: wire.Context, msg: EthereumSignTxAny ctx: Context, msg: EthereumSignTxAny
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]: ) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
from .layout import require_confirm_unknown_token
from . import tokens
data_initial_chunk = msg.data_initial_chunk # local_cache_attribute
token = None token = None
address_bytes = recipient = bytes_from_address(msg.to) address_bytes = recipient = bytes_from_address(msg.to)
value = int.from_bytes(msg.value, "big") value = int.from_bytes(msg.value, "big")
@ -104,13 +114,13 @@ async def handle_erc20(
len(msg.to) in (40, 42) len(msg.to) in (40, 42)
and len(msg.value) == 0 and len(msg.value) == 0
and msg.data_length == 68 and msg.data_length == 68
and len(msg.data_initial_chunk) == 68 and len(data_initial_chunk) == 68
and msg.data_initial_chunk[:16] and data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" == b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
): ):
token = tokens.token_by_chain_address(msg.chain_id, address_bytes) token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
recipient = msg.data_initial_chunk[16:36] recipient = data_initial_chunk[16:36]
value = int.from_bytes(msg.data_initial_chunk[36:68], "big") value = int.from_bytes(data_initial_chunk[36:68], "big")
if token is tokens.UNKNOWN_TOKEN: if token is tokens.UNKNOWN_TOKEN:
await require_confirm_unknown_token(ctx, address_bytes) await require_confirm_unknown_token(ctx, address_bytes)
@ -118,7 +128,7 @@ async def handle_erc20(
return token, address_bytes, recipient, value return token, address_bytes, recipient, value
def get_total_length(msg: EthereumSignTx, data_total: int) -> int: def _get_total_length(msg: EthereumSignTx, data_total: int) -> int:
length = 0 length = 0
if msg.tx_type is not None: if msg.tx_type is not None:
length += rlp.length(msg.tx_type) length += rlp.length(msg.tx_type)
@ -143,20 +153,20 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
return length return length
async def send_request_chunk(ctx: wire.Context, data_left: int) -> EthereumTxAck: async def send_request_chunk(ctx: Context, data_left: int) -> EthereumTxAck:
from trezor.messages import EthereumTxAck
# TODO: layoutProgress ? # TODO: layoutProgress ?
req = EthereumTxRequest() req = EthereumTxRequest()
if data_left <= 1024: req.data_length = min(data_left, 1024)
req.data_length = data_left
else:
req.data_length = 1024
return await ctx.call(req, EthereumTxAck) return await ctx.call(req, EthereumTxAck)
def sign_digest( def _sign_digest(
msg: EthereumSignTx, keychain: Keychain, digest: bytes msg: EthereumSignTx, keychain: Keychain, digest: bytes
) -> EthereumTxRequest: ) -> EthereumTxRequest:
from trezor.crypto.curve import secp256k1
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
signature = secp256k1.sign( signature = secp256k1.sign(
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
@ -175,33 +185,25 @@ def sign_digest(
return req return req
def check(msg: EthereumSignTx) -> None:
if msg.tx_type not in [1, 6, None]:
raise wire.DataError("tx_type out of bounds")
if len(msg.gas_price) + len(msg.gas_limit) > 30:
raise wire.DataError("Fee overflow")
check_common_fields(msg)
def check_common_fields(msg: EthereumSignTxAny) -> None: def check_common_fields(msg: EthereumSignTxAny) -> None:
if msg.data_length > 0: data_length = msg.data_length # local_cache_attribute
if data_length > 0:
if not msg.data_initial_chunk: if not msg.data_initial_chunk:
raise wire.DataError("Data length provided, but no initial chunk") raise DataError("Data length provided, but no initial chunk")
# Our encoding only supports transactions up to 2^24 bytes. To # Our encoding only supports transactions up to 2^24 bytes. To
# prevent exceeding the limit we use a stricter limit on data length. # prevent exceeding the limit we use a stricter limit on data length.
if msg.data_length > 16_000_000: if data_length > 16_000_000:
raise wire.DataError("Data length exceeds limit") raise DataError("Data length exceeds limit")
if len(msg.data_initial_chunk) > msg.data_length: if len(msg.data_initial_chunk) > data_length:
raise wire.DataError("Invalid size of initial chunk") raise DataError("Invalid size of initial chunk")
if len(msg.to) not in (0, 40, 42): if len(msg.to) not in (0, 40, 42):
raise wire.DataError("Invalid recipient address") raise DataError("Invalid recipient address")
if not msg.to and msg.data_length == 0: if not msg.to and data_length == 0:
# sending transaction to address 0 (contract creation) without a data field # sending transaction to address 0 (contract creation) without a data field
raise wire.DataError("Contract creation without data") raise DataError("Contract creation without data")
if msg.chain_id == 0: if msg.chain_id == 0:
raise wire.DataError("Chain ID out of bounds") raise DataError("Chain ID out of bounds")

View File

@ -1,28 +1,21 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto import rlp from trezor.crypto import rlp
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import EthereumAccessList, EthereumTxRequest
from trezor.utils import HashWriter
from apps.common import paths
from .helpers import bytes_from_address from .helpers import bytes_from_address
from .keychain import with_keychain_from_chain_id from .keychain import with_keychain_from_chain_id
from .layout import (
require_confirm_data,
require_confirm_eip1559_fee,
require_confirm_tx,
)
from .sign_tx import check_common_fields, handle_erc20, send_request_chunk
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumSignTxEIP1559 from trezor.messages import (
EthereumSignTxEIP1559,
EthereumAccessList,
EthereumTxRequest,
)
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context
_TX_TYPE = const(2) _TX_TYPE = const(2)
@ -35,28 +28,30 @@ def access_list_item_length(item: EthereumAccessList) -> int:
) )
def access_list_length(access_list: list[EthereumAccessList]) -> int:
payload_length = sum(access_list_item_length(i) for i in access_list)
return rlp.header_length(payload_length) + payload_length
def write_access_list(w: HashWriter, access_list: list[EthereumAccessList]) -> None:
payload_length = sum(access_list_item_length(i) for i in access_list)
rlp.write_header(w, payload_length, rlp.LIST_HEADER_BYTE)
for item in access_list:
address_bytes = bytes_from_address(item.address)
address_length = rlp.length(address_bytes)
keys_length = rlp.length(item.storage_keys)
rlp.write_header(w, address_length + keys_length, rlp.LIST_HEADER_BYTE)
rlp.write(w, address_bytes)
rlp.write(w, item.storage_keys)
@with_keychain_from_chain_id @with_keychain_from_chain_id
async def sign_tx_eip1559( async def sign_tx_eip1559(
ctx: wire.Context, msg: EthereumSignTxEIP1559, keychain: Keychain ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain
) -> EthereumTxRequest: ) -> EthereumTxRequest:
check(msg) from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter
from trezor import wire
from trezor.crypto import rlp # local_cache_global
from apps.common import paths
from .layout import (
require_confirm_data,
require_confirm_eip1559_fee,
require_confirm_tx,
)
from .sign_tx import handle_erc20, send_request_chunk, check_common_fields
gas_limit = msg.gas_limit # local_cache_attribute
# check
if len(msg.max_gas_fee) + len(gas_limit) > 30:
raise wire.DataError("Fee overflow")
if len(msg.max_priority_fee) + len(gas_limit) > 30:
raise wire.DataError("Fee overflow")
check_common_fields(msg)
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
@ -74,7 +69,7 @@ async def sign_tx_eip1559(
value, value,
int.from_bytes(msg.max_priority_fee, "big"), int.from_bytes(msg.max_priority_fee, "big"),
int.from_bytes(msg.max_gas_fee, "big"), int.from_bytes(msg.max_gas_fee, "big"),
int.from_bytes(msg.gas_limit, "big"), int.from_bytes(gas_limit, "big"),
msg.chain_id, msg.chain_id,
token, token,
) )
@ -83,7 +78,7 @@ async def sign_tx_eip1559(
data += msg.data_initial_chunk data += msg.data_initial_chunk
data_left = data_total - len(msg.data_initial_chunk) data_left = data_total - len(msg.data_initial_chunk)
total_length = get_total_length(msg, data_total) total_length = _get_total_length(msg, data_total)
sha = HashWriter(sha3_256(keccak=True)) sha = HashWriter(sha3_256(keccak=True))
@ -96,7 +91,7 @@ async def sign_tx_eip1559(
msg.nonce, msg.nonce,
msg.max_priority_fee, msg.max_priority_fee,
msg.max_gas_fee, msg.max_gas_fee,
msg.gas_limit, gas_limit,
address_bytes, address_bytes,
msg.value, msg.value,
) )
@ -114,15 +109,24 @@ async def sign_tx_eip1559(
data_left -= len(resp.data_chunk) data_left -= len(resp.data_chunk)
sha.extend(resp.data_chunk) sha.extend(resp.data_chunk)
write_access_list(sha, msg.access_list) # write_access_list
payload_length = sum(access_list_item_length(i) for i in msg.access_list)
rlp.write_header(sha, payload_length, rlp.LIST_HEADER_BYTE)
for item in msg.access_list:
address_bytes = bytes_from_address(item.address)
address_length = rlp.length(address_bytes)
keys_length = rlp.length(item.storage_keys)
rlp.write_header(sha, address_length + keys_length, rlp.LIST_HEADER_BYTE)
rlp.write(sha, address_bytes)
rlp.write(sha, item.storage_keys)
digest = sha.get_digest() digest = sha.get_digest()
result = sign_digest(msg, keychain, digest) result = _sign_digest(msg, keychain, digest)
return result return result
def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int: def _get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
length = 0 length = 0
fields: tuple[rlp.RLPItem, ...] = ( fields: tuple[rlp.RLPItem, ...] = (
@ -140,14 +144,21 @@ def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
length += rlp.header_length(data_total, msg.data_initial_chunk) length += rlp.header_length(data_total, msg.data_initial_chunk)
length += data_total length += data_total
length += access_list_length(msg.access_list) # access_list_length
payload_length = sum(access_list_item_length(i) for i in msg.access_list)
access_list_length = rlp.header_length(payload_length) + payload_length
length += access_list_length
return length return length
def sign_digest( def _sign_digest(
msg: EthereumSignTxEIP1559, keychain: Keychain, digest: bytes msg: EthereumSignTxEIP1559, keychain: Keychain, digest: bytes
) -> EthereumTxRequest: ) -> EthereumTxRequest:
from trezor.messages import EthereumTxRequest
from trezor.crypto.curve import secp256k1
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
signature = secp256k1.sign( signature = secp256k1.sign(
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
@ -159,12 +170,3 @@ def sign_digest(
req.signature_s = signature[33:] req.signature_s = signature[33:]
return req return req
def check(msg: EthereumSignTxEIP1559) -> None:
if len(msg.max_gas_fee) + len(msg.gas_limit) > 30:
raise wire.DataError("Fee overflow")
if len(msg.max_priority_fee) + len(msg.gas_limit) > 30:
raise wire.DataError("Fee overflow")
check_common_fields(msg)

View File

@ -1,38 +1,24 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.enums import EthereumDataType from trezor.enums import EthereumDataType
from trezor.messages import ( from trezor.wire import DataError
EthereumFieldType,
EthereumTypedDataSignature,
EthereumTypedDataStructAck,
EthereumTypedDataStructRequest,
EthereumTypedDataValueAck,
EthereumTypedDataValueRequest,
)
from trezor.utils import HashWriter
from apps.common import paths from .helpers import get_type_name
from .helpers import address_from_bytes, get_type_name
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
from .layout import ( from .layout import should_show_struct
confirm_empty_typed_message,
confirm_typed_data_final,
confirm_typed_value,
should_show_array,
should_show_domain,
should_show_struct,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context from trezor.wire import Context
from trezor.utils import HashWriter
from trezor.messages import EthereumSignTypedData from trezor.messages import (
EthereumSignTypedData,
EthereumFieldType,
EthereumTypedDataSignature,
EthereumTypedDataStructAck,
)
# Maximum data size we support # Maximum data size we support
@ -43,9 +29,14 @@ _MAX_VALUE_BYTE_SIZE = const(1024)
async def sign_typed_data( async def sign_typed_data(
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain ctx: Context, msg: EthereumSignTypedData, keychain: Keychain
) -> EthereumTypedDataSignature: ) -> EthereumTypedDataSignature:
from trezor.crypto.curve import secp256k1
from apps.common import paths
from .helpers import address_from_bytes
from trezor.messages import EthereumTypedDataSignature
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
data_hash = await generate_typed_data_hash( data_hash = await _generate_typed_data_hash(
ctx, msg.primary_type, msg.metamask_v4_compat ctx, msg.primary_type, msg.metamask_v4_compat
) )
@ -60,7 +51,7 @@ async def sign_typed_data(
) )
async def generate_typed_data_hash( async def _generate_typed_data_hash(
ctx: Context, primary_type: str, metamask_v4_compat: bool = True ctx: Context, primary_type: str, metamask_v4_compat: bool = True
) -> bytes: ) -> bytes:
""" """
@ -69,20 +60,26 @@ async def generate_typed_data_hash(
metamask_v4_compat - a flag that enables compatibility with MetaMask's signTypedData_v4 method metamask_v4_compat - a flag that enables compatibility with MetaMask's signTypedData_v4 method
""" """
from .layout import (
confirm_empty_typed_message,
confirm_typed_data_final,
should_show_domain,
)
typed_data_envelope = TypedDataEnvelope( typed_data_envelope = TypedDataEnvelope(
ctx=ctx, ctx,
primary_type=primary_type, primary_type,
metamask_v4_compat=metamask_v4_compat, metamask_v4_compat,
) )
await typed_data_envelope.collect_types() await typed_data_envelope.collect_types()
name, version = await get_name_and_version_for_domain(ctx, typed_data_envelope) name, version = await _get_name_and_version_for_domain(ctx, typed_data_envelope)
show_domain = await should_show_domain(ctx, name, version) show_domain = await should_show_domain(ctx, name, version)
domain_separator = await typed_data_envelope.hash_struct( domain_separator = await typed_data_envelope.hash_struct(
primary_type="EIP712Domain", "EIP712Domain",
member_path=[0], [0],
show_data=show_domain, show_domain,
parent_objects=["EIP712Domain"], ["EIP712Domain"],
) )
# Setting the primary_type to "EIP712Domain" is technically in spec # Setting the primary_type to "EIP712Domain" is technically in spec
@ -94,16 +91,16 @@ async def generate_typed_data_hash(
else: else:
show_message = await should_show_struct( show_message = await should_show_struct(
ctx, ctx,
description=primary_type, primary_type,
data_members=typed_data_envelope.types[primary_type].members, typed_data_envelope.types[primary_type].members,
title="Confirm message", "Confirm message",
button_text="Show full message", "Show full message",
) )
message_hash = await typed_data_envelope.hash_struct( message_hash = await typed_data_envelope.hash_struct(
primary_type=primary_type, primary_type,
member_path=[1], [1],
show_data=show_message, show_message,
parent_objects=[primary_type], [primary_type],
) )
await confirm_typed_data_final(ctx) await confirm_typed_data_final(ctx)
@ -112,6 +109,9 @@ async def generate_typed_data_hash(
def get_hash_writer() -> HashWriter: def get_hash_writer() -> HashWriter:
from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter
return HashWriter(sha3_256(keccak=True)) return HashWriter(sha3_256(keccak=True))
@ -142,6 +142,11 @@ class TypedDataEnvelope:
async def _collect_types(self, type_name: str) -> None: async def _collect_types(self, type_name: str) -> None:
"""Recursively collect types from the client.""" """Recursively collect types from the client."""
from trezor.messages import (
EthereumTypedDataStructRequest,
EthereumTypedDataStructAck,
)
req = EthereumTypedDataStructRequest(name=type_name) req = EthereumTypedDataStructRequest(name=type_name)
current_type = await self.ctx.call(req, EthereumTypedDataStructAck) current_type = await self.ctx.call(req, EthereumTypedDataStructAck)
self.types[type_name] = current_type self.types[type_name] = current_type
@ -169,11 +174,11 @@ class TypedDataEnvelope:
w = get_hash_writer() w = get_hash_writer()
self.hash_type(w, primary_type) self.hash_type(w, primary_type)
await self.get_and_encode_data( await self.get_and_encode_data(
w=w, w,
primary_type=primary_type, primary_type,
member_path=member_path, member_path,
show_data=show_data, show_data,
parent_objects=parent_objects, parent_objects,
) )
return w.get_digest() return w.get_digest()
@ -242,6 +247,10 @@ class TypedDataEnvelope:
i.e. the concatenation of the encoded member values in the order that they appear in the type. i.e. the concatenation of the encoded member values in the order that they appear in the type.
Each encoded member value is exactly 32-byte long. Each encoded member value is exactly 32-byte long.
""" """
from .layout import confirm_typed_value, should_show_array
ctx = self.ctx # local_cache_attribute
type_members = self.types[primary_type].members type_members = self.types[primary_type].members
member_value_path = member_path + [0] member_value_path = member_path + [0]
current_parent_objects = parent_objects + [""] current_parent_objects = parent_objects + [""]
@ -258,25 +267,25 @@ class TypedDataEnvelope:
if show_data: if show_data:
show_struct = await should_show_struct( show_struct = await should_show_struct(
ctx=self.ctx, ctx,
description=struct_name, struct_name, # description
data_members=self.types[struct_name].members, self.types[struct_name].members, # data_members
title=".".join(current_parent_objects), ".".join(current_parent_objects), # title
) )
else: else:
show_struct = False show_struct = False
res = await self.hash_struct( res = await self.hash_struct(
primary_type=struct_name, struct_name,
member_path=member_value_path, member_value_path,
show_data=show_struct, show_struct,
parent_objects=current_parent_objects, current_parent_objects,
) )
w.extend(res) w.extend(res)
elif field_type.data_type == EthereumDataType.ARRAY: elif field_type.data_type == EthereumDataType.ARRAY:
# Getting the length of the array first, if not fixed # Getting the length of the array first, if not fixed
if field_type.size is None: if field_type.size is None:
array_size = await get_array_size(self.ctx, member_value_path) array_size = await _get_array_size(ctx, member_value_path)
else: else:
array_size = field_type.size array_size = field_type.size
@ -286,10 +295,10 @@ class TypedDataEnvelope:
if show_data: if show_data:
show_array = await should_show_array( show_array = await should_show_array(
ctx=self.ctx, ctx,
parent_objects=current_parent_objects, current_parent_objects,
data_type=get_type_name(entry_type), get_type_name(entry_type),
size=array_size, array_size,
) )
else: else:
show_array = False show_array = False
@ -309,43 +318,43 @@ class TypedDataEnvelope:
# Metamask V4 is using hash_struct() even in this case # Metamask V4 is using hash_struct() even in this case
if self.metamask_v4_compat: if self.metamask_v4_compat:
res = await self.hash_struct( res = await self.hash_struct(
primary_type=struct_name, struct_name,
member_path=el_member_path, el_member_path,
show_data=show_array, show_array,
parent_objects=current_parent_objects, current_parent_objects,
) )
arr_w.extend(res) arr_w.extend(res)
else: else:
await self.get_and_encode_data( await self.get_and_encode_data(
w=arr_w, arr_w,
primary_type=struct_name, struct_name,
member_path=el_member_path, el_member_path,
show_data=show_array, show_array,
parent_objects=current_parent_objects, current_parent_objects,
) )
else: else:
value = await get_value(self.ctx, entry_type, el_member_path) value = await get_value(ctx, entry_type, el_member_path)
encode_field(arr_w, entry_type, value) encode_field(arr_w, entry_type, value)
if show_array: if show_array:
await confirm_typed_value( await confirm_typed_value(
ctx=self.ctx, ctx,
name=field_name, field_name,
value=value, value,
parent_objects=parent_objects, parent_objects,
field=entry_type, entry_type,
array_index=i, i,
) )
w.extend(arr_w.get_digest()) w.extend(arr_w.get_digest())
else: else:
value = await get_value(self.ctx, field_type, member_value_path) value = await get_value(ctx, field_type, member_value_path)
encode_field(w, field_type, value) encode_field(w, field_type, value)
if show_data: if show_data:
await confirm_typed_value( await confirm_typed_value(
ctx=self.ctx, ctx,
name=field_name, field_name,
value=value, value,
parent_objects=parent_objects, parent_objects,
field=field_type, field_type,
) )
@ -370,21 +379,27 @@ def encode_field(
encodeData of their contents encodeData of their contents
- Struct values are encoded recursively as hashStruct(value) - Struct values are encoded recursively as hashStruct(value)
""" """
EDT = EthereumDataType # local_cache_global
data_type = field.data_type data_type = field.data_type
if data_type == EthereumDataType.BYTES: if data_type == EDT.BYTES:
if field.size is None: if field.size is None:
w.extend(keccak256(value)) w.extend(keccak256(value))
else: else:
write_rightpad32(w, value) # write_rightpad32
elif data_type == EthereumDataType.STRING: assert len(value) <= 32
w.extend(value)
for _ in range(32 - len(value)):
w.append(0x00)
elif data_type == EDT.STRING:
w.extend(keccak256(value)) w.extend(keccak256(value))
elif data_type == EthereumDataType.INT: elif data_type == EDT.INT:
write_leftpad32(w, value, signed=True) write_leftpad32(w, value, signed=True)
elif data_type in ( elif data_type in (
EthereumDataType.UINT, EDT.UINT,
EthereumDataType.BOOL, EDT.BOOL,
EthereumDataType.ADDRESS, EDT.ADDRESS,
): ):
write_leftpad32(w, value) write_leftpad32(w, value)
else: else:
@ -405,93 +420,90 @@ def write_leftpad32(w: HashWriter, value: bytes, signed: bool = False) -> None:
w.extend(value) w.extend(value)
def write_rightpad32(w: HashWriter, value: bytes) -> None: def _validate_value(field: EthereumFieldType, value: bytes) -> None:
assert len(value) <= 32
w.extend(value)
for _ in range(32 - len(value)):
w.append(0x00)
def validate_value(field: EthereumFieldType, value: bytes) -> None:
""" """
Make sure the byte data we receive are not corrupted or incorrect. Make sure the byte data we receive are not corrupted or incorrect.
Raise wire.DataError if encountering a problem, so clients are notified. Raise DataError if encountering a problem, so clients are notified.
""" """
# Checking if the size corresponds to what is defined in types, # Checking if the size corresponds to what is defined in types,
# and also setting our maximum supported size in bytes # and also setting our maximum supported size in bytes
if field.size is not None: if field.size is not None:
if len(value) != field.size: if len(value) != field.size:
raise wire.DataError("Invalid length") raise DataError("Invalid length")
else: else:
if len(value) > _MAX_VALUE_BYTE_SIZE: if len(value) > _MAX_VALUE_BYTE_SIZE:
raise wire.DataError(f"Invalid length, bigger than {_MAX_VALUE_BYTE_SIZE}") raise DataError(f"Invalid length, bigger than {_MAX_VALUE_BYTE_SIZE}")
# Specific tests for some data types # Specific tests for some data types
if field.data_type == EthereumDataType.BOOL: if field.data_type == EthereumDataType.BOOL:
if value not in (b"\x00", b"\x01"): if value not in (b"\x00", b"\x01"):
raise wire.DataError("Invalid boolean value") raise DataError("Invalid boolean value")
elif field.data_type == EthereumDataType.ADDRESS: elif field.data_type == EthereumDataType.ADDRESS:
if len(value) != 20: if len(value) != 20:
raise wire.DataError("Invalid address") raise DataError("Invalid address")
elif field.data_type == EthereumDataType.STRING: elif field.data_type == EthereumDataType.STRING:
try: try:
value.decode() value.decode()
except UnicodeError: except UnicodeError:
raise wire.DataError("Invalid UTF-8") raise DataError("Invalid UTF-8")
def validate_field_type(field: EthereumFieldType) -> None: def validate_field_type(field: EthereumFieldType) -> None:
""" """
Make sure the field type is consistent with our expectation. Make sure the field type is consistent with our expectation.
Raise wire.DataError if encountering a problem, so clients are notified. Raise DataError if encountering a problem, so clients are notified.
""" """
EDT = EthereumDataType # local_cache_global
data_type = field.data_type data_type = field.data_type
# entry_type is only for arrays # entry_type is only for arrays
if data_type == EthereumDataType.ARRAY: if data_type == EDT.ARRAY:
if field.entry_type is None: if field.entry_type is None:
raise wire.DataError("Missing entry_type in array") raise DataError("Missing entry_type in array")
# We also need to validate it recursively # We also need to validate it recursively
validate_field_type(field.entry_type) validate_field_type(field.entry_type)
else: else:
if field.entry_type is not None: if field.entry_type is not None:
raise wire.DataError("Unexpected entry_type in nonarray") raise DataError("Unexpected entry_type in nonarray")
# struct_name is only for structs # struct_name is only for structs
if data_type == EthereumDataType.STRUCT: if data_type == EDT.STRUCT:
if field.struct_name is None: if field.struct_name is None:
raise wire.DataError("Missing struct_name in struct") raise DataError("Missing struct_name in struct")
else: else:
if field.struct_name is not None: if field.struct_name is not None:
raise wire.DataError("Unexpected struct_name in nonstruct") raise DataError("Unexpected struct_name in nonstruct")
size = field.size # local_cache_attribute
# size is special for each type # size is special for each type
if data_type == EthereumDataType.STRUCT: if data_type == EDT.STRUCT:
if field.size is None: if size is None:
raise wire.DataError("Missing size in struct") raise DataError("Missing size in struct")
elif data_type == EthereumDataType.BYTES: elif data_type == EDT.BYTES:
if field.size is not None and not 1 <= field.size <= 32: if size is not None and not 1 <= size <= 32:
raise wire.DataError("Invalid size in bytes") raise DataError("Invalid size in bytes")
elif data_type in ( elif data_type in (
EthereumDataType.UINT, EDT.UINT,
EthereumDataType.INT, EDT.INT,
): ):
if field.size is None or not 1 <= field.size <= 32: if size is None or not 1 <= size <= 32:
raise wire.DataError("Invalid size in int/uint") raise DataError("Invalid size in int/uint")
elif data_type in ( elif data_type in (
EthereumDataType.STRING, EDT.STRING,
EthereumDataType.BOOL, EDT.BOOL,
EthereumDataType.ADDRESS, EDT.ADDRESS,
): ):
if field.size is not None: if size is not None:
raise wire.DataError("Unexpected size in str/bool/addr") raise DataError("Unexpected size in str/bool/addr")
async def get_array_size(ctx: Context, member_path: list[int]) -> int: async def _get_array_size(ctx: Context, member_path: list[int]) -> int:
"""Get the length of an array at specific `member_path` from the client.""" """Get the length of an array at specific `member_path` from the client."""
from trezor.messages import EthereumFieldType
# Field type for getting the array length from client, so we can check the return value # Field type for getting the array length from client, so we can check the return value
ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2) ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2)
length_value = await get_value(ctx, ARRAY_LENGTH_TYPE, member_path) length_value = await get_value(ctx, ARRAY_LENGTH_TYPE, member_path)
@ -504,18 +516,20 @@ async def get_value(
member_value_path: list[int], member_value_path: list[int],
) -> bytes: ) -> bytes:
"""Get a single value from the client and perform its validation.""" """Get a single value from the client and perform its validation."""
from trezor.messages import EthereumTypedDataValueAck, EthereumTypedDataValueRequest
req = EthereumTypedDataValueRequest( req = EthereumTypedDataValueRequest(
member_path=member_value_path, member_path=member_value_path,
) )
res = await ctx.call(req, EthereumTypedDataValueAck) res = await ctx.call(req, EthereumTypedDataValueAck)
value = res.value value = res.value
validate_value(field=field, value=value) _validate_value(field=field, value=value)
return value return value
async def get_name_and_version_for_domain( async def _get_name_and_version_for_domain(
ctx: Context, typed_data_envelope: TypedDataEnvelope ctx: Context, typed_data_envelope: TypedDataEnvelope
) -> tuple[bytes, bytes]: ) -> tuple[bytes, bytes]:
domain_name = b"unknown" domain_name = b"unknown"

File diff suppressed because it is too large Load Diff

View File

@ -2,6 +2,20 @@
# (by running `make templates` in `core`) # (by running `make templates` in `core`)
# do not edit manually! # do not edit manually!
# fmt: off # fmt: off
# NOTE: returning a tuple instead of `TokenInfo` from the "data" function
# saves 5600 bytes of flash size. Implementing the `_token_iterator`
# instead of if-tree approach saves another 5600 bytes.
# NOTE: interestingly, it did not save much flash size to use smaller
# parts of the address, for example address length of 10 bytes saves
# 1 byte per entry, so 1887 bytes overall (and further decrease does not help).
# (The idea was not having to store the whole address, even a smaller part
# of it has enough collision-resistance.)
# (In the if-tree approach the address length did not have any effect whatsoever.)
from typing import Iterator
<% <%
from collections import defaultdict from collections import defaultdict
@ -22,11 +36,20 @@ UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)
def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo: def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo:
for addr, symbol, decimal in _token_iterator(chain_id):
if address == addr:
return TokenInfo(symbol, decimal)
return UNKNOWN_TOKEN
def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int]]:
% for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items(): % for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items():
if chain_id == ${token_chain_id}: if chain_id == ${token_chain_id}:
% for t in tokens: % for t in tokens:
if address == ${black_repr(t.address_bytes)}: yield ( # address, symbol, decimals
return TokenInfo(${black_repr(t.symbol)}, ${t.decimals}) # ${t.chain} / ${t.name.strip()} ${black_repr(t.address_bytes)},
${black_repr(t.symbol)},
${t.decimals},
)
% endfor % endfor
% endfor % endfor
return UNKNOWN_TOKEN

View File

@ -1,42 +1,42 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import Success
from trezor.ui.layouts import confirm_signverify, show_success
from apps.common.signverify import decode_message
from .helpers import address_from_bytes, bytes_from_address
from .sign_message import message_digest
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumVerifyMessage from trezor.messages import EthereumVerifyMessage, Success
from trezor.wire import Context from trezor.wire import Context
async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success: async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
from trezor.wire import DataError
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import Success
from trezor.ui.layouts import confirm_signverify, show_success
from apps.common.signverify import decode_message
from .helpers import address_from_bytes, bytes_from_address
from .sign_message import message_digest
digest = message_digest(msg.message) digest = message_digest(msg.message)
if len(msg.signature) != 65: if len(msg.signature) != 65:
raise wire.DataError("Invalid signature") raise DataError("Invalid signature")
sig = bytearray([msg.signature[64]]) + msg.signature[:64] sig = bytearray([msg.signature[64]]) + msg.signature[:64]
pubkey = secp256k1.verify_recover(sig, digest) pubkey = secp256k1.verify_recover(sig, digest)
if not pubkey: if not pubkey:
raise wire.DataError("Invalid signature") raise DataError("Invalid signature")
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:] pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
address_bytes = bytes_from_address(msg.address) address_bytes = bytes_from_address(msg.address)
if address_bytes != pkh: if address_bytes != pkh:
raise wire.DataError("Invalid signature") raise DataError("Invalid signature")
address = address_from_bytes(address_bytes) address = address_from_bytes(address_bytes)
await confirm_signverify( await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address=address, verify=True ctx, "ETH", decode_message(msg.message), address, verify=True
) )
await show_success(ctx, "verify_message", "The signature is valid.") await show_success(ctx, "verify_message", "The signature is valid.")

View File

@ -11,7 +11,7 @@ from trezor.enums import EthereumDataType as EDT
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from apps.ethereum.sign_typed_data import ( from apps.ethereum.sign_typed_data import (
encode_field, encode_field,
validate_value, _validate_value,
validate_field_type, validate_field_type,
keccak256, keccak256,
TypedDataEnvelope, TypedDataEnvelope,
@ -594,10 +594,10 @@ class TestEthereumSignTypedData(unittest.TestCase):
for field, valid_values, invalid_values in VECTORS_VALID_INVALID: for field, valid_values, invalid_values in VECTORS_VALID_INVALID:
for valid_value in valid_values: for valid_value in valid_values:
validate_value(field=field, value=valid_value) _validate_value(field=field, value=valid_value)
for invalid_value in invalid_values: for invalid_value in invalid_values:
with self.assertRaises(wire.DataError): with self.assertRaises(wire.DataError):
validate_value(field=field, value=invalid_value) _validate_value(field=field, value=invalid_value)
def test_validate_field_type(self): def test_validate_field_type(self):
ET = EFT(data_type=EDT.BYTES, size=8) ET = EFT(data_type=EDT.BYTES, size=8)