mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 06:18:07 +00:00
core/ethereum: introduce custom keychain decorators
This commit is contained in:
parent
b594248ac2
commit
0dff3853a7
@ -1,18 +1,12 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common import HARDENED
|
||||
from apps.ethereum.networks import all_slip44_ids_hardened
|
||||
|
||||
CURVE = "secp256k1"
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
ns = []
|
||||
for i in all_slip44_ids_hardened():
|
||||
ns.append([CURVE, HARDENED | 44, i])
|
||||
wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns)
|
||||
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key", ns)
|
||||
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns)
|
||||
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message", ns)
|
||||
wire.add(MessageType.EthereumGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx")
|
||||
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message")
|
||||
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")
|
||||
|
@ -6,8 +6,10 @@ from apps.common import paths
|
||||
from apps.common.layout import address_n_to_str, show_address, show_qr
|
||||
from apps.ethereum import CURVE, networks
|
||||
from apps.ethereum.address import address_from_bytes, validate_full_path
|
||||
from apps.ethereum.keychain import with_keychain_from_path
|
||||
|
||||
|
||||
@with_keychain_from_path
|
||||
async def get_address(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
|
||||
|
||||
|
@ -3,8 +3,10 @@ from trezor.messages.HDNodeType import HDNodeType
|
||||
|
||||
from apps.common import coins, layout, paths
|
||||
from apps.ethereum import CURVE, address
|
||||
from apps.ethereum.keychain import with_keychain_from_path
|
||||
|
||||
|
||||
@with_keychain_from_path
|
||||
async def get_public_key(ctx, msg, keychain):
|
||||
await paths.validate_path(
|
||||
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n, CURVE
|
||||
|
64
core/src/apps/ethereum/keychain.py
Normal file
64
core/src/apps/ethereum/keychain.py
Normal file
@ -0,0 +1,64 @@
|
||||
from trezor import wire
|
||||
|
||||
from . import CURVE, networks
|
||||
|
||||
from apps.common import HARDENED, seed
|
||||
from apps.common.seed import get_keychain
|
||||
|
||||
if False:
|
||||
from typing import List
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from protobuf import MessageType
|
||||
|
||||
from trezor.messages.EthereumSignTx import EthereumSignTx
|
||||
|
||||
from apps.common.seed import MsgOut, Handler, HandlerWithKeychain
|
||||
|
||||
class MsgWithAddressN(MessageType, Protocol):
|
||||
address_n = ... # type: List[int]
|
||||
|
||||
|
||||
async def from_address_n(ctx: wire.Context, address_n: List[int]) -> seed.Keychain:
|
||||
if len(address_n) < 2:
|
||||
raise wire.DataError("Forbidden key path")
|
||||
slip44_hardened = address_n[1]
|
||||
if slip44_hardened not in networks.all_slip44_ids_hardened():
|
||||
raise wire.DataError("Forbidden key path")
|
||||
namespace = CURVE, [44 | HARDENED, slip44_hardened]
|
||||
return await get_keychain(ctx, [namespace])
|
||||
|
||||
|
||||
def with_keychain_from_path(
|
||||
func: HandlerWithKeychain[MsgWithAddressN, MsgOut]
|
||||
) -> Handler[MsgWithAddressN, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: MsgWithAddressN) -> MsgOut:
|
||||
keychain = await from_address_n(ctx, msg.address_n)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def with_keychain_from_chain_id(
|
||||
func: HandlerWithKeychain[EthereumSignTx, MsgOut]
|
||||
) -> Handler[EthereumSignTx, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: EthereumSignTx) -> MsgOut:
|
||||
if msg.chain_id is None:
|
||||
keychain = await from_address_n(ctx, msg.address_n)
|
||||
else:
|
||||
info = networks.by_chain_id(msg.chain_id)
|
||||
if info is None:
|
||||
raise wire.DataError("Unsupported chain id")
|
||||
|
||||
slip44 = info.slip44
|
||||
if networks.is_wanchain(msg.chain_id, msg.tx_type):
|
||||
slip44 = networks.SLIP44_WANCHAIN
|
||||
|
||||
namespace = CURVE, [44 | HARDENED, slip44 | HARDENED]
|
||||
keychain = await get_keychain(ctx, [namespace])
|
||||
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain)
|
||||
|
||||
return wrapper
|
@ -12,8 +12,12 @@ if False:
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
def is_wanchain(chain_id: int, tx_type: int) -> bool:
|
||||
return tx_type in (1, 6) and chain_id in (1, 3)
|
||||
|
||||
|
||||
def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str:
|
||||
if tx_type in (1, 6) and chain_id in (1, 3):
|
||||
if is_wanchain(chain_id, tx_type):
|
||||
return "WAN"
|
||||
else:
|
||||
n = by_chain_id(chain_id)
|
||||
|
@ -12,8 +12,12 @@ if False:
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
def is_wanchain(chain_id: int, tx_type: int) -> bool:
|
||||
return tx_type in (1, 6) and chain_id in (1, 3)
|
||||
|
||||
|
||||
def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str:
|
||||
if tx_type in (1, 6) and chain_id in (1, 3):
|
||||
if is_wanchain(chain_id, tx_type):
|
||||
return "WAN"
|
||||
else:
|
||||
n = by_chain_id(chain_id)
|
||||
|
@ -6,6 +6,7 @@ from trezor.utils import HashWriter
|
||||
from apps.common import paths
|
||||
from apps.common.signverify import require_confirm_sign_message
|
||||
from apps.ethereum import CURVE, address
|
||||
from apps.ethereum.keychain import with_keychain_from_path
|
||||
|
||||
|
||||
def message_digest(message):
|
||||
@ -17,6 +18,7 @@ def message_digest(message):
|
||||
return h.get_digest()
|
||||
|
||||
|
||||
@with_keychain_from_path
|
||||
async def sign_message(ctx, msg, keychain):
|
||||
await paths.validate_path(
|
||||
ctx, address.validate_full_path, keychain, msg.address_n, CURVE
|
||||
|
@ -10,6 +10,7 @@ from trezor.utils import HashWriter
|
||||
from apps.common import paths
|
||||
from apps.ethereum import CURVE, address, tokens
|
||||
from apps.ethereum.address import validate_full_path
|
||||
from apps.ethereum.keychain import with_keychain_from_chain_id
|
||||
from apps.ethereum.layout import (
|
||||
require_confirm_data,
|
||||
require_confirm_fee,
|
||||
@ -20,6 +21,7 @@ from apps.ethereum.layout import (
|
||||
MAX_CHAIN_ID = 2147483629
|
||||
|
||||
|
||||
@with_keychain_from_chain_id
|
||||
async def sign_tx(ctx, msg, keychain):
|
||||
msg = sanitize(msg)
|
||||
check(msg)
|
||||
|
Loading…
Reference in New Issue
Block a user