1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-30 02:18:16 +00:00

feat(core): added Ethereum definitions object - handle definitions from host

This commit is contained in:
Martin Novak 2022-06-23 11:56:11 +02:00
parent 3cf9b7c235
commit 824abe7d2f
15 changed files with 355 additions and 84 deletions

View File

@ -491,6 +491,8 @@ if not utils.BITCOIN_ONLY:
import apps.eos.writers import apps.eos.writers
apps.ethereum apps.ethereum
import apps.ethereum import apps.ethereum
apps.ethereum.definitions
import apps.ethereum.definitions
apps.ethereum.get_address apps.ethereum.get_address
import apps.ethereum.get_address import apps.ethereum.get_address
apps.ethereum.get_public_key apps.ethereum.get_public_key

View File

@ -0,0 +1,192 @@
from collections import defaultdict
from typing import TYPE_CHECKING, Tuple
from apps.ethereum import tokens
from trezor import protobuf, wire
from trezor.crypto.curve import ed25519
from trezor.enums import EthereumDefinitionType
from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo
from . import networks
if TYPE_CHECKING:
from trezor.protobuf import MessageType
from .networks import NetworkInfo
from .tokens import TokenInfo
DEFINITIONS_PUBLIC_KEY = b""
MIN_DATA_VERSION = 1
FORMAT_VERSION = "trzd1"
if __debug__:
DEFINITIONS_DEV_PUBLIC_KEY = b""
class EthereumDefinitionParser:
def __init__(self, definition_bytes: bytes) -> None:
if len(definition_bytes) <= (8 + 1 + 4 + 64):
raise wire.DataError("Received Ethereum definition is probably malformed (too few data).")
self.format_version: str = definition_bytes[:8].rstrip(b'\0').decode("utf-8")
self.definition_type: int = definition_bytes[8]
self.data_version: int = int.from_bytes(definition_bytes[9:13], 'big')
self.clean_payload = definition_bytes[13:-64]
self.payload = definition_bytes[:-64]
self.signature = definition_bytes[-64:]
def decode_definition(
definition: bytes, expected_type: EthereumDefinitionType
) -> NetworkInfo | TokenInfo:
# check network definition
parsed_definition = EthereumDefinitionParser(definition)
# first check format version
if parsed_definition.format_version != FORMAT_VERSION:
raise wire.DataError("Used different Ethereum definition format version.")
# second check the type of the data
if parsed_definition.definition_type != expected_type:
raise wire.DataError("Definition of invalid type for Ethereum.")
# third check data version
if parsed_definition.data_version < MIN_DATA_VERSION:
raise wire.DataError("Used Ethereum definition data version too low.")
# at the end verify the signature
if not ed25519.verify(DEFINITIONS_PUBLIC_KEY, parsed_definition.signature, parsed_definition.payload):
error_msg = wire.DataError("Ethereum definition signature is invalid.")
if __debug__:
# check against dev key
if not ed25519.verify(DEFINITIONS_DEV_PUBLIC_KEY, parsed_definition.signature, parsed_definition.payload):
raise error_msg
else:
raise error_msg
# decode it if it's OK
if expected_type == EthereumDefinitionType.NETWORK:
info = protobuf.decode(parsed_definition.payload, EthereumNetworkInfo, True)
# TODO: temporarily convert to internal class
if info is not None:
from .networks import NetworkInfo
info = NetworkInfo(
chain_id=info.chain_id,
slip44=info.slip44,
shortcut=info.shortcut,
name=info.name,
rskip60=info.rskip60
)
else:
info = protobuf.decode(parsed_definition.payload, EthereumTokenInfo, True)
# TODO: temporarily convert to internal class
if info is not None:
from .tokens import TokenInfo
info = TokenInfo(
symbol=info.symbol,
decimals=info.decimals,
address=info.address,
chain_id=info.chain_id,
name=info.name,
)
return info
def _get_network_definiton(encoded_network_definition: bytes | None, ref_chain_id: int | None = None) -> NetworkInfo | None:
if encoded_network_definition is None and ref_chain_id is None:
return None
if ref_chain_id is not None:
# if we have a built-in definition, use it
network = networks.by_chain_id(ref_chain_id)
if network is not None:
return network
if encoded_network_definition is not None:
# get definition if it was send
network = decode_definition(encoded_network_definition, EthereumDefinitionType.NETWORK)
# check referential chain_id with encoded chain_id
if network.chain_id != ref_chain_id:
raise wire.DataError("Invalid network definition - chain IDs not equal.")
return network
return None
def _get_token_definiton(encoded_token_definition: bytes | None, ref_chain_id: int | None = None, ref_address: int | None = None) -> TokenInfo:
if encoded_token_definition is None and (ref_chain_id is None or ref_address is None):
return None
# if we have a built-in definition, use it
if ref_chain_id is not None and ref_address is not None:
token = tokens.token_by_chain_address(ref_chain_id, ref_address)
if token is not tokens.UNKNOWN_TOKEN:
return token
if encoded_token_definition is not None:
# get definition if it was send
token = decode_definition(encoded_token_definition, EthereumDefinitionType.TOKEN)
# check token against ref_chain_id and ref_address
if (
(ref_chain_id is None or token.chain_id == ref_chain_id)
and (ref_address is None or token.address == ref_address)
):
return token
return tokens.UNKNOWN_TOKEN
class EthereumDefinitions:
"""Class that holds Ethereum definitions - network and tokens. Prefers built-in definitions over encoded ones."""
def __init__(
self,
encoded_network_definition: bytes | None = None,
encoded_token_definition: bytes | None = None,
ref_chain_id: int | None = None,
ref_token_address: int | None = None,
) -> None:
self.network = _get_network_definiton(encoded_network_definition, ref_chain_id)
self.token_dict: defaultdict[bytes, TokenInfo] = defaultdict(lambda: tokens.UNKNOWN_TOKEN)
# if we have some network, we can try to get token
if self.network is not None:
received_token = _get_token_definiton(encoded_token_definition, self.network.chain_id, ref_token_address)
if received_token is not tokens.UNKNOWN_TOKEN:
self.token_dict[received_token.address] = received_token
def get_definitions_from_msg(msg: MessageType) -> EthereumDefinitions:
encoded_network_definition: bytes | None = None
encoded_token_definition: bytes | None = None
chain_id: int | None = None
token_address: int | None = None
# first try to get both definitions
try:
if msg.definitions is not None:
encoded_network_definition = msg.definitions.encoded_network
encoded_token_definition = msg.definitions.encoded_token
except AttributeError:
encoded_network_definition = msg.encoded_network
# get chain_id
try:
chain_id = msg.chain_id
except AttributeError:
pass
# get token_address
try:
token_address = msg.to
except AttributeError:
pass
return EthereumDefinitions(encoded_network_definition, encoded_token_definition, chain_id, token_address)

