1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

core/wallet: implement keychain for apps.wallet

This commit is contained in:
matejcik 2020-05-06 12:42:52 +02:00 committed by matejcik
parent 0dff3853a7
commit a31b2cd1bc
11 changed files with 139 additions and 63 deletions

View File

@ -3,22 +3,12 @@ from trezor.messages import MessageType
def boot() -> None:
ns = [
["curve25519"],
["ed25519"],
["ed25519-keccak"],
["nist256p1"],
["secp256k1"],
["secp256k1-decred"],
["secp256k1-groestl"],
["secp256k1-smart"],
]
wire.add(MessageType.GetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.GetAddress, __name__, "get_address", ns)
wire.add(MessageType.GetPublicKey, __name__, "get_public_key")
wire.add(MessageType.GetAddress, __name__, "get_address")
wire.add(MessageType.GetEntropy, __name__, "get_entropy")
wire.add(MessageType.SignTx, __name__, "sign_tx", ns)
wire.add(MessageType.SignMessage, __name__, "sign_message", ns)
wire.add(MessageType.SignTx, __name__, "sign_tx")
wire.add(MessageType.SignMessage, __name__, "sign_message")
wire.add(MessageType.VerifyMessage, __name__, "verify_message")
wire.add(MessageType.SignIdentity, __name__, "sign_identity", ns)
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key", ns)
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value", ns)
wire.add(MessageType.SignIdentity, __name__, "sign_identity")
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key")
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value")

View File

@ -5,9 +5,12 @@ from trezor.messages.CipheredKeyValue import CipheredKeyValue
from trezor.ui.text import Text
from apps.common.confirm import require_confirm
from apps.common.seed import get_keychain
async def cipher_key_value(ctx, msg, keychain):
async def cipher_key_value(ctx, msg):
keychain = await get_keychain(ctx, [("secp256k1", [])])
if len(msg.value) % 16 > 0:
raise wire.DataError("Value length must be a multiple of 16")

View File

