feat(core): external Ethereum definitions

pull/2914/head
Martin Novák 1 year ago committed by matejcik
parent 168ab2944c
commit c2c0900c5d

@ -0,0 +1 @@
Signed Ethereum network and token definitions from host

@ -375,6 +375,8 @@ if not utils.BITCOIN_ONLY:
import trezor.enums.CardanoTxWitnessType import trezor.enums.CardanoTxWitnessType
trezor.enums.EthereumDataType trezor.enums.EthereumDataType
import trezor.enums.EthereumDataType import trezor.enums.EthereumDataType
trezor.enums.EthereumDefinitionType
import trezor.enums.EthereumDefinitionType
trezor.enums.MoneroNetworkType trezor.enums.MoneroNetworkType
import trezor.enums.MoneroNetworkType import trezor.enums.MoneroNetworkType
trezor.enums.NEMImportanceTransferMode trezor.enums.NEMImportanceTransferMode
@ -481,6 +483,10 @@ if not utils.BITCOIN_ONLY:
import apps.eos.writers import apps.eos.writers
apps.ethereum apps.ethereum
import apps.ethereum import apps.ethereum
apps.ethereum.definitions
import apps.ethereum.definitions
apps.ethereum.definitions_constants
import apps.ethereum.definitions_constants
apps.ethereum.get_address apps.ethereum.get_address
import apps.ethereum.get_address import apps.ethereum.get_address
apps.ethereum.get_public_key apps.ethereum.get_public_key

@ -33,19 +33,25 @@ def read_compact_size(r: BufferReader) -> int:
def read_uint16_be(r: BufferReader) -> int: def read_uint16_be(r: BufferReader) -> int:
n = r.get() data = r.read_memoryview(2)
return (n << 8) + r.get() return int.from_bytes(data, "big")
def read_uint32_be(r: BufferReader) -> int: def read_uint32_be(r: BufferReader) -> int:
n = r.get() data = r.read_memoryview(4)
for _ in range(3): return int.from_bytes(data, "big")
n = (n << 8) + r.get()
return n
def read_uint64_be(r: BufferReader) -> int: def read_uint64_be(r: BufferReader) -> int:
n = r.get() data = r.read_memoryview(8)
for _ in range(7): return int.from_bytes(data, "big")
n = (n << 8) + r.get()
return n
def read_uint16_le(r: BufferReader) -> int:
data = r.read_memoryview(2)
return int.from_bytes(data, "little")
def read_uint32_le(r: BufferReader) -> int:
data = r.read_memoryview(4)
return int.from_bytes(data, "little")

@ -0,0 +1,154 @@
from typing import TYPE_CHECKING
from trezor import protobuf, utils
from trezor.crypto.curve import ed25519
from trezor.crypto.hashlib import sha256
from trezor.enums import EthereumDefinitionType
from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo
from trezor.wire import DataError
from apps.common import readers
from . import definitions_constants as consts, networks, tokens
from .networks import UNKNOWN_NETWORK
if TYPE_CHECKING:
from typing import TypeVar
from typing_extensions import Self
DefType = TypeVar("DefType", EthereumNetworkInfo, EthereumTokenInfo)
def decode_definition(definition: bytes, expected_type: type[DefType]) -> DefType:
# check network definition
r = utils.BufferReader(definition)
expected_type_number = EthereumDefinitionType.NETWORK
# TODO: can't check equality of MsgDefObjs now, so we check the name
if expected_type.MESSAGE_NAME == EthereumTokenInfo.MESSAGE_NAME:
expected_type_number = EthereumDefinitionType.TOKEN
try:
# first check format version
if r.read_memoryview(len(consts.FORMAT_VERSION)) != consts.FORMAT_VERSION:
raise DataError("Invalid Ethereum definition")
# second check the type of the data
if r.get() != expected_type_number:
raise DataError("Definition type mismatch")
# third check data version
if readers.read_uint32_le(r) < consts.MIN_DATA_VERSION:
raise DataError("Definition is outdated")
# get payload
payload_length = readers.read_uint16_le(r)
payload = r.read_memoryview(payload_length)
# at the end compute Merkle tree root hash using
# provided leaf data (payload with prefix) and proof
hasher = sha256(b"\x00")
hasher.update(memoryview(definition)[: r.offset])
hash = hasher.digest()
proof_length = r.get()
for _ in range(proof_length):
proof_entry = r.read_memoryview(32)
hash_a = min(hash, proof_entry)
hash_b = max(hash, proof_entry)
hasher = sha256(b"\x01")
hasher.update(hash_a)
hasher.update(hash_b)
hash = hasher.digest()
signed_tree_root = r.read_memoryview(64)
if r.remaining_count():
raise DataError("Invalid Ethereum definition")
except EOFError:
raise DataError("Invalid Ethereum definition")
# verify signature
if not ed25519.verify(consts.DEFINITIONS_PUBLIC_KEY, signed_tree_root, hash):
error_msg = DataError("Invalid definition signature")
if __debug__:
# check against dev key
if not ed25519.verify(
consts.DEFINITIONS_DEV_PUBLIC_KEY,
signed_tree_root,
hash,
):
raise error_msg
else:
raise error_msg
# decode it if it's OK
try:
return protobuf.decode(payload, expected_type, True)
except ValueError:
raise DataError("Invalid Ethereum definition")
class Definitions:
"""Class that holds Ethereum definitions - network and tokens.
Prefers built-in definitions over encoded ones.
"""
def __init__(
self, network: EthereumNetworkInfo, tokens: dict[bytes, EthereumTokenInfo]
) -> None:
self.network = network
self._tokens = tokens
@classmethod
def from_encoded(
cls,
encoded_network: bytes | None,
encoded_token: bytes | None,
chain_id: int | None = None,
slip44: int | None = None,
) -> Self:
network = UNKNOWN_NETWORK
tokens: dict[bytes, EthereumTokenInfo] = {}
# if we have a built-in definition, use it
if chain_id is not None:
network = networks.by_chain_id(chain_id)
elif slip44 is not None:
network = networks.by_slip44(slip44)
else:
# ignore encoded definitions if we can't match them to request details
return cls(UNKNOWN_NETWORK, {})
if network is UNKNOWN_NETWORK and encoded_network is not None:
network = decode_definition(encoded_network, EthereumNetworkInfo)
if network is UNKNOWN_NETWORK:
# ignore tokens if we don't have a network
return cls(UNKNOWN_NETWORK, {})
if chain_id is not None and network.chain_id != chain_id:
raise DataError("Network definition mismatch")
if slip44 is not None and network.slip44 != slip44:
raise DataError("Network definition mismatch")
# get token definition
if encoded_token is not None:
token = decode_definition(encoded_token, EthereumTokenInfo)
# Ignore token if it doesn't match the network instead of raising an error.
# This might help us in the future if we allow multiple networks/tokens
# in the same message.
if token.chain_id == network.chain_id:
tokens[token.address] = token
return cls(network, tokens)
def get_token(self, address: bytes) -> EthereumTokenInfo:
# if we have a built-in definition, use it
token = tokens.token_by_chain_address(self.network.chain_id, address)
if token is not None:
return token
if address in self._tokens:
return self._tokens[address]
return tokens.UNKNOWN_TOKEN

@ -0,0 +1,14 @@
# generated from definitions_constants.py.mako
# (by running `make templates` in `core`)
# do not edit manually!
from ubinascii import unhexlify
DEFINITIONS_PUBLIC_KEY = b""
MIN_DATA_VERSION = 1669892465
FORMAT_VERSION = b"trzd1"
if __debug__:
DEFINITIONS_DEV_PUBLIC_KEY = unhexlify(
"db995fe25169d141cab9bbba92baa01f9f2e1ece7df4cb2ac05190f37fcc1f9d"
)

@ -0,0 +1,14 @@
# generated from definitions_constants.py.mako
# (by running `make templates` in `core`)
# do not edit manually!
from ubinascii import unhexlify
DEFINITIONS_PUBLIC_KEY = b""
MIN_DATA_VERSION = ${ethereum_defs_timestamp}
FORMAT_VERSION = b"trzd1"
if __debug__:
DEFINITIONS_DEV_PUBLIC_KEY = unhexlify(
"db995fe25169d141cab9bbba92baa01f9f2e1ece7df4cb2ac05190f37fcc1f9d"
)

@ -7,16 +7,19 @@ if TYPE_CHECKING:
from trezor.wire import Context from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .definitions import Definitions
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path(*PATTERNS_ADDRESS)
async def get_address( async def get_address(
ctx: Context, msg: EthereumGetAddress, keychain: Keychain ctx: Context,
msg: EthereumGetAddress,
keychain: Keychain,
defs: Definitions,
) -> EthereumAddress: ) -> EthereumAddress:
from trezor.messages import EthereumAddress from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address from trezor.ui.layouts import show_address
from apps.common import paths from apps.common import paths
from . import networks
from .helpers import address_from_bytes from .helpers import address_from_bytes
address_n = msg.address_n # local_cache_attribute address_n = msg.address_n # local_cache_attribute
@ -25,11 +28,7 @@ async def get_address(
node = keychain.derive(address_n) node = keychain.derive(address_n)
if len(address_n) > 1: # path has slip44 network identifier address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF)
else:
network = None
address = address_from_bytes(node.ethereum_pubkeyhash(), network)
if msg.show_display: if msg.show_display:
await show_address(ctx, address, path=paths.address_n_to_str(address_n)) await show_address(ctx, address, path=paths.address_n_to_str(address_n))

@ -1,19 +1,27 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ubinascii import hexlify from ubinascii import hexlify
from . import networks
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumFieldType from trezor.messages import EthereumFieldType
from .networks import NetworkInfo from .networks import EthereumNetworkInfo
RSKIP60_NETWORKS = (30, 31)
def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None) -> str: def address_from_bytes(
address_bytes: bytes, network: EthereumNetworkInfo = networks.UNKNOWN_NETWORK
) -> str:
""" """
Converts address in bytes to a checksummed string as defined Converts address in bytes to a checksummed string as defined
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
""" """
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
if network is not None and network.rskip60: if network.chain_id in RSKIP60_NETWORKS:
# rskip60 is a different way to calculate checksum
prefix = str(network.chain_id) + "0x" prefix = str(network.chain_id) + "0x"
else: else:
prefix = "" prefix = ""

@ -1,37 +1,56 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.messages import EthereumNetworkInfo
from apps.common import paths from apps.common import paths
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
from . import CURVE, networks from . import CURVE, definitions, networks
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Callable, Iterable, TypeVar from typing import Any, Awaitable, Callable, Iterable, TypeVar
from apps.common.keychain import Keychain
from trezor.wire import Context from trezor.wire import Context
from trezor.messages import ( from trezor.messages import (
EthereumGetAddress, EthereumGetAddress,
EthereumGetPublicKey,
EthereumSignMessage, EthereumSignMessage,
EthereumSignTx, EthereumSignTx,
EthereumSignTxEIP1559, EthereumSignTxEIP1559,
EthereumSignTypedData, EthereumSignTypedData,
) )
from apps.common.keychain import MsgOut, Handler, HandlerWithKeychain from apps.common.keychain import (
MsgOut,
Handler,
)
EthereumMessages = ( # messages for "with_keychain_and_network_from_path" decorator
EthereumGetAddress MsgInAddressN = TypeVar(
| EthereumGetPublicKey "MsgInAddressN",
| EthereumSignTx EthereumGetAddress,
| EthereumSignMessage EthereumSignMessage,
| EthereumSignTypedData EthereumSignTypedData,
) )
MsgIn = TypeVar("MsgIn", bound=EthereumMessages)
EthereumSignTxAny = EthereumSignTx | EthereumSignTxEIP1559 HandlerAddressN = Callable[
MsgInChainId = TypeVar("MsgInChainId", bound=EthereumSignTxAny) [Context, MsgInAddressN, Keychain, definitions.Definitions],
Awaitable[MsgOut],
]
# messages for "with_keychain_and_defs_from_chain_id" decorator
MsgInSignTx = TypeVar(
"MsgInSignTx",
EthereumSignTx,
EthereumSignTxEIP1559,
)
HandlerChainId = Callable[
[Context, MsgInSignTx, Keychain, definitions.Definitions],
Awaitable[MsgOut],
]
# We believe Ethereum should use 44'/60'/a' for everything, because it is # We believe Ethereum should use 44'/60'/a' for everything, because it is
@ -48,67 +67,83 @@ PATTERNS_ADDRESS = (
) )
def _schemas_from_address_n( def _slip44_from_address_n(address_n: paths.Bip32Path) -> int | None:
patterns: Iterable[str], address_n: paths.Bip32Path HARDENED = paths.HARDENED # local_cache_attribute
) -> Iterable[paths.PathSchema]: if len(address_n) < 2:
# Casa paths (purpose of 45) do not have hardened coin types return None
if address_n[0] == 45 | paths.HARDENED and not address_n[1] & paths.HARDENED:
slip44_hardened = address_n[1] | paths.HARDENED
else:
slip44_hardened = address_n[1]
if slip44_hardened not in networks.all_slip44_ids_hardened(): if address_n[0] == 45 | HARDENED and not address_n[1] & HARDENED:
return () return address_n[1]
if not slip44_hardened & paths.HARDENED: return address_n[1] & ~HARDENED
return ()
slip44_id = slip44_hardened - paths.HARDENED
schemas = [paths.PathSchema.parse(pattern, slip44_id) for pattern in patterns]
return [s.copy() for s in schemas]
def _defs_from_message(
msg: Any, chain_id: int | None = None, slip44: int | None = None
) -> definitions.Definitions:
encoded_network = None
encoded_token = None
def with_keychain_from_path( # try to get both from msg.definitions
*patterns: str, if hasattr(msg, "definitions"):
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: if msg.definitions is not None:
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: encoded_network = msg.definitions.encoded_network
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut: encoded_token = msg.definitions.encoded_token
schemas = _schemas_from_address_n(patterns, msg.address_n)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain)
return wrapper elif hasattr(msg, "encoded_network"):
encoded_network = msg.encoded_network
return decorator return definitions.Definitions.from_encoded(
encoded_network, encoded_token, chain_id, slip44
)
def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]: def _schemas_from_network(
info = networks.by_chain_id(msg.chain_id) patterns: Iterable[str],
network_info: EthereumNetworkInfo,
) -> Iterable[paths.PathSchema]:
slip44_id: tuple[int, ...] slip44_id: tuple[int, ...]
if info is None: if network_info is networks.UNKNOWN_NETWORK:
# allow Ethereum or testnet paths for unknown networks # allow Ethereum or testnet paths for unknown networks
slip44_id = (60, 1) slip44_id = (60, 1)
elif info.slip44 not in (60, 1): elif network_info.slip44 not in (60, 1):
# allow cross-signing with Ethereum unless it's testnet # allow cross-signing with Ethereum unless it's testnet
slip44_id = (info.slip44, 60) slip44_id = (network_info.slip44, 60)
else: else:
slip44_id = (info.slip44,) slip44_id = (network_info.slip44,)
schemas = [ schemas = [paths.PathSchema.parse(pattern, slip44_id) for pattern in patterns]
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
]
return [s.copy() for s in schemas] return [s.copy() for s in schemas]
def with_keychain_from_path(
*patterns: str,
) -> Callable[[HandlerAddressN[MsgInAddressN, MsgOut]], Handler[MsgInAddressN, MsgOut]]:
def decorator(
func: HandlerAddressN[MsgInAddressN, MsgOut]
) -> Handler[MsgInAddressN, MsgOut]:
async def wrapper(ctx: Context, msg: MsgInAddressN) -> MsgOut:
slip44 = _slip44_from_address_n(msg.address_n)
defs = _defs_from_message(msg, slip44=slip44)
schemas = _schemas_from_network(patterns, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain, defs)
return wrapper
return decorator
def with_keychain_from_chain_id( def with_keychain_from_chain_id(
func: HandlerWithKeychain[MsgInChainId, MsgOut] func: HandlerChainId[MsgInSignTx, MsgOut]
) -> Handler[MsgInChainId, MsgOut]: ) -> Handler[MsgInSignTx, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed # this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut: async def wrapper(ctx: Context, msg: MsgInSignTx) -> MsgOut:
schemas = _schemas_from_chain_id(msg) defs = _defs_from_message(msg, chain_id=msg.chain_id)
schemas = _schemas_from_network(PATTERNS_ADDRESS, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas) keychain = await get_keychain(ctx, CURVE, schemas)
with keychain: with keychain:
return await func(ctx, msg, keychain) return await func(ctx, msg, keychain, defs)
return wrapper return wrapper

@ -11,35 +11,38 @@ from trezor.ui.layouts import (
should_show_more, should_show_more,
) )
from . import networks
from .helpers import decode_typed_data from .helpers import decode_typed_data
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Awaitable, Iterable from typing import Awaitable, Iterable
from trezor.messages import EthereumFieldType, EthereumStructMember from trezor.messages import (
EthereumFieldType,
EthereumNetworkInfo,
EthereumStructMember,
EthereumTokenInfo,
)
from trezor.wire import Context from trezor.wire import Context
from . import tokens
def require_confirm_tx( def require_confirm_tx(
ctx: Context, ctx: Context,
to_bytes: bytes, to_bytes: bytes,
value: int, value: int,
chain_id: int, network: EthereumNetworkInfo,
token: tokens.TokenInfo | None = None, token: EthereumTokenInfo | None,
) -> Awaitable[None]: ) -> Awaitable[None]:
from .helpers import address_from_bytes from .helpers import address_from_bytes
from trezor.ui.layouts import confirm_output from trezor.ui.layouts import confirm_output
if to_bytes: if to_bytes:
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id)) to_str = address_from_bytes(to_bytes, network)
else: else:
to_str = "new contract?" to_str = "new contract?"
return confirm_output( return confirm_output(
ctx, ctx,
to_str, to_str,
format_ethereum_amount(value, token, chain_id), format_ethereum_amount(value, token, network),
br_code=ButtonRequestType.SignTx, br_code=ButtonRequestType.SignTx,
) )
@ -49,19 +52,19 @@ async def require_confirm_fee(
spending: int, spending: int,
gas_price: int, gas_price: int,
gas_limit: int, gas_limit: int,
chain_id: int, network: EthereumNetworkInfo,
token: tokens.TokenInfo | None = None, token: EthereumTokenInfo | None,
) -> None: ) -> None:
await confirm_amount( await confirm_amount(
ctx, ctx,
title="Confirm fee", title="Confirm fee",
description="Gas price:", description="Gas price:",
amount=format_ethereum_amount(gas_price, None, chain_id), amount=format_ethereum_amount(gas_price, None, network),
) )
await confirm_total( await confirm_total(
ctx, ctx,
total_amount=format_ethereum_amount(spending, token, chain_id), total_amount=format_ethereum_amount(spending, token, network),
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, chain_id), fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network),
total_label="Amount sent:", total_label="Amount sent:",
fee_label="Maximum fee:", fee_label="Maximum fee:",
) )
@ -73,25 +76,25 @@ async def require_confirm_eip1559_fee(
max_priority_fee: int, max_priority_fee: int,
max_gas_fee: int, max_gas_fee: int,
gas_limit: int, gas_limit: int,
chain_id: int, network: EthereumNetworkInfo,
token: tokens.TokenInfo | None = None, token: EthereumTokenInfo | None,
) -> None: ) -> None:
await confirm_amount( await confirm_amount(
ctx, ctx,
"Confirm fee", "Confirm fee",
format_ethereum_amount(max_gas_fee, None, chain_id), format_ethereum_amount(max_gas_fee, None, network),
"Maximum fee per gas", "Maximum fee per gas",
) )
await confirm_amount( await confirm_amount(
ctx, ctx,
"Confirm fee", "Confirm fee",
format_ethereum_amount(max_priority_fee, None, chain_id), format_ethereum_amount(max_priority_fee, None, network),
"Priority fee per gas", "Priority fee per gas",
) )
await confirm_total( await confirm_total(
ctx, ctx,
format_ethereum_amount(spending, token, chain_id), format_ethereum_amount(spending, token, network),
format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id), format_ethereum_amount(max_gas_fee * gas_limit, None, network),
total_label="Amount sent:", total_label="Amount sent:",
fee_label="Maximum fee:", fee_label="Maximum fee:",
) )
@ -262,7 +265,9 @@ async def confirm_typed_value(
def format_ethereum_amount( def format_ethereum_amount(
value: int, token: tokens.TokenInfo | None, chain_id: int value: int,
token: EthereumTokenInfo | None,
network: EthereumNetworkInfo,
) -> str: ) -> str:
from trezor.strings import format_amount from trezor.strings import format_amount
@ -270,7 +275,7 @@ def format_ethereum_amount(
suffix = token.symbol suffix = token.symbol
decimals = token.decimals decimals = token.decimals
else: else:
suffix = networks.shortcut_by_chain_id(chain_id) suffix = network.symbol
decimals = 18 decimals = 18
# Don't want to display wei values for tokens with small decimal numbers # Don't want to display wei values for tokens with small decimal numbers

File diff suppressed because it is too large Load Diff

@ -7,7 +7,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from apps.common.paths import HARDENED from trezor.messages import EthereumNetworkInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Iterator from typing import Iterator
@ -17,71 +17,52 @@ if TYPE_CHECKING:
NetworkInfoTuple = tuple[ NetworkInfoTuple = tuple[
int, # chain_id int, # chain_id
int, # slip44 int, # slip44
str, # shortcut str, # symbol
str, # name str, # name
bool # rskip60
] ]
# fmt: on # fmt: on
UNKNOWN_NETWORK = EthereumNetworkInfo(
chain_id=0,
slip44=0,
symbol="UNKN",
name="Unknown network",
)
def shortcut_by_chain_id(chain_id: int) -> str:
n = by_chain_id(chain_id)
return n.shortcut if n is not None else "UNKN"
def by_chain_id(chain_id: int) -> EthereumNetworkInfo:
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
for n in _networks_iterator(): for n in _networks_iterator():
n_chain_id = n[0] n_chain_id = n[0]
if n_chain_id == chain_id: if n_chain_id == chain_id:
return NetworkInfo( return EthereumNetworkInfo(
chain_id=n[0], chain_id=n[0],
slip44=n[1], slip44=n[1],
shortcut=n[2], symbol=n[2],
name=n[3], name=n[3],
rskip60=n[4],
) )
return None return UNKNOWN_NETWORK
def by_slip44(slip44: int) -> "NetworkInfo" | None: def by_slip44(slip44: int) -> EthereumNetworkInfo:
for n in _networks_iterator(): for n in _networks_iterator():
n_slip44 = n[1] n_slip44 = n[1]
if n_slip44 == slip44: if n_slip44 == slip44:
return NetworkInfo( return EthereumNetworkInfo(
chain_id=n[0], chain_id=n[0],
slip44=n[1], slip44=n[1],
shortcut=n[2], symbol=n[2],
name=n[3], name=n[3],
rskip60=n[4],
) )
return None return UNKNOWN_NETWORK
def all_slip44_ids_hardened() -> Iterator[int]:
for n in _networks_iterator():
# n_slip_44 is the second element
yield n[1] | HARDENED
class NetworkInfo:
def __init__(
self, chain_id: int, slip44: int, shortcut: str, name: str, rskip60: bool
) -> None:
self.chain_id = chain_id
self.slip44 = slip44
self.shortcut = shortcut
self.name = name
self.rskip60 = rskip60
# fmt: off # fmt: off
def _networks_iterator() -> Iterator[NetworkInfoTuple]: def _networks_iterator() -> Iterator[NetworkInfoTuple]:
% for n in supported_on("trezor2", eth): % for n in sorted(supported_on("trezor2", eth), key=lambda network: (int(network.chain_id), network.name)):
yield ( yield (
${n.chain_id}, # chain_id ${n.chain_id}, # chain_id
${n.slip44}, # slip44 ${n.slip44}, # slip44
"${n.shortcut}", # shortcut "${n.shortcut}", # symbol
"${n.name}", # name "${n.name}", # name
${n.rskip60}, # rskip60
) )
% endfor % endfor

@ -3,10 +3,14 @@ from typing import TYPE_CHECKING
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.messages import EthereumSignMessage, EthereumMessageSignature from trezor.messages import (
EthereumSignMessage,
EthereumMessageSignature,
)
from trezor.wire import Context from trezor.wire import Context
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from .definitions import Definitions
def message_digest(message: bytes) -> bytes: def message_digest(message: bytes) -> bytes:
@ -23,7 +27,10 @@ def message_digest(message: bytes) -> bytes:
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_message( async def sign_message(
ctx: Context, msg: EthereumSignMessage, keychain: Keychain ctx: Context,
msg: EthereumSignMessage,
keychain: Keychain,
defs: Definitions,
) -> EthereumMessageSignature: ) -> EthereumMessageSignature:
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from trezor.messages import EthereumMessageSignature from trezor.messages import EthereumMessageSignature
@ -37,7 +44,7 @@ async def sign_message(
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
address = address_from_bytes(node.ethereum_pubkeyhash()) address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
await confirm_signverify( await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address, verify=False ctx, "ETH", decode_message(msg.message), address, verify=False
) )

@ -9,11 +9,11 @@ from .keychain import with_keychain_from_chain_id
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.messages import EthereumSignTx, EthereumTxAck from trezor.messages import EthereumSignTx, EthereumTxAck, EthereumTokenInfo
from trezor.wire import Context from trezor.wire import Context
from .keychain import EthereumSignTxAny from .definitions import Definitions
from . import tokens from .keychain import MsgInSignTx
# Maximum chain_id which returns the full signature_v (which must fit into an uint32). # Maximum chain_id which returns the full signature_v (which must fit into an uint32).
@ -24,7 +24,10 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id @with_keychain_from_chain_id
async def sign_tx( async def sign_tx(
ctx: Context, msg: EthereumSignTx, keychain: Keychain ctx: Context,
msg: EthereumSignTx,
keychain: Keychain,
defs: Definitions,
) -> EthereumTxRequest: ) -> EthereumTxRequest:
from trezor.utils import HashWriter from trezor.utils import HashWriter
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
@ -45,11 +48,11 @@ async def sign_tx(
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s # Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg) token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs)
data_total = msg.data_length data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token) await require_confirm_tx(ctx, recipient, value, defs.network, token)
if token is None and msg.data_length > 0: if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total) await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
@ -58,7 +61,7 @@ async def sign_tx(
value, value,
int.from_bytes(msg.gas_price, "big"), int.from_bytes(msg.gas_price, "big"),
int.from_bytes(msg.gas_limit, "big"), int.from_bytes(msg.gas_limit, "big"),
msg.chain_id, defs.network,
token, token,
) )
@ -100,13 +103,14 @@ async def sign_tx(
async def handle_erc20( async def handle_erc20(
ctx: Context, msg: EthereumSignTxAny ctx: Context,
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]: msg: MsgInSignTx,
definitions: Definitions,
) -> tuple[EthereumTokenInfo | None, bytes, bytes, int]:
from .layout import require_confirm_unknown_token from .layout import require_confirm_unknown_token
from . import tokens from . import tokens
data_initial_chunk = msg.data_initial_chunk # local_cache_attribute data_initial_chunk = msg.data_initial_chunk # local_cache_attribute
token = None token = None
address_bytes = recipient = bytes_from_address(msg.to) address_bytes = recipient = bytes_from_address(msg.to)
value = int.from_bytes(msg.value, "big") value = int.from_bytes(msg.value, "big")
@ -118,7 +122,7 @@ async def handle_erc20(
and data_initial_chunk[:16] and data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" == b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
): ):
token = tokens.token_by_chain_address(msg.chain_id, address_bytes) token = definitions.get_token(address_bytes)
recipient = data_initial_chunk[16:36] recipient = data_initial_chunk[16:36]
value = int.from_bytes(data_initial_chunk[36:68], "big") value = int.from_bytes(data_initial_chunk[36:68], "big")
@ -185,7 +189,7 @@ def _sign_digest(
return req return req
def check_common_fields(msg: EthereumSignTxAny) -> None: def check_common_fields(msg: MsgInSignTx) -> None:
data_length = msg.data_length # local_cache_attribute data_length = msg.data_length # local_cache_attribute
if data_length > 0: if data_length > 0:

@ -12,9 +12,9 @@ if TYPE_CHECKING:
EthereumAccessList, EthereumAccessList,
EthereumTxRequest, EthereumTxRequest,
) )
from apps.common.keychain import Keychain
from trezor.wire import Context from trezor.wire import Context
from apps.common.keychain import Keychain
from .definitions import Definitions
_TX_TYPE = const(2) _TX_TYPE = const(2)
@ -30,7 +30,10 @@ def access_list_item_length(item: EthereumAccessList) -> int:
@with_keychain_from_chain_id @with_keychain_from_chain_id
async def sign_tx_eip1559( async def sign_tx_eip1559(
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain ctx: Context,
msg: EthereumSignTxEIP1559,
keychain: Keychain,
defs: Definitions,
) -> EthereumTxRequest: ) -> EthereumTxRequest:
from trezor.crypto.hashlib import sha3_256 from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter from trezor.utils import HashWriter
@ -56,11 +59,11 @@ async def sign_tx_eip1559(
await paths.validate_path(ctx, keychain, msg.address_n) await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s # Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg) token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs)
data_total = msg.data_length data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token) await require_confirm_tx(ctx, recipient, value, defs.network, token)
if token is None and msg.data_length > 0: if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total) await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
@ -70,7 +73,7 @@ async def sign_tx_eip1559(
int.from_bytes(msg.max_priority_fee, "big"), int.from_bytes(msg.max_priority_fee, "big"),
int.from_bytes(msg.max_gas_fee, "big"), int.from_bytes(msg.max_gas_fee, "big"),
int.from_bytes(gas_limit, "big"), int.from_bytes(gas_limit, "big"),
msg.chain_id, defs.network,
token, token,
) )

@ -11,6 +11,7 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain from apps.common.keychain import Keychain
from trezor.wire import Context from trezor.wire import Context
from trezor.utils import HashWriter from trezor.utils import HashWriter
from .definitions import Definitions
from trezor.messages import ( from trezor.messages import (
EthereumSignTypedData, EthereumSignTypedData,
@ -22,7 +23,10 @@ if TYPE_CHECKING:
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_typed_data( async def sign_typed_data(
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain ctx: Context,
msg: EthereumSignTypedData,
keychain: Keychain,
defs: Definitions,
) -> EthereumTypedDataSignature: ) -> EthereumTypedDataSignature:
from trezor.crypto.curve import secp256k1 from trezor.crypto.curve import secp256k1
from apps.common import paths from apps.common import paths
@ -47,7 +51,7 @@ async def sign_typed_data(
) )
return EthereumTypedDataSignature( return EthereumTypedDataSignature(
address=address_from_bytes(address_bytes), address=address_from_bytes(address_bytes, defs.network),
signature=signature[1:] + signature[0:1], signature=signature[1:] + signature[0:1],
) )

File diff suppressed because it is too large Load Diff

@ -16,6 +16,7 @@
from typing import Iterator from typing import Iterator
from trezor.messages import EthereumTokenInfo
<% <%
from collections import defaultdict from collections import defaultdict
@ -26,30 +27,37 @@ def group_tokens(tokens):
return r return r
%>\ %>\
class TokenInfo: UNKNOWN_TOKEN = EthereumTokenInfo(
def __init__(self, symbol: str, decimals: int) -> None: symbol="Wei UNKN",
self.symbol = symbol decimals=0,
self.decimals = decimals address=b"",
chain_id=0,
name="Unknown token",
)
UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0) def token_by_chain_address(chain_id: int, address: bytes) -> EthereumTokenInfo | None:
for addr, symbol, decimal, name in _token_iterator(chain_id):
def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo:
for addr, symbol, decimal in _token_iterator(chain_id):
if address == addr: if address == addr:
return TokenInfo(symbol, decimal) return EthereumTokenInfo(
return UNKNOWN_TOKEN symbol=symbol,
decimals=decimal,
address=address,
chain_id=chain_id,
name=name,
)
return None
def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int]]: def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int, str]]:
% for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items(): % for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items():
if chain_id == ${token_chain_id}: if chain_id == ${token_chain_id}: # ${tokens[0].chain}
% for t in tokens: % for t in tokens:
yield ( # address, symbol, decimals yield ( # address, symbol, decimals, name
${black_repr(t.address_bytes)}, ${black_repr(t.address_bytes)},
${black_repr(t.symbol)}, ${black_repr(t.symbol)},
${t.decimals}, ${t.decimals},
${black_repr(t.name.strip())},
) )
% endfor % endfor
% endfor % endfor

