1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-29 09:58:47 +00:00

feat(core): added Ethereum definitions object - handle definitions from host

This commit is contained in:
Martin Novak 2022-06-23 11:56:11 +02:00
parent 3cf9b7c235
commit 824abe7d2f
15 changed files with 355 additions and 84 deletions

View File

@ -491,6 +491,8 @@ if not utils.BITCOIN_ONLY:
import apps.eos.writers
apps.ethereum
import apps.ethereum
apps.ethereum.definitions
import apps.ethereum.definitions
apps.ethereum.get_address
import apps.ethereum.get_address
apps.ethereum.get_public_key

View File

@ -0,0 +1,192 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Tuple
from apps.ethereum import tokens
from trezor import protobuf, wire
from trezor.crypto.curve import ed25519
from trezor.enums import EthereumDefinitionType
from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo
from . import networks
if TYPE_CHECKING:
from trezor.protobuf import MessageType
from .networks import NetworkInfo
from .tokens import TokenInfo
DEFINITIONS_PUBLIC_KEY = b""
MIN_DATA_VERSION = 1
FORMAT_VERSION = "trzd1"
if __debug__:
DEFINITIONS_DEV_PUBLIC_KEY = b""
class EthereumDefinitionParser:
def __init__(self, definition_bytes: bytes) -> None:
if len(definition_bytes) <= (8 + 1 + 4 + 64):
raise wire.DataError("Received Ethereum definition is probably malformed (too few data).")
self.format_version: str = definition_bytes[:8].rstrip(b'\0').decode("utf-8")
self.definition_type: int = definition_bytes[8]
self.data_version: int = int.from_bytes(definition_bytes[9:13], 'big')
self.clean_payload = definition_bytes[13:-64]
self.payload = definition_bytes[:-64]
self.signature = definition_bytes[-64:]
def decode_definition(
definition: bytes, expected_type: EthereumDefinitionType
) -> NetworkInfo | TokenInfo:
# check network definition
parsed_definition = EthereumDefinitionParser(definition)
# first check format version
if parsed_definition.format_version != FORMAT_VERSION:
raise wire.DataError("Used different Ethereum definition format version.")
# second check the type of the data
if parsed_definition.definition_type != expected_type:
raise wire.DataError("Definition of invalid type for Ethereum.")
# third check data version
if parsed_definition.data_version < MIN_DATA_VERSION:
raise wire.DataError("Used Ethereum definition data version too low.")
# at the end verify the signature
if not ed25519.verify(DEFINITIONS_PUBLIC_KEY, parsed_definition.signature, parsed_definition.payload):
error_msg = wire.DataError("Ethereum definition signature is invalid.")
if __debug__:
# check against dev key
if not ed25519.verify(DEFINITIONS_DEV_PUBLIC_KEY, parsed_definition.signature, parsed_definition.payload):
raise error_msg
else:
raise error_msg
# decode it if it's OK
if expected_type == EthereumDefinitionType.NETWORK:
info = protobuf.decode(parsed_definition.payload, EthereumNetworkInfo, True)
# TODO: temporarily convert to internal class
if info is not None:
from .networks import NetworkInfo
info = NetworkInfo(
chain_id=info.chain_id,
slip44=info.slip44,
shortcut=info.shortcut,
name=info.name,
rskip60=info.rskip60
)
else:
info = protobuf.decode(parsed_definition.payload, EthereumTokenInfo, True)
# TODO: temporarily convert to internal class
if info is not None:
from .tokens import TokenInfo
info = TokenInfo(
symbol=info.symbol,
decimals=info.decimals,
address=info.address,
chain_id=info.chain_id,
name=info.name,
)
return info
def _get_network_definiton(encoded_network_definition: bytes | None, ref_chain_id: int | None = None) -> NetworkInfo | None:
if encoded_network_definition is None and ref_chain_id is None:
return None
if ref_chain_id is not None:
# if we have a built-in definition, use it
network = networks.by_chain_id(ref_chain_id)
if network is not None:
return network
if encoded_network_definition is not None:
# get definition if it was send
network = decode_definition(encoded_network_definition, EthereumDefinitionType.NETWORK)
# check referential chain_id with encoded chain_id
if network.chain_id != ref_chain_id:
raise wire.DataError("Invalid network definition - chain IDs not equal.")
return network
return None
def _get_token_definiton(encoded_token_definition: bytes | None, ref_chain_id: int | None = None, ref_address: int | None = None) -> TokenInfo:
if encoded_token_definition is None and (ref_chain_id is None or ref_address is None):
return None
# if we have a built-in definition, use it
if ref_chain_id is not None and ref_address is not None:
token = tokens.token_by_chain_address(ref_chain_id, ref_address)
if token is not tokens.UNKNOWN_TOKEN:
return token
if encoded_token_definition is not None:
# get definition if it was send
token = decode_definition(encoded_token_definition, EthereumDefinitionType.TOKEN)
# check token against ref_chain_id and ref_address
if (
(ref_chain_id is None or token.chain_id == ref_chain_id)
and (ref_address is None or token.address == ref_address)
):
return token
return tokens.UNKNOWN_TOKEN
class EthereumDefinitions:
"""Class that holds Ethereum definitions - network and tokens. Prefers built-in definitions over encoded ones."""
def __init__(
self,
encoded_network_definition: bytes | None = None,
encoded_token_definition: bytes | None = None,
ref_chain_id: int | None = None,
ref_token_address: int | None = None,
) -> None:
self.network = _get_network_definiton(encoded_network_definition, ref_chain_id)
self.token_dict: defaultdict[bytes, TokenInfo] = defaultdict(lambda: tokens.UNKNOWN_TOKEN)
# if we have some network, we can try to get token
if self.network is not None:
received_token = _get_token_definiton(encoded_token_definition, self.network.chain_id, ref_token_address)
if received_token is not tokens.UNKNOWN_TOKEN:
self.token_dict[received_token.address] = received_token
def get_definitions_from_msg(msg: MessageType) -> EthereumDefinitions:
encoded_network_definition: bytes | None = None
encoded_token_definition: bytes | None = None
chain_id: int | None = None
token_address: int | None = None
# first try to get both definitions
try:
if msg.definitions is not None:
encoded_network_definition = msg.definitions.encoded_network
encoded_token_definition = msg.definitions.encoded_token
except AttributeError:
encoded_network_definition = msg.encoded_network
# get chain_id
try:
chain_id = msg.chain_id
except AttributeError:
pass
# get token_address
try:
token_address = msg.to
except AttributeError:
pass
return EthereumDefinitions(encoded_network_definition, encoded_token_definition, chain_id, token_address)

