1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-09 06:50:58 +00:00

core/wallet: implement keychain for apps.wallet

This commit is contained in:
matejcik 2020-05-06 12:42:52 +02:00 committed by matejcik
parent 0dff3853a7
commit a31b2cd1bc
11 changed files with 139 additions and 63 deletions

View File

@ -3,22 +3,12 @@ from trezor.messages import MessageType
def boot() -> None: def boot() -> None:
ns = [ wire.add(MessageType.GetPublicKey, __name__, "get_public_key")
["curve25519"], wire.add(MessageType.GetAddress, __name__, "get_address")
["ed25519"],
["ed25519-keccak"],
["nist256p1"],
["secp256k1"],
["secp256k1-decred"],
["secp256k1-groestl"],
["secp256k1-smart"],
]
wire.add(MessageType.GetPublicKey, __name__, "get_public_key", ns)
wire.add(MessageType.GetAddress, __name__, "get_address", ns)
wire.add(MessageType.GetEntropy, __name__, "get_entropy") wire.add(MessageType.GetEntropy, __name__, "get_entropy")
wire.add(MessageType.SignTx, __name__, "sign_tx", ns) wire.add(MessageType.SignTx, __name__, "sign_tx")
wire.add(MessageType.SignMessage, __name__, "sign_message", ns) wire.add(MessageType.SignMessage, __name__, "sign_message")
wire.add(MessageType.VerifyMessage, __name__, "verify_message") wire.add(MessageType.VerifyMessage, __name__, "verify_message")
wire.add(MessageType.SignIdentity, __name__, "sign_identity", ns) wire.add(MessageType.SignIdentity, __name__, "sign_identity")
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key", ns) wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key")
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value", ns) wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value")

View File

@ -5,9 +5,12 @@ from trezor.messages.CipheredKeyValue import CipheredKeyValue
from trezor.ui.text import Text from trezor.ui.text import Text
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.seed import get_keychain
async def cipher_key_value(ctx, msg, keychain): async def cipher_key_value(ctx, msg):
keychain = await get_keychain(ctx, [("secp256k1", [])])
if len(msg.value) % 16 > 0: if len(msg.value) % 16 > 0:
raise wire.DataError("Value length must be a multiple of 16") raise wire.DataError("Value length must be a multiple of 16")

View File

