mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-07-29 09:58:47 +00:00
feat(core): added Ethereum definitions object - handle definitions from host
This commit is contained in:
parent
3cf9b7c235
commit
824abe7d2f
@ -491,6 +491,8 @@ if not utils.BITCOIN_ONLY:
|
||||
import apps.eos.writers
|
||||
apps.ethereum
|
||||
import apps.ethereum
|
||||
apps.ethereum.definitions
|
||||
import apps.ethereum.definitions
|
||||
apps.ethereum.get_address
|
||||
import apps.ethereum.get_address
|
||||
apps.ethereum.get_public_key
|
||||
|
192
core/src/apps/ethereum/definitions.py
Normal file
192
core/src/apps/ethereum/definitions.py
Normal file
@ -0,0 +1,192 @@
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from apps.ethereum import tokens
|
||||
|
||||
from trezor import protobuf, wire
|
||||
from trezor.crypto.curve import ed25519
|
||||
from trezor.enums import EthereumDefinitionType
|
||||
from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo
|
||||
|
||||
from . import networks
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.protobuf import MessageType
|
||||
|
||||
from .networks import NetworkInfo
|
||||
from .tokens import TokenInfo
|
||||
|
||||
|
||||
DEFINITIONS_PUBLIC_KEY = b""
|
||||
MIN_DATA_VERSION = 1
|
||||
FORMAT_VERSION = "trzd1"
|
||||
|
||||
if __debug__:
|
||||
DEFINITIONS_DEV_PUBLIC_KEY = b""
|
||||
|
||||
|
||||
class EthereumDefinitionParser:
|
||||
def __init__(self, definition_bytes: bytes) -> None:
|
||||
if len(definition_bytes) <= (8 + 1 + 4 + 64):
|
||||
raise wire.DataError("Received Ethereum definition is probably malformed (too few data).")
|
||||
|
||||
self.format_version: str = definition_bytes[:8].rstrip(b'\0').decode("utf-8")
|
||||
self.definition_type: int = definition_bytes[8]
|
||||
self.data_version: int = int.from_bytes(definition_bytes[9:13], 'big')
|
||||
self.clean_payload = definition_bytes[13:-64]
|
||||
self.payload = definition_bytes[:-64]
|
||||
self.signature = definition_bytes[-64:]
|
||||
|
||||
|
||||
def decode_definition(
|
||||
definition: bytes, expected_type: EthereumDefinitionType
|
||||
) -> NetworkInfo | TokenInfo:
|
||||
# check network definition
|
||||
parsed_definition = EthereumDefinitionParser(definition)
|
||||
|
||||
# first check format version
|
||||
if parsed_definition.format_version != FORMAT_VERSION:
|
||||
raise wire.DataError("Used different Ethereum definition format version.")
|
||||
|
||||
# second check the type of the data
|
||||
if parsed_definition.definition_type != expected_type:
|
||||
raise wire.DataError("Definition of invalid type for Ethereum.")
|
||||
|
||||
# third check data version
|
||||
if parsed_definition.data_version < MIN_DATA_VERSION:
|
||||
raise wire.DataError("Used Ethereum definition data version too low.")
|
||||
|
||||
# at the end verify the signature
|
||||
if not ed25519.verify(DEFINITIONS_PUBLIC_KEY, parsed_definition.signature, parsed_definition.payload):
|
||||
error_msg = wire.DataError("Ethereum definition signature is invalid.")
|
||||
if __debug__:
|
||||
# check against dev key
|
||||
if not ed25519.verify(DEFINITIONS_DEV_PUBLIC_KEY, parsed_definition.signature, parsed_definition.payload):
|
||||
raise error_msg
|
||||
else:
|
||||
raise error_msg
|
||||
|
||||
# decode it if it's OK
|
||||
if expected_type == EthereumDefinitionType.NETWORK:
|
||||
info = protobuf.decode(parsed_definition.payload, EthereumNetworkInfo, True)
|
||||
|
||||
# TODO: temporarily convert to internal class
|
||||
if info is not None:
|
||||
from .networks import NetworkInfo
|
||||
info = NetworkInfo(
|
||||
chain_id=info.chain_id,
|
||||
slip44=info.slip44,
|
||||
shortcut=info.shortcut,
|
||||
name=info.name,
|
||||
rskip60=info.rskip60
|
||||
)
|
||||
else:
|
||||
info = protobuf.decode(parsed_definition.payload, EthereumTokenInfo, True)
|
||||
|
||||
# TODO: temporarily convert to internal class
|
||||
if info is not None:
|
||||
from .tokens import TokenInfo
|
||||
info = TokenInfo(
|
||||
symbol=info.symbol,
|
||||
decimals=info.decimals,
|
||||
address=info.address,
|
||||
chain_id=info.chain_id,
|
||||
name=info.name,
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def _get_network_definiton(encoded_network_definition: bytes | None, ref_chain_id: int | None = None) -> NetworkInfo | None:
|
||||
if encoded_network_definition is None and ref_chain_id is None:
|
||||
return None
|
||||
|
||||
if ref_chain_id is not None:
|
||||
# if we have a built-in definition, use it
|
||||
network = networks.by_chain_id(ref_chain_id)
|
||||
if network is not None:
|
||||
return network
|
||||
|
||||
if encoded_network_definition is not None:
|
||||
# get definition if it was send
|
||||
network = decode_definition(encoded_network_definition, EthereumDefinitionType.NETWORK)
|
||||
|
||||
# check referential chain_id with encoded chain_id
|
||||
if network.chain_id != ref_chain_id:
|
||||
raise wire.DataError("Invalid network definition - chain IDs not equal.")
|
||||
|
||||
return network
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_token_definiton(encoded_token_definition: bytes | None, ref_chain_id: int | None = None, ref_address: int | None = None) -> TokenInfo:
|
||||
if encoded_token_definition is None and (ref_chain_id is None or ref_address is None):
|
||||
return None
|
||||
|
||||
# if we have a built-in definition, use it
|
||||
if ref_chain_id is not None and ref_address is not None:
|
||||
token = tokens.token_by_chain_address(ref_chain_id, ref_address)
|
||||
if token is not tokens.UNKNOWN_TOKEN:
|
||||
return token
|
||||
|
||||
if encoded_token_definition is not None:
|
||||
# get definition if it was send
|
||||
token = decode_definition(encoded_token_definition, EthereumDefinitionType.TOKEN)
|
||||
|
||||
# check token against ref_chain_id and ref_address
|
||||
if (
|
||||
(ref_chain_id is None or token.chain_id == ref_chain_id)
|
||||
and (ref_address is None or token.address == ref_address)
|
||||
):
|
||||
return token
|
||||
|
||||
return tokens.UNKNOWN_TOKEN
|
||||
|
||||
|
||||
class EthereumDefinitions:
|
||||
"""Class that holds Ethereum definitions - network and tokens. Prefers built-in definitions over encoded ones."""
|
||||
def __init__(
|
||||
self,
|
||||
encoded_network_definition: bytes | None = None,
|
||||
encoded_token_definition: bytes | None = None,
|
||||
ref_chain_id: int | None = None,
|
||||
ref_token_address: int | None = None,
|
||||
) -> None:
|
||||
self.network = _get_network_definiton(encoded_network_definition, ref_chain_id)
|
||||
self.token_dict: defaultdict[bytes, TokenInfo] = defaultdict(lambda: tokens.UNKNOWN_TOKEN)
|
||||
|
||||
# if we have some network, we can try to get token
|
||||
if self.network is not None:
|
||||
received_token = _get_token_definiton(encoded_token_definition, self.network.chain_id, ref_token_address)
|
||||
if received_token is not tokens.UNKNOWN_TOKEN:
|
||||
self.token_dict[received_token.address] = received_token
|
||||
|
||||
|
||||
def get_definitions_from_msg(msg: MessageType) -> EthereumDefinitions:
|
||||
encoded_network_definition: bytes | None = None
|
||||
encoded_token_definition: bytes | None = None
|
||||
chain_id: int | None = None
|
||||
token_address: int | None = None
|
||||
|
||||
# first try to get both definitions
|
||||
try:
|
||||
if msg.definitions is not None:
|
||||
encoded_network_definition = msg.definitions.encoded_network
|
||||
encoded_token_definition = msg.definitions.encoded_token
|
||||
except AttributeError:
|
||||
encoded_network_definition = msg.encoded_network
|
||||
|
||||
# get chain_id
|
||||
try:
|
||||
chain_id = msg.chain_id
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# get token_address
|
||||
try:
|
||||
token_address = msg.to
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
return EthereumDefinitions(encoded_network_definition, encoded_token_definition, chain_id, token_address)
|
@ -1,6 +1,13 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||
from trezor.messages import EthereumAddress
|
||||
from trezor.ui.layouts import show_address
|
||||
|
||||
from apps.common import paths
|
||||
|
||||
from . import networks
|
||||
from .helpers import address_from_bytes
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumGetAddress, EthereumAddress
|
||||
@ -8,10 +15,12 @@ if TYPE_CHECKING:
|
||||
|
||||
from apps.common.keychain import Keychain
|
||||
|
||||
from . import definitions
|
||||
|
||||
@with_keychain_from_path(*PATTERNS_ADDRESS)
|
||||
|
||||
@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
|
||||
async def get_address(
|
||||
ctx: Context, msg: EthereumGetAddress, keychain: Keychain
|
||||
ctx: Context, msg: EthereumGetAddress, keychain: Keychain, defs: definitions.EthereumDefinitions
|
||||
) -> EthereumAddress:
|
||||
from trezor.messages import EthereumAddress
|
||||
from trezor.ui.layouts import show_address
|
||||
@ -25,8 +34,12 @@ async def get_address(
|
||||
|
||||
node = keychain.derive(address_n)
|
||||
|
||||
if len(address_n) > 1: # path has slip44 network identifier
|
||||
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF)
|
||||
if len(msg.address_n) > 1: # path has slip44 network identifier
|
||||
slip44 = msg.address_n[1] & 0x7FFF_FFFF
|
||||
if slip44 == defs.network.slip44:
|
||||
network = defs.network
|
||||
else:
|
||||
network = networks.by_slip44(slip44)
|
||||
else:
|
||||
network = None
|
||||
address = address_from_bytes(node.ethereum_pubkeyhash(), network)
|
||||
|
@ -1,12 +1,14 @@
|
||||
from typing import TYPE_CHECKING
|
||||
from ubinascii import hexlify
|
||||
|
||||
from .networks import by_chain_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumFieldType
|
||||
from .networks import NetworkInfo
|
||||
|
||||
|
||||
def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None) -> str:
|
||||
def address_from_bytes(address_bytes: bytes, network: NetworkInfo = by_chain_id(1)) -> str:
|
||||
"""
|
||||
Converts address in bytes to a checksummed string as defined
|
||||
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
|
||||
|
@ -3,10 +3,12 @@ from typing import TYPE_CHECKING
|
||||
from apps.common import paths
|
||||
from apps.common.keychain import get_keychain
|
||||
|
||||
from . import CURVE, networks
|
||||
from . import CURVE, networks, definitions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable, Iterable, TypeVar
|
||||
from typing import Awaitable, Callable, Iterable, TypeVar
|
||||
|
||||
from apps.common.keychain import Keychain
|
||||
|
||||
from trezor.wire import Context
|
||||
|
||||
@ -19,19 +21,17 @@ if TYPE_CHECKING:
|
||||
EthereumSignTypedData,
|
||||
)
|
||||
|
||||
from apps.common.keychain import MsgOut, Handler, HandlerWithKeychain
|
||||
from apps.common.keychain import MsgIn as MsgInGeneric, MsgOut, Handler, HandlerWithKeychain
|
||||
|
||||
EthereumMessages = (
|
||||
EthereumGetAddress
|
||||
| EthereumGetPublicKey
|
||||
| EthereumSignTx
|
||||
| EthereumSignMessage
|
||||
| EthereumSignTypedData
|
||||
)
|
||||
MsgIn = TypeVar("MsgIn", bound=EthereumMessages)
|
||||
# messages for "with_keychain_from_path" decorator
|
||||
MsgInKeychainPath = TypeVar("MsgInKeychainPath", bound=EthereumGetPublicKey)
|
||||
# messages for "with_keychain_from_path_and_defs" decorator
|
||||
MsgInKeychainPathDefs = TypeVar("MsgInKeychainPathDefs", bound=EthereumGetAddress | EthereumSignMessage | EthereumSignTypedData)
|
||||
# messages for "with_keychain_from_chain_id_and_defs" decorator
|
||||
MsgInKeychainChainIdDefs = TypeVar("MsgInKeychainChainIdDefs", bound=EthereumSignTx | EthereumSignTxEIP1559)
|
||||
|
||||
EthereumSignTxAny = EthereumSignTx | EthereumSignTxEIP1559
|
||||
MsgInChainId = TypeVar("MsgInChainId", bound=EthereumSignTxAny)
|
||||
# TODO: check the types of messages
|
||||
HandlerWithKeychainAndDefinitions = Callable[[Context, MsgInGeneric, Keychain, definitions.EthereumDefinitions], Awaitable[MsgOut]]
|
||||
|
||||
|
||||
# We believe Ethereum should use 44'/60'/a' for everything, because it is
|
||||
@ -48,13 +48,20 @@ PATTERNS_ADDRESS = (
|
||||
|
||||
|
||||
def _schemas_from_address_n(
|
||||
patterns: Iterable[str], address_n: paths.Bip32Path
|
||||
patterns: Iterable[str], address_n: paths.Bip32Path, network_info: networks.NetworkInfo | None
|
||||
) -> Iterable[paths.PathSchema]:
|
||||
if len(address_n) < 2:
|
||||
return ()
|
||||
|
||||
slip44_hardened = address_n[1]
|
||||
if slip44_hardened not in networks.all_slip44_ids_hardened():
|
||||
|
||||
def _get_hardened_slip44_networks():
|
||||
if network_info is not None:
|
||||
yield network_info.slip44 | paths.HARDENED
|
||||
yield from networks.all_slip44_ids_hardened()
|
||||
|
||||
# check with network from definitions and if that is None then with built-in ones
|
||||
if slip44_hardened not in _get_hardened_slip44_networks():
|
||||
return ()
|
||||
|
||||
if not slip44_hardened & paths.HARDENED:
|
||||
@ -67,10 +74,11 @@ def _schemas_from_address_n(
|
||||
|
||||
def with_keychain_from_path(
|
||||
*patterns: str,
|
||||
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
|
||||
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
|
||||
schemas = _schemas_from_address_n(patterns, msg.address_n)
|
||||
) -> Callable[[HandlerWithKeychain[MsgInKeychainPath, MsgOut]], Handler[MsgInKeychainPath, MsgOut]]:
|
||||
def decorator(func: HandlerWithKeychain[MsgInKeychainPath, MsgOut]) -> Handler[MsgInKeychainPath, MsgOut]:
|
||||
async def wrapper(ctx: Context, msg: MsgInKeychainPath) -> MsgOut:
|
||||
defs = definitions.get_definitions_from_msg(msg)
|
||||
schemas = _schemas_from_address_n(patterns, msg.address_n, defs.network)
|
||||
keychain = await get_keychain(ctx, CURVE, schemas)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain)
|
||||
@ -80,17 +88,32 @@ def with_keychain_from_path(
|
||||
return decorator
|
||||
|
||||
|
||||
def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]:
|
||||
info = networks.by_chain_id(msg.chain_id)
|
||||
def with_keychain_from_path_and_defs(
|
||||
*patterns: str,
|
||||
) -> Callable[[HandlerWithKeychainAndDefinitions[MsgInKeychainPathDefs, MsgOut]], Handler[MsgInKeychainPathDefs, MsgOut]]:
|
||||
def decorator(func: HandlerWithKeychainAndDefinitions[MsgInKeychainPathDefs, MsgOut]) -> Handler[MsgInKeychainPathDefs, MsgOut]:
|
||||
async def wrapper(ctx: Context, msg: MsgInKeychainPathDefs) -> MsgOut:
|
||||
defs = definitions.get_definitions_from_msg(msg)
|
||||
schemas = _schemas_from_address_n(patterns, msg.address_n, defs.network)
|
||||
keychain = await get_keychain(ctx, CURVE, schemas)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain, defs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _schemas_from_chain_id(network_info: networks.NetworkInfo | None) -> Iterable[paths.PathSchema]:
|
||||
slip44_id: tuple[int, ...]
|
||||
if info is None:
|
||||
if network_info is None:
|
||||
# allow Ethereum or testnet paths for unknown networks
|
||||
slip44_id = (60, 1)
|
||||
elif info.slip44 not in (60, 1):
|
||||
elif network_info.slip44 not in (60, 1):
|
||||
# allow cross-signing with Ethereum unless it's testnet
|
||||
slip44_id = (info.slip44, 60)
|
||||
slip44_id = (network_info.slip44, 60)
|
||||
else:
|
||||
slip44_id = (info.slip44,)
|
||||
slip44_id = (network_info.slip44,)
|
||||
|
||||
schemas = [
|
||||
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
|
||||
@ -98,14 +121,15 @@ def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]
|
||||
return [s.copy() for s in schemas]
|
||||
|
||||
|
||||
def with_keychain_from_chain_id(
|
||||
func: HandlerWithKeychain[MsgInChainId, MsgOut]
|
||||
) -> Handler[MsgInChainId, MsgOut]:
|
||||
def with_keychain_from_chain_id_and_defs(
|
||||
func: HandlerWithKeychainAndDefinitions[MsgInKeychainChainIdDefs, MsgOut]
|
||||
) -> Handler[MsgInKeychainChainIdDefs, MsgOut]:
|
||||
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
|
||||
async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut:
|
||||
schemas = _schemas_from_chain_id(msg)
|
||||
async def wrapper(ctx: Context, msg: MsgInKeychainChainIdDefs) -> MsgOut:
|
||||
defs = definitions.get_definitions_from_msg(msg)
|
||||
schemas = _schemas_from_chain_id(defs.network)
|
||||
keychain = await get_keychain(ctx, CURVE, schemas)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain)
|
||||
return await func(ctx, msg, keychain, defs)
|
||||
|
||||
return wrapper
|
||||
|
@ -12,7 +12,7 @@ from trezor.ui.layouts import (
|
||||
)
|
||||
|
||||
from . import networks
|
||||
from .helpers import decode_typed_data
|
||||
from .helpers import address_from_bytes, decode_typed_data, get_type_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Awaitable, Iterable
|
||||
@ -21,25 +21,29 @@ if TYPE_CHECKING:
|
||||
from trezor.wire import Context
|
||||
from . import tokens
|
||||
|
||||
from . import tokens
|
||||
|
||||
|
||||
def require_confirm_tx(
|
||||
ctx: Context,
|
||||
to_bytes: bytes,
|
||||
value: int,
|
||||
chain_id: int,
|
||||
token: tokens.TokenInfo | None = None,
|
||||
network: networks.NetworkInfo,
|
||||
token: tokens.TokenInfo,
|
||||
) -> Awaitable[None]:
|
||||
from .helpers import address_from_bytes
|
||||
from trezor.ui.layouts import confirm_output
|
||||
|
||||
if to_bytes:
|
||||
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
|
||||
to_str = address_from_bytes(to_bytes, network)
|
||||
else:
|
||||
to_str = "new contract?"
|
||||
return confirm_output(
|
||||
ctx,
|
||||
to_str,
|
||||
format_ethereum_amount(value, token, chain_id),
|
||||
address=to_str,
|
||||
amount=format_ethereum_amount(value, token, network),
|
||||
font_amount=ui.BOLD,
|
||||
color_to=ui.GREY,
|
||||
br_code=ButtonRequestType.SignTx,
|
||||
)
|
||||
|
||||
@ -49,19 +53,19 @@ async def require_confirm_fee(
|
||||
spending: int,
|
||||
gas_price: int,
|
||||
gas_limit: int,
|
||||
chain_id: int,
|
||||
token: tokens.TokenInfo | None = None,
|
||||
network: networks.NetworkInfo,
|
||||
token: tokens.TokenInfo,
|
||||
) -> None:
|
||||
await confirm_amount(
|
||||
ctx,
|
||||
title="Confirm fee",
|
||||
description="Gas price:",
|
||||
amount=format_ethereum_amount(gas_price, None, chain_id),
|
||||
amount=format_ethereum_amount(gas_price, None, network),
|
||||
)
|
||||
await confirm_total(
|
||||
ctx,
|
||||
total_amount=format_ethereum_amount(spending, token, chain_id),
|
||||
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, chain_id),
|
||||
total_amount=format_ethereum_amount(spending, token, network),
|
||||
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network),
|
||||
total_label="Amount sent:",
|
||||
fee_label="Maximum fee:",
|
||||
)
|
||||
@ -73,25 +77,25 @@ async def require_confirm_eip1559_fee(
|
||||
max_priority_fee: int,
|
||||
max_gas_fee: int,
|
||||
gas_limit: int,
|
||||
chain_id: int,
|
||||
token: tokens.TokenInfo | None = None,
|
||||
network: networks.NetworkInfo,
|
||||
token: tokens.TokenInfo,
|
||||
) -> None:
|
||||
await confirm_amount(
|
||||
ctx,
|
||||
"Confirm fee",
|
||||
format_ethereum_amount(max_gas_fee, None, chain_id),
|
||||
format_ethereum_amount(max_gas_fee, None, network),
|
||||
"Maximum fee per gas",
|
||||
)
|
||||
await confirm_amount(
|
||||
ctx,
|
||||
"Confirm fee",
|
||||
format_ethereum_amount(max_priority_fee, None, chain_id),
|
||||
format_ethereum_amount(max_priority_fee, None, network),
|
||||
"Priority fee per gas",
|
||||
)
|
||||
await confirm_total(
|
||||
ctx,
|
||||
format_ethereum_amount(spending, token, chain_id),
|
||||
format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
|
||||
format_ethereum_amount(spending, token, network),
|
||||
format_ethereum_amount(max_gas_fee * gas_limit, None, network),
|
||||
total_label="Amount sent:",
|
||||
fee_label="Maximum fee:",
|
||||
)
|
||||
@ -249,7 +253,7 @@ async def confirm_typed_value(
|
||||
|
||||
|
||||
def format_ethereum_amount(
|
||||
value: int, token: tokens.TokenInfo | None, chain_id: int
|
||||
value: int, token: tokens.TokenInfo | None, network_info: networks.NetworkInfo | None
|
||||
) -> str:
|
||||
from trezor.strings import format_amount
|
||||
|
||||
@ -257,7 +261,10 @@ def format_ethereum_amount(
|
||||
suffix = token.symbol
|
||||
decimals = token.decimals
|
||||
else:
|
||||
suffix = networks.shortcut_by_chain_id(chain_id)
|
||||
if network_info is not None:
|
||||
suffix = network_info.shortcut
|
||||
else:
|
||||
suffix = networks.UNKNOWN_NETWORK_SHORTCUT
|
||||
decimals = 18
|
||||
|
||||
# Don't want to display wei values for tokens with small decimal numbers
|
||||
|
@ -22,11 +22,12 @@ if TYPE_CHECKING:
|
||||
bool # rskip60
|
||||
]
|
||||
# fmt: on
|
||||
UNKNOWN_NETWORK_SHORTCUT = "UNKN"
|
||||
|
||||
|
||||
def shortcut_by_chain_id(chain_id: int) -> str:
|
||||
n = by_chain_id(chain_id)
|
||||
return n.shortcut if n is not None else "UNKN"
|
||||
return n.shortcut if n is not None else UNKNOWN_NETWORK_SHORTCUT
|
||||
|
||||
|
||||
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
|
||||
|
@ -22,11 +22,12 @@ if TYPE_CHECKING:
|
||||
bool # rskip60
|
||||
]
|
||||
# fmt: on
|
||||
UNKNOWN_NETWORK_SHORTCUT = "UNKN"
|
||||
|
||||
|
||||
def shortcut_by_chain_id(chain_id: int) -> str:
|
||||
n = by_chain_id(chain_id)
|
||||
return n.shortcut if n is not None else "UNKN"
|
||||
return n.shortcut if n is not None else UNKNOWN_NETWORK_SHORTCUT
|
||||
|
||||
|
||||
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import EthereumSignMessage, EthereumMessageSignature
|
||||
@ -8,6 +8,8 @@ if TYPE_CHECKING:
|
||||
|
||||
from apps.common.keychain import Keychain
|
||||
|
||||
from . import definitions
|
||||
|
||||
|
||||
def message_digest(message: bytes) -> bytes:
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
@ -21,9 +23,9 @@ def message_digest(message: bytes) -> bytes:
|
||||
return h.get_digest()
|
||||
|
||||
|
||||
@with_keychain_from_path(*PATTERNS_ADDRESS)
|
||||
@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
|
||||
async def sign_message(
|
||||
ctx: Context, msg: EthereumSignMessage, keychain: Keychain
|
||||
ctx: Context, msg: EthereumSignMessage, keychain: Keychain, defs: definitions.EthereumDefinitions
|
||||
) -> EthereumMessageSignature:
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from trezor.messages import EthereumMessageSignature
|
||||
@ -37,7 +39,7 @@ async def sign_message(
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
address = address_from_bytes(node.ethereum_pubkeyhash())
|
||||
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
|
||||
await confirm_signverify(
|
||||
ctx, "ETH", decode_message(msg.message), address, verify=False
|
||||
)
|
||||
|
@ -5,15 +5,16 @@ from trezor.messages import EthereumTxRequest
|
||||
from trezor.wire import DataError
|
||||
|
||||
from .helpers import bytes_from_address
|
||||
from .keychain import with_keychain_from_chain_id
|
||||
from .keychain import with_keychain_from_chain_id_and_defs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import defaultdict
|
||||
from apps.common.keychain import Keychain
|
||||
from trezor.messages import EthereumSignTx, EthereumTxAck
|
||||
from trezor.wire import Context
|
||||
|
||||
from .keychain import EthereumSignTxAny
|
||||
from . import tokens
|
||||
from . import tokens, definitions
|
||||
|
||||
|
||||
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
|
||||
@ -22,9 +23,9 @@ if TYPE_CHECKING:
|
||||
MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
|
||||
|
||||
|
||||
@with_keychain_from_chain_id
|
||||
@with_keychain_from_chain_id_and_defs
|
||||
async def sign_tx(
|
||||
ctx: Context, msg: EthereumSignTx, keychain: Keychain
|
||||
ctx: Context, msg: EthereumSignTx, keychain: Keychain, defs: definitions.EthereumDefinitions
|
||||
) -> EthereumTxRequest:
|
||||
from trezor.utils import HashWriter
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
@ -45,11 +46,11 @@ async def sign_tx(
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
|
||||
# Handle ERC20s
|
||||
token, address_bytes, recipient, value = await handle_erc20(ctx, msg)
|
||||
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict)
|
||||
|
||||
data_total = msg.data_length
|
||||
|
||||
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token)
|
||||
await require_confirm_tx(ctx, recipient, value, defs.network, token)
|
||||
if token is None and msg.data_length > 0:
|
||||
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
|
||||
|
||||
@ -58,7 +59,7 @@ async def sign_tx(
|
||||
value,
|
||||
int.from_bytes(msg.gas_price, "big"),
|
||||
int.from_bytes(msg.gas_limit, "big"),
|
||||
msg.chain_id,
|
||||
defs.network,
|
||||
token,
|
||||
)
|
||||
|
||||
@ -100,7 +101,7 @@ async def sign_tx(
|
||||
|
||||
|
||||
async def handle_erc20(
|
||||
ctx: Context, msg: EthereumSignTxAny
|
||||
ctx: Context, msg: EthereumSignTxAny, token_dict: defaultdict[bytes, tokens.TokenInfo]
|
||||
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
|
||||
from .layout import require_confirm_unknown_token
|
||||
from . import tokens
|
||||
@ -118,7 +119,7 @@ async def handle_erc20(
|
||||
and data_initial_chunk[:16]
|
||||
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
|
||||
):
|
||||
token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
|
||||
token = token_dict[address_bytes]
|
||||
recipient = data_initial_chunk[16:36]
|
||||
value = int.from_bytes(data_initial_chunk[36:68], "big")
|
||||
|
||||
|
@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
|
||||
from trezor.crypto import rlp
|
||||
|
||||
from .helpers import bytes_from_address
|
||||
from .keychain import with_keychain_from_chain_id
|
||||
from .keychain import with_keychain_from_chain_id_and_defs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import (
|
||||
@ -13,6 +13,7 @@ if TYPE_CHECKING:
|
||||
EthereumTxRequest,
|
||||
)
|
||||
|
||||
from . import definitions
|
||||
from apps.common.keychain import Keychain
|
||||
from trezor.wire import Context
|
||||
|
||||
@ -28,9 +29,9 @@ def access_list_item_length(item: EthereumAccessList) -> int:
|
||||
)
|
||||
|
||||
|
||||
@with_keychain_from_chain_id
|
||||
@with_keychain_from_chain_id_and_defs
|
||||
async def sign_tx_eip1559(
|
||||
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain
|
||||
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain, defs: definitions.EthereumDefinitions
|
||||
) -> EthereumTxRequest:
|
||||
from trezor.crypto.hashlib import sha3_256
|
||||
from trezor.utils import HashWriter
|
||||
@ -56,11 +57,11 @@ async def sign_tx_eip1559(
|
||||
await paths.validate_path(ctx, keychain, msg.address_n)
|
||||
|
||||
# Handle ERC20s
|
||||
token, address_bytes, recipient, value = await handle_erc20(ctx, msg)
|
||||
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict)
|
||||
|
||||
data_total = msg.data_length
|
||||
|
||||
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token)
|
||||
await require_confirm_tx(ctx, recipient, value, defs.network, token)
|
||||
if token is None and msg.data_length > 0:
|
||||
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
|
||||
|
||||
@ -70,7 +71,7 @@ async def sign_tx_eip1559(
|
||||
int.from_bytes(msg.max_priority_fee, "big"),
|
||||
int.from_bytes(msg.max_gas_fee, "big"),
|
||||
int.from_bytes(gas_limit, "big"),
|
||||
msg.chain_id,
|
||||
defs.network,
|
||||
token,
|
||||
)
|
||||
|
||||
|
@ -5,7 +5,7 @@ from trezor.enums import EthereumDataType
|
||||
from trezor.wire import DataError
|
||||
|
||||
from .helpers import get_type_name
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
|
||||
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path_and_defs
|
||||
from .layout import should_show_struct
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -20,14 +20,16 @@ if TYPE_CHECKING:
|
||||
EthereumTypedDataStructAck,
|
||||
)
|
||||
|
||||
from . import definitions
|
||||
|
||||
|
||||
# Maximum data size we support
|
||||
_MAX_VALUE_BYTE_SIZE = const(1024)
|
||||
|
||||
|
||||
@with_keychain_from_path(*PATTERNS_ADDRESS)
|
||||
@with_keychain_from_path_and_defs(*PATTERNS_ADDRESS)
|
||||
async def sign_typed_data(
|
||||
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain
|
||||
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain, defs: definitions.EthereumDefinitions
|
||||
) -> EthereumTypedDataSignature:
|
||||
from trezor.crypto.curve import secp256k1
|
||||
from apps.common import paths
|
||||
@ -46,7 +48,7 @@ async def sign_typed_data(
|
||||
)
|
||||
|
||||
return EthereumTypedDataSignature(
|
||||
address=address_from_bytes(node.ethereum_pubkeyhash()),
|
||||
address=address_from_bytes(node.ethereum_pubkeyhash(), defs.network),
|
||||
signature=signature[1:] + signature[0:1],
|
||||
)
|
||||
|
||||
|
@ -18,9 +18,19 @@ from typing import Iterator
|
||||
|
||||
|
||||
class TokenInfo:
|
||||
def __init__(self, symbol: str, decimals: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str,
|
||||
decimals: int,
|
||||
address: bytes = None,
|
||||
chain_id: int = None,
|
||||
name: str = None,
|
||||
) -> None:
|
||||
self.symbol = symbol
|
||||
self.decimals = decimals
|
||||
self.address = address
|
||||
self.chain_id = chain_id
|
||||
self.name = name
|
||||
|
||||
|
||||
UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)
|
||||
|
@ -27,9 +27,19 @@ def group_tokens(tokens):
|
||||
%>\
|
||||
|
||||
class TokenInfo:
|
||||
def __init__(self, symbol: str, decimals: int) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
symbol: str,
|
||||
decimals: int,
|
||||
address: bytes = None,
|
||||
chain_id: int = None,
|
||||
name: str = None,
|
||||
) -> None:
|
||||
self.symbol = symbol
|
||||
self.decimals = decimals
|
||||
self.address = address
|
||||
self.chain_id = chain_id
|
||||
self.name = name
|
||||
|
||||
|
||||
UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)
|
||||
|
@ -14,9 +14,12 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
|
||||
|
||||
from apps.common.signverify import decode_message
|
||||
|
||||
from . import definitions
|
||||
from .helpers import address_from_bytes, bytes_from_address
|
||||
from .sign_message import message_digest
|
||||
|
||||
defs = definitions.get_definitions_from_msg(msg)
|
||||
|
||||
digest = message_digest(msg.message)
|
||||
if len(msg.signature) != 65:
|
||||
raise DataError("Invalid signature")
|
||||
@ -33,7 +36,7 @@ async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success:
|
||||
if address_bytes != pkh:
|
||||
raise DataError("Invalid signature")
|
||||
|
||||
address = address_from_bytes(address_bytes)
|
||||
address = address_from_bytes(address_bytes, defs.network)
|
||||
|
||||
await confirm_signverify(
|
||||
ctx, "ETH", decode_message(msg.message), address, verify=True
|
||||
|
Loading…
Reference in New Issue
Block a user