View File

@ -1,6 +1,13 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address
from apps.common import paths
from . import networks
from .helpers import address_from_bytes
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumGetAddress, EthereumAddress from trezor.messages import EthereumGetAddress, EthereumAddress
@ -8,10 +15,12 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from . import definitions
@with_keychain_from_path(*PATTERNS_ADDRESS)
@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
async def get_address( async def get_address(
ctx: Context, msg: EthereumGetAddress, keychain: Keychain ctx: Context, msg: EthereumGetAddress, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumAddress: ) -> EthereumAddress:
from trezor.messages import EthereumAddress from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address from trezor.ui.layouts import show_address
@ -25,8 +34,12 @@ async def get_address(
node = keychain.derive(address_n) node = keychain.derive(address_n)
if len(address_n) > 1: # path has slip44 network identifier if len(msg.address_n) > 1: # path has slip44 network identifier
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF) slip44 = msg.address_n[1] & 0x7FFF_FFFF
if slip44 == defs.network.slip44:
network = defs.network
else:
network = networks.by_slip44(slip44)
else: else:
network = None network = None
address = address_from_bytes(node.ethereum_pubkeyhash(), network) address = address_from_bytes(node.ethereum_pubkeyhash(), network)

View File

@ -1,12 +1,14 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify from ubinascii import hexlify
from .networks import by_chain_id
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumFieldType from trezor.messages import EthereumFieldType
from .networks import NetworkInfo from .networks import NetworkInfo
def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None) -> str: def address_from_bytes(address_bytes: bytes, network: NetworkInfo = by_chain_id(1)) -> str:
""" """
Converts address in bytes to a checksummed string as defined Converts address in bytes to a checksummed string as defined
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md

