mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-23 14:58:09 +00:00
256 lines
9.8 KiB
Python
256 lines
9.8 KiB
Python
from common import *
|
|
import unittest
|
|
import typing as t
|
|
from trezor import utils, wire
|
|
from ubinascii import hexlify # noqa: F401
|
|
|
|
if not utils.BITCOIN_ONLY:
|
|
|
|
from apps.ethereum import networks, tokens
|
|
from apps.ethereum.definitions import decode_definition, Definitions
|
|
from ethereum_common import *
|
|
from trezor import protobuf
|
|
from trezor.enums import EthereumDefinitionType
|
|
from trezor.messages import (
|
|
EthereumDefinitions,
|
|
EthereumNetworkInfo,
|
|
EthereumTokenInfo,
|
|
EthereumSignTx,
|
|
EthereumSignTxEIP1559,
|
|
EthereumSignTypedData,
|
|
)
|
|
|
|
TETHER_ADDRESS = b"\xda\xc1\x7f\x95\x8d\x2e\xe5\x23\xa2\x20\x62\x06\x99\x45\x97\xc1\x3d\x83\x1e\xc7"
|
|
|
|
|
|
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
|
|
class TestDecodeDefinition(unittest.TestCase):
|
|
def test_short_message(self):
|
|
with self.assertRaises(wire.DataError):
|
|
decode_definition(b"\x00", EthereumNetworkInfo)
|
|
with self.assertRaises(wire.DataError):
|
|
decode_definition(b"\x00", EthereumTokenInfo)
|
|
|
|
# successful decode network
|
|
def test_network_definition(self):
|
|
network = make_network(chain_id=42, slip44=69, symbol="FAKE", name="Fakenet")
|
|
encoded = encode_network(network)
|
|
try:
|
|
self.assertEqual(decode_definition(encoded, EthereumNetworkInfo), network)
|
|
except Exception as e:
|
|
print(e.message)
|
|
|
|
# successful decode token
|
|
def test_token_definition(self):
|
|
token = make_token("FAKE", decimals=33, address=b"abcd" * 5, chain_id=42)
|
|
encoded = encode_token(token)
|
|
self.assertEqual(decode_definition(encoded, EthereumTokenInfo), token)
|
|
|
|
def assertFailed(self, data: bytes) -> None:
|
|
with self.assertRaises(wire.DataError):
|
|
decode_definition(data, EthereumNetworkInfo)
|
|
|
|
def test_mangled_signature(self):
|
|
payload = make_payload()
|
|
proof, signature = sign_payload(payload, [])
|
|
bad_signature = signature[:-1] + b"\xff"
|
|
self.assertFailed(payload + proof + bad_signature)
|
|
|
|
def test_not_enough_signatures(self):
|
|
payload = make_payload()
|
|
proof, signature = sign_payload(payload, [], threshold=1)
|
|
self.assertFailed(payload + proof + signature)
|
|
|
|
def test_missing_signature(self):
|
|
payload = make_payload()
|
|
proof, signature = sign_payload(payload, [])
|
|
self.assertFailed(payload + proof)
|
|
|
|
def test_mangled_payload(self):
|
|
payload = make_payload()
|
|
proof, signature = sign_payload(payload, [])
|
|
bad_payload = payload[:-1] + b"\xff"
|
|
self.assertFailed(bad_payload + proof + signature)
|
|
|
|
def test_proof_length_mismatch(self):
|
|
payload = make_payload()
|
|
proof, signature = sign_payload(payload, [])
|
|
bad_proof = b"\x01"
|
|
self.assertFailed(payload + bad_proof + signature)
|
|
|
|
def test_bad_proof(self):
|
|
payload = make_payload()
|
|
proof, signature = sign_payload(payload, [sha256(b"x").digest()])
|
|
bad_proof = proof[:-1] + b"\xff"
|
|
self.assertFailed(payload + bad_proof + signature)
|
|
|
|
def test_trimmed_proof(self):
|
|
payload = make_payload()
|
|
proof, signature = sign_payload(payload, [])
|
|
bad_proof = proof[:-1]
|
|
self.assertFailed(payload + bad_proof + signature)
|
|
|
|
def test_bad_prefix(self):
|
|
payload = make_payload(prefix=b"trzd2")
|
|
proof, signature = sign_payload(payload, [])
|
|
self.assertFailed(payload + proof + signature)
|
|
|
|
def test_bad_type(self):
|
|
payload = make_payload(
|
|
data_type=EthereumDefinitionType.TOKEN, message=make_token()
|
|
)
|
|
proof, signature = sign_payload(payload, [])
|
|
self.assertFailed(payload + proof + signature)
|
|
|
|
def test_outdated(self):
|
|
payload = make_payload(timestamp=0)
|
|
proof, signature = sign_payload(payload, [])
|
|
self.assertFailed(payload + proof + signature)
|
|
|
|
def test_malformed_protobuf(self):
|
|
payload = make_payload(message=b"\x00")
|
|
proof, signature = sign_payload(payload, [])
|
|
self.assertFailed(payload + proof + signature)
|
|
|
|
def test_protobuf_mismatch(self):
|
|
payload = make_payload(
|
|
data_type=EthereumDefinitionType.NETWORK, message=make_token()
|
|
)
|
|
proof, signature = sign_payload(payload, [])
|
|
self.assertFailed(payload + proof + signature)
|
|
|
|
payload = make_payload(
|
|
data_type=EthereumDefinitionType.TOKEN, message=make_network()
|
|
)
|
|
proof, signature = sign_payload(payload, [])
|
|
self.assertFailed(payload + proof + signature)
|
|
|
|
def test_trailing_garbage(self):
|
|
payload = make_payload()
|
|
proof, signature = sign_payload(payload, [])
|
|
self.assertFailed(payload + proof + signature + b"\x00")
|
|
|
|
|
|
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
|
|
class TestEthereumDefinitions(unittest.TestCase):
|
|
def assertUnknown(self, what: t.Any) -> None:
|
|
if what is networks.UNKNOWN_NETWORK:
|
|
return
|
|
if what is tokens.UNKNOWN_TOKEN:
|
|
return
|
|
self.fail("Expected UNKNOWN_*, got %r" % what)
|
|
|
|
def assertKnown(self, what: t.Any) -> None:
|
|
if not EthereumNetworkInfo.is_type_of(
|
|
what
|
|
) and not EthereumTokenInfo.is_type_of(what):
|
|
self.fail("Expected network / token info, got %r" % what)
|
|
if what is networks.UNKNOWN_NETWORK:
|
|
self.fail("Expected known network, got UNKNOWN_NETWORK")
|
|
if what is tokens.UNKNOWN_TOKEN:
|
|
self.fail("Expected known token, got UNKNOWN_TOKEN")
|
|
|
|
def test_empty(self) -> None:
|
|
# no slip44 nor chain_id -- should short-circuit and always be unknown
|
|
defs = Definitions.from_encoded(None, None)
|
|
self.assertUnknown(defs.network)
|
|
self.assertFalse(defs._tokens)
|
|
self.assertUnknown(defs.get_token(TETHER_ADDRESS))
|
|
|
|
# chain_id provided, no definition
|
|
defs = Definitions.from_encoded(None, None, chain_id=100_000)
|
|
self.assertUnknown(defs.network)
|
|
self.assertFalse(defs._tokens)
|
|
self.assertUnknown(defs.get_token(TETHER_ADDRESS))
|
|
|
|
def test_builtin(self) -> None:
|
|
defs = Definitions.from_encoded(None, None, chain_id=1)
|
|
self.assertKnown(defs.network)
|
|
self.assertFalse(defs._tokens)
|
|
self.assertKnown(defs.get_token(TETHER_ADDRESS))
|
|
self.assertUnknown(defs.get_token(b"\x00" * 20))
|
|
|
|
defs = Definitions.from_encoded(None, None, slip44=60)
|
|
self.assertKnown(defs.network)
|
|
self.assertFalse(defs._tokens)
|
|
self.assertKnown(defs.get_token(TETHER_ADDRESS))
|
|
self.assertUnknown(defs.get_token(b"\x00" * 20))
|
|
|
|
def test_external(self) -> None:
|
|
network = make_network(chain_id=42)
|
|
defs = Definitions.from_encoded(encode_network(network), None, chain_id=42)
|
|
self.assertEqual(defs.network, network)
|
|
self.assertUnknown(defs.get_token(b"\x00" * 20))
|
|
|
|
token = make_token(chain_id=42, address=b"\x00" * 20)
|
|
defs = Definitions.from_encoded(
|
|
encode_network(network), encode_token(token), chain_id=42
|
|
)
|
|
self.assertEqual(defs.network, network)
|
|
self.assertEqual(defs.get_token(b"\x00" * 20), token)
|
|
|
|
token = make_token(chain_id=1, address=b"\x00" * 20)
|
|
defs = Definitions.from_encoded(None, encode_token(token), chain_id=1)
|
|
self.assertKnown(defs.network)
|
|
self.assertEqual(defs.get_token(b"\x00" * 20), token)
|
|
|
|
def test_external_token_mismatch(self) -> None:
|
|
network = make_network(chain_id=42)
|
|
token = make_token(chain_id=43, address=b"\x00" * 20)
|
|
defs = Definitions.from_encoded(encode_network(network), encode_token(token))
|
|
self.assertUnknown(defs.get_token(b"\x00" * 20))
|
|
|
|
def test_external_chain_match(self) -> None:
|
|
network = make_network(chain_id=42)
|
|
token = make_token(chain_id=42, address=b"\x00" * 20)
|
|
defs = Definitions.from_encoded(
|
|
encode_network(network), encode_token(token), chain_id=42
|
|
)
|
|
self.assertEqual(defs.network, network)
|
|
self.assertEqual(defs.get_token(b"\x00" * 20), token)
|
|
|
|
with self.assertRaises(wire.DataError):
|
|
Definitions.from_encoded(
|
|
encode_network(network), encode_token(token), chain_id=333
|
|
)
|
|
|
|
def test_external_slip44_mismatch(self) -> None:
|
|
network = make_network(chain_id=42, slip44=1999)
|
|
token = make_token(chain_id=42, address=b"\x00" * 20)
|
|
defs = Definitions.from_encoded(
|
|
encode_network(network), encode_token(token), slip44=1999
|
|
)
|
|
self.assertEqual(defs.network, network)
|
|
self.assertEqual(defs.get_token(b"\x00" * 20), token)
|
|
|
|
with self.assertRaises(wire.DataError):
|
|
Definitions.from_encoded(
|
|
encode_network(network), encode_token(token), slip44=333
|
|
)
|
|
|
|
def test_ignore_encoded_network(self) -> None:
|
|
# when network is builtin, ignore the encoded one
|
|
network = encode_network(chain_id=1, symbol="BAD")
|
|
defs = Definitions.from_encoded(network, None, chain_id=1)
|
|
self.assertNotEqual(defs.network, network)
|
|
|
|
def test_ignore_encoded_token(self) -> None:
|
|
# when token is builtin, ignore the encoded one
|
|
token = encode_token(chain_id=1, address=TETHER_ADDRESS, symbol="BAD")
|
|
defs = Definitions.from_encoded(None, token, chain_id=1)
|
|
self.assertNotEqual(defs.get_token(TETHER_ADDRESS), token)
|
|
|
|
def test_ignore_with_no_match(self) -> None:
|
|
network = encode_network(chain_id=100_000, symbol="BAD")
|
|
# smoke test: definition is accepted
|
|
defs = Definitions.from_encoded(network, None, chain_id=100_000)
|
|
self.assertKnown(defs.network)
|
|
|
|
# same definition but nothing to match it to
|
|
defs = Definitions.from_encoded(network, None)
|
|
self.assertUnknown(defs.network)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|