@ -2,11 +2,12 @@ from trezor.crypto import bip32
from trezor.messages import InputScriptType from trezor.messages import InputScriptType
from trezor.messages.Address import Address from trezor.messages.Address import Address
from apps.common import coins from .keychain import with_keychain
from .sign_tx import addresses
from .sign_tx.multisig import multisig_pubkey_index
from apps.common.layout import address_n_to_str, show_address, show_qr, show_xpub from apps.common.layout import address_n_to_str, show_address, show_qr, show_xpub
from apps.common.paths import validate_path from apps.common.paths import validate_path
from apps.wallet.sign_tx import addresses
from apps.wallet.sign_tx.multisig import multisig_pubkey_index
if False: if False:
from typing import List from typing import List
@ -36,10 +37,8 @@ async def show_xpubs(
return False return False
async def get_address(ctx, msg, keychain): @with_keychain
coin_name = msg.coin_name or "Bitcoin" async def get_address(ctx, msg, keychain, coin):
coin = coins.by_name(coin_name)
await validate_path( await validate_path(
ctx, ctx,
addresses.validate_full_path, addresses.validate_full_path,
@ -50,7 +49,7 @@ async def get_address(ctx, msg, keychain):
script_type=msg.script_type, script_type=msg.script_type,
) )
node = keychain.derive(msg.address_n, coin.curve_name) node = keychain.derive(msg.address_n)
address = addresses.get_address(msg.script_type, coin, node, msg.multisig) address = addresses.get_address(msg.script_type, coin, node, msg.multisig)
address_short = addresses.address_short(coin, address) address_short = addresses.address_short(coin, address)
if msg.script_type == InputScriptType.SPENDWITNESS: if msg.script_type == InputScriptType.SPENDWITNESS:

View File

@ -8,22 +8,24 @@ from trezor.utils import chunks
from apps.common import HARDENED from apps.common import HARDENED
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.seed import get_keychain
from apps.wallet.sign_identity import ( from apps.wallet.sign_identity import (
serialize_identity, serialize_identity,
serialize_identity_without_proto, serialize_identity_without_proto,
) )
async def get_ecdh_session_key(ctx, msg, keychain): async def get_ecdh_session_key(ctx, msg):
if msg.ecdsa_curve_name is None: if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = "secp256k1" msg.ecdsa_curve_name = "secp256k1"
keychain = await get_keychain(ctx, [(msg.ecdsa_curve_name, [])])
identity = serialize_identity(msg.identity) identity = serialize_identity(msg.identity)
await require_confirm_ecdh_session_key(ctx, msg.identity) await require_confirm_ecdh_session_key(ctx, msg.identity)
address_n = get_ecdh_path(identity, msg.identity.index or 0) address_n = get_ecdh_path(identity, msg.identity.index or 0)
node = keychain.derive(address_n, msg.ecdsa_curve_name) node = keychain.derive(address_n)
session_key = ecdh( session_key = ecdh(
seckey=node.private_key(), seckey=node.private_key(),

View File

@ -3,16 +3,42 @@ from trezor.messages import InputScriptType
from trezor.messages.HDNodeType import HDNodeType from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.PublicKey import PublicKey from trezor.messages.PublicKey import PublicKey
from apps.common import coins, layout from .keychain import get_keychain_for_coin
from apps.common import HARDENED, coins, layout, seed
async def get_public_key(ctx, msg, keychain): async def get_keychain_for_curve(ctx: wire.Context, curve_name: str) -> seed.Keychain:
coin_name = msg.coin_name or "Bitcoin" """Set up a keychain for SLIP-13 and SLIP-17 namespaces with a specified curve."""
coin = coins.by_name(coin_name) namespaces = [
curve_name = msg.ecdsa_curve_name or coin.curve_name (curve_name, [13 | HARDENED]),
(curve_name, [17 | HARDENED]),
]
return await seed.get_keychain(ctx, namespaces)
async def get_public_key(ctx, msg):
script_type = msg.script_type or InputScriptType.SPENDADDRESS script_type = msg.script_type or InputScriptType.SPENDADDRESS
node = keychain.derive(msg.address_n, curve_name=curve_name) if msg.ecdsa_curve_name is not None:
# If a curve name is provided, disallow coin-specific features.
if (
msg.coin_name is not None
or msg.script_type is not InputScriptType.SPENDADDRESS
):
raise wire.DataError(
"Cannot use coin_name or script_type with ecdsa_curve_name"
)
coin = coins.by_name("Bitcoin")
# only allow SLIP-13/17 namespaces
keychain = await get_keychain_for_curve(ctx, msg.ecdsa_curve_name)
else:
# select curve and namespaces based on the requested coin properties
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
node = keychain.derive(msg.address_n)
if ( if (
script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG] script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG]

View File

@ -0,0 +1,61 @@
from trezor import wire
from apps.common import HARDENED, coininfo
from apps.common.seed import get_keychain
if False:
from protobuf import MessageType
from typing import Callable, Optional, Tuple, TypeVar
from typing_extensions import Protocol
from apps.common.seed import Keychain, MsgOut, Handler
class MsgWithCoinName(MessageType, Protocol):
coin_name = ... # type: Optional[str]
MsgIn = TypeVar("MsgIn", bound=MsgWithCoinName)
HandlerWithCoinInfo = Callable[
[wire.Context, MsgIn, Keychain, coininfo.CoinInfo], MsgOut
]
async def get_keychain_for_coin(
ctx: wire.Context, coin_name: Optional[str]
) -> Tuple[Keychain, coininfo.CoinInfo]:
if coin_name is None:
coin_name = "Bitcoin"
try:
coin = coininfo.by_name(coin_name)
except ValueError:
raise wire.DataError("Unsupported coin type")
namespaces = []
curve = coin.curve_name
slip44_id = coin.slip44 | HARDENED
# BIP-44 - legacy: m/44'/slip44' (/account'/change/addr)
namespaces.append((curve, [44 | HARDENED, slip44_id]))
# BIP-45 - multisig cosigners: m/45' (/cosigner/change/addr)
namespaces.append((curve, [45 | HARDENED]))
# "purpose48" - multisig as done by Electrum
# m/48'/slip44' (/account'/script_type'/change/addr)
namespaces.append((curve, [48 | HARDENED, slip44_id]))
if coin.segwit:
# BIP-49 - p2sh segwit: m/49'/slip44' (/account'/change/addr)
namespaces.append((curve, [49 | HARDENED, slip44_id]))
# BIP-84 - native segwit: m/84'/slip44' (/account'/change/addr)
namespaces.append((curve, [84 | HARDENED, slip44_id]))
keychain = await get_keychain(ctx, namespaces)
return keychain, coin
def with_keychain(func: HandlerWithCoinInfo[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
with keychain:
return await func(ctx, msg, keychain, coin)
return wrapper

View File

@ -8,18 +8,20 @@ from trezor.utils import chunks
from apps.common import HARDENED, coins from apps.common import HARDENED, coins
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
from apps.common.seed import get_keychain
async def sign_identity(ctx, msg, keychain): async def sign_identity(ctx, msg):
if msg.ecdsa_curve_name is None: if msg.ecdsa_curve_name is None:
msg.ecdsa_curve_name = "secp256k1" msg.ecdsa_curve_name = "secp256k1"
keychain = await get_keychain(ctx, [(msg.ecdsa_curve_name, [])])
identity = serialize_identity(msg.identity) identity = serialize_identity(msg.identity)
await require_confirm_sign_identity(ctx, msg.identity, msg.challenge_visual) await require_confirm_sign_identity(ctx, msg.identity, msg.challenge_visual)
address_n = get_identity_path(identity, msg.identity.index or 0) address_n = get_identity_path(identity, msg.identity.index or 0)
node = keychain.derive(address_n, msg.ecdsa_curve_name) node = keychain.derive(address_n)
coin = coins.by_name("Bitcoin") coin = coins.by_name("Bitcoin")
if msg.ecdsa_curve_name == "secp256k1": if msg.ecdsa_curve_name == "secp256k1":

View File

@ -3,18 +3,18 @@ from trezor.crypto.curve import secp256k1
from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPENDWITNESS from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPENDWITNESS
from trezor.messages.MessageSignature import MessageSignature from trezor.messages.MessageSignature import MessageSignature
from apps.common import coins from .keychain import with_keychain
from .sign_tx.addresses import get_address, validate_full_path
from apps.common.paths import validate_path from apps.common.paths import validate_path
from apps.common.signverify import message_digest, require_confirm_sign_message from apps.common.signverify import message_digest, require_confirm_sign_message
from apps.wallet.sign_tx.addresses import get_address, validate_full_path
async def sign_message(ctx, msg, keychain): @with_keychain
async def sign_message(ctx, msg, keychain, coin):
message = msg.message message = msg.message
address_n = msg.address_n address_n = msg.address_n
coin_name = msg.coin_name or "Bitcoin"
script_type = msg.script_type or 0 script_type = msg.script_type or 0
coin = coins.by_name(coin_name)
await require_confirm_sign_message(ctx, "Sign message", message) await require_confirm_sign_message(ctx, "Sign message", message)
await validate_path( await validate_path(
@ -28,7 +28,7 @@ async def sign_message(ctx, msg, keychain):
validate_script_type=False, validate_script_type=False,
) )
node = keychain.derive(address_n, coin.curve_name) node = keychain.derive(address_n)
seckey = node.private_key() seckey = node.private_key()
address = get_address(script_type, coin, node) address = get_address(script_type, coin, node)

View File

@ -4,17 +4,10 @@ from trezor.messages.SignTx import SignTx
from trezor.messages.TxAck import TxAck from trezor.messages.TxAck import TxAck
from trezor.messages.TxRequest import TxRequest from trezor.messages.TxRequest import TxRequest
from apps.common import coins, paths, seed from ..keychain import with_keychain
from apps.wallet.sign_tx import ( from . import addresses, bitcoin, common, helpers, layout, multisig, progress, scripts
addresses,
bitcoin, from apps.common import coininfo, paths, seed
common,
helpers,
layout,
multisig,
progress,
scripts,
)
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from apps.wallet.sign_tx import bitcoinlike, decred, zcash from apps.wallet.sign_tx import bitcoinlike, decred, zcash
@ -26,16 +19,16 @@ if False:
BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet") BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet")
async def sign_tx(ctx: wire.Context, msg: SignTx, keychain: seed.Keychain) -> TxRequest: @with_keychain
coin_name = msg.coin_name if msg.coin_name is not None else "Bitcoin" async def sign_tx(
coin = coins.by_name(coin_name) ctx: wire.Context, msg: SignTx, keychain: seed.Keychain, coin: coininfo.CoinInfo
) -> TxRequest:
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
if coin.decred: if coin.decred:
signer_class = decred.Decred # type: Type[bitcoin.Bitcoin] signer_class = decred.Decred # type: Type[bitcoin.Bitcoin]
elif coin.overwintered: elif coin.overwintered:
signer_class = zcash.Overwintered signer_class = zcash.Overwintered
elif coin_name not in BITCOIN_NAMES: elif coin.coin_name not in BITCOIN_NAMES:
signer_class = bitcoinlike.Bitcoinlike signer_class = bitcoinlike.Bitcoinlike
else: else:
signer_class = bitcoin.Bitcoin signer_class = bitcoin.Bitcoin

View File

@ -235,7 +235,7 @@ class Bitcoin:
# NOTE: No need to check the multisig fingerprint, because we won't be signing # NOTE: No need to check the multisig fingerprint, because we won't be signing
# the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS. # the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS.
node = self.keychain.derive(txi.address_n, self.coin.curve_name) node = self.keychain.derive(txi.address_n)
key_sign_pub = node.public_key() key_sign_pub = node.public_key()
script_sig = self.input_derive_script(txi, key_sign_pub) script_sig = self.input_derive_script(txi, key_sign_pub)
self.write_tx_input(self.serialized_tx, txi, script_sig) self.write_tx_input(self.serialized_tx, txi, script_sig)
@ -250,7 +250,7 @@ class Bitcoin:
) )
self.bip143_in -= txi.amount self.bip143_in -= txi.amount
node = self.keychain.derive(txi.address_n, self.coin.curve_name) node = self.keychain.derive(txi.address_n)
public_key = node.public_key() public_key = node.public_key()
hash143_hash = self.hash143_preimage_hash( hash143_hash = self.hash143_preimage_hash(
txi, addresses.ecdsa_hash_pubkey(public_key, self.coin) txi, addresses.ecdsa_hash_pubkey(public_key, self.coin)
@ -302,7 +302,7 @@ class Bitcoin:
self.wallet_path.check_input(txi) self.wallet_path.check_input(txi)
self.multisig_fingerprint.check_input(txi) self.multisig_fingerprint.check_input(txi)
# NOTE: wallet_path is checked in write_tx_input_check() # NOTE: wallet_path is checked in write_tx_input_check()
node = self.keychain.derive(txi.address_n, self.coin.curve_name) node = self.keychain.derive(txi.address_n)
key_sign_pub = node.public_key() key_sign_pub = node.public_key()
# if multisig, do a sanity check to ensure we are signing with a key that is included in the multisig # if multisig, do a sanity check to ensure we are signing with a key that is included in the multisig
if txi.multisig: if txi.multisig:
@ -470,7 +470,7 @@ class Bitcoin:
] ]
except KeyError: except KeyError:
raise SigningError(FailureType.DataError, "Invalid script type") raise SigningError(FailureType.DataError, "Invalid script type")
node = self.keychain.derive(txo.address_n, self.coin.curve_name) node = self.keychain.derive(txo.address_n)
txo.address = addresses.get_address( txo.address = addresses.get_address(
input_script_type, self.coin, node, txo.multisig input_script_type, self.coin, node, txo.multisig
) )

View File

@ -80,7 +80,7 @@ class Decred(Bitcoin):
self.wallet_path.check_input(txi_sign) self.wallet_path.check_input(txi_sign)
self.multisig_fingerprint.check_input(txi_sign) self.multisig_fingerprint.check_input(txi_sign)
key_sign = self.keychain.derive(txi_sign.address_n, self.coin.curve_name) key_sign = self.keychain.derive(txi_sign.address_n)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()
if txi_sign.script_type == InputScriptType.SPENDMULTISIG: if txi_sign.script_type == InputScriptType.SPENDMULTISIG: