diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 50a7aa3b7f..cc0615941a 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -491,6 +491,8 @@ if not utils.BITCOIN_ONLY: import apps.eos.writers apps.ethereum import apps.ethereum + apps.ethereum.definitions + import apps.ethereum.definitions apps.ethereum.get_address import apps.ethereum.get_address apps.ethereum.get_public_key diff --git a/core/src/apps/ethereum/definitions.py b/core/src/apps/ethereum/definitions.py new file mode 100644 index 0000000000..31b8d8541c --- /dev/null +++ b/core/src/apps/ethereum/definitions.py @@ -0,0 +1,192 @@ +from collections import defaultdict +from typing import TYPE_CHECKING, Tuple + +from apps.ethereum import tokens + +from trezor import protobuf, wire +from trezor.crypto.curve import ed25519 +from trezor.enums import EthereumDefinitionType +from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo + +from . import networks + +if TYPE_CHECKING: + from trezor.protobuf import MessageType + + from .networks import NetworkInfo + from .tokens import TokenInfo + + +DEFINITIONS_PUBLIC_KEY = b"" +MIN_DATA_VERSION = 1 +FORMAT_VERSION = "trzd1" + +if __debug__: + DEFINITIONS_DEV_PUBLIC_KEY = b"" + + +class EthereumDefinitionParser: + def __init__(self, definition_bytes: bytes) -> None: + if len(definition_bytes) <= (8 + 1 + 4 + 64): + raise wire.DataError("Received Ethereum definition is probably malformed (too few data).") + + self.format_version: str = definition_bytes[:8].rstrip(b'\0').decode("utf-8") + self.definition_type: int = definition_bytes[8] + self.data_version: int = int.from_bytes(definition_bytes[9:13], 'big') + self.clean_payload = definition_bytes[13:-64] + self.payload = definition_bytes[:-64] + self.signature = definition_bytes[-64:] + + +def decode_definition( + definition: bytes, expected_type: EthereumDefinitionType +) -> NetworkInfo | TokenInfo: + # check network definition + parsed_definition = EthereumDefinitionParser(definition) + + # first check format version + if parsed_definition.format_version != FORMAT_VERSION: + raise wire.DataError("Used different Ethereum definition format version.") + + # second check the type of the data + if parsed_definition.definition_type != expected_type: + raise wire.DataError("Definition of invalid type for Ethereum.") + + # third check data version + if parsed_definition.data_version < MIN_DATA_VERSION: + raise wire.DataError("Used Ethereum definition data version too low.") + + # at the end verify the signature + if not ed25519.verify(DEFINITIONS_PUBLIC_KEY, parsed_definition.signature, parsed_definition.payload): + error_msg = wire.DataError("Ethereum definition signature is invalid.") + if __debug__: + # check against dev key + if not ed25519.verify(DEFINITIONS_DEV_PUBLIC_KEY, parsed_definition.signature, parsed_definition.payload): + raise error_msg + else: + raise error_msg + + # decode it if it's OK + if expected_type == EthereumDefinitionType.NETWORK: + info = protobuf.decode(parsed_definition.payload, EthereumNetworkInfo, True) + + # TODO: temporarily convert to internal class + if info is not None: + from .networks import NetworkInfo + info = NetworkInfo( + chain_id=info.chain_id, + slip44=info.slip44, + shortcut=info.shortcut, + name=info.name, + rskip60=info.rskip60 + ) + else: + info = protobuf.decode(parsed_definition.payload, EthereumTokenInfo, True) + + # TODO: temporarily convert to internal class + if info is not None: + from .tokens import TokenInfo + info = TokenInfo( + symbol=info.symbol, + decimals=info.decimals, + address=info.address, + chain_id=info.chain_id, + name=info.name, + ) + + return info + + +def _get_network_definiton(encoded_network_definition: bytes | None, ref_chain_id: int | None = None) -> NetworkInfo | None: + if encoded_network_definition is None and ref_chain_id is None: + return None + + if ref_chain_id is not None: + # if we have a built-in definition, use it + network = networks.by_chain_id(ref_chain_id) + if network is not None: + return network + + if encoded_network_definition is not None: + # get definition if it was send + network = decode_definition(encoded_network_definition, EthereumDefinitionType.NETWORK) + + # check referential chain_id with encoded chain_id + if network.chain_id != ref_chain_id: + raise wire.DataError("Invalid network definition - chain IDs not equal.") + + return network + + return None + + +def _get_token_definiton(encoded_token_definition: bytes | None, ref_chain_id: int | None = None, ref_address: int | None = None) -> TokenInfo: + if encoded_token_definition is None and (ref_chain_id is None or ref_address is None): + return None + + # if we have a built-in definition, use it + if ref_chain_id is not None and ref_address is not None: + token = tokens.token_by_chain_address(ref_chain_id, ref_address) + if token is not tokens.UNKNOWN_TOKEN: + return token + + if encoded_token_definition is not None: + # get definition if it was send + token = decode_definition(encoded_token_definition, EthereumDefinitionType.TOKEN) + + # check token against ref_chain_id and ref_address + if ( + (ref_chain_id is None or token.chain_id == ref_chain_id) + and (ref_address is None or token.address == ref_address) + ): + return token + + return tokens.UNKNOWN_TOKEN + + +class EthereumDefinitions: + """Class that holds Ethereum definitions - network and tokens. Prefers built-in definitions over encoded ones.""" + def __init__( + self, + encoded_network_definition: bytes | None = None, + encoded_token_definition: bytes | None = None, + ref_chain_id: int | None = None, + ref_token_address: int | None = None, + ) -> None: + self.network = _get_network_definiton(encoded_network_definition, ref_chain_id) + self.token_dict: defaultdict[bytes, TokenInfo] = defaultdict(lambda: tokens.UNKNOWN_TOKEN) + + # if we have some network, we can try to get token + if self.network is not None: + received_token = _get_token_definiton(encoded_token_definition, self.network.chain_id, ref_token_address) + if received_token is not tokens.UNKNOWN_TOKEN: + self.token_dict[received_token.address] = received_token + + +def get_definitions_from_msg(msg: MessageType) -> EthereumDefinitions: + encoded_network_definition: bytes | None = None + encoded_token_definition: bytes | None = None + chain_id: int | None = None + token_address: int | None = None + + # first try to get both definitions + try: + if msg.definitions is not None: + encoded_network_definition = msg.definitions.encoded_network + encoded_token_definition = msg.definitions.encoded_token + except AttributeError: + encoded_network_definition = msg.encoded_network + + # get chain_id + try: + chain_id = msg.chain_id + except AttributeError: + pass + + # get token_address + try: + token_address = msg.to + except AttributeError: + pass + + return EthereumDefinitions(encoded_network_definition, encoded_token_definition, chain_id, token_address) diff --git a/core/src/apps/ethereum/get_address.py b/core/src/apps/ethereum/get_address.py index f33a5390a7..f031a77bd3 100644 --- a/core/src/apps/ethereum/get_address.py +++ b/core/src/apps/ethereum/get_address.py @@ -1,6 +1,13 @@ from typing import TYPE_CHECKING -from .keychain import PATTERNS_ADDRESS, with_keychain_from_path +from trezor.messages import EthereumAddress +from trezor.ui.layouts import show_address + +from apps.common import paths + +from . import networks +from .helpers import address_from_bytes +from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs if TYPE_CHECKING: from trezor.messages import EthereumGetAddress, EthereumAddress @@ -8,10 +15,12 @@ if TYPE_CHECKING: from apps.common.keychain import Keychain + from . import definitions -@with_keychain_from_path(*PATTERNS_ADDRESS) + +@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS) async def get_address( - ctx: Context, msg: EthereumGetAddress, keychain: Keychain + ctx: Context, msg: EthereumGetAddress, keychain: Keychain, defs: definitions.EthereumDefinitions ) -> EthereumAddress: from trezor.messages import EthereumAddress from trezor.ui.layouts import show_address @@ -25,8 +34,12 @@ async def get_address( node = keychain.derive(address_n) - if len(address_n) > 1: # path has slip44 network identifier - network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF) + if len(msg.address_n) > 1: # path has slip44 network identifier + slip44 = msg.address_n[1] & 0x7FFF_FFFF + if slip44 == defs.network.slip44: + network = defs.network + else: + network = networks.by_slip44(slip44) else: network = None address = address_from_bytes(node.ethereum_pubkeyhash(), network) diff --git a/core/src/apps/ethereum/helpers.py b/core/src/apps/ethereum/helpers.py index 7560522824..5f931f4593 100644 --- a/core/src/apps/ethereum/helpers.py +++ b/core/src/apps/ethereum/helpers.py @@ -1,12 +1,14 @@ from typing import TYPE_CHECKING from ubinascii import hexlify +from .networks import by_chain_id + if TYPE_CHECKING: from trezor.messages import EthereumFieldType from .networks import NetworkInfo -def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None) -> str: +def address_from_bytes(address_bytes: bytes, network: NetworkInfo = by_chain_id(1)) -> str: """ Converts address in bytes to a checksummed string as defined in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md diff --git a/core/src/apps/ethereum/keychain.py b/core/src/apps/ethereum/keychain.py index 8243b2942d..81a4748c79 100644 --- a/core/src/apps/ethereum/keychain.py +++ b/core/src/apps/ethereum/keychain.py @@ -3,10 +3,12 @@ from typing import TYPE_CHECKING from apps.common import paths from apps.common.keychain import get_keychain -from . import CURVE, networks +from . import CURVE, networks, definitions if TYPE_CHECKING: - from typing import Callable, Iterable, TypeVar + from typing import Awaitable, Callable, Iterable, TypeVar + + from apps.common.keychain import Keychain from trezor.wire import Context @@ -19,19 +21,17 @@ if TYPE_CHECKING: EthereumSignTypedData, ) - from apps.common.keychain import MsgOut, Handler, HandlerWithKeychain + from apps.common.keychain import MsgIn as MsgInGeneric, MsgOut, Handler, HandlerWithKeychain - EthereumMessages = ( - EthereumGetAddress - | EthereumGetPublicKey - | EthereumSignTx - | EthereumSignMessage - | EthereumSignTypedData - ) - MsgIn = TypeVar("MsgIn", bound=EthereumMessages) + # messages for "with_keychain_from_path" decorator + MsgInKeychainPath = TypeVar("MsgInKeychainPath", bound=EthereumGetPublicKey) + # messages for "with_keychain_from_path_and_defs" decorator + MsgInKeychainPathDefs = TypeVar("MsgInKeychainPathDefs", bound=EthereumGetAddress | EthereumSignMessage | EthereumSignTypedData) + # messages for "with_keychain_from_chain_id_and_defs" decorator + MsgInKeychainChainIdDefs = TypeVar("MsgInKeychainChainIdDefs", bound=EthereumSignTx | EthereumSignTxEIP1559) - EthereumSignTxAny = EthereumSignTx | EthereumSignTxEIP1559 - MsgInChainId = TypeVar("MsgInChainId", bound=EthereumSignTxAny) + # TODO: check the types of messages + HandlerWithKeychainAndDefinitions = Callable[[Context, MsgInGeneric, Keychain, definitions.EthereumDefinitions], Awaitable[MsgOut]] # We believe Ethereum should use 44'/60'/a' for everything, because it is @@ -48,13 +48,20 @@ PATTERNS_ADDRESS = ( def _schemas_from_address_n( - patterns: Iterable[str], address_n: paths.Bip32Path + patterns: Iterable[str], address_n: paths.Bip32Path, network_info: networks.NetworkInfo | None ) -> Iterable[paths.PathSchema]: if len(address_n) < 2: return () slip44_hardened = address_n[1] - if slip44_hardened not in networks.all_slip44_ids_hardened(): + + def _get_hardened_slip44_networks(): + if network_info is not None: + yield network_info.slip44 | paths.HARDENED + yield from networks.all_slip44_ids_hardened() + + # check with network from definitions and if that is None then with built-in ones + if slip44_hardened not in _get_hardened_slip44_networks(): return () if not slip44_hardened & paths.HARDENED: @@ -67,10 +74,11 @@ def _schemas_from_address_n( def with_keychain_from_path( *patterns: str, -) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: - def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: - async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut: - schemas = _schemas_from_address_n(patterns, msg.address_n) +) -> Callable[[HandlerWithKeychain[MsgInKeychainPath, MsgOut]], Handler[MsgInKeychainPath, MsgOut]]: + def decorator(func: HandlerWithKeychain[MsgInKeychainPath, MsgOut]) -> Handler[MsgInKeychainPath, MsgOut]: + async def wrapper(ctx: Context, msg: MsgInKeychainPath) -> MsgOut: + defs = definitions.get_definitions_from_msg(msg) + schemas = _schemas_from_address_n(patterns, msg.address_n, defs.network) keychain = await get_keychain(ctx, CURVE, schemas) with keychain: return await func(ctx, msg, keychain) @@ -80,17 +88,32 @@ def with_keychain_from_path( return decorator -def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]: - info = networks.by_chain_id(msg.chain_id) +def with_keychain_from_path_and_defs( + *patterns: str, +) -> Callable[[HandlerWithKeychainAndDefinitions[MsgInKeychainPathDefs, MsgOut]], Handler[MsgInKeychainPathDefs, MsgOut]]: + def decorator(func: HandlerWithKeychainAndDefinitions[MsgInKeychainPathDefs, MsgOut]) -> Handler[MsgInKeychainPathDefs, MsgOut]: + async def wrapper(ctx: Context, msg: MsgInKeychainPathDefs) -> MsgOut: + defs = definitions.get_definitions_from_msg(msg) + schemas = _schemas_from_address_n(patterns, msg.address_n, defs.network) + keychain = await get_keychain(ctx, CURVE, schemas) + with keychain: + return await func(ctx, msg, keychain, defs) + + return wrapper + + return decorator + + +def _schemas_from_chain_id(network_info: networks.NetworkInfo | None) -> Iterable[paths.PathSchema]: slip44_id: tuple[int, ...] - if info is None: + if network_info is None: # allow Ethereum or testnet paths for unknown networks slip44_id = (60, 1) - elif info.slip44 not in (60, 1): + elif network_info.slip44 not in (60, 1): # allow cross-signing with Ethereum unless it's testnet - slip44_id = (info.slip44, 60) + slip44_id = (network_info.slip44, 60) else: - slip44_id = (info.slip44,) + slip44_id = (network_info.slip44,) schemas = [ paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS @@ -98,14 +121,15 @@ def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema] return [s.copy() for s in schemas] -def with_keychain_from_chain_id( - func: HandlerWithKeychain[MsgInChainId, MsgOut] -) -> Handler[MsgInChainId, MsgOut]: +def with_keychain_from_chain_id_and_defs( + func: HandlerWithKeychainAndDefinitions[MsgInKeychainChainIdDefs, MsgOut] +) -> Handler[MsgInKeychainChainIdDefs, MsgOut]: # this is only for SignTx, and only PATTERN_ADDRESS is allowed - async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut: - schemas = _schemas_from_chain_id(msg) + async def wrapper(ctx: Context, msg: MsgInKeychainChainIdDefs) -> MsgOut: + defs = definitions.get_definitions_from_msg(msg) + schemas = _schemas_from_chain_id(defs.network) keychain = await get_keychain(ctx, CURVE, schemas) with keychain: - return await func(ctx, msg, keychain) + return await func(ctx, msg, keychain, defs) return wrapper diff --git a/core/src/apps/ethereum/layout.py b/core/src/apps/ethereum/layout.py index 72513517f9..825a818614 100644 --- a/core/src/apps/ethereum/layout.py +++ b/core/src/apps/ethereum/layout.py @@ -12,7 +12,7 @@ from trezor.ui.layouts import ( ) from . import networks -from .helpers import decode_typed_data +from .helpers import address_from_bytes, decode_typed_data, get_type_name if TYPE_CHECKING: from typing import Awaitable, Iterable @@ -21,25 +21,29 @@ if TYPE_CHECKING: from trezor.wire import Context from . import tokens + from . import tokens + def require_confirm_tx( ctx: Context, to_bytes: bytes, value: int, - chain_id: int, - token: tokens.TokenInfo | None = None, + network: networks.NetworkInfo, + token: tokens.TokenInfo, ) -> Awaitable[None]: from .helpers import address_from_bytes from trezor.ui.layouts import confirm_output if to_bytes: - to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id)) + to_str = address_from_bytes(to_bytes, network) else: to_str = "new contract?" return confirm_output( ctx, - to_str, - format_ethereum_amount(value, token, chain_id), + address=to_str, + amount=format_ethereum_amount(value, token, network), + font_amount=ui.BOLD, + color_to=ui.GREY, br_code=ButtonRequestType.SignTx, ) @@ -49,19 +53,19 @@ async def require_confirm_fee( spending: int, gas_price: int, gas_limit: int, - chain_id: int, - token: tokens.TokenInfo | None = None, + network: networks.NetworkInfo, + token: tokens.TokenInfo, ) -> None: await confirm_amount( ctx, title="Confirm fee", description="Gas price:", - amount=format_ethereum_amount(gas_price, None, chain_id), + amount=format_ethereum_amount(gas_price, None, network), ) await confirm_total( ctx, - total_amount=format_ethereum_amount(spending, token, chain_id), - fee_amount=format_ethereum_amount(gas_price * gas_limit, None, chain_id), + total_amount=format_ethereum_amount(spending, token, network), + fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network), total_label="Amount sent:", fee_label="Maximum fee:", ) @@ -73,25 +77,25 @@ async def require_confirm_eip1559_fee( max_priority_fee: int, max_gas_fee: int, gas_limit: int, - chain_id: int, - token: tokens.TokenInfo | None = None, + network: networks.NetworkInfo, + token: tokens.TokenInfo, ) -> None: await confirm_amount( ctx, "Confirm fee", - format_ethereum_amount(max_gas_fee, None, chain_id), + format_ethereum_amount(max_gas_fee, None, network), "Maximum fee per gas", ) await confirm_amount( ctx, "Confirm fee", - format_ethereum_amount(max_priority_fee, None, chain_id), + format_ethereum_amount(max_priority_fee, None, network), "Priority fee per gas", ) await confirm_total( ctx, - format_ethereum_amount(spending, token, chain_id), - format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id), + format_ethereum_amount(spending, token, network), + format_ethereum_amount(max_gas_fee * gas_limit, None, network), total_label="Amount sent:", fee_label="Maximum fee:", ) @@ -249,7 +253,7 @@ async def confirm_typed_value( def format_ethereum_amount( - value: int, token: tokens.TokenInfo | None, chain_id: int + value: int, token: tokens.TokenInfo | None, network_info: networks.NetworkInfo | None ) -> str: from trezor.strings import format_amount @@ -257,7 +261,10 @@ def format_ethereum_amount( suffix = token.symbol decimals = token.decimals else: - suffix = networks.shortcut_by_chain_id(chain_id) + if network_info is not None: + suffix = network_info.shortcut + else: + suffix = networks.UNKNOWN_NETWORK_SHORTCUT decimals = 18 # Don't want to display wei values for tokens with small decimal numbers diff --git a/core/src/apps/ethereum/networks.py b/core/src/apps/ethereum/networks.py index 26730b04b5..5ea30cea84 100644 --- a/core/src/apps/ethereum/networks.py +++ b/core/src/apps/ethereum/networks.py @@ -22,11 +22,12 @@ if TYPE_CHECKING: bool # rskip60 ] # fmt: on +UNKNOWN_NETWORK_SHORTCUT = "UNKN" def shortcut_by_chain_id(chain_id: int) -> str: n = by_chain_id(chain_id) - return n.shortcut if n is not None else "UNKN" + return n.shortcut if n is not None else UNKNOWN_NETWORK_SHORTCUT def by_chain_id(chain_id: int) -> "NetworkInfo" | None: diff --git a/core/src/apps/ethereum/networks.py.mako b/core/src/apps/ethereum/networks.py.mako index 07e1351208..f2ed36682c 100644 --- a/core/src/apps/ethereum/networks.py.mako +++ b/core/src/apps/ethereum/networks.py.mako @@ -22,11 +22,12 @@ if TYPE_CHECKING: bool # rskip60 ] # fmt: on +UNKNOWN_NETWORK_SHORTCUT = "UNKN" def shortcut_by_chain_id(chain_id: int) -> str: n = by_chain_id(chain_id) - return n.shortcut if n is not None else "UNKN" + return n.shortcut if n is not None else UNKNOWN_NETWORK_SHORTCUT def by_chain_id(chain_id: int) -> "NetworkInfo" | None: diff --git a/core/src/apps/ethereum/sign_message.py b/core/src/apps/ethereum/sign_message.py index 808cc98923..a14d39d6a6 100644 --- a/core/src/apps/ethereum/sign_message.py +++ b/core/src/apps/ethereum/sign_message.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from .keychain import PATTERNS_ADDRESS, with_keychain_from_path +from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs if TYPE_CHECKING: from trezor.messages import EthereumSignMessage, EthereumMessageSignature @@ -8,6 +8,8 @@ if TYPE_CHECKING: from apps.common.keychain import Keychain + from . import definitions + def message_digest(message: bytes) -> bytes: from trezor.crypto.hashlib import sha3_256 @@ -21,9 +23,9 @@ def message_digest(message: bytes) -> bytes: return h.get_digest() -@with_keychain_from_path(*PATTERNS_ADDRESS) +@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS) async def sign_message( - ctx: Context, msg: EthereumSignMessage, keychain: Keychain + ctx: Context, msg: EthereumSignMessage, keychain: Keychain, defs: definitions.EthereumDefinitions ) -> EthereumMessageSignature: from trezor.crypto.curve import secp256k1 from trezor.messages import EthereumMessageSignature @@ -37,7 +39,7 @@ async def sign_message( await paths.validate_path(ctx, keychain, msg.address_n) node = keychain.derive(msg.address_n) - address = address_from_bytes(node.ethereum_pubkeyhash()) + address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network) await confirm_signverify( ctx, "ETH", decode_message(msg.message), address, verify=False ) diff --git a/core/src/apps/ethereum/sign_tx.py b/core/src/apps/ethereum/sign_tx.py index a6ccc68974..c75284bed1 100644 --- a/core/src/apps/ethereum/sign_tx.py +++ b/core/src/apps/ethereum/sign_tx.py @@ -5,15 +5,16 @@ from trezor.messages import EthereumTxRequest from trezor.wire import DataError from .helpers import bytes_from_address -from .keychain import with_keychain_from_chain_id +from .keychain import with_keychain_from_chain_id_and_defs if TYPE_CHECKING: + from collections import defaultdict from apps.common.keychain import Keychain from trezor.messages import EthereumSignTx, EthereumTxAck from trezor.wire import Context from .keychain import EthereumSignTxAny - from . import tokens + from . import tokens, definitions # Maximum chain_id which returns the full signature_v (which must fit into an uint32). @@ -22,9 +23,9 @@ if TYPE_CHECKING: MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2 -@with_keychain_from_chain_id +@with_keychain_from_chain_id_and_defs async def sign_tx( - ctx: Context, msg: EthereumSignTx, keychain: Keychain + ctx: Context, msg: EthereumSignTx, keychain: Keychain, defs: definitions.EthereumDefinitions ) -> EthereumTxRequest: from trezor.utils import HashWriter from trezor.crypto.hashlib import sha3_256 @@ -45,11 +46,11 @@ async def sign_tx( await paths.validate_path(ctx, keychain, msg.address_n) # Handle ERC20s - token, address_bytes, recipient, value = await handle_erc20(ctx, msg) + token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict) data_total = msg.data_length - await require_confirm_tx(ctx, recipient, value, msg.chain_id, token) + await require_confirm_tx(ctx, recipient, value, defs.network, token) if token is None and msg.data_length > 0: await require_confirm_data(ctx, msg.data_initial_chunk, data_total) @@ -58,7 +59,7 @@ async def sign_tx( value, int.from_bytes(msg.gas_price, "big"), int.from_bytes(msg.gas_limit, "big"), - msg.chain_id, + defs.network, token, ) @@ -100,7 +101,7 @@ async def sign_tx( async def handle_erc20( - ctx: Context, msg: EthereumSignTxAny + ctx: Context, msg: EthereumSignTxAny, token_dict: defaultdict[bytes, tokens.TokenInfo] ) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]: from .layout import require_confirm_unknown_token from . import tokens @@ -118,7 +119,7 @@ async def handle_erc20( and data_initial_chunk[:16] == b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" ): - token = tokens.token_by_chain_address(msg.chain_id, address_bytes) + token = token_dict[address_bytes] recipient = data_initial_chunk[16:36] value = int.from_bytes(data_initial_chunk[36:68], "big") diff --git a/core/src/apps/ethereum/sign_tx_eip1559.py b/core/src/apps/ethereum/sign_tx_eip1559.py index 098ebc3521..961501b998 100644 --- a/core/src/apps/ethereum/sign_tx_eip1559.py +++ b/core/src/apps/ethereum/sign_tx_eip1559.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING from trezor.crypto import rlp from .helpers import bytes_from_address -from .keychain import with_keychain_from_chain_id +from .keychain import with_keychain_from_chain_id_and_defs if TYPE_CHECKING: from trezor.messages import ( @@ -13,6 +13,7 @@ if TYPE_CHECKING: EthereumTxRequest, ) + from . import definitions from apps.common.keychain import Keychain from trezor.wire import Context @@ -28,9 +29,9 @@ def access_list_item_length(item: EthereumAccessList) -> int: ) -@with_keychain_from_chain_id +@with_keychain_from_chain_id_and_defs async def sign_tx_eip1559( - ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain + ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain, defs: definitions.EthereumDefinitions ) -> EthereumTxRequest: from trezor.crypto.hashlib import sha3_256 from trezor.utils import HashWriter @@ -56,11 +57,11 @@ async def sign_tx_eip1559( await paths.validate_path(ctx, keychain, msg.address_n) # Handle ERC20s - token, address_bytes, recipient, value = await handle_erc20(ctx, msg) + token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict) data_total = msg.data_length - await require_confirm_tx(ctx, recipient, value, msg.chain_id, token) + await require_confirm_tx(ctx, recipient, value, defs.network, token) if token is None and msg.data_length > 0: await require_confirm_data(ctx, msg.data_initial_chunk, data_total) @@ -70,7 +71,7 @@ async def sign_tx_eip1559( int.from_bytes(msg.max_priority_fee, "big"), int.from_bytes(msg.max_gas_fee, "big"), int.from_bytes(gas_limit, "big"), - msg.chain_id, + defs.network, token, ) diff --git a/core/src/apps/ethereum/sign_typed_data.py b/core/src/apps/ethereum/sign_typed_data.py index d0426b6130..56478f0d06 100644 --- a/core/src/apps/ethereum/sign_typed_data.py +++ b/core/src/apps/ethereum/sign_typed_data.py @@ -5,7 +5,7 @@ from trezor.enums import EthereumDataType from trezor.wire import DataError from .helpers import get_type_name -from .keychain import PATTERNS_ADDRESS, with_keychain_from_path +from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs from .layout import should_show_struct if TYPE_CHECKING: @@ -20,14 +20,16 @@ if TYPE_CHECKING: EthereumTypedDataStructAck, ) + from . import definitions + # Maximum data size we support _MAX_VALUE_BYTE_SIZE = const(1024) -@with_keychain_from_path(*PATTERNS_ADDRESS) +@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS) async def sign_typed_data( - ctx: Context, msg: EthereumSignTypedData, keychain: Keychain + ctx: Context, msg: EthereumSignTypedData, keychain: Keychain, defs: definitions.EthereumDefinitions ) -> EthereumTypedDataSignature: from trezor.crypto.curve import secp256k1 from apps.common import paths @@ -46,7 +48,7 @@ async def sign_typed_data( ) return EthereumTypedDataSignature( - address=address_from_bytes(node.ethereum_pubkeyhash()), + address=address_from_bytes(node.ethereum_pubkeyhash(), defs.network), signature=signature[1:] + signature[0:1], ) diff --git a/core/src/apps/ethereum/tokens.py b/core/src/apps/ethereum/tokens.py index b0f66afb43..c7ff897ad6 100644 --- a/core/src/apps/ethereum/tokens.py +++ b/core/src/apps/ethereum/tokens.py @@ -18,9 +18,19 @@ from typing import Iterator class TokenInfo: - def __init__(self, symbol: str, decimals: int) -> None: + def __init__( + self, + symbol: str, + decimals: int, + address: bytes = None, + chain_id: int = None, + name: str = None, + ) -> None: self.symbol = symbol self.decimals = decimals + self.address = address + self.chain_id = chain_id + self.name = name UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0) diff --git a/core/src/apps/ethereum/tokens.py.mako b/core/src/apps/ethereum/tokens.py.mako index bf1455deba..a94b573f24 100644 --- a/core/src/apps/ethereum/tokens.py.mako +++ b/core/src/apps/ethereum/tokens.py.mako @@ -27,9 +27,19 @@ def group_tokens(tokens): %>\ class TokenInfo: - def __init__(self, symbol: str, decimals: int) -> None: + def __init__( + self, + symbol: str, + decimals: int, + address: bytes = None, + chain_id: int = None, + name: str = None, + ) -> None: self.symbol = symbol self.decimals = decimals + self.address = address + self.chain_id = chain_id + self.name = name UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0) diff --git a/core/src/apps/ethereum/verify_message.py b/core/src/apps/ethereum/verify_message.py index cc6bd552d1..7fc9158078 100644 --- a/core/src/apps/ethereum/verify_message.py +++ b/core/src/apps/ethereum/verify_message.py @@ -14,9 +14,12 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success: from apps.common.signverify import decode_message + from . import definitions from .helpers import address_from_bytes, bytes_from_address from .sign_message import message_digest + defs = definitions.get_definitions_from_msg(msg) + digest = message_digest(msg.message) if len(msg.signature) != 65: raise DataError("Invalid signature") @@ -33,7 +36,7 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success: if address_bytes != pkh: raise DataError("Invalid signature") - address = address_from_bytes(address_bytes) + address = address_from_bytes(address_bytes, defs.network) await confirm_signverify( ctx, "ETH", decode_message(msg.message), address, verify=True