1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-03 03:50:58 +00:00

core/ethereum: introduce custom keychain decorators

This commit is contained in:
matejcik 2020-04-20 11:53:33 +02:00 committed by matejcik
parent b594248ac2
commit 0dff3853a7
8 changed files with 86 additions and 12 deletions

View File

@ -1,18 +1,12 @@
from trezor import wire from trezor import wire
from trezor.messages import MessageType from trezor.messages import MessageType
from apps.common import HARDENED
from apps.ethereum.networks import all_slip44_ids_hardened
CURVE = "secp256k1" CURVE = "secp256k1"
def boot() -> None: def boot() -> None:
ns = [] wire.add(MessageType.EthereumGetAddress, __name__, "get_address")
for i in all_slip44_ids_hardened(): wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key")
ns.append([CURVE, HARDENED | 44, i]) wire.add(MessageType.EthereumSignTx, __name__, "sign_tx")
wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns) wire.add(MessageType.EthereumSignMessage, __name__, "sign_message")
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns)
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message", ns)
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message") wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")

View File

@ -6,8 +6,10 @@ from apps.common import paths
from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.layout import address_n_to_str, show_address, show_qr
from apps.ethereum import CURVE, networks from apps.ethereum import CURVE, networks
from apps.ethereum.address import address_from_bytes, validate_full_path from apps.ethereum.address import address_from_bytes, validate_full_path
from apps.ethereum.keychain import with_keychain_from_path
@with_keychain_from_path
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)

View File

@ -3,8 +3,10 @@ from trezor.messages.HDNodeType import HDNodeType
from apps.common import coins, layout, paths from apps.common import coins, layout, paths
from apps.ethereum import CURVE, address from apps.ethereum import CURVE, address
from apps.ethereum.keychain import with_keychain_from_path
@with_keychain_from_path
async def get_public_key(ctx, msg, keychain): async def get_public_key(ctx, msg, keychain):
await paths.validate_path( await paths.validate_path(
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n, CURVE ctx, address.validate_path_for_get_public_key, keychain, msg.address_n, CURVE

View File

@ -0,0 +1,64 @@
from trezor import wire
from . import CURVE, networks
from apps.common import HARDENED, seed
from apps.common.seed import get_keychain
if False:
from typing import List
from typing_extensions import Protocol
from protobuf import MessageType
from trezor.messages.EthereumSignTx import EthereumSignTx
from apps.common.seed import MsgOut, Handler, HandlerWithKeychain
class MsgWithAddressN(MessageType, Protocol):
address_n = ... # type: List[int]
async def from_address_n(ctx: wire.Context, address_n: List[int]) -> seed.Keychain:
if len(address_n) < 2:
raise wire.DataError("Forbidden key path")
slip44_hardened = address_n[1]
if slip44_hardened not in networks.all_slip44_ids_hardened():
raise wire.DataError("Forbidden key path")
namespace = CURVE, [44 | HARDENED, slip44_hardened]
return await get_keychain(ctx, [namespace])
def with_keychain_from_path(
func: HandlerWithKeychain[MsgWithAddressN, MsgOut]
) -> Handler[MsgWithAddressN, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgWithAddressN) -> MsgOut:
keychain = await from_address_n(ctx, msg.address_n)
with keychain:
return await func(ctx, msg, keychain)
return wrapper
def with_keychain_from_chain_id(
func: HandlerWithKeychain[EthereumSignTx, MsgOut]
) -> Handler[EthereumSignTx, MsgOut]:
async def wrapper(ctx: wire.Context, msg: EthereumSignTx) -> MsgOut:
if msg.chain_id is None:
keychain = await from_address_n(ctx, msg.address_n)
else:
info = networks.by_chain_id(msg.chain_id)
if info is None:
raise wire.DataError("Unsupported chain id")
slip44 = info.slip44
if networks.is_wanchain(msg.chain_id, msg.tx_type):
slip44 = networks.SLIP44_WANCHAIN
namespace = CURVE, [44 | HARDENED, slip44 | HARDENED]
keychain = await get_keychain(ctx, [namespace])
with keychain:
return await func(ctx, msg, keychain)
return wrapper

View File

@ -12,8 +12,12 @@ if False:
from typing import Iterator, Optional from typing import Iterator, Optional
def is_wanchain(chain_id: int, tx_type: int) -> bool:
return tx_type in (1, 6) and chain_id in (1, 3)
def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str: def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str:
if tx_type in (1, 6) and chain_id in (1, 3): if is_wanchain(chain_id, tx_type):
return "WAN" return "WAN"
else: else:
n = by_chain_id(chain_id) n = by_chain_id(chain_id)

View File

@ -12,8 +12,12 @@ if False:
from typing import Iterator, Optional from typing import Iterator, Optional
def is_wanchain(chain_id: int, tx_type: int) -> bool:
return tx_type in (1, 6) and chain_id in (1, 3)
def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str: def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str:
if tx_type in (1, 6) and chain_id in (1, 3): if is_wanchain(chain_id, tx_type):
return "WAN" return "WAN"
else: else:
n = by_chain_id(chain_id) n = by_chain_id(chain_id)

View File

@ -6,6 +6,7 @@ from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths
from apps.common.signverify import require_confirm_sign_message from apps.common.signverify import require_confirm_sign_message
from apps.ethereum import CURVE, address from apps.ethereum import CURVE, address
from apps.ethereum.keychain import with_keychain_from_path
def message_digest(message): def message_digest(message):
@ -17,6 +18,7 @@ def message_digest(message):
return h.get_digest() return h.get_digest()
@with_keychain_from_path
async def sign_message(ctx, msg, keychain): async def sign_message(ctx, msg, keychain):
await paths.validate_path( await paths.validate_path(
ctx, address.validate_full_path, keychain, msg.address_n, CURVE ctx, address.validate_full_path, keychain, msg.address_n, CURVE

View File

@ -10,6 +10,7 @@ from trezor.utils import HashWriter
from apps.common import paths from apps.common import paths
from apps.ethereum import CURVE, address, tokens from apps.ethereum import CURVE, address, tokens
from apps.ethereum.address import validate_full_path from apps.ethereum.address import validate_full_path
from apps.ethereum.keychain import with_keychain_from_chain_id
from apps.ethereum.layout import ( from apps.ethereum.layout import (
require_confirm_data, require_confirm_data,
require_confirm_fee, require_confirm_fee,
@ -20,6 +21,7 @@ from apps.ethereum.layout import (
MAX_CHAIN_ID = 2147483629 MAX_CHAIN_ID = 2147483629
@with_keychain_from_chain_id
async def sign_tx(ctx, msg, keychain): async def sign_tx(ctx, msg, keychain):
msg = sanitize(msg) msg = sanitize(msg)
check(msg) check(msg)