View File

@ -3,10 +3,12 @@ from typing import TYPE_CHECKING
from apps.common import paths from apps.common import paths
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
from . import CURVE, networks from . import CURVE, networks, definitions
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, TypeVar from typing import Awaitable, Callable, Iterable, TypeVar
from apps.common.keychain import Keychain
from trezor.wire import Context from trezor.wire import Context
@ -19,19 +21,17 @@ if TYPE_CHECKING:
EthereumSignTypedData, EthereumSignTypedData,
) )
from apps.common.keychain import MsgOut, Handler, HandlerWithKeychain from apps.common.keychain import MsgIn as MsgInGeneric, MsgOut, Handler, HandlerWithKeychain
EthereumMessages = ( # messages for "with_keychain_from_path" decorator
EthereumGetAddress MsgInKeychainPath = TypeVar("MsgInKeychainPath", bound=EthereumGetPublicKey)
| EthereumGetPublicKey # messages for "with_keychain_from_path_and_defs" decorator
| EthereumSignTx MsgInKeychainPathDefs = TypeVar("MsgInKeychainPathDefs", bound=EthereumGetAddress | EthereumSignMessage | EthereumSignTypedData)
| EthereumSignMessage # messages for "with_keychain_from_chain_id_and_defs" decorator
| EthereumSignTypedData MsgInKeychainChainIdDefs = TypeVar("MsgInKeychainChainIdDefs", bound=EthereumSignTx | EthereumSignTxEIP1559)
)
MsgIn = TypeVar("MsgIn", bound=EthereumMessages)
EthereumSignTxAny = EthereumSignTx | EthereumSignTxEIP1559 # TODO: check the types of messages
MsgInChainId = TypeVar("MsgInChainId", bound=EthereumSignTxAny) HandlerWithKeychainAndDefinitions = Callable[[Context, MsgInGeneric, Keychain, definitions.EthereumDefinitions], Awaitable[MsgOut]]
# We believe Ethereum should use 44'/60'/a' for everything, because it is # We believe Ethereum should use 44'/60'/a' for everything, because it is
@ -48,13 +48,20 @@ PATTERNS_ADDRESS = (
def _schemas_from_address_n( def _schemas_from_address_n(
patterns: Iterable[str], address_n: paths.Bip32Path patterns: Iterable[str], address_n: paths.Bip32Path, network_info: networks.NetworkInfo | None
) -> Iterable[paths.PathSchema]: ) -> Iterable[paths.PathSchema]:
if len(address_n) < 2: if len(address_n) < 2:
return () return ()
slip44_hardened = address_n[1] slip44_hardened = address_n[1]
if slip44_hardened not in networks.all_slip44_ids_hardened():
def _get_hardened_slip44_networks():
if network_info is not None:
yield network_info.slip44 | paths.HARDENED
yield from networks.all_slip44_ids_hardened()
# check with network from definitions and if that is None then with built-in ones
if slip44_hardened not in _get_hardened_slip44_networks():
return () return ()
if not slip44_hardened & paths.HARDENED: if not slip44_hardened & paths.HARDENED:
@ -67,10 +74,11 @@ def _schemas_from_address_n(
def with_keychain_from_path( def with_keychain_from_path(
*patterns: str, *patterns: str,
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: ) -> Callable[[HandlerWithKeychain[MsgInKeychainPath, MsgOut]], Handler[MsgInKeychainPath, MsgOut]]:
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: def decorator(func: HandlerWithKeychain[MsgInKeychainPath, MsgOut]) -> Handler[MsgInKeychainPath, MsgOut]:
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut: async def wrapper(ctx: Context, msg: MsgInKeychainPath) -> MsgOut:
schemas = _schemas_from_address_n(patterns, msg.address_n) defs = definitions.get_definitions_from_msg(msg)
schemas = _schemas_from_address_n(patterns, msg.address_n, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas) keychain = await get_keychain(ctx, CURVE, schemas)
with keychain: with keychain:
return await func(ctx, msg, keychain) return await func(ctx, msg, keychain)
@ -80,17 +88,32 @@ def with_keychain_from_path(
return decorator return decorator
def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]: def with_keychain_from_path_and_defs(
info = networks.by_chain_id(msg.chain_id) *patterns: str,
) -> Callable[[HandlerWithKeychainAndDefinitions[MsgInKeychainPathDefs, MsgOut]], Handler[MsgInKeychainPathDefs, MsgOut]]:
def decorator(func: HandlerWithKeychainAndDefinitions[MsgInKeychainPathDefs, MsgOut]) -> Handler[MsgInKeychainPathDefs, MsgOut]:
async def wrapper(ctx: Context, msg: MsgInKeychainPathDefs) -> MsgOut:
defs = definitions.get_definitions_from_msg(msg)
schemas = _schemas_from_address_n(patterns, msg.address_n, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain, defs)
return wrapper
return decorator
def _schemas_from_chain_id(network_info: networks.NetworkInfo | None) -> Iterable[paths.PathSchema]:
slip44_id: tuple[int, ...] slip44_id: tuple[int, ...]
if info is None: if network_info is None:
# allow Ethereum or testnet paths for unknown networks # allow Ethereum or testnet paths for unknown networks
slip44_id = (60, 1) slip44_id = (60, 1)
elif info.slip44 not in (60, 1): elif network_info.slip44 not in (60, 1):
# allow cross-signing with Ethereum unless it's testnet # allow cross-signing with Ethereum unless it's testnet
slip44_id = (info.slip44, 60) slip44_id = (network_info.slip44, 60)
else: else:
slip44_id = (info.slip44,) slip44_id = (network_info.slip44,)
schemas = [ schemas = [
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
@ -98,14 +121,15 @@ def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]
return [s.copy() for s in schemas] return [s.copy() for s in schemas]
def with_keychain_from_chain_id( def with_keychain_from_chain_id_and_defs(
func: HandlerWithKeychain[MsgInChainId, MsgOut] func: HandlerWithKeychainAndDefinitions[MsgInKeychainChainIdDefs, MsgOut]
) -> Handler[MsgInChainId, MsgOut]: ) -> Handler[MsgInKeychainChainIdDefs, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed # this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut: async def wrapper(ctx: Context, msg: MsgInKeychainChainIdDefs) -> MsgOut:
schemas = _schemas_from_chain_id(msg) defs = definitions.get_definitions_from_msg(msg)
schemas = _schemas_from_chain_id(defs.network)
keychain = await get_keychain(ctx, CURVE, schemas) keychain = await get_keychain(ctx, CURVE, schemas)
with keychain: with keychain:
return await func(ctx, msg, keychain) return await func(ctx, msg, keychain, defs)
return wrapper return wrapper

View File

@ -12,7 +12,7 @@ from trezor.ui.layouts import (
) )
from . import networks from . import networks
from .helpers import decode_typed_data from .helpers import address_from_bytes, decode_typed_data, get_type_name
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Awaitable, Iterable from typing import Awaitable, Iterable
@ -21,25 +21,29 @@ if TYPE_CHECKING:
from trezor.wire import Context from trezor.wire import Context
from . import tokens from . import tokens
from . import tokens
def require_confirm_tx( def require_confirm_tx(
ctx: Context, ctx: Context,
to_bytes: bytes, to_bytes: bytes,
value: int, value: int,
chain_id: int, network: networks.NetworkInfo,
token: tokens.TokenInfo | None = None, token: tokens.TokenInfo,
) -> Awaitable[None]: ) -> Awaitable[None]:
from .helpers import address_from_bytes from .helpers import address_from_bytes
from trezor.ui.layouts import confirm_output from trezor.ui.layouts import confirm_output
if to_bytes: if to_bytes:
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id)) to_str = address_from_bytes(to_bytes, network)
else: else:
to_str = "new contract?" to_str = "new contract?"
return confirm_output( return confirm_output(
ctx, ctx,
to_str, address=to_str,
format_ethereum_amount(value, token, chain_id), amount=format_ethereum_amount(value, token, network),
font_amount=ui.BOLD,
color_to=ui.GREY,
br_code=ButtonRequestType.SignTx, br_code=ButtonRequestType.SignTx,
) )
@ -49,19 +53,19 @@ async def require_confirm_fee(
spending: int, spending: int,
gas_price: int, gas_price: int,
gas_limit: int, gas_limit: int,
chain_id: int, network: networks.NetworkInfo,
token: tokens.TokenInfo | None = None, token: tokens.TokenInfo,
) -> None: ) -> None:
await confirm_amount( await confirm_amount(
ctx, ctx,
title="Confirm fee", title="Confirm fee",
description="Gas price:", description="Gas price:",
amount=format_ethereum_amount(gas_price, None, chain_id), amount=format_ethereum_amount(gas_price, None, network),
) )
await confirm_total( await confirm_total(
ctx, ctx,
total_amount=format_ethereum_amount(spending, token, chain_id), total_amount=format_ethereum_amount(spending, token, network),
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, chain_id), fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network),
total_label="Amount sent:", total_label="Amount sent:",
fee_label="Maximum fee:", fee_label="Maximum fee:",
) )
@ -73,25 +77,25 @@ async def require_confirm_eip1559_fee(
max_priority_fee: int, max_priority_fee: int,
max_gas_fee: int, max_gas_fee: int,
gas_limit: int, gas_limit: int,
chain_id: int, network: networks.NetworkInfo,
token: tokens.TokenInfo | None = None, token: tokens.TokenInfo,
) -> None: ) -> None:
await confirm_amount( await confirm_amount(
ctx, ctx,
"Confirm fee", "Confirm fee",
format_ethereum_amount(max_gas_fee, None, chain_id), format_ethereum_amount(max_gas_fee, None, network),
"Maximum fee per gas", "Maximum fee per gas",
) )
await confirm_amount( await confirm_amount(
ctx, ctx,
"Confirm fee", "Confirm fee",
format_ethereum_amount(max_priority_fee, None, chain_id), format_ethereum_amount(max_priority_fee, None, network),
"Priority fee per gas", "Priority fee per gas",
) )
await confirm_total( await confirm_total(
ctx, ctx,
format_ethereum_amount(spending, token, chain_id), format_ethereum_amount(spending, token, network),
format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id), format_ethereum_amount(max_gas_fee * gas_limit, None, network),
total_label="Amount sent:", total_label="Amount sent:",
fee_label="Maximum fee:", fee_label="Maximum fee:",
) )
@ -249,7 +253,7 @@ async def confirm_typed_value(
def format_ethereum_amount( def format_ethereum_amount(
value: int, token: tokens.TokenInfo | None, chain_id: int value: int, token: tokens.TokenInfo | None, network_info: networks.NetworkInfo | None
) -> str: ) -> str:
from trezor.strings import format_amount from trezor.strings import format_amount
@ -257,7 +261,10 @@ def format_ethereum_amount(
suffix = token.symbol suffix = token.symbol
decimals = token.decimals decimals = token.decimals
else: else:
suffix = networks.shortcut_by_chain_id(chain_id) if network_info is not None:
suffix = network_info.shortcut
else:
suffix = networks.UNKNOWN_NETWORK_SHORTCUT
decimals = 18 decimals = 18
# Don't want to display wei values for tokens with small decimal numbers # Don't want to display wei values for tokens with small decimal numbers

View File

@ -22,11 +22,12 @@ if TYPE_CHECKING:
bool # rskip60 bool # rskip60
] ]
# fmt: on # fmt: on
UNKNOWN_NETWORK_SHORTCUT = "UNKN"
def shortcut_by_chain_id(chain_id: int) -> str: def shortcut_by_chain_id(chain_id: int) -> str:
n = by_chain_id(chain_id) n = by_chain_id(chain_id)
return n.shortcut if n is not None else "UNKN" return n.shortcut if n is not None else UNKNOWN_NETWORK_SHORTCUT
def by_chain_id(chain_id: int) -> "NetworkInfo" | None: def by_chain_id(chain_id: int) -> "NetworkInfo" | None:

View File

@ -22,11 +22,12 @@ if TYPE_CHECKING:
bool # rskip60 bool # rskip60
] ]
# fmt: on # fmt: on
UNKNOWN_NETWORK_SHORTCUT = "UNKN"
def shortcut_by_chain_id(chain_id: int) -> str: def shortcut_by_chain_id(chain_id: int) -> str:
n = by_chain_id(chain_id) n = by_chain_id(chain_id)
return n.shortcut if n is not None else "UNKN" return n.shortcut if n is not None else UNKNOWN_NETWORK_SHORTCUT
def by_chain_id(chain_id: int) -> "NetworkInfo" | None: def by_chain_id(chain_id: int) -> "NetworkInfo" | None:

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumSignMessage, EthereumMessageSignature from trezor.messages import EthereumSignMessage, EthereumMessageSignature
@ -8,6 +8,8 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from . import definitions
def message_digest(message: bytes) -> bytes: def message_digest(message: bytes) -> bytes:
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
@ -21,9 +23,9 @@ def message_digest(message: bytes) -> bytes:
return h.get_digest() return h.get_digest()
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
async def sign_message( async def sign_message(
ctx: Context, msg: EthereumSignMessage, keychain: Keychain ctx: Context, msg: EthereumSignMessage, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumMessageSignature: ) -> EthereumMessageSignature:
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.messages import EthereumMessageSignature from trezor.messages import EthereumMessageSignature
@ -37,7 +39,7 @@ async def sign_message(
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
address = address_from_bytes(node.ethereum_pubkeyhash()) address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
await confirm_signverify( await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address, verify=False ctx, "ETH", decode_message(msg.message), address, verify=False
) )

View File

@ -5,15 +5,16 @@ from trezor.messages import EthereumTxRequest
from trezor.wire import DataError from trezor.wire import DataError
from .helpers import bytes_from_address from .helpers import bytes_from_address
from .keychain import with_keychain_from_chain_id from .keychain import with_keychain_from_chain_id_and_defs
if TYPE_CHECKING: if TYPE_CHECKING:
from collections import defaultdict
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.messages import EthereumSignTx, EthereumTxAck from trezor.messages import EthereumSignTx, EthereumTxAck
from trezor.wire import Context from trezor.wire import Context
from .keychain import EthereumSignTxAny from .keychain import EthereumSignTxAny
from . import tokens from . import tokens, definitions
# Maximum chain_id which returns the full signature_v (which must fit into an uint32). # Maximum chain_id which returns the full signature_v (which must fit into an uint32).
@ -22,9 +23,9 @@ if TYPE_CHECKING:
MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2 MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id @with_keychain_from_chain_id_and_defs
async def sign_tx( async def sign_tx(
ctx: Context, msg: EthereumSignTx, keychain: Keychain ctx: Context, msg: EthereumSignTx, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumTxRequest: ) -> EthereumTxRequest:
from trezor.utils import HashWriter from trezor.utils import HashWriter
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
@ -45,11 +46,11 @@ async def sign_tx(
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s # Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg) token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict)
data_total = msg.data_length data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token) await require_confirm_tx(ctx, recipient, value, defs.network, token)
if token is None and msg.data_length > 0: if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total) await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
@ -58,7 +59,7 @@ async def sign_tx(
value, value,
int.from_bytes(msg.gas_price, "big"), int.from_bytes(msg.gas_price, "big"),
int.from_bytes(msg.gas_limit, "big"), int.from_bytes(msg.gas_limit, "big"),
msg.chain_id, defs.network,
token, token,
) )
@ -100,7 +101,7 @@ async def sign_tx(
async def handle_erc20( async def handle_erc20(
ctx: Context, msg: EthereumSignTxAny ctx: Context, msg: EthereumSignTxAny, token_dict: defaultdict[bytes, tokens.TokenInfo]
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]: ) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
from .layout import require_confirm_unknown_token from .layout import require_confirm_unknown_token
from . import tokens from . import tokens
@ -118,7 +119,7 @@ async def handle_erc20(
and data_initial_chunk[:16] and data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" == b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
): ):
token = tokens.token_by_chain_address(msg.chain_id, address_bytes) token = token_dict[address_bytes]
recipient = data_initial_chunk[16:36] recipient = data_initial_chunk[16:36]
value = int.from_bytes(data_initial_chunk[36:68], "big") value = int.from_bytes(data_initial_chunk[36:68], "big")