View File

@ -1,6 +1,13 @@
from typing import TYPE_CHECKING
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address
from apps.common import paths
from . import networks
from .helpers import address_from_bytes
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
if TYPE_CHECKING:
from trezor.messages import EthereumGetAddress, EthereumAddress
@ -8,10 +15,12 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain
from . import definitions
@with_keychain_from_path(*PATTERNS_ADDRESS)
@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
async def get_address(
ctx: Context, msg: EthereumGetAddress, keychain: Keychain
ctx: Context, msg: EthereumGetAddress, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumAddress:
from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address
@ -25,8 +34,12 @@ async def get_address(
node = keychain.derive(address_n)
if len(address_n) > 1: # path has slip44 network identifier
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF)
if len(msg.address_n) > 1: # path has slip44 network identifier
slip44 = msg.address_n[1] & 0x7FFF_FFFF
if slip44 == defs.network.slip44:
network = defs.network
else:
network = networks.by_slip44(slip44)
else:
network = None
address = address_from_bytes(node.ethereum_pubkeyhash(), network)

View File

@ -1,12 +1,14 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from .networks import by_chain_id
if TYPE_CHECKING:
from trezor.messages import EthereumFieldType
from .networks import NetworkInfo
def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None) -> str:
def address_from_bytes(address_bytes: bytes, network: NetworkInfo = by_chain_id(1)) -> str:
"""
Converts address in bytes to a checksummed string as defined
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md

View File

@ -3,10 +3,12 @@ from typing import TYPE_CHECKING
from apps.common import paths
from apps.common.keychain import get_keychain
from . import CURVE, networks
from . import CURVE, networks, definitions
if TYPE_CHECKING:
from typing import Callable, Iterable, TypeVar
from typing import Awaitable, Callable, Iterable, TypeVar
from apps.common.keychain import Keychain
from trezor.wire import Context
@ -19,19 +21,17 @@ if TYPE_CHECKING:
EthereumSignTypedData,
)
from apps.common.keychain import MsgOut, Handler, HandlerWithKeychain
from apps.common.keychain import MsgIn as MsgInGeneric, MsgOut, Handler, HandlerWithKeychain
EthereumMessages = (
EthereumGetAddress
| EthereumGetPublicKey
| EthereumSignTx
| EthereumSignMessage
| EthereumSignTypedData
)
MsgIn = TypeVar("MsgIn", bound=EthereumMessages)
# messages for "with_keychain_from_path" decorator
MsgInKeychainPath = TypeVar("MsgInKeychainPath", bound=EthereumGetPublicKey)
# messages for "with_keychain_from_path_and_defs" decorator
MsgInKeychainPathDefs = TypeVar("MsgInKeychainPathDefs", bound=EthereumGetAddress | EthereumSignMessage | EthereumSignTypedData)
# messages for "with_keychain_from_chain_id_and_defs" decorator
MsgInKeychainChainIdDefs = TypeVar("MsgInKeychainChainIdDefs", bound=EthereumSignTx | EthereumSignTxEIP1559)
EthereumSignTxAny = EthereumSignTx | EthereumSignTxEIP1559
MsgInChainId = TypeVar("MsgInChainId", bound=EthereumSignTxAny)
# TODO: check the types of messages
HandlerWithKeychainAndDefinitions = Callable[[Context, MsgInGeneric, Keychain, definitions.EthereumDefinitions], Awaitable[MsgOut]]
# We believe Ethereum should use 44'/60'/a' for everything, because it is
@ -48,13 +48,20 @@ PATTERNS_ADDRESS = (
def _schemas_from_address_n(
patterns: Iterable[str], address_n: paths.Bip32Path
patterns: Iterable[str], address_n: paths.Bip32Path, network_info: networks.NetworkInfo | None
) -> Iterable[paths.PathSchema]:
if len(address_n) < 2:
return ()
slip44_hardened = address_n[1]
if slip44_hardened not in networks.all_slip44_ids_hardened():
def _get_hardened_slip44_networks():
if network_info is not None:
yield network_info.slip44 | paths.HARDENED
yield from networks.all_slip44_ids_hardened()
# check with network from definitions and if that is None then with built-in ones
if slip44_hardened not in _get_hardened_slip44_networks():
return ()
if not slip44_hardened & paths.HARDENED:
@ -67,10 +74,11 @@ def _schemas_from_address_n(
def with_keychain_from_path(
*patterns: str,
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
schemas = _schemas_from_address_n(patterns, msg.address_n)
) -> Callable[[HandlerWithKeychain[MsgInKeychainPath, MsgOut]], Handler[MsgInKeychainPath, MsgOut]]:
def decorator(func: HandlerWithKeychain[MsgInKeychainPath, MsgOut]) -> Handler[MsgInKeychainPath, MsgOut]:
async def wrapper(ctx: Context, msg: MsgInKeychainPath) -> MsgOut:
defs = definitions.get_definitions_from_msg(msg)
schemas = _schemas_from_address_n(patterns, msg.address_n, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain)
@ -80,17 +88,32 @@ def with_keychain_from_path(
return decorator
def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]:
info = networks.by_chain_id(msg.chain_id)
def with_keychain_from_path_and_defs(
*patterns: str,
) -> Callable[[HandlerWithKeychainAndDefinitions[MsgInKeychainPathDefs, MsgOut]], Handler[MsgInKeychainPathDefs, MsgOut]]:
def decorator(func: HandlerWithKeychainAndDefinitions[MsgInKeychainPathDefs, MsgOut]) -> Handler[MsgInKeychainPathDefs, MsgOut]:
async def wrapper(ctx: Context, msg: MsgInKeychainPathDefs) -> MsgOut:
defs = definitions.get_definitions_from_msg(msg)
schemas = _schemas_from_address_n(patterns, msg.address_n, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain, defs)
return wrapper
return decorator
def _schemas_from_chain_id(network_info: networks.NetworkInfo | None) -> Iterable[paths.PathSchema]:
slip44_id: tuple[int, ...]
if info is None:
if network_info is None:
# allow Ethereum or testnet paths for unknown networks
slip44_id = (60, 1)
elif info.slip44 not in (60, 1):
elif network_info.slip44 not in (60, 1):
# allow cross-signing with Ethereum unless it's testnet
slip44_id = (info.slip44, 60)
slip44_id = (network_info.slip44, 60)
else:
slip44_id = (info.slip44,)
slip44_id = (network_info.slip44,)
schemas = [
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
@ -98,14 +121,15 @@ def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]
return [s.copy() for s in schemas]
def with_keychain_from_chain_id(
func: HandlerWithKeychain[MsgInChainId, MsgOut]
) -> Handler[MsgInChainId, MsgOut]:
def with_keychain_from_chain_id_and_defs(
func: HandlerWithKeychainAndDefinitions[MsgInKeychainChainIdDefs, MsgOut]
) -> Handler[MsgInKeychainChainIdDefs, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut:
schemas = _schemas_from_chain_id(msg)
async def wrapper(ctx: Context, msg: MsgInKeychainChainIdDefs) -> MsgOut:
defs = definitions.get_definitions_from_msg(msg)
schemas = _schemas_from_chain_id(defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain)
return await func(ctx, msg, keychain, defs)
return wrapper

View File

@ -12,7 +12,7 @@ from trezor.ui.layouts import (
)
from . import networks
from .helpers import decode_typed_data
from .helpers import address_from_bytes, decode_typed_data, get_type_name
if TYPE_CHECKING:
from typing import Awaitable, Iterable
@ -21,25 +21,29 @@ if TYPE_CHECKING:
from trezor.wire import Context
from . import tokens
from . import tokens
def require_confirm_tx(
ctx: Context,
to_bytes: bytes,
value: int,
chain_id: int,
token: tokens.TokenInfo | None = None,
network: networks.NetworkInfo,
token: tokens.TokenInfo,
) -> Awaitable[None]:
from .helpers import address_from_bytes
from trezor.ui.layouts import confirm_output
if to_bytes:
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
to_str = address_from_bytes(to_bytes, network)
else:
to_str = "new contract?"
return confirm_output(
ctx,
to_str,
format_ethereum_amount(value, token, chain_id),
address=to_str,
amount=format_ethereum_amount(value, token, network),
font_amount=ui.BOLD,
color_to=ui.GREY,
br_code=ButtonRequestType.SignTx,
)
@ -49,19 +53,19 @@ async def require_confirm_fee(
spending: int,
gas_price: int,
gas_limit: int,
chain_id: int,
token: tokens.TokenInfo | None = None,
network: networks.NetworkInfo,
token: tokens.TokenInfo,
) -> None:
await confirm_amount(
ctx,
title="Confirm fee",
description="Gas price:",
amount=format_ethereum_amount(gas_price, None, chain_id),
amount=format_ethereum_amount(gas_price, None, network),
)
await confirm_total(
ctx,
total_amount=format_ethereum_amount(spending, token, chain_id),
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, chain_id),
total_amount=format_ethereum_amount(spending, token, network),
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network),
total_label="Amount sent:",
fee_label="Maximum fee:",
)
@ -73,25 +77,25 @@ async def require_confirm_eip1559_fee(
max_priority_fee: int,
max_gas_fee: int,
gas_limit: int,
chain_id: int,
token: tokens.TokenInfo | None = None,
network: networks.NetworkInfo,
token: tokens.TokenInfo,
) -> None:
await confirm_amount(
ctx,
"Confirm fee",
format_ethereum_amount(max_gas_fee, None, chain_id),
format_ethereum_amount(max_gas_fee, None, network),
"Maximum fee per gas",
)
await confirm_amount(
ctx,
"Confirm fee",
format_ethereum_amount(max_priority_fee, None, chain_id),
format_ethereum_amount(max_priority_fee, None, network),
"Priority fee per gas",
)
await confirm_total(
ctx,
format_ethereum_amount(spending, token, chain_id),
format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
format_ethereum_amount(spending, token, network),
format_ethereum_amount(max_gas_fee * gas_limit, None, network),
total_label="Amount sent:",
fee_label="Maximum fee:",
)
@ -249,7 +253,7 @@ async def confirm_typed_value(
def format_ethereum_amount(
value: int, token: tokens.TokenInfo | None, chain_id: int
value: int, token: tokens.TokenInfo | None, network_info: networks.NetworkInfo | None
) -> str:
from trezor.strings import format_amount
@ -257,7 +261,10 @@ def format_ethereum_amount(
suffix = token.symbol
decimals = token.decimals
else:
suffix = networks.shortcut_by_chain_id(chain_id)
if network_info is not None:
suffix = network_info.shortcut
else:
suffix = networks.UNKNOWN_NETWORK_SHORTCUT
decimals = 18
# Don't want to display wei values for tokens with small decimal numbers

View File

@ -22,11 +22,12 @@ if TYPE_CHECKING:
bool # rskip60
]
# fmt: on
UNKNOWN_NETWORK_SHORTCUT = "UNKN"
def shortcut_by_chain_id(chain_id: int) -> str:
n = by_chain_id(chain_id)
return n.shortcut if n is not None else "UNKN"
return n.shortcut if n is not None else UNKNOWN_NETWORK_SHORTCUT
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:

View File

@ -22,11 +22,12 @@ if TYPE_CHECKING:
bool # rskip60
]
# fmt: on
UNKNOWN_NETWORK_SHORTCUT = "UNKN"
def shortcut_by_chain_id(chain_id: int) -> str:
n = by_chain_id(chain_id)
return n.shortcut if n is not None else "UNKN"
return n.shortcut if n is not None else UNKNOWN_NETWORK_SHORTCUT
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
if TYPE_CHECKING:
from trezor.messages import EthereumSignMessage, EthereumMessageSignature
@ -8,6 +8,8 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain
from . import definitions
def message_digest(message: bytes) -> bytes:
from trezor.crypto.hashlib import sha3_256
@ -21,9 +23,9 @@ def message_digest(message: bytes) -> bytes:
return h.get_digest()
@with_keychain_from_path(*PATTERNS_ADDRESS)
@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
async def sign_message(
ctx: Context, msg: EthereumSignMessage, keychain: Keychain
ctx: Context, msg: EthereumSignMessage, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumMessageSignature:
from trezor.crypto.curve import secp256k1
from trezor.messages import EthereumMessageSignature
@ -37,7 +39,7 @@ async def sign_message(
await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n)
address = address_from_bytes(node.ethereum_pubkeyhash())
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address, verify=False
)

View File

@ -5,15 +5,16 @@ from trezor.messages import EthereumTxRequest
from trezor.wire import DataError
from .helpers import bytes_from_address
from .keychain import with_keychain_from_chain_id
from .keychain import with_keychain_from_chain_id_and_defs
if TYPE_CHECKING:
from collections import defaultdict
from apps.common.keychain import Keychain
from trezor.messages import EthereumSignTx, EthereumTxAck
from trezor.wire import Context
from .keychain import EthereumSignTxAny
from . import tokens
from . import tokens, definitions
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
@ -22,9 +23,9 @@ if TYPE_CHECKING:
MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id
@with_keychain_from_chain_id_and_defs
async def sign_tx(
ctx: Context, msg: EthereumSignTx, keychain: Keychain
ctx: Context, msg: EthereumSignTx, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumTxRequest:
from trezor.utils import HashWriter
from trezor.crypto.hashlib import sha3_256
@ -45,11 +46,11 @@ async def sign_tx(
await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg)
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict)
data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token)
await require_confirm_tx(ctx, recipient, value, defs.network, token)
if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
@ -58,7 +59,7 @@ async def sign_tx(
value,
int.from_bytes(msg.gas_price, "big"),
int.from_bytes(msg.gas_limit, "big"),
msg.chain_id,
defs.network,
token,
)
@ -100,7 +101,7 @@ async def sign_tx(
async def handle_erc20(
ctx: Context, msg: EthereumSignTxAny
ctx: Context, msg: EthereumSignTxAny, token_dict: defaultdict[bytes, tokens.TokenInfo]
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
from .layout import require_confirm_unknown_token
from . import tokens
@ -118,7 +119,7 @@ async def handle_erc20(
and data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
):
token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
token = token_dict[address_bytes]
recipient = data_initial_chunk[16:36]
value = int.from_bytes(data_initial_chunk[36:68], "big")

View File

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
from trezor.crypto import rlp
from .helpers import bytes_from_address
from .keychain import with_keychain_from_chain_id
from .keychain import with_keychain_from_chain_id_and_defs
if TYPE_CHECKING:
from trezor.messages import (
@ -13,6 +13,7 @@ if TYPE_CHECKING:
EthereumTxRequest,
)
from . import definitions
from apps.common.keychain import Keychain
from trezor.wire import Context
@ -28,9 +29,9 @@ def access_list_item_length(item: EthereumAccessList) -> int:
)
@with_keychain_from_chain_id
@with_keychain_from_chain_id_and_defs
async def sign_tx_eip1559(
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumTxRequest:
from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter
@ -56,11 +57,11 @@ async def sign_tx_eip1559(
await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg)
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict)
data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token)
await require_confirm_tx(ctx, recipient, value, defs.network, token)
if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
@ -70,7 +71,7 @@ async def sign_tx_eip1559(
int.from_bytes(msg.max_priority_fee, "big"),
int.from_bytes(msg.max_gas_fee, "big"),
int.from_bytes(gas_limit, "big"),
msg.chain_id,
defs.network,
token,
)

View File

@ -5,7 +5,7 @@ from trezor.enums import EthereumDataType
from trezor.wire import DataError
from .helpers import get_type_name
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
from .layout import should_show_struct
if TYPE_CHECKING:
@ -20,14 +20,16 @@ if TYPE_CHECKING:
EthereumTypedDataStructAck,
)
from . import definitions
# Maximum data size we support
_MAX_VALUE_BYTE_SIZE = const(1024)
@with_keychain_from_path(*PATTERNS_ADDRESS)
@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
async def sign_typed_data(
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumTypedDataSignature:
from trezor.crypto.curve import secp256k1
from apps.common import paths
@ -46,7 +48,7 @@ async def sign_typed_data(
)
return EthereumTypedDataSignature(
address=address_from_bytes(node.ethereum_pubkeyhash()),
address=address_from_bytes(node.ethereum_pubkeyhash(), defs.network),
signature=signature[1:] + signature[0:1],
)

View File

@ -18,9 +18,19 @@ from typing import Iterator
class TokenInfo:
def __init__(self, symbol: str, decimals: int) -> None:
def __init__(
self,
symbol: str,
decimals: int,
address: bytes = None,
chain_id: int = None,
name: str = None,
) -> None:
self.symbol = symbol
self.decimals = decimals
self.address = address
self.chain_id = chain_id
self.name = name
UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)

View File

@ -27,9 +27,19 @@ def group_tokens(tokens):
%>\
class TokenInfo:
def __init__(self, symbol: str, decimals: int) -> None:
def __init__(
self,
symbol: str,
decimals: int,
address: bytes = None,
chain_id: int = None,
name: str = None,
) -> None:
self.symbol = symbol
self.decimals = decimals
self.address = address
self.chain_id = chain_id
self.name = name
UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)

View File

@ -14,9 +14,12 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
from apps.common.signverify import decode_message
from . import definitions
from .helpers import address_from_bytes, bytes_from_address
from .sign_message import message_digest
defs = definitions.get_definitions_from_msg(msg)
digest = message_digest(msg.message)
if len(msg.signature) != 65:
raise DataError("Invalid signature")
@ -33,7 +36,7 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
if address_bytes != pkh:
raise DataError("Invalid signature")
address = address_from_bytes(address_bytes)
address = address_from_bytes(address_bytes, defs.network)
await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address, verify=True