@ -0,0 +1,6 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
NETWORK = 0
TOKEN = 1

@ -455,6 +455,10 @@ if TYPE_CHECKING:
YES = 1 YES = 1
INFO = 2 INFO = 2
class EthereumDefinitionType(IntEnum):
NETWORK = 0
TOKEN = 1
class EthereumDataType(IntEnum): class EthereumDataType(IntEnum):
UINT = 1 UINT = 1
INT = 2 INT = 2

@ -38,6 +38,7 @@ if TYPE_CHECKING:
from trezor.enums import DebugSwipeDirection # noqa: F401 from trezor.enums import DebugSwipeDirection # noqa: F401
from trezor.enums import DecredStakingSpendType # noqa: F401 from trezor.enums import DecredStakingSpendType # noqa: F401
from trezor.enums import EthereumDataType # noqa: F401 from trezor.enums import EthereumDataType # noqa: F401
from trezor.enums import EthereumDefinitionType # noqa: F401
from trezor.enums import FailureType # noqa: F401 from trezor.enums import FailureType # noqa: F401
from trezor.enums import HomescreenFormat # noqa: F401 from trezor.enums import HomescreenFormat # noqa: F401
from trezor.enums import InputScriptType # noqa: F401 from trezor.enums import InputScriptType # noqa: F401
@ -3372,10 +3373,69 @@ if TYPE_CHECKING:
def is_type_of(cls, msg: Any) -> TypeGuard["EosActionUnknown"]: def is_type_of(cls, msg: Any) -> TypeGuard["EosActionUnknown"]:
return isinstance(msg, cls) return isinstance(msg, cls)
class EthereumNetworkInfo(protobuf.MessageType):
chain_id: "int"
symbol: "str"
slip44: "int"
name: "str"
def __init__(
self,
*,
chain_id: "int",
symbol: "str",
slip44: "int",
name: "str",
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["EthereumNetworkInfo"]:
return isinstance(msg, cls)
class EthereumTokenInfo(protobuf.MessageType):
address: "bytes"
chain_id: "int"
symbol: "str"
decimals: "int"
name: "str"
def __init__(
self,
*,
address: "bytes",
chain_id: "int",
symbol: "str",
decimals: "int",
name: "str",
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["EthereumTokenInfo"]:
return isinstance(msg, cls)
class EthereumDefinitions(protobuf.MessageType):
encoded_network: "bytes | None"
encoded_token: "bytes | None"
def __init__(
self,
*,
encoded_network: "bytes | None" = None,
encoded_token: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["EthereumDefinitions"]:
return isinstance(msg, cls)
class EthereumSignTypedData(protobuf.MessageType): class EthereumSignTypedData(protobuf.MessageType):
address_n: "list[int]" address_n: "list[int]"
primary_type: "str" primary_type: "str"
metamask_v4_compat: "bool" metamask_v4_compat: "bool"
definitions: "EthereumDefinitions | None"
def __init__( def __init__(
self, self,
@ -3383,6 +3443,7 @@ if TYPE_CHECKING:
primary_type: "str", primary_type: "str",
address_n: "list[int] | None" = None, address_n: "list[int] | None" = None,
metamask_v4_compat: "bool | None" = None, metamask_v4_compat: "bool | None" = None,
definitions: "EthereumDefinitions | None" = None,
) -> None: ) -> None:
pass pass
@ -3517,12 +3578,14 @@ if TYPE_CHECKING:
class EthereumGetAddress(protobuf.MessageType): class EthereumGetAddress(protobuf.MessageType):
address_n: "list[int]" address_n: "list[int]"
show_display: "bool | None" show_display: "bool | None"
encoded_network: "bytes | None"
def __init__( def __init__(
self, self,
*, *,
address_n: "list[int] | None" = None, address_n: "list[int] | None" = None,
show_display: "bool | None" = None, show_display: "bool | None" = None,
encoded_network: "bytes | None" = None,
) -> None: ) -> None:
pass pass
@ -3555,6 +3618,7 @@ if TYPE_CHECKING:
data_length: "int" data_length: "int"
chain_id: "int" chain_id: "int"
tx_type: "int | None" tx_type: "int | None"
definitions: "EthereumDefinitions | None"
def __init__( def __init__(
self, self,
@ -3569,6 +3633,7 @@ if TYPE_CHECKING:
data_initial_chunk: "bytes | None" = None, data_initial_chunk: "bytes | None" = None,
data_length: "int | None" = None, data_length: "int | None" = None,
tx_type: "int | None" = None, tx_type: "int | None" = None,
definitions: "EthereumDefinitions | None" = None,
) -> None: ) -> None:
pass pass
@ -3588,6 +3653,7 @@ if TYPE_CHECKING:
data_length: "int" data_length: "int"
chain_id: "int" chain_id: "int"
access_list: "list[EthereumAccessList]" access_list: "list[EthereumAccessList]"
definitions: "EthereumDefinitions | None"
def __init__( def __init__(
self, self,
@ -3603,6 +3669,7 @@ if TYPE_CHECKING:
access_list: "list[EthereumAccessList] | None" = None, access_list: "list[EthereumAccessList] | None" = None,
to: "str | None" = None, to: "str | None" = None,
data_initial_chunk: "bytes | None" = None, data_initial_chunk: "bytes | None" = None,
definitions: "EthereumDefinitions | None" = None,
) -> None: ) -> None:
pass pass
@ -3647,12 +3714,14 @@ if TYPE_CHECKING:
class EthereumSignMessage(protobuf.MessageType): class EthereumSignMessage(protobuf.MessageType):
address_n: "list[int]" address_n: "list[int]"
message: "bytes" message: "bytes"
encoded_network: "bytes | None"
def __init__( def __init__(
self, self,
*, *,
message: "bytes", message: "bytes",
address_n: "list[int] | None" = None, address_n: "list[int] | None" = None,
encoded_network: "bytes | None" = None,
) -> None: ) -> None:
pass pass
@ -3698,6 +3767,7 @@ if TYPE_CHECKING:
address_n: "list[int]" address_n: "list[int]"
domain_separator_hash: "bytes" domain_separator_hash: "bytes"
message_hash: "bytes | None" message_hash: "bytes | None"
encoded_network: "bytes | None"
def __init__( def __init__(
self, self,
@ -3705,6 +3775,7 @@ if TYPE_CHECKING:
domain_separator_hash: "bytes", domain_separator_hash: "bytes",
address_n: "list[int] | None" = None, address_n: "list[int] | None" = None,
message_hash: "bytes | None" = None, message_hash: "bytes | None" = None,
encoded_network: "bytes | None" = None,
) -> None: ) -> None:
pass pass

@ -5,15 +5,25 @@ sys.path.append("../src")
from ubinascii import hexlify, unhexlify # noqa: F401 from ubinascii import hexlify, unhexlify # noqa: F401
import unittest # noqa: F401 import unittest # noqa: F401
from typing import Any, Awaitable
from trezor import utils # noqa: F401 from trezor import utils # noqa: F401
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
def H_(x: int) -> int: def H_(x: int) -> int:
"""
Shortcut function that "hardens" a number in a BIP44 path.
"""
return x | HARDENED return x | HARDENED
def UH_(x: int) -> int:
"""
Shortcut function that "un-hardens" a number in a BIP44 path.
"""
return x & ~(HARDENED)
def await_result(task: Awaitable) -> Any: def await_result(task: Awaitable) -> Any:
value = None value = None
while True: while True:

@ -0,0 +1,102 @@
from ubinascii import unhexlify # noqa: F401
from trezor import messages, protobuf
from trezor.enums import EthereumDefinitionType
from trezor.crypto.curve import ed25519
from trezor.crypto.hashlib import sha256
DEFINITIONS_DEV_PRIVATE_KEY = unhexlify(
"4141414141414141414141414141414141414141414141414141414141414141"
)
def make_network(
chain_id: int = 0,
slip44: int = 0,
symbol: str = "FAKE",
name: str = "Fake network",
) -> messages.EthereumNetworkInfo:
return messages.EthereumNetworkInfo(
chain_id=chain_id,
slip44=slip44,
symbol=symbol,
name=name,
)
def make_token(
symbol: str = "FAKE",
decimals: int = 18,
address: bytes = b"",
chain_id: int = 0,
name: str = "Fake token",
) -> messages.EthereumTokenInfo:
return messages.EthereumTokenInfo(
symbol=symbol,
decimals=decimals,
address=address,
chain_id=chain_id,
name=name,
)
def make_payload(
prefix: bytes = b"trzd1",
data_type: EthereumDefinitionType = EthereumDefinitionType.NETWORK,
timestamp: int = 0xFFFF_FFFF,
message: messages.EthereumNetworkInfo
| messages.EthereumTokenInfo
| bytes = make_network(),
) -> bytes:
payload = prefix
payload += data_type.to_bytes(1, "little")
payload += timestamp.to_bytes(4, "little")
if isinstance(message, bytes):
message_bytes = message
else:
message_bytes = protobuf.dump_message_buffer(message)
payload += len(message_bytes).to_bytes(2, "little")
payload += message_bytes
return payload
def sign_payload(payload: bytes, merkle_neighbors: list[bytes]) -> tuple[bytes, bytes]:
digest = sha256(b"\x00" + payload).digest()
merkle_proof = []
for item in merkle_neighbors:
left, right = min(digest, item), max(digest, item)
digest = sha256(b"\x01" + left + right).digest()
merkle_proof.append(digest)
merkle_proof = len(merkle_proof).to_bytes(1, "little") + b"".join(merkle_proof)
signature = ed25519.sign(DEFINITIONS_DEV_PRIVATE_KEY, digest)
return merkle_proof, signature
def encode_network(
network: messages.EthereumNetworkInfo | None = None,
chain_id: int = 0,
slip44: int = 0,
symbol: str = "FAKE",
name: str = "Fake network",
) -> bytes:
if network is None:
network = make_network(chain_id, slip44, symbol, name)
payload = make_payload(data_type=EthereumDefinitionType.NETWORK, message=network)
proof, signature = sign_payload(payload, [])
return payload + proof + signature
def encode_token(
token: messages.EthereumTokenInfo | None = None,
symbol: str = "FAKE",
decimals: int = 18,
address: bytes = b"",
chain_id: int = 0,
name: str = "Fake token",
) -> bytes:
if token is None:
token = make_token(symbol, decimals, address, chain_id, name)
payload = make_payload(data_type=EthereumDefinitionType.TOKEN, message=token)
proof, signature = sign_payload(payload, [])
return payload + proof + signature

@ -0,0 +1,250 @@
from common import *
import unittest
import typing as t
from trezor import utils, wire
from ubinascii import hexlify # noqa: F401
if not utils.BITCOIN_ONLY:
from apps.ethereum import networks, tokens
from apps.ethereum.definitions import decode_definition, Definitions
from ethereum_common import *
from trezor import protobuf
from trezor.enums import EthereumDefinitionType
from trezor.messages import (
EthereumDefinitions,
EthereumNetworkInfo,
EthereumTokenInfo,
EthereumSignTx,
EthereumSignTxEIP1559,
EthereumSignTypedData,
)
TETHER_ADDRESS = b"\xda\xc1\x7f\x95\x8d\x2e\xe5\x23\xa2\x20\x62\x06\x99\x45\x97\xc1\x3d\x83\x1e\xc7"
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestDecodeDefinition(unittest.TestCase):
def test_short_message(self):
with self.assertRaises(wire.DataError):
decode_definition(b"\x00", EthereumNetworkInfo)
with self.assertRaises(wire.DataError):
decode_definition(b"\x00", EthereumTokenInfo)
# successful decode network
def test_network_definition(self):
network = make_network(chain_id=42, slip44=69, symbol="FAKE", name="Fakenet")
encoded = encode_network(network)
try:
self.assertEqual(decode_definition(encoded, EthereumNetworkInfo), network)
except Exception as e:
print(e.message)
# successful decode token
def test_token_definition(self):
token = make_token("FAKE", decimals=33, address=b"abcd" * 5, chain_id=42)
encoded = encode_token(token)
self.assertEqual(decode_definition(encoded, EthereumTokenInfo), token)
def assertFailed(self, data: bytes) -> None:
with self.assertRaises(wire.DataError):
decode_definition(data, EthereumNetworkInfo)
def test_mangled_signature(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
bad_signature = signature[:-1] + b"\xff"
self.assertFailed(payload + proof + bad_signature)
def test_missing_signature(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof)
def test_mangled_payload(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
bad_payload = payload[:-1] + b"\xff"
self.assertFailed(bad_payload + proof + signature)
def test_proof_length_mismatch(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
bad_proof = b"\x01"
self.assertFailed(payload + bad_proof + signature)
def test_bad_proof(self):
payload = make_payload()
proof, signature = sign_payload(payload, [sha256(b"x").digest()])
bad_proof = proof[:-1] + b"\xff"
self.assertFailed(payload + bad_proof + signature)
def test_trimmed_proof(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
bad_proof = proof[:-1]
self.assertFailed(payload + bad_proof + signature)
def test_bad_prefix(self):
payload = make_payload(prefix=b"trzd2")
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_bad_type(self):
payload = make_payload(
data_type=EthereumDefinitionType.TOKEN, message=make_token()
)
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_outdated(self):
payload = make_payload(timestamp=0)
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_malformed_protobuf(self):
payload = make_payload(message=b"\x00")
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_protobuf_mismatch(self):
payload = make_payload(
data_type=EthereumDefinitionType.NETWORK, message=make_token()
)
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
payload = make_payload(
data_type=EthereumDefinitionType.TOKEN, message=make_network()
)
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_trailing_garbage(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature + b"\x00")
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumDefinitions(unittest.TestCase):
def assertUnknown(self, what: t.Any) -> None:
if what is networks.UNKNOWN_NETWORK:
return
if what is tokens.UNKNOWN_TOKEN:
return
self.fail("Expected UNKNOWN_*, got %r" % what)
def assertKnown(self, what: t.Any) -> None:
if not EthereumNetworkInfo.is_type_of(
what
) and not EthereumTokenInfo.is_type_of(what):
self.fail("Expected network / token info, got %r" % what)
if what is networks.UNKNOWN_NETWORK:
self.fail("Expected known network, got UNKNOWN_NETWORK")
if what is tokens.UNKNOWN_TOKEN:
self.fail("Expected known token, got UNKNOWN_TOKEN")
def test_empty(self) -> None:
# no slip44 nor chain_id -- should short-circuit and always be unknown
defs = Definitions.from_encoded(None, None)
self.assertUnknown(defs.network)
self.assertFalse(defs._tokens)
self.assertUnknown(defs.get_token(TETHER_ADDRESS))
# chain_id provided, no definition
defs = Definitions.from_encoded(None, None, chain_id=100_000)
self.assertUnknown(defs.network)
self.assertFalse(defs._tokens)
self.assertUnknown(defs.get_token(TETHER_ADDRESS))
def test_builtin(self) -> None:
defs = Definitions.from_encoded(None, None, chain_id=1)
self.assertKnown(defs.network)
self.assertFalse(defs._tokens)
self.assertKnown(defs.get_token(TETHER_ADDRESS))
self.assertUnknown(defs.get_token(b"\x00" * 20))
defs = Definitions.from_encoded(None, None, slip44=60)
self.assertKnown(defs.network)
self.assertFalse(defs._tokens)
self.assertKnown(defs.get_token(TETHER_ADDRESS))
self.assertUnknown(defs.get_token(b"\x00" * 20))
def test_external(self) -> None:
network = make_network(chain_id=42)
defs = Definitions.from_encoded(encode_network(network), None, chain_id=42)
self.assertEqual(defs.network, network)
self.assertUnknown(defs.get_token(b"\x00" * 20))
token = make_token(chain_id=42, address=b"\x00" * 20)
defs = Definitions.from_encoded(
encode_network(network), encode_token(token), chain_id=42
)
self.assertEqual(defs.network, network)
self.assertEqual(defs.get_token(b"\x00" * 20), token)
token = make_token(chain_id=1, address=b"\x00" * 20)
defs = Definitions.from_encoded(None, encode_token(token), chain_id=1)
self.assertKnown(defs.network)
self.assertEqual(defs.get_token(b"\x00" * 20), token)
def test_external_token_mismatch(self) -> None:
network = make_network(chain_id=42)
token = make_token(chain_id=43, address=b"\x00" * 20)
defs = Definitions.from_encoded(encode_network(network), encode_token(token))
self.assertUnknown(defs.get_token(b"\x00" * 20))
def test_external_chain_match(self) -> None:
network = make_network(chain_id=42)
token = make_token(chain_id=42, address=b"\x00" * 20)
defs = Definitions.from_encoded(
encode_network(network), encode_token(token), chain_id=42
)
self.assertEqual(defs.network, network)
self.assertEqual(defs.get_token(b"\x00" * 20), token)
with self.assertRaises(wire.DataError):
Definitions.from_encoded(
encode_network(network), encode_token(token), chain_id=333
)
def test_external_slip44_mismatch(self) -> None:
network = make_network(chain_id=42, slip44=1999)
token = make_token(chain_id=42, address=b"\x00" * 20)
defs = Definitions.from_encoded(
encode_network(network), encode_token(token), slip44=1999
)
self.assertEqual(defs.network, network)
self.assertEqual(defs.get_token(b"\x00" * 20), token)
with self.assertRaises(wire.DataError):
Definitions.from_encoded(
encode_network(network), encode_token(token), slip44=333
)
def test_ignore_encoded_network(self) -> None:
# when network is builtin, ignore the encoded one
network = encode_network(chain_id=1, symbol="BAD")
defs = Definitions.from_encoded(network, None, chain_id=1)
self.assertNotEqual(defs.network, network)
def test_ignore_encoded_token(self) -> None:
# when token is builtin, ignore the encoded one
token = encode_token(chain_id=1, address=TETHER_ADDRESS, symbol="BAD")
defs = Definitions.from_encoded(None, token, chain_id=1)
self.assertNotEqual(defs.get_token(TETHER_ADDRESS), token)
def test_ignore_with_no_match(self) -> None:
network = encode_network(chain_id=100_000, symbol="BAD")
# smoke test: definition is accepted
defs = Definitions.from_encoded(network, None, chain_id=100_000)
self.assertKnown(defs.network)
# same definition but nothing to match it to
defs = Definitions.from_encoded(network, None)
self.assertUnknown(defs.network)
if __name__ == "__main__":
unittest.main()

@ -3,23 +3,22 @@ from apps.common.paths import HARDENED
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from apps.ethereum.helpers import address_from_bytes from apps.ethereum.helpers import address_from_bytes
from apps.ethereum.networks import NetworkInfo from ethereum_common import make_network
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumGetAddress(unittest.TestCase): class TestEthereumGetAddress(unittest.TestCase):
def test_address_from_bytes_eip55(self): def test_address_from_bytes_eip55(self):
# https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md # https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
eip55 = [ eip55 = [
'0x52908400098527886E0F7030069857D2E4169EE7', "0x52908400098527886E0F7030069857D2E4169EE7",
'0x8617E340B3D01FA5F11F306F4090FD50E238070D', "0x8617E340B3D01FA5F11F306F4090FD50E238070D",
'0xde709f2102306220921060314715629080e2fb77', "0xde709f2102306220921060314715629080e2fb77",
'0x27b1fdb04752bbc536007a920d24acb045561c26', "0x27b1fdb04752bbc536007a920d24acb045561c26",
'0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed', "0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed",
'0xfB6916095ca1df60bB79Ce92cE3Ea74c37c5d359', "0xfB6916095ca1df60bB79Ce92cE3Ea74c37c5d359",
'0xdbF03B407c01E7cD3CBea99509d93f8DDDC8C6FB', "0xdbF03B407c01E7cD3CBea99509d93f8DDDC8C6FB",
'0xD1220A0cf47c7B9Be7A2E6BA89F429762e7b9aDb', "0xD1220A0cf47c7B9Be7A2E6BA89F429762e7b9aDb",
] ]
for s in eip55: for s in eip55:
b = unhexlify(s[2:]) b = unhexlify(s[2:])
@ -29,28 +28,30 @@ class TestEthereumGetAddress(unittest.TestCase):
def test_address_from_bytes_rskip60(self): def test_address_from_bytes_rskip60(self):
# https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md # https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md
rskip60_chain_30 = [ rskip60_chain_30 = [
'0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD', "0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD",
'0xFb6916095cA1Df60bb79ce92cE3EA74c37c5d359', "0xFb6916095cA1Df60bb79ce92cE3EA74c37c5d359",
'0xDBF03B407c01E7CD3cBea99509D93F8Dddc8C6FB', "0xDBF03B407c01E7CD3cBea99509D93F8Dddc8C6FB",
'0xD1220A0Cf47c7B9BE7a2e6ba89F429762E7B9adB' "0xD1220A0Cf47c7B9BE7a2e6ba89F429762E7B9adB",
] ]
rskip60_chain_31 = [ rskip60_chain_31 = [
'0x5aAeb6053F3e94c9b9A09F33669435E7EF1BEaEd', "0x5aAeb6053F3e94c9b9A09F33669435E7EF1BEaEd",
'0xFb6916095CA1dF60bb79CE92ce3Ea74C37c5D359', "0xFb6916095CA1dF60bb79CE92ce3Ea74C37c5D359",
'0xdbF03B407C01E7cd3cbEa99509D93f8dDDc8C6fB', "0xdbF03B407C01E7cd3cbEa99509D93f8dDDc8C6fB",
'0xd1220a0CF47c7B9Be7A2E6Ba89f429762E7b9adB' "0xd1220a0CF47c7B9Be7A2E6Ba89f429762E7b9adB",
] ]
n = NetworkInfo(chain_id=30, slip44=1, shortcut='T', name='T', rskip60=True)
n = make_network(chain_id=30)
for s in rskip60_chain_30: for s in rskip60_chain_30:
b = unhexlify(s[2:]) b = unhexlify(s[2:])
h = address_from_bytes(b, n) h = address_from_bytes(b, n)
self.assertEqual(h, s) self.assertEqual(h, s)
n.chain_id = 31
n = make_network(chain_id=31)
for s in rskip60_chain_31: for s in rskip60_chain_31:
b = unhexlify(s[2:]) b = unhexlify(s[2:])
h = address_from_bytes(b, n) h = address_from_bytes(b, n)
self.assertEqual(h, s) self.assertEqual(h, s)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

@ -1,22 +1,35 @@
from common import * from common import *
import unittest
from storage import cache from storage import cache
from trezor import wire from trezor import wire, utils
from trezor.crypto import bip39 from trezor.crypto import bip39
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from apps.ethereum import CURVE from apps.ethereum import CURVE
from apps.ethereum.networks import UNKNOWN_NETWORK
from apps.ethereum.keychain import ( from apps.ethereum.keychain import (
PATTERNS_ADDRESS, PATTERNS_ADDRESS,
_schemas_from_address_n, _schemas_from_network,
_defs_from_message,
_slip44_from_address_n,
with_keychain_from_path, with_keychain_from_path,
with_keychain_from_chain_id, with_keychain_from_chain_id,
) )
from apps.ethereum.networks import by_chain_id, by_slip44
from trezor.messages import EthereumGetAddress from trezor.messages import (
from trezor.messages import EthereumSignTx EthereumGetAddress,
EthereumSignTx,
EthereumDefinitions,
EthereumSignMessage,
EthereumSignTypedData,
EthereumSignTxEIP1559,
)
from ethereum_common import make_network, encode_network
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
@ -43,11 +56,19 @@ class TestEthereumKeychain(unittest.TestCase):
[44 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0], [44 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0],
[44 | HARDENED, slip44_id | HARDENED, 1 | HARDENED, 0], [44 | HARDENED, slip44_id | HARDENED, 1 | HARDENED, 0],
[44 | HARDENED, slip44_id | HARDENED, 0 | HARDENED, 0 | HARDENED, 0], [44 | HARDENED, slip44_id | HARDENED, 0 | HARDENED, 0 | HARDENED, 0],
[44 | HARDENED, slip44_id | HARDENED, 0 | HARDENED, 0 | HARDENED, 0 | HARDENED], [
44 | HARDENED,
slip44_id | HARDENED,
0 | HARDENED,
0 | HARDENED,
0 | HARDENED,
],
) )
for addr in invalid_addresses: for addr in invalid_addresses:
self.assertRaises( self.assertRaises(
wire.DataError, keychain.derive, addr, wire.DataError,
keychain.derive,
addr,
) )
def setUp(self): def setUp(self):
@ -56,7 +77,9 @@ class TestEthereumKeychain(unittest.TestCase):
cache.set(cache.APP_COMMON_SEED, seed) cache.set(cache.APP_COMMON_SEED, seed)
def from_address_n(self, address_n): def from_address_n(self, address_n):
schemas = _schemas_from_address_n(PATTERNS_ADDRESS, address_n) slip44 = _slip44_from_address_n(address_n)
network = make_network(slip44=slip44)
schemas = _schemas_from_network(PATTERNS_ADDRESS, network)
return await_result(get_keychain(wire.DUMMY_CONTEXT, CURVE, schemas)) return await_result(get_keychain(wire.DUMMY_CONTEXT, CURVE, schemas))
def test_from_address_n(self): def test_from_address_n(self):
@ -69,116 +92,140 @@ class TestEthereumKeychain(unittest.TestCase):
keychain = self.from_address_n([44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0]) keychain = self.from_address_n([44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0])
self._check_keychain(keychain, 60) self._check_keychain(keychain, 60)
def test_from_address_n_unknown(self): def test_from_address_n_casa45(self):
# try Bitcoin slip44 id m/44'/0'/0' # valid keychain m/45'/60/0
schemas = tuple(_schemas_from_address_n(PATTERNS_ADDRESS, [44 | HARDENED, 0 | HARDENED, 0 | HARDENED])) keychain = self.from_address_n([45 | HARDENED, 60, 0, 0, 0])
self.assertEqual(schemas, ()) keychain.derive([45 | HARDENED, 60, 0, 0, 0])
with self.assertRaises(wire.DataError):
keychain.derive([45 | HARDENED, 60 | HARDENED, 0, 0, 0])
def test_bad_address_n(self): def test_with_keychain_from_path_short(self):
# keychain generated from valid slip44 id but invalid address m/0'/60'/0' # check that the keychain will not die when the address_n is too short
keychain = self.from_address_n([0 | HARDENED, 60 | HARDENED, 0 | HARDENED]) @with_keychain_from_path(*PATTERNS_ADDRESS)
self._check_keychain(keychain, 60) async def handler(ctx, msg, keychain, defs):
# in this case the network is unknown so the keychain should allow access
# to Ethereum and testnet paths
self._check_keychain(keychain, 60)
self._check_keychain(keychain, 1)
self.assertIs(defs.network, UNKNOWN_NETWORK)
def test_with_keychain_from_path(self): await_result(handler(wire.DUMMY_CONTEXT, EthereumGetAddress(address_n=[])))
await_result(handler(wire.DUMMY_CONTEXT, EthereumGetAddress(address_n=[0])))
def test_with_keychain_from_path_builtins(self):
@with_keychain_from_path(*PATTERNS_ADDRESS) @with_keychain_from_path(*PATTERNS_ADDRESS)
async def handler(ctx, msg, keychain): async def handler(ctx, msg, keychain, defs):
self._check_keychain(keychain, msg.address_n[1] & ~HARDENED) slip44 = msg.address_n[1] & ~HARDENED
self._check_keychain(keychain, slip44)
await_result( self.assertEqual(defs.network.slip44, slip44)
handler(
wire.DUMMY_CONTEXT, vectors = (
EthereumGetAddress( # Ethereum
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED] [44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
), # Ethereum from Ledger Live legacy path
) [44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0],
# Ethereum Classic
[44 | HARDENED, 61 | HARDENED, 0 | HARDENED],
) )
await_result( # Ethereum from Ledger Live legacy path for address_n in vectors:
handler( await_result(
wire.DUMMY_CONTEXT, handler(wire.DUMMY_CONTEXT, EthereumGetAddress(address_n=address_n))
EthereumGetAddress(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0]
),
) )
)
await_result( with self.assertRaises(wire.DataError):
handler( await_result(
wire.DUMMY_CONTEXT, handler( # unknown network
EthereumGetAddress( wire.DUMMY_CONTEXT,
address_n=[44 | HARDENED, 108 | HARDENED, 0 | HARDENED] EthereumGetAddress(
), address_n=[44 | HARDENED, 0 | HARDENED, 0 | HARDENED]
),
)
) )
def test_with_keychain_from_path_external(self):
FORBIDDEN_SYMBOL = "forbidden name"
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def handler(ctx, msg, keychain, defs):
slip44 = msg.address_n[1] & ~HARDENED
self._check_keychain(keychain, slip44)
self.assertEqual(defs.network.slip44, slip44)
self.assertNotEqual(defs.network.name, FORBIDDEN_SYMBOL)
vectors_valid = ( # slip44, network_def
# invalid network is ignored when there is a builtin
(60, b"hello"),
# valid network is ignored when there is a builtin
(60, encode_network(slip44=60, symbol=FORBIDDEN_SYMBOL)),
# valid network is accepted for unknown slip44 ids
(33333, encode_network(slip44=33333)),
) )
with self.assertRaises(wire.DataError): for slip44, encoded_network in vectors_valid:
await_result( await_result(
handler( handler(
wire.DUMMY_CONTEXT, wire.DUMMY_CONTEXT,
EthereumGetAddress( EthereumGetAddress(
address_n=[44 | HARDENED, 0 | HARDENED, 0 | HARDENED] address_n=[44 | HARDENED, slip44 | HARDENED, 0 | HARDENED],
encoded_network=encoded_network,
), ),
) )
) )
def test_with_keychain_from_chain_id(self): vectors_invalid = ( # slip44, network_def
# invalid network is rejected
(30000, b"hello"),
# invalid network does not prove mismatched slip44 id
(30000, encode_network(slip44=666)),
)
for slip44, encoded_network in vectors_invalid:
with self.assertRaises(wire.DataError):
await_result(
handler(
wire.DUMMY_CONTEXT,
EthereumGetAddress(
address_n=[44 | HARDENED, slip44 | HARDENED, 0 | HARDENED],
encoded_network=encoded_network,
),
)
)
def test_with_keychain_from_chain_id_builtin(self):
@with_keychain_from_chain_id @with_keychain_from_chain_id
async def handler_chain_id(ctx, msg, keychain): async def handler_chain_id(ctx, msg, keychain, defs):
slip44_id = msg.address_n[1] & ~HARDENED slip44_id = msg.address_n[1] & ~HARDENED
# standard tests # standard tests
self._check_keychain(keychain, slip44_id) self._check_keychain(keychain, slip44_id)
# provided address should succeed too # provided address should succeed too
keychain.derive(msg.address_n) keychain.derive(msg.address_n)
self.assertEqual(defs.network.chain_id, msg.chain_id)
await_result( # Ethereum vectors = ( # chain_id, address_n
handler_chain_id( # Ethereum
wire.DUMMY_CONTEXT, (1, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED]),
EthereumSignTx( # Ethereum from Ledger Live legacy path
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED], (1, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0]),
chain_id=1, # Ethereum Classic
gas_price=b"", (61, [44 | HARDENED, 61 | HARDENED, 0 | HARDENED]),
gas_limit=b"", # ETH slip44, ETC chain_id
), # (known networks are allowed to use eth slip44 for cross-signing)
) (61, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED]),
)
await_result( # Ethereum from Ledger Live legacy path
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0],
chain_id=1,
gas_price=b"",
gas_limit=b"",
),
)
)
await_result( # Ethereum Classic
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=[44 | HARDENED, 61 | HARDENED, 0 | HARDENED],
chain_id=61,
gas_price=b"",
gas_limit=b"",
),
)
) )
# Known chain-ids are allowed to use Ethereum derivation paths too, as there is for chain_id, address_n in vectors:
# no risk of replaying the transaction on the Ethereum chain await_result( # Ethereum
await_result( # ETH slip44 with ETC chain-id handler_chain_id(
handler_chain_id( wire.DUMMY_CONTEXT,
wire.DUMMY_CONTEXT, EthereumSignTx(
EthereumSignTx( address_n=address_n,
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED], chain_id=chain_id,
chain_id=61, gas_price=b"",
gas_price=b"", gas_limit=b"",
gas_limit=b"", ),
), )
) )
)
with self.assertRaises(wire.DataError): with self.assertRaises(wire.DataError):
await_result( # chain_id and network mismatch await_result( # chain_id and network mismatch
@ -193,6 +240,136 @@ class TestEthereumKeychain(unittest.TestCase):
) )
) )
def test_with_keychain_from_chain_id_external(self):
FORBIDDEN_SYMBOL = "forbidden name"
@with_keychain_from_chain_id
async def handler_chain_id(ctx, msg, keychain, defs):
slip44_id = msg.address_n[1] & ~HARDENED
# standard tests
self._check_keychain(keychain, slip44_id)
# provided address should succeed too
keychain.derive(msg.address_n)
self.assertEqual(defs.network.chain_id, msg.chain_id)
self.assertNotEqual(defs.network.name, FORBIDDEN_SYMBOL)
vectors_valid = ( # chain_id, address_n, encoded_network
# invalid network is ignored when there is a builtin
(1, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED], b"hello"),
# valid network is ignored when there is a builtin
(
1,
[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
encode_network(slip44=60, symbol=FORBIDDEN_SYMBOL),
),
# valid network is accepted for unknown chain ids
(
33333,
[44 | HARDENED, 33333 | HARDENED, 0 | HARDENED],
encode_network(slip44=33333, chain_id=33333),
),
# valid network is allowed to cross-sign for Ethereum slip44
(
33333,
[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
encode_network(slip44=33333, chain_id=33333),
),
# valid network where slip44 and chain_id are different
(
44444,
[44 | HARDENED, 33333 | HARDENED, 0 | HARDENED],
encode_network(slip44=33333, chain_id=44444),
),
)
for chain_id, address_n, encoded_network in vectors_valid:
await_result(
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=address_n,
chain_id=chain_id,
gas_price=b"",
gas_limit=b"",
definitions=EthereumDefinitions(
encoded_network=encoded_network
),
),
)
)
vectors_invalid = ( # chain_id, address_n, encoded_network
# invalid network is rejected
(30000, [44 | HARDENED, 30000 | HARDENED, 0 | HARDENED], b"hello"),
# invalid network does not prove mismatched slip44 id
(
30000,
[44 | HARDENED, 30000 | HARDENED, 0 | HARDENED],
encode_network(chain_id=30000, slip44=666),
),
# invalid network does not prove mismatched chain_id
(
30000,
[44 | HARDENED, 30000 | HARDENED, 0 | HARDENED],
encode_network(chain_id=666, slip44=30000),
),
)
for chain_id, address_n, encoded_network in vectors_invalid:
with self.assertRaises(wire.DataError):
await_result(
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=address_n,
chain_id=chain_id,
gas_price=b"",
gas_limit=b"",
definitions=EthereumDefinitions(
encoded_network=encoded_network
),
),
)
)
def test_message_types(self) -> None:
network = make_network(symbol="Testing Network")
encoded_network = encode_network(network)
messages = (
EthereumSignTx(
gas_price=b"",
gas_limit=b"",
chain_id=0,
definitions=EthereumDefinitions(encoded_network=encoded_network),
),
EthereumSignMessage(
message=b"",
encoded_network=encoded_network,
),
EthereumSignTxEIP1559(
chain_id=0,
gas_limit=b"",
max_gas_fee=b"",
max_priority_fee=b"",
nonce=b"",
value=b"",
data_length=0,
definitions=EthereumDefinitions(encoded_network=encoded_network),
),
EthereumSignTypedData(
primary_type="",
definitions=EthereumDefinitions(encoded_network=encoded_network),
),
EthereumGetAddress(
encoded_network=encoded_network,
),
)
for message in messages:
defs = _defs_from_message(message, chain_id=0)
self.assertEqual(defs.network, network)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

