From f10588b2b6d813bcab90db9fb244f66aba18db3e Mon Sep 17 00:00:00 2001 From: matejcik Date: Tue, 29 Apr 2025 12:11:57 +0200 Subject: [PATCH] tests(core): split up definition unit tests and add Solana cross-parseability check --- core/tests/ethereum_common.py | 40 ++++- core/tests/test_apps.common.definitions.py | 142 +++++++++++++++++ core/tests/test_apps.ethereum.definitions.py | 152 +++---------------- core/tests/test_apps.ethereum.helpers.py | 6 +- core/tests/test_apps.ethereum.keychain.py | 26 ++-- core/tests/test_apps.ethereum.layout.py | 8 +- 6 files changed, 216 insertions(+), 158 deletions(-) create mode 100644 core/tests/test_apps.common.definitions.py diff --git a/core/tests/ethereum_common.py b/core/tests/ethereum_common.py index f5053957ef..407e0baad4 100644 --- a/core/tests/ethereum_common.py +++ b/core/tests/ethereum_common.py @@ -9,7 +9,7 @@ from trezor.enums import DefinitionType PRIVATE_KEYS_DEV = [byte * 32 for byte in (b"\xdd", b"\xde", b"\xdf")] -def make_network( +def make_eth_network( chain_id: int = 0, slip44: int = 0, symbol: str = "FAKE", @@ -23,7 +23,7 @@ def make_network( ) -def make_token( +def make_eth_token( symbol: str = "FAKE", decimals: int = 18, address: bytes = b"", @@ -39,13 +39,24 @@ def make_token( ) +def make_solana_token( + symbol: str = "FAKE", + mint: bytes = b"\x00" * 32, + name: str = "Fake token", +) -> messages.SolanaTokenInfo: + return messages.SolanaTokenInfo(symbol=symbol, mint=mint, name=name) + + def make_payload( prefix: bytes = b"trzd1", data_type: DefinitionType = DefinitionType.ETHEREUM_NETWORK, timestamp: int = 0xFFFF_FFFF, message: ( - messages.EthereumNetworkInfo | messages.EthereumTokenInfo | bytes - ) = make_network(), + messages.EthereumNetworkInfo + | messages.EthereumTokenInfo + | messages.SolanaTokenInfo + | bytes + ) = make_eth_network(), ) -> bytes: payload = prefix payload += data_type.to_bytes(1, "little") @@ -97,7 +108,7 @@ def sign_payload( return merkle_proof, sigmask_byte + signature -def encode_network( +def encode_eth_network( network: messages.EthereumNetworkInfo | None = None, chain_id: int = 0, slip44: int = 0, @@ -105,13 +116,13 @@ def encode_network( name: str = "Fake network", ) -> bytes: if network is None: - network = make_network(chain_id, slip44, symbol, name) + network = make_eth_network(chain_id, slip44, symbol, name) payload = make_payload(data_type=DefinitionType.ETHEREUM_NETWORK, message=network) proof, signature = sign_payload(payload, []) return payload + proof + signature -def encode_token( +def encode_eth_token( token: messages.EthereumTokenInfo | None = None, symbol: str = "FAKE", decimals: int = 18, @@ -120,7 +131,20 @@ def encode_token( name: str = "Fake token", ) -> bytes: if token is None: - token = make_token(symbol, decimals, address, chain_id, name) + token = make_eth_token(symbol, decimals, address, chain_id, name) payload = make_payload(data_type=DefinitionType.ETHEREUM_TOKEN, message=token) proof, signature = sign_payload(payload, []) return payload + proof + signature + + +def encode_solana_token( + token: messages.SolanaTokenInfo | None = None, + symbol: str = "FAKE", + mint: bytes = b"\x00" * 32, + name: str = "Fake token", +) -> bytes: + if token is None: + token = make_solana_token(symbol, mint, name) + payload = make_payload(data_type=DefinitionType.SOLANA_TOKEN, message=token) + proof, signature = sign_payload(payload, []) + return payload + proof + signature diff --git a/core/tests/test_apps.common.definitions.py b/core/tests/test_apps.common.definitions.py new file mode 100644 index 0000000000..9dcda2ce61 --- /dev/null +++ b/core/tests/test_apps.common.definitions.py @@ -0,0 +1,142 @@ +# flake8: noqa: F403,F405 +from common import * # isort:skip + +import typing as t +import unittest + +from trezor import utils, wire + +if not utils.BITCOIN_ONLY: + + from ethereum_common import * + from trezor.enums import DefinitionType + from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo, SolanaTokenInfo + + from apps.common.definitions import decode_definition + + +@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") +class TestDecodeDefinition(unittest.TestCase): + def test_short_message(self): + for message in (EthereumNetworkInfo, EthereumTokenInfo, SolanaTokenInfo): + with self.assertRaises(wire.DataError): + decode_definition(b"\x00", message) + + # successful decode network + def test_network_definition(self): + network = make_eth_network( + chain_id=42, slip44=69, symbol="FAKE", name="Fakenet" + ) + encoded = encode_eth_network(network) + try: + self.assertEqual(decode_definition(encoded, EthereumNetworkInfo), network) + except Exception as e: + print(e.message) + + # successful decode token + def test_token_definition(self): + token = make_eth_token("FAKE", decimals=33, address=b"abcd" * 5, chain_id=42) + encoded = encode_eth_token(token) + self.assertEqual(decode_definition(encoded, EthereumTokenInfo), token) + + # successful decode solana token + def test_solana_token_definition(self): + token = make_solana_token("FAKE", mint=b"abcd" * 5, name="Fakenet") + encoded = encode_solana_token(token) + self.assertEqual(decode_definition(encoded, SolanaTokenInfo), token) + + def assertFailed(self, data: bytes) -> None: + with self.assertRaises(wire.DataError): + decode_definition(data, EthereumNetworkInfo) + + def test_mangled_signature(self): + payload = make_payload() + proof, signature = sign_payload(payload, []) + bad_signature = signature[:-1] + b"\xff" + self.assertFailed(payload + proof + bad_signature) + + def test_not_enough_signatures(self): + payload = make_payload() + proof, signature = sign_payload(payload, [], threshold=1) + self.assertFailed(payload + proof + signature) + + def test_missing_signature(self): + payload = make_payload() + proof, _ = sign_payload(payload, []) + self.assertFailed(payload + proof) + + def test_mangled_payload(self): + payload = make_payload() + proof, signature = sign_payload(payload, []) + bad_payload = payload[:-1] + b"\xff" + self.assertFailed(bad_payload + proof + signature) + + def test_proof_length_mismatch(self): + payload = make_payload() + _, signature = sign_payload(payload, []) + bad_proof = b"\x01" + self.assertFailed(payload + bad_proof + signature) + + def test_bad_proof(self): + payload = make_payload() + proof, signature = sign_payload(payload, [sha256(b"x").digest()]) + bad_proof = proof[:-1] + b"\xff" + self.assertFailed(payload + bad_proof + signature) + + def test_trimmed_proof(self): + payload = make_payload() + proof, signature = sign_payload(payload, []) + bad_proof = proof[:-1] + self.assertFailed(payload + bad_proof + signature) + + def test_bad_prefix(self): + payload = make_payload(prefix=b"trzd2") + proof, signature = sign_payload(payload, []) + self.assertFailed(payload + proof + signature) + + def test_bad_type(self): + payload = make_payload( + data_type=DefinitionType.ETHEREUM_TOKEN, message=make_eth_token() + ) + proof, signature = sign_payload(payload, []) + self.assertFailed(payload + proof + signature) + + def test_outdated(self): + payload = make_payload(timestamp=0) + proof, signature = sign_payload(payload, []) + self.assertFailed(payload + proof + signature) + + def test_malformed_protobuf(self): + payload = make_payload(message=b"\x00") + proof, signature = sign_payload(payload, []) + self.assertFailed(payload + proof + signature) + + def test_protobuf_mismatch(self): + variants = ( + (DefinitionType.ETHEREUM_NETWORK, EthereumTokenInfo, make_eth_network()), + (DefinitionType.ETHEREUM_TOKEN, EthereumNetworkInfo, make_eth_token()), + (DefinitionType.SOLANA_TOKEN, SolanaTokenInfo, make_solana_token()), + ) + for variant in variants: + ( + encode_type, + decode_type, + _, + ) = variant + for other in variants: + if other is variant: + continue + _, _, message = other + payload = make_payload(data_type=encode_type, message=message) + proof, signature = sign_payload(payload, []) + with self.assertRaises(wire.DataError): + decode_definition(payload + proof + signature, decode_type) + + def test_trailing_garbage(self): + payload = make_payload() + proof, signature = sign_payload(payload, []) + self.assertFailed(payload + proof + signature + b"\x00") + + +if __name__ == "__main__": + unittest.main() diff --git a/core/tests/test_apps.ethereum.definitions.py b/core/tests/test_apps.ethereum.definitions.py index 5ec5299af1..02ba5fb3ae 100644 --- a/core/tests/test_apps.ethereum.definitions.py +++ b/core/tests/test_apps.ethereum.definitions.py @@ -9,124 +9,14 @@ from trezor import utils, wire if not utils.BITCOIN_ONLY: from ethereum_common import * - from trezor.enums import DefinitionType from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo - from apps.common.definitions import decode_definition from apps.ethereum import networks, tokens from apps.ethereum.definitions import Definitions TETHER_ADDRESS = b"\xda\xc1\x7f\x95\x8d\x2e\xe5\x23\xa2\x20\x62\x06\x99\x45\x97\xc1\x3d\x83\x1e\xc7" -@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") -class TestDecodeDefinition(unittest.TestCase): - def test_short_message(self): - with self.assertRaises(wire.DataError): - decode_definition(b"\x00", EthereumNetworkInfo) - with self.assertRaises(wire.DataError): - decode_definition(b"\x00", EthereumTokenInfo) - - # successful decode network - def test_network_definition(self): - network = make_network(chain_id=42, slip44=69, symbol="FAKE", name="Fakenet") - encoded = encode_network(network) - try: - self.assertEqual(decode_definition(encoded, EthereumNetworkInfo), network) - except Exception as e: - print(e.message) - - # successful decode token - def test_token_definition(self): - token = make_token("FAKE", decimals=33, address=b"abcd" * 5, chain_id=42) - encoded = encode_token(token) - self.assertEqual(decode_definition(encoded, EthereumTokenInfo), token) - - def assertFailed(self, data: bytes) -> None: - with self.assertRaises(wire.DataError): - decode_definition(data, EthereumNetworkInfo) - - def test_mangled_signature(self): - payload = make_payload() - proof, signature = sign_payload(payload, []) - bad_signature = signature[:-1] + b"\xff" - self.assertFailed(payload + proof + bad_signature) - - def test_not_enough_signatures(self): - payload = make_payload() - proof, signature = sign_payload(payload, [], threshold=1) - self.assertFailed(payload + proof + signature) - - def test_missing_signature(self): - payload = make_payload() - proof, _ = sign_payload(payload, []) - self.assertFailed(payload + proof) - - def test_mangled_payload(self): - payload = make_payload() - proof, signature = sign_payload(payload, []) - bad_payload = payload[:-1] + b"\xff" - self.assertFailed(bad_payload + proof + signature) - - def test_proof_length_mismatch(self): - payload = make_payload() - _, signature = sign_payload(payload, []) - bad_proof = b"\x01" - self.assertFailed(payload + bad_proof + signature) - - def test_bad_proof(self): - payload = make_payload() - proof, signature = sign_payload(payload, [sha256(b"x").digest()]) - bad_proof = proof[:-1] + b"\xff" - self.assertFailed(payload + bad_proof + signature) - - def test_trimmed_proof(self): - payload = make_payload() - proof, signature = sign_payload(payload, []) - bad_proof = proof[:-1] - self.assertFailed(payload + bad_proof + signature) - - def test_bad_prefix(self): - payload = make_payload(prefix=b"trzd2") - proof, signature = sign_payload(payload, []) - self.assertFailed(payload + proof + signature) - - def test_bad_type(self): - payload = make_payload( - data_type=DefinitionType.ETHEREUM_TOKEN, message=make_token() - ) - proof, signature = sign_payload(payload, []) - self.assertFailed(payload + proof + signature) - - def test_outdated(self): - payload = make_payload(timestamp=0) - proof, signature = sign_payload(payload, []) - self.assertFailed(payload + proof + signature) - - def test_malformed_protobuf(self): - payload = make_payload(message=b"\x00") - proof, signature = sign_payload(payload, []) - self.assertFailed(payload + proof + signature) - - def test_protobuf_mismatch(self): - payload = make_payload( - data_type=DefinitionType.ETHEREUM_NETWORK, message=make_token() - ) - proof, signature = sign_payload(payload, []) - self.assertFailed(payload + proof + signature) - - payload = make_payload( - data_type=DefinitionType.ETHEREUM_TOKEN, message=make_network() - ) - proof, signature = sign_payload(payload, []) - self.assertFailed(payload + proof + signature) - - def test_trailing_garbage(self): - payload = make_payload() - proof, signature = sign_payload(payload, []) - self.assertFailed(payload + proof + signature + b"\x00") - - @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestEthereumDefinitions(unittest.TestCase): def assertUnknown(self, what: t.Any) -> None: @@ -173,71 +63,73 @@ class TestEthereumDefinitions(unittest.TestCase): self.assertUnknown(defs.get_token(b"\x00" * 20)) def test_external(self) -> None: - network = make_network(chain_id=42) - defs = Definitions.from_encoded(encode_network(network), None, chain_id=42) + network = make_eth_network(chain_id=42) + defs = Definitions.from_encoded(encode_eth_network(network), None, chain_id=42) self.assertEqual(defs.network, network) self.assertUnknown(defs.get_token(b"\x00" * 20)) - token = make_token(chain_id=42, address=b"\x00" * 20) + token = make_eth_token(chain_id=42, address=b"\x00" * 20) defs = Definitions.from_encoded( - encode_network(network), encode_token(token), chain_id=42 + encode_eth_network(network), encode_eth_token(token), chain_id=42 ) self.assertEqual(defs.network, network) self.assertEqual(defs.get_token(b"\x00" * 20), token) - token = make_token(chain_id=1, address=b"\x00" * 20) - defs = Definitions.from_encoded(None, encode_token(token), chain_id=1) + token = make_eth_token(chain_id=1, address=b"\x00" * 20) + defs = Definitions.from_encoded(None, encode_eth_token(token), chain_id=1) self.assertKnown(defs.network) self.assertEqual(defs.get_token(b"\x00" * 20), token) def test_external_token_mismatch(self) -> None: - network = make_network(chain_id=42) - token = make_token(chain_id=43, address=b"\x00" * 20) - defs = Definitions.from_encoded(encode_network(network), encode_token(token)) + network = make_eth_network(chain_id=42) + token = make_eth_token(chain_id=43, address=b"\x00" * 20) + defs = Definitions.from_encoded( + encode_eth_network(network), encode_eth_token(token) + ) self.assertUnknown(defs.get_token(b"\x00" * 20)) def test_external_chain_match(self) -> None: - network = make_network(chain_id=42) - token = make_token(chain_id=42, address=b"\x00" * 20) + network = make_eth_network(chain_id=42) + token = make_eth_token(chain_id=42, address=b"\x00" * 20) defs = Definitions.from_encoded( - encode_network(network), encode_token(token), chain_id=42 + encode_eth_network(network), encode_eth_token(token), chain_id=42 ) self.assertEqual(defs.network, network) self.assertEqual(defs.get_token(b"\x00" * 20), token) with self.assertRaises(wire.DataError): Definitions.from_encoded( - encode_network(network), encode_token(token), chain_id=333 + encode_eth_network(network), encode_eth_token(token), chain_id=333 ) def test_external_slip44_mismatch(self) -> None: - network = make_network(chain_id=42, slip44=1999) - token = make_token(chain_id=42, address=b"\x00" * 20) + network = make_eth_network(chain_id=42, slip44=1999) + token = make_eth_token(chain_id=42, address=b"\x00" * 20) defs = Definitions.from_encoded( - encode_network(network), encode_token(token), slip44=1999 + encode_eth_network(network), encode_eth_token(token), slip44=1999 ) self.assertEqual(defs.network, network) self.assertEqual(defs.get_token(b"\x00" * 20), token) with self.assertRaises(wire.DataError): Definitions.from_encoded( - encode_network(network), encode_token(token), slip44=333 + encode_eth_network(network), encode_eth_token(token), slip44=333 ) def test_ignore_encoded_network(self) -> None: # when network is builtin, ignore the encoded one - network = encode_network(chain_id=1, symbol="BAD") + network = encode_eth_network(chain_id=1, symbol="BAD") defs = Definitions.from_encoded(network, None, chain_id=1) self.assertNotEqual(defs.network, network) def test_ignore_encoded_token(self) -> None: # when token is builtin, ignore the encoded one - token = encode_token(chain_id=1, address=TETHER_ADDRESS, symbol="BAD") + token = encode_eth_token(chain_id=1, address=TETHER_ADDRESS, symbol="BAD") defs = Definitions.from_encoded(None, token, chain_id=1) self.assertNotEqual(defs.get_token(TETHER_ADDRESS), token) def test_ignore_with_no_match(self) -> None: - network = encode_network(chain_id=100_000, symbol="BAD") + network = encode_eth_network(chain_id=100_000, symbol="BAD") # smoke test: definition is accepted defs = Definitions.from_encoded(network, None, chain_id=100_000) self.assertKnown(defs.network) diff --git a/core/tests/test_apps.ethereum.helpers.py b/core/tests/test_apps.ethereum.helpers.py index 5c0bbdfda7..95de8cc44a 100644 --- a/core/tests/test_apps.ethereum.helpers.py +++ b/core/tests/test_apps.ethereum.helpers.py @@ -2,7 +2,7 @@ from common import * # isort:skip if not utils.BITCOIN_ONLY: - from ethereum_common import make_network + from ethereum_common import make_eth_network from apps.ethereum.helpers import address_from_bytes @@ -41,13 +41,13 @@ class TestEthereumGetAddress(unittest.TestCase): "0xd1220a0CF47c7B9Be7A2E6Ba89f429762E7b9adB", ] - n = make_network(chain_id=30) + n = make_eth_network(chain_id=30) for s in rskip60_chain_30: b = unhexlify(s[2:]) h = address_from_bytes(b, n) self.assertEqual(h, s) - n = make_network(chain_id=31) + n = make_eth_network(chain_id=31) for s in rskip60_chain_31: b = unhexlify(s[2:]) h = address_from_bytes(b, n) diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py index 3215aba267..669f2664c5 100644 --- a/core/tests/test_apps.ethereum.keychain.py +++ b/core/tests/test_apps.ethereum.keychain.py @@ -13,7 +13,7 @@ from apps.common.keychain import get_keychain from apps.common.paths import HARDENED if not utils.BITCOIN_ONLY: - from ethereum_common import encode_network, make_network + from ethereum_common import encode_eth_network, make_eth_network from trezor.messages import ( EthereumDefinitions, EthereumGetAddress, @@ -87,7 +87,7 @@ class TestEthereumKeychain(unittest.TestCase): def from_address_n(self, address_n): slip44 = _slip44_from_address_n(address_n) - network = make_network(slip44=slip44) + network = make_eth_network(slip44=slip44) schemas = _schemas_from_network(PATTERNS_ADDRESS, network) return await_result(get_keychain(CURVE, schemas)) @@ -163,9 +163,9 @@ class TestEthereumKeychain(unittest.TestCase): # invalid network is ignored when there is a builtin (60, b"hello"), # valid network is ignored when there is a builtin - (60, encode_network(slip44=60, symbol=FORBIDDEN_SYMBOL)), + (60, encode_eth_network(slip44=60, symbol=FORBIDDEN_SYMBOL)), # valid network is accepted for unknown slip44 ids - (33333, encode_network(slip44=33333)), + (33333, encode_eth_network(slip44=33333)), ) for slip44, encoded_network in vectors_valid: @@ -182,7 +182,7 @@ class TestEthereumKeychain(unittest.TestCase): # invalid network is rejected (30000, b"hello"), # invalid network does not prove mismatched slip44 id - (30000, encode_network(slip44=666)), + (30000, encode_eth_network(slip44=666)), ) for slip44, encoded_network in vectors_invalid: @@ -262,25 +262,25 @@ class TestEthereumKeychain(unittest.TestCase): ( 1, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED], - encode_network(slip44=60, symbol=FORBIDDEN_SYMBOL), + encode_eth_network(slip44=60, symbol=FORBIDDEN_SYMBOL), ), # valid network is accepted for unknown chain ids ( 33333, [44 | HARDENED, 33333 | HARDENED, 0 | HARDENED], - encode_network(slip44=33333, chain_id=33333), + encode_eth_network(slip44=33333, chain_id=33333), ), # valid network is allowed to cross-sign for Ethereum slip44 ( 33333, [44 | HARDENED, 60 | HARDENED, 0 | HARDENED], - encode_network(slip44=33333, chain_id=33333), + encode_eth_network(slip44=33333, chain_id=33333), ), # valid network where slip44 and chain_id are different ( 44444, [44 | HARDENED, 33333 | HARDENED, 0 | HARDENED], - encode_network(slip44=33333, chain_id=44444), + encode_eth_network(slip44=33333, chain_id=44444), ), ) @@ -306,13 +306,13 @@ class TestEthereumKeychain(unittest.TestCase): ( 30000, [44 | HARDENED, 30000 | HARDENED, 0 | HARDENED], - encode_network(chain_id=30000, slip44=666), + encode_eth_network(chain_id=30000, slip44=666), ), # invalid network does not prove mismatched chain_id ( 30000, [44 | HARDENED, 30000 | HARDENED, 0 | HARDENED], - encode_network(chain_id=666, slip44=30000), + encode_eth_network(chain_id=666, slip44=30000), ), ) @@ -333,8 +333,8 @@ class TestEthereumKeychain(unittest.TestCase): ) def test_message_types(self) -> None: - network = make_network(symbol="Testing Network") - encoded_network = encode_network(network) + network = make_eth_network(symbol="Testing Network") + encoded_network = encode_eth_network(network) messages = ( EthereumSignTx( diff --git a/core/tests/test_apps.ethereum.layout.py b/core/tests/test_apps.ethereum.layout.py index b160ade4ed..5fba090a10 100644 --- a/core/tests/test_apps.ethereum.layout.py +++ b/core/tests/test_apps.ethereum.layout.py @@ -2,7 +2,7 @@ from common import * # isort:skip if not utils.BITCOIN_ONLY: - from ethereum_common import make_network, make_token + from ethereum_common import make_eth_network, make_eth_token from apps.ethereum import networks from apps.ethereum.helpers import format_ethereum_amount @@ -65,7 +65,7 @@ class TestFormatEthereumAmount(unittest.TestCase): self.assertEqual(text, "10.000000000000000001 ETH") def test_symbols(self): - fake_network = make_network(symbol="FAKE") + fake_network = make_eth_network(symbol="FAKE") text = format_ethereum_amount(1, None, fake_network) self.assertEqual(text, "1 Wei FAKE") text = format_ethereum_amount(1000000000000000000, None, fake_network) @@ -85,7 +85,7 @@ class TestFormatEthereumAmount(unittest.TestCase): def test_tokens(self): # tokens with low decimal values # USDC has 6 decimals - usdc_token = make_token(symbol="USDC", decimals=6) + usdc_token = make_eth_token(symbol="USDC", decimals=6) # when decimals < 10, should never display 'Wei' format text = format_ethereum_amount(1, usdc_token, ETH) self.assertEqual(text, "0.000001 USDC") @@ -93,7 +93,7 @@ class TestFormatEthereumAmount(unittest.TestCase): self.assertEqual(text, "0 USDC") # ICO has 10 decimals - ico_token = make_token(symbol="ICO", decimals=10) + ico_token = make_eth_token(symbol="ICO", decimals=10) text = format_ethereum_amount(1, ico_token, ETH) self.assertEqual(text, "1 Wei ICO") text = format_ethereum_amount(9, ico_token, ETH)