1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-07 15:18:08 +00:00
trezor-firmware/core/tests/test_apps.ethereum.definitions.py

435 lines
19 KiB
Python

from common import *
from trezor import wire
from ubinascii import hexlify # noqa: F401
if not utils.BITCOIN_ONLY:
import apps.ethereum.definitions as dfs
from apps.ethereum import networks
from ethereum_common import *
from trezor import protobuf
from trezor.enums import EthereumDefinitionType
from trezor.messages import (
EthereumEncodedDefinitions,
EthereumNetworkInfo,
EthereumTokenInfo,
EthereumGetAddress,
EthereumGetPublicKey,
EthereumSignMessage,
EthereumSignTx,
EthereumSignTxEIP1559,
EthereumSignTypedData,
EthereumVerifyMessage,
)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumDefinitionParser(unittest.TestCase):
def setUp(self):
# prefix
self.format_version = b'trzd1' + b'\x00' * 3
self.definition_type = b'\x01'
self.data_version = b'\x00\x00\x00\x02'
self.payload_length_in_bytes = b'\x00\x03'
self.prefix = self.format_version + self.definition_type + self.data_version + self.payload_length_in_bytes
# payload
self.payload = b'\x00\x00\x04' # optional length
self.payload_with_prefix = self.prefix + self.payload
# suffix - Merkle tree proof and signed root hash
self.proof_length = b'\x01'
self.proof = b'\x00' * 31 + b'\x06'
self.signed_tree_root = b'\x00' * 63 + b'\x07'
self.definition = self.payload_with_prefix + self.proof_length + self.proof + self.signed_tree_root
def test_short_message(self):
with self.assertRaises(wire.DataError):
dfs.EthereumDefinitionParser(b'\x00')
def test_ok_message(self):
parser = dfs.EthereumDefinitionParser(self.definition)
self.assertEqual(parser.format_version, self.format_version.rstrip(b'\0').decode("utf-8"))
self.assertEqual(parser.definition_type, int.from_bytes(self.definition_type, 'big'))
self.assertEqual(parser.data_version, int.from_bytes(self.data_version, 'big'))
self.assertEqual(parser.payload_length_in_bytes, int.from_bytes(self.payload_length_in_bytes, 'big'))
self.assertEqual(parser.payload, self.payload)
self.assertEqual(parser.payload_with_prefix, self.payload_with_prefix)
self.assertEqual(parser.proof_length, int.from_bytes(self.proof_length, 'big'))
self.assertEqual(parser.proof, [self.proof])
self.assertEqual(parser.signed_tree_root, self.signed_tree_root)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestDecodeDefinition(unittest.TestCase):
# successful decode network
def test_network_definition(self):
rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4)
self.assertEqual(dfs.decode_definition(rinkeby_network.definition, EthereumDefinitionType.NETWORK), rinkeby_network.info)
# successful decode token
def test_token_definition(self):
# Karma Token
kc_token = get_ethereum_token_info_with_definition(chain_id=4)
self.assertEqual(dfs.decode_definition(kc_token.definition, EthereumDefinitionType.TOKEN), kc_token.info)
def test_invalid_data(self):
rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4)
invalid_dataset = []
# mangle signature
invalid_dataset.append(bytearray(rinkeby_network.definition))
invalid_dataset[-1][-1] += 1
# mangle payload
invalid_dataset.append(bytearray(rinkeby_network.definition))
invalid_dataset[-1][16] += 1
# wrong format version
invalid_dataset.append(bytearray(rinkeby_network.definition))
invalid_dataset[-1][:5] = b'trzd2' # change "trzd1" to "trzd2"
# wrong definition type
invalid_dataset.append(bytearray(rinkeby_network.definition))
invalid_dataset[-1][8] += 1
# wrong data format version
invalid_dataset.append(bytearray(rinkeby_network.definition))
invalid_dataset[-1][13] += 1
for data in invalid_dataset:
with self.assertRaises(wire.DataError):
dfs.decode_definition(bytes(data), EthereumDefinitionType.NETWORK)
def test_wrong_requested_type(self):
rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4)
with self.assertRaises(wire.DataError):
dfs.decode_definition(rinkeby_network.definition, EthereumDefinitionType.TOKEN)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestGetNetworkDefiniton(unittest.TestCase):
def setUp(self):
# use mockup function for built-in networks
networks._networks_iterator = builtin_networks_iterator
def test_get_network_definition(self):
eth_network = get_ethereum_network_info_with_definition(chain_id=1)
self.assertEqual(dfs._get_network_definiton(None, 1), eth_network.info)
def test_built_in_preference(self):
eth_network = get_ethereum_network_info_with_definition(chain_id=1)
eth_classic_network = get_ethereum_network_info_with_definition(chain_id=61)
self.assertEqual(dfs._get_network_definiton(eth_classic_network.definition, 1), eth_network.info)
def test_no_built_in(self):
ubiq_network = get_ethereum_network_info_with_definition(chain_id=8)
# use provided (encoded) definition
self.assertEqual(dfs._get_network_definiton(ubiq_network.definition, 8), ubiq_network.info)
# here the result should be the same as above
self.assertEqual(dfs._get_network_definiton(ubiq_network.definition, None), ubiq_network.info)
# nothing should be found
self.assertIsNone(dfs._get_network_definiton(None, 8))
self.assertIsNone(dfs._get_network_definiton(None, None))
# reference chain_id is used to check the encoded network chain_id - so in case they do not equal
# error is raised
with self.assertRaises(wire.DataError):
dfs._get_network_definiton(ubiq_network.definition, ubiq_network.info.chain_id + 9999)
def test_invalid_encoded_definition(self):
rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4)
definition = bytearray(rinkeby_network.definition)
# mangle signature - this should have the same effect as it has in "decode_definition" function
definition[-1] += 1
with self.assertRaises(wire.DataError):
dfs._get_network_definiton(bytes(definition), None)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestGetTokenDefiniton(unittest.TestCase):
def setUp(self):
# use mockup function for built-in tokens
tokens.token_by_chain_address = builtin_token_by_chain_address
def test_get_token_definition(self):
aave_token = get_ethereum_token_info_with_definition(chain_id=1, token_address="7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9")
self.assertEqual(dfs._get_token_definiton(None, aave_token.info.chain_id, aave_token.info.address), aave_token.info)
def test_built_in_preference(self):
aave_token = get_ethereum_token_info_with_definition(chain_id=1, token_address="7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9")
taud_token = get_ethereum_token_info_with_definition(chain_id=1, token_address="00006100f7090010005f1bd7ae6122c3c2cf0090")
self.assertEqual(dfs._get_token_definiton(taud_token.definition, aave_token.info.chain_id, aave_token.info.address), aave_token.info)
def test_no_built_in(self):
kc_token = get_ethereum_token_info_with_definition(chain_id=4)
# use provided (encoded) definition
self.assertEqual(dfs._get_token_definiton(kc_token.definition, kc_token.info.chain_id, kc_token.info.address), kc_token.info)
# here the results should be the same as above
self.assertEqual(dfs._get_token_definiton(kc_token.definition, None, kc_token.info.address), kc_token.info)
self.assertEqual(dfs._get_token_definiton(kc_token.definition, kc_token.info.chain_id, None), kc_token.info)
self.assertEqual(dfs._get_token_definiton(kc_token.definition, None, None), kc_token.info)
# nothing should be found
self.assertEqual(dfs._get_token_definiton(None, kc_token.info.chain_id, kc_token.info.address), tokens.UNKNOWN_TOKEN)
self.assertEqual(dfs._get_token_definiton(None, None, kc_token.info.address), tokens.UNKNOWN_TOKEN)
self.assertEqual(dfs._get_token_definiton(None, kc_token.info.chain_id, None), tokens.UNKNOWN_TOKEN)
# reference chain_id and/or token address is used to check the encoded token chain_id/address - so in case they do not equal
# tokens.UNKNOWN_TOKEN is returned
self.assertEqual(dfs._get_token_definiton(kc_token.definition, kc_token.info.chain_id + 1, kc_token.info.address + b"\x00"), tokens.UNKNOWN_TOKEN)
self.assertEqual(dfs._get_token_definiton(kc_token.definition, kc_token.info.chain_id, kc_token.info.address + b"\x00"), tokens.UNKNOWN_TOKEN)
self.assertEqual(dfs._get_token_definiton(kc_token.definition, kc_token.info.chain_id + 1, kc_token.info.address), tokens.UNKNOWN_TOKEN)
self.assertEqual(dfs._get_token_definiton(kc_token.definition, None, kc_token.info.address + b"\x00"), tokens.UNKNOWN_TOKEN)
self.assertEqual(dfs._get_token_definiton(kc_token.definition, kc_token.info.chain_id + 1, None), tokens.UNKNOWN_TOKEN)
def test_invalid_encoded_definition(self):
kc_token = get_ethereum_token_info_with_definition(chain_id=4)
definition = bytearray(kc_token.definition)
# mangle signature - this should have the same effect as it has in "decode_definition" function
definition[-1] += 1
with self.assertRaises(wire.DataError):
dfs._get_token_definiton(bytes(definition), None, None)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumDefinitions(unittest.TestCase):
def setUp(self):
# use mockup functions for built-in definitions
networks._networks_iterator = builtin_networks_iterator
tokens.token_by_chain_address = builtin_token_by_chain_address
def get_and_compare_ethereum_definitions(
self,
network_definition: bytes | None,
token_definition: bytes | None,
ref_chain_id: int | None,
ref_token_address: bytes | None,
network_info: EthereumNetworkInfo | None,
token_info: EthereumTokenInfo | None,
):
# get
definitions = dfs.EthereumDefinitions(network_definition, token_definition, ref_chain_id, ref_token_address)
ref_token_dict = dict()
if token_info is not None:
ref_token_dict[token_info.address] = token_info
# compare
self.assertEqual(definitions.network, network_info)
self.assertDictEqual(definitions.token_dict, ref_token_dict)
def test_get_definitions(self):
# built-in
eth_network = get_ethereum_network_info_with_definition(chain_id=1)
aave_token = get_ethereum_token_info_with_definition(chain_id=1, token_address="7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9")
# not built-in
rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4)
kc_token = get_ethereum_token_info_with_definition(chain_id=4)
# these variations should have the same result - successfully load built-in or encoded network/token
calls_params = [
(None, None, eth_network.info.chain_id, aave_token.info.address, eth_network.info, aave_token.info),
(rinkeby_network.definition, None, eth_network.info.chain_id, aave_token.info.address, eth_network.info, aave_token.info),
(None, kc_token.definition, eth_network.info.chain_id, aave_token.info.address, eth_network.info, aave_token.info),
(rinkeby_network.definition, kc_token.definition, eth_network.info.chain_id, aave_token.info.address, eth_network.info, aave_token.info),
(rinkeby_network.definition, kc_token.definition, None, kc_token.info.address, rinkeby_network.info, kc_token.info),
(rinkeby_network.definition, kc_token.definition, rinkeby_network.info.chain_id, None, rinkeby_network.info, kc_token.info),
(rinkeby_network.definition, kc_token.definition, None, None, rinkeby_network.info, kc_token.info),
]
for params in calls_params:
self.get_and_compare_ethereum_definitions(*params)
def test_no_network_or_token(self):
rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4)
kc_token = get_ethereum_token_info_with_definition(chain_id=4)
calls_params = [
# without network there should be no token loaded
(None, kc_token.definition, None, kc_token.info.address, None, None),
(None, kc_token.definition, 0, kc_token.info.address, None, None), # non-existing chain_id
# also without token there should be no token loaded
(rinkeby_network.definition, None, rinkeby_network.info.chain_id, None, rinkeby_network.info, None),
(rinkeby_network.definition, None, rinkeby_network.info.chain_id, kc_token.info.address + b"\x00", rinkeby_network.info, None), # non-existing token address
]
for params in calls_params:
self.get_and_compare_ethereum_definitions(*params)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestGetDefinitonsFromMsg(unittest.TestCase):
def setUp(self):
# use mockup functions for built-in definitions
networks._networks_iterator = builtin_networks_iterator
tokens.token_by_chain_address = builtin_token_by_chain_address
def get_and_compare_ethereum_definitions(
self,
msg: protobuf.MessageType,
network_info: EthereumNetworkInfo | None,
token_info: EthereumTokenInfo | None,
):
# get
definitions = dfs.get_definitions_from_msg(msg)
ref_token_dict = dict()
if token_info is not None:
ref_token_dict[token_info.address] = token_info
# compare
self.assertEqual(definitions.network, network_info)
self.assertDictEqual(definitions.token_dict, ref_token_dict)
def test_get_definitions_SignTx_messages(self):
# built-in
eth_network = get_ethereum_network_info_with_definition(chain_id=1)
aave_token = get_ethereum_token_info_with_definition(chain_id=1, token_address="7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9")
# not built-in
rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4)
kc_token = get_ethereum_token_info_with_definition(chain_id=4)
def create_EthereumSignTx_msg(**kwargs):
return EthereumSignTx(
gas_price=b'',
gas_limit=b'',
**kwargs
)
def create_EthereumSignTxEIP1559_msg(**kwargs):
return EthereumSignTxEIP1559(
nonce=b'',
max_gas_fee=b'',
max_priority_fee=b'',
gas_limit=b'',
value=b'',
data_length=0,
**kwargs
)
# both network and token should be loaded
params_set = [
(
create_EthereumSignTx_msg(
chain_id=rinkeby_network.info.chain_id,
to=hexlify(kc_token.info.address),
definitions=EthereumEncodedDefinitions(
encoded_network=rinkeby_network.definition,
encoded_token=kc_token.definition,
),
),
rinkeby_network.info,
kc_token.info,
),
(
create_EthereumSignTx_msg(
chain_id=eth_network.info.chain_id,
to=hexlify(aave_token.info.address),
),
eth_network.info,
aave_token.info,
),
(
create_EthereumSignTxEIP1559_msg(
chain_id=rinkeby_network.info.chain_id,
to=hexlify(kc_token.info.address),
definitions=EthereumEncodedDefinitions(
encoded_network=rinkeby_network.definition,
encoded_token=kc_token.definition,
),
),
rinkeby_network.info,
kc_token.info,
),
(
create_EthereumSignTxEIP1559_msg(
chain_id=eth_network.info.chain_id,
to=hexlify(aave_token.info.address),
),
eth_network.info,
aave_token.info,
),
]
for params in params_set:
self.get_and_compare_ethereum_definitions(*params)
# missing "to" parameter in messages should lead to no token is loaded if none was provided
params_set = [
(
create_EthereumSignTx_msg(
chain_id=rinkeby_network.info.chain_id,
definitions=EthereumEncodedDefinitions(
encoded_network=rinkeby_network.definition,
encoded_token=None,
),
),
rinkeby_network.info,
None,
),
(
create_EthereumSignTx_msg(
chain_id=eth_network.info.chain_id,
),
eth_network.info,
None,
),
(
create_EthereumSignTxEIP1559_msg(
chain_id=rinkeby_network.info.chain_id,
definitions=EthereumEncodedDefinitions(
encoded_network=rinkeby_network.definition,
encoded_token=None
),
),
rinkeby_network.info,
None,
),
(
create_EthereumSignTxEIP1559_msg(
chain_id=eth_network.info.chain_id,
),
eth_network.info,
None,
),
]
for params in params_set:
self.get_and_compare_ethereum_definitions(*params)
def test_other_messages(self):
rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4)
# only network should be loaded
messages = [
EthereumGetAddress(encoded_network=rinkeby_network.definition),
EthereumGetPublicKey(encoded_network=rinkeby_network.definition),
EthereumSignMessage(message=b'', encoded_network=rinkeby_network.definition),
EthereumSignTypedData(primary_type="", encoded_network=rinkeby_network.definition),
EthereumVerifyMessage(signature=b'', message=b'', address="", encoded_network=rinkeby_network.definition),
]
for msg in messages:
self.get_and_compare_ethereum_definitions(msg, rinkeby_network.info, None)
# neither network nor token should be loaded
messages = [
EthereumGetAddress(),
EthereumGetPublicKey(),
EthereumSignMessage(message=b''),
EthereumSignTypedData(primary_type=""),
EthereumVerifyMessage(signature=b'', message=b'', address=""),
]
for msg in messages:
self.get_and_compare_ethereum_definitions(msg, None, None)
def test_invalid_message(self):
# msg without any of the required fields - chain_id, to, definitions, encoded_network
class InvalidMsg():
pass
self.get_and_compare_ethereum_definitions(InvalidMsg(), None, None)
if __name__ == "__main__":
unittest.main()