diff --git a/core/.changelog.d/1794.added b/core/.changelog.d/1794.added new file mode 100644 index 000000000..d2a9e8bbc --- /dev/null +++ b/core/.changelog.d/1794.added @@ -0,0 +1 @@ +Full type-checking for Ethereum app diff --git a/core/Makefile b/core/Makefile index 26b04d51b..babe858f4 100644 --- a/core/Makefile +++ b/core/Makefile @@ -109,6 +109,7 @@ mypy: src/main.py \ src/apps/bitcoin \ src/apps/cardano \ + src/apps/ethereum \ src/apps/management \ src/apps/misc \ src/apps/webauthn \ diff --git a/core/src/apps/ethereum/address.py b/core/src/apps/ethereum/address.py index e11fecd44..d4fa1cde0 100644 --- a/core/src/apps/ethereum/address.py +++ b/core/src/apps/ethereum/address.py @@ -1,35 +1,42 @@ -from ubinascii import unhexlify +from ubinascii import hexlify, unhexlify from trezor import wire +if False: + from .networks import NetworkInfo -def address_from_bytes(address_bytes: bytes, network=None) -> str: + +def address_from_bytes(address_bytes: bytes, network: NetworkInfo | None = None) -> str: """ Converts address in bytes to a checksummed string as defined in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md """ - from ubinascii import hexlify from trezor.crypto.hashlib import sha3_256 - rskip60 = network is not None and network.rskip60 + if network is not None and network.rskip60: + prefix = str(network.chain_id) + "0x" + else: + prefix = "" - hx = hexlify(address_bytes).decode() + address_hex = hexlify(address_bytes).decode() + digest = sha3_256((prefix + address_hex).encode(), keccak=True).digest() - prefix = str(network.chain_id) + "0x" if rskip60 else "" - hs = sha3_256(prefix + hx, keccak=True).digest() - h = "" + def maybe_upper(i: int) -> str: + """Uppercase i-th letter only if the corresponding nibble has high bit set.""" + digest_byte = digest[i // 2] + hex_letter = address_hex[i] + if i % 2 == 0: + # even letter -> high nibble + bit = 0x80 + else: + # odd letter -> low nibble + bit = 0x08 + if digest_byte & bit: + return hex_letter.upper() + else: + return hex_letter - for i in range(20): - l = hx[i * 2] - if hs[i] & 0x80 and l >= "a" and l <= "f": - l = l.upper() - h += l - l = hx[i * 2 + 1] - if hs[i] & 0x08 and l >= "a" and l <= "f": - l = l.upper() - h += l - - return "0x" + h + return "0x" + "".join(maybe_upper(i) for i in range(len(address_hex))) def bytes_from_address(address: str) -> bytes: diff --git a/core/src/apps/ethereum/get_address.py b/core/src/apps/ethereum/get_address.py index b03dd36b9..b12f60bd1 100644 --- a/core/src/apps/ethereum/get_address.py +++ b/core/src/apps/ethereum/get_address.py @@ -9,9 +9,17 @@ from . import networks from .address import address_from_bytes from .keychain import PATTERNS_ADDRESS, with_keychain_from_path +if False: + from trezor.messages import EthereumGetAddress + from trezor.wire import Context + + from apps.common.keychain import Keychain + @with_keychain_from_path(*PATTERNS_ADDRESS) -async def get_address(ctx, msg, keychain): +async def get_address( + ctx: Context, msg: EthereumGetAddress, keychain: Keychain +) -> EthereumAddress: await paths.validate_path(ctx, keychain, msg.address_n) node = keychain.derive(msg.address_n) diff --git a/core/src/apps/ethereum/get_public_key.py b/core/src/apps/ethereum/get_public_key.py index 4f8e04fde..b62110e7f 100644 --- a/core/src/apps/ethereum/get_public_key.py +++ b/core/src/apps/ethereum/get_public_key.py @@ -7,9 +7,17 @@ from apps.common import coins, paths from .keychain import with_keychain_from_path +if False: + from trezor.messages import EthereumGetPublicKey + from trezor.wire import Context + + from apps.common.keychain import Keychain + @with_keychain_from_path(paths.PATTERN_BIP44_PUBKEY) -async def get_public_key(ctx, msg, keychain): +async def get_public_key( + ctx: Context, msg: EthereumGetPublicKey, keychain: Keychain +) -> EthereumPublicKey: await paths.validate_path(ctx, keychain, msg.address_n) node = keychain.derive(msg.address_n) diff --git a/core/src/apps/ethereum/keychain.py b/core/src/apps/ethereum/keychain.py index e1604c28b..010438286 100644 --- a/core/src/apps/ethereum/keychain.py +++ b/core/src/apps/ethereum/keychain.py @@ -6,17 +6,31 @@ from apps.common.keychain import get_keychain from . import CURVE, networks if False: - from typing import Callable, Iterable - from typing_extensions import Protocol + from typing import Callable, Iterable, TypeVar, Union - from protobuf import MessageType - - from trezor.messages import EthereumSignTx + from trezor.messages import ( + EthereumGetAddress, + EthereumGetPublicKey, + EthereumSignMessage, + EthereumSignTx, + EthereumSignTxEIP1559, + ) from apps.common.keychain import MsgOut, Handler, HandlerWithKeychain - class MsgWithAddressN(MessageType, Protocol): - address_n: paths.Bip32Path + EthereumMessages = Union[ + EthereumGetAddress, + EthereumGetPublicKey, + EthereumSignTx, + EthereumSignMessage, + ] + MsgIn = TypeVar("MsgIn", bound=EthereumMessages) + + EthereumSignTxAny = Union[ + EthereumSignTx, + EthereumSignTxEIP1559, + ] + MsgInChainId = TypeVar("MsgInChainId", bound=EthereumSignTxAny) # We believe Ethereum should use 44'/60'/a' for everything, because it is @@ -47,13 +61,9 @@ def _schemas_from_address_n( def with_keychain_from_path( *patterns: str, -) -> Callable[ - [HandlerWithKeychain[MsgWithAddressN, MsgOut]], Handler[MsgWithAddressN, MsgOut] -]: - def decorator( - func: HandlerWithKeychain[MsgWithAddressN, MsgOut] - ) -> Handler[MsgWithAddressN, MsgOut]: - async def wrapper(ctx: wire.Context, msg: MsgWithAddressN) -> MsgOut: +) -> Callable[[HandlerWithKeychain[MsgIn, MsgOut]], Handler[MsgIn, MsgOut]]: + def decorator(func: HandlerWithKeychain[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: + async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: schemas = _schemas_from_address_n(patterns, msg.address_n) keychain = await get_keychain(ctx, CURVE, schemas) with keychain: @@ -64,11 +74,9 @@ def with_keychain_from_path( return decorator -def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]: - if msg.chain_id is None: - return _schemas_from_address_n(PATTERNS_ADDRESS, msg.address_n) - +def _schemas_from_chain_id(msg: EthereumSignTxAny) -> Iterable[paths.PathSchema]: info = networks.by_chain_id(msg.chain_id) + slip44_id: tuple[int, ...] if info is None: # allow Ethereum or testnet paths for unknown networks slip44_id = (60, 1) @@ -85,10 +93,10 @@ def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]: def with_keychain_from_chain_id( - func: HandlerWithKeychain[EthereumSignTx, MsgOut] -) -> Handler[EthereumSignTx, MsgOut]: + func: HandlerWithKeychain[MsgInChainId, MsgOut] +) -> Handler[MsgInChainId, MsgOut]: # this is only for SignTx, and only PATTERN_ADDRESS is allowed - async def wrapper(ctx: wire.Context, msg: EthereumSignTx) -> MsgOut: + async def wrapper(ctx: wire.Context, msg: MsgInChainId) -> MsgOut: schemas = _schemas_from_chain_id(msg) keychain = await get_keychain(ctx, CURVE, schemas) with keychain: diff --git a/core/src/apps/ethereum/layout.py b/core/src/apps/ethereum/layout.py index c1fba6d42..98b62eabd 100644 --- a/core/src/apps/ethereum/layout.py +++ b/core/src/apps/ethereum/layout.py @@ -14,13 +14,24 @@ from trezor.ui.layouts.tt.altcoin import confirm_total_ethereum from . import networks, tokens from .address import address_from_bytes +if False: + from typing import Awaitable -async def require_confirm_tx(ctx, to_bytes, value, chain_id, token=None): + from trezor.wire import Context + + +def require_confirm_tx( + ctx: Context, + to_bytes: bytes, + value: int, + chain_id: int, + token: tokens.TokenInfo | None = None, +) -> Awaitable[None]: if to_bytes: to_str = address_from_bytes(to_bytes, networks.by_chain_id(chain_id)) else: to_str = "new contract?" - await confirm_output( + return confirm_output( ctx, address=to_str, amount=format_ethereum_amount(value, token, chain_id), @@ -30,10 +41,15 @@ async def require_confirm_tx(ctx, to_bytes, value, chain_id, token=None): ) -async def require_confirm_fee( - ctx, spending, gas_price, gas_limit, chain_id, token=None -): - await confirm_total_ethereum( +def require_confirm_fee( + ctx: Context, + spending: int, + gas_price: int, + gas_limit: int, + chain_id: int, + token: tokens.TokenInfo | None = None, +) -> Awaitable[None]: + return confirm_total_ethereum( ctx, format_ethereum_amount(spending, token, chain_id), format_ethereum_amount(gas_price, None, chain_id), @@ -42,8 +58,8 @@ async def require_confirm_fee( async def require_confirm_eip1559_fee( - ctx, max_priority_fee, max_gas_fee, gas_limit, chain_id -): + ctx: Context, max_priority_fee: int, max_gas_fee: int, gas_limit: int, chain_id: int +) -> None: await confirm_amount( ctx, title="Confirm fee", @@ -64,9 +80,11 @@ async def require_confirm_eip1559_fee( ) -async def require_confirm_unknown_token(ctx, address_bytes): +def require_confirm_unknown_token( + ctx: Context, address_bytes: bytes +) -> Awaitable[None]: contract_address_hex = "0x" + hexlify(address_bytes).decode() - await confirm_address( + return confirm_address( ctx, "Unknown token", contract_address_hex, @@ -77,8 +95,8 @@ async def require_confirm_unknown_token(ctx, address_bytes): ) -async def require_confirm_data(ctx, data, data_total): - await confirm_blob( +def require_confirm_data(ctx: Context, data: bytes, data_total: int) -> Awaitable[None]: + return confirm_blob( ctx, "confirm_data", title="Confirm data", @@ -88,13 +106,6 @@ async def require_confirm_data(ctx, data, data_total): ) -def format_ethereum_amount(value: int, token, chain_id: int): - if token is tokens.UNKNOWN_TOKEN: - suffix = "Wei UNKN" - decimals = 0 - elif token: - suffix = token[2] - decimals = token[3] def format_ethereum_amount( value: int, token: tokens.TokenInfo | None, chain_id: int ) -> str: diff --git a/core/src/apps/ethereum/networks.py b/core/src/apps/ethereum/networks.py index 930e926e5..f17a087d9 100644 --- a/core/src/apps/ethereum/networks.py +++ b/core/src/apps/ethereum/networks.py @@ -1,8 +1,6 @@ # generated from networks.py.mako # do not edit manually! -from micropython import const - from apps.common.paths import HARDENED if False: diff --git a/core/src/apps/ethereum/networks.py.mako b/core/src/apps/ethereum/networks.py.mako index 53bcbd8cd..56bfef680 100644 --- a/core/src/apps/ethereum/networks.py.mako +++ b/core/src/apps/ethereum/networks.py.mako @@ -1,8 +1,6 @@ # generated from networks.py.mako # do not edit manually! -from micropython import const - from apps.common.paths import HARDENED if False: diff --git a/core/src/apps/ethereum/sign_message.py b/core/src/apps/ethereum/sign_message.py index c04c87a03..a7afcbd65 100644 --- a/core/src/apps/ethereum/sign_message.py +++ b/core/src/apps/ethereum/sign_message.py @@ -10,18 +10,26 @@ from apps.common.signverify import decode_message from . import address from .keychain import PATTERNS_ADDRESS, with_keychain_from_path +if False: + from trezor.messages import EthereumSignMessage + from trezor.wire import Context -def message_digest(message): + from apps.common.keychain import Keychain + + +def message_digest(message: bytes) -> bytes: h = HashWriter(sha3_256(keccak=True)) - signed_message_header = "\x19Ethereum Signed Message:\n" + signed_message_header = b"\x19Ethereum Signed Message:\n" h.extend(signed_message_header) - h.extend(str(len(message))) + h.extend(str(len(message)).encode()) h.extend(message) return h.get_digest() @with_keychain_from_path(*PATTERNS_ADDRESS) -async def sign_message(ctx, msg, keychain): +async def sign_message( + ctx: Context, msg: EthereumSignMessage, keychain: Keychain +) -> EthereumMessageSignature: await paths.validate_path(ctx, keychain, msg.address_n) await confirm_signverify(ctx, "ETH", decode_message(msg.message)) diff --git a/core/src/apps/ethereum/sign_tx.py b/core/src/apps/ethereum/sign_tx.py index d66bcf612..05f78a0b1 100644 --- a/core/src/apps/ethereum/sign_tx.py +++ b/core/src/apps/ethereum/sign_tx.py @@ -16,6 +16,14 @@ from .layout import ( require_confirm_unknown_token, ) +if False: + from typing import Tuple + + from apps.common.keychain import Keychain + + from .keychain import EthereumSignTxAny + + # Maximum chain_id which returns the full signature_v (which must fit into an uint32). # chain_ids larger than this will only return one bit and the caller must recalculate # the full value: v = 2 * chain_id + 35 + v_bit @@ -23,9 +31,9 @@ MAX_CHAIN_ID = (0xFFFF_FFFF - 36) // 2 @with_keychain_from_chain_id -async def sign_tx(ctx, msg, keychain): - msg = sanitize(msg) - +async def sign_tx( + ctx: wire.Context, msg: EthereumSignTx, keychain: Keychain +) -> EthereumTxRequest: check(msg) await paths.validate_path(ctx, keychain, msg.address_n) @@ -84,7 +92,9 @@ async def sign_tx(ctx, msg, keychain): return result -async def handle_erc20(ctx, msg): +async def handle_erc20( + ctx: wire.Context, msg: EthereumSignTxAny +) -> Tuple[tokens.TokenInfo | None, bytes, bytes, int]: token = None address_bytes = recipient = address.bytes_from_address(msg.to) value = int.from_bytes(msg.value, "big") @@ -111,7 +121,7 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int: if msg.tx_type is not None: length += rlp.length(msg.tx_type) - for item in ( + fields: Tuple[rlp.RLPItem, ...] = ( msg.nonce, msg.gas_price, msg.gas_limit, @@ -120,8 +130,10 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int: msg.chain_id, 0, 0, - ): - length += rlp.length(item) + ) + + for field in fields: + length += rlp.length(field) length += rlp.header_length(data_total, msg.data_initial_chunk) length += data_total @@ -129,7 +141,7 @@ def get_total_length(msg: EthereumSignTx, data_total: int) -> int: return length -async def send_request_chunk(ctx, data_left: int): +async def send_request_chunk(ctx: wire.Context, data_left: int) -> EthereumTxAck: # TODO: layoutProgress ? req = EthereumTxRequest() if data_left <= 1024: @@ -140,7 +152,9 @@ async def send_request_chunk(ctx, data_left: int): return await ctx.call(req, EthereumTxAck) -def sign_digest(msg: EthereumSignTx, keychain, digest): +def sign_digest( + msg: EthereumSignTx, keychain: Keychain, digest: bytes +) -> EthereumTxRequest: node = keychain.derive(msg.address_n) signature = secp256k1.sign( node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM @@ -159,7 +173,7 @@ def sign_digest(msg: EthereumSignTx, keychain, digest): return req -def check(msg: EthereumSignTx): +def check(msg: EthereumSignTx) -> None: if msg.tx_type not in [1, 6, None]: raise wire.DataError("tx_type out of bounds") @@ -170,7 +184,7 @@ def check(msg: EthereumSignTx): raise wire.DataError("Safety check failed") -def check_data(msg: EthereumSignTx): +def check_data(msg: EthereumSignTxAny) -> None: if msg.data_length > 0: if not msg.data_initial_chunk: raise wire.DataError("Data length provided, but no initial chunk") @@ -183,15 +197,13 @@ def check_data(msg: EthereumSignTx): def check_gas(msg: EthereumSignTx) -> bool: - if msg.gas_price is None or msg.gas_limit is None: - return False if len(msg.gas_price) + len(msg.gas_limit) > 30: # sanity check that fee doesn't overflow return False return True -def check_to(msg: EthereumTxRequest) -> bool: +def check_to(msg: EthereumSignTxAny) -> bool: if msg.to == "": if msg.data_length == 0: # sending transaction to address 0 (contract creation) without a data field @@ -200,17 +212,3 @@ def check_to(msg: EthereumTxRequest) -> bool: if len(msg.to) not in (40, 42): return False return True - - -def sanitize(msg): - if msg.value is None: - msg.value = b"" - if msg.data_initial_chunk is None: - msg.data_initial_chunk = b"" - if msg.data_length is None: - msg.data_length = 0 - if msg.to is None: - msg.to = "" - if msg.nonce is None: - msg.nonce = b"" - return msg diff --git a/core/src/apps/ethereum/sign_tx_eip1559.py b/core/src/apps/ethereum/sign_tx_eip1559.py index b2533757c..f08437360 100644 --- a/core/src/apps/ethereum/sign_tx_eip1559.py +++ b/core/src/apps/ethereum/sign_tx_eip1559.py @@ -2,7 +2,7 @@ from trezor import wire from trezor.crypto import rlp from trezor.crypto.curve import secp256k1 from trezor.crypto.hashlib import sha3_256 -from trezor.messages import EthereumAccessList, EthereumSignTxEIP1559, EthereumTxRequest +from trezor.messages import EthereumAccessList, EthereumTxRequest from trezor.utils import HashWriter from apps.common import paths @@ -14,7 +14,14 @@ from .layout import ( require_confirm_eip1559_fee, require_confirm_tx, ) -from .sign_tx import check_data, check_to, handle_erc20, sanitize, send_request_chunk +from .sign_tx import check_data, check_to, handle_erc20, send_request_chunk + +if False: + from typing import Tuple + + from trezor.messages import EthereumSignTxEIP1559 + + from apps.common.keychain import Keychain TX_TYPE = 2 @@ -45,9 +52,9 @@ def write_access_list(w: HashWriter, access_list: list[EthereumAccessList]) -> N @with_keychain_from_chain_id -async def sign_tx_eip1559(ctx, msg, keychain): - msg = sanitize(msg) - +async def sign_tx_eip1559( + ctx: wire.Context, msg: EthereumSignTxEIP1559, keychain: Keychain +) -> EthereumTxRequest: check(msg) await paths.validate_path(ctx, keychain, msg.address_n) @@ -81,7 +88,7 @@ async def sign_tx_eip1559(ctx, msg, keychain): rlp.write_header(sha, total_length, rlp.LIST_HEADER_BYTE) - for field in ( + fields: Tuple[rlp.RLPItem, ...] = ( msg.chain_id, msg.nonce, msg.max_priority_fee, @@ -89,7 +96,8 @@ async def sign_tx_eip1559(ctx, msg, keychain): msg.gas_limit, address_bytes, msg.value, - ): + ) + for field in fields: rlp.write(sha, field) if data_left == 0: @@ -114,7 +122,7 @@ async def sign_tx_eip1559(ctx, msg, keychain): def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int: length = 0 - for item in ( + fields: Tuple[rlp.RLPItem, ...] = ( msg.nonce, msg.gas_limit, address.bytes_from_address(msg.to), @@ -122,8 +130,9 @@ def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int: msg.chain_id, msg.max_gas_fee, msg.max_priority_fee, - ): - length += rlp.length(item) + ) + for field in fields: + length += rlp.length(field) length += rlp.header_length(data_total, msg.data_initial_chunk) length += data_total @@ -133,7 +142,9 @@ def get_total_length(msg: EthereumSignTxEIP1559, data_total: int) -> int: return length -def sign_digest(msg: EthereumSignTxEIP1559, keychain, digest): +def sign_digest( + msg: EthereumSignTxEIP1559, keychain: Keychain, digest: bytes +) -> EthereumTxRequest: node = keychain.derive(msg.address_n) signature = secp256k1.sign( node.private_key(), digest, False, secp256k1.CANONICAL_SIG_ETHEREUM @@ -147,7 +158,7 @@ def sign_digest(msg: EthereumSignTxEIP1559, keychain, digest): return req -def check(msg: EthereumSignTxEIP1559): +def check(msg: EthereumSignTxEIP1559) -> None: check_data(msg) if not check_to(msg): diff --git a/core/src/apps/ethereum/verify_message.py b/core/src/apps/ethereum/verify_message.py index f69ccc00a..be06f9bd3 100644 --- a/core/src/apps/ethereum/verify_message.py +++ b/core/src/apps/ethereum/verify_message.py @@ -9,8 +9,12 @@ from apps.common.signverify import decode_message from .address import address_from_bytes, bytes_from_address from .sign_message import message_digest +if False: + from trezor.messages import EthereumVerifyMessage + from trezor.wire import Context -async def verify_message(ctx, msg): + +async def verify_message(ctx: Context, msg: EthereumVerifyMessage) -> Success: digest = message_digest(msg.message) if len(msg.signature) != 65: raise wire.DataError("Invalid signature") diff --git a/core/tests/test_apps.ethereum.address.py b/core/tests/test_apps.ethereum.address.py index 0db45db87..7fdf753d4 100644 --- a/core/tests/test_apps.ethereum.address.py +++ b/core/tests/test_apps.ethereum.address.py @@ -22,10 +22,9 @@ class TestEthereumGetAddress(unittest.TestCase): '0xD1220A0cf47c7B9Be7A2E6BA89F429762e7b9aDb', ] for s in eip55: - s = s[2:] - b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) + b = unhexlify(s[2:]) h = address_from_bytes(b) - self.assertEqual(h, '0x' + s) + self.assertEqual(h, s) def test_address_from_bytes_rskip60(self): # https://github.com/rsksmart/RSKIPs/blob/master/IPs/RSKIP60.md @@ -43,16 +42,14 @@ class TestEthereumGetAddress(unittest.TestCase): ] n = NetworkInfo(chain_id=30, slip44=1, shortcut='T', name='T', rskip60=True) for s in rskip60_chain_30: - s = s[2:] - b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) + b = unhexlify(s[2:]) h = address_from_bytes(b, n) - self.assertEqual(h, '0x' + s) + self.assertEqual(h, s) n.chain_id = 31 for s in rskip60_chain_31: - s = s[2:] - b = bytes([int(s[i:i + 2], 16) for i in range(0, len(s), 2)]) + b = unhexlify(s[2:]) h = address_from_bytes(b, n) - self.assertEqual(h, '0x' + s) + self.assertEqual(h, s) if __name__ == '__main__': diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py index 15621b296..536885b03 100644 --- a/core/tests/test_apps.ethereum.keychain.py +++ b/core/tests/test_apps.ethereum.keychain.py @@ -118,6 +118,8 @@ class TestEthereumKeychain(unittest.TestCase): EthereumSignTx( address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED], chain_id=1, + gas_price=b"", + gas_limit=b"", ), ) ) @@ -128,6 +130,8 @@ class TestEthereumKeychain(unittest.TestCase): EthereumSignTx( address_n=[44 | HARDENED, 61 | HARDENED, 0 | HARDENED], chain_id=61, + gas_price=b"", + gas_limit=b"", ), ) ) @@ -140,6 +144,8 @@ class TestEthereumKeychain(unittest.TestCase): EthereumSignTx( address_n=[44 | HARDENED, 60 | HARDENED, 0 | HARDENED], chain_id=61, + gas_price=b"", + gas_limit=b"", ), ) ) @@ -151,6 +157,8 @@ class TestEthereumKeychain(unittest.TestCase): EthereumSignTx( address_n=[44 | HARDENED, 61 | HARDENED, 0 | HARDENED], chain_id=2, + gas_price=b"", + gas_limit=b"", ), ) ) diff --git a/core/tests/test_apps.ethereum.layout.py b/core/tests/test_apps.ethereum.layout.py index e720bec52..8ba20f723 100644 --- a/core/tests/test_apps.ethereum.layout.py +++ b/core/tests/test_apps.ethereum.layout.py @@ -6,7 +6,7 @@ if not utils.BITCOIN_ONLY: @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") -class TestEthereumLayout(unittest.TestCase): +class TestFormatEthereumAmount(unittest.TestCase): def test_format(self): text = format_ethereum_amount(1, None, 1) @@ -60,12 +60,14 @@ class TestEthereumLayout(unittest.TestCase): text = format_ethereum_amount(1000000000000000001, None, 31) self.assertEqual(text, '1.000000000000000001 tRBTC') + def test_unknown_chain(self): # unknown chain text = format_ethereum_amount(1, None, 9999) self.assertEqual(text, '1 Wei UNKN') text = format_ethereum_amount(10000000000000000001, None, 9999) self.assertEqual(text, '10.000000000000000001 UNKN') + def test_tokens(self): # tokens with low decimal values # USDC has 6 decimals usdc_token = token_by_chain_address(1, unhexlify("a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48")) @@ -87,6 +89,16 @@ class TestEthereumLayout(unittest.TestCase): text = format_ethereum_amount(11, ico_token, 1) self.assertEqual(text, '0.0000000011 ICO') + def test_unknown_token(self): + unknown_token = token_by_chain_address(1, b"hello") + text = format_ethereum_amount(1, unknown_token, 1) + self.assertEqual(text, '1 Wei UNKN') + text = format_ethereum_amount(0, unknown_token, 1) + self.assertEqual(text, '0 Wei UNKN') + # unknown token has 0 decimals so is always wei + text = format_ethereum_amount(1000000000000000000, unknown_token, 1) + self.assertEqual(text, '1000000000000000000 Wei UNKN') + if __name__ == '__main__': unittest.main()