1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 04:18:10 +00:00

feat(core): external Ethereum definitions

This commit is contained in:
Martin Novák 2023-02-03 14:39:31 +01:00 committed by matejcik
parent 168ab2944c
commit c2c0900c5d
29 changed files with 1387 additions and 12645 deletions

View File

@ -0,0 +1 @@
Signed Ethereum network and token definitions from host

View File

@ -375,6 +375,8 @@ if not utils.BITCOIN_ONLY:
import trezor.enums.CardanoTxWitnessType
trezor.enums.EthereumDataType
import trezor.enums.EthereumDataType
trezor.enums.EthereumDefinitionType
import trezor.enums.EthereumDefinitionType
trezor.enums.MoneroNetworkType
import trezor.enums.MoneroNetworkType
trezor.enums.NEMImportanceTransferMode
@ -481,6 +483,10 @@ if not utils.BITCOIN_ONLY:
import apps.eos.writers
apps.ethereum
import apps.ethereum
apps.ethereum.definitions
import apps.ethereum.definitions
apps.ethereum.definitions_constants
import apps.ethereum.definitions_constants
apps.ethereum.get_address
import apps.ethereum.get_address
apps.ethereum.get_public_key

View File

@ -33,19 +33,25 @@ def read_compact_size(r: BufferReader) -> int:
def read_uint16_be(r: BufferReader) -> int:
n = r.get()
return (n << 8) + r.get()
data = r.read_memoryview(2)
return int.from_bytes(data, "big")
def read_uint32_be(r: BufferReader) -> int:
n = r.get()
for _ in range(3):
n = (n << 8) + r.get()
return n
data = r.read_memoryview(4)
return int.from_bytes(data, "big")
def read_uint64_be(r: BufferReader) -> int:
n = r.get()
for _ in range(7):
n = (n << 8) + r.get()
return n
data = r.read_memoryview(8)
return int.from_bytes(data, "big")
def read_uint16_le(r: BufferReader) -> int:
data = r.read_memoryview(2)
return int.from_bytes(data, "little")
def read_uint32_le(r: BufferReader) -> int:
data = r.read_memoryview(4)
return int.from_bytes(data, "little")

View File

@ -0,0 +1,154 @@
from typing import TYPE_CHECKING
from trezor import protobuf, utils
from trezor.crypto.curve import ed25519
from trezor.crypto.hashlib import sha256
from trezor.enums import EthereumDefinitionType
from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo
from trezor.wire import DataError
from apps.common import readers
from . import definitions_constants as consts, networks, tokens
from .networks import UNKNOWN_NETWORK
if TYPE_CHECKING:
from typing import TypeVar
from typing_extensions import Self
DefType = TypeVar("DefType", EthereumNetworkInfo, EthereumTokenInfo)
def decode_definition(definition: bytes, expected_type: type[DefType]) -> DefType:
# check network definition
r = utils.BufferReader(definition)
expected_type_number = EthereumDefinitionType.NETWORK
# TODO: can't check equality of MsgDefObjs now, so we check the name
if expected_type.MESSAGE_NAME == EthereumTokenInfo.MESSAGE_NAME:
expected_type_number = EthereumDefinitionType.TOKEN
try:
# first check format version
if r.read_memoryview(len(consts.FORMAT_VERSION)) != consts.FORMAT_VERSION:
raise DataError("Invalid Ethereum definition")
# second check the type of the data
if r.get() != expected_type_number:
raise DataError("Definition type mismatch")
# third check data version
if readers.read_uint32_le(r) < consts.MIN_DATA_VERSION:
raise DataError("Definition is outdated")
# get payload
payload_length = readers.read_uint16_le(r)
payload = r.read_memoryview(payload_length)
# at the end compute Merkle tree root hash using
# provided leaf data (payload with prefix) and proof
hasher = sha256(b"\x00")
hasher.update(memoryview(definition)[: r.offset])
hash = hasher.digest()
proof_length = r.get()
for _ in range(proof_length):
proof_entry = r.read_memoryview(32)
hash_a = min(hash, proof_entry)
hash_b = max(hash, proof_entry)
hasher = sha256(b"\x01")
hasher.update(hash_a)
hasher.update(hash_b)
hash = hasher.digest()
signed_tree_root = r.read_memoryview(64)
if r.remaining_count():
raise DataError("Invalid Ethereum definition")
except EOFError:
raise DataError("Invalid Ethereum definition")
# verify signature
if not ed25519.verify(consts.DEFINITIONS_PUBLIC_KEY, signed_tree_root, hash):
error_msg = DataError("Invalid definition signature")
if __debug__:
# check against dev key
if not ed25519.verify(
consts.DEFINITIONS_DEV_PUBLIC_KEY,
signed_tree_root,
hash,
):
raise error_msg
else:
raise error_msg
# decode it if it's OK
try:
return protobuf.decode(payload, expected_type, True)
except ValueError:
raise DataError("Invalid Ethereum definition")
class Definitions:
"""Class that holds Ethereum definitions - network and tokens.
Prefers built-in definitions over encoded ones.
"""
def __init__(
self, network: EthereumNetworkInfo, tokens: dict[bytes, EthereumTokenInfo]
) -> None:
self.network = network
self._tokens = tokens
@classmethod
def from_encoded(
cls,
encoded_network: bytes | None,
encoded_token: bytes | None,
chain_id: int | None = None,
slip44: int | None = None,
) -> Self:
network = UNKNOWN_NETWORK
tokens: dict[bytes, EthereumTokenInfo] = {}
# if we have a built-in definition, use it
if chain_id is not None:
network = networks.by_chain_id(chain_id)
elif slip44 is not None:
network = networks.by_slip44(slip44)
else:
# ignore encoded definitions if we can't match them to request details
return cls(UNKNOWN_NETWORK, {})
if network is UNKNOWN_NETWORK and encoded_network is not None:
network = decode_definition(encoded_network, EthereumNetworkInfo)
if network is UNKNOWN_NETWORK:
# ignore tokens if we don't have a network
return cls(UNKNOWN_NETWORK, {})
if chain_id is not None and network.chain_id != chain_id:
raise DataError("Network definition mismatch")
if slip44 is not None and network.slip44 != slip44:
raise DataError("Network definition mismatch")
# get token definition
if encoded_token is not None:
token = decode_definition(encoded_token, EthereumTokenInfo)
# Ignore token if it doesn't match the network instead of raising an error.
# This might help us in the future if we allow multiple networks/tokens
# in the same message.
if token.chain_id == network.chain_id:
tokens[token.address] = token
return cls(network, tokens)
def get_token(self, address: bytes) -> EthereumTokenInfo:
# if we have a built-in definition, use it
token = tokens.token_by_chain_address(self.network.chain_id, address)
if token is not None:
return token
if address in self._tokens:
return self._tokens[address]
return tokens.UNKNOWN_TOKEN

View File

@ -0,0 +1,14 @@
# generated from definitions_constants.py.mako
# (by running `make templates` in `core`)
# do not edit manually!
from ubinascii import unhexlify
DEFINITIONS_PUBLIC_KEY = b""
MIN_DATA_VERSION = 1669892465
FORMAT_VERSION = b"trzd1"
if __debug__:
DEFINITIONS_DEV_PUBLIC_KEY = unhexlify(
"db995fe25169d141cab9bbba92baa01f9f2e1ece7df4cb2ac05190f37fcc1f9d"
)

View File

@ -0,0 +1,14 @@
# generated from definitions_constants.py.mako
# (by running `make templates` in `core`)
# do not edit manually!
from ubinascii import unhexlify
DEFINITIONS_PUBLIC_KEY = b""
MIN_DATA_VERSION = ${ethereum_defs_timestamp}
FORMAT_VERSION = b"trzd1"
if __debug__:
DEFINITIONS_DEV_PUBLIC_KEY = unhexlify(
"db995fe25169d141cab9bbba92baa01f9f2e1ece7df4cb2ac05190f37fcc1f9d"
)

View File

@ -7,16 +7,19 @@ if TYPE_CHECKING:
from trezor.wire import Context
from apps.common.keychain import Keychain
from .definitions import Definitions
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def get_address(
ctx: Context, msg: EthereumGetAddress, keychain: Keychain
ctx: Context,
msg: EthereumGetAddress,
keychain: Keychain,
defs: Definitions,
) -> EthereumAddress:
from trezor.messages import EthereumAddress
from trezor.ui.layouts import show_address
from apps.common import paths
from . import networks
from .helpers import address_from_bytes
address_n = msg.address_n # local_cache_attribute
@ -25,11 +28,7 @@ async def get_address(
node = keychain.derive(address_n)
if len(address_n) > 1: # path has slip44 network identifier
network = networks.by_slip44(address_n[1] & 0x7FFF_FFFF)
else:
network = None
address = address_from_bytes(node.ethereum_pubkeyhash(), network)
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
if msg.show_display:
await show_address(ctx, address, path=paths.address_n_to_str(address_n))

View File

@ -1,19 +1,27 @@
from typing import TYPE_CHECKING
from ubinascii import hexlify
from . import networks
if TYPE_CHECKING:
from trezor.messages import EthereumFieldType
from .networks import NetworkInfo
from .networks import EthereumNetworkInfo
def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None) -> str:
RSKIP60_NETWORKS = (30, 31)
def address_from_bytes(
address_bytes: bytes, network: EthereumNetworkInfo = networks.UNKNOWN_NETWORK
) -> str:
"""
Converts address in bytes to a checksummed string as defined
in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
"""
from trezor.crypto.hashlib import sha3_256
if network is not None and network.rskip60:
if network.chain_id in RSKIP60_NETWORKS:
# rskip60 is a different way to calculate checksum
prefix = str(network.chain_id) + "0x"
else:
prefix = ""