@ -1,104 +1,109 @@
from common import * from common import *
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from apps.ethereum import networks
from apps.ethereum.layout import format_ethereum_amount from apps.ethereum.layout import format_ethereum_amount
from apps.ethereum.tokens import token_by_chain_address from apps.ethereum.tokens import UNKNOWN_TOKEN
from ethereum_common import make_network, make_token
ETH = networks.by_chain_id(1)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestFormatEthereumAmount(unittest.TestCase): class TestFormatEthereumAmount(unittest.TestCase):
def test_denominations(self):
text = format_ethereum_amount(1, None, ETH)
self.assertEqual(text, "1 Wei ETH")
text = format_ethereum_amount(1000, None, ETH)
self.assertEqual(text, "1,000 Wei ETH")
text = format_ethereum_amount(1000000, None, ETH)
self.assertEqual(text, "1,000,000 Wei ETH")
text = format_ethereum_amount(10000000, None, ETH)
self.assertEqual(text, "10,000,000 Wei ETH")
text = format_ethereum_amount(100000000, None, ETH)
self.assertEqual(text, "100,000,000 Wei ETH")
text = format_ethereum_amount(1000000000, None, ETH)
self.assertEqual(text, "0.000000001 ETH")
text = format_ethereum_amount(10000000000, None, ETH)
self.assertEqual(text, "0.00000001 ETH")
text = format_ethereum_amount(100000000000, None, ETH)
self.assertEqual(text, "0.0000001 ETH")
text = format_ethereum_amount(1000000000000, None, ETH)
self.assertEqual(text, "0.000001 ETH")
text = format_ethereum_amount(10000000000000, None, ETH)
self.assertEqual(text, "0.00001 ETH")
text = format_ethereum_amount(100000000000000, None, ETH)
self.assertEqual(text, "0.0001 ETH")
text = format_ethereum_amount(1000000000000000, None, ETH)
self.assertEqual(text, "0.001 ETH")
text = format_ethereum_amount(10000000000000000, None, ETH)
self.assertEqual(text, "0.01 ETH")
text = format_ethereum_amount(100000000000000000, None, ETH)
self.assertEqual(text, "0.1 ETH")
text = format_ethereum_amount(1000000000000000000, None, ETH)
self.assertEqual(text, "1 ETH")
text = format_ethereum_amount(10000000000000000000, None, ETH)
self.assertEqual(text, "10 ETH")
text = format_ethereum_amount(100000000000000000000, None, ETH)
self.assertEqual(text, "100 ETH")
text = format_ethereum_amount(1000000000000000000000, None, ETH)
self.assertEqual(text, "1,000 ETH")
def test_format(self): def test_precision(self):
text = format_ethereum_amount(1, None, 1) text = format_ethereum_amount(1000000000000000001, None, ETH)
self.assertEqual(text, '1 Wei ETH') self.assertEqual(text, "1.000000000000000001 ETH")
text = format_ethereum_amount(1000, None, 1) text = format_ethereum_amount(10000000000000000001, None, ETH)
self.assertEqual(text, '1,000 Wei ETH') self.assertEqual(text, "10.000000000000000001 ETH")
text = format_ethereum_amount(1000000, None, 1)
self.assertEqual(text, '1,000,000 Wei ETH')
text = format_ethereum_amount(10000000, None, 1)
self.assertEqual(text, '10,000,000 Wei ETH')
text = format_ethereum_amount(100000000, None, 1)
self.assertEqual(text, '100,000,000 Wei ETH')
text = format_ethereum_amount(1000000000, None, 1)
self.assertEqual(text, '0.000000001 ETH')
text = format_ethereum_amount(10000000000, None, 1)
self.assertEqual(text, '0.00000001 ETH')
text = format_ethereum_amount(100000000000, None, 1)
self.assertEqual(text, '0.0000001 ETH')
text = format_ethereum_amount(1000000000000, None, 1)
self.assertEqual(text, '0.000001 ETH')
text = format_ethereum_amount(10000000000000, None, 1)
self.assertEqual(text, '0.00001 ETH')
text = format_ethereum_amount(100000000000000, None, 1)
self.assertEqual(text, '0.0001 ETH')
text = format_ethereum_amount(1000000000000000, None, 1)
self.assertEqual(text, '0.001 ETH')
text = format_ethereum_amount(10000000000000000, None, 1)
self.assertEqual(text, '0.01 ETH')
text = format_ethereum_amount(100000000000000000, None, 1)
self.assertEqual(text, '0.1 ETH')
text = format_ethereum_amount(1000000000000000000, None, 1)
self.assertEqual(text, '1 ETH')
text = format_ethereum_amount(10000000000000000000, None, 1)
self.assertEqual(text, '10 ETH')
text = format_ethereum_amount(100000000000000000000, None, 1)
self.assertEqual(text, '100 ETH')
text = format_ethereum_amount(1000000000000000000000, None, 1)
self.assertEqual(text, '1,000 ETH')
text = format_ethereum_amount(1000000000000000000, None, 61)
self.assertEqual(text, '1 ETC')
text = format_ethereum_amount(1000000000000000000, None, 31)
self.assertEqual(text, '1 tRBTC')
text = format_ethereum_amount(1000000000000000001, None, 1) def test_symbols(self):
self.assertEqual(text, '1.000000000000000001 ETH') fake_network = make_network(symbol="FAKE")
text = format_ethereum_amount(10000000000000000001, None, 1) text = format_ethereum_amount(1, None, fake_network)
self.assertEqual(text, '10.000000000000000001 ETH') self.assertEqual(text, "1 Wei FAKE")
text = format_ethereum_amount(10000000000000000001, None, 61) text = format_ethereum_amount(1000000000000000000, None, fake_network)
self.assertEqual(text, '10.000000000000000001 ETC') self.assertEqual(text, "1 FAKE")
text = format_ethereum_amount(1000000000000000001, None, 31) text = format_ethereum_amount(1000000000000000001, None, fake_network)
self.assertEqual(text, '1.000000000000000001 tRBTC') self.assertEqual(text, "1.000000000000000001 FAKE")
def test_unknown_chain(self): def test_unknown_chain(self):
# unknown chain # unknown chain
text = format_ethereum_amount(1, None, 9999) text = format_ethereum_amount(1, None, networks.UNKNOWN_NETWORK)
self.assertEqual(text, '1 Wei UNKN') self.assertEqual(text, "1 Wei UNKN")
text = format_ethereum_amount(10000000000000000001, None, 9999) text = format_ethereum_amount(
self.assertEqual(text, '10.000000000000000001 UNKN') 10000000000000000001, None, networks.UNKNOWN_NETWORK
)
self.assertEqual(text, "10.000000000000000001 UNKN")
def test_tokens(self): def test_tokens(self):
# tokens with low decimal values # tokens with low decimal values
# USDC has 6 decimals # USDC has 6 decimals
usdc_token = token_by_chain_address(1, unhexlify("a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48")) usdc_token = make_token(symbol="USDC", decimals=6)
# ICO has 10 decimals
ico_token = token_by_chain_address(1, unhexlify("a33e729bf4fdeb868b534e1f20523463d9c46bee"))
# when decimals < 10, should never display 'Wei' format # when decimals < 10, should never display 'Wei' format
text = format_ethereum_amount(1, usdc_token, 1) text = format_ethereum_amount(1, usdc_token, ETH)
self.assertEqual(text, '0.000001 USDC') self.assertEqual(text, "0.000001 USDC")
text = format_ethereum_amount(0, usdc_token, 1) text = format_ethereum_amount(0, usdc_token, ETH)
self.assertEqual(text, '0 USDC') self.assertEqual(text, "0 USDC")
text = format_ethereum_amount(1, ico_token, 1) # ICO has 10 decimals
self.assertEqual(text, '1 Wei ICO') ico_token = make_token(symbol="ICO", decimals=10)
text = format_ethereum_amount(9, ico_token, 1) text = format_ethereum_amount(1, ico_token, ETH)
self.assertEqual(text, '9 Wei ICO') self.assertEqual(text, "1 Wei ICO")
text = format_ethereum_amount(10, ico_token, 1) text = format_ethereum_amount(9, ico_token, ETH)
self.assertEqual(text, '0.000000001 ICO') self.assertEqual(text, "9 Wei ICO")
text = format_ethereum_amount(11, ico_token, 1) text = format_ethereum_amount(10, ico_token, ETH)
self.assertEqual(text, '0.0000000011 ICO') self.assertEqual(text, "0.000000001 ICO")
text = format_ethereum_amount(11, ico_token, ETH)
self.assertEqual(text, "0.0000000011 ICO")
def test_unknown_token(self): def test_unknown_token(self):
unknown_token = token_by_chain_address(1, b"hello") text = format_ethereum_amount(1, UNKNOWN_TOKEN, ETH)
text = format_ethereum_amount(1, unknown_token, 1) self.assertEqual(text, "1 Wei UNKN")
self.assertEqual(text, '1 Wei UNKN') text = format_ethereum_amount(0, UNKNOWN_TOKEN, ETH)
text = format_ethereum_amount(0, unknown_token, 1) self.assertEqual(text, "0 Wei UNKN")
self.assertEqual(text, '0 Wei UNKN')
# unknown token has 0 decimals so is always wei # unknown token has 0 decimals so is always wei
text = format_ethereum_amount(1000000000000000000, unknown_token, 1) text = format_ethereum_amount(1000000000000000000, UNKNOWN_TOKEN, ETH)
self.assertEqual(text, '1,000,000,000,000,000,000 Wei UNKN') self.assertEqual(text, "1,000,000,000,000,000,000 Wei UNKN")
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