View File

@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
from trezor.crypto import rlp from trezor.crypto import rlp
from .helpers import bytes_from_address from .helpers import bytes_from_address
from .keychain import with_keychain_from_chain_id from .keychain import with_keychain_from_chain_id_and_defs
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import ( from trezor.messages import (
@ -13,6 +13,7 @@ if TYPE_CHECKING:
EthereumTxRequest, EthereumTxRequest,
) )
from . import definitions
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context from trezor.wire import Context
@ -28,9 +29,9 @@ def access_list_item_length(item: EthereumAccessList) -> int:
) )
@with_keychain_from_chain_id @with_keychain_from_chain_id_and_defs
async def sign_tx_eip1559( async def sign_tx_eip1559(
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumTxRequest: ) -> EthereumTxRequest:
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter from trezor.utils import HashWriter
@ -56,11 +57,11 @@ async def sign_tx_eip1559(
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s # Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg) token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict)
data_total = msg.data_length data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token) await require_confirm_tx(ctx, recipient, value, defs.network, token)
if token is None and msg.data_length > 0: if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total) await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
@ -70,7 +71,7 @@ async def sign_tx_eip1559(
int.from_bytes(msg.max_priority_fee, "big"), int.from_bytes(msg.max_priority_fee, "big"),
int.from_bytes(msg.max_gas_fee, "big"), int.from_bytes(msg.max_gas_fee, "big"),
int.from_bytes(gas_limit, "big"), int.from_bytes(gas_limit, "big"),
msg.chain_id, defs.network,
token, token,
) )

