1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-07-23 23:18:16 +00:00
trezor-firmware/core/tests/test_apps.ethereum.definitions.py
2022-12-06 10:10:14 +01:00

345 lines
15 KiB
Python

from common import *
from trezor import wire
from ubinascii import hexlify # noqa: F401
if not utils.BITCOIN_ONLY:
import apps.ethereum.definitions as dfs
from apps.ethereum import networks, tokens
from ethereum_common import *
from trezor import protobuf
from trezor.messages import (
EthereumDefinitions,
EthereumNetworkInfo,
EthereumTokenInfo,
EthereumSignTx,
EthereumSignTxEIP1559,
EthereumSignTypedData,
)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestDecodeDefinition(unittest.TestCase):
def test_short_message(self):
with self.assertRaises(wire.DataError):
dfs.decode_definition(b'\x00', EthereumNetworkInfo)
with self.assertRaises(wire.DataError):
dfs.decode_definition(b'\x00', EthereumTokenInfo)
# successful decode network
def test_network_definition(self):
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
self.assertEqual(dfs.decode_definition(ubiq_network_encoded, EthereumNetworkInfo), ubiq_network)
# successful decode token
def test_token_definition(self):
# Sphere Token
sphr_token = get_reference_ethereum_token_info(8, "20e3dd746ddf519b23ffbbb6da7a5d33ea6349d6")
sphr_token_encoded = get_encoded_token_definition(sphr_token)
self.assertEqual(dfs.decode_definition(sphr_token_encoded, EthereumTokenInfo), sphr_token)
def test_invalid_data(self):
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
invalid_dataset = []
# mangle Merkle tree proof
invalid_dataset.append(bytearray(ubiq_network_encoded))
invalid_dataset[-1][-65] += 1
# mangle signature
invalid_dataset.append(bytearray(ubiq_network_encoded))
invalid_dataset[-1][-1] += 1
# mangle payload
invalid_dataset.append(bytearray(ubiq_network_encoded))
invalid_dataset[-1][16] += 1
# wrong format version
invalid_dataset.append(bytearray(ubiq_network_encoded))
invalid_dataset[-1][:5] = b'trzd2' # change "trzd1" to "trzd2"
# wrong definition type
invalid_dataset.append(bytearray(ubiq_network_encoded))
invalid_dataset[-1][8] += 1
# wrong data format version
invalid_dataset.append(bytearray(ubiq_network_encoded))
invalid_dataset[-1][13] += 1
for data in invalid_dataset:
with self.assertRaises(wire.DataError):
dfs.decode_definition(bytes(data), EthereumNetworkInfo)
def test_wrong_requested_type(self):
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
with self.assertRaises(wire.DataError):
dfs.decode_definition(ubiq_network_encoded, EthereumTokenInfo)
def test_outdated_definition(self):
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network, 0)
with self.assertRaises(wire.DataError):
dfs.decode_definition(ubiq_network_encoded, EthereumNetworkInfo)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestGetAndCheckDefinition(unittest.TestCase):
def test_get_network_definition(self):
eth_network = get_reference_ethereum_network_info(1)
eth_network_encoded = get_encoded_network_definition(eth_network)
self.assertEqual(dfs.get_and_check_definition(eth_network_encoded, EthereumNetworkInfo, 1), eth_network)
self.assertEqual(dfs.get_and_check_definition(eth_network_encoded, EthereumNetworkInfo), eth_network)
def test_get_token_definition(self):
aave_token = get_reference_ethereum_token_info(1, "7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9")
aave_token_encoded = get_encoded_token_definition(aave_token)
self.assertEqual(dfs.get_and_check_definition(aave_token_encoded, EthereumTokenInfo, 1), aave_token)
self.assertEqual(dfs.get_and_check_definition(aave_token_encoded, EthereumTokenInfo), aave_token)
def test_invalid_expected_type(self):
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
with self.assertRaises(wire.DataError):
dfs.get_and_check_definition(ubiq_network_encoded, EthereumTokenInfo, 8)
sphr_token = get_reference_ethereum_token_info(8, "20e3dd746ddf519b23ffbbb6da7a5d33ea6349d6")
sphr_token_encoded = get_encoded_token_definition(sphr_token)
with self.assertRaises(wire.DataError):
dfs.get_and_check_definition(sphr_token_encoded, EthereumNetworkInfo, 8)
def test_fail_check_chain_id(self):
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
with self.assertRaises(wire.DataError):
dfs.get_and_check_definition(ubiq_network_encoded, EthereumNetworkInfo, 1)
sphr_token = get_reference_ethereum_token_info(8, "20e3dd746ddf519b23ffbbb6da7a5d33ea6349d6")
sphr_token_encoded = get_encoded_token_definition(sphr_token)
with self.assertRaises(wire.DataError):
dfs.get_and_check_definition(sphr_token_encoded, EthereumTokenInfo, 1)
def test_invalid_encoded_definition(self):
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
definition = bytearray(ubiq_network_encoded)
# mangle signature - this should have the same effect as it has in "decode_definition" function
definition[-1] += 1
with self.assertRaises(wire.DataError):
dfs.get_and_check_definition(bytes(definition), EthereumNetworkInfo, 8)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestEthereumDefinitions(unittest.TestCase):
def get_and_compare_ethereum_definitions(
self,
network_definition: bytes | None,
token_definition: bytes | None,
ref_chain_id: int | None,
ref_token_address: bytes,
network_info: EthereumNetworkInfo,
token_info: EthereumTokenInfo,
):
# get
definitions = dfs.Definitions(network_definition, token_definition, ref_chain_id)
# compare
self.assertEqual(definitions.network, network_info)
self.assertEqual(definitions.get_token(ref_token_address), token_info)
def test_get_definitions(self):
# built-in
eth_network = get_reference_ethereum_network_info(1)
aave_token = get_reference_ethereum_token_info(1, "7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9")
# not built-in
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
sphr_token = get_reference_ethereum_token_info(8, "20e3dd746ddf519b23ffbbb6da7a5d33ea6349d6")
sphr_token_encoded = get_encoded_token_definition(sphr_token)
calls_params = [
# no network
(None, None, None, aave_token.address, networks.UNKNOWN_NETWORK, tokens.UNKNOWN_TOKEN),
# no encoded definitions
(None, None, eth_network.chain_id, aave_token.address, eth_network, aave_token),
# no encoded definitions - token address from other chain
(None, None, eth_network.chain_id, sphr_token.address, eth_network, tokens.UNKNOWN_TOKEN),
# with encoded network definition
(ubiq_network_encoded, None, None, aave_token.address, ubiq_network, tokens.UNKNOWN_TOKEN),
(ubiq_network_encoded, None, None, sphr_token.address, ubiq_network, tokens.UNKNOWN_TOKEN),
(ubiq_network_encoded, None, eth_network.chain_id, aave_token.address, eth_network, aave_token),
(ubiq_network_encoded, None, ubiq_network.chain_id, sphr_token.address, ubiq_network, tokens.UNKNOWN_TOKEN),
# with encoded network definition - token address from other chain
(ubiq_network_encoded, None, eth_network.chain_id, sphr_token.address, eth_network, tokens.UNKNOWN_TOKEN),
# with encoded network and token definition
(ubiq_network_encoded, sphr_token_encoded, None, sphr_token.address, ubiq_network, sphr_token),
(ubiq_network_encoded, sphr_token_encoded, ubiq_network.chain_id, sphr_token.address, ubiq_network, sphr_token),
# with encoded network and token definition - token address from other chain
(ubiq_network_encoded, sphr_token_encoded, None, aave_token.address, ubiq_network, tokens.UNKNOWN_TOKEN),
(ubiq_network_encoded, sphr_token_encoded, ubiq_network.chain_id, aave_token.address, ubiq_network, tokens.UNKNOWN_TOKEN),
]
for params in calls_params:
self.get_and_compare_ethereum_definitions(*params)
def test_get_definitions_chain_id_mismatch(self):
# built-in
eth_network = get_reference_ethereum_network_info(1)
eth_network_encoded = get_encoded_network_definition(eth_network)
aave_token = get_reference_ethereum_token_info(1, "7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9")
aave_token_encoded = get_encoded_token_definition(aave_token)
# not built-in
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
sphr_token = get_reference_ethereum_token_info(8, "20e3dd746ddf519b23ffbbb6da7a5d33ea6349d6")
sphr_token_encoded = get_encoded_token_definition(sphr_token)
# these variations should have the same result - error on chain id check in encoded definition
calls_params = [
(None, sphr_token_encoded, None),
(None, aave_token_encoded, None),
(None, sphr_token_encoded, eth_network.chain_id),
(None, aave_token_encoded, ubiq_network.chain_id),
(eth_network_encoded, None, ubiq_network.chain_id),
(ubiq_network_encoded, sphr_token_encoded, eth_network.chain_id),
(eth_network_encoded, aave_token_encoded, ubiq_network.chain_id),
]
for params in calls_params:
with self.assertRaises(wire.DataError):
dfs.Definitions(*params)
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestGetDefinitonsFromMsg(unittest.TestCase):
def get_and_compare_ethereum_definitions(
self,
msg: protobuf.MessageType,
network_info: EthereumNetworkInfo,
token_info: EthereumTokenInfo,
ref_token_address: bytes,
):
# get
definitions = dfs.get_definitions_from_msg(msg)
# compare
self.assertEqual(definitions.network, network_info)
self.assertEqual(definitions.get_token(ref_token_address), token_info)
def test_get_definitions_SignTx_messages(self):
# built-in
eth_network = get_reference_ethereum_network_info(1)
aave_token = get_reference_ethereum_token_info(1, "7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9")
# not built-in
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
sphr_token = get_reference_ethereum_token_info(8, "20e3dd746ddf519b23ffbbb6da7a5d33ea6349d6")
sphr_token_encoded = get_encoded_token_definition(sphr_token)
def create_EthereumSignTx_msg(**kwargs):
return EthereumSignTx(
gas_price=b'',
gas_limit=b'',
**kwargs
)
def create_EthereumSignTxEIP1559_msg(**kwargs):
return EthereumSignTxEIP1559(
nonce=b'',
max_gas_fee=b'',
max_priority_fee=b'',
gas_limit=b'',
value=b'',
data_length=0,
**kwargs
)
# both network and token should be loaded
params_set = [
(
create_EthereumSignTx_msg(
chain_id=ubiq_network.chain_id,
to=hexlify(sphr_token.address),
definitions=EthereumDefinitions(
encoded_network=ubiq_network_encoded,
encoded_token=sphr_token_encoded,
),
),
ubiq_network,
sphr_token,
sphr_token.address,
),
(
create_EthereumSignTx_msg(
chain_id=eth_network.chain_id,
to=hexlify(aave_token.address),
),
eth_network,
aave_token,
aave_token.address,
),
(
create_EthereumSignTxEIP1559_msg(
chain_id=ubiq_network.chain_id,
to=hexlify(sphr_token.address),
definitions=EthereumDefinitions(
encoded_network=ubiq_network_encoded,
encoded_token=sphr_token_encoded,
),
),
ubiq_network,
sphr_token,
sphr_token.address,
),
(
create_EthereumSignTxEIP1559_msg(
chain_id=eth_network.chain_id,
to=hexlify(aave_token.address),
),
eth_network,
aave_token,
aave_token.address,
),
]
for params in params_set:
self.get_and_compare_ethereum_definitions(*params)
def test_EthereumSignTypedData_message(self):
ubiq_network = get_reference_ethereum_network_info(8)
ubiq_network_encoded = get_encoded_network_definition(ubiq_network)
sphr_token = get_reference_ethereum_token_info(8, "20e3dd746ddf519b23ffbbb6da7a5d33ea6349d6")
msg = EthereumSignTypedData(
primary_type="",
definitions=EthereumDefinitions(
encoded_network=ubiq_network_encoded,
encoded_token=None,
)
)
self.get_and_compare_ethereum_definitions(msg, ubiq_network, tokens.UNKNOWN_TOKEN, sphr_token.address)
# neither network nor token should be loaded
msg = EthereumSignTypedData(primary_type="")
self.get_and_compare_ethereum_definitions(msg, networks.UNKNOWN_NETWORK, tokens.UNKNOWN_TOKEN, sphr_token.address)
def test_invalid_message(self):
sphr_token = get_reference_ethereum_token_info(8, "20e3dd746ddf519b23ffbbb6da7a5d33ea6349d6")
# msg without any of the required fields - chain_id, definitions, encoded_network
class InvalidMsg():
pass
self.get_and_compare_ethereum_definitions(InvalidMsg(), networks.UNKNOWN_NETWORK, tokens.UNKNOWN_TOKEN, sphr_token.address)
if __name__ == "__main__":
unittest.main()