@ -2,11 +2,12 @@ from trezor.crypto import bip32
from trezor.messages import InputScriptType
from trezor.messages.Address import Address
from apps.common import coins
from .keychain import with_keychain
from .sign_tx import addresses
from .sign_tx.multisig import multisig_pubkey_index
from apps.common.layout import address_n_to_str, show_address, show_qr, show_xpub
from apps.common.paths import validate_path
from apps.wallet.sign_tx import addresses
from apps.wallet.sign_tx.multisig import multisig_pubkey_index
if False:
from typing import List
@ -36,10 +37,8 @@ async def show_xpubs(
return False
async def get_address(ctx, msg, keychain):
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
@with_keychain
async def get_address(ctx, msg, keychain, coin):
await validate_path(
ctx,
addresses.validate_full_path,
@ -50,7 +49,7 @@ async def get_address(ctx, msg, keychain):
script_type=msg.script_type,
)
node = keychain.derive(msg.address_n, coin.curve_name)
node = keychain.derive(msg.address_n)
address = addresses.get_address(msg.script_type, coin, node, msg.multisig)
address_short = addresses.address_short(coin, address)
if msg.script_type == InputScriptType.SPENDWITNESS:

View File

@ -8,22 +8,24 @@ from trezor.utils import chunks
from apps.common import HARDENED
from apps.common.confirm import require_confirm
from apps.common.seed import get_keychain
from apps.wallet.sign_identity import (
serialize_identity,
serialize_identity_without_proto,
)
async def get_ecdh_session_key(ctx, msg, keychain):
async def get_ecdh_session_key(ctx, msg):
if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = "secp256k1"
keychain = await get_keychain(ctx, [(msg.ecdsa_curve_name, [])])
identity = serialize_identity(msg.identity)
await require_confirm_ecdh_session_key(ctx, msg.identity)
address_n = get_ecdh_path(identity, msg.identity.index or 0)
node = keychain.derive(address_n, msg.ecdsa_curve_name)
node = keychain.derive(address_n)
session_key = ecdh(
seckey=node.private_key(),

View File

@ -3,16 +3,42 @@ from trezor.messages import InputScriptType
from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.PublicKey import PublicKey
from apps.common import coins, layout
from .keychain import get_keychain_for_coin
from apps.common import HARDENED, coins, layout, seed
async def get_public_key(ctx, msg, keychain):
coin_name = msg.coin_name or "Bitcoin"
coin = coins.by_name(coin_name)
curve_name = msg.ecdsa_curve_name or coin.curve_name
async def get_keychain_for_curve(ctx: wire.Context, curve_name: str) -> seed.Keychain:
"""Set up a keychain for SLIP-13 and SLIP-17 namespaces with a specified curve."""
namespaces = [
(curve_name, [13 | HARDENED]),
(curve_name, [17 | HARDENED]),
]
return await seed.get_keychain(ctx, namespaces)
async def get_public_key(ctx, msg):
script_type = msg.script_type or InputScriptType.SPENDADDRESS
node = keychain.derive(msg.address_n, curve_name=curve_name)
if msg.ecdsa_curve_name is not None:
# If a curve name is provided, disallow coin-specific features.
if (
msg.coin_name is not None
or msg.script_type is not InputScriptType.SPENDADDRESS
):
raise wire.DataError(
"Cannot use coin_name or script_type with ecdsa_curve_name"
)
coin = coins.by_name("Bitcoin")
# only allow SLIP-13/17 namespaces
keychain = await get_keychain_for_curve(ctx, msg.ecdsa_curve_name)
else:
# select curve and namespaces based on the requested coin properties
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
node = keychain.derive(msg.address_n)
if (
script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG]

View File

@ -0,0 +1,61 @@
from trezor import wire
from apps.common import HARDENED, coininfo
from apps.common.seed import get_keychain
if False:
from protobuf import MessageType
from typing import Callable, Optional, Tuple, TypeVar
from typing_extensions import Protocol
from apps.common.seed import Keychain, MsgOut, Handler
class MsgWithCoinName(MessageType, Protocol):
coin_name = ... # type: Optional[str]
MsgIn = TypeVar("MsgIn", bound=MsgWithCoinName)
HandlerWithCoinInfo = Callable[
[wire.Context, MsgIn, Keychain, coininfo.CoinInfo], MsgOut
]
async def get_keychain_for_coin(
ctx: wire.Context, coin_name: Optional[str]
) -> Tuple[Keychain, coininfo.CoinInfo]:
if coin_name is None:
coin_name = "Bitcoin"
try:
coin = coininfo.by_name(coin_name)
except ValueError:
raise wire.DataError("Unsupported coin type")
namespaces = []
curve = coin.curve_name
slip44_id = coin.slip44 | HARDENED
# BIP-44 - legacy: m/44'/slip44' (/account'/change/addr)
namespaces.append((curve, [44 | HARDENED, slip44_id]))
# BIP-45 - multisig cosigners: m/45' (/cosigner/change/addr)
namespaces.append((curve, [45 | HARDENED]))
# "purpose48" - multisig as done by Electrum
# m/48'/slip44' (/account'/script_type'/change/addr)
namespaces.append((curve, [48 | HARDENED, slip44_id]))
if coin.segwit:
# BIP-49 - p2sh segwit: m/49'/slip44' (/account'/change/addr)
namespaces.append((curve, [49 | HARDENED, slip44_id]))
# BIP-84 - native segwit: m/84'/slip44' (/account'/change/addr)
namespaces.append((curve, [84 | HARDENED, slip44_id]))
keychain = await get_keychain(ctx, namespaces)
return keychain, coin
def with_keychain(func: HandlerWithCoinInfo[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
with keychain:
return await func(ctx, msg, keychain, coin)
return wrapper

View File

@ -8,18 +8,20 @@ from trezor.utils import chunks
from apps.common import HARDENED, coins
from apps.common.confirm import require_confirm
from apps.common.seed import get_keychain
async def sign_identity(ctx, msg, keychain):
async def sign_identity(ctx, msg):
if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = "secp256k1"
keychain = await get_keychain(ctx, [(msg.ecdsa_curve_name, [])])
identity = serialize_identity(msg.identity)
await require_confirm_sign_identity(ctx, msg.identity, msg.challenge_visual)
address_n = get_identity_path(identity, msg.identity.index or 0)
node = keychain.derive(address_n, msg.ecdsa_curve_name)
node = keychain.derive(address_n)
coin = coins.by_name("Bitcoin")
if msg.ecdsa_curve_name == "secp256k1":

View File

@ -3,18 +3,18 @@ from trezor.crypto.curve import secp256k1
from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPENDWITNESS
from trezor.messages.MessageSignature import MessageSignature
from apps.common import coins
from .keychain import with_keychain
from .sign_tx.addresses import get_address, validate_full_path
from apps.common.paths import validate_path
from apps.common.signverify import message_digest, require_confirm_sign_message
from apps.wallet.sign_tx.addresses import get_address, validate_full_path
async def sign_message(ctx, msg, keychain):
@with_keychain
async def sign_message(ctx, msg, keychain, coin):
message = msg.message
address_n = msg.address_n
coin_name = msg.coin_name or "Bitcoin"
script_type = msg.script_type or 0
coin = coins.by_name(coin_name)
await require_confirm_sign_message(ctx, "Sign message", message)
await validate_path(
@ -28,7 +28,7 @@ async def sign_message(ctx, msg, keychain):
validate_script_type=False,
)
node = keychain.derive(address_n, coin.curve_name)
node = keychain.derive(address_n)
seckey = node.private_key()
address = get_address(script_type, coin, node)

View File

@ -4,17 +4,10 @@ from trezor.messages.SignTx import SignTx
from trezor.messages.TxAck import TxAck
from trezor.messages.TxRequest import TxRequest
from apps.common import coins, paths, seed
from apps.wallet.sign_tx import (
addresses,
bitcoin,
common,
helpers,
layout,
multisig,
progress,
scripts,
)
from ..keychain import with_keychain
from . import addresses, bitcoin, common, helpers, layout, multisig, progress, scripts
from apps.common import coininfo, paths, seed
if not utils.BITCOIN_ONLY:
from apps.wallet.sign_tx import bitcoinlike, decred, zcash
@ -26,16 +19,16 @@ if False:
BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet")
async def sign_tx(ctx: wire.Context, msg: SignTx, keychain: seed.Keychain) -> TxRequest:
coin_name = msg.coin_name if msg.coin_name is not None else "Bitcoin"
coin = coins.by_name(coin_name)
@with_keychain
async def sign_tx(
ctx: wire.Context, msg: SignTx, keychain: seed.Keychain, coin: coininfo.CoinInfo
) -> TxRequest:
if not utils.BITCOIN_ONLY:
if coin.decred:
signer_class = decred.Decred # type: Type[bitcoin.Bitcoin]
elif coin.overwintered:
signer_class = zcash.Overwintered
elif coin_name not in BITCOIN_NAMES:
elif coin.coin_name not in BITCOIN_NAMES:
signer_class = bitcoinlike.Bitcoinlike
else:
signer_class = bitcoin.Bitcoin

View File

@ -235,7 +235,7 @@ class Bitcoin:
# NOTE: No need to check the multisig fingerprint, because we won't be signing
# the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS.
node = self.keychain.derive(txi.address_n, self.coin.curve_name)
node = self.keychain.derive(txi.address_n)
key_sign_pub = node.public_key()
script_sig = self.input_derive_script(txi, key_sign_pub)
self.write_tx_input(self.serialized_tx, txi, script_sig)
@ -250,7 +250,7 @@ class Bitcoin:
)
self.bip143_in -= txi.amount
node = self.keychain.derive(txi.address_n, self.coin.curve_name)
node = self.keychain.derive(txi.address_n)
public_key = node.public_key()
hash143_hash = self.hash143_preimage_hash(
txi, addresses.ecdsa_hash_pubkey(public_key, self.coin)
@ -302,7 +302,7 @@ class Bitcoin:
self.wallet_path.check_input(txi)
self.multisig_fingerprint.check_input(txi)
# NOTE: wallet_path is checked in write_tx_input_check()
node = self.keychain.derive(txi.address_n, self.coin.curve_name)
node = self.keychain.derive(txi.address_n)
key_sign_pub = node.public_key()
# if multisig, do a sanity check to ensure we are signing with a key that is included in the multisig
if txi.multisig:
@ -470,7 +470,7 @@ class Bitcoin:
]
except KeyError:
raise SigningError(FailureType.DataError, "Invalid script type")
node = self.keychain.derive(txo.address_n, self.coin.curve_name)
node = self.keychain.derive(txo.address_n)
txo.address = addresses.get_address(
input_script_type, self.coin, node, txo.multisig
)

View File

@ -80,7 +80,7 @@ class Decred(Bitcoin):
self.wallet_path.check_input(txi_sign)
self.multisig_fingerprint.check_input(txi_sign)
key_sign = self.keychain.derive(txi_sign.address_n, self.coin.curve_name)
key_sign = self.keychain.derive(txi_sign.address_n)
key_sign_pub = key_sign.public_key()
if txi_sign.script_type == InputScriptType.SPENDMULTISIG: