chore(core): decrease ethereum size by 17250 bytes

pull/2633/head
grdddj 2 years ago committed by matejcik
parent 0c3423b1c7
commit 26fd0de198

@ -1,16 +1,9 @@
from typing import TYPE_CHECKING
from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address
from apps.common import paths
from . import networks
from .helpers import address_from_bytes
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if TYPE_CHECKING:
from trezor.messages import EthereumGetAddress
from trezor.messages import EthereumGetAddress, EthereumAddress
from trezor.wire import Context
from apps.common.keychain import Keychain
@ -20,18 +13,26 @@ if TYPE_CHECKING:
async def get_address(
ctx: Context, msg: EthereumGetAddress, keychain: Keychain
) -> EthereumAddress:
await paths.validate_path(ctx, keychain, msg.address_n)
from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address
from apps.common import paths
from . import networks
from .helpers import address_from_bytes
address_n = msg.address_n # local_cache_attribute
await paths.validate_path(ctx, keychain, address_n)
node = keychain.derive(msg.address_n)
node = keychain.derive(address_n)
if len(msg.address_n) > 1: # path has slip44 network identifier
network = networks.by_slip44(msg.address_n[1] & 0x7FFF_FFFF)
if len(address_n) > 1: # path has slip44 network identifier
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF)
else:
network = None
address = address_from_bytes(node.ethereum_pubkeyhash(), network)
if msg.show_display:
title = paths.address_n_to_str(msg.address_n)
await show_address(ctx, address=address, title=title)
title = paths.address_n_to_str(address_n)
await show_address(ctx, address, title=title)
return EthereumAddress(address=address)

@ -1,15 +1,11 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor.messages import EthereumPublicKey, HDNodeType
from trezor.ui.layouts import show_pubkey
from apps.common import coins, paths
from apps.common import paths
from .keychain import with_keychain_from_path
if TYPE_CHECKING:
from trezor.messages import EthereumGetPublicKey
from trezor.messages import EthereumGetPublicKey, EthereumPublicKey
from trezor.wire import Context
from apps.common.keychain import Keychain
@ -19,6 +15,11 @@ if TYPE_CHECKING:
async def get_public_key(
ctx: Context, msg: EthereumGetPublicKey, keychain: Keychain
) -> EthereumPublicKey:
from ubinascii import hexlify
from trezor.messages import EthereumPublicKey, HDNodeType
from trezor.ui.layouts import show_pubkey
from apps.common import coins
await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n)

@ -1,8 +1,5 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify, unhexlify
from trezor import wire
from trezor.enums import EthereumDataType
from ubinascii import hexlify
if TYPE_CHECKING:
from trezor.messages import EthereumFieldType
@ -24,7 +21,7 @@ def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None)
address_hex = hexlify(address_bytes).decode()
digest = sha3_256((prefix + address_hex).encode(), keccak=True).digest()
def maybe_upper(i: int) -> str:
def _maybe_upper(i: int) -> str:
"""Uppercase i-th letter only if the corresponding nibble has high bit set."""
digest_byte = digest[i // 2]
hex_letter = address_hex[i]
@ -39,10 +36,13 @@ def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None)
else:
return hex_letter
return "0x" + "".join(maybe_upper(i) for i in range(len(address_hex)))
return "0x" + "".join(_maybe_upper(i) for i in range(len(address_hex)))
def bytes_from_address(address: str) -> bytes:
from ubinascii import unhexlify
from trezor import wire
if len(address) == 40:
return unhexlify(address)
@ -59,6 +59,8 @@ def bytes_from_address(address: str) -> bytes:
def get_type_name(field: EthereumFieldType) -> str:
"""Create a string from type definition (like uint256 or bytes16)."""
from trezor.enums import EthereumDataType
data_type = field.data_type
size = field.size
@ -109,12 +111,12 @@ def decode_typed_data(data: bytes, type_name: str) -> str:
return str(int.from_bytes(data, "big"))
elif type_name.startswith("int"):
# Micropython does not implement "signed" arg in int.from_bytes()
return str(from_bytes_bigendian_signed(data))
return str(_from_bytes_bigendian_signed(data))
raise ValueError # Unsupported data type for direct field decoding
def from_bytes_bigendian_signed(b: bytes) -> int:
def _from_bytes_bigendian_signed(b: bytes) -> int:
negative = b[0] & 0x80
if negative:
neg_b = bytearray(b)

