mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-23 06:48:16 +00:00
core/wallet: implement keychain for apps.wallet
This commit is contained in:
parent
0dff3853a7
commit
a31b2cd1bc
@ -3,22 +3,12 @@ from trezor.messages import MessageType
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
ns = [
|
||||
["curve25519"],
|
||||
["ed25519"],
|
||||
["ed25519-keccak"],
|
||||
["nist256p1"],
|
||||
["secp256k1"],
|
||||
["secp256k1-decred"],
|
||||
["secp256k1-groestl"],
|
||||
["secp256k1-smart"],
|
||||
]
|
||||
wire.add(MessageType.GetPublicKey, __name__, "get_public_key", ns)
|
||||
wire.add(MessageType.GetAddress, __name__, "get_address", ns)
|
||||
wire.add(MessageType.GetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.GetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.GetEntropy, __name__, "get_entropy")
|
||||
wire.add(MessageType.SignTx, __name__, "sign_tx", ns)
|
||||
wire.add(MessageType.SignMessage, __name__, "sign_message", ns)
|
||||
wire.add(MessageType.SignTx, __name__, "sign_tx")
|
||||
wire.add(MessageType.SignMessage, __name__, "sign_message")
|
||||
wire.add(MessageType.VerifyMessage, __name__, "verify_message")
|
||||
wire.add(MessageType.SignIdentity, __name__, "sign_identity", ns)
|
||||
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key", ns)
|
||||
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value", ns)
|
||||
wire.add(MessageType.SignIdentity, __name__, "sign_identity")
|
||||
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key")
|
||||
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value")
|
||||
|
@ -5,9 +5,12 @@ from trezor.messages.CipheredKeyValue import CipheredKeyValue
|
||||
from trezor.ui.text import Text
|
||||
|
||||
from apps.common.confirm import require_confirm
|
||||
from apps.common.seed import get_keychain
|
||||
|
||||
|
||||
async def cipher_key_value(ctx, msg, keychain):
|
||||
async def cipher_key_value(ctx, msg):
|
||||
keychain = await get_keychain(ctx, [("secp256k1", [])])
|
||||
|
||||
if len(msg.value) % 16 > 0:
|
||||
raise wire.DataError("Value length must be a multiple of 16")
|
||||
|
||||
|
@ -2,11 +2,12 @@ from trezor.crypto import bip32
|
||||
from trezor.messages import InputScriptType
|
||||
from trezor.messages.Address import Address
|
||||
|
||||
from apps.common import coins
|
||||
from .keychain import with_keychain
|
||||
from .sign_tx import addresses
|
||||
from .sign_tx.multisig import multisig_pubkey_index
|
||||
|
||||
from apps.common.layout import address_n_to_str, show_address, show_qr, show_xpub
|
||||
from apps.common.paths import validate_path
|
||||
from apps.wallet.sign_tx import addresses
|
||||
from apps.wallet.sign_tx.multisig import multisig_pubkey_index
|
||||
|
||||
if False:
|
||||
from typing import List
|
||||
@ -36,10 +37,8 @@ async def show_xpubs(
|
||||
return False
|
||||
|
||||
|
||||
async def get_address(ctx, msg, keychain):
|
||||
coin_name = msg.coin_name or "Bitcoin"
|
||||
coin = coins.by_name(coin_name)
|
||||
|
||||
@with_keychain
|
||||
async def get_address(ctx, msg, keychain, coin):
|
||||
await validate_path(
|
||||
ctx,
|
||||
addresses.validate_full_path,
|
||||
@ -50,7 +49,7 @@ async def get_address(ctx, msg, keychain):
|
||||
script_type=msg.script_type,
|
||||
)
|
||||
|
||||
node = keychain.derive(msg.address_n, coin.curve_name)
|
||||
node = keychain.derive(msg.address_n)
|
||||
address = addresses.get_address(msg.script_type, coin, node, msg.multisig)
|
||||
address_short = addresses.address_short(coin, address)
|
||||
if msg.script_type == InputScriptType.SPENDWITNESS:
|
||||
|
@ -8,22 +8,24 @@ from trezor.utils import chunks
|
||||
|
||||
from apps.common import HARDENED
|
||||
from apps.common.confirm import require_confirm
|
||||
from apps.common.seed import get_keychain
|
||||
from apps.wallet.sign_identity import (
|
||||
serialize_identity,
|
||||
serialize_identity_without_proto,
|
||||
)
|
||||
|
||||
|
||||
async def get_ecdh_session_key(ctx, msg, keychain):
|
||||
async def get_ecdh_session_key(ctx, msg):
|
||||
if msg.ecdsa_curve_name is None:
|
||||
msg.ecdsa_curve_name = "secp256k1"
|
||||
|
||||
keychain = await get_keychain(ctx, [(msg.ecdsa_curve_name, [])])
|
||||
identity = serialize_identity(msg.identity)
|
||||
|
||||
await require_confirm_ecdh_session_key(ctx, msg.identity)
|
||||
|
||||
address_n = get_ecdh_path(identity, msg.identity.index or 0)
|
||||
node = keychain.derive(address_n, msg.ecdsa_curve_name)
|
||||
node = keychain.derive(address_n)
|
||||
|
||||
session_key = ecdh(
|
||||
seckey=node.private_key(),
|
||||
|
@ -3,16 +3,42 @@ from trezor.messages import InputScriptType
|
||||
from trezor.messages.HDNodeType import HDNodeType
|
||||
from trezor.messages.PublicKey import PublicKey
|
||||
|
||||
from apps.common import coins, layout
|
||||
from .keychain import get_keychain_for_coin
|
||||
|
||||
from apps.common import HARDENED, coins, layout, seed
|
||||
|
||||
|
||||
async def get_public_key(ctx, msg, keychain):
|
||||
coin_name = msg.coin_name or "Bitcoin"
|
||||
coin = coins.by_name(coin_name)
|
||||
curve_name = msg.ecdsa_curve_name or coin.curve_name
|
||||
async def get_keychain_for_curve(ctx: wire.Context, curve_name: str) -> seed.Keychain:
|
||||
"""Set up a keychain for SLIP-13 and SLIP-17 namespaces with a specified curve."""
|
||||
namespaces = [
|
||||
(curve_name, [13 | HARDENED]),
|
||||
(curve_name, [17 | HARDENED]),
|
||||
]
|
||||
return await seed.get_keychain(ctx, namespaces)
|
||||
|
||||
|
||||
async def get_public_key(ctx, msg):
|
||||
script_type = msg.script_type or InputScriptType.SPENDADDRESS
|
||||
|
||||
node = keychain.derive(msg.address_n, curve_name=curve_name)
|
||||
if msg.ecdsa_curve_name is not None:
|
||||
# If a curve name is provided, disallow coin-specific features.
|
||||
if (
|
||||
msg.coin_name is not None
|
||||
or msg.script_type is not InputScriptType.SPENDADDRESS
|
||||
):
|
||||
raise wire.DataError(
|
||||
"Cannot use coin_name or script_type with ecdsa_curve_name"
|
||||
)
|
||||
|
||||
coin = coins.by_name("Bitcoin")
|
||||
# only allow SLIP-13/17 namespaces
|
||||
keychain = await get_keychain_for_curve(ctx, msg.ecdsa_curve_name)
|
||||
|
||||
else:
|
||||
# select curve and namespaces based on the requested coin properties
|
||||
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
|
||||
if (
|
||||
script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG]
|
||||
|
61
core/src/apps/wallet/keychain.py
Normal file
61
core/src/apps/wallet/keychain.py
Normal file
@ -0,0 +1,61 @@
|
||||
from trezor import wire
|
||||
|
||||
from apps.common import HARDENED, coininfo
|
||||
from apps.common.seed import get_keychain
|
||||
|
||||
if False:
|
||||
from protobuf import MessageType
|
||||
from typing import Callable, Optional, Tuple, TypeVar
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from apps.common.seed import Keychain, MsgOut, Handler
|
||||
|
||||
class MsgWithCoinName(MessageType, Protocol):
|
||||
coin_name = ... # type: Optional[str]
|
||||
|
||||
MsgIn = TypeVar("MsgIn", bound=MsgWithCoinName)
|
||||
HandlerWithCoinInfo = Callable[
|
||||
[wire.Context, MsgIn, Keychain, coininfo.CoinInfo], MsgOut
|
||||
]
|
||||
|
||||
|
||||
async def get_keychain_for_coin(
|
||||
ctx: wire.Context, coin_name: Optional[str]
|
||||
) -> Tuple[Keychain, coininfo.CoinInfo]:
|
||||
if coin_name is None:
|
||||
coin_name = "Bitcoin"
|
||||
|
||||
try:
|
||||
coin = coininfo.by_name(coin_name)
|
||||
except ValueError:
|
||||
raise wire.DataError("Unsupported coin type")
|
||||
|
||||
namespaces = []
|
||||
curve = coin.curve_name
|
||||
slip44_id = coin.slip44 | HARDENED
|
||||
|
||||
# BIP-44 - legacy: m/44'/slip44' (/account'/change/addr)
|
||||
namespaces.append((curve, [44 | HARDENED, slip44_id]))
|
||||
# BIP-45 - multisig cosigners: m/45' (/cosigner/change/addr)
|
||||
namespaces.append((curve, [45 | HARDENED]))
|
||||
# "purpose48" - multisig as done by Electrum
|
||||
# m/48'/slip44' (/account'/script_type'/change/addr)
|
||||
namespaces.append((curve, [48 | HARDENED, slip44_id]))
|
||||
|
||||
if coin.segwit:
|
||||
# BIP-49 - p2sh segwit: m/49'/slip44' (/account'/change/addr)
|
||||
namespaces.append((curve, [49 | HARDENED, slip44_id]))
|
||||
# BIP-84 - native segwit: m/84'/slip44' (/account'/change/addr)
|
||||
namespaces.append((curve, [84 | HARDENED, slip44_id]))
|
||||
|
||||
keychain = await get_keychain(ctx, namespaces)
|
||||
return keychain, coin
|
||||
|
||||
|
||||
def with_keychain(func: HandlerWithCoinInfo[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]:
|
||||
async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut:
|
||||
keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name)
|
||||
with keychain:
|
||||
return await func(ctx, msg, keychain, coin)
|
||||
|
||||
return wrapper
|
@ -8,18 +8,20 @@ from trezor.utils import chunks
|
||||
|
||||
from apps.common import HARDENED, coins
|
||||
from apps.common.confirm import require_confirm
|
||||
from apps.common.seed import get_keychain
|
||||
|
||||
|
||||
async def sign_identity(ctx, msg, keychain):
|
||||
async def sign_identity(ctx, msg):
|
||||
if msg.ecdsa_curve_name is None:
|
||||
msg.ecdsa_curve_name = "secp256k1"
|
||||
|
||||
keychain = await get_keychain(ctx, [(msg.ecdsa_curve_name, [])])
|
||||
identity = serialize_identity(msg.identity)
|
||||
|
||||
await require_confirm_sign_identity(ctx, msg.identity, msg.challenge_visual)
|
||||
|
||||
address_n = get_identity_path(identity, msg.identity.index or 0)
|
||||
node = keychain.derive(address_n, msg.ecdsa_curve_name)
|
||||
node = keychain.derive(address_n)
|
||||
|
||||
coin = coins.by_name("Bitcoin")
|
||||
if msg.ecdsa_curve_name == "secp256k1":
|
||||
|
@ -3,18 +3,18 @@ from trezor.crypto.curve import secp256k1
|
||||
from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPENDWITNESS
|
||||
from trezor.messages.MessageSignature import MessageSignature
|
||||
|
||||
from apps.common import coins
|
||||
from .keychain import with_keychain
|
||||
from .sign_tx.addresses import get_address, validate_full_path
|
||||
|
||||
from apps.common.paths import validate_path
|
||||
from apps.common.signverify import message_digest, require_confirm_sign_message
|
||||
from apps.wallet.sign_tx.addresses import get_address, validate_full_path
|
||||
|
||||
|
||||
async def sign_message(ctx, msg, keychain):
|
||||
@with_keychain
|
||||
async def sign_message(ctx, msg, keychain, coin):
|
||||
message = msg.message
|
||||
address_n = msg.address_n
|
||||
coin_name = msg.coin_name or "Bitcoin"
|
||||
script_type = msg.script_type or 0
|
||||
coin = coins.by_name(coin_name)
|
||||
|
||||
await require_confirm_sign_message(ctx, "Sign message", message)
|
||||
await validate_path(
|
||||
@ -28,7 +28,7 @@ async def sign_message(ctx, msg, keychain):
|
||||
validate_script_type=False,
|
||||
)
|
||||
|
||||
node = keychain.derive(address_n, coin.curve_name)
|
||||
node = keychain.derive(address_n)
|
||||
seckey = node.private_key()
|
||||
|
||||
address = get_address(script_type, coin, node)
|
||||
|
@ -4,17 +4,10 @@ from trezor.messages.SignTx import SignTx
|
||||
from trezor.messages.TxAck import TxAck
|
||||
from trezor.messages.TxRequest import TxRequest
|
||||
|
||||
from apps.common import coins, paths, seed
|
||||
from apps.wallet.sign_tx import (
|
||||
addresses,
|
||||
bitcoin,
|
||||
common,
|
||||
helpers,
|
||||
layout,
|
||||
multisig,
|
||||
progress,
|
||||
scripts,
|
||||
)
|
||||
from ..keychain import with_keychain
|
||||
from . import addresses, bitcoin, common, helpers, layout, multisig, progress, scripts
|
||||
|
||||
from apps.common import coininfo, paths, seed
|
||||
|
||||
if not utils.BITCOIN_ONLY:
|
||||
from apps.wallet.sign_tx import bitcoinlike, decred, zcash
|
||||
@ -26,16 +19,16 @@ if False:
|
||||
BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet")
|
||||
|
||||
|
||||
async def sign_tx(ctx: wire.Context, msg: SignTx, keychain: seed.Keychain) -> TxRequest:
|
||||
coin_name = msg.coin_name if msg.coin_name is not None else "Bitcoin"
|
||||
coin = coins.by_name(coin_name)
|
||||
|
||||
@with_keychain
|
||||
async def sign_tx(
|
||||
ctx: wire.Context, msg: SignTx, keychain: seed.Keychain, coin: coininfo.CoinInfo
|
||||
) -> TxRequest:
|
||||
if not utils.BITCOIN_ONLY:
|
||||
if coin.decred:
|
||||
signer_class = decred.Decred # type: Type[bitcoin.Bitcoin]
|
||||
elif coin.overwintered:
|
||||
signer_class = zcash.Overwintered
|
||||
elif coin_name not in BITCOIN_NAMES:
|
||||
elif coin.coin_name not in BITCOIN_NAMES:
|
||||
signer_class = bitcoinlike.Bitcoinlike
|
||||
else:
|
||||
signer_class = bitcoin.Bitcoin
|
||||
|
@ -235,7 +235,7 @@ class Bitcoin:
|
||||
# NOTE: No need to check the multisig fingerprint, because we won't be signing
|
||||
# the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS.
|
||||
|
||||
node = self.keychain.derive(txi.address_n, self.coin.curve_name)
|
||||
node = self.keychain.derive(txi.address_n)
|
||||
key_sign_pub = node.public_key()
|
||||
script_sig = self.input_derive_script(txi, key_sign_pub)
|
||||
self.write_tx_input(self.serialized_tx, txi, script_sig)
|
||||
@ -250,7 +250,7 @@ class Bitcoin:
|
||||
)
|
||||
self.bip143_in -= txi.amount
|
||||
|
||||
node = self.keychain.derive(txi.address_n, self.coin.curve_name)
|
||||
node = self.keychain.derive(txi.address_n)
|
||||
public_key = node.public_key()
|
||||
hash143_hash = self.hash143_preimage_hash(
|
||||
txi, addresses.ecdsa_hash_pubkey(public_key, self.coin)
|
||||
@ -302,7 +302,7 @@ class Bitcoin:
|
||||
self.wallet_path.check_input(txi)
|
||||
self.multisig_fingerprint.check_input(txi)
|
||||
# NOTE: wallet_path is checked in write_tx_input_check()
|
||||
node = self.keychain.derive(txi.address_n, self.coin.curve_name)
|
||||
node = self.keychain.derive(txi.address_n)
|
||||
key_sign_pub = node.public_key()
|
||||
# if multisig, do a sanity check to ensure we are signing with a key that is included in the multisig
|
||||
if txi.multisig:
|
||||
@ -470,7 +470,7 @@ class Bitcoin:
|
||||
]
|
||||
except KeyError:
|
||||
raise SigningError(FailureType.DataError, "Invalid script type")
|
||||
node = self.keychain.derive(txo.address_n, self.coin.curve_name)
|
||||
node = self.keychain.derive(txo.address_n)
|
||||
txo.address = addresses.get_address(
|
||||
input_script_type, self.coin, node, txo.multisig
|
||||
)
|
||||
|
@ -80,7 +80,7 @@ class Decred(Bitcoin):
|
||||
self.wallet_path.check_input(txi_sign)
|
||||
self.multisig_fingerprint.check_input(txi_sign)
|
||||
|
||||
key_sign = self.keychain.derive(txi_sign.address_n, self.coin.curve_name)
|
||||
key_sign = self.keychain.derive(txi_sign.address_n)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
|
||||
if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
|
||||
|
Loading…
Reference in New Issue
Block a user