@ -6,28 +6,17 @@ if not utils.BITCOIN_ONLY:
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumTokens(unittest.TestCase): class TestEthereumTokens(unittest.TestCase):
def test_token_by_chain_address(self): def test_token_by_chain_address(self):
token = tokens.token_by_chain_address(1, b'\x7d\xd7\xf5\x6d\x69\x7c\xc0\xf2\xb5\x2b\xd5\x5c\x05\x7f\x37\x8f\x1f\xe6\xab\x4b') token = tokens.token_by_chain_address(1, b"\x7f\xc6\x65\x00\xc8\x4a\x76\xad\x7e\x9c\x93\x43\x7b\xfc\x5a\xc3\x3e\x2d\xda\xe9")
self.assertEqual(token.symbol, '$TEAK') self.assertEqual(token.symbol, 'AAVE')
token = tokens.token_by_chain_address(1, b'\x59\x41\x6a\x25\x62\x8a\x76\xb4\x73\x0e\xc5\x14\x86\x11\x4c\x32\xe0\xb5\x82\xa1')
self.assertEqual(token.symbol, 'PLASMA')
self.assertEqual(token.decimals, 6)
token = tokens.token_by_chain_address(4, b'\x0a\x05\x7a\x87\xce\x9c\x56\xd7\xe3\x36\xb4\x17\xc7\x9c\xf3\x0e\x8d\x27\x86\x0b')
self.assertEqual(token.symbol, 'WALL')
self.assertEqual(token.decimals, 15)
token = tokens.token_by_chain_address(8, b'\x4b\x48\x99\xa1\x0f\x3e\x50\x7d\xb2\x07\xb0\xee\x24\x26\x02\x9e\xfa\x16\x8a\x67')
self.assertEqual(token.symbol, 'QWARK')
# invalid adress, invalid chain # invalid adress, invalid chain
token = tokens.token_by_chain_address(999, b'\x00\xFF') token = tokens.token_by_chain_address(999, b'\x00\xFF')
self.assertIs(token, tokens.UNKNOWN_TOKEN) self.assertIs(token, None)
self.assertEqual(token.symbol, 'Wei UNKN')
self.assertEqual(token.decimals, 0) self.assertEqual(tokens.UNKNOWN_TOKEN.symbol, 'Wei UNKN')
self.assertEqual(tokens.UNKNOWN_TOKEN.decimals, 0)
if __name__ == '__main__': if __name__ == '__main__':

