You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/apps/ethereum/definitions.py

155 lines
5.3 KiB

from typing import TYPE_CHECKING
from trezor import protobuf, utils
from trezor.crypto.curve import ed25519
from trezor.crypto.hashlib import sha256
from trezor.enums import EthereumDefinitionType
from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo
from trezor.wire import DataError
from apps.common import readers
from . import definitions_constants as consts, networks, tokens
from .networks import UNKNOWN_NETWORK
if TYPE_CHECKING:
from typing import TypeVar
from typing_extensions import Self
DefType = TypeVar("DefType", EthereumNetworkInfo, EthereumTokenInfo)
def decode_definition(definition: bytes, expected_type: type[DefType]) -> DefType:
# check network definition
r = utils.BufferReader(definition)
expected_type_number = EthereumDefinitionType.NETWORK
# TODO: can't check equality of MsgDefObjs now, so we check the name
if expected_type.MESSAGE_NAME == EthereumTokenInfo.MESSAGE_NAME:
expected_type_number = EthereumDefinitionType.TOKEN
try:
# first check format version
if r.read_memoryview(len(consts.FORMAT_VERSION)) != consts.FORMAT_VERSION:
raise DataError("Invalid Ethereum definition")
# second check the type of the data
if r.get() != expected_type_number:
raise DataError("Definition type mismatch")
# third check data version
if readers.read_uint32_le(r) < consts.MIN_DATA_VERSION:
raise DataError("Definition is outdated")
# get payload
payload_length = readers.read_uint16_le(r)
payload = r.read_memoryview(payload_length)
# at the end compute Merkle tree root hash using
# provided leaf data (payload with prefix) and proof
hasher = sha256(b"\x00")
hasher.update(memoryview(definition)[: r.offset])
hash = hasher.digest()
proof_length = r.get()
for _ in range(proof_length):
proof_entry = r.read_memoryview(32)
hash_a = min(hash, proof_entry)
hash_b = max(hash, proof_entry)
hasher = sha256(b"\x01")
hasher.update(hash_a)
hasher.update(hash_b)
hash = hasher.digest()
signed_tree_root = r.read_memoryview(64)
if r.remaining_count():
raise DataError("Invalid Ethereum definition")
except EOFError:
raise DataError("Invalid Ethereum definition")
# verify signature
if not ed25519.verify(consts.DEFINITIONS_PUBLIC_KEY, signed_tree_root, hash):
error_msg = DataError("Invalid definition signature")
if __debug__:
# check against dev key
if not ed25519.verify(
consts.DEFINITIONS_DEV_PUBLIC_KEY,
signed_tree_root,
hash,
):
raise error_msg
else:
raise error_msg
# decode it if it's OK
try:
return protobuf.decode(payload, expected_type, True)
except ValueError:
raise DataError("Invalid Ethereum definition")
class Definitions:
"""Class that holds Ethereum definitions - network and tokens.
Prefers built-in definitions over encoded ones.
"""
def __init__(
self, network: EthereumNetworkInfo, tokens: dict[bytes, EthereumTokenInfo]
) -> None:
self.network = network
self._tokens = tokens
@classmethod
def from_encoded(
cls,
encoded_network: bytes | None,
encoded_token: bytes | None,
chain_id: int | None = None,
slip44: int | None = None,
) -> Self:
network = UNKNOWN_NETWORK
tokens: dict[bytes, EthereumTokenInfo] = {}
# if we have a built-in definition, use it
if chain_id is not None:
network = networks.by_chain_id(chain_id)
elif slip44 is not None:
network = networks.by_slip44(slip44)
else:
# ignore encoded definitions if we can't match them to request details
return cls(UNKNOWN_NETWORK, {})
if network is UNKNOWN_NETWORK and encoded_network is not None:
network = decode_definition(encoded_network, EthereumNetworkInfo)
if network is UNKNOWN_NETWORK:
# ignore tokens if we don't have a network
return cls(UNKNOWN_NETWORK, {})
if chain_id is not None and network.chain_id != chain_id:
raise DataError("Network definition mismatch")
if slip44 is not None and network.slip44 != slip44:
raise DataError("Network definition mismatch")
# get token definition
if encoded_token is not None:
token = decode_definition(encoded_token, EthereumTokenInfo)
# Ignore token if it doesn't match the network instead of raising an error.
# This might help us in the future if we allow multiple networks/tokens
# in the same message.
if token.chain_id == network.chain_id:
tokens[token.address] = token
return cls(network, tokens)
def get_token(self, address: bytes) -> EthereumTokenInfo:
# if we have a built-in definition, use it
token = tokens.token_by_chain_address(self.network.chain_id, address)
if token is not None:
return token
if address in self._tokens:
return self._tokens[address]
return tokens.UNKNOWN_TOKEN