From 0dff3853a76d8704bb7de7ce5b42759178b96a0e Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 20 Apr 2020 11:53:33 +0200 Subject: [PATCH] core/ethereum: introduce custom keychain decorators --- core/src/apps/ethereum/__init__.py | 14 ++---- core/src/apps/ethereum/get_address.py | 2 + core/src/apps/ethereum/get_public_key.py | 2 + core/src/apps/ethereum/keychain.py | 64 ++++++++++++++++++++++++ core/src/apps/ethereum/networks.py | 6 ++- core/src/apps/ethereum/networks.py.mako | 6 ++- core/src/apps/ethereum/sign_message.py | 2 + core/src/apps/ethereum/sign_tx.py | 2 + 8 files changed, 86 insertions(+), 12 deletions(-) create mode 100644 core/src/apps/ethereum/keychain.py diff --git a/core/src/apps/ethereum/__init__.py b/core/src/apps/ethereum/__init__.py index 4117dd4f00..1ce060d24e 100644 --- a/core/src/apps/ethereum/__init__.py +++ b/core/src/apps/ethereum/__init__.py @@ -1,18 +1,12 @@ from trezor import wire from trezor.messages import MessageType -from apps.common import HARDENED -from apps.ethereum.networks import all_slip44_ids_hardened - CURVE = "secp256k1" def boot() -> None: - ns = [] - for i in all_slip44_ids_hardened(): - ns.append([CURVE, HARDENED | 44, i]) - wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns) - wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key", ns) - wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns) - wire.add(MessageType.EthereumSignMessage, __name__, "sign_message", ns) + wire.add(MessageType.EthereumGetAddress, __name__, "get_address") + wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key") + wire.add(MessageType.EthereumSignTx, __name__, "sign_tx") + wire.add(MessageType.EthereumSignMessage, __name__, "sign_message") wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message") diff --git a/core/src/apps/ethereum/get_address.py b/core/src/apps/ethereum/get_address.py index 002cad70ae..d8f09adc72 100644 --- a/core/src/apps/ethereum/get_address.py +++ b/core/src/apps/ethereum/get_address.py @@ -6,8 +6,10 @@ from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr from apps.ethereum import CURVE, networks from apps.ethereum.address import address_from_bytes, validate_full_path +from apps.ethereum.keychain import with_keychain_from_path +@with_keychain_from_path async def get_address(ctx, msg, keychain): await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) diff --git a/core/src/apps/ethereum/get_public_key.py b/core/src/apps/ethereum/get_public_key.py index e463473a89..14f9e8c21e 100644 --- a/core/src/apps/ethereum/get_public_key.py +++ b/core/src/apps/ethereum/get_public_key.py @@ -3,8 +3,10 @@ from trezor.messages.HDNodeType import HDNodeType from apps.common import coins, layout, paths from apps.ethereum import CURVE, address +from apps.ethereum.keychain import with_keychain_from_path +@with_keychain_from_path async def get_public_key(ctx, msg, keychain): await paths.validate_path( ctx, address.validate_path_for_get_public_key, keychain, msg.address_n, CURVE diff --git a/core/src/apps/ethereum/keychain.py b/core/src/apps/ethereum/keychain.py new file mode 100644 index 0000000000..4b94e07e91 --- /dev/null +++ b/core/src/apps/ethereum/keychain.py @@ -0,0 +1,64 @@ +from trezor import wire + +from . import CURVE, networks + +from apps.common import HARDENED, seed +from apps.common.seed import get_keychain + +if False: + from typing import List + from typing_extensions import Protocol + + from protobuf import MessageType + + from trezor.messages.EthereumSignTx import EthereumSignTx + + from apps.common.seed import MsgOut, Handler, HandlerWithKeychain + + class MsgWithAddressN(MessageType, Protocol): + address_n = ... # type: List[int] + + +async def from_address_n(ctx: wire.Context, address_n: List[int]) -> seed.Keychain: + if len(address_n) < 2: + raise wire.DataError("Forbidden key path") + slip44_hardened = address_n[1] + if slip44_hardened not in networks.all_slip44_ids_hardened(): + raise wire.DataError("Forbidden key path") + namespace = CURVE, [44 | HARDENED, slip44_hardened] + return await get_keychain(ctx, [namespace]) + + +def with_keychain_from_path( + func: HandlerWithKeychain[MsgWithAddressN, MsgOut] +) -> Handler[MsgWithAddressN, MsgOut]: + async def wrapper(ctx: wire.Context, msg: MsgWithAddressN) -> MsgOut: + keychain = await from_address_n(ctx, msg.address_n) + with keychain: + return await func(ctx, msg, keychain) + + return wrapper + + +def with_keychain_from_chain_id( + func: HandlerWithKeychain[EthereumSignTx, MsgOut] +) -> Handler[EthereumSignTx, MsgOut]: + async def wrapper(ctx: wire.Context, msg: EthereumSignTx) -> MsgOut: + if msg.chain_id is None: + keychain = await from_address_n(ctx, msg.address_n) + else: + info = networks.by_chain_id(msg.chain_id) + if info is None: + raise wire.DataError("Unsupported chain id") + + slip44 = info.slip44 + if networks.is_wanchain(msg.chain_id, msg.tx_type): + slip44 = networks.SLIP44_WANCHAIN + + namespace = CURVE, [44 | HARDENED, slip44 | HARDENED] + keychain = await get_keychain(ctx, [namespace]) + + with keychain: + return await func(ctx, msg, keychain) + + return wrapper diff --git a/core/src/apps/ethereum/networks.py b/core/src/apps/ethereum/networks.py index d681674b3e..537658359f 100644 --- a/core/src/apps/ethereum/networks.py +++ b/core/src/apps/ethereum/networks.py @@ -12,8 +12,12 @@ if False: from typing import Iterator, Optional +def is_wanchain(chain_id: int, tx_type: int) -> bool: + return tx_type in (1, 6) and chain_id in (1, 3) + + def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str: - if tx_type in (1, 6) and chain_id in (1, 3): + if is_wanchain(chain_id, tx_type): return "WAN" else: n = by_chain_id(chain_id) diff --git a/core/src/apps/ethereum/networks.py.mako b/core/src/apps/ethereum/networks.py.mako index 93e16ca03f..bcbf9c58d0 100644 --- a/core/src/apps/ethereum/networks.py.mako +++ b/core/src/apps/ethereum/networks.py.mako @@ -12,8 +12,12 @@ if False: from typing import Iterator, Optional +def is_wanchain(chain_id: int, tx_type: int) -> bool: + return tx_type in (1, 6) and chain_id in (1, 3) + + def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str: - if tx_type in (1, 6) and chain_id in (1, 3): + if is_wanchain(chain_id, tx_type): return "WAN" else: n = by_chain_id(chain_id) diff --git a/core/src/apps/ethereum/sign_message.py b/core/src/apps/ethereum/sign_message.py index 47c978f42c..7f9619927e 100644 --- a/core/src/apps/ethereum/sign_message.py +++ b/core/src/apps/ethereum/sign_message.py @@ -6,6 +6,7 @@ from trezor.utils import HashWriter from apps.common import paths from apps.common.signverify import require_confirm_sign_message from apps.ethereum import CURVE, address +from apps.ethereum.keychain import with_keychain_from_path def message_digest(message): @@ -17,6 +18,7 @@ def message_digest(message): return h.get_digest() +@with_keychain_from_path async def sign_message(ctx, msg, keychain): await paths.validate_path( ctx, address.validate_full_path, keychain, msg.address_n, CURVE diff --git a/core/src/apps/ethereum/sign_tx.py b/core/src/apps/ethereum/sign_tx.py index 0f0ad5923b..839e777b5d 100644 --- a/core/src/apps/ethereum/sign_tx.py +++ b/core/src/apps/ethereum/sign_tx.py @@ -10,6 +10,7 @@ from trezor.utils import HashWriter from apps.common import paths from apps.ethereum import CURVE, address, tokens from apps.ethereum.address import validate_full_path +from apps.ethereum.keychain import with_keychain_from_chain_id from apps.ethereum.layout import ( require_confirm_data, require_confirm_fee, @@ -20,6 +21,7 @@ from apps.ethereum.layout import ( MAX_CHAIN_ID = 2147483629 +@with_keychain_from_chain_id async def sign_tx(ctx, msg, keychain): msg = sanitize(msg) check(msg)