View File

@ -5,7 +5,7 @@ from trezor.enums import EthereumDataType
from trezor.wire import DataError from trezor.wire import DataError
from .helpers import get_type_name from .helpers import get_type_name
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
from .layout import should_show_struct from .layout import should_show_struct
if TYPE_CHECKING: if TYPE_CHECKING:
@ -20,14 +20,16 @@ if TYPE_CHECKING:
EthereumTypedDataStructAck, EthereumTypedDataStructAck,
) )
from . import definitions
# Maximum data size we support # Maximum data size we support
_MAX_VALUE_BYTE_SIZE = const(1024) _MAX_VALUE_BYTE_SIZE = const(1024)
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
async def sign_typed_data( async def sign_typed_data(
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain ctx: Context, msg: EthereumSignTypedData, keychain: Keychain, defs: definitions.EthereumDefinitions
) -> EthereumTypedDataSignature: ) -> EthereumTypedDataSignature:
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from apps.common import paths from apps.common import paths
@ -46,7 +48,7 @@ async def sign_typed_data(
) )
return EthereumTypedDataSignature( return EthereumTypedDataSignature(
address=address_from_bytes(node.ethereum_pubkeyhash()), address=address_from_bytes(node.ethereum_pubkeyhash(), defs.network),
signature=signature[1:] + signature[0:1], signature=signature[1:] + signature[0:1],
) )

