diff --git a/core/src/apps/wallet/__init__.py b/core/src/apps/wallet/__init__.py index 03d5321c30..6d15518bdd 100644 --- a/core/src/apps/wallet/__init__.py +++ b/core/src/apps/wallet/__init__.py @@ -3,22 +3,12 @@ from trezor.messages import MessageType def boot() -> None: - ns = [ - ["curve25519"], - ["ed25519"], - ["ed25519-keccak"], - ["nist256p1"], - ["secp256k1"], - ["secp256k1-decred"], - ["secp256k1-groestl"], - ["secp256k1-smart"], - ] - wire.add(MessageType.GetPublicKey, __name__, "get_public_key", ns) - wire.add(MessageType.GetAddress, __name__, "get_address", ns) + wire.add(MessageType.GetPublicKey, __name__, "get_public_key") + wire.add(MessageType.GetAddress, __name__, "get_address") wire.add(MessageType.GetEntropy, __name__, "get_entropy") - wire.add(MessageType.SignTx, __name__, "sign_tx", ns) - wire.add(MessageType.SignMessage, __name__, "sign_message", ns) + wire.add(MessageType.SignTx, __name__, "sign_tx") + wire.add(MessageType.SignMessage, __name__, "sign_message") wire.add(MessageType.VerifyMessage, __name__, "verify_message") - wire.add(MessageType.SignIdentity, __name__, "sign_identity", ns) - wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key", ns) - wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value", ns) + wire.add(MessageType.SignIdentity, __name__, "sign_identity") + wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key") + wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value") diff --git a/core/src/apps/wallet/cipher_key_value.py b/core/src/apps/wallet/cipher_key_value.py index b6c14663bf..abecd79a8a 100644 --- a/core/src/apps/wallet/cipher_key_value.py +++ b/core/src/apps/wallet/cipher_key_value.py @@ -5,9 +5,12 @@ from trezor.messages.CipheredKeyValue import CipheredKeyValue from trezor.ui.text import Text from apps.common.confirm import require_confirm +from apps.common.seed import get_keychain -async def cipher_key_value(ctx, msg, keychain): +async def cipher_key_value(ctx, msg): + keychain = await get_keychain(ctx, [("secp256k1", [])]) + if len(msg.value) % 16 > 0: raise wire.DataError("Value length must be a multiple of 16") diff --git a/core/src/apps/wallet/get_address.py b/core/src/apps/wallet/get_address.py index c6c13f40e8..a63b90839b 100644 --- a/core/src/apps/wallet/get_address.py +++ b/core/src/apps/wallet/get_address.py @@ -2,11 +2,12 @@ from trezor.crypto import bip32 from trezor.messages import InputScriptType from trezor.messages.Address import Address -from apps.common import coins +from .keychain import with_keychain +from .sign_tx import addresses +from .sign_tx.multisig import multisig_pubkey_index + from apps.common.layout import address_n_to_str, show_address, show_qr, show_xpub from apps.common.paths import validate_path -from apps.wallet.sign_tx import addresses -from apps.wallet.sign_tx.multisig import multisig_pubkey_index if False: from typing import List @@ -36,10 +37,8 @@ async def show_xpubs( return False -async def get_address(ctx, msg, keychain): - coin_name = msg.coin_name or "Bitcoin" - coin = coins.by_name(coin_name) - +@with_keychain +async def get_address(ctx, msg, keychain, coin): await validate_path( ctx, addresses.validate_full_path, @@ -50,7 +49,7 @@ async def get_address(ctx, msg, keychain): script_type=msg.script_type, ) - node = keychain.derive(msg.address_n, coin.curve_name) + node = keychain.derive(msg.address_n) address = addresses.get_address(msg.script_type, coin, node, msg.multisig) address_short = addresses.address_short(coin, address) if msg.script_type == InputScriptType.SPENDWITNESS: diff --git a/core/src/apps/wallet/get_ecdh_session_key.py b/core/src/apps/wallet/get_ecdh_session_key.py index c3816fcf91..52d5fa5984 100644 --- a/core/src/apps/wallet/get_ecdh_session_key.py +++ b/core/src/apps/wallet/get_ecdh_session_key.py @@ -8,22 +8,24 @@ from trezor.utils import chunks from apps.common import HARDENED from apps.common.confirm import require_confirm +from apps.common.seed import get_keychain from apps.wallet.sign_identity import ( serialize_identity, serialize_identity_without_proto, ) -async def get_ecdh_session_key(ctx, msg, keychain): +async def get_ecdh_session_key(ctx, msg): if msg.ecdsa_curve_name is None: msg.ecdsa_curve_name = "secp256k1" + keychain = await get_keychain(ctx, [(msg.ecdsa_curve_name, [])]) identity = serialize_identity(msg.identity) await require_confirm_ecdh_session_key(ctx, msg.identity) address_n = get_ecdh_path(identity, msg.identity.index or 0) - node = keychain.derive(address_n, msg.ecdsa_curve_name) + node = keychain.derive(address_n) session_key = ecdh( seckey=node.private_key(), diff --git a/core/src/apps/wallet/get_public_key.py b/core/src/apps/wallet/get_public_key.py index e779ff02dc..8fad540fb5 100644 --- a/core/src/apps/wallet/get_public_key.py +++ b/core/src/apps/wallet/get_public_key.py @@ -3,16 +3,42 @@ from trezor.messages import InputScriptType from trezor.messages.HDNodeType import HDNodeType from trezor.messages.PublicKey import PublicKey -from apps.common import coins, layout +from .keychain import get_keychain_for_coin + +from apps.common import HARDENED, coins, layout, seed -async def get_public_key(ctx, msg, keychain): - coin_name = msg.coin_name or "Bitcoin" - coin = coins.by_name(coin_name) - curve_name = msg.ecdsa_curve_name or coin.curve_name +async def get_keychain_for_curve(ctx: wire.Context, curve_name: str) -> seed.Keychain: + """Set up a keychain for SLIP-13 and SLIP-17 namespaces with a specified curve.""" + namespaces = [ + (curve_name, [13 | HARDENED]), + (curve_name, [17 | HARDENED]), + ] + return await seed.get_keychain(ctx, namespaces) + + +async def get_public_key(ctx, msg): script_type = msg.script_type or InputScriptType.SPENDADDRESS - node = keychain.derive(msg.address_n, curve_name=curve_name) + if msg.ecdsa_curve_name is not None: + # If a curve name is provided, disallow coin-specific features. + if ( + msg.coin_name is not None + or msg.script_type is not InputScriptType.SPENDADDRESS + ): + raise wire.DataError( + "Cannot use coin_name or script_type with ecdsa_curve_name" + ) + + coin = coins.by_name("Bitcoin") + # only allow SLIP-13/17 namespaces + keychain = await get_keychain_for_curve(ctx, msg.ecdsa_curve_name) + + else: + # select curve and namespaces based on the requested coin properties + keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name) + + node = keychain.derive(msg.address_n) if ( script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG] diff --git a/core/src/apps/wallet/keychain.py b/core/src/apps/wallet/keychain.py new file mode 100644 index 0000000000..99caa2dfe8 --- /dev/null +++ b/core/src/apps/wallet/keychain.py @@ -0,0 +1,61 @@ +from trezor import wire + +from apps.common import HARDENED, coininfo +from apps.common.seed import get_keychain + +if False: + from protobuf import MessageType + from typing import Callable, Optional, Tuple, TypeVar + from typing_extensions import Protocol + + from apps.common.seed import Keychain, MsgOut, Handler + + class MsgWithCoinName(MessageType, Protocol): + coin_name = ... # type: Optional[str] + + MsgIn = TypeVar("MsgIn", bound=MsgWithCoinName) + HandlerWithCoinInfo = Callable[ + [wire.Context, MsgIn, Keychain, coininfo.CoinInfo], MsgOut + ] + + +async def get_keychain_for_coin( + ctx: wire.Context, coin_name: Optional[str] +) -> Tuple[Keychain, coininfo.CoinInfo]: + if coin_name is None: + coin_name = "Bitcoin" + + try: + coin = coininfo.by_name(coin_name) + except ValueError: + raise wire.DataError("Unsupported coin type") + + namespaces = [] + curve = coin.curve_name + slip44_id = coin.slip44 | HARDENED + + # BIP-44 - legacy: m/44'/slip44' (/account'/change/addr) + namespaces.append((curve, [44 | HARDENED, slip44_id])) + # BIP-45 - multisig cosigners: m/45' (/cosigner/change/addr) + namespaces.append((curve, [45 | HARDENED])) + # "purpose48" - multisig as done by Electrum + # m/48'/slip44' (/account'/script_type'/change/addr) + namespaces.append((curve, [48 | HARDENED, slip44_id])) + + if coin.segwit: + # BIP-49 - p2sh segwit: m/49'/slip44' (/account'/change/addr) + namespaces.append((curve, [49 | HARDENED, slip44_id])) + # BIP-84 - native segwit: m/84'/slip44' (/account'/change/addr) + namespaces.append((curve, [84 | HARDENED, slip44_id])) + + keychain = await get_keychain(ctx, namespaces) + return keychain, coin + + +def with_keychain(func: HandlerWithCoinInfo[MsgIn, MsgOut]) -> Handler[MsgIn, MsgOut]: + async def wrapper(ctx: wire.Context, msg: MsgIn) -> MsgOut: + keychain, coin = await get_keychain_for_coin(ctx, msg.coin_name) + with keychain: + return await func(ctx, msg, keychain, coin) + + return wrapper diff --git a/core/src/apps/wallet/sign_identity.py b/core/src/apps/wallet/sign_identity.py index 234a7db04e..fc02ad736d 100644 --- a/core/src/apps/wallet/sign_identity.py +++ b/core/src/apps/wallet/sign_identity.py @@ -8,18 +8,20 @@ from trezor.utils import chunks from apps.common import HARDENED, coins from apps.common.confirm import require_confirm +from apps.common.seed import get_keychain -async def sign_identity(ctx, msg, keychain): +async def sign_identity(ctx, msg): if msg.ecdsa_curve_name is None: msg.ecdsa_curve_name = "secp256k1" + keychain = await get_keychain(ctx, [(msg.ecdsa_curve_name, [])]) identity = serialize_identity(msg.identity) await require_confirm_sign_identity(ctx, msg.identity, msg.challenge_visual) address_n = get_identity_path(identity, msg.identity.index or 0) - node = keychain.derive(address_n, msg.ecdsa_curve_name) + node = keychain.derive(address_n) coin = coins.by_name("Bitcoin") if msg.ecdsa_curve_name == "secp256k1": diff --git a/core/src/apps/wallet/sign_message.py b/core/src/apps/wallet/sign_message.py index fb4c508b66..f750384870 100644 --- a/core/src/apps/wallet/sign_message.py +++ b/core/src/apps/wallet/sign_message.py @@ -3,18 +3,18 @@ from trezor.crypto.curve import secp256k1 from trezor.messages.InputScriptType import SPENDADDRESS, SPENDP2SHWITNESS, SPENDWITNESS from trezor.messages.MessageSignature import MessageSignature -from apps.common import coins +from .keychain import with_keychain +from .sign_tx.addresses import get_address, validate_full_path + from apps.common.paths import validate_path from apps.common.signverify import message_digest, require_confirm_sign_message -from apps.wallet.sign_tx.addresses import get_address, validate_full_path -async def sign_message(ctx, msg, keychain): +@with_keychain +async def sign_message(ctx, msg, keychain, coin): message = msg.message address_n = msg.address_n - coin_name = msg.coin_name or "Bitcoin" script_type = msg.script_type or 0 - coin = coins.by_name(coin_name) await require_confirm_sign_message(ctx, "Sign message", message) await validate_path( @@ -28,7 +28,7 @@ async def sign_message(ctx, msg, keychain): validate_script_type=False, ) - node = keychain.derive(address_n, coin.curve_name) + node = keychain.derive(address_n) seckey = node.private_key() address = get_address(script_type, coin, node) diff --git a/core/src/apps/wallet/sign_tx/__init__.py b/core/src/apps/wallet/sign_tx/__init__.py index 6253bd0890..50b5a75144 100644 --- a/core/src/apps/wallet/sign_tx/__init__.py +++ b/core/src/apps/wallet/sign_tx/__init__.py @@ -4,17 +4,10 @@ from trezor.messages.SignTx import SignTx from trezor.messages.TxAck import TxAck from trezor.messages.TxRequest import TxRequest -from apps.common import coins, paths, seed -from apps.wallet.sign_tx import ( - addresses, - bitcoin, - common, - helpers, - layout, - multisig, - progress, - scripts, -) +from ..keychain import with_keychain +from . import addresses, bitcoin, common, helpers, layout, multisig, progress, scripts + +from apps.common import coininfo, paths, seed if not utils.BITCOIN_ONLY: from apps.wallet.sign_tx import bitcoinlike, decred, zcash @@ -26,16 +19,16 @@ if False: BITCOIN_NAMES = ("Bitcoin", "Regtest", "Testnet") -async def sign_tx(ctx: wire.Context, msg: SignTx, keychain: seed.Keychain) -> TxRequest: - coin_name = msg.coin_name if msg.coin_name is not None else "Bitcoin" - coin = coins.by_name(coin_name) - +@with_keychain +async def sign_tx( + ctx: wire.Context, msg: SignTx, keychain: seed.Keychain, coin: coininfo.CoinInfo +) -> TxRequest: if not utils.BITCOIN_ONLY: if coin.decred: signer_class = decred.Decred # type: Type[bitcoin.Bitcoin] elif coin.overwintered: signer_class = zcash.Overwintered - elif coin_name not in BITCOIN_NAMES: + elif coin.coin_name not in BITCOIN_NAMES: signer_class = bitcoinlike.Bitcoinlike else: signer_class = bitcoin.Bitcoin diff --git a/core/src/apps/wallet/sign_tx/bitcoin.py b/core/src/apps/wallet/sign_tx/bitcoin.py index 3336b47114..38df36a410 100644 --- a/core/src/apps/wallet/sign_tx/bitcoin.py +++ b/core/src/apps/wallet/sign_tx/bitcoin.py @@ -235,7 +235,7 @@ class Bitcoin: # NOTE: No need to check the multisig fingerprint, because we won't be signing # the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS. - node = self.keychain.derive(txi.address_n, self.coin.curve_name) + node = self.keychain.derive(txi.address_n) key_sign_pub = node.public_key() script_sig = self.input_derive_script(txi, key_sign_pub) self.write_tx_input(self.serialized_tx, txi, script_sig) @@ -250,7 +250,7 @@ class Bitcoin: ) self.bip143_in -= txi.amount - node = self.keychain.derive(txi.address_n, self.coin.curve_name) + node = self.keychain.derive(txi.address_n) public_key = node.public_key() hash143_hash = self.hash143_preimage_hash( txi, addresses.ecdsa_hash_pubkey(public_key, self.coin) @@ -302,7 +302,7 @@ class Bitcoin: self.wallet_path.check_input(txi) self.multisig_fingerprint.check_input(txi) # NOTE: wallet_path is checked in write_tx_input_check() - node = self.keychain.derive(txi.address_n, self.coin.curve_name) + node = self.keychain.derive(txi.address_n) key_sign_pub = node.public_key() # if multisig, do a sanity check to ensure we are signing with a key that is included in the multisig if txi.multisig: @@ -470,7 +470,7 @@ class Bitcoin: ] except KeyError: raise SigningError(FailureType.DataError, "Invalid script type") - node = self.keychain.derive(txo.address_n, self.coin.curve_name) + node = self.keychain.derive(txo.address_n) txo.address = addresses.get_address( input_script_type, self.coin, node, txo.multisig ) diff --git a/core/src/apps/wallet/sign_tx/decred.py b/core/src/apps/wallet/sign_tx/decred.py index eee0dd4570..9084c939f7 100644 --- a/core/src/apps/wallet/sign_tx/decred.py +++ b/core/src/apps/wallet/sign_tx/decred.py @@ -80,7 +80,7 @@ class Decred(Bitcoin): self.wallet_path.check_input(txi_sign) self.multisig_fingerprint.check_input(txi_sign) - key_sign = self.keychain.derive(txi_sign.address_n, self.coin.curve_name) + key_sign = self.keychain.derive(txi_sign.address_n) key_sign_pub = key_sign.public_key() if txi_sign.script_type == InputScriptType.SPENDMULTISIG: