diff --git a/core/src/apps/ethereum/get_address.py b/core/src/apps/ethereum/get_address.py index 562992e63..ec4212c07 100644 --- a/core/src/apps/ethereum/get_address.py +++ b/core/src/apps/ethereum/get_address.py @@ -7,10 +7,10 @@ from apps.common.layout import address_n_to_str, show_address, show_qr from . import networks from .address import address_from_bytes -from .keychain import PATTERN_ADDRESS, with_keychain_from_path +from .keychain import PATTERNS_ADDRESS, with_keychain_from_path -@with_keychain_from_path(PATTERN_ADDRESS) +@with_keychain_from_path(*PATTERNS_ADDRESS) async def get_address(ctx, msg, keychain): await paths.validate_path(ctx, keychain, msg.address_n) diff --git a/core/src/apps/ethereum/get_public_key.py b/core/src/apps/ethereum/get_public_key.py index 8d692bd3c..1c78198d5 100644 --- a/core/src/apps/ethereum/get_public_key.py +++ b/core/src/apps/ethereum/get_public_key.py @@ -3,10 +3,10 @@ from trezor.messages.HDNodeType import HDNodeType from apps.common import coins, layout, paths -from .keychain import PATTERN_PUBKEY, with_keychain_from_path +from .keychain import with_keychain_from_path -@with_keychain_from_path(PATTERN_PUBKEY) +@with_keychain_from_path(paths.PATTERN_BIP44_PUBKEY) async def get_public_key(ctx, msg, keychain): await paths.validate_path(ctx, keychain, msg.address_n) node = keychain.derive(msg.address_n) diff --git a/core/src/apps/ethereum/keychain.py b/core/src/apps/ethereum/keychain.py index 48f3e384c..2426c7747 100644 --- a/core/src/apps/ethereum/keychain.py +++ b/core/src/apps/ethereum/keychain.py @@ -6,7 +6,7 @@ from apps.common.keychain import get_keychain from . import CURVE, networks if False: - from typing import Callable + from typing import Callable, Iterable from typing_extensions import Protocol from protobuf import MessageType @@ -22,32 +22,30 @@ if False: # We believe Ethereum should use 44'/60'/a' for everything, because it is # account-based, rather than UTXO-based. Unfortunately, lot of Ethereum # tools (MEW, Metamask) do not use such scheme and set a = 0 and then -# iterate the address index i. Therefore for compatibility reasons we use -# the same scheme: 44'/60'/0'/0/i and only the i is being iterated. +# iterate the address index i. For compatibility, we allow this scheme as well. -PATTERN_ADDRESS = "m/44'/coin_type'/0'/0/address_index" -PATTERN_PUBKEY = "m/44'/coin_type'/0'/*" +PATTERNS_ADDRESS = (paths.PATTERN_BIP44, paths.PATTERN_SEP5) -def _schema_from_address_n( - pattern: str, address_n: paths.Bip32Path -) -> paths.PathSchema: +def _schemas_from_address_n( + patterns: Iterable[str], address_n: paths.Bip32Path +) -> Iterable[paths.PathSchema]: if len(address_n) < 2: - return paths.SCHEMA_NO_MATCH + return () slip44_hardened = address_n[1] if slip44_hardened not in networks.all_slip44_ids_hardened(): - return paths.SCHEMA_NO_MATCH + return () if not slip44_hardened & HARDENED: - return paths.SCHEMA_ANY_PATH + return () slip44_id = slip44_hardened - HARDENED - return paths.PathSchema(pattern, slip44_id) + return (paths.PathSchema(pattern, slip44_id) for pattern in patterns) def with_keychain_from_path( - pattern: str, + *patterns: str, ) -> Callable[ [HandlerWithKeychain[MsgWithAddressN, MsgOut]], Handler[MsgWithAddressN, MsgOut] ]: @@ -55,8 +53,8 @@ def with_keychain_from_path( func: HandlerWithKeychain[MsgWithAddressN, MsgOut] ) -> Handler[MsgWithAddressN, MsgOut]: async def wrapper(ctx: wire.Context, msg: MsgWithAddressN) -> MsgOut: - schema = _schema_from_address_n(pattern, msg.address_n) - keychain = await get_keychain(ctx, CURVE, [schema]) + schemas = _schemas_from_address_n(patterns, msg.address_n) + keychain = await get_keychain(ctx, CURVE, schemas) with keychain: return await func(ctx, msg, keychain) @@ -65,18 +63,23 @@ def with_keychain_from_path( return decorator -def _schema_from_chain_id(msg: EthereumSignTx) -> paths.PathSchema: +def _schemas_from_chain_id(msg: EthereumSignTx) -> Iterable[paths.PathSchema]: if msg.chain_id is None: - return _schema_from_address_n(PATTERN_ADDRESS, msg.address_n) + return _schemas_from_address_n(PATTERNS_ADDRESS, msg.address_n) info = networks.by_chain_id(msg.chain_id) if info is None: - return paths.SCHEMA_NO_MATCH + return () - slip44_id = info.slip44 if networks.is_wanchain(msg.chain_id, msg.tx_type): - slip44_id = networks.SLIP44_WANCHAIN - return paths.PathSchema(PATTERN_ADDRESS, slip44_id) + slip44_id = (networks.SLIP44_WANCHAIN,) + elif info.slip44 != 60 and info.slip44 != 1: + # allow cross-signing with Ethereum unless it's testnet + slip44_id = (info.slip44, 60) + else: + slip44_id = (info.slip44,) + + return (paths.PathSchema(pattern, slip44_id) for pattern in PATTERNS_ADDRESS) def with_keychain_from_chain_id( @@ -84,8 +87,8 @@ def with_keychain_from_chain_id( ) -> Handler[EthereumSignTx, MsgOut]: # this is only for SignTx, and only PATTERN_ADDRESS is allowed async def wrapper(ctx: wire.Context, msg: EthereumSignTx) -> MsgOut: - schema = _schema_from_chain_id(msg) - keychain = await get_keychain(ctx, CURVE, [schema]) + schemas = _schemas_from_chain_id(msg) + keychain = await get_keychain(ctx, CURVE, schemas) with keychain: return await func(ctx, msg, keychain) diff --git a/core/src/apps/ethereum/sign_message.py b/core/src/apps/ethereum/sign_message.py index 0052f365c..0d1a4c89c 100644 --- a/core/src/apps/ethereum/sign_message.py +++ b/core/src/apps/ethereum/sign_message.py @@ -7,7 +7,7 @@ from apps.common import paths from apps.common.signverify import require_confirm_sign_message from . import address -from .keychain import PATTERN_ADDRESS, with_keychain_from_path +from .keychain import PATTERNS_ADDRESS, with_keychain_from_path def message_digest(message): @@ -19,7 +19,7 @@ def message_digest(message): return h.get_digest() -@with_keychain_from_path(PATTERN_ADDRESS) +@with_keychain_from_path(*PATTERNS_ADDRESS) async def sign_message(ctx, msg, keychain): await paths.validate_path(ctx, keychain, msg.address_n) await require_confirm_sign_message(ctx, "ETH", msg.message)