View File

@ -1,37 +1,56 @@
from typing import TYPE_CHECKING
from trezor.messages import EthereumNetworkInfo
from apps.common import paths
from apps.common.keychain import get_keychain
from . import CURVE, networks
from . import CURVE, definitions, networks
if TYPE_CHECKING:
from typing import Callable, Iterable, TypeVar
from typing import Any, Awaitable, Callable, Iterable, TypeVar
from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.messages import (
EthereumGetAddress,
EthereumGetPublicKey,
EthereumSignMessage,
EthereumSignTx,
EthereumSignTxEIP1559,
EthereumSignTypedData,
)
from apps.common.keychain import MsgOut, Handler, HandlerWithKeychain
EthereumMessages = (
EthereumGetAddress
| EthereumGetPublicKey
| EthereumSignTx
| EthereumSignMessage
| EthereumSignTypedData
from apps.common.keychain import (
MsgOut,
Handler,
)
MsgIn = TypeVar("MsgIn", bound=EthereumMessages)
EthereumSignTxAny = EthereumSignTx | EthereumSignTxEIP1559
MsgInChainId = TypeVar("MsgInChainId", bound=EthereumSignTxAny)
# messages for "with_keychain_and_network_from_path" decorator
MsgInAddressN = TypeVar(
"MsgInAddressN",
EthereumGetAddress,
EthereumSignMessage,
EthereumSignTypedData,
)
HandlerAddressN = Callable[
[Context, MsgInAddressN, Keychain, definitions.Definitions],
Awaitable[MsgOut],
]
# messages for "with_keychain_and_defs_from_chain_id" decorator
MsgInSignTx = TypeVar(
"MsgInSignTx",
EthereumSignTx,
EthereumSignTxEIP1559,
)
HandlerChainId = Callable[
[Context, MsgInSignTx, Keychain, definitions.Definitions],
Awaitable[MsgOut],
]
# We believe Ethereum should use 44'/60'/a' for everything, because it is
@ -48,67 +67,83 @@ PATTERNS_ADDRESS = (
)
def _schemas_from_address_n(
patterns: Iterable[str], address_n: paths.Bip32Path
def _slip44_from_address_n(address_n: paths.Bip32Path) -> int | None:
HARDENED = paths.HARDENED # local_cache_attribute
if len(address_n) < 2:
return None
if address_n[0] == 45 | HARDENED and not address_n[1] & HARDENED:
return address_n[1]
return address_n[1] & ~HARDENED
def _defs_from_message(
msg: Any, chain_id: int | None = None, slip44: int | None = None
) -> definitions.Definitions:
encoded_network = None
encoded_token = None
# try to get both from msg.definitions
if hasattr(msg, "definitions"):
if msg.definitions is not None:
encoded_network = msg.definitions.encoded_network
encoded_token = msg.definitions.encoded_token
elif hasattr(msg, "encoded_network"):
encoded_network = msg.encoded_network
return definitions.Definitions.from_encoded(
encoded_network, encoded_token, chain_id, slip44
)
def _schemas_from_network(
patterns: Iterable[str],
network_info: EthereumNetworkInfo,
) -> Iterable[paths.PathSchema]:
# Casa paths (purpose of 45) do not have hardened coin types
if address_n[0] == 45 | paths.HARDENED and not address_n[1] & paths.HARDENED:
slip44_hardened = address_n[1] | paths.HARDENED
slip44_id: tuple[int, ...]
if network_info is networks.UNKNOWN_NETWORK:
# allow Ethereum or testnet paths for unknown networks
slip44_id = (60, 1)
elif network_info.slip44 not in (60, 1):
# allow cross-signing with Ethereum unless it's testnet
slip44_id = (network_info.slip44, 60)
else:
slip44_hardened = address_n[1]
slip44_id = (network_info.slip44,)
if slip44_hardened not in networks.all_slip44_ids_hardened():
return ()
if not slip44_hardened & paths.HARDENED:
return ()
slip44_id = slip44_hardened - paths.HARDENED
schemas = [paths.PathSchema.parse(pattern, slip44_id) for pattern in patterns]
return [s.copy() for s in schemas]
def with_keychain_from_path(
*patterns: str,
) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]:
def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: Context, msg: MsgIn) -> MsgOut:
schemas = _schemas_from_address_n(patterns, msg.address_n)
) -> Callable[[HandlerAddressN[MsgInAddressN, MsgOut]], Handler[MsgInAddressN, MsgOut]]:
def decorator(
func: HandlerAddressN[MsgInAddressN, MsgOut]
) -> Handler[MsgInAddressN, MsgOut]:
async def wrapper(ctx: Context, msg: MsgInAddressN) -> MsgOut:
slip44 = _slip44_from_address_n(msg.address_n)
defs = _defs_from_message(msg, slip44=slip44)
schemas = _schemas_from_network(patterns, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain)
return await func(ctx, msg, keychain, defs)
return wrapper
return decorator
def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]:
info = networks.by_chain_id(msg.chain_id)
slip44_id: tuple[int, ...]
if info is None:
# allow Ethereum or testnet paths for unknown networks
slip44_id = (60, 1)
elif info.slip44 not in (60, 1):
# allow cross-signing with Ethereum unless it's testnet
slip44_id = (info.slip44, 60)
else:
slip44_id = (info.slip44,)
schemas = [
paths.PathSchema.parse(pattern, slip44_id) for pattern in PATTERNS_ADDRESS
]
return [s.copy() for s in schemas]
def with_keychain_from_chain_id(
func: HandlerWithKeychain[MsgInChainId, MsgOut]
) -> Handler[MsgInChainId, MsgOut]:
func: HandlerChainId[MsgInSignTx, MsgOut]
) -> Handler[MsgInSignTx, MsgOut]:
# this is only for SignTx, and only PATTERN_ADDRESS is allowed
async def wrapper(ctx: Context, msg: MsgInChainId) -> MsgOut:
schemas = _schemas_from_chain_id(msg)
async def wrapper(ctx: Context, msg: MsgInSignTx) -> MsgOut:
defs = _defs_from_message(msg, chain_id=msg.chain_id)
schemas = _schemas_from_network(PATTERNS_ADDRESS, defs.network)
keychain = await get_keychain(ctx, CURVE, schemas)
with keychain:
return await func(ctx, msg, keychain)
return await func(ctx, msg, keychain, defs)
return wrapper

View File

