diff --git a/common/tools/ethereum_definitions.py b/common/tools/ethereum_definitions.py index 04a2780481..189d300e97 100755 --- a/common/tools/ethereum_definitions.py +++ b/common/tools/ethereum_definitions.py @@ -20,7 +20,6 @@ from urllib3.util.retry import Retry from coin_info import ( Coin, - CoinBuckets, Coins, _load_builtin_erc20_tokens, _load_builtin_ethereum_networks, @@ -389,7 +388,8 @@ def print_definitions_collision( Returns a tuple composed from the prompt result if prompted otherwise None and the default value.""" if old_definitions: old_defs_hash_no_metadata = [ - hash_dict_on_keys(d, exclude_keys=["metadata"]) for d in old_definitions + hash_dict_on_keys(d, exclude_keys=["metadata", "coingecko_id"]) + for d in old_definitions ] default_index = None @@ -399,7 +399,7 @@ def print_definitions_collision( found = "" if ( old_definitions - and hash_dict_on_keys(definition, exclude_keys=["metadata"]) + and hash_dict_on_keys(definition, exclude_keys=["metadata", "coingecko_id"]) in old_defs_hash_no_metadata ): found = " (found in old definitions)" @@ -566,6 +566,8 @@ def check_definitions_list( modified_definitions.remove((orig_def, new_def)) def any_in_top_100(*definitions) -> bool: + if top100_coingecko_ids is None: + return True if definitions is not None: for d in definitions: if d is not None and d.get("coingecko_id") in top100_coingecko_ids: @@ -625,10 +627,7 @@ def check_definitions_list( accept_change = True print_change = any_in_top_100(old_def, new_def) # if the change contains symbol change "--force" parameter must be used to be able to accept this change - if ( - old_def.get("shortcut") != new_def.get("shortcut") - and not force - ): + if old_def.get("shortcut") != new_def.get("shortcut") and not force: print( "\nERROR: Symbol change in this definition! To be able to approve this change re-run with `--force` argument." ) @@ -677,16 +676,22 @@ def check_definitions_list( _set_definition_metadata(definition) -def _load_prepared_definitions(definitions_file: pathlib.Path) -> tuple[list[dict], list[dict]]: +def _load_prepared_definitions( + definitions_file: pathlib.Path, +) -> tuple[list[dict], list[dict]]: if not definitions_file.is_file(): - click.ClickException(f"File {definitions_file} with prepared definitions does not exists or is not a file.") + click.ClickException( + f"File {definitions_file} with prepared definitions does not exists or is not a file." + ) prepared_definitions_data = load_json(definitions_file) try: networks_data = prepared_definitions_data["networks"] tokens_data = prepared_definitions_data["tokens"] except KeyError: - click.ClickException(f"File with prepared definitions is not complete. Whole \"networks\" and/or \"tokens\" section are missing.") + click.ClickException( + 'File with prepared definitions is not complete. Whole "networks" and/or "tokens" section are missing.' + ) networks: Coins = [] for network_data in networks_data: @@ -872,6 +877,15 @@ def prepare_definitions( cg_tokens = _load_erc20_tokens_from_coingecko(downloader, networks) repo_tokens = _load_erc20_tokens_from_repo(tokens_dir, networks) + # get data used in further processing now to be able to save cache before we do any + # token collision process and others + # get CoinGecko coin list + cg_coin_list = downloader.get_coingecko_coins_list() + # get top 100 coins + cg_top100 = downloader.get_coingecko_top100() + # save cache + downloader.save_cache() + # merge tokens tokens: List[Dict] = [] cg_tokens_chain_id_and_address = [] @@ -901,7 +915,6 @@ def prepare_definitions( # map coingecko ids to tokens tokens_by_chain_id_and_address = {(t["chain_id"], t["address"]): t for t in tokens} - cg_coin_list = downloader.get_coingecko_coins_list() for coin in cg_coin_list: for platform_name, address in coin.get("platforms", dict()).items(): key = (coingecko_id_to_chain_id.get(platform_name), address) @@ -909,10 +922,7 @@ def prepare_definitions( tokens_by_chain_id_and_address[key]["coingecko_id"] = coin["id"] # load top 100 (by market cap) definitions from CoinGecko - cg_top100_ids = [d["id"] for d in downloader.get_coingecko_top100()] - - # save cache - downloader.save_cache() + cg_top100_ids = [d["id"] for d in cg_top100] # check changes in definitions if old_defs is not None: @@ -970,9 +980,11 @@ def prepare_definitions( "--deffile", type=click.Path(resolve_path=True, dir_okay=False, path_type=pathlib.Path), default="./definitions-latest.json", - help="File where the prepared definitions are saved in json format." + help="File where the prepared definitions are saved in json format.", ) -def sign_definitions(outdir: pathlib.Path, privatekey: TextIO, deffile: pathlib.Path) -> None: +def sign_definitions( + outdir: pathlib.Path, privatekey: TextIO, deffile: pathlib.Path +) -> None: """Generate signed Ethereum definitions for python-trezor and others.""" hex_key = None if privatekey is None: @@ -992,7 +1004,7 @@ def sign_definitions(outdir: pathlib.Path, privatekey: TextIO, deffile: pathlib. if complete_file_path.exists(): raise click.ClickException( - f"Definition \"{complete_file_path}\" already generated - attempt to generate another definition." + f'Definition "{complete_file_path}" already generated - attempt to generate another definition.' ) directory.mkdir(parents=True, exist_ok=True) @@ -1039,7 +1051,8 @@ def sign_definitions(outdir: pathlib.Path, privatekey: TextIO, deffile: pathlib. definitions_by_serialization: dict[bytes, dict] = dict() for network in networks: ser = serialize_eth_info( - eth_info_from_dict(network, EthereumNetworkInfo), EthereumDefinitionType.NETWORK + eth_info_from_dict(network, EthereumNetworkInfo), + EthereumDefinitionType.NETWORK, ) network["serialized"] = ser definitions_by_serialization[ser] = network @@ -1052,8 +1065,8 @@ def sign_definitions(outdir: pathlib.Path, privatekey: TextIO, deffile: pathlib. # build Merkle tree mt = MerkleTree( - [network["serialized"] for network in networks] + - [token["serialized"] for token in tokens] + [network["serialized"] for network in networks] + + [token["serialized"] for token in tokens] ) # sign tree root hash diff --git a/common/tools/merkle_tree.py b/common/tools/merkle_tree.py index d1c9b63eb0..682abfad39 100755 --- a/common/tools/merkle_tree.py +++ b/common/tools/merkle_tree.py @@ -6,12 +6,15 @@ except ImportError: from hashlib import sha256 -class Node(): +class Node: """ Single node of Merkle tree. """ - def __init__(self: "Node", left: Union[bytes, "Node"], right: Optional["Node"] = None) -> None: - self.is_leaf = (left is None) != (right is None) # XOR + + def __init__( + self: "Node", left: Union[bytes, "Node"], right: Optional["Node"] = None + ) -> None: + self.is_leaf = (left is None) != (right is None) # XOR if self.is_leaf: self.raw_value = left self.hash = None @@ -43,11 +46,12 @@ class Node(): self.right_child.add_to_proof(proof) -class MerkleTree(): +class MerkleTree: """ Simple Merkle tree that implements the building of Merkle tree itself and generate proofs for leaf nodes. """ + def __init__(self, values: list[bytes]) -> None: self.leaves = [Node(v) for v in values] diff --git a/core/src/apps/ethereum/definitions.py b/core/src/apps/ethereum/definitions.py index ab6fddbe77..4e43317081 100644 --- a/core/src/apps/ethereum/definitions.py +++ b/core/src/apps/ethereum/definitions.py @@ -1,28 +1,23 @@ +from typing import Any from ubinascii import unhexlify -from typing import TYPE_CHECKING - -from apps.ethereum import tokens from trezor import protobuf, wire from trezor.crypto.curve import ed25519 from trezor.enums import EthereumDefinitionType from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo +from apps.ethereum import tokens + from . import helpers, networks -if TYPE_CHECKING: - from trezor.protobuf import MessageType - - from .networks import NetworkInfo - from .tokens import TokenInfo - - DEFINITIONS_PUBLIC_KEY = b"" MIN_DATA_VERSION = 1 FORMAT_VERSION = "trzd1" if __debug__: - DEFINITIONS_DEV_PUBLIC_KEY = unhexlify("db995fe25169d141cab9bbba92baa01f9f2e1ece7df4cb2ac05190f37fcc1f9d") + DEFINITIONS_DEV_PUBLIC_KEY = unhexlify( + "db995fe25169d141cab9bbba92baa01f9f2e1ece7df4cb2ac05190f37fcc1f9d" + ) class EthereumDefinitionParser: @@ -31,15 +26,21 @@ class EthereumDefinitionParser: try: # prefix - self.format_version = definition_bytes[:8].rstrip(b'\0').decode("utf-8") + self.format_version = definition_bytes[:8].rstrip(b"\0").decode("utf-8") self.definition_type: int = definition_bytes[8] - self.data_version = int.from_bytes(definition_bytes[9:13], 'big') - self.payload_length_in_bytes = int.from_bytes(definition_bytes[13:15], 'big') + self.data_version = int.from_bytes(definition_bytes[9:13], "big") + self.payload_length_in_bytes = int.from_bytes( + definition_bytes[13:15], "big" + ) actual_position += 8 + 1 + 4 + 2 # payload - self.payload = definition_bytes[actual_position:(actual_position + self.payload_length_in_bytes)] - self.payload_with_prefix = definition_bytes[:(actual_position + self.payload_length_in_bytes)] + self.payload = definition_bytes[ + actual_position : (actual_position + self.payload_length_in_bytes) + ] + self.payload_with_prefix = definition_bytes[ + : (actual_position + self.payload_length_in_bytes) + ] actual_position += self.payload_length_in_bytes # suffix - Merkle tree proof and signed root hash @@ -47,16 +48,20 @@ class EthereumDefinitionParser: actual_position += 1 self.proof: list[bytes] = [] for _ in range(self.proof_length): - self.proof.append(definition_bytes[actual_position:(actual_position + 32)]) + self.proof.append( + definition_bytes[actual_position : (actual_position + 32)] + ) actual_position += 32 - self.signed_tree_root = definition_bytes[actual_position:(actual_position + 64)] + self.signed_tree_root = definition_bytes[ + actual_position : (actual_position + 64) + ] except IndexError: raise wire.DataError("Invalid Ethereum definition.") def decode_definition( definition: bytes, expected_type: EthereumDefinitionType -) -> NetworkInfo | TokenInfo: +) -> EthereumNetworkInfo | EthereumTokenInfo: # check network definition parsed_definition = EthereumDefinitionParser(definition) @@ -75,6 +80,7 @@ def decode_definition( # at the end verify the signature - compute Merkle tree root hash using provided leaf data and proof def compute_mt_root_hash(data: bytes, proof: list[bytes]) -> bytes: from trezor.crypto.hashlib import sha256 + hash = sha256(b"\x00" + data).digest() for p in proof: hash_a = min(hash, p) @@ -84,13 +90,21 @@ def decode_definition( return hash # verify Merkle proof - root_hash = compute_mt_root_hash(parsed_definition.payload_with_prefix, parsed_definition.proof) + root_hash = compute_mt_root_hash( + parsed_definition.payload_with_prefix, parsed_definition.proof + ) - if not ed25519.verify(DEFINITIONS_PUBLIC_KEY, parsed_definition.signed_tree_root, root_hash): + if not ed25519.verify( + DEFINITIONS_PUBLIC_KEY, parsed_definition.signed_tree_root, root_hash + ): error_msg = wire.DataError("Ethereum definition signature is invalid.") if __debug__: # check against dev key - if not ed25519.verify(DEFINITIONS_DEV_PUBLIC_KEY, parsed_definition.signed_tree_root, root_hash): + if not ed25519.verify( + DEFINITIONS_DEV_PUBLIC_KEY, + parsed_definition.signed_tree_root, + root_hash, + ): raise error_msg else: raise error_msg @@ -98,35 +112,15 @@ def decode_definition( # decode it if it's OK if expected_type == EthereumDefinitionType.NETWORK: info = protobuf.decode(parsed_definition.payload, EthereumNetworkInfo, True) - - # TODO: temporarily convert to internal class - if info is not None: - from .networks import NetworkInfo - info = NetworkInfo( - chain_id=info.chain_id, - slip44=info.slip44, - shortcut=info.shortcut, - name=info.name, - rskip60=info.rskip60 - ) else: info = protobuf.decode(parsed_definition.payload, EthereumTokenInfo, True) - # TODO: temporarily convert to internal class - if info is not None: - from .tokens import TokenInfo - info = TokenInfo( - symbol=info.symbol, - decimals=info.decimals, - address=info.address, - chain_id=info.chain_id, - name=info.name, - ) - return info -def _get_network_definiton(encoded_network_definition: bytes | None, ref_chain_id: int | None = None) -> NetworkInfo | None: +def _get_network_definiton( + encoded_network_definition: bytes | None, ref_chain_id: int | None = None +) -> EthereumNetworkInfo | None: if encoded_network_definition is None and ref_chain_id is None: return None @@ -134,23 +128,31 @@ def _get_network_definiton(encoded_network_definition: bytes | None, ref_chain_i # if we have a built-in definition, use it network = networks.by_chain_id(ref_chain_id) if network is not None: - return network + return network # type: EthereumNetworkInfo if encoded_network_definition is not None: # get definition if it was send - network = decode_definition(encoded_network_definition, EthereumDefinitionType.NETWORK) + network = decode_definition( + encoded_network_definition, EthereumDefinitionType.NETWORK + ) # check referential chain_id with encoded chain_id if ref_chain_id is not None and network.chain_id != ref_chain_id: raise wire.DataError("Invalid network definition - chain IDs not equal.") - return network + return network # type: ignore [Expression of type "EthereumNetworkInfo | EthereumTokenInfo" cannot be assigned to return type "EthereumNetworkInfo | None"] return None -def _get_token_definiton(encoded_token_definition: bytes | None, ref_chain_id: int | None = None, ref_address: bytes | None = None) -> TokenInfo: - if encoded_token_definition is None and (ref_chain_id is None or ref_address is None): +def _get_token_definiton( + encoded_token_definition: bytes | None, + ref_chain_id: int | None = None, + ref_address: bytes | None = None, +) -> EthereumTokenInfo: + if encoded_token_definition is None and ( + ref_chain_id is None or ref_address is None + ): return tokens.UNKNOWN_TOKEN # if we have a built-in definition, use it @@ -161,12 +163,13 @@ def _get_token_definiton(encoded_token_definition: bytes | None, ref_chain_id: i if encoded_token_definition is not None: # get definition if it was send - token = decode_definition(encoded_token_definition, EthereumDefinitionType.TOKEN) + token: EthereumTokenInfo = decode_definition( # type: ignore [Expression of type "EthereumNetworkInfo | EthereumTokenInfo" cannot be assigned to declared type "EthereumTokenInfo"] + encoded_token_definition, EthereumDefinitionType.TOKEN + ) # check token against ref_chain_id and ref_address - if ( - (ref_chain_id is None or token.chain_id == ref_chain_id) - and (ref_address is None or token.address == ref_address) + if (ref_chain_id is None or token.chain_id == ref_chain_id) and ( + ref_address is None or token.address == ref_address ): return token @@ -175,6 +178,7 @@ def _get_token_definiton(encoded_token_definition: bytes | None, ref_chain_id: i class EthereumDefinitions: """Class that holds Ethereum definitions - network and tokens. Prefers built-in definitions over encoded ones.""" + def __init__( self, encoded_network_definition: bytes | None = None, @@ -183,20 +187,22 @@ class EthereumDefinitions: ref_token_address: bytes | None = None, ) -> None: self.network = _get_network_definiton(encoded_network_definition, ref_chain_id) - self.token_dict: dict[bytes, TokenInfo] = dict() + self.token_dict: dict[bytes, EthereumTokenInfo] = dict() # if we have some network, we can try to get token if self.network is not None: - token = _get_token_definiton(encoded_token_definition, self.network.chain_id, ref_token_address) + token = _get_token_definiton( + encoded_token_definition, self.network.chain_id, ref_token_address + ) if token is not tokens.UNKNOWN_TOKEN: self.token_dict[token.address] = token -def get_definitions_from_msg(msg: MessageType) -> EthereumDefinitions: +def get_definitions_from_msg(msg: Any) -> EthereumDefinitions: encoded_network_definition: bytes | None = None encoded_token_definition: bytes | None = None chain_id: int | None = None - token_address: str | None = None + token_address: bytes | None = None # first try to get both definitions try: @@ -225,4 +231,6 @@ def get_definitions_from_msg(msg: MessageType) -> EthereumDefinitions: except AttributeError: pass - return EthereumDefinitions(encoded_network_definition, encoded_token_definition, chain_id, token_address) + return EthereumDefinitions( + encoded_network_definition, encoded_token_definition, chain_id, token_address + ) diff --git a/core/src/apps/ethereum/get_address.py b/core/src/apps/ethereum/get_address.py index f031a77bd3..a92c2302f5 100644 --- a/core/src/apps/ethereum/get_address.py +++ b/core/src/apps/ethereum/get_address.py @@ -20,7 +20,10 @@ if TYPE_CHECKING: @with_keychain_from_path_and_defs(*PATTERNS_ADDRESS) async def get_address( - ctx: Context, msg: EthereumGetAddress, keychain: Keychain, defs: definitions.EthereumDefinitions + ctx: Context, + msg: EthereumGetAddress, + keychain: Keychain, + defs: definitions.EthereumDefinitions, ) -> EthereumAddress: from trezor.messages import EthereumAddress from trezor.ui.layouts import show_address @@ -36,7 +39,7 @@ async def get_address( if len(msg.address_n) > 1: # path has slip44 network identifier slip44 = msg.address_n[1] & 0x7FFF_FFFF - if slip44 == defs.network.slip44: + if defs.network is not None and slip44 == defs.network.slip44: network = defs.network else: network = networks.by_slip44(slip44) diff --git a/core/src/apps/ethereum/helpers.py b/core/src/apps/ethereum/helpers.py index 5f931f4593..b29cb64da7 100644 --- a/core/src/apps/ethereum/helpers.py +++ b/core/src/apps/ethereum/helpers.py @@ -5,10 +5,12 @@ from .networks import by_chain_id if TYPE_CHECKING: from trezor.messages import EthereumFieldType - from .networks import NetworkInfo + from .networks import EthereumNetworkInfo -def address_from_bytes(address_bytes: bytes, network: NetworkInfo = by_chain_id(1)) -> str: +def address_from_bytes( + address_bytes: bytes, network: EthereumNetworkInfo | None = by_chain_id(1) +) -> str: """ Converts address in bytes to a checksummed string as defined in https://github.com/ethereum/EIPs/blob/master/EIPS/eip-55.md diff --git a/core/src/apps/ethereum/keychain.py b/core/src/apps/ethereum/keychain.py index 81a4748c79..d102c57738 100644 --- a/core/src/apps/ethereum/keychain.py +++ b/core/src/apps/ethereum/keychain.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING from apps.common import paths from apps.common.keychain import get_keychain -from . import CURVE, networks, definitions +from . import CURVE, definitions, networks if TYPE_CHECKING: from typing import Awaitable, Callable, Iterable, TypeVar @@ -15,20 +15,31 @@ if TYPE_CHECKING: from trezor.messages import ( EthereumGetAddress, EthereumGetPublicKey, + EthereumNetworkInfo, EthereumSignMessage, EthereumSignTx, EthereumSignTxEIP1559, EthereumSignTypedData, ) - from apps.common.keychain import MsgIn as MsgInGeneric, MsgOut, Handler, HandlerWithKeychain + from apps.common.keychain import ( + MsgIn as MsgInGeneric, + MsgOut, + Handler, + HandlerWithKeychain, + ) # messages for "with_keychain_from_path" decorator MsgInKeychainPath = TypeVar("MsgInKeychainPath", bound=EthereumGetPublicKey) # messages for "with_keychain_from_path_and_defs" decorator - MsgInKeychainPathDefs = TypeVar("MsgInKeychainPathDefs", bound=EthereumGetAddress | EthereumSignMessage | EthereumSignTypedData) + MsgInKeychainPathDefs = TypeVar( + "MsgInKeychainPathDefs", + bound=EthereumGetAddress | EthereumSignMessage | EthereumSignTypedData, + ) # messages for "with_keychain_from_chain_id_and_defs" decorator - MsgInKeychainChainIdDefs = TypeVar("MsgInKeychainChainIdDefs", bound=EthereumSignTx | EthereumSignTxEIP1559) + MsgInKeychainChainIdDefs = TypeVar( + "MsgInKeychainChainIdDefs", bound=EthereumSignTx | EthereumSignTxEIP1559 + ) # TODO: check the types of messages HandlerWithKeychainAndDefinitions = Callable[[Context, MsgInGeneric, Keychain, definitions.EthereumDefinitions], Awaitable[MsgOut]] @@ -48,7 +59,9 @@ PATTERNS_ADDRESS = ( def _schemas_from_address_n( - patterns: Iterable[str], address_n: paths.Bip32Path, network_info: networks.NetworkInfo | None + patterns: Iterable[str], + address_n: paths.Bip32Path, + network_info: EthereumNetworkInfo | None, ) -> Iterable[paths.PathSchema]: if len(address_n) < 2: return () @@ -104,7 +117,9 @@ def with_keychain_from_path_and_defs( return decorator -def _schemas_from_chain_id(network_info: networks.NetworkInfo | None) -> Iterable[paths.PathSchema]: +def _schemas_from_chain_id( + network_info: EthereumNetworkInfo | None, +) -> Iterable[paths.PathSchema]: slip44_id: tuple[int, ...] if network_info is None: # allow Ethereum or testnet paths for unknown networks diff --git a/core/src/apps/ethereum/layout.py b/core/src/apps/ethereum/layout.py index 825a818614..8717e96abd 100644 --- a/core/src/apps/ethereum/layout.py +++ b/core/src/apps/ethereum/layout.py @@ -21,15 +21,15 @@ if TYPE_CHECKING: from trezor.wire import Context from . import tokens - from . import tokens + from trezor.messages import EthereumNetworkInfo, EthereumTokenInfo def require_confirm_tx( ctx: Context, to_bytes: bytes, value: int, - network: networks.NetworkInfo, - token: tokens.TokenInfo, + network: EthereumNetworkInfo | None, + token: EthereumTokenInfo | None, ) -> Awaitable[None]: from .helpers import address_from_bytes from trezor.ui.layouts import confirm_output @@ -53,8 +53,8 @@ async def require_confirm_fee( spending: int, gas_price: int, gas_limit: int, - network: networks.NetworkInfo, - token: tokens.TokenInfo, + network: EthereumNetworkInfo | None, + token: EthereumTokenInfo | None, ) -> None: await confirm_amount( ctx, @@ -77,8 +77,8 @@ async def require_confirm_eip1559_fee( max_priority_fee: int, max_gas_fee: int, gas_limit: int, - network: networks.NetworkInfo, - token: tokens.TokenInfo, + network: EthereumNetworkInfo | None, + token: EthereumTokenInfo | None, ) -> None: await confirm_amount( ctx, @@ -253,7 +253,9 @@ async def confirm_typed_value( def format_ethereum_amount( - value: int, token: tokens.TokenInfo | None, network_info: networks.NetworkInfo | None + value: int, + token: EthereumTokenInfo | None, + network_info: EthereumNetworkInfo | None, ) -> str: from trezor.strings import format_amount diff --git a/core/src/apps/ethereum/networks.py b/core/src/apps/ethereum/networks.py index 69c0bf02fc..128f755958 100644 --- a/core/src/apps/ethereum/networks.py +++ b/core/src/apps/ethereum/networks.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING from apps.common.paths import HARDENED +from trezor.messages import EthereumNetworkInfo if TYPE_CHECKING: from typing import Iterator @@ -25,11 +26,11 @@ if TYPE_CHECKING: UNKNOWN_NETWORK_SHORTCUT = "UNKN" -def by_chain_id(chain_id: int) -> "NetworkInfo" | None: +def by_chain_id(chain_id: int) -> EthereumNetworkInfo | None: for n in _networks_iterator(): n_chain_id = n[0] if n_chain_id == chain_id: - return NetworkInfo( + return EthereumNetworkInfo( chain_id=n[0], slip44=n[1], shortcut=n[2], @@ -39,11 +40,11 @@ def by_chain_id(chain_id: int) -> "NetworkInfo" | None: return None -def by_slip44(slip44: int) -> "NetworkInfo" | None: +def by_slip44(slip44: int) -> EthereumNetworkInfo | None: for n in _networks_iterator(): n_slip44 = n[1] if n_slip44 == slip44: - return NetworkInfo( + return EthereumNetworkInfo( chain_id=n[0], slip44=n[1], shortcut=n[2], @@ -59,17 +60,6 @@ def all_slip44_ids_hardened() -> Iterator[int]: yield n[1] | HARDENED -class NetworkInfo: - def __init__( - self, chain_id: int, slip44: int, shortcut: str, name: str, rskip60: bool - ) -> None: - self.chain_id = chain_id - self.slip44 = slip44 - self.shortcut = shortcut - self.name = name - self.rskip60 = rskip60 - - # fmt: off def _networks_iterator() -> Iterator[NetworkInfoTuple]: yield ( diff --git a/core/src/apps/ethereum/networks.py.mako b/core/src/apps/ethereum/networks.py.mako index 83ba4b667e..440813e327 100644 --- a/core/src/apps/ethereum/networks.py.mako +++ b/core/src/apps/ethereum/networks.py.mako @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING from apps.common.paths import HARDENED +from trezor.messages import EthereumNetworkInfo if TYPE_CHECKING: from typing import Iterator @@ -25,11 +26,11 @@ if TYPE_CHECKING: UNKNOWN_NETWORK_SHORTCUT = "UNKN" -def by_chain_id(chain_id: int) -> "NetworkInfo" | None: +def by_chain_id(chain_id: int) -> EthereumNetworkInfo | None: for n in _networks_iterator(): n_chain_id = n[0] if n_chain_id == chain_id: - return NetworkInfo( + return EthereumNetworkInfo( chain_id=n[0], slip44=n[1], shortcut=n[2], @@ -39,11 +40,11 @@ def by_chain_id(chain_id: int) -> "NetworkInfo" | None: return None -def by_slip44(slip44: int) -> "NetworkInfo" | None: +def by_slip44(slip44: int) -> EthereumNetworkInfo | None: for n in _networks_iterator(): n_slip44 = n[1] if n_slip44 == slip44: - return NetworkInfo( + return EthereumNetworkInfo( chain_id=n[0], slip44=n[1], shortcut=n[2], @@ -59,17 +60,6 @@ def all_slip44_ids_hardened() -> Iterator[int]: yield n[1] | HARDENED -class NetworkInfo: - def __init__( - self, chain_id: int, slip44: int, shortcut: str, name: str, rskip60: bool - ) -> None: - self.chain_id = chain_id - self.slip44 = slip44 - self.shortcut = shortcut - self.name = name - self.rskip60 = rskip60 - - # fmt: off def _networks_iterator() -> Iterator[NetworkInfoTuple]: % for n in supported_on("trezor2", eth): diff --git a/core/src/apps/ethereum/sign_message.py b/core/src/apps/ethereum/sign_message.py index a14d39d6a6..452128a92a 100644 --- a/core/src/apps/ethereum/sign_message.py +++ b/core/src/apps/ethereum/sign_message.py @@ -25,7 +25,10 @@ def message_digest(message: bytes) -> bytes: @with_keychain_from_path_and_defs(*PATTERNS_ADDRESS) async def sign_message( - ctx: Context, msg: EthereumSignMessage, keychain: Keychain, defs: definitions.EthereumDefinitions + ctx: Context, + msg: EthereumSignMessage, + keychain: Keychain, + defs: definitions.EthereumDefinitions, ) -> EthereumMessageSignature: from trezor.crypto.curve import secp256k1 from trezor.messages import EthereumMessageSignature diff --git a/core/src/apps/ethereum/sign_tx.py b/core/src/apps/ethereum/sign_tx.py index dce8cf4421..f15077af3d 100644 --- a/core/src/apps/ethereum/sign_tx.py +++ b/core/src/apps/ethereum/sign_tx.py @@ -9,7 +9,7 @@ from .keychain import with_keychain_from_chain_id_and_defs if TYPE_CHECKING: from apps.common.keychain import Keychain - from trezor.messages import EthereumSignTx, EthereumTxAck + from trezor.messages import EthereumSignTx, EthereumTxAck, EthereumTokenInfo from trezor.wire import Context from .keychain import MsgInKeychainChainIdDefs @@ -45,7 +45,9 @@ async def sign_tx( await paths.validate_path(ctx, keychain, msg.address_n) # Handle ERC20s - token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict) + token, address_bytes, recipient, value = await handle_erc20( + ctx, msg, defs.token_dict + ) data_total = msg.data_length @@ -100,13 +102,14 @@ async def sign_tx( async def handle_erc20( - ctx: Context, msg: MsgInKeychainChainIdDefs, token_dict: dict[bytes, tokens.TokenInfo] -) -> tuple[tokens.TokenInfo | None, bytes, bytes, int]: + ctx: Context, + msg: MsgInKeychainChainIdDefs, # type: ignore [TypeVar "MsgInKeychainChainIdDefs" appears only once in generic function signature] + token_dict: dict[bytes, EthereumTokenInfo], +) -> tuple[EthereumTokenInfo | None, bytes, bytes, int]: from .layout import require_confirm_unknown_token from . import tokens data_initial_chunk = msg.data_initial_chunk # local_cache_attribute - token = None address_bytes = recipient = bytes_from_address(msg.to) value = int.from_bytes(msg.value, "big") @@ -185,7 +188,7 @@ def _sign_digest( return req -def check_common_fields(msg: MsgInKeychainChainIdDefs) -> None: +def check_common_fields(msg: MsgInKeychainChainIdDefs) -> None: # type: ignore [TypeVar "MsgInKeychainChainIdDefs" appears only once in generic function signature] data_length = msg.data_length # local_cache_attribute if data_length > 0: diff --git a/core/src/apps/ethereum/sign_tx_eip1559.py b/core/src/apps/ethereum/sign_tx_eip1559.py index 961501b998..f842bd5f79 100644 --- a/core/src/apps/ethereum/sign_tx_eip1559.py +++ b/core/src/apps/ethereum/sign_tx_eip1559.py @@ -57,7 +57,9 @@ async def sign_tx_eip1559( await paths.validate_path(ctx, keychain, msg.address_n) # Handle ERC20s - token, address_bytes, recipient, value = await handle_erc20(ctx, msg, defs.token_dict) + token, address_bytes, recipient, value = await handle_erc20( + ctx, msg, defs.token_dict + ) data_total = msg.data_length diff --git a/core/src/apps/ethereum/sign_typed_data.py b/core/src/apps/ethereum/sign_typed_data.py index 56478f0d06..cbf60c905c 100644 --- a/core/src/apps/ethereum/sign_typed_data.py +++ b/core/src/apps/ethereum/sign_typed_data.py @@ -29,7 +29,10 @@ _MAX_VALUE_BYTE_SIZE = const(1024) @with_keychain_from_path_and_defs(*PATTERNS_ADDRESS) async def sign_typed_data( - ctx: Context, msg: EthereumSignTypedData, keychain: Keychain, defs: definitions.EthereumDefinitions + ctx: Context, + msg: EthereumSignTypedData, + keychain: Keychain, + defs: definitions.EthereumDefinitions, ) -> EthereumTypedDataSignature: from trezor.crypto.curve import secp256k1 from apps.common import paths diff --git a/core/src/apps/ethereum/tokens.py b/core/src/apps/ethereum/tokens.py index 4de7234431..1ac6c67eff 100644 --- a/core/src/apps/ethereum/tokens.py +++ b/core/src/apps/ethereum/tokens.py @@ -35,6 +35,7 @@ class TokenInfo: UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0, b"", 0) +# TODO: delete completely def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo: for addr, symbol, decimal in _token_iterator(chain_id): if address == addr: diff --git a/core/src/apps/ethereum/tokens.py.mako b/core/src/apps/ethereum/tokens.py.mako index 523ac693f3..3ece60810a 100644 --- a/core/src/apps/ethereum/tokens.py.mako +++ b/core/src/apps/ethereum/tokens.py.mako @@ -14,6 +14,7 @@ # of it has enough collision-resistance.) # (In the if-tree approach the address length did not have any effect whatsoever.) +from trezor.messages import EthereumTokenInfo from typing import Iterator <% @@ -26,39 +27,38 @@ def group_tokens(tokens): return r %>\ -class TokenInfo: - def __init__( - self, - symbol: str, - decimals: int, - address: bytes, - chain_id: int, - name: str = None, - ) -> None: - self.symbol = symbol - self.decimals = decimals - self.address = address - self.chain_id = chain_id - self.name = name +UNKNOWN_TOKEN = EthereumTokenInfo( + symbol="Wei UNKN", + decimals=0, + address=b"", + chain_id=0, + name="Unknown token", +) -UNKNOWN_TOKEN = TokenInfo("Wei UNKN", 0, b"", 0) -def token_by_chain_address(chain_id: int, address: bytes) -> TokenInfo: - for addr, symbol, decimal in _token_iterator(chain_id): +def token_by_chain_address(chain_id: int, address: bytes) -> EthereumTokenInfo: + for addr, symbol, decimal, name in _token_iterator(chain_id): if address == addr: - return TokenInfo(symbol, decimal) + return EthereumTokenInfo( + symbol=symbol, + decimals=decimal, + address=address, + chain_id=chain_id, + name=name, + ) return UNKNOWN_TOKEN def _token_iterator(chain_id: int) -> Iterator[tuple[bytes, str, int]]: % for token_chain_id, tokens in group_tokens(supported_on("trezor2", erc20)).items(): - if chain_id == ${token_chain_id}: + if chain_id == ${token_chain_id}: # ${tokens[0].chain} % for t in tokens: - yield ( # address, symbol, decimals + yield ( # address, symbol, decimals, name ${black_repr(t.address_bytes)}, ${black_repr(t.symbol)}, ${t.decimals}, + ${black_repr(t.name.strip())}, ) % endfor % endfor diff --git a/core/tests/ethereum_common.py b/core/tests/ethereum_common.py index 3c44418f46..0dc00b5e0e 100644 --- a/core/tests/ethereum_common.py +++ b/core/tests/ethereum_common.py @@ -3,14 +3,14 @@ from trezor.utils import ensure from ubinascii import hexlify, unhexlify # noqa: F401 from trezor import messages -from apps.ethereum import networks, tokens +from apps.ethereum import tokens EXPECTED_FORMAT_VERSION = 1 EXPECTED_DATA_VERSION = 1663054984 # unix epoch time class InfoWithDefinition(): - def __init__(self, definition, info): + def __init__(self, definition: bytes | None, info: messages.EthereumNetworkInfo | messages.EthereumTokenInfo): self.definition = definition self.info = info @@ -21,7 +21,7 @@ NETWORKS = { # Ethereum 1: InfoWithDefinition( definition=None, # built-in definitions are not encoded - info=networks.NetworkInfo( + info=messages.EthereumNetworkInfo( chain_id=1, slip44=60, shortcut="ETH", @@ -32,7 +32,7 @@ NETWORKS = { # Rinkeby 4: InfoWithDefinition( definition=unhexlify("74727a643100000000632034880015080410011a047452494e220752696e6b65627928000e8cc47ed4e657d9a9b98e1dd02164320c54a9724e17f91d1d79f6760169582c98ec70ca6f4e94d27e574175c59d2ae04e0cd30b65fb19acd8d2c5fb90bcb7db96f6102e4182c0cef5f412ac3c5fa94f9505b4df2633a0f7bdffa309588d722415624adeb8f329b1572ff9dfc81fbc86e61f1fcb2369f51ba85ea765c908ac254ba996f842a6277583f8d02f149c78bc0eeb8f3d41240403f85785dc3a3925ea768d76aae12342c8a24de223c1ea75e5f07f6b94b8f22189413631eed3c9a362b4501f68b645aa487b9d159a8161404a218507641453ebf045cec56710bb7d873e102777695b56903766e1af16f95576ec4f41874bdaf80cec02ee067d30e721515564d4f30fa74a6c61eb784ea65cc881ead7af2ffac02d5bf1fe1a756918fe37b74828a24b640025cd79443ada60063e3034444fc49ed6055dbba6a09fa4484c42cb85abb49103dc8c781c8f190c4632e2dec30081770448021313955dbb49e8a02fd49b34d030280452fe0a5c3bcba4958bc287c67e12519be4f4aec7ab0c8e574e53a663f635f75508f23d92c77b2147f29feb79c38d0f793fba295aae605c7e8226523edefc6ad1eefe088e5b8376028bf90116ece4fb876510b4ae1c89686dbcaacbbac8225baba429ca376fafac50f4bd1ff4ce1c61dd53318d0718bf513ea6f770cce81e07a653622e4dbd03bdaa570bfe43219eb0d4fab725c9a8da04"), - info=networks.NetworkInfo( + info=messages.EthereumNetworkInfo( chain_id=4, slip44=1, shortcut="tRIN", @@ -43,7 +43,7 @@ NETWORKS = { # Ubiq 8: InfoWithDefinition( definition=unhexlify("74727a6431000000006320348800110808106c1a0355425122045562697128000e5641d82e3622b4e6addd4354efd933cf15947d1d608a60d324d1156b5a4999f70c41beb85bd866aa3059123447dfeef2e1b6c009b66ac8d04ebbca854ad30049edbbb2fbfda3bfedc6fdb4a76f1db8a4f210bd89d3c3ec1761157b0ec2b13e2f624adeb8f329b1572ff9dfc81fbc86e61f1fcb2369f51ba85ea765c908ac254ba996f842a6277583f8d02f149c78bc0eeb8f3d41240403f85785dc3a3925ea768d76aae12342c8a24de223c1ea75e5f07f6b94b8f22189413631eed3c9a362b4501f68b645aa487b9d159a8161404a218507641453ebf045cec56710bb7d873e102777695b56903766e1af16f95576ec4f41874bdaf80cec02ee067d30e721515564d4f30fa74a6c61eb784ea65cc881ead7af2ffac02d5bf1fe1a756918fe37b74828a24b640025cd79443ada60063e3034444fc49ed6055dbba6a09fa4484c42cb85abb49103dc8c781c8f190c4632e2dec30081770448021313955dbb49e8a02fd49b34d030280452fe0a5c3bcba4958bc287c67e12519be4f4aec7ab0c8e574e53a663f635f75508f23d92c77b2147f29feb79c38d0f793fba295aae605c7e8226523edefc6ad1eefe088e5b8376028bf90116ece4fb876510b4ae1c89686dbcaacbbac8225baba429ca376fafac50f4bd1ff4ce1c61dd53318d0718bf513ea6f770cce81e07a653622e4dbd03bdaa570bfe43219eb0d4fab725c9a8da04"), - info=networks.NetworkInfo( + info=messages.EthereumNetworkInfo( chain_id=8, slip44=108, shortcut="UBQ", @@ -54,7 +54,7 @@ NETWORKS = { # Ethereum Classic 61: InfoWithDefinition( definition=unhexlify("74727a64310000000063203488001d083d103d1a034554432210457468657265756d20436c617373696328000e6b891a57fe4c38c54b475f22f0d9242dd8ddab0b4f360bd86e37e2e8b79de5ef29237436351f7bc924cd110716b5adde7c28c03d76ac83b091dbce1b5d7d0edbddb221bd894806f7ea1b195443176e06830a83c0204e33f19c51d2fccc3a9f80ac2cca38822db998ddf76778dada240d39b3c6193c6335d7c693dea90d19a41f86855375c2f48c18cdc012ccac771aa316d776c8721c2b1f6d5980808337dfdae13b5be07e3cbc3526119b88c5eb44be0b1dab1094a5ec5215b47daf91736d16501f68b645aa487b9d159a8161404a218507641453ebf045cec56710bb7d873e102777695b56903766e1af16f95576ec4f41874bdaf80cec02ee067d30e721515564d4f30fa74a6c61eb784ea65cc881ead7af2ffac02d5bf1fe1a756918fe37b74828a24b640025cd79443ada60063e3034444fc49ed6055dbba6a09fa4484c42cb85abb49103dc8c781c8f190c4632e2dec30081770448021313955dbb49e8a02fd49b34d030280452fe0a5c3bcba4958bc287c67e12519be4f4aec7ab0c8e574e53a663f635f75508f23d92c77b2147f29feb79c38d0f793fba295aae605c7e8226523edefc6ad1eefe088e5b8376028bf90116ece4fb876510b4ae1c89686dbcaacbbac8225baba429ca376fafac50f4bd1ff4ce1c61dd53318d0718bf513ea6f770cce81e07a653622e4dbd03bdaa570bfe43219eb0d4fab725c9a8da04"), - info=networks.NetworkInfo( + info=messages.EthereumNetworkInfo( chain_id=61, slip44=61, shortcut="ETC", @@ -80,21 +80,23 @@ TOKENS = { # AAVE "7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9": InfoWithDefinition( definition=None, # built-in definitions are not encoded - info=tokens.TokenInfo( + info=messages.EthereumTokenInfo( symbol="AAVE", decimals=18, address=unhexlify("7fc66500c84a76ad7e9c93437bfc5ac33e2ddae9"), chain_id=1, + name="Aave", ), ), # TrueAUD "00006100f7090010005f1bd7ae6122c3c2cf0090": InfoWithDefinition( definition=unhexlify("74727a6431000000016320348800290a045441554410121a1400006100f7090010005f1bd7ae6122c3c2cf009020012a07547275654155440e310dad13f7d3012903a9a457134c9f38c62c04370cb92c7a528838e30a032dffbceeaa2aa849e590c4e6dbc69b0ea5359f3527b95b56ab59a33dc584105b35ea7c06afc296cc1c1e58cc3d6b461631c4c770b9409837ab3d29bc1b666fb9cf5245c4c218b0e9521c185d102f596905ba860e6f56a0a8b394f943855c74eea6fcac87210a9988ac02803f4cc61cf78e7e2409175a75f4f3a82eb84b1f2d1ea8177d5dccd62949d80d7942105e22a452be01859fe816736e803b120fb9bcc0c1117180dbda19e1ad1aafb9b9f1555c75275820bf7c1e568bcb265bdc4dfdae0511782026e11a151f6894d11128327c8c42958c9ae900af970fec13a11ffdeba6ac10733ca55a906142e0b9130312e8e85606108612581aca9087c452f38f14185db74828a24b640025cd79443ada60063e3034444fc49ed6055dbba6a09fa4484c42cb85abb49103dc8c781c8f190c4632e2dec30081770448021313955dbb49e8a02fd49b34d030280452fe0a5c3bcba4958bc287c67e12519be4f4aec7ab0c8e574e53a663f635f75508f23d92c77b2147f29feb79c38d0f793fba295aae605c7e8226523edefc6ad1eefe088e5b8376028bf90116ece4fb876510b4ae1c89686dbcaacbbac8225baba429ca376fafac50f4bd1ff4ce1c61dd53318d0718bf513ea6f770cce81e07a653622e4dbd03bdaa570bfe43219eb0d4fab725c9a8da04"), - info=tokens.TokenInfo( + info=messages.EthereumTokenInfo( symbol="TAUD", decimals=18, address=unhexlify("00006100f7090010005f1bd7ae6122c3c2cf0090"), chain_id=1, + name="TrueAUD", ), ), }, @@ -102,44 +104,20 @@ TOKENS = { # Karma Token "275a5b346599b56917e7b1c9de019dcf9ead861a": InfoWithDefinition( definition=unhexlify("74727a64310000000163203488002b0a024b4310121a14275a5b346599b56917e7b1c9de019dcf9ead861a20042a0b4b61726d6120546f6b656e0e2b3cb176ff5a2cf431620c1a7eee9aa297f5de36d29ae6d423166cf7391e41c5826c57f30b11421a4bf10f336f12050f6d959e02bfb17a8ce7ae15087d4f083124c0cebed2ce45b15b2608b1a8f0ee443e8c4f33111d880a6a3c09a77c627f82d68b62a1bd39975b2a2c86f196b9a3dcb62bdc3554fbf85b75331bc0d39f23a46f5ed91f208757d1136bb20b3618294fbfb0a826e9c09e392fe8109181bc6c28cad78db1987947f461bfc1042b88a91d6d61297d0cf194dfeea981b4515c2ed09dc2966671f5c715c64ceb25e53e1df3c7234e3e0ddf0dcd54d40fde0c51903685f9dc7fa69c71184f17af852e74490ea7286e89a0aa4770629664f7dd8eab8c4e009ff4c24682f85f7e01d4e10ae5c06212d5a4f43bac2b4f0e79383666ef12054ddbf757809aa6b446d65f7fd1bdd76fb1d7770398bd17af50635027e680801d244bd7b4f14c57edc3cd961722315e076120bf1d35db8520edb812bfbb5bab8ff57cc2dc1b3d1f9d95b33dba5d759aef1123f2ef346b6328973fba204fd745e644c8e492f9a76c0019b2cf21715fba682b46b9c58013e0b0927e5272c808a67e8226523edefc6ad1eefe088e5b8376028bf90116ece4fb876510b4ae1c89686dbcaacbbac8225baba429ca376fafac50f4bd1ff4ce1c61dd53318d0718bf513ea6f770cce81e07a653622e4dbd03bdaa570bfe43219eb0d4fab725c9a8da04"), - info=tokens.TokenInfo( + info=messages.EthereumTokenInfo( symbol="KC", decimals=18, address=unhexlify("275a5b346599b56917e7b1c9de019dcf9ead861a"), chain_id=4, + name="Karma Token", ), ), }, } -def equalNetworkInfo(n1: networks.NetworkInfo, n2: networks.NetworkInfo, msg: str = '') -> bool: - ensure( - cond=( - n1.chain_id == n2.chain_id - and n1.slip44 == n2.slip44 - and n1.shortcut == n2.shortcut - and n1.name == n2.name - and n1.rskip60 == n2.rskip60 - ), - msg=msg, - ) - - -def equalTokenInfo(t1: tokens.TokenInfo, t2: tokens.TokenInfo, msg: str = '') -> bool: - ensure( - cond=( - t1.symbol == t2.symbol - and t1.decimals == t2.decimals - and t1.address == t2.address - and t1.chain_id == t2.chain_id - ), - msg=msg, - ) - - -def construct_network_info(chain_id: int = 0, slip44: int = 0, shortcut: str = "", name: str = "", rskip60: bool = False) -> networks.NetworkInfo: - return networks.NetworkInfo( +def construct_network_info(chain_id: int = 0, slip44: int = 0, shortcut: str = "", name: str = "", rskip60: bool = False) -> messages.EthereumNetworkInfo: + return messages.EthereumNetworkInfo( chain_id=chain_id, slip44=slip44, shortcut=shortcut, @@ -154,8 +132,8 @@ def construct_token_info( address: bytes = b'', chain_id: int = 0, name: str = "", - ) -> tokens.TokenInfo: - return tokens.TokenInfo( + ) -> messages.EthereumTokenInfo: + return messages.EthereumTokenInfo( symbol=symbol, decimals=decimals, address=address, @@ -202,14 +180,14 @@ def get_ethereum_encoded_definition(chain_id: int | None = None, slip44: int | N ) -def builtin_networks_iterator() -> Iterator[networks.NetworkInfo]: +def builtin_networks_iterator() -> Iterator[messages.EthereumNetworkInfo]: """Mockup function replaces original function from core/src/apps/ethereum/networks.py used to get built-in network definitions.""" for _, network in NETWORKS.items(): if network.definition is None: yield network.info -def builtin_token_by_chain_address(chain_id: int, address: bytes) -> tokens.TokenInfo: +def builtin_token_by_chain_address(chain_id: int, address: bytes) -> messages.EthereumTokenInfo: """Mockup function replaces original function from core/src/apps/ethereum/tokens.py used to get built-in token definitions.""" address_str = hexlify(address).decode('hex') try: diff --git a/core/tests/test_apps.ethereum.definitions.py b/core/tests/test_apps.ethereum.definitions.py index b0ea4ab646..e49002715a 100644 --- a/core/tests/test_apps.ethereum.definitions.py +++ b/core/tests/test_apps.ethereum.definitions.py @@ -5,11 +5,14 @@ from ubinascii import hexlify # noqa: F401 if not utils.BITCOIN_ONLY: import apps.ethereum.definitions as dfs + from apps.ethereum import networks from ethereum_common import * from trezor import protobuf from trezor.enums import EthereumDefinitionType from trezor.messages import ( EthereumEncodedDefinitions, + EthereumNetworkInfo, + EthereumTokenInfo, EthereumGetAddress, EthereumGetPublicKey, EthereumSignMessage, @@ -59,10 +62,6 @@ class TestEthereumDefinitionParser(unittest.TestCase): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestDecodeDefinition(unittest.TestCase): - def setUp(self): - self.addTypeEqualityFunc(networks.NetworkInfo, equalNetworkInfo) - self.addTypeEqualityFunc(tokens.TokenInfo, equalTokenInfo) - # successful decode network def test_network_definition(self): rinkeby_network = get_ethereum_network_info_with_definition(chain_id=4) @@ -112,7 +111,6 @@ class TestDecodeDefinition(unittest.TestCase): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestGetNetworkDefiniton(unittest.TestCase): def setUp(self): - self.addTypeEqualityFunc(networks.NetworkInfo, equalNetworkInfo) # use mockup function for built-in networks networks._networks_iterator = builtin_networks_iterator @@ -153,7 +151,6 @@ class TestGetNetworkDefiniton(unittest.TestCase): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestGetTokenDefiniton(unittest.TestCase): def setUp(self): - self.addTypeEqualityFunc(tokens.TokenInfo, equalTokenInfo) # use mockup function for built-in tokens tokens.token_by_chain_address = builtin_token_by_chain_address @@ -200,8 +197,6 @@ class TestGetTokenDefiniton(unittest.TestCase): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestEthereumDefinitions(unittest.TestCase): def setUp(self): - self.addTypeEqualityFunc(networks.NetworkInfo, equalNetworkInfo) - self.addTypeEqualityFunc(tokens.TokenInfo, equalTokenInfo) # use mockup functions for built-in definitions networks._networks_iterator = builtin_networks_iterator tokens.token_by_chain_address = builtin_token_by_chain_address @@ -212,8 +207,8 @@ class TestEthereumDefinitions(unittest.TestCase): token_definition: bytes | None, ref_chain_id: int | None, ref_token_address: bytes | None, - network_info: networks.NetworkInfo | None, - token_info: tokens.TokenInfo | None, + network_info: EthereumNetworkInfo | None, + token_info: EthereumTokenInfo | None, ): # get definitions = dfs.EthereumDefinitions(network_definition, token_definition, ref_chain_id, ref_token_address) @@ -267,8 +262,6 @@ class TestEthereumDefinitions(unittest.TestCase): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestGetDefinitonsFromMsg(unittest.TestCase): def setUp(self): - self.addTypeEqualityFunc(networks.NetworkInfo, equalNetworkInfo) - self.addTypeEqualityFunc(tokens.TokenInfo, equalTokenInfo) # use mockup functions for built-in definitions networks._networks_iterator = builtin_networks_iterator tokens.token_by_chain_address = builtin_token_by_chain_address @@ -276,8 +269,8 @@ class TestGetDefinitonsFromMsg(unittest.TestCase): def get_and_compare_ethereum_definitions( self, msg: protobuf.MessageType, - network_info: networks.NetworkInfo | None, - token_info: tokens.TokenInfo | None, + network_info: EthereumNetworkInfo | None, + token_info: EthereumTokenInfo | None, ): # get definitions = dfs.get_definitions_from_msg(msg) diff --git a/core/tests/test_apps.ethereum.helpers.py b/core/tests/test_apps.ethereum.helpers.py index a8cdf995da..a8bb2d851d 100644 --- a/core/tests/test_apps.ethereum.helpers.py +++ b/core/tests/test_apps.ethereum.helpers.py @@ -3,7 +3,7 @@ from apps.common.paths import HARDENED if not utils.BITCOIN_ONLY: from apps.ethereum.helpers import address_from_bytes - from apps.ethereum.networks import NetworkInfo + from trezor.messages import EthereumNetworkInfo @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @@ -40,7 +40,7 @@ class TestEthereumGetAddress(unittest.TestCase): '0xdbF03B407C01E7cd3cbEa99509D93f8dDDc8C6fB', '0xd1220a0CF47c7B9Be7A2E6Ba89f429762E7b9adB' ] - n = NetworkInfo(chain_id=30, slip44=1, shortcut='T', name='T', rskip60=True) + n = EthereumNetworkInfo(chain_id=30, slip44=1, shortcut='T', name='T', rskip60=True) for s in rskip60_chain_30: b = unhexlify(s[2:]) h = address_from_bytes(b, n) diff --git a/python/src/trezorlib/cli/ethereum.py b/python/src/trezorlib/cli/ethereum.py index c47639eb9c..40d6aabc41 100644 --- a/python/src/trezorlib/cli/ethereum.py +++ b/python/src/trezorlib/cli/ethereum.py @@ -19,14 +19,15 @@ import pathlib import re import sys import tarfile -from io import BytesIO from decimal import Decimal +from io import BytesIO from typing import ( + NoReturn, TYPE_CHECKING, Any, + BinaryIO, Dict, List, - NoReturn, Optional, Sequence, TextIO, @@ -168,30 +169,37 @@ def _format_access_list( def _get_ethereum_definitions( - definitions_dir: pathlib.Path = None, - network_def_file: TextIO = None, - token_def_file: TextIO = None, - download_definitions: bool = False, - chain_id: Optional[int] = None, - slip44_hardened: Optional[int] = None, - token_address: Optional[str] = None, - ) -> ethereum.messages.EthereumEncodedDefinitions: + definitions_dir: pathlib.Path = None, + network_def_file: BinaryIO = None, + token_def_file: BinaryIO = None, + download_definitions: bool = False, + chain_id: Optional[int] = None, + slip44_hardened: Optional[int] = None, + token_address: Optional[str] = None, +) -> ethereum.messages.EthereumEncodedDefinitions: count_of_options_used = sum( - bool(o) for o in ( + bool(o) + for o in ( definitions_dir, (network_def_file or token_def_file), - download_definitions + download_definitions, ) ) if count_of_options_used > 1: - raise click.ClickException("More than one mutually exclusive option for definitions was used. See --help for more info.") + raise click.ClickException( + "More than one mutually exclusive option for definitions was used. See --help for more info." + ) defs = ethereum.messages.EthereumEncodedDefinitions() if definitions_dir is not None: if chain_id is not None or slip44_hardened is not None: - defs.encoded_network = ethereum.network_definition_from_dir(definitions_dir, chain_id, slip44_hardened) + defs.encoded_network = ethereum.network_definition_from_dir( + definitions_dir, chain_id, slip44_hardened + ) if chain_id is not None and token_address is not None: - defs.encoded_token = ethereum.token_definition_from_dir(definitions_dir, chain_id, token_address) + defs.encoded_token = ethereum.token_definition_from_dir( + definitions_dir, chain_id, token_address + ) elif network_def_file is not None or token_def_file is not None: if network_def_file is not None: with network_def_file: @@ -201,9 +209,13 @@ def _get_ethereum_definitions( defs.encoded_token = token_def_file.read() elif download_definitions: if chain_id is not None or slip44_hardened is not None: - defs.encoded_network = ethereum.download_network_definition(chain_id, slip44_hardened) + defs.encoded_network = ethereum.download_network_definition( + chain_id, slip44_hardened + ) if chain_id is not None and token_address is not None: - defs.encoded_token = ethereum.download_token_definition(chain_id, token_address) + defs.encoded_token = ethereum.download_token_definition( + chain_id, token_address + ) return defs @@ -215,26 +227,28 @@ def _get_ethereum_definitions( definitions_dir_option = click.option( "--definitions-dir", - type=click.Path(exists=True, file_okay=False, resolve_path=True, path_type=pathlib.Path), - help="Directory with stored definitions. Directory structure should be the same as it is in downloaded archive from " \ - "`https:\\data.trezor.io\definitions\???`. Mutually exclusive with `--network-def`, `--token-def` and " \ - "`--download-definitions`.", # TODO: add link?, replace this ur with function used to download defs + type=click.Path( + exists=True, file_okay=False, resolve_path=True, path_type=pathlib.Path + ), + help="Directory with stored definitions. Directory structure should be the same as it is in downloaded archive from " + r"`https:\\data.trezor.io\definitions\???`. Mutually exclusive with `--network-def`, `--token-def` and " + "`--download-definitions`.", # TODO: add link?, replace this ur with function used to download defs ) network_def_option = click.option( "--network-def", type=click.File(mode="rb"), - help="Binary file with network definition. Mutually exclusive with `--definitions-dir` and `--download-definitions`." + help="Binary file with network definition. Mutually exclusive with `--definitions-dir` and `--download-definitions`.", ) token_def_options = click.option( "--token-def", type=click.File(mode="rb"), - help="Binary file with token definition. Mutually exclusive with `--definitions-dir` and `--download-definitions`." + help="Binary file with token definition. Mutually exclusive with `--definitions-dir` and `--download-definitions`.", ) download_definitions_option = click.option( "--download-definitions", is_flag=True, - help="Automatically download required definitions from `data.trezor.io\definitions` and use them. " \ - "Mutually exclusive with `--definitions-dir`, `--network-def` and `--token-def`." + help=r"Automatically download required definitions from `data.trezor.io\definitions` and use them. " + "Mutually exclusive with `--definitions-dir`, `--network-def` and `--token-def`.", ) @@ -244,14 +258,21 @@ def cli() -> None: @cli.command() -@click.option("-o", "--outdir", type=click.Path(resolve_path=True, file_okay=False, path_type=pathlib.Path), default="./definitions-latest") +@click.option( + "-o", + "--outdir", + type=click.Path(resolve_path=True, file_okay=False, path_type=pathlib.Path), + default="./definitions-latest", +) @click.option("-u", "--unpack", is_flag=True) -def download_definitions(outdir: pathlib.Path, unpack: bool) -> str: +def download_definitions(outdir: pathlib.Path, unpack: bool) -> None: """Download all Ethereum network and token definitions and save them.""" archive_filename = "definitions.tar.gz" # TODO: change once we know the urls - archived_definitions = ethereum.download_from_url("https://data.trezor.io/eth_definitions/" + archive_filename) + archived_definitions = ethereum.download_from_url( + "https://data.trezor.io/eth_definitions/" + archive_filename + ) # unpack and/or save if unpack: @@ -271,7 +292,14 @@ def download_definitions(outdir: pathlib.Path, unpack: bool) -> str: @network_def_option @download_definitions_option @with_client -def get_address(client: "TrezorClient", address: str, show_display: bool, definitions_dir: pathlib.Path, network_def: TextIO, download_definitions: bool) -> str: +def get_address( + client: "TrezorClient", + address: str, + show_display: bool, + definitions_dir: pathlib.Path, + network_def: BinaryIO, + download_definitions: bool, +) -> str: """Get Ethereum address in hex encoding.""" address_n = tools.parse_path(address) defs = _get_ethereum_definitions( @@ -290,7 +318,14 @@ def get_address(client: "TrezorClient", address: str, show_display: bool, defini @network_def_option @download_definitions_option @with_client -def get_public_node(client: "TrezorClient", address: str, show_display: bool, definitions_dir: pathlib.Path, network_def: TextIO, download_definitions: bool) -> dict: +def get_public_node( + client: "TrezorClient", + address: str, + show_display: bool, + definitions_dir: pathlib.Path, + network_def: BinaryIO, + download_definitions: bool, +) -> dict: """Get Ethereum public node of given path.""" address_n = tools.parse_path(address) defs = _get_ethereum_definitions( @@ -299,7 +334,12 @@ def get_public_node(client: "TrezorClient", address: str, show_display: bool, de download_definitions=download_definitions, slip44_hardened=address_n[1], ) - result = ethereum.get_public_node(client, address_n, show_display=show_display, encoded_network=defs.encoded_network) + result = ethereum.get_public_node( + client, + address_n, + show_display=show_display, + encoded_network=defs.encoded_network, + ) return { "node": { "depth": result.node.depth, @@ -380,8 +420,8 @@ def sign_tx( access_list: List[ethereum.messages.EthereumAccessList], eip2718_type: Optional[int], definitions_dir: pathlib.Path, - network_def: TextIO, - token_def: TextIO, + network_def: BinaryIO, + token_def: BinaryIO, download_definitions: bool, ) -> str: """Sign (and optionally publish) Ethereum transaction. @@ -457,7 +497,7 @@ def sign_tx( token_def_file=token_def, download_definitions=download_definitions, chain_id=chain_id, - token_address=to_address + token_address=to_address, ) if is_eip1559: @@ -547,7 +587,14 @@ def sign_tx( @network_def_option @download_definitions_option @with_client -def sign_message(client: "TrezorClient", address: str, message: str, definitions_dir: pathlib.Path, network_def: TextIO, download_definitions: bool) -> Dict[str, str]: +def sign_message( + client: "TrezorClient", + address: str, + message: str, + definitions_dir: pathlib.Path, + network_def: BinaryIO, + download_definitions: bool, +) -> Dict[str, str]: """Sign message with Ethereum address.""" address_n = tools.parse_path(address) defs = _get_ethereum_definitions( @@ -578,7 +625,13 @@ def sign_message(client: "TrezorClient", address: str, message: str, definitions @download_definitions_option @with_client def sign_typed_data( - client: "TrezorClient", address: str, metamask_v4_compat: bool, file: TextIO, definitions_dir: pathlib.Path, network_def: TextIO, download_definitions: bool + client: "TrezorClient", + address: str, + metamask_v4_compat: bool, + file: TextIO, + definitions_dir: pathlib.Path, + network_def: BinaryIO, + download_definitions: bool, ) -> Dict[str, str]: """Sign typed data (EIP-712) with Ethereum address. @@ -595,7 +648,11 @@ def sign_typed_data( slip44_hardened=address_n[1], ) ret = ethereum.sign_typed_data( - client, address_n, data, metamask_v4_compat=metamask_v4_compat, encoded_network=defs.encoded_network + client, + address_n, + data, + metamask_v4_compat=metamask_v4_compat, + encoded_network=defs.encoded_network, ) output = { "address": ret.address, @@ -613,7 +670,13 @@ def sign_typed_data( @download_definitions_option @with_client def verify_message( - client: "TrezorClient", address: str, signature: str, message: str, definitions_dir: pathlib.Path, network_def: TextIO, download_definitions: bool + client: "TrezorClient", + address: str, + signature: str, + message: str, + definitions_dir: pathlib.Path, + network_def: BinaryIO, + download_definitions: bool, ) -> bool: """Verify message signed with Ethereum address.""" chain_id = 1 @@ -624,7 +687,9 @@ def verify_message( download_definitions=download_definitions, chain_id=chain_id, ) - return ethereum.verify_message(client, address, signature_bytes, message, chain_id, defs.encoded_network) + return ethereum.verify_message( + client, address, signature_bytes, message, chain_id, defs.encoded_network + ) @cli.command() diff --git a/python/src/trezorlib/ethereum.py b/python/src/trezorlib/ethereum.py index 4aaef2cefb..fe804d7214 100644 --- a/python/src/trezorlib/ethereum.py +++ b/python/src/trezorlib/ethereum.py @@ -14,12 +14,14 @@ # You should have received a copy of the License along with this library. # If not, see . -from itertools import chain -import pathlib, re, requests -from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, TextIO, Tuple +import pathlib +import re +from typing import TYPE_CHECKING, Any, AnyStr, Dict, List, Optional, Tuple + +import requests from . import exceptions, messages -from .tools import expect, UH_, prepare_message_bytes, session +from .tools import UH_, expect, prepare_message_bytes, session if TYPE_CHECKING: from .client import TrezorClient @@ -28,11 +30,11 @@ if TYPE_CHECKING: # TODO: change once we know the urls -DEFS_BASE_URL="https://data.trezor.io/eth_definitions/{lookup_type}/{id}/{name}.dat" -DEFS_NETWORK_BY_CHAINID_LOOKUP_TYPE="by_chain_id" -DEFS_NETWORK_BY_SLIP44_LOOKUP_TYPE="by_slip44" -DEFS_NETWORK_URI_NAME="network" -DEFS_TOKEN_URI_NAME="token_{hex_address}" +DEFS_BASE_URL = "https://data.trezor.io/eth_definitions/{lookup_type}/{id}/{name}.dat" +DEFS_NETWORK_BY_CHAINID_LOOKUP_TYPE = "by_chain_id" +DEFS_NETWORK_BY_SLIP44_LOOKUP_TYPE = "by_slip44" +DEFS_NETWORK_URI_NAME = "network" +DEFS_TOKEN_URI_NAME = "token_{hex_address}" def int_to_big_endian(value: int) -> bytes: @@ -159,9 +161,13 @@ def download_from_url(url: str, error_msg: str = "") -> bytes: raise RuntimeError(f"{error_msg}{err}") -def download_network_definition(chain_id: Optional[int] = None, slip44_hardened: Optional[int] = None) -> Optional[bytes]: - if not ((chain_id is None) != (slip44_hardened is None)): # not XOR - raise RuntimeError(f"Exactly one of chain_id or slip44_hardened parameters are needed to load network definition from directory.") +def download_network_definition( + chain_id: Optional[int] = None, slip44_hardened: Optional[int] = None +) -> Optional[bytes]: + if not ((chain_id is None) != (slip44_hardened is None)): # not XOR + raise RuntimeError( + "Exactly one of chain_id or slip44_hardened parameters are needed to load network definition from directory." + ) if chain_id is not None: url = DEFS_BASE_URL.format( @@ -172,17 +178,21 @@ def download_network_definition(chain_id: Optional[int] = None, slip44_hardened: else: url = DEFS_BASE_URL.format( lookup_type=DEFS_NETWORK_BY_SLIP44_LOOKUP_TYPE, - id=UH_(slip44_hardened), + id=UH_(slip44_hardened), # type: ignore [Argument of type "int | None" cannot be assigned to parameter "x" of type "int" in function "UH_"] name=DEFS_NETWORK_URI_NAME, ) - error_msg = f"While downloading network definition from \"{url}\" following HTTP error occured: " + error_msg = f'While downloading network definition from "{url}" following HTTP error occured: ' return download_from_url(url, error_msg) -def download_token_definition(chain_id: Optional[int] = None, token_address: Optional[str] = None) -> Optional[bytes]: +def download_token_definition( + chain_id: Optional[int] = None, token_address: Optional[str] = None +) -> Optional[bytes]: if chain_id is None or token_address is None: - raise RuntimeError(f"Both chain_id and token_address parameters are needed to download token definition.") + raise RuntimeError( + "Both chain_id and token_address parameters are needed to download token definition." + ) url = DEFS_BASE_URL.format( lookup_type=DEFS_NETWORK_BY_CHAINID_LOOKUP_TYPE, @@ -190,37 +200,64 @@ def download_token_definition(chain_id: Optional[int] = None, token_address: Opt name=DEFS_TOKEN_URI_NAME.format(hex_address=token_address), ) - error_msg = f"While downloading token definition from \"{url}\" following HTTP error occured: " + error_msg = f'While downloading token definition from "{url}" following HTTP error occured: ' return download_from_url(url, error_msg) -def network_definition_from_dir(path: pathlib.Path, chain_id: Optional[int] = None, slip44_hardened: Optional[int] = None) -> Optional[bytes]: - if not ((chain_id is None) != (slip44_hardened is None)): # not XOR - raise RuntimeError(f"Exactly one of chain_id or slip44_hardened parameters are needed to load network definition from directory.") +def network_definition_from_dir( + path: pathlib.Path, + chain_id: Optional[int] = None, + slip44_hardened: Optional[int] = None, +) -> Optional[bytes]: + if not ((chain_id is None) != (slip44_hardened is None)): # not XOR + raise RuntimeError( + "Exactly one of chain_id or slip44_hardened parameters are needed to load network definition from directory." + ) def read_definition(path: pathlib.Path) -> Optional[bytes]: if not path.exists() or not path.is_file(): return None with open(path, mode="rb") as f: - return f.read() + return f.read() if chain_id is not None: - return read_definition(path / DEFS_NETWORK_BY_CHAINID_LOOKUP_TYPE / str(chain_id) / (DEFS_NETWORK_URI_NAME + ".dat")) + return read_definition( + path + / DEFS_NETWORK_BY_CHAINID_LOOKUP_TYPE + / str(chain_id) + / (DEFS_NETWORK_URI_NAME + ".dat") + ) else: - return read_definition(path / DEFS_NETWORK_BY_SLIP44_LOOKUP_TYPE / str(UH_(slip44_hardened)) / (DEFS_NETWORK_URI_NAME + ".dat")) + return read_definition( + path + / DEFS_NETWORK_BY_SLIP44_LOOKUP_TYPE + / str(UH_(slip44_hardened)) # type: ignore [Argument of type "int | None" cannot be assigned to parameter "x" of type "int" in function "UH_"] + / (DEFS_NETWORK_URI_NAME + ".dat") + ) -def token_definition_from_dir(path: pathlib.Path, chain_id: Optional[int] = None, token_address: Optional[str] = None) -> Optional[bytes]: +def token_definition_from_dir( + path: pathlib.Path, + chain_id: Optional[int] = None, + token_address: Optional[str] = None, +) -> Optional[bytes]: if chain_id is None or token_address is None: - raise RuntimeError(f"Both chain_id and token_address parameters are needed to load token definition from directory.") + raise RuntimeError( + "Both chain_id and token_address parameters are needed to load token definition from directory." + ) - path = path / DEFS_NETWORK_BY_CHAINID_LOOKUP_TYPE / str(chain_id) / (DEFS_TOKEN_URI_NAME.format(hex_address=token_address) + ".dat") + path = ( + path + / DEFS_NETWORK_BY_CHAINID_LOOKUP_TYPE + / str(chain_id) + / (DEFS_TOKEN_URI_NAME.format(hex_address=token_address) + ".dat") + ) if not path.exists() or not path.is_file(): return None with open(path, mode="rb") as f: - return f.read() + return f.read() # ====== Client functions ====== # @@ -228,7 +265,10 @@ def token_definition_from_dir(path: pathlib.Path, chain_id: Optional[int] = None @expect(messages.EthereumAddress, field="address", ret_type=str) def get_address( - client: "TrezorClient", n: "Address", show_display: bool = False, encoded_network: bytes = None + client: "TrezorClient", + n: "Address", + show_display: bool = False, + encoded_network: bytes = None, ) -> "MessageType": return client.call( messages.EthereumGetAddress( @@ -241,7 +281,10 @@ def get_address( @expect(messages.EthereumPublicKey) def get_public_node( - client: "TrezorClient", n: "Address", show_display: bool = False, encoded_network: bytes = None + client: "TrezorClient", + n: "Address", + show_display: bool = False, + encoded_network: bytes = None, ) -> "MessageType": return client.call( messages.EthereumGetPublicKey( @@ -446,7 +489,12 @@ def sign_typed_data( def verify_message( - client: "TrezorClient", address: str, signature: bytes, message: AnyStr, chain_id: int = 1, encoded_network: bytes = None + client: "TrezorClient", + address: str, + signature: bytes, + message: AnyStr, + chain_id: int = 1, + encoded_network: bytes = None, ) -> bool: try: resp = client.call(