View File

@ -18,9 +18,19 @@ from typing import Iterator
class TokenInfo: class TokenInfo:
def __init__(self, symbol: str, decimals: int) -> None: def __init__(
self,
symbol: str,
decimals: int,
address: bytes = None,
chain_id: int = None,
name: str = None,
) -> None:
self.symbol = symbol self.symbol = symbol
self.decimals = decimals self.decimals = decimals
self.address = address
self.chain_id = chain_id
self.name = name
UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0) UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)

View File

@ -27,9 +27,19 @@ def group_tokens(tokens):
%>\ %>\
class TokenInfo: class TokenInfo:
def __init__(self, symbol: str, decimals: int) -> None: def __init__(
self,
symbol: str,
decimals: int,
address: bytes = None,
chain_id: int = None,
name: str = None,
) -> None:
self.symbol = symbol self.symbol = symbol
self.decimals = decimals self.decimals = decimals
self.address = address
self.chain_id = chain_id
self.name = name
UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0) UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)

View File

@ -14,9 +14,12 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
from apps.common.signverify import decode_message from apps.common.signverify import decode_message
from . import definitions
from .helpers import address_from_bytes, bytes_from_address from .helpers import address_from_bytes, bytes_from_address
from .sign_message import message_digest from .sign_message import message_digest
defs = definitions.get_definitions_from_msg(msg)
digest = message_digest(msg.message) digest = message_digest(msg.message)
if len(msg.signature) != 65: if len(msg.signature) != 65:
raise DataError("Invalid signature") raise DataError("Invalid signature")
@ -33,7 +36,7 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
if address_bytes != pkh: if address_bytes != pkh:
raise DataError("Invalid signature") raise DataError("Invalid signature")
address = address_from_bytes(address_bytes) address = address_from_bytes(address_bytes, defs.network)
await confirm_signverify( await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address, verify=True ctx, "ETH", decode_message(msg.message), address, verify=True