You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
5.3 KiB
155 lines
5.3 KiB
from typing import TYPE_CHECKING
|
|
|
|
from trezor import protobuf, utils
|
|
from trezor.crypto.curve import ed25519
|
|
from trezor.crypto.hashlib import sha256
|
|
from trezor.enums import EthereumDefinitionType
|
|
from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo
|
|
from trezor.wire import DataError
|
|
|
|
from apps.common import readers
|
|
|
|
from . import definitions_constants as consts, networks, tokens
|
|
from .networks import UNKNOWN_NETWORK
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import TypeVar
|
|
from typing_extensions import Self
|
|
|
|
DefType = TypeVar("DefType", EthereumNetworkInfo, EthereumTokenInfo)
|
|
|
|
|
|
def decode_definition(definition: bytes, expected_type: type[DefType]) -> DefType:
|
|
# check network definition
|
|
r = utils.BufferReader(definition)
|
|
expected_type_number = EthereumDefinitionType.NETWORK
|
|
# TODO: can't check equality of MsgDefObjs now, so we check the name
|
|
if expected_type.MESSAGE_NAME == EthereumTokenInfo.MESSAGE_NAME:
|
|
expected_type_number = EthereumDefinitionType.TOKEN
|
|
|
|
try:
|
|
# first check format version
|
|
if r.read_memoryview(len(consts.FORMAT_VERSION)) != consts.FORMAT_VERSION:
|
|
raise DataError("Invalid Ethereum definition")
|
|
|
|
# second check the type of the data
|
|
if r.get() != expected_type_number:
|
|
raise DataError("Definition type mismatch")
|
|
|
|
# third check data version
|
|
if readers.read_uint32_le(r) < consts.MIN_DATA_VERSION:
|
|
raise DataError("Definition is outdated")
|
|
|
|
# get payload
|
|
payload_length = readers.read_uint16_le(r)
|
|
payload = r.read_memoryview(payload_length)
|
|
|
|
# at the end compute Merkle tree root hash using
|
|
# provided leaf data (payload with prefix) and proof
|
|
hasher = sha256(b"\x00")
|
|
hasher.update(memoryview(definition)[: r.offset])
|
|
hash = hasher.digest()
|
|
proof_length = r.get()
|
|
for _ in range(proof_length):
|
|
proof_entry = r.read_memoryview(32)
|
|
hash_a = min(hash, proof_entry)
|
|
hash_b = max(hash, proof_entry)
|
|
hasher = sha256(b"\x01")
|
|
hasher.update(hash_a)
|
|
hasher.update(hash_b)
|
|
hash = hasher.digest()
|
|
|
|
signed_tree_root = r.read_memoryview(64)
|
|
|
|
if r.remaining_count():
|
|
raise DataError("Invalid Ethereum definition")
|
|
|
|
except EOFError:
|
|
raise DataError("Invalid Ethereum definition")
|
|
|
|
# verify signature
|
|
if not ed25519.verify(consts.DEFINITIONS_PUBLIC_KEY, signed_tree_root, hash):
|
|
error_msg = DataError("Invalid definition signature")
|
|
if __debug__:
|
|
# check against dev key
|
|
if not ed25519.verify(
|
|
consts.DEFINITIONS_DEV_PUBLIC_KEY,
|
|
signed_tree_root,
|
|
hash,
|
|
):
|
|
raise error_msg
|
|
else:
|
|
raise error_msg
|
|
|
|
# decode it if it's OK
|
|
try:
|
|
return protobuf.decode(payload, expected_type, True)
|
|
except ValueError:
|
|
raise DataError("Invalid Ethereum definition")
|
|
|
|
|
|
class Definitions:
|
|
"""Class that holds Ethereum definitions - network and tokens.
|
|
Prefers built-in definitions over encoded ones.
|
|
"""
|
|
|
|
def __init__(
|
|
self, network: EthereumNetworkInfo, tokens: dict[bytes, EthereumTokenInfo]
|
|
) -> None:
|
|
self.network = network
|
|
self._tokens = tokens
|
|
|
|
@classmethod
|
|
def from_encoded(
|
|
cls,
|
|
encoded_network: bytes | None,
|
|
encoded_token: bytes | None,
|
|
chain_id: int | None = None,
|
|
slip44: int | None = None,
|
|
) -> Self:
|
|
network = UNKNOWN_NETWORK
|
|
tokens: dict[bytes, EthereumTokenInfo] = {}
|
|
|
|
# if we have a built-in definition, use it
|
|
if chain_id is not None:
|
|
network = networks.by_chain_id(chain_id)
|
|
elif slip44 is not None:
|
|
network = networks.by_slip44(slip44)
|
|
else:
|
|
# ignore encoded definitions if we can't match them to request details
|
|
return cls(UNKNOWN_NETWORK, {})
|
|
|
|
if network is UNKNOWN_NETWORK and encoded_network is not None:
|
|
network = decode_definition(encoded_network, EthereumNetworkInfo)
|
|
|
|
if network is UNKNOWN_NETWORK:
|
|
# ignore tokens if we don't have a network
|
|
return cls(UNKNOWN_NETWORK, {})
|
|
|
|
if chain_id is not None and network.chain_id != chain_id:
|
|
raise DataError("Network definition mismatch")
|
|
if slip44 is not None and network.slip44 != slip44:
|
|
raise DataError("Network definition mismatch")
|
|
|
|
# get token definition
|
|
if encoded_token is not None:
|
|
token = decode_definition(encoded_token, EthereumTokenInfo)
|
|
# Ignore token if it doesn't match the network instead of raising an error.
|
|
# This might help us in the future if we allow multiple networks/tokens
|
|
# in the same message.
|
|
if token.chain_id == network.chain_id:
|
|
tokens[token.address] = token
|
|
|
|
return cls(network, tokens)
|
|
|
|
def get_token(self, address: bytes) -> EthereumTokenInfo:
|
|
# if we have a built-in definition, use it
|
|
token = tokens.token_by_chain_address(self.network.chain_id, address)
|
|
if token is not None:
|
|
return token
|
|
|
|
if address in self._tokens:
|
|
return self._tokens[address]
|
|
|
|
return tokens.UNKNOWN_TOKEN
|