@ -11,35 +11,38 @@ from trezor.ui.layouts import (
should_show_more,
)
from . import networks
from .helpers import decode_typed_data
if TYPE_CHECKING:
from typing import Awaitable, Iterable
from trezor.messages import EthereumFieldType, EthereumStructMember
from trezor.messages import (
EthereumFieldType,
EthereumNetworkInfo,
EthereumStructMember,
EthereumTokenInfo,
)
from trezor.wire import Context
from . import tokens
def require_confirm_tx(
ctx: Context,
to_bytes: bytes,
value: int,
chain_id: int,
token: tokens.TokenInfo | None = None,
network: EthereumNetworkInfo,
token: EthereumTokenInfo | None,
) -> Awaitable[None]:
from .helpers import address_from_bytes
from trezor.ui.layouts import confirm_output
if to_bytes:
to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id))
to_str = address_from_bytes(to_bytes, network)
else:
to_str = "new contract?"
return confirm_output(
ctx,
to_str,
format_ethereum_amount(value, token, chain_id),
format_ethereum_amount(value, token, network),
br_code=ButtonRequestType.SignTx,
)
@ -49,19 +52,19 @@ async def require_confirm_fee(
spending: int,
gas_price: int,
gas_limit: int,
chain_id: int,
token: tokens.TokenInfo | None = None,
network: EthereumNetworkInfo,
token: EthereumTokenInfo | None,
) -> None:
await confirm_amount(
ctx,
title="Confirm fee",
description="Gas price:",
amount=format_ethereum_amount(gas_price, None, chain_id),
amount=format_ethereum_amount(gas_price, None, network),
)
await confirm_total(
ctx,
total_amount=format_ethereum_amount(spending, token, chain_id),
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, chain_id),
total_amount=format_ethereum_amount(spending, token, network),
fee_amount=format_ethereum_amount(gas_price * gas_limit, None, network),
total_label="Amount sent:",
fee_label="Maximum fee:",
)
@ -73,25 +76,25 @@ async def require_confirm_eip1559_fee(
max_priority_fee: int,
max_gas_fee: int,
gas_limit: int,
chain_id: int,
token: tokens.TokenInfo | None = None,
network: EthereumNetworkInfo,
token: EthereumTokenInfo | None,
) -> None:
await confirm_amount(
ctx,
"Confirm fee",
format_ethereum_amount(max_gas_fee, None, chain_id),
format_ethereum_amount(max_gas_fee, None, network),
"Maximum fee per gas",
)
await confirm_amount(
ctx,
"Confirm fee",
format_ethereum_amount(max_priority_fee, None, chain_id),
format_ethereum_amount(max_priority_fee, None, network),
"Priority fee per gas",
)
await confirm_total(
ctx,
format_ethereum_amount(spending, token, chain_id),
format_ethereum_amount(max_gas_fee * gas_limit, None, chain_id),
format_ethereum_amount(spending, token, network),
format_ethereum_amount(max_gas_fee * gas_limit, None, network),
total_label="Amount sent:",
fee_label="Maximum fee:",
)
@ -262,7 +265,9 @@ async def confirm_typed_value(
def format_ethereum_amount(
value: int, token: tokens.TokenInfo | None, chain_id: int
value: int,
token: EthereumTokenInfo | None,
network: EthereumNetworkInfo,
) -> str:
from trezor.strings import format_amount
@ -270,7 +275,7 @@ def format_ethereum_amount(
suffix = token.symbol
decimals = token.decimals
else:
suffix = networks.shortcut_by_chain_id(chain_id)
suffix = network.symbol
decimals = 18
# Don't want to display wei values for tokens with small decimal numbers

File diff suppressed because it is too large Load Diff

View File

@ -7,7 +7,7 @@
from typing import TYPE_CHECKING
from apps.common.paths import HARDENED
from trezor.messages import EthereumNetworkInfo
if TYPE_CHECKING:
from typing import Iterator
@ -17,71 +17,52 @@ if TYPE_CHECKING:
NetworkInfoTuple = tuple[
int, # chain_id
int, # slip44
str, # shortcut
str, # symbol
str, # name
bool # rskip60
]
# fmt: on
def shortcut_by_chain_id(chain_id: int) -> str:
n = by_chain_id(chain_id)
return n.shortcut if n is not None else "UNKN"
UNKNOWN_NETWORK = EthereumNetworkInfo(
chain_id=0,
slip44=0,
symbol="UNKN",
name="Unknown network",
)
def by_chain_id(chain_id: int) -> "NetworkInfo" | None:
def by_chain_id(chain_id: int) -> EthereumNetworkInfo:
for n in _networks_iterator():
n_chain_id = n[0]
if n_chain_id == chain_id:
return NetworkInfo(
return EthereumNetworkInfo(
chain_id=n[0],
slip44=n[1],
shortcut=n[2],
symbol=n[2],
name=n[3],
rskip60=n[4],
)
return None
return UNKNOWN_NETWORK
def by_slip44(slip44: int) -> "NetworkInfo" | None:
def by_slip44(slip44: int) -> EthereumNetworkInfo:
for n in _networks_iterator():
n_slip44 = n[1]
if n_slip44 == slip44:
return NetworkInfo(
return EthereumNetworkInfo(
chain_id=n[0],
slip44=n[1],
shortcut=n[2],
symbol=n[2],
name=n[3],
rskip60=n[4],
)
return None
def all_slip44_ids_hardened() -> Iterator[int]:
for n in _networks_iterator():
# n_slip_44 is the second element
yield n[1] | HARDENED
class NetworkInfo:
def __init__(
self, chain_id: int, slip44: int, shortcut: str, name: str, rskip60: bool
) -> None:
self.chain_id = chain_id
self.slip44 = slip44
self.shortcut = shortcut
self.name = name
self.rskip60 = rskip60
return UNKNOWN_NETWORK
# fmt: off
def _networks_iterator() -> Iterator[NetworkInfoTuple]:
% for n in supported_on("trezor2", eth):
% for n in sorted(supported_on("trezor2", eth), key=lambda network: (int(network.chain_id), network.name)):
yield (
${n.chain_id}, # chain_id
${n.slip44}, # slip44
"${n.shortcut}", # shortcut
"${n.shortcut}", # symbol
"${n.name}", # name
${n.rskip60}, # rskip60
)
% endfor

View File

@ -3,10 +3,14 @@ from typing import TYPE_CHECKING
from .keychain import PATTERNS_ADDRESS, with_keychain_from_path
if TYPE_CHECKING:
from trezor.messages import EthereumSignMessage, EthereumMessageSignature
from trezor.messages import (
EthereumSignMessage,
EthereumMessageSignature,
)
from trezor.wire import Context
from apps.common.keychain import Keychain
from .definitions import Definitions
def message_digest(message: bytes) -> bytes:
@ -23,7 +27,10 @@ def message_digest(message: bytes) -> bytes:
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_message(
ctx: Context, msg: EthereumSignMessage, keychain: Keychain
ctx: Context,
msg: EthereumSignMessage,
keychain: Keychain,
defs: Definitions,
) -> EthereumMessageSignature:
from trezor.crypto.curve import secp256k1
from trezor.messages import EthereumMessageSignature
@ -37,7 +44,7 @@ async def sign_message(
await paths.validate_path(ctx, keychain, msg.address_n)
node = keychain.derive(msg.address_n)
address = address_from_bytes(node.ethereum_pubkeyhash())
address = address_from_bytes(node.ethereum_pubkeyhash(), defs.network)
await confirm_signverify(
ctx, "ETH", decode_message(msg.message), address, verify=False
)

View File

@ -9,11 +9,11 @@ from .keychain import with_keychain_from_chain_id
if TYPE_CHECKING:
from apps.common.keychain import Keychain
from trezor.messages import EthereumSignTx, EthereumTxAck
from trezor.messages import EthereumSignTx, EthereumTxAck, EthereumTokenInfo
from trezor.wire import Context
from .keychain import EthereumSignTxAny
from . import tokens
from .definitions import Definitions
from .keychain import MsgInSignTx
# Maximum chain_id which returns the full signature_v (which must fit into an uint32).
@ -24,7 +24,10 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2
@with_keychain_from_chain_id
async def sign_tx(
ctx: Context, msg: EthereumSignTx, keychain: Keychain
ctx: Context,
msg: EthereumSignTx,
keychain: Keychain,
defs: Definitions,
) -> EthereumTxRequest:
from trezor.utils import HashWriter
from trezor.crypto.hashlib import sha3_256
@ -45,11 +48,11 @@ async def sign_tx(
await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg)
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs)
data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token)
await require_confirm_tx(ctx, recipient, value, defs.network, token)
if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
@ -58,7 +61,7 @@ async def sign_tx(
value,
int.from_bytes(msg.gas_price, "big"),
int.from_bytes(msg.gas_limit, "big"),
msg.chain_id,
defs.network,
token,
)
@ -100,13 +103,14 @@ async def sign_tx(
async def handle_erc20(
ctx: Context, msg: EthereumSignTxAny
) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]:
ctx: Context,
msg: MsgInSignTx,
definitions: Definitions,
) -> tuple[EthereumTokenInfo | None, bytes, bytes, int]:
from .layout import require_confirm_unknown_token
from . import tokens
data_initial_chunk = msg.data_initial_chunk # local_cache_attribute
token = None
address_bytes = recipient = bytes_from_address(msg.to)
value = int.from_bytes(msg.value, "big")
@ -118,7 +122,7 @@ async def handle_erc20(
and data_initial_chunk[:16]
== b"\xa9\x05\x9c\xbb\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"
):
token = tokens.token_by_chain_address(msg.chain_id, address_bytes)
token = definitions.get_token(address_bytes)
recipient = data_initial_chunk[16:36]
value = int.from_bytes(data_initial_chunk[36:68], "big")
@ -185,7 +189,7 @@ def _sign_digest(
return req
def check_common_fields(msg: EthereumSignTxAny) -> None:
def check_common_fields(msg: MsgInSignTx) -> None:
data_length = msg.data_length # local_cache_attribute
if data_length > 0:

View File

@ -12,9 +12,9 @@ if TYPE_CHECKING:
EthereumAccessList,
EthereumTxRequest,
)
from apps.common.keychain import Keychain
from trezor.wire import Context
from apps.common.keychain import Keychain
from .definitions import Definitions
_TX_TYPE = const(2)
@ -30,7 +30,10 @@ def access_list_item_length(item: EthereumAccessList) -> int:
@with_keychain_from_chain_id
async def sign_tx_eip1559(
ctx: Context, msg: EthereumSignTxEIP1559, keychain: Keychain
ctx: Context,
msg: EthereumSignTxEIP1559,
keychain: Keychain,
defs: Definitions,
) -> EthereumTxRequest:
from trezor.crypto.hashlib import sha3_256
from trezor.utils import HashWriter
@ -56,11 +59,11 @@ async def sign_tx_eip1559(
await paths.validate_path(ctx, keychain, msg.address_n)
# Handle ERC20s
token, address_bytes, recipient, value = await handle_erc20(ctx, msg)
token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs)
data_total = msg.data_length
await require_confirm_tx(ctx, recipient, value, msg.chain_id, token)
await require_confirm_tx(ctx, recipient, value, defs.network, token)
if token is None and msg.data_length > 0:
await require_confirm_data(ctx, msg.data_initial_chunk, data_total)
@ -70,7 +73,7 @@ async def sign_tx_eip1559(
int.from_bytes(msg.max_priority_fee, "big"),
int.from_bytes(msg.max_gas_fee, "big"),
int.from_bytes(gas_limit, "big"),
msg.chain_id,
defs.network,
token,
)

View File

@ -11,6 +11,7 @@ if TYPE_CHECKING:
from apps.common.keychain import Keychain
from trezor.wire import Context
from trezor.utils import HashWriter
from .definitions import Definitions
from trezor.messages import (
EthereumSignTypedData,
@ -22,7 +23,10 @@ if TYPE_CHECKING:
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def sign_typed_data(
ctx: Context, msg: EthereumSignTypedData, keychain: Keychain
ctx: Context,
msg: EthereumSignTypedData,
keychain: Keychain,
defs: Definitions,
) -> EthereumTypedDataSignature:
from trezor.crypto.curve import secp256k1
from apps.common import paths
@ -47,7 +51,7 @@ async def sign_typed_data(
)
return EthereumTypedDataSignature(
address=address_from_bytes(address_bytes),
address=address_from_bytes(address_bytes, defs.network),
signature=signature[1:] + signature[0:1],
)

File diff suppressed because it is too large Load Diff

View File

@ -16,6 +16,7 @@
from typing import Iterator
from trezor.messages import EthereumTokenInfo
<%
from collections import defaultdict
@ -26,30 +27,37 @@ def group_tokens(tokens):
return r
%>\
class TokenInfo:
def __init__(self, symbol: str, decimals: int) -> None:
self.symbol = symbol
self.decimals = decimals
UNKNOWN_TOKEN = EthereumTokenInfo(
symbol="Wei UNKN",
decimals=0,
address=b"",
chain_id=0,
name="Unknown token",
)
UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0)
def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo:
for addr, symbol, decimal in _token_iterator(chain_id):
def token_by_chain_address(chain_id: int, address: bytes) -> EthereumTokenInfo | None:
for addr, symbol, decimal, name in _token_iterator(chain_id):
if address == addr:
return TokenInfo(symbol, decimal)
return UNKNOWN_TOKEN
return EthereumTokenInfo(
symbol=symbol,
decimals=decimal,
address=address,
chain_id=chain_id,
name=name,
)
return None
def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int]]:
def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int, str]]:
% for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items():
if chain_id == ${token_chain_id}:
if chain_id == ${token_chain_id}: # ${tokens[0].chain}
% for t in tokens:
yield ( # address, symbol, decimals
yield ( # address, symbol, decimals, name
${black_repr(t.address_bytes)},
${black_repr(t.symbol)},
${t.decimals},
${black_repr(t.name.strip())},
)
% endfor
% endfor

View File

@ -0,0 +1,6 @@
# Automatically generated by pb2py
# fmt: off
# isort:skip_file
NETWORK = 0
TOKEN = 1

View File

@ -455,6 +455,10 @@ if TYPE_CHECKING:
YES = 1
INFO = 2
class EthereumDefinitionType(IntEnum):
NETWORK = 0
TOKEN = 1
class EthereumDataType(IntEnum):
UINT = 1
INT = 2

View File

@ -38,6 +38,7 @@ if TYPE_CHECKING:
from trezor.enums import DebugSwipeDirection # noqa: F401
from trezor.enums import DecredStakingSpendType # noqa: F401
from trezor.enums import EthereumDataType # noqa: F401
from trezor.enums import EthereumDefinitionType # noqa: F401
from trezor.enums import FailureType # noqa: F401
from trezor.enums import HomescreenFormat # noqa: F401
from trezor.enums import InputScriptType # noqa: F401
@ -3372,10 +3373,69 @@ if TYPE_CHECKING:
def is_type_of(cls, msg: Any) -> TypeGuard["EosActionUnknown"]:
return isinstance(msg, cls)
class EthereumNetworkInfo(protobuf.MessageType):
chain_id: "int"
symbol: "str"
slip44: "int"
name: "str"
def __init__(
self,
*,
chain_id: "int",
symbol: "str",
slip44: "int",
name: "str",
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["EthereumNetworkInfo"]:
return isinstance(msg, cls)
class EthereumTokenInfo(protobuf.MessageType):
address: "bytes"
chain_id: "int"
symbol: "str"
decimals: "int"
name: "str"
def __init__(
self,
*,
address: "bytes",
chain_id: "int",
symbol: "str",
decimals: "int",
name: "str",
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["EthereumTokenInfo"]:
return isinstance(msg, cls)
class EthereumDefinitions(protobuf.MessageType):
encoded_network: "bytes | None"
encoded_token: "bytes | None"
def __init__(
self,
*,
encoded_network: "bytes | None" = None,
encoded_token: "bytes | None" = None,
) -> None:
pass
@classmethod
def is_type_of(cls, msg: Any) -> TypeGuard["EthereumDefinitions"]:
return isinstance(msg, cls)
class EthereumSignTypedData(protobuf.MessageType):
address_n: "list[int]"
primary_type: "str"
metamask_v4_compat: "bool"
definitions: "EthereumDefinitions | None"
def __init__(
self,
@ -3383,6 +3443,7 @@ if TYPE_CHECKING:
primary_type: "str",
address_n: "list[int] | None" = None,
metamask_v4_compat: "bool | None" = None,
definitions: "EthereumDefinitions | None" = None,
) -> None:
pass
@ -3517,12 +3578,14 @@ if TYPE_CHECKING:
class EthereumGetAddress(protobuf.MessageType):
address_n: "list[int]"
show_display: "bool | None"
encoded_network: "bytes | None"
def __init__(
self,
*,
address_n: "list[int] | None" = None,
show_display: "bool | None" = None,
encoded_network: "bytes | None" = None,
) -> None:
pass
@ -3555,6 +3618,7 @@ if TYPE_CHECKING:
data_length: "int"
chain_id: "int"
tx_type: "int | None"
definitions: "EthereumDefinitions | None"
def __init__(
self,
@ -3569,6 +3633,7 @@ if TYPE_CHECKING:
data_initial_chunk: "bytes | None" = None,
data_length: "int | None" = None,
tx_type: "int | None" = None,
definitions: "EthereumDefinitions | None" = None,
) -> None:
pass
@ -3588,6 +3653,7 @@ if TYPE_CHECKING:
data_length: "int"
chain_id: "int"
access_list: "list[EthereumAccessList]"
definitions: "EthereumDefinitions | None"
def __init__(
self,
@ -3603,6 +3669,7 @@ if TYPE_CHECKING:
access_list: "list[EthereumAccessList] | None" = None,
to: "str | None" = None,
data_initial_chunk: "bytes | None" = None,
definitions: "EthereumDefinitions | None" = None,
) -> None:
pass
@ -3647,12 +3714,14 @@ if TYPE_CHECKING:
class EthereumSignMessage(protobuf.MessageType):
address_n: "list[int]"
message: "bytes"
encoded_network: "bytes | None"
def __init__(
self,
*,
message: "bytes",
address_n: "list[int] | None" = None,
encoded_network: "bytes | None" = None,
) -> None:
pass
@ -3698,6 +3767,7 @@ if TYPE_CHECKING:
address_n: "list[int]"
domain_separator_hash: "bytes"
message_hash: "bytes | None"
encoded_network: "bytes | None"
def __init__(
self,
@ -3705,6 +3775,7 @@ if TYPE_CHECKING:
domain_separator_hash: "bytes",
address_n: "list[int] | None" = None,
message_hash: "bytes | None" = None,
encoded_network: "bytes | None" = None,
) -> None:
pass

View File

@ -5,15 +5,25 @@ sys.path.append("../src")
from ubinascii import hexlify, unhexlify # noqa: F401
import unittest # noqa: F401
from typing import Any, Awaitable
from trezor import utils # noqa: F401
from apps.common.paths import HARDENED
def H_(x: int) -> int:
"""
Shortcut function that "hardens" a number in a BIP44 path.
"""
return x | HARDENED
def UH_(x: int) -> int:
"""
Shortcut function that "un-hardens" a number in a BIP44 path.
"""
return x & ~(HARDENED)
def await_result(task: Awaitable) -> Any:
value = None
while True:

View File

@ -0,0 +1,102 @@
from ubinascii import unhexlify # noqa: F401
from trezor import messages, protobuf
from trezor.enums import EthereumDefinitionType
from trezor.crypto.curve import ed25519
from trezor.crypto.hashlib import sha256
DEFINITIONS_DEV_PRIVATE_KEY = unhexlify(
"4141414141414141414141414141414141414141414141414141414141414141"
)
def make_network(
chain_id: int = 0,
slip44: int = 0,
symbol: str = "FAKE",
name: str = "Fake network",
) -> messages.EthereumNetworkInfo:
return messages.EthereumNetworkInfo(
chain_id=chain_id,
slip44=slip44,
symbol=symbol,
name=name,
)
def make_token(
symbol: str = "FAKE",
decimals: int = 18,
address: bytes = b"",
chain_id: int = 0,
name: str = "Fake token",
) -> messages.EthereumTokenInfo:
return messages.EthereumTokenInfo(
symbol=symbol,
decimals=decimals,
address=address,
chain_id=chain_id,
name=name,
)
def make_payload(
prefix: bytes = b"trzd1",
data_type: EthereumDefinitionType = EthereumDefinitionType.NETWORK,
timestamp: int = 0xFFFF_FFFF,
message: messages.EthereumNetworkInfo
| messages.EthereumTokenInfo
| bytes = make_network(),
) -> bytes:
payload = prefix
payload += data_type.to_bytes(1, "little")
payload += timestamp.to_bytes(4, "little")
if isinstance(message, bytes):
message_bytes = message
else:
message_bytes = protobuf.dump_message_buffer(message)
payload += len(message_bytes).to_bytes(2, "little")
payload += message_bytes
return payload
def sign_payload(payload: bytes, merkle_neighbors: list[bytes]) -> tuple[bytes, bytes]:
digest = sha256(b"\x00" + payload).digest()
merkle_proof = []
for item in merkle_neighbors:
left, right = min(digest, item), max(digest, item)
digest = sha256(b"\x01" + left + right).digest()
merkle_proof.append(digest)
merkle_proof = len(merkle_proof).to_bytes(1, "little") + b"".join(merkle_proof)
signature = ed25519.sign(DEFINITIONS_DEV_PRIVATE_KEY, digest)
return merkle_proof, signature
def encode_network(
network: messages.EthereumNetworkInfo | None = None,
chain_id: int = 0,
slip44: int = 0,
symbol: str = "FAKE",
name: str = "Fake network",
) -> bytes:
if network is None:
network = make_network(chain_id, slip44, symbol, name)
payload = make_payload(data_type=EthereumDefinitionType.NETWORK, message=network)
proof, signature = sign_payload(payload, [])
return payload + proof + signature
def encode_token(
token: messages.EthereumTokenInfo | None = None,
symbol: str = "FAKE",
decimals: int = 18,
address: bytes = b"",
chain_id: int = 0,
name: str = "Fake token",
) -> bytes:
if token is None:
token = make_token(symbol, decimals, address, chain_id, name)
payload = make_payload(data_type=EthereumDefinitionType.TOKEN, message=token)
proof, signature = sign_payload(payload, [])
return payload + proof + signature

View File

@ -0,0 +1,250 @@
from common import *
import unittest
import typing as t
from trezor import utils, wire
from ubinascii import hexlify # noqa: F401
if not utils.BITCOIN_ONLY:
from apps.ethereum import networks, tokens
from apps.ethereum.definitions import decode_definition, Definitions
from ethereum_common import *
from trezor import protobuf
from trezor.enums import EthereumDefinitionType
from trezor.messages import (
EthereumDefinitions,
EthereumNetworkInfo,
EthereumTokenInfo,
EthereumSignTx,
EthereumSignTxEIP1559,
EthereumSignTypedData,
)
TETHER_ADDRESS = b"\xda\xc1\x7f\x95\x8d\x2e\xe5\x23\xa2\x20\x62\x06\x99\x45\x97\xc1\x3d\x83\x1e\xc7"
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestDecodeDefinition(unittest.TestCase):
def test_short_message(self):
with self.assertRaises(wire.DataError):
decode_definition(b"\x00", EthereumNetworkInfo)
with self.assertRaises(wire.DataError):
decode_definition(b"\x00", EthereumTokenInfo)
# successful decode network
def test_network_definition(self):
network = make_network(chain_id=42, slip44=69, symbol="FAKE", name="Fakenet")
encoded = encode_network(network)
try:
self.assertEqual(decode_definition(encoded, EthereumNetworkInfo), network)
except Exception as e:
print(e.message)
# successful decode token
def test_token_definition(self):
token = make_token("FAKE", decimals=33, address=b"abcd" * 5, chain_id=42)
encoded = encode_token(token)
self.assertEqual(decode_definition(encoded, EthereumTokenInfo), token)
def assertFailed(self, data: bytes) -> None:
with self.assertRaises(wire.DataError):
decode_definition(data, EthereumNetworkInfo)
def test_mangled_signature(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
bad_signature = signature[:-1] + b"\xff"
self.assertFailed(payload + proof + bad_signature)
def test_missing_signature(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof)
def test_mangled_payload(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
bad_payload = payload[:-1] + b"\xff"
self.assertFailed(bad_payload + proof + signature)
def test_proof_length_mismatch(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
bad_proof = b"\x01"
self.assertFailed(payload + bad_proof + signature)
def test_bad_proof(self):
payload = make_payload()
proof, signature = sign_payload(payload, [sha256(b"x").digest()])
bad_proof = proof[:-1] + b"\xff"
self.assertFailed(payload + bad_proof + signature)
def test_trimmed_proof(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
bad_proof = proof[:-1]
self.assertFailed(payload + bad_proof + signature)
def test_bad_prefix(self):
payload = make_payload(prefix=b"trzd2")
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_bad_type(self):
payload = make_payload(
data_type=EthereumDefinitionType.TOKEN, message=make_token()
)
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_outdated(self):
payload = make_payload(timestamp=0)
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_malformed_protobuf(self):
payload = make_payload(message=b"\x00")
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_protobuf_mismatch(self):
payload = make_payload(
data_type=EthereumDefinitionType.NETWORK, message=make_token()
)
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
payload = make_payload(
data_type=EthereumDefinitionType.TOKEN, message=make_network()
)
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature)
def test_trailing_garbage(self):
payload = make_payload()
proof, signature = sign_payload(payload, [])
self.assertFailed(payload + proof + signature + b"\x00")
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumDefinitions(unittest.TestCase):
def assertUnknown(self, what: t.Any) -> None:
if what is networks.UNKNOWN_NETWORK:
return
if what is tokens.UNKNOWN_TOKEN:
return
self.fail("Expected UNKNOWN_*, got %r" % what)
def assertKnown(self, what: t.Any) -> None:
if not EthereumNetworkInfo.is_type_of(
what
) and not EthereumTokenInfo.is_type_of(what):
self.fail("Expected network / token info, got %r" % what)
if what is networks.UNKNOWN_NETWORK:
self.fail("Expected known network, got UNKNOWN_NETWORK")
if what is tokens.UNKNOWN_TOKEN:
self.fail("Expected known token, got UNKNOWN_TOKEN")
def test_empty(self) -> None:
# no slip44 nor chain_id -- should short-circuit and always be unknown
defs = Definitions.from_encoded(None, None)
self.assertUnknown(defs.network)
self.assertFalse(defs._tokens)
self.assertUnknown(defs.get_token(TETHER_ADDRESS))
# chain_id provided, no definition
defs = Definitions.from_encoded(None, None, chain_id=100_000)
self.assertUnknown(defs.network)
self.assertFalse(defs._tokens)
self.assertUnknown(defs.get_token(TETHER_ADDRESS))
def test_builtin(self) -> None:
defs = Definitions.from_encoded(None, None, chain_id=1)
self.assertKnown(defs.network)
self.assertFalse(defs._tokens)
self.assertKnown(defs.get_token(TETHER_ADDRESS))
self.assertUnknown(defs.get_token(b"\x00" * 20))
defs = Definitions.from_encoded(None, None, slip44=60)
self.assertKnown(defs.network)
self.assertFalse(defs._tokens)
self.assertKnown(defs.get_token(TETHER_ADDRESS))
self.assertUnknown(defs.get_token(b"\x00" * 20))
def test_external(self) -> None:
network = make_network(chain_id=42)
defs = Definitions.from_encoded(encode_network(network), None, chain_id=42)
self.assertEqual(defs.network, network)
self.assertUnknown(defs.get_token(b"\x00" * 20))
token = make_token(chain_id=42, address=b"\x00" * 20)
defs = Definitions.from_encoded(
encode_network(network), encode_token(token), chain_id=42
)
self.assertEqual(defs.network, network)
self.assertEqual(defs.get_token(b"\x00" * 20), token)
token = make_token(chain_id=1, address=b"\x00" * 20)
defs = Definitions.from_encoded(None, encode_token(token), chain_id=1)
self.assertKnown(defs.network)
self.assertEqual(defs.get_token(b"\x00" * 20), token)
def test_external_token_mismatch(self) -> None:
network = make_network(chain_id=42)
token = make_token(chain_id=43, address=b"\x00" * 20)
defs = Definitions.from_encoded(encode_network(network), encode_token(token))
self.assertUnknown(defs.get_token(b"\x00" * 20))
def test_external_chain_match(self) -> None:
network = make_network(chain_id=42)
token = make_token(chain_id=42, address=b"\x00" * 20)
defs = Definitions.from_encoded(
encode_network(network), encode_token(token), chain_id=42
)
self.assertEqual(defs.network, network)
self.assertEqual(defs.get_token(b"\x00" * 20), token)
with self.assertRaises(wire.DataError):
Definitions.from_encoded(
encode_network(network), encode_token(token), chain_id=333
)
def test_external_slip44_mismatch(self) -> None:
network = make_network(chain_id=42, slip44=1999)
token = make_token(chain_id=42, address=b"\x00" * 20)
defs = Definitions.from_encoded(
encode_network(network), encode_token(token), slip44=1999
)
self.assertEqual(defs.network, network)
self.assertEqual(defs.get_token(b"\x00" * 20), token)
with self.assertRaises(wire.DataError):
Definitions.from_encoded(
encode_network(network), encode_token(token), slip44=333
)
def test_ignore_encoded_network(self) -> None:
# when network is builtin, ignore the encoded one
network = encode_network(chain_id=1, symbol="BAD")
defs = Definitions.from_encoded(network, None, chain_id=1)
self.assertNotEqual(defs.network, network)
def test_ignore_encoded_token(self) -> None:
# when token is builtin, ignore the encoded one
token = encode_token(chain_id=1, address=TETHER_ADDRESS, symbol="BAD")
defs = Definitions.from_encoded(None, token, chain_id=1)
self.assertNotEqual(defs.get_token(TETHER_ADDRESS), token)
def test_ignore_with_no_match(self) -> None:
network = encode_network(chain_id=100_000, symbol="BAD")
# smoke test: definition is accepted
defs = Definitions.from_encoded(network, None, chain_id=100_000)
self.assertKnown(defs.network)
# same definition but nothing to match it to
defs = Definitions.from_encoded(network, None)
self.assertUnknown(defs.network)
if __name__ == "__main__":
unittest.main()

View File

@ -3,23 +3,22 @@ from apps.common.paths import HARDENED
if not utils.BITCOIN_ONLY:
from apps.ethereum.helpers import address_from_bytes
from apps.ethereum.networks import NetworkInfo
from ethereum_common import make_network
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumGetAddress(unittest.TestCase):
def test_address_from_bytes_eip55(self):
# https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md
eip55 = [
'0x52908400098527886E0F7030069857D2E4169EE7',
'0x8617E340B3D01FA5F11F306F4090FD50E238070D',
'0xde709f2102306220921060314715629080e2fb77',
'0x27b1fdb04752bbc536007a920d24acb045561c26',
'0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed',
'0xfB6916095ca1df60bB79Ce92cE3Ea74c37c5d359',
'0xdbF03B407c01E7cD3CBea99509d93f8DDDC8C6FB',
'0xD1220A0cf47c7B9Be7A2E6BA89F429762e7b9aDb',
"0x52908400098527886E0F7030069857D2E4169EE7",
"0x8617E340B3D01FA5F11F306F4090FD50E238070D",
"0xde709f2102306220921060314715629080e2fb77",
"0x27b1fdb04752bbc536007a920d24acb045561c26",
"0x5aAeb6053F3E94C9b9A09f33669435E7Ef1BeAed",
"0xfB6916095ca1df60bB79Ce92cE3Ea74c37c5d359",
"0xdbF03B407c01E7cD3CBea99509d93f8DDDC8C6FB",
"0xD1220A0cf47c7B9Be7A2E6BA89F429762e7b9aDb",
]
for s in eip55:
b = unhexlify(s[2:])
@ -29,28 +28,30 @@ class TestEthereumGetAddress(unittest.TestCase):
def test_address_from_bytes_rskip60(self):
# https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md
rskip60_chain_30 = [
'0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD',
'0xFb6916095cA1Df60bb79ce92cE3EA74c37c5d359',
'0xDBF03B407c01E7CD3cBea99509D93F8Dddc8C6FB',
'0xD1220A0Cf47c7B9BE7a2e6ba89F429762E7B9adB'
"0x5aaEB6053f3e94c9b9a09f33669435E7ef1bEAeD",
"0xFb6916095cA1Df60bb79ce92cE3EA74c37c5d359",
"0xDBF03B407c01E7CD3cBea99509D93F8Dddc8C6FB",
"0xD1220A0Cf47c7B9BE7a2e6ba89F429762E7B9adB",
]
rskip60_chain_31 = [
'0x5aAeb6053F3e94c9b9A09F33669435E7EF1BEaEd',
'0xFb6916095CA1dF60bb79CE92ce3Ea74C37c5D359',
'0xdbF03B407C01E7cd3cbEa99509D93f8dDDc8C6fB',
'0xd1220a0CF47c7B9Be7A2E6Ba89f429762E7b9adB'
"0x5aAeb6053F3e94c9b9A09F33669435E7EF1BEaEd",
"0xFb6916095CA1dF60bb79CE92ce3Ea74C37c5D359",
"0xdbF03B407C01E7cd3cbEa99509D93f8dDDc8C6fB",
"0xd1220a0CF47c7B9Be7A2E6Ba89f429762E7b9adB",
]
n = NetworkInfo(chain_id=30, slip44=1, shortcut='T', name='T', rskip60=True)
n = make_network(chain_id=30)
for s in rskip60_chain_30:
b = unhexlify(s[2:])
h = address_from_bytes(b, n)
self.assertEqual(h, s)
n.chain_id = 31
n = make_network(chain_id=31)
for s in rskip60_chain_31:
b = unhexlify(s[2:])
h = address_from_bytes(b, n)
self.assertEqual(h, s)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -1,22 +1,35 @@
from common import *
import unittest
from storage import cache
from trezor import wire
from trezor import wire, utils
from trezor.crypto import bip39
from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED
if not utils.BITCOIN_ONLY:
from apps.ethereum import CURVE
from apps.ethereum.networks import UNKNOWN_NETWORK
from apps.ethereum.keychain import (
PATTERNS_ADDRESS,
_schemas_from_address_n,
_schemas_from_network,
_defs_from_message,
_slip44_from_address_n,
with_keychain_from_path,
with_keychain_from_chain_id,
)
from apps.ethereum.networks import by_chain_id, by_slip44
from trezor.messages import EthereumGetAddress
from trezor.messages import EthereumSignTx
from trezor.messages import (
EthereumGetAddress,
EthereumSignTx,
EthereumDefinitions,
EthereumSignMessage,
EthereumSignTypedData,
EthereumSignTxEIP1559,
)
from ethereum_common import make_network, encode_network
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
@ -43,11 +56,19 @@ class TestEthereumKeychain(unittest.TestCase):
[44 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0],
[44 | HARDENED, slip44_id | HARDENED, 1 | HARDENED, 0],
[44 | HARDENED, slip44_id | HARDENED, 0 | HARDENED, 0 | HARDENED, 0],
[44 | HARDENED, slip44_id | HARDENED, 0 | HARDENED, 0 | HARDENED, 0 | HARDENED],
[
44 | HARDENED,
slip44_id | HARDENED,
0 | HARDENED,
0 | HARDENED,
0 | HARDENED,
],
)
for addr in invalid_addresses:
self.assertRaises(
wire.DataError, keychain.derive, addr,
wire.DataError,
keychain.derive,
addr,
)
def setUp(self):
@ -56,7 +77,9 @@ class TestEthereumKeychain(unittest.TestCase):
cache.set(cache.APP_COMMON_SEED, seed)
def from_address_n(self, address_n):
schemas = _schemas_from_address_n(PATTERNS_ADDRESS, address_n)
slip44 = _slip44_from_address_n(address_n)
network = make_network(slip44=slip44)
schemas = _schemas_from_network(PATTERNS_ADDRESS, network)
return await_result(get_keychain(wire.DUMMY_CONTEXT, CURVE, schemas))
def test_from_address_n(self):
@ -69,51 +92,50 @@ class TestEthereumKeychain(unittest.TestCase):
keychain = self.from_address_n([44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0])
self._check_keychain(keychain, 60)
def test_from_address_n_unknown(self):
# try Bitcoin slip44 id m/44'/0'/0'
schemas = tuple(_schemas_from_address_n(PATTERNS_ADDRESS, [44 | HARDENED, 0 | HARDENED, 0 | HARDENED]))
self.assertEqual(schemas, ())
def test_from_address_n_casa45(self):
# valid keychain m/45'/60/0
keychain = self.from_address_n([45 | HARDENED, 60, 0, 0, 0])
keychain.derive([45 | HARDENED, 60, 0, 0, 0])
with self.assertRaises(wire.DataError):
keychain.derive([45 | HARDENED, 60 | HARDENED, 0, 0, 0])
def test_bad_address_n(self):
# keychain generated from valid slip44 id but invalid address m/0'/60'/0'
keychain = self.from_address_n([0 | HARDENED, 60 | HARDENED, 0 | HARDENED])
self._check_keychain(keychain, 60)
def test_with_keychain_from_path(self):
def test_with_keychain_from_path_short(self):
# check that the keychain will not die when the address_n is too short
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def handler(ctx, msg, keychain):
self._check_keychain(keychain, msg.address_n[1] & ~HARDENED)
async def handler(ctx, msg, keychain, defs):
# in this case the network is unknown so the keychain should allow access
# to Ethereum and testnet paths
self._check_keychain(keychain, 60)
self._check_keychain(keychain, 1)
self.assertIs(defs.network, UNKNOWN_NETWORK)
await_result(
handler(
wire.DUMMY_CONTEXT,
EthereumGetAddress(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED]
),
)
await_result(handler(wire.DUMMY_CONTEXT, EthereumGetAddress(address_n=[])))
await_result(handler(wire.DUMMY_CONTEXT, EthereumGetAddress(address_n=[0])))
def test_with_keychain_from_path_builtins(self):
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def handler(ctx, msg, keychain, defs):
slip44 = msg.address_n[1] & ~HARDENED
self._check_keychain(keychain, slip44)
self.assertEqual(defs.network.slip44, slip44)
vectors = (
# Ethereum
[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
# Ethereum from Ledger Live legacy path
[44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0],
# Ethereum Classic
[44 | HARDENED, 61 | HARDENED, 0 | HARDENED],
)
await_result( # Ethereum from Ledger Live legacy path
handler(
wire.DUMMY_CONTEXT,
EthereumGetAddress(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0]
),
for address_n in vectors:
await_result(
handler(wire.DUMMY_CONTEXT, EthereumGetAddress(address_n=address_n))
)
)
await_result(
handler(
wire.DUMMY_CONTEXT,
EthereumGetAddress(
address_n=[44 | HARDENED, 108 | HARDENED, 0 | HARDENED]
),
)
)
with self.assertRaises(wire.DataError):
await_result(
handler(
handler( # unknown network
wire.DUMMY_CONTEXT,
EthereumGetAddress(
address_n=[44 | HARDENED, 0 | HARDENED, 0 | HARDENED]
@ -121,64 +143,89 @@ class TestEthereumKeychain(unittest.TestCase):
)
)
def test_with_keychain_from_chain_id(self):
def test_with_keychain_from_path_external(self):
FORBIDDEN_SYMBOL = "forbidden name"
@with_keychain_from_path(*PATTERNS_ADDRESS)
async def handler(ctx, msg, keychain, defs):
slip44 = msg.address_n[1] & ~HARDENED
self._check_keychain(keychain, slip44)
self.assertEqual(defs.network.slip44, slip44)
self.assertNotEqual(defs.network.name, FORBIDDEN_SYMBOL)
vectors_valid = ( # slip44, network_def
# invalid network is ignored when there is a builtin
(60, b"hello"),
# valid network is ignored when there is a builtin
(60, encode_network(slip44=60, symbol=FORBIDDEN_SYMBOL)),
# valid network is accepted for unknown slip44 ids
(33333, encode_network(slip44=33333)),
)
for slip44, encoded_network in vectors_valid:
await_result(
handler(
wire.DUMMY_CONTEXT,
EthereumGetAddress(
address_n=[44 | HARDENED, slip44 | HARDENED, 0 | HARDENED],
encoded_network=encoded_network,
),
)
)
vectors_invalid = ( # slip44, network_def
# invalid network is rejected
(30000, b"hello"),
# invalid network does not prove mismatched slip44 id
(30000, encode_network(slip44=666)),
)
for slip44, encoded_network in vectors_invalid:
with self.assertRaises(wire.DataError):
await_result(
handler(
wire.DUMMY_CONTEXT,
EthereumGetAddress(
address_n=[44 | HARDENED, slip44 | HARDENED, 0 | HARDENED],
encoded_network=encoded_network,
),
)
)
def test_with_keychain_from_chain_id_builtin(self):
@with_keychain_from_chain_id
async def handler_chain_id(ctx, msg, keychain):
async def handler_chain_id(ctx, msg, keychain, defs):
slip44_id = msg.address_n[1] & ~HARDENED
# standard tests
self._check_keychain(keychain, slip44_id)
# provided address should succeed too
keychain.derive(msg.address_n)
self.assertEqual(defs.network.chain_id, msg.chain_id)
await_result( # Ethereum
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
chain_id=1,
gas_price=b"",
gas_limit=b"",
),
)
vectors = ( # chain_id, address_n
# Ethereum
(1, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED]),
# Ethereum from Ledger Live legacy path
(1, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0]),
# Ethereum Classic
(61, [44 | HARDENED, 61 | HARDENED, 0 | HARDENED]),
# ETH slip44, ETC chain_id
# (known networks are allowed to use eth slip44 for cross-signing)
(61, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED]),
)
await_result( # Ethereum from Ledger Live legacy path
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED, 0],
chain_id=1,
gas_price=b"",
gas_limit=b"",
),
for chain_id, address_n in vectors:
await_result( # Ethereum
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=address_n,
chain_id=chain_id,
gas_price=b"",
gas_limit=b"",
),
)
)
)
await_result( # Ethereum Classic
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=[44 | HARDENED, 61 | HARDENED, 0 | HARDENED],
chain_id=61,
gas_price=b"",
gas_limit=b"",
),
)
)
# Known chain-ids are allowed to use Ethereum derivation paths too, as there is
# no risk of replaying the transaction on the Ethereum chain
await_result( # ETH slip44 with ETC chain-id
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
chain_id=61,
gas_price=b"",
gas_limit=b"",
),
)
)
with self.assertRaises(wire.DataError):
await_result( # chain_id and network mismatch
@ -193,6 +240,136 @@ class TestEthereumKeychain(unittest.TestCase):
)
)
def test_with_keychain_from_chain_id_external(self):
FORBIDDEN_SYMBOL = "forbidden name"
@with_keychain_from_chain_id
async def handler_chain_id(ctx, msg, keychain, defs):
slip44_id = msg.address_n[1] & ~HARDENED
# standard tests
self._check_keychain(keychain, slip44_id)
# provided address should succeed too
keychain.derive(msg.address_n)
self.assertEqual(defs.network.chain_id, msg.chain_id)
self.assertNotEqual(defs.network.name, FORBIDDEN_SYMBOL)
vectors_valid = ( # chain_id, address_n, encoded_network
# invalid network is ignored when there is a builtin
(1, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED], b"hello"),
# valid network is ignored when there is a builtin
(
1,
[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
encode_network(slip44=60, symbol=FORBIDDEN_SYMBOL),
),
# valid network is accepted for unknown chain ids
(
33333,
[44 | HARDENED, 33333 | HARDENED, 0 | HARDENED],
encode_network(slip44=33333, chain_id=33333),
),
# valid network is allowed to cross-sign for Ethereum slip44
(
33333,
[44 | HARDENED, 60 | HARDENED, 0 | HARDENED],
encode_network(slip44=33333, chain_id=33333),
),
# valid network where slip44 and chain_id are different
(
44444,
[44 | HARDENED, 33333 | HARDENED, 0 | HARDENED],
encode_network(slip44=33333, chain_id=44444),
),
)
for chain_id, address_n, encoded_network in vectors_valid:
await_result(
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=address_n,
chain_id=chain_id,
gas_price=b"",
gas_limit=b"",
definitions=EthereumDefinitions(
encoded_network=encoded_network
),
),
)
)
vectors_invalid = ( # chain_id, address_n, encoded_network
# invalid network is rejected
(30000, [44 | HARDENED, 30000 | HARDENED, 0 | HARDENED], b"hello"),
# invalid network does not prove mismatched slip44 id
(
30000,
[44 | HARDENED, 30000 | HARDENED, 0 | HARDENED],
encode_network(chain_id=30000, slip44=666),
),
# invalid network does not prove mismatched chain_id
(
30000,
[44 | HARDENED, 30000 | HARDENED, 0 | HARDENED],
encode_network(chain_id=666, slip44=30000),
),
)
for chain_id, address_n, encoded_network in vectors_invalid:
with self.assertRaises(wire.DataError):
await_result(
handler_chain_id(
wire.DUMMY_CONTEXT,
EthereumSignTx(
address_n=address_n,
chain_id=chain_id,
gas_price=b"",
gas_limit=b"",
definitions=EthereumDefinitions(
encoded_network=encoded_network
),
),
)
)
def test_message_types(self) -> None:
network = make_network(symbol="Testing Network")
encoded_network = encode_network(network)
messages = (
EthereumSignTx(
gas_price=b"",
gas_limit=b"",
chain_id=0,
definitions=EthereumDefinitions(encoded_network=encoded_network),
),
EthereumSignMessage(
message=b"",
encoded_network=encoded_network,
),
EthereumSignTxEIP1559(
chain_id=0,
gas_limit=b"",
max_gas_fee=b"",
max_priority_fee=b"",
nonce=b"",
value=b"",
data_length=0,
definitions=EthereumDefinitions(encoded_network=encoded_network),
),
EthereumSignTypedData(
primary_type="",
definitions=EthereumDefinitions(encoded_network=encoded_network),
),
EthereumGetAddress(
encoded_network=encoded_network,
),
)
for message in messages:
defs = _defs_from_message(message, chain_id=0)
self.assertEqual(defs.network, network)
if __name__ == "__main__":
unittest.main()

View File

@ -1,104 +1,109 @@
from common import *
if not utils.BITCOIN_ONLY:
from apps.ethereum import networks
from apps.ethereum.layout import format_ethereum_amount
from apps.ethereum.tokens import token_by_chain_address
from apps.ethereum.tokens import UNKNOWN_TOKEN
from ethereum_common import make_network, make_token
ETH = networks.by_chain_id(1)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestFormatEthereumAmount(unittest.TestCase):
def test_denominations(self):
text = format_ethereum_amount(1, None, ETH)
self.assertEqual(text, "1 Wei ETH")
text = format_ethereum_amount(1000, None, ETH)
self.assertEqual(text, "1,000 Wei ETH")
text = format_ethereum_amount(1000000, None, ETH)
self.assertEqual(text, "1,000,000 Wei ETH")
text = format_ethereum_amount(10000000, None, ETH)
self.assertEqual(text, "10,000,000 Wei ETH")
text = format_ethereum_amount(100000000, None, ETH)
self.assertEqual(text, "100,000,000 Wei ETH")
text = format_ethereum_amount(1000000000, None, ETH)
self.assertEqual(text, "0.000000001 ETH")
text = format_ethereum_amount(10000000000, None, ETH)
self.assertEqual(text, "0.00000001 ETH")
text = format_ethereum_amount(100000000000, None, ETH)
self.assertEqual(text, "0.0000001 ETH")
text = format_ethereum_amount(1000000000000, None, ETH)
self.assertEqual(text, "0.000001 ETH")
text = format_ethereum_amount(10000000000000, None, ETH)
self.assertEqual(text, "0.00001 ETH")
text = format_ethereum_amount(100000000000000, None, ETH)
self.assertEqual(text, "0.0001 ETH")
text = format_ethereum_amount(1000000000000000, None, ETH)
self.assertEqual(text, "0.001 ETH")
text = format_ethereum_amount(10000000000000000, None, ETH)
self.assertEqual(text, "0.01 ETH")
text = format_ethereum_amount(100000000000000000, None, ETH)
self.assertEqual(text, "0.1 ETH")
text = format_ethereum_amount(1000000000000000000, None, ETH)
self.assertEqual(text, "1 ETH")
text = format_ethereum_amount(10000000000000000000, None, ETH)
self.assertEqual(text, "10 ETH")
text = format_ethereum_amount(100000000000000000000, None, ETH)
self.assertEqual(text, "100 ETH")
text = format_ethereum_amount(1000000000000000000000, None, ETH)
self.assertEqual(text, "1,000 ETH")
def test_format(self):
text = format_ethereum_amount(1, None, 1)
self.assertEqual(text, '1 Wei ETH')
text = format_ethereum_amount(1000, None, 1)
self.assertEqual(text, '1,000 Wei ETH')
text = format_ethereum_amount(1000000, None, 1)
self.assertEqual(text, '1,000,000 Wei ETH')
text = format_ethereum_amount(10000000, None, 1)
self.assertEqual(text, '10,000,000 Wei ETH')
text = format_ethereum_amount(100000000, None, 1)
self.assertEqual(text, '100,000,000 Wei ETH')
text = format_ethereum_amount(1000000000, None, 1)
self.assertEqual(text, '0.000000001 ETH')
text = format_ethereum_amount(10000000000, None, 1)
self.assertEqual(text, '0.00000001 ETH')
text = format_ethereum_amount(100000000000, None, 1)
self.assertEqual(text, '0.0000001 ETH')
text = format_ethereum_amount(1000000000000, None, 1)
self.assertEqual(text, '0.000001 ETH')
text = format_ethereum_amount(10000000000000, None, 1)
self.assertEqual(text, '0.00001 ETH')
text = format_ethereum_amount(100000000000000, None, 1)
self.assertEqual(text, '0.0001 ETH')
text = format_ethereum_amount(1000000000000000, None, 1)
self.assertEqual(text, '0.001 ETH')
text = format_ethereum_amount(10000000000000000, None, 1)
self.assertEqual(text, '0.01 ETH')
text = format_ethereum_amount(100000000000000000, None, 1)
self.assertEqual(text, '0.1 ETH')
text = format_ethereum_amount(1000000000000000000, None, 1)
self.assertEqual(text, '1 ETH')
text = format_ethereum_amount(10000000000000000000, None, 1)
self.assertEqual(text, '10 ETH')
text = format_ethereum_amount(100000000000000000000, None, 1)
self.assertEqual(text, '100 ETH')
text = format_ethereum_amount(1000000000000000000000, None, 1)
self.assertEqual(text, '1,000 ETH')
def test_precision(self):
text = format_ethereum_amount(1000000000000000001, None, ETH)
self.assertEqual(text, "1.000000000000000001 ETH")
text = format_ethereum_amount(10000000000000000001, None, ETH)
self.assertEqual(text, "10.000000000000000001 ETH")
text = format_ethereum_amount(1000000000000000000, None, 61)
self.assertEqual(text, '1 ETC')
text = format_ethereum_amount(1000000000000000000, None, 31)
self.assertEqual(text, '1 tRBTC')
text = format_ethereum_amount(1000000000000000001, None, 1)
self.assertEqual(text, '1.000000000000000001 ETH')
text = format_ethereum_amount(10000000000000000001, None, 1)
self.assertEqual(text, '10.000000000000000001 ETH')
text = format_ethereum_amount(10000000000000000001, None, 61)
self.assertEqual(text, '10.000000000000000001 ETC')
text = format_ethereum_amount(1000000000000000001, None, 31)
self.assertEqual(text, '1.000000000000000001 tRBTC')
def test_symbols(self):
fake_network = make_network(symbol="FAKE")
text = format_ethereum_amount(1, None, fake_network)
self.assertEqual(text, "1 Wei FAKE")
text = format_ethereum_amount(1000000000000000000, None, fake_network)
self.assertEqual(text, "1 FAKE")
text = format_ethereum_amount(1000000000000000001, None, fake_network)
self.assertEqual(text, "1.000000000000000001 FAKE")
def test_unknown_chain(self):
# unknown chain
text = format_ethereum_amount(1, None, 9999)
self.assertEqual(text, '1 Wei UNKN')
text = format_ethereum_amount(10000000000000000001, None, 9999)
self.assertEqual(text, '10.000000000000000001 UNKN')
text = format_ethereum_amount(1, None, networks.UNKNOWN_NETWORK)
self.assertEqual(text, "1 Wei UNKN")
text = format_ethereum_amount(
10000000000000000001, None, networks.UNKNOWN_NETWORK
)
self.assertEqual(text, "10.000000000000000001 UNKN")
def test_tokens(self):
# tokens with low decimal values
# USDC has 6 decimals
usdc_token = token_by_chain_address(1, unhexlify("a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48"))
# ICO has 10 decimals
ico_token = token_by_chain_address(1, unhexlify("a33e729bf4fdeb868b534e1f20523463d9c46bee"))
usdc_token = make_token(symbol="USDC", decimals=6)
# when decimals < 10, should never display 'Wei' format
text = format_ethereum_amount(1, usdc_token, 1)
self.assertEqual(text, '0.000001 USDC')
text = format_ethereum_amount(0, usdc_token, 1)
self.assertEqual(text, '0 USDC')
text = format_ethereum_amount(1, usdc_token, ETH)
self.assertEqual(text, "0.000001 USDC")
text = format_ethereum_amount(0, usdc_token, ETH)
self.assertEqual(text, "0 USDC")
text = format_ethereum_amount(1, ico_token, 1)
self.assertEqual(text, '1 Wei ICO')
text = format_ethereum_amount(9, ico_token, 1)
self.assertEqual(text, '9 Wei ICO')
text = format_ethereum_amount(10, ico_token, 1)
self.assertEqual(text, '0.000000001 ICO')
text = format_ethereum_amount(11, ico_token, 1)
self.assertEqual(text, '0.0000000011 ICO')
# ICO has 10 decimals
ico_token = make_token(symbol="ICO", decimals=10)
text = format_ethereum_amount(1, ico_token, ETH)
self.assertEqual(text, "1 Wei ICO")
text = format_ethereum_amount(9, ico_token, ETH)
self.assertEqual(text, "9 Wei ICO")
text = format_ethereum_amount(10, ico_token, ETH)
self.assertEqual(text, "0.000000001 ICO")
text = format_ethereum_amount(11, ico_token, ETH)
self.assertEqual(text, "0.0000000011 ICO")
def test_unknown_token(self):
unknown_token = token_by_chain_address(1, b"hello")
text = format_ethereum_amount(1, unknown_token, 1)
self.assertEqual(text, '1 Wei UNKN')
text = format_ethereum_amount(0, unknown_token, 1)
self.assertEqual(text, '0 Wei UNKN')
text = format_ethereum_amount(1, UNKNOWN_TOKEN, ETH)
self.assertEqual(text, "1 Wei UNKN")
text = format_ethereum_amount(0, UNKNOWN_TOKEN, ETH)
self.assertEqual(text, "0 Wei UNKN")
# unknown token has 0 decimals so is always wei
text = format_ethereum_amount(1000000000000000000, unknown_token, 1)
self.assertEqual(text, '1,000,000,000,000,000,000 Wei UNKN')
text = format_ethereum_amount(1000000000000000000, UNKNOWN_TOKEN, ETH)
self.assertEqual(text, "1,000,000,000,000,000,000 Wei UNKN")
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -6,28 +6,17 @@ if not utils.BITCOIN_ONLY:
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumTokens(unittest.TestCase):
def test_token_by_chain_address(self):
token = tokens.token_by_chain_address(1, b'\x7d\xd7\xf5\x6d\x69\x7c\xc0\xf2\xb5\x2b\xd5\x5c\x05\x7f\x37\x8f\x1f\xe6\xab\x4b')
self.assertEqual(token.symbol, '$TEAK')
token = tokens.token_by_chain_address(1, b'\x59\x41\x6a\x25\x62\x8a\x76\xb4\x73\x0e\xc5\x14\x86\x11\x4c\x32\xe0\xb5\x82\xa1')
self.assertEqual(token.symbol, 'PLASMA')
self.assertEqual(token.decimals, 6)
token = tokens.token_by_chain_address(4, b'\x0a\x05\x7a\x87\xce\x9c\x56\xd7\xe3\x36\xb4\x17\xc7\x9c\xf3\x0e\x8d\x27\x86\x0b')
self.assertEqual(token.symbol, 'WALL')
self.assertEqual(token.decimals, 15)
token = tokens.token_by_chain_address(8, b'\x4b\x48\x99\xa1\x0f\x3e\x50\x7d\xb2\x07\xb0\xee\x24\x26\x02\x9e\xfa\x16\x8a\x67')
self.assertEqual(token.symbol, 'QWARK')
token = tokens.token_by_chain_address(1, b"\x7f\xc6\x65\x00\xc8\x4a\x76\xad\x7e\x9c\x93\x43\x7b\xfc\x5a\xc3\x3e\x2d\xda\xe9")
self.assertEqual(token.symbol, 'AAVE')
# invalid adress, invalid chain
token = tokens.token_by_chain_address(999, b'\x00\xFF')
self.assertIs(token, tokens.UNKNOWN_TOKEN)
self.assertEqual(token.symbol, 'Wei UNKN')
self.assertEqual(token.decimals, 0)
self.assertIs(token, None)
self.assertEqual(tokens.UNKNOWN_TOKEN.symbol, 'Wei UNKN')
self.assertEqual(tokens.UNKNOWN_TOKEN.decimals, 0)
if __name__ == '__main__':

View File

@ -26,16 +26,24 @@ class AssertRaisesContext:
class TestCase:
def __init__(self) -> None:
self.__equality_functions = {}
def fail(self, msg=''):
ensure(False, msg)
def addTypeEqualityFunc(self, typeobj, function):
ensure(callable(function))
self.__equality_functions[typeobj.__name__] = function
def assertEqual(self, x, y, msg=''):
if not msg:
msg = f"{repr(x)} vs (expected) {repr(y)}"
if x.__class__ == y.__class__ and x.__class__.__name__ == "Msg":
self.assertMessageEqual(x, y)
elif x.__class__.__name__ in self.__equality_functions:
self.__equality_functions[x.__class__.__name__](x, y, msg)
else:
ensure(x == y, msg)
@ -156,21 +164,32 @@ class TestCase:
self.assertIsInstance(a, b.__class__, msg)
self.assertEqual(a.__dict__, b.__dict__, msg)
def assertDictEqual(self, x, y):
self.assertEqual(
len(x),
len(y),
f"Dict lengths not equal - {len(x)} vs {len(y)}"
)
for key in x:
self.assertIn(
key,
y,
f"Key {key} not found in second dict."
)
self.assertEqual(
x[key],
y[key],
f"At key {key} expected {x[key]}, found {y[key]}"
)
def assertMessageEqual(self, x, y):
self.assertEqual(
x.MESSAGE_NAME,
y.MESSAGE_NAME,
f"Expected {x.MESSAGE_NAME}, found {y.MESSAGE_NAME}"
)
xdict = x.__dict__
ydict = y.__dict__
for key in xdict:
self.assertTrue(key in ydict)
self.assertEqual(
xdict[key],
ydict[key],
f"At {x.MESSAGE_NAME}.{key} expected {xdict[key]}, found {ydict[key]}"
)
self.assertDictEqual(x.__dict__, y.__dict__)
def skip(msg):