@ -1,7 +1,5 @@
from typing import TYPE_CHECKING
from trezor import wire
from apps.common import paths
from apps.common.keychain import get_keychain
@ -10,6 +8,8 @@ from . import CURVE, networks
if TYPE_CHECKING:
from typing import Callable, Iterable, TypeVar
from trezor.wire import Context
from trezor.messages import (
EthereumGetAddress,
EthereumGetPublicKey,
@ -64,7 +64,7 @@ def with_keychain_from_path(
*patterns: str,
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
schemas = _schemas_from_address_n(patterns, msg.address_n)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
@ -97,7 +97,7 @@ def with_keychain_from_chain_id(
func: HandlerWithKeychain[MsgInChainId, MsgOut]
) -> Handler[MsgInChainId, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: wire.Context, msg: MsgInChainId) -> MsgOut:
async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut:
schemas = _schemas_from_chain_id(msg)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:

@ -1,29 +1,19 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from trezor import ui
from trezor.enums import ButtonRequestType, EthereumDataType
from trezor.strings import format_amount, format_plural
from trezor.ui.layouts import (
confirm_action,
confirm_address,
confirm_amount,
confirm_blob,
confirm_output,
confirm_text,
confirm_total,
should_show_more,
)
from trezor.ui.layouts.altcoin import confirm_total_ethereum
from . import networks, tokens
from .helpers import address_from_bytes, decode_typed_data, get_type_name
from trezor.enums import ButtonRequestType
from trezor.strings import format_plural
from trezor.ui.layouts import confirm_blob, confirm_text, should_show_more
from . import networks
from .helpers import decode_typed_data
if TYPE_CHECKING:
from typing import Awaitable, Iterable
from trezor.messages import EthereumFieldType, EthereumStructMember
from trezor.wire import Context
from . import tokens
def require_confirm_tx(
@ -33,15 +23,18 @@ def require_confirm_tx(
chain_id: int,
token: tokens.TokenInfo | None = None,
) -> Awaitable[None]:
from .helpers import address_from_bytes
from trezor.ui.layouts import confirm_output
if to_bytes:
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
else:
to_str = "new contract?"
return confirm_output(
ctx,
address=to_str,
amount=format_ethereum_amount(value, token, chain_id),
font_amount=ui.BOLD,
to_str,
format_ethereum_amount(value, token, chain_id),
ui.BOLD,
color_to=ui.GREY,
br_code=ButtonRequestType.SignTx,
)
@ -55,6 +48,8 @@ def require_confirm_fee(
chain_id: int,
token: tokens.TokenInfo | None = None,
) -> Awaitable[None]:
from trezor.ui.layouts.altcoin import confirm_total_ethereum
return confirm_total_ethereum(
ctx,
format_ethereum_amount(spending, token, chain_id),
@ -72,22 +67,24 @@ async def require_confirm_eip1559_fee(
chain_id: int,
token: tokens.TokenInfo | None = None,
) -> None:
from trezor.ui.layouts import confirm_amount, confirm_total
await confirm_amount(
ctx,
title="Confirm fee",
description="Maximum fee per gas",
amount=format_ethereum_amount(max_gas_fee, None, chain_id),
"Confirm fee",
format_ethereum_amount(max_gas_fee, None, chain_id),
"Maximum fee per gas",
)
await confirm_amount(
ctx,
title="Confirm fee",
description="Priority fee per gas",
amount=format_ethereum_amount(max_priority_fee, None, chain_id),
"Confirm fee",
format_ethereum_amount(max_priority_fee, None, chain_id),
"Priority fee per gas",
)
await confirm_total(
ctx,
total_amount=format_ethereum_amount(spending, token, chain_id),
fee_amount=format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
format_ethereum_amount(spending, token, chain_id),
format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
total_label="Amount sent:\n",
fee_label="\nMaximum fee:\n",
)
@ -96,13 +93,16 @@ async def require_confirm_eip1559_fee(
def require_confirm_unknown_token(
ctx: Context, address_bytes: bytes
) -> Awaitable[None]:
from ubinascii import hexlify
from trezor.ui.layouts import confirm_address
contract_address_hex = "0x" + hexlify(address_bytes).decode()
return confirm_address(
ctx,
"Unknown token",
contract_address_hex,
description="Contract:",
br_type="unknown_token",
"Contract:",
"unknown_token",
icon_color=ui.ORANGE,
br_code=ButtonRequestType.SignTx,
)
@ -112,20 +112,22 @@ def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitabl
return confirm_blob(
ctx,
"confirm_data",
title="Confirm data",
description=f"Size: {data_total} bytes",
data=data,
"Confirm data",
data,
f"Size: {data_total} bytes",
br_code=ButtonRequestType.SignTx,
ask_pagination=True,
)
async def confirm_typed_data_final(ctx: Context) -> None:
from trezor.ui.layouts import confirm_action
await confirm_action(
ctx,
"confirm_typed_data_final",
title="Confirm typed data",
action="Really sign EIP-712 typed data?",
"Confirm typed data",
"Really sign EIP-712 typed data?",
verb="Hold to confirm",
hold=True,
)
@ -135,9 +137,9 @@ def confirm_empty_typed_message(ctx: Context) -> Awaitable[None]:
return confirm_text(
ctx,
"confirm_empty_typed_message",
title="Confirm message",
data="",
description="No message field",
"Confirm message",
"",
"No message field",
)
@ -152,10 +154,10 @@ async def should_show_domain(ctx: Context, name: bytes, version: bytes) -> bool:
)
return await should_show_more(
ctx,
title="Confirm domain",
para=para,
button_text="Show full domain",
br_type="should_show_domain",
"Confirm domain",
para,
"Show full domain",
"should_show_domain",
)
@ -176,10 +178,10 @@ async def should_show_struct(
)
return await should_show_more(
ctx,
title=title,
para=para,
button_text=button_text,
br_type="should_show_struct",
title,
para,
button_text,
"should_show_struct",
)
@ -192,10 +194,10 @@ async def should_show_array(
para = ((ui.NORMAL, format_plural("Array of {count} {plural}", size, data_type)),)
return await should_show_more(
ctx,
title=limit_str(".".join(parent_objects)),
para=para,
button_text="Show full array",
br_type="should_show_array",
limit_str(".".join(parent_objects)),
para,
"Show full array",
"should_show_array",
)
@ -207,6 +209,9 @@ async def confirm_typed_value(
field: EthereumFieldType,
array_index: int | None = None,
) -> None:
from trezor.enums import EthereumDataType
from .helpers import get_type_name
type_name = get_type_name(field)
if array_index is not None:
@ -222,24 +227,26 @@ async def confirm_typed_value(
await confirm_blob(
ctx,
"confirm_typed_value",
title=title,
data=data,
description=description,
title,
data,
description,
ask_pagination=True,
)
else:
await confirm_text(
ctx,
"confirm_typed_value",
title=title,
data=data,
description=description,
title,
data,
description,
)
def format_ethereum_amount(
value: int, token: tokens.TokenInfo | None, chain_id: int
) -> str:
from trezor.strings import format_amount
if token:
suffix = token.symbol
decimals = token.decimals

File diff suppressed because it is too large Load Diff

@ -1,10 +1,28 @@
# generated from networks.py.mako
# (by running `make templates` in `core`)
# do not edit manually!
from typing import Iterator
# NOTE: using positional arguments saves 4400 bytes in flash size,
# returning tuples instead of classes saved 800 bytes
from typing import TYPE_CHECKING
from apps.common.paths import HARDENED
if TYPE_CHECKING:
from typing import Iterator
# Removing the necessity to construct object to save space
# fmt: off
NetworkInfoTuple = tuple[
int, # chain_id
int, # slip44
str, # shortcut
str, # name
bool # rskip60
]
# fmt: on
def shortcut_by_chain_id(chain_id: int) -> str:
n = by_chain_id(chain_id)
@ -13,21 +31,36 @@ def shortcut_by_chain_id(chain_id: int) -> str:
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
for n in _networks_iterator():
if n.chain_id == chain_id:
return n
n_chain_id = n[0]
if n_chain_id == chain_id:
return NetworkInfo(
chain_id=n[0],
slip44=n[1],
shortcut=n[2],
name=n[3],
rskip60=n[4],
)
return None
def by_slip44(slip44: int) -> "NetworkInfo" | None:
for n in _networks_iterator():
if n.slip44 == slip44:
return n
n_slip44 = n[1]
if n_slip44 == slip44:
return NetworkInfo(
chain_id=n[0],
slip44=n[1],
shortcut=n[2],
name=n[3],
rskip60=n[4],
)
return None
def all_slip44_ids_hardened() -> Iterator[int]:
for n in _networks_iterator():
yield n.slip44 | HARDENED
# n_slip_44 is the second element
yield n[1] | HARDENED
class NetworkInfo:
@ -42,13 +75,13 @@ class NetworkInfo:
# fmt: off
def _networks_iterator() -> Iterator[NetworkInfo]:
def _networks_iterator() -> Iterator[NetworkInfoTuple]:
% for n in supported_on("trezor2", eth):
yield NetworkInfo(
chain_id=${n.chain_id},
slip44=${n.slip44},
shortcut="${n.shortcut}",
name="${n.name}",
rskip60=${n.rskip60},
yield (
${n.chain_id}, # chain_id
${n.slip44}, # slip44
"${n.shortcut}", # shortcut
"${n.name}", # name
${n.rskip60}, # rskip60
)
% endfor

@ -1,25 +1,18 @@
from typing import TYPE_CHECKING
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import EthereumMessageSignature
from trezor.ui.layouts import confirm_signverify
from trezor.utils import HashWriter
from apps.common import paths
from apps.common.signverify import decode_message
from .helpers import address_from_bytes
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if TYPE_CHECKING:
from trezor.messages import EthereumSignMessage
from trezor.messages import EthereumSignMessage, EthereumMessageSignature
from trezor.wire import Context
from apps.common.keychain import Keychain
def message_digest(message: bytes) -> bytes:
from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter
h = HashWriter(sha3_256(keccak=True))
signed_message_header = b"\x19Ethereum Signed Message:\n"
h.extend(signed_message_header)
@ -32,6 +25,15 @@ def message_digest(message: bytes) -> bytes:
async def sign_message(
ctx: Context, msg: EthereumSignMessage, keychain: Keychain
) -> EthereumMessageSignature:
from trezor.crypto.curve import secp256k1
from trezor.messages import EthereumMessageSignature
from trezor.ui.layouts import confirm_signverify
from apps.common import paths
from apps.common.signverify import decode_message
from .helpers import address_from_bytes
await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n)

@ -1,29 +1,19 @@
from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto import rlp
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import EthereumTxAck, EthereumTxRequest
from trezor.utils import HashWriter
from trezor.messages import EthereumTxRequest
from trezor.wire import DataError
from apps.common import paths
from . import tokens
from .helpers import bytes_from_address
from .keychain import with_keychain_from_chain_id
from .layout import (
require_confirm_data,
require_confirm_fee,
require_confirm_tx,
require_confirm_unknown_token,
)
if TYPE_CHECKING:
from apps.common.keychain import Keychain
from trezor.messages import EthereumSignTx
from trezor.messages import EthereumSignTx, EthereumTxAck
from trezor.wire import Context
from .keychain import EthereumSignTxAny
from . import tokens
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
@ -34,9 +24,24 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id
async def sign_tx(
ctx: wire.Context, msg: EthereumSignTx, keychain: Keychain
ctx: Context, msg: EthereumSignTx, keychain: Keychain
) -> EthereumTxRequest:
check(msg)
from trezor.utils import HashWriter
from trezor.crypto.hashlib import sha3_256
from apps.common import paths
from .layout import (
require_confirm_data,
require_confirm_fee,
require_confirm_tx,
)
# check
if msg.tx_type not in [1, 6, None]:
raise DataError("tx_type out of bounds")
if len(msg.gas_price) + len(msg.gas_limit) > 30:
raise DataError("Fee overflow")
check_common_fields(msg)
await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s
@ -61,7 +66,7 @@ async def sign_tx(
data += msg.data_initial_chunk
data_left = data_total - len(msg.data_initial_chunk)
total_length = get_total_length(msg, data_total)
total_length = _get_total_length(msg, data_total)
sha = HashWriter(sha3_256(keccak=True))
rlp.write_header(sha, total_length, rlp.LIST_HEADER_BYTE)
@ -89,14 +94,19 @@ async def sign_tx(
rlp.write(sha, 0)
digest = sha.get_digest()
result = sign_digest(msg, keychain, digest)
result = _sign_digest(msg, keychain, digest)
return result
async def handle_erc20(
ctx: wire.Context, msg: EthereumSignTxAny
ctx: Context, msg: EthereumSignTxAny
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
from .layout import require_confirm_unknown_token
from . import tokens
data_initial_chunk = msg.data_initial_chunk # local_cache_attribute
token = None
address_bytes = recipient = bytes_from_address(msg.to)
value = int.from_bytes(msg.value, "big")
@ -104,13 +114,13 @@ async def handle_erc20(
len(msg.to) in (40, 42)
and len(msg.value) == 0
and msg.data_length == 68
and len(msg.data_initial_chunk) == 68
and msg.data_initial_chunk[:16]
and len(data_initial_chunk) == 68
and data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
):
token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
recipient = msg.data_initial_chunk[16:36]
value = int.from_bytes(msg.data_initial_chunk[36:68], "big")
recipient = data_initial_chunk[16:36]
value = int.from_bytes(data_initial_chunk[36:68], "big")
if token is tokens.UNKNOWN_TOKEN:
await require_confirm_unknown_token(ctx, address_bytes)
@ -118,7 +128,7 @@ async def handle_erc20(
return token, address_bytes, recipient, value
def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
def _get_total_length(msg: EthereumSignTx, data_total: int) -> int:
length = 0
if msg.tx_type is not None:
length += rlp.length(msg.tx_type)
@ -143,20 +153,20 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int:
return length
async def send_request_chunk(ctx: wire.Context, data_left: int) -> EthereumTxAck:
async def send_request_chunk(ctx: Context, data_left: int) -> EthereumTxAck:
from trezor.messages import EthereumTxAck
# TODO: layoutProgress ?
req = EthereumTxRequest()
if data_left <= 1024:
req.data_length = data_left
else:
req.data_length = 1024
req.data_length = min(data_left, 1024)
return await ctx.call(req, EthereumTxAck)
def sign_digest(
def _sign_digest(
msg: EthereumSignTx, keychain: Keychain, digest: bytes
) -> EthereumTxRequest:
from trezor.crypto.curve import secp256k1
node = keychain.derive(msg.address_n)
signature = secp256k1.sign(
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
@ -175,33 +185,25 @@ def sign_digest(
return req
def check(msg: EthereumSignTx) -> None:
if msg.tx_type not in [1, 6, None]:
raise wire.DataError("tx_type out of bounds")
if len(msg.gas_price) + len(msg.gas_limit) > 30:
raise wire.DataError("Fee overflow")
check_common_fields(msg)
def check_common_fields(msg: EthereumSignTxAny) -> None:
if msg.data_length > 0:
data_length = msg.data_length # local_cache_attribute
if data_length > 0:
if not msg.data_initial_chunk:
raise wire.DataError("Data length provided, but no initial chunk")
raise DataError("Data length provided, but no initial chunk")
# Our encoding only supports transactions up to 2^24 bytes. To
# prevent exceeding the limit we use a stricter limit on data length.
if msg.data_length > 16_000_000:
raise wire.DataError("Data length exceeds limit")
if len(msg.data_initial_chunk) > msg.data_length:
raise wire.DataError("Invalid size of initial chunk")
if data_length > 16_000_000:
raise DataError("Data length exceeds limit")
if len(msg.data_initial_chunk) > data_length:
raise DataError("Invalid size of initial chunk")
if len(msg.to) not in (0, 40, 42):
raise wire.DataError("Invalid recipient address")
raise DataError("Invalid recipient address")
if not msg.to and msg.data_length == 0:
if not msg.to and data_length == 0:
# sending transaction to address 0 (contract creation) without a data field
raise wire.DataError("Contract creation without data")
raise DataError("Contract creation without data")
if msg.chain_id == 0:
raise wire.DataError("Chain ID out of bounds")
raise DataError("Chain ID out of bounds")

@ -1,28 +1,21 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto import rlp
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import EthereumAccessList, EthereumTxRequest
from trezor.utils import HashWriter
from apps.common import paths
from .helpers import bytes_from_address
from .keychain import with_keychain_from_chain_id
from .layout import (
require_confirm_data,
require_confirm_eip1559_fee,
require_confirm_tx,
)
from .sign_tx import check_common_fields, handle_erc20, send_request_chunk
if TYPE_CHECKING:
from trezor.messages import EthereumSignTxEIP1559
from trezor.messages import (
EthereumSignTxEIP1559,
EthereumAccessList,
EthereumTxRequest,
)
from apps.common.keychain import Keychain
from trezor.wire import Context
_TX_TYPE = const(2)
@ -35,28 +28,30 @@ def access_list_item_length(item: EthereumAccessList) -> int:
)
def access_list_length(access_list: list[EthereumAccessList]) -> int:
payload_length = sum(access_list_item_length(i) for i in access_list)
return rlp.header_length(payload_length) + payload_length
def write_access_list(w: HashWriter, access_list: list[EthereumAccessList]) -> None:
payload_length = sum(access_list_item_length(i) for i in access_list)
rlp.write_header(w, payload_length, rlp.LIST_HEADER_BYTE)
for item in access_list:
address_bytes = bytes_from_address(item.address)
address_length = rlp.length(address_bytes)
keys_length = rlp.length(item.storage_keys)
rlp.write_header(w, address_length + keys_length, rlp.LIST_HEADER_BYTE)
rlp.write(w, address_bytes)
rlp.write(w, item.storage_keys)
@with_keychain_from_chain_id
async def sign_tx_eip1559(
ctx: wire.Context, msg: EthereumSignTxEIP1559, keychain: Keychain
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain
) -> EthereumTxRequest:
check(msg)
from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter
from trezor import wire
from trezor.crypto import rlp # local_cache_global
from apps.common import paths
from .layout import (
require_confirm_data,
require_confirm_eip1559_fee,
require_confirm_tx,
)
from .sign_tx import handle_erc20, send_request_chunk, check_common_fields
gas_limit = msg.gas_limit # local_cache_attribute
# check
if len(msg.max_gas_fee) + len(gas_limit) > 30:
raise wire.DataError("Fee overflow")
if len(msg.max_priority_fee) + len(gas_limit) > 30:
raise wire.DataError("Fee overflow")
check_common_fields(msg)
await paths.validate_path(ctx, keychain, msg.address_n)
@ -74,7 +69,7 @@ async def sign_tx_eip1559(
value,
int.from_bytes(msg.max_priority_fee, "big"),
int.from_bytes(msg.max_gas_fee, "big"),
int.from_bytes(msg.gas_limit, "big"),
int.from_bytes(gas_limit, "big"),
msg.chain_id,
token,
)
@ -83,7 +78,7 @@ async def sign_tx_eip1559(
data += msg.data_initial_chunk
data_left = data_total - len(msg.data_initial_chunk)
total_length = get_total_length(msg, data_total)
total_length = _get_total_length(msg, data_total)
sha = HashWriter(sha3_256(keccak=True))
@ -96,7 +91,7 @@ async def sign_tx_eip1559(
msg.nonce,
msg.max_priority_fee,
msg.max_gas_fee,
msg.gas_limit,
gas_limit,
address_bytes,
msg.value,
)
@ -114,15 +109,24 @@ async def sign_tx_eip1559(
data_left -= len(resp.data_chunk)
sha.extend(resp.data_chunk)
write_access_list(sha, msg.access_list)
# write_access_list
payload_length = sum(access_list_item_length(i) for i in msg.access_list)
rlp.write_header(sha, payload_length, rlp.LIST_HEADER_BYTE)
for item in msg.access_list:
address_bytes = bytes_from_address(item.address)
address_length = rlp.length(address_bytes)
keys_length = rlp.length(item.storage_keys)
rlp.write_header(sha, address_length + keys_length, rlp.LIST_HEADER_BYTE)
rlp.write(sha, address_bytes)
rlp.write(sha, item.storage_keys)
digest = sha.get_digest()
result = sign_digest(msg, keychain, digest)
result = _sign_digest(msg, keychain, digest)
return result
def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
def _get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
length = 0
fields: tuple[rlp.RLPItem, ...] = (
@ -140,14 +144,21 @@ def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int:
length += rlp.header_length(data_total, msg.data_initial_chunk)
length += data_total
length += access_list_length(msg.access_list)
# access_list_length
payload_length = sum(access_list_item_length(i) for i in msg.access_list)
access_list_length = rlp.header_length(payload_length) + payload_length
length += access_list_length
return length
def sign_digest(
def _sign_digest(
msg: EthereumSignTxEIP1559, keychain: Keychain, digest: bytes
) -> EthereumTxRequest:
from trezor.messages import EthereumTxRequest
from trezor.crypto.curve import secp256k1
node = keychain.derive(msg.address_n)
signature = secp256k1.sign(
node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM
@ -159,12 +170,3 @@ def sign_digest(
req.signature_s = signature[33:]
return req
def check(msg: EthereumSignTxEIP1559) -> None:
if len(msg.max_gas_fee) + len(msg.gas_limit) > 30:
raise wire.DataError("Fee overflow")
if len(msg.max_priority_fee) + len(msg.gas_limit) > 30:
raise wire.DataError("Fee overflow")
check_common_fields(msg)

@ -1,38 +1,24 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.enums import EthereumDataType
from trezor.messages import (
EthereumFieldType,
EthereumTypedDataSignature,
EthereumTypedDataStructAck,
EthereumTypedDataStructRequest,
EthereumTypedDataValueAck,
EthereumTypedDataValueRequest,
)
from trezor.utils import HashWriter
from apps.common import paths
from .helpers import address_from_bytes, get_type_name
from trezor.wire import DataError
from .helpers import get_type_name
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
from .layout import (
confirm_empty_typed_message,
confirm_typed_data_final,
confirm_typed_value,
should_show_array,
should_show_domain,
should_show_struct,
)
from .layout import should_show_struct
if TYPE_CHECKING:
from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.utils import HashWriter
from trezor.messages import EthereumSignTypedData
from trezor.messages import (
EthereumSignTypedData,
EthereumFieldType,
EthereumTypedDataSignature,
EthereumTypedDataStructAck,
)
# Maximum data size we support
@ -43,9 +29,14 @@ _MAX_VALUE_BYTE_SIZE = const(1024)
async def sign_typed_data(
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain
) -> EthereumTypedDataSignature:
from trezor.crypto.curve import secp256k1
from apps.common import paths
from .helpers import address_from_bytes
from trezor.messages import EthereumTypedDataSignature
await paths.validate_path(ctx, keychain, msg.address_n)
data_hash = await generate_typed_data_hash(
data_hash = await _generate_typed_data_hash(
ctx, msg.primary_type, msg.metamask_v4_compat
)
@ -60,7 +51,7 @@ async def sign_typed_data(
)
async def generate_typed_data_hash(
async def _generate_typed_data_hash(
ctx: Context, primary_type: str, metamask_v4_compat: bool = True
) -> bytes:
"""
@ -69,20 +60,26 @@ async def generate_typed_data_hash(
metamask_v4_compat - a flag that enables compatibility with MetaMask's signTypedData_v4 method
"""
from .layout import (
confirm_empty_typed_message,
confirm_typed_data_final,
should_show_domain,
)
typed_data_envelope = TypedDataEnvelope(
ctx=ctx,
primary_type=primary_type,
metamask_v4_compat=metamask_v4_compat,
ctx,
primary_type,
metamask_v4_compat,
)
await typed_data_envelope.collect_types()
name, version = await get_name_and_version_for_domain(ctx, typed_data_envelope)
name, version = await _get_name_and_version_for_domain(ctx, typed_data_envelope)
show_domain = await should_show_domain(ctx, name, version)
domain_separator = await typed_data_envelope.hash_struct(
primary_type="EIP712Domain",
member_path=[0],
show_data=show_domain,
parent_objects=["EIP712Domain"],
"EIP712Domain",
[0],
show_domain,
["EIP712Domain"],
)
# Setting the primary_type to "EIP712Domain" is technically in spec
@ -94,16 +91,16 @@ async def generate_typed_data_hash(
else:
show_message = await should_show_struct(
ctx,
description=primary_type,
data_members=typed_data_envelope.types[primary_type].members,
title="Confirm message",
button_text="Show full message",
primary_type,
typed_data_envelope.types[primary_type].members,
"Confirm message",
"Show full message",
)
message_hash = await typed_data_envelope.hash_struct(
primary_type=primary_type,
member_path=[1],
show_data=show_message,
parent_objects=[primary_type],
primary_type,
[1],
show_message,
[primary_type],
)
await confirm_typed_data_final(ctx)
@ -112,6 +109,9 @@ async def generate_typed_data_hash(
def get_hash_writer() -> HashWriter:
from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter
return HashWriter(sha3_256(keccak=True))
@ -142,6 +142,11 @@ class TypedDataEnvelope:
async def _collect_types(self, type_name: str) -> None:
"""Recursively collect types from the client."""
from trezor.messages import (
EthereumTypedDataStructRequest,
EthereumTypedDataStructAck,
)
req = EthereumTypedDataStructRequest(name=type_name)
current_type = await self.ctx.call(req, EthereumTypedDataStructAck)
self.types[type_name] = current_type
@ -169,11 +174,11 @@ class TypedDataEnvelope:
w = get_hash_writer()
self.hash_type(w, primary_type)
await self.get_and_encode_data(
w=w,
primary_type=primary_type,
member_path=member_path,
show_data=show_data,
parent_objects=parent_objects,
w,
primary_type,
member_path,
show_data,
parent_objects,
)
return w.get_digest()
@ -242,6 +247,10 @@ class TypedDataEnvelope:
i.e. the concatenation of the encoded member values in the order that they appear in the type.
Each encoded member value is exactly 32-byte long.
"""
from .layout import confirm_typed_value, should_show_array
ctx = self.ctx # local_cache_attribute
type_members = self.types[primary_type].members
member_value_path = member_path + [0]
current_parent_objects = parent_objects + [""]
@ -258,25 +267,25 @@ class TypedDataEnvelope:
if show_data:
show_struct = await should_show_struct(
ctx=self.ctx,
description=struct_name,
data_members=self.types[struct_name].members,
title=".".join(current_parent_objects),
ctx,
struct_name, # description
self.types[struct_name].members, # data_members
".".join(current_parent_objects), # title
)
else:
show_struct = False
res = await self.hash_struct(
primary_type=struct_name,
member_path=member_value_path,
show_data=show_struct,
parent_objects=current_parent_objects,
struct_name,
member_value_path,
show_struct,
current_parent_objects,
)
w.extend(res)
elif field_type.data_type == EthereumDataType.ARRAY:
# Getting the length of the array first, if not fixed
if field_type.size is None:
array_size = await get_array_size(self.ctx, member_value_path)
array_size = await _get_array_size(ctx, member_value_path)
else:
array_size = field_type.size
@ -286,10 +295,10 @@ class TypedDataEnvelope:
if show_data:
show_array = await should_show_array(
ctx=self.ctx,
parent_objects=current_parent_objects,
data_type=get_type_name(entry_type),
size=array_size,
ctx,
current_parent_objects,
get_type_name(entry_type),
array_size,
)
else:
show_array = False
@ -309,43 +318,43 @@ class TypedDataEnvelope:
# Metamask V4 is using hash_struct() even in this case
if self.metamask_v4_compat:
res = await self.hash_struct(
primary_type=struct_name,
member_path=el_member_path,
show_data=show_array,
parent_objects=current_parent_objects,
struct_name,
el_member_path,
show_array,
current_parent_objects,
)
arr_w.extend(res)
else:
await self.get_and_encode_data(
w=arr_w,
primary_type=struct_name,
member_path=el_member_path,
show_data=show_array,
parent_objects=current_parent_objects,
arr_w,
struct_name,
el_member_path,
show_array,
current_parent_objects,
)
else:
value = await get_value(self.ctx, entry_type, el_member_path)
value = await get_value(ctx, entry_type, el_member_path)
encode_field(arr_w, entry_type, value)
if show_array:
await confirm_typed_value(
ctx=self.ctx,
name=field_name,
value=value,
parent_objects=parent_objects,
field=entry_type,
array_index=i,
ctx,
field_name,
value,
parent_objects,
entry_type,
i,
)
w.extend(arr_w.get_digest())
else:
value = await get_value(self.ctx, field_type, member_value_path)
value = await get_value(ctx, field_type, member_value_path)
encode_field(w, field_type, value)
if show_data:
await confirm_typed_value(
ctx=self.ctx,
name=field_name,
value=value,
parent_objects=parent_objects,
field=field_type,
ctx,
field_name,
value,
parent_objects,
field_type,
)
@ -370,21 +379,27 @@ def encode_field(
encodeData of their contents
- Struct values are encoded recursively as hashStruct(value)
"""
EDT = EthereumDataType # local_cache_global
data_type = field.data_type
if data_type == EthereumDataType.BYTES:
if data_type == EDT.BYTES:
if field.size is None:
w.extend(keccak256(value))
else:
write_rightpad32(w, value)
elif data_type == EthereumDataType.STRING:
# write_rightpad32
assert len(value) <= 32
w.extend(value)
for _ in range(32 - len(value)):
w.append(0x00)
elif data_type == EDT.STRING:
w.extend(keccak256(value))
elif data_type == EthereumDataType.INT:
elif data_type == EDT.INT:
write_leftpad32(w, value, signed=True)
elif data_type in (
EthereumDataType.UINT,
EthereumDataType.BOOL,
EthereumDataType.ADDRESS,
EDT.UINT,
EDT.BOOL,
EDT.ADDRESS,
):
write_leftpad32(w, value)
else:
@ -405,93 +420,90 @@ def write_leftpad32(w: HashWriter, value: bytes, signed: bool = False) -> None:
w.extend(value)
def write_rightpad32(w: HashWriter, value: bytes) -> None:
assert len(value) <= 32
w.extend(value)
for _ in range(32 - len(value)):
w.append(0x00)
def validate_value(field: EthereumFieldType, value: bytes) -> None:
def _validate_value(field: EthereumFieldType, value: bytes) -> None:
"""
Make sure the byte data we receive are not corrupted or incorrect.
Raise wire.DataError if encountering a problem, so clients are notified.
Raise DataError if encountering a problem, so clients are notified.
"""
# Checking if the size corresponds to what is defined in types,
# and also setting our maximum supported size in bytes
if field.size is not None:
if len(value) != field.size:
raise wire.DataError("Invalid length")
raise DataError("Invalid length")
else:
if len(value) > _MAX_VALUE_BYTE_SIZE:
raise wire.DataError(f"Invalid length, bigger than {_MAX_VALUE_BYTE_SIZE}")
raise DataError(f"Invalid length, bigger than {_MAX_VALUE_BYTE_SIZE}")
# Specific tests for some data types
if field.data_type == EthereumDataType.BOOL:
if value not in (b"\x00", b"\x01"):
raise wire.DataError("Invalid boolean value")
raise DataError("Invalid boolean value")
elif field.data_type == EthereumDataType.ADDRESS:
if len(value) != 20:
raise wire.DataError("Invalid address")
raise DataError("Invalid address")
elif field.data_type == EthereumDataType.STRING:
try:
value.decode()
except UnicodeError:
raise wire.DataError("Invalid UTF-8")
raise DataError("Invalid UTF-8")
def validate_field_type(field: EthereumFieldType) -> None:
"""
Make sure the field type is consistent with our expectation.
Raise wire.DataError if encountering a problem, so clients are notified.
Raise DataError if encountering a problem, so clients are notified.
"""
EDT = EthereumDataType # local_cache_global
data_type = field.data_type
# entry_type is only for arrays
if data_type == EthereumDataType.ARRAY:
if data_type == EDT.ARRAY:
if field.entry_type is None:
raise wire.DataError("Missing entry_type in array")
raise DataError("Missing entry_type in array")
# We also need to validate it recursively
validate_field_type(field.entry_type)
else:
if field.entry_type is not None:
raise wire.DataError("Unexpected entry_type in nonarray")
raise DataError("Unexpected entry_type in nonarray")
# struct_name is only for structs
if data_type == EthereumDataType.STRUCT:
if data_type == EDT.STRUCT:
if field.struct_name is None:
raise wire.DataError("Missing struct_name in struct")
raise DataError("Missing struct_name in struct")
else:
if field.struct_name is not None:
raise wire.DataError("Unexpected struct_name in nonstruct")
raise DataError("Unexpected struct_name in nonstruct")
size = field.size # local_cache_attribute
# size is special for each type
if data_type == EthereumDataType.STRUCT:
if field.size is None:
raise wire.DataError("Missing size in struct")
elif data_type == EthereumDataType.BYTES:
if field.size is not None and not 1 <= field.size <= 32:
raise wire.DataError("Invalid size in bytes")
if data_type == EDT.STRUCT:
if size is None:
raise DataError("Missing size in struct")
elif data_type == EDT.BYTES:
if size is not None and not 1 <= size <= 32:
raise DataError("Invalid size in bytes")
elif data_type in (
EthereumDataType.UINT,
EthereumDataType.INT,
EDT.UINT,
EDT.INT,
):
if field.size is None or not 1 <= field.size <= 32:
raise wire.DataError("Invalid size in int/uint")
if size is None or not 1 <= size <= 32:
raise DataError("Invalid size in int/uint")
elif data_type in (
EthereumDataType.STRING,
EthereumDataType.BOOL,
EthereumDataType.ADDRESS,
EDT.STRING,
EDT.BOOL,
EDT.ADDRESS,
):
if field.size is not None:
raise wire.DataError("Unexpected size in str/bool/addr")
if size is not None:
raise DataError("Unexpected size in str/bool/addr")
async def get_array_size(ctx: Context, member_path: list[int]) -> int:
async def _get_array_size(ctx: Context, member_path: list[int]) -> int:
"""Get the length of an array at specific `member_path` from the client."""
from trezor.messages import EthereumFieldType
# Field type for getting the array length from client, so we can check the return value
ARRAY_LENGTH_TYPE = EthereumFieldType(data_type=EthereumDataType.UINT, size=2)
length_value = await get_value(ctx, ARRAY_LENGTH_TYPE, member_path)
@ -504,18 +516,20 @@ async def get_value(
member_value_path: list[int],
) -> bytes:
"""Get a single value from the client and perform its validation."""
from trezor.messages import EthereumTypedDataValueAck, EthereumTypedDataValueRequest
req = EthereumTypedDataValueRequest(
member_path=member_value_path,
)
res = await ctx.call(req, EthereumTypedDataValueAck)
value = res.value
validate_value(field=field, value=value)
_validate_value(field=field, value=value)
return value
async def get_name_and_version_for_domain(
async def _get_name_and_version_for_domain(
ctx: Context, typed_data_envelope: TypedDataEnvelope
) -> tuple[bytes, bytes]:
domain_name = b"unknown"

File diff suppressed because it is too large Load Diff

@ -2,6 +2,20 @@
# (by running `make templates` in `core`)
# do not edit manually!
# fmt: off
# NOTE: returning a tuple instead of `TokenInfo` from the "data" function
# saves 5600 bytes of flash size. Implementing the `_token_iterator`
# instead of if-tree approach saves another 5600 bytes.
# NOTE: interestingly, it did not save much flash size to use smaller
# parts of the address, for example address length of 10 bytes saves
# 1 byte per entry, so 1887 bytes overall (and further decrease does not help).
# (The idea was not having to store the whole address, even a smaller part
# of it has enough collision-resistance.)
# (In the if-tree approach the address length did not have any effect whatsoever.)
from typing import Iterator
<%
from collections import defaultdict
@ -22,11 +36,20 @@ UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)
def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo:
for addr, symbol, decimal in _token_iterator(chain_id):
if address == addr:
return TokenInfo(symbol, decimal)
return UNKNOWN_TOKEN
def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int]]:
% for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items():
if chain_id == ${token_chain_id}:
% for t in tokens:
if address == ${black_repr(t.address_bytes)}:
return TokenInfo(${black_repr(t.symbol)}, ${t.decimals}) # ${t.chain} / ${t.name.strip()}
yield ( # address, symbol, decimals
${black_repr(t.address_bytes)},
${black_repr(t.symbol)},
${t.decimals},
)
% endfor
% endfor
return UNKNOWN_TOKEN

@ -1,42 +1,42 @@
from typing import TYPE_CHECKING
from trezor import wire
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import Success
from trezor.ui.layouts import confirm_signverify, show_success
from apps.common.signverify import decode_message
from .helpers import address_from_bytes, bytes_from_address
from .sign_message import message_digest
if TYPE_CHECKING:
from trezor.messages import EthereumVerifyMessage
from trezor.messages import EthereumVerifyMessage, Success
from trezor.wire import Context
async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
from trezor.wire import DataError
from trezor.crypto.curve import secp256k1
from trezor.crypto.hashlib import sha3_256
from trezor.messages import Success
from trezor.ui.layouts import confirm_signverify, show_success
from apps.common.signverify import decode_message
from .helpers import address_from_bytes, bytes_from_address
from .sign_message import message_digest
digest = message_digest(msg.message)
if len(msg.signature) != 65:
raise wire.DataError("Invalid signature")
raise DataError("Invalid signature")
sig = bytearray([msg.signature[64]]) + msg.signature[:64]
pubkey = secp256k1.verify_recover(sig, digest)
if not pubkey:
raise wire.DataError("Invalid signature")
raise DataError("Invalid signature")
pkh = sha3_256(pubkey[1:], keccak=True).digest()[-20:]
address_bytes = bytes_from_address(msg.address)
if address_bytes != pkh:
raise wire.DataError("Invalid signature")
raise DataError("Invalid signature")
address = address_from_bytes(address_bytes)
await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address=address, verify=True
ctx, "ETH", decode_message(msg.message), address, verify=True
)
await show_success(ctx, "verify_message", "The signature is valid.")

@ -11,7 +11,7 @@ from trezor.enums import EthereumDataType as EDT
if not utils.BITCOIN_ONLY:
from apps.ethereum.sign_typed_data import (
encode_field,
validate_value,
_validate_value,
validate_field_type,
keccak256,
TypedDataEnvelope,
@ -594,10 +594,10 @@ class TestEthereumSignTypedData(unittest.TestCase):
for field, valid_values, invalid_values in VECTORS_VALID_INVALID:
for valid_value in valid_values:
validate_value(field=field, value=valid_value)
_validate_value(field=field, value=valid_value)
for invalid_value in invalid_values:
with self.assertRaises(wire.DataError):
validate_value(field=field, value=invalid_value)
_validate_value(field=field, value=invalid_value)
def test_validate_field_type(self):
ET = EFT(data_type=EDT.BYTES, size=8)

Loading…
Cancel
Save