mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-03 03:50:58 +00:00
core/ethereum: introduce custom keychain decorators
This commit is contained in:
parent
b594248ac2
commit
0dff3853a7
@ -1,18 +1,12 @@
|
|||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.messages import MessageType
|
from trezor.messages import MessageType
|
||||||
|
|
||||||
from apps.common import HARDENED
|
|
||||||
from apps.ethereum.networks import all_slip44_ids_hardened
|
|
||||||
|
|
||||||
CURVE = "secp256k1"
|
CURVE = "secp256k1"
|
||||||
|
|
||||||
|
|
||||||
def boot() -> None:
|
def boot() -> None:
|
||||||
ns = []
|
wire.add(MessageType.EthereumGetAddress, __name__, "get_address")
|
||||||
for i in all_slip44_ids_hardened():
|
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key")
|
||||||
ns.append([CURVE, HARDENED | 44, i])
|
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx")
|
||||||
wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns)
|
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message")
|
||||||
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key", ns)
|
|
||||||
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns)
|
|
||||||
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message", ns)
|
|
||||||
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")
|
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")
|
||||||
|
@ -6,8 +6,10 @@ from apps.common import paths
|
|||||||
from apps.common.layout import address_n_to_str, show_address, show_qr
|
from apps.common.layout import address_n_to_str, show_address, show_qr
|
||||||
from apps.ethereum import CURVE, networks
|
from apps.ethereum import CURVE, networks
|
||||||
from apps.ethereum.address import address_from_bytes, validate_full_path
|
from apps.ethereum.address import address_from_bytes, validate_full_path
|
||||||
|
from apps.ethereum.keychain import with_keychain_from_path
|
||||||
|
|
||||||
|
|
||||||
|
@with_keychain_from_path
|
||||||
async def get_address(ctx, msg, keychain):
|
async def get_address(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
|
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE)
|
||||||
|
|
||||||
|
@ -3,8 +3,10 @@ from trezor.messages.HDNodeType import HDNodeType
|
|||||||
|
|
||||||
from apps.common import coins, layout, paths
|
from apps.common import coins, layout, paths
|
||||||
from apps.ethereum import CURVE, address
|
from apps.ethereum import CURVE, address
|
||||||
|
from apps.ethereum.keychain import with_keychain_from_path
|
||||||
|
|
||||||
|
|
||||||
|
@with_keychain_from_path
|
||||||
async def get_public_key(ctx, msg, keychain):
|
async def get_public_key(ctx, msg, keychain):
|
||||||
await paths.validate_path(
|
await paths.validate_path(
|
||||||
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n, CURVE
|
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n, CURVE
|
||||||
|
64
core/src/apps/ethereum/keychain.py
Normal file
64
core/src/apps/ethereum/keychain.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
from trezor import wire
|
||||||
|
|
||||||
|
from . import CURVE, networks
|
||||||
|
|
||||||
|
from apps.common import HARDENED, seed
|
||||||
|
from apps.common.seed import get_keychain
|
||||||
|
|
||||||
|
if False:
|
||||||
|
from typing import List
|
||||||
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
from protobuf import MessageType
|
||||||
|
|
||||||
|
from trezor.messages.EthereumSignTx import EthereumSignTx
|
||||||
|
|
||||||
|
from apps.common.seed import MsgOut, Handler, HandlerWithKeychain
|
||||||
|
|
||||||
|
class MsgWithAddressN(MessageType, Protocol):
|
||||||
|
address_n = ... # type: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
async def from_address_n(ctx: wire.Context, address_n: List[int]) -> seed.Keychain:
|
||||||
|
if len(address_n) < 2:
|
||||||
|
raise wire.DataError("Forbidden key path")
|
||||||
|
slip44_hardened = address_n[1]
|
||||||
|
if slip44_hardened not in networks.all_slip44_ids_hardened():
|
||||||
|
raise wire.DataError("Forbidden key path")
|
||||||
|
namespace = CURVE, [44 | HARDENED, slip44_hardened]
|
||||||
|
return await get_keychain(ctx, [namespace])
|
||||||
|
|
||||||
|
|
||||||
|
def with_keychain_from_path(
|
||||||
|
func: HandlerWithKeychain[MsgWithAddressN, MsgOut]
|
||||||
|
) -> Handler[MsgWithAddressN, MsgOut]:
|
||||||
|
async def wrapper(ctx: wire.Context, msg: MsgWithAddressN) -> MsgOut:
|
||||||
|
keychain = await from_address_n(ctx, msg.address_n)
|
||||||
|
with keychain:
|
||||||
|
return await func(ctx, msg, keychain)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def with_keychain_from_chain_id(
|
||||||
|
func: HandlerWithKeychain[EthereumSignTx, MsgOut]
|
||||||
|
) -> Handler[EthereumSignTx, MsgOut]:
|
||||||
|
async def wrapper(ctx: wire.Context, msg: EthereumSignTx) -> MsgOut:
|
||||||
|
if msg.chain_id is None:
|
||||||
|
keychain = await from_address_n(ctx, msg.address_n)
|
||||||
|
else:
|
||||||
|
info = networks.by_chain_id(msg.chain_id)
|
||||||
|
if info is None:
|
||||||
|
raise wire.DataError("Unsupported chain id")
|
||||||
|
|
||||||
|
slip44 = info.slip44
|
||||||
|
if networks.is_wanchain(msg.chain_id, msg.tx_type):
|
||||||
|
slip44 = networks.SLIP44_WANCHAIN
|
||||||
|
|
||||||
|
namespace = CURVE, [44 | HARDENED, slip44 | HARDENED]
|
||||||
|
keychain = await get_keychain(ctx, [namespace])
|
||||||
|
|
||||||
|
with keychain:
|
||||||
|
return await func(ctx, msg, keychain)
|
||||||
|
|
||||||
|
return wrapper
|
@ -12,8 +12,12 @@ if False:
|
|||||||
from typing import Iterator, Optional
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def is_wanchain(chain_id: int, tx_type: int) -> bool:
|
||||||
|
return tx_type in (1, 6) and chain_id in (1, 3)
|
||||||
|
|
||||||
|
|
||||||
def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str:
|
def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str:
|
||||||
if tx_type in (1, 6) and chain_id in (1, 3):
|
if is_wanchain(chain_id, tx_type):
|
||||||
return "WAN"
|
return "WAN"
|
||||||
else:
|
else:
|
||||||
n = by_chain_id(chain_id)
|
n = by_chain_id(chain_id)
|
||||||
|
@ -12,8 +12,12 @@ if False:
|
|||||||
from typing import Iterator, Optional
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def is_wanchain(chain_id: int, tx_type: int) -> bool:
|
||||||
|
return tx_type in (1, 6) and chain_id in (1, 3)
|
||||||
|
|
||||||
|
|
||||||
def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str:
|
def shortcut_by_chain_id(chain_id: int, tx_type: int = None) -> str:
|
||||||
if tx_type in (1, 6) and chain_id in (1, 3):
|
if is_wanchain(chain_id, tx_type):
|
||||||
return "WAN"
|
return "WAN"
|
||||||
else:
|
else:
|
||||||
n = by_chain_id(chain_id)
|
n = by_chain_id(chain_id)
|
||||||
|
@ -6,6 +6,7 @@ from trezor.utils import HashWriter
|
|||||||
from apps.common import paths
|
from apps.common import paths
|
||||||
from apps.common.signverify import require_confirm_sign_message
|
from apps.common.signverify import require_confirm_sign_message
|
||||||
from apps.ethereum import CURVE, address
|
from apps.ethereum import CURVE, address
|
||||||
|
from apps.ethereum.keychain import with_keychain_from_path
|
||||||
|
|
||||||
|
|
||||||
def message_digest(message):
|
def message_digest(message):
|
||||||
@ -17,6 +18,7 @@ def message_digest(message):
|
|||||||
return h.get_digest()
|
return h.get_digest()
|
||||||
|
|
||||||
|
|
||||||
|
@with_keychain_from_path
|
||||||
async def sign_message(ctx, msg, keychain):
|
async def sign_message(ctx, msg, keychain):
|
||||||
await paths.validate_path(
|
await paths.validate_path(
|
||||||
ctx, address.validate_full_path, keychain, msg.address_n, CURVE
|
ctx, address.validate_full_path, keychain, msg.address_n, CURVE
|
||||||
|
@ -10,6 +10,7 @@ from trezor.utils import HashWriter
|
|||||||
from apps.common import paths
|
from apps.common import paths
|
||||||
from apps.ethereum import CURVE, address, tokens
|
from apps.ethereum import CURVE, address, tokens
|
||||||
from apps.ethereum.address import validate_full_path
|
from apps.ethereum.address import validate_full_path
|
||||||
|
from apps.ethereum.keychain import with_keychain_from_chain_id
|
||||||
from apps.ethereum.layout import (
|
from apps.ethereum.layout import (
|
||||||
require_confirm_data,
|
require_confirm_data,
|
||||||
require_confirm_fee,
|
require_confirm_fee,
|
||||||
@ -20,6 +21,7 @@ from apps.ethereum.layout import (
|
|||||||
MAX_CHAIN_ID = 2147483629
|
MAX_CHAIN_ID = 2147483629
|
||||||
|
|
||||||
|
|
||||||
|
@with_keychain_from_chain_id
|
||||||
async def sign_tx(ctx, msg, keychain):
|
async def sign_tx(ctx, msg, keychain):
|
||||||
msg = sanitize(msg)
|
msg = sanitize(msg)
|
||||||
check(msg)
|
check(msg)
|
||||||
|
Loading…
Reference in New Issue
Block a user