@ -26,16 +26,24 @@ class AssertRaisesContext:
class TestCase: class TestCase:
def __init__(self) -> None:
self.__equality_functions = {}
def fail(self, msg=''): def fail(self, msg=''):
ensure(False, msg) ensure(False, msg)
def addTypeEqualityFunc(self, typeobj, function):
ensure(callable(function))
self.__equality_functions[typeobj.__name__] = function
def assertEqual(self, x, y, msg=''): def assertEqual(self, x, y, msg=''):
if not msg: if not msg:
msg = f"{repr(x)} vs (expected) {repr(y)}" msg = f"{repr(x)} vs (expected) {repr(y)}"
if x.__class__ == y.__class__ and x.__class__.__name__ == "Msg": if x.__class__ == y.__class__ and x.__class__.__name__ == "Msg":
self.assertMessageEqual(x, y) self.assertMessageEqual(x, y)
elif x.__class__.__name__ in self.__equality_functions:
self.__equality_functions[x.__class__.__name__](x, y, msg)
else: else:
ensure(x == y, msg) ensure(x == y, msg)
@ -156,21 +164,32 @@ class TestCase:
self.assertIsInstance(a, b.__class__, msg) self.assertIsInstance(a, b.__class__, msg)
self.assertEqual(a.__dict__, b.__dict__, msg) self.assertEqual(a.__dict__, b.__dict__, msg)
def assertDictEqual(self, x, y):
self.assertEqual(
len(x),
len(y),
f"Dict lengths not equal - {len(x)} vs {len(y)}"
)
for key in x:
self.assertIn(
key,
y,
f"Key {key} not found in second dict."
)
self.assertEqual(
x[key],
y[key],
f"At key {key} expected {x[key]}, found {y[key]}"
)
def assertMessageEqual(self, x, y): def assertMessageEqual(self, x, y):
self.assertEqual( self.assertEqual(
x.MESSAGE_NAME, x.MESSAGE_NAME,
y.MESSAGE_NAME, y.MESSAGE_NAME,
f"Expected {x.MESSAGE_NAME}, found {y.MESSAGE_NAME}" f"Expected {x.MESSAGE_NAME}, found {y.MESSAGE_NAME}"
) )
xdict = x.__dict__ self.assertDictEqual(x.__dict__, y.__dict__)
ydict = y.__dict__
for key in xdict:
self.assertTrue(key in ydict)
self.assertEqual(
xdict[key],
ydict[key],
f"At {x.MESSAGE_NAME}.{key} expected {xdict[key]}, found {ydict[key]}"
)
def skip(msg): def skip(msg):

Loading…
Cancel
Save