From 8aa60e6cfd84ba35613e5eed0088fb43e97aa8e0 Mon Sep 17 00:00:00 2001 From: Tomas Susanka Date: Fri, 5 Apr 2019 11:23:06 +0200 Subject: [PATCH] paths: validate curve as well --- src/apps/cardano/__init__.py | 3 +- src/apps/cardano/get_address.py | 4 +- src/apps/cardano/get_public_key.py | 3 +- src/apps/cardano/seed.py | 10 ++-- src/apps/cardano/sign_tx.py | 4 +- src/apps/common/paths.py | 4 +- src/apps/common/seed.py | 16 ++++- src/apps/ethereum/__init__.py | 4 +- src/apps/ethereum/get_address.py | 4 +- src/apps/ethereum/get_public_key.py | 4 +- src/apps/ethereum/sign_message.py | 6 +- src/apps/ethereum/sign_tx.py | 4 +- src/apps/lisk/__init__.py | 4 +- src/apps/lisk/get_address.py | 7 ++- src/apps/lisk/get_public_key.py | 8 +-- src/apps/lisk/helpers.py | 2 - src/apps/lisk/sign_message.py | 8 +-- src/apps/lisk/sign_tx.py | 12 ++-- src/apps/monero/get_address.py | 6 +- src/apps/monero/get_tx_keys.py | 6 +- src/apps/monero/get_watch_only.py | 6 +- src/apps/monero/key_image_sync.py | 6 +- src/apps/monero/live_refresh.py | 6 +- .../signing/step_01_init_transaction.py | 6 +- src/apps/nem/__init__.py | 7 +-- src/apps/nem/get_address.py | 12 ++-- src/apps/nem/helpers.py | 1 - src/apps/nem/sign_tx.py | 7 ++- src/apps/ripple/__init__.py | 4 +- src/apps/ripple/get_address.py | 6 +- src/apps/ripple/sign_tx.py | 6 +- src/apps/stellar/__init__.py | 4 +- src/apps/stellar/consts.py | 1 - src/apps/stellar/get_address.py | 8 ++- src/apps/stellar/helpers.py | 2 - src/apps/stellar/sign_tx.py | 8 ++- src/apps/tezos/__init__.py | 4 +- src/apps/tezos/get_address.py | 8 ++- src/apps/tezos/get_public_key.py | 8 ++- src/apps/tezos/helpers.py | 1 - src/apps/tezos/sign_tx.py | 8 ++- src/apps/wallet/get_address.py | 1 + src/apps/wallet/sign_message.py | 1 + tests/test_apps.common.seed.py | 60 +++++++++++++++++++ tests/test_apps.nem.hdnode.py | 7 ++- 45 files changed, 206 insertions(+), 101 deletions(-) create mode 100644 tests/test_apps.common.seed.py diff --git a/src/apps/cardano/__init__.py b/src/apps/cardano/__init__.py index dcfa3b5f4..648340077 100644 --- a/src/apps/cardano/__init__.py +++ b/src/apps/cardano/__init__.py @@ -3,7 +3,8 @@ from trezor.messages import MessageType from apps.common import HARDENED -SEED_NAMESPACE = [[HARDENED | 44, HARDENED | 1815]] +CURVE = "ed25519" +SEED_NAMESPACE = [HARDENED | 44, HARDENED | 1815] def boot(): diff --git a/src/apps/cardano/get_address.py b/src/apps/cardano/get_address.py index cb772b4a7..e84de8e65 100644 --- a/src/apps/cardano/get_address.py +++ b/src/apps/cardano/get_address.py @@ -1,7 +1,7 @@ from trezor import log, wire from trezor.messages.CardanoAddress import CardanoAddress -from apps.cardano import seed +from apps.cardano import CURVE, seed from apps.cardano.address import derive_address_and_node, validate_full_path from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr @@ -10,7 +10,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr async def get_address(ctx, msg): keychain = await seed.get_keychain(ctx) - await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) try: address, _ = derive_address_and_node(keychain, msg.address_n) diff --git a/src/apps/cardano/get_public_key.py b/src/apps/cardano/get_public_key.py index f648406f6..362a84d5a 100644 --- a/src/apps/cardano/get_public_key.py +++ b/src/apps/cardano/get_public_key.py @@ -4,7 +4,7 @@ from trezor import log, wire from trezor.messages.CardanoPublicKey import CardanoPublicKey from trezor.messages.HDNodeType import HDNodeType -from apps.cardano import seed +from apps.cardano import CURVE, seed from apps.cardano.address import derive_address_and_node from apps.common import layout, paths from apps.common.seed import remove_ed25519_prefix @@ -18,6 +18,7 @@ async def get_public_key(ctx, msg): paths.validate_path_for_get_public_key, keychain, msg.address_n, + CURVE, slip44_id=1815, ) diff --git a/src/apps/cardano/seed.py b/src/apps/cardano/seed.py index 60cc994ac..f8efa24b4 100644 --- a/src/apps/cardano/seed.py +++ b/src/apps/cardano/seed.py @@ -1,7 +1,7 @@ from trezor import wire from trezor.crypto import bip32 -from apps.cardano import SEED_NAMESPACE +from apps.cardano import CURVE, SEED_NAMESPACE from apps.common import cache, mnemonic, storage from apps.common.request_passphrase import protect_by_passphrase @@ -11,8 +11,8 @@ class Keychain: self.path = path self.root = root - def validate_path(self, checked_path: list): - if checked_path[:2] != SEED_NAMESPACE[0]: + def validate_path(self, checked_path: list, checked_curve: str): + if checked_curve != CURVE or checked_path[:2] != SEED_NAMESPACE: raise wire.DataError("Forbidden key path") def derive(self, node_path: list) -> bip32.HDNode: @@ -40,8 +40,8 @@ async def get_keychain(ctx: wire.Context) -> Keychain: root = bip32.from_mnemonic_cardano(mnemonic.restore(), passphrase) # derive the namespaced root node - for i in SEED_NAMESPACE[0]: + for i in SEED_NAMESPACE: root.derive_cardano(i) - keychain = Keychain(SEED_NAMESPACE[0], root) + keychain = Keychain(SEED_NAMESPACE, root) return keychain diff --git a/src/apps/cardano/sign_tx.py b/src/apps/cardano/sign_tx.py index 99c4243a5..8493c470a 100644 --- a/src/apps/cardano/sign_tx.py +++ b/src/apps/cardano/sign_tx.py @@ -7,7 +7,7 @@ from trezor.messages.CardanoSignedTx import CardanoSignedTx from trezor.messages.CardanoTxRequest import CardanoTxRequest from trezor.messages.MessageType import CardanoTxAck -from apps.cardano import cbor, seed +from apps.cardano import CURVE, cbor, seed from apps.cardano.address import ( derive_address_and_node, is_safe_output_address, @@ -85,7 +85,7 @@ async def sign_tx(ctx, msg): display_homescreen() for i in msg.inputs: - await validate_path(ctx, validate_full_path, keychain, i.address_n) + await validate_path(ctx, validate_full_path, keychain, i.address_n, CURVE) # sign the transaction bundle and prepare the result transaction = Transaction( diff --git a/src/apps/common/paths.py b/src/apps/common/paths.py index 0f3cae73b..2494bf7f3 100644 --- a/src/apps/common/paths.py +++ b/src/apps/common/paths.py @@ -8,8 +8,8 @@ from apps.common import HARDENED from apps.common.confirm import require_confirm -async def validate_path(ctx, validate_func, keychain, path, **kwargs): - keychain.validate_path(path) +async def validate_path(ctx, validate_func, keychain, path, curve, **kwargs): + keychain.validate_path(path, curve) if not validate_func(path, **kwargs): await show_path_warning(ctx, path) diff --git a/src/apps/common/seed.py b/src/apps/common/seed.py index 260bbb5c7..6a2c084a7 100644 --- a/src/apps/common/seed.py +++ b/src/apps/common/seed.py @@ -1,7 +1,7 @@ from trezor import ui, wire from trezor.crypto import bip32 -from apps.common import cache, mnemonic, storage +from apps.common import HARDENED, cache, mnemonic, storage from apps.common.request_passphrase import protect_by_passphrase allow = list @@ -25,9 +25,11 @@ class Keychain: del self.roots del self.seed - def validate_path(self, checked_path: list): + def validate_path(self, checked_path: list, checked_curve: str): for curve, *path in self.namespaces: - if path == checked_path[: len(path)]: # TODO: check curve_name + if path == checked_path[: len(path)] and curve == checked_curve: + if curve == "ed25519" and not _path_hardened(checked_path): + break return raise wire.DataError("Forbidden key path") @@ -67,6 +69,14 @@ async def get_keychain(ctx: wire.Context, namespaces: list) -> Keychain: return keychain +def _path_hardened(path: list) -> bool: + # TODO: move to paths.py after #538 is fixed + for i in path: + if not (i & HARDENED): + return False + return True + + @ui.layout_no_slide async def _compute_seed(ctx: wire.Context) -> bytes: passphrase = cache.get_passphrase() diff --git a/src/apps/ethereum/__init__.py b/src/apps/ethereum/__init__.py index 8c9be34c7..553f35ecd 100644 --- a/src/apps/ethereum/__init__.py +++ b/src/apps/ethereum/__init__.py @@ -4,11 +4,13 @@ from trezor.messages import MessageType from apps.common import HARDENED from apps.ethereum.networks import all_slip44_ids_hardened +CURVE = "secp256k1" + def boot(): ns = [] for i in all_slip44_ids_hardened(): - ns.append(["secp256k1", HARDENED | 44, i]) + ns.append([CURVE, HARDENED | 44, i]) wire.add(MessageType.EthereumGetAddress, __name__, "get_address", ns) wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key", ns) wire.add(MessageType.EthereumSignTx, __name__, "sign_tx", ns) diff --git a/src/apps/ethereum/get_address.py b/src/apps/ethereum/get_address.py index b7e58db6f..b44b7ace4 100644 --- a/src/apps/ethereum/get_address.py +++ b/src/apps/ethereum/get_address.py @@ -4,12 +4,12 @@ from trezor.messages.EthereumAddress import EthereumAddress from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr -from apps.ethereum import networks +from apps.ethereum import CURVE, networks from apps.ethereum.address import address_from_bytes, validate_full_path async def get_address(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) node = keychain.derive(msg.address_n) seckey = node.private_key() diff --git a/src/apps/ethereum/get_public_key.py b/src/apps/ethereum/get_public_key.py index 2521a1ebf..e463473a8 100644 --- a/src/apps/ethereum/get_public_key.py +++ b/src/apps/ethereum/get_public_key.py @@ -2,12 +2,12 @@ from trezor.messages.EthereumPublicKey import EthereumPublicKey from trezor.messages.HDNodeType import HDNodeType from apps.common import coins, layout, paths -from apps.ethereum import address +from apps.ethereum import CURVE, address async def get_public_key(ctx, msg, keychain): await paths.validate_path( - ctx, address.validate_path_for_get_public_key, keychain, msg.address_n + ctx, address.validate_path_for_get_public_key, keychain, msg.address_n, CURVE ) node = keychain.derive(msg.address_n) diff --git a/src/apps/ethereum/sign_message.py b/src/apps/ethereum/sign_message.py index abffef522..e029f6951 100644 --- a/src/apps/ethereum/sign_message.py +++ b/src/apps/ethereum/sign_message.py @@ -7,7 +7,7 @@ from trezor.utils import HashWriter from apps.common import paths from apps.common.confirm import require_confirm from apps.common.signverify import split_message -from apps.ethereum import address +from apps.ethereum import CURVE, address def message_digest(message): @@ -20,7 +20,9 @@ def message_digest(message): async def sign_message(ctx, msg, keychain): - await paths.validate_path(ctx, address.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, address.validate_full_path, keychain, msg.address_n, CURVE + ) await require_confirm_sign_message(ctx, msg.message) node = keychain.derive(msg.address_n) diff --git a/src/apps/ethereum/sign_tx.py b/src/apps/ethereum/sign_tx.py index 0c5c86ea3..16985d40e 100644 --- a/src/apps/ethereum/sign_tx.py +++ b/src/apps/ethereum/sign_tx.py @@ -8,7 +8,7 @@ from trezor.messages.MessageType import EthereumTxAck from trezor.utils import HashWriter from apps.common import paths -from apps.ethereum import address, tokens +from apps.ethereum import CURVE, address, tokens from apps.ethereum.address import validate_full_path from apps.ethereum.layout import ( require_confirm_data, @@ -23,7 +23,7 @@ MAX_CHAIN_ID = 2147483629 async def sign_tx(ctx, msg, keychain): msg = sanitize(msg) check(msg) - await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) data_total = msg.data_length diff --git a/src/apps/lisk/__init__.py b/src/apps/lisk/__init__.py index f4a092cec..4032a1951 100644 --- a/src/apps/lisk/__init__.py +++ b/src/apps/lisk/__init__.py @@ -3,9 +3,11 @@ from trezor.messages import MessageType from apps.common import HARDENED +CURVE = "ed25519" + def boot(): - ns = [["ed25519", HARDENED | 44, HARDENED | 134]] + ns = [[CURVE, HARDENED | 44, HARDENED | 134]] wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key", ns) wire.add(MessageType.LiskGetAddress, __name__, "get_address", ns) wire.add(MessageType.LiskSignTx, __name__, "sign_tx", ns) diff --git a/src/apps/lisk/get_address.py b/src/apps/lisk/get_address.py index c0b8dc4d5..94b06820f 100644 --- a/src/apps/lisk/get_address.py +++ b/src/apps/lisk/get_address.py @@ -1,15 +1,16 @@ from trezor.messages.LiskAddress import LiskAddress -from .helpers import LISK_CURVE, get_address_from_public_key, validate_full_path +from .helpers import get_address_from_public_key, validate_full_path from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr +from apps.lisk import CURVE async def get_address(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) - node = keychain.derive(msg.address_n, LISK_CURVE) + node = keychain.derive(msg.address_n, CURVE) pubkey = node.public_key() pubkey = pubkey[1:] # skip ed25519 pubkey marker address = get_address_from_public_key(pubkey) diff --git a/src/apps/lisk/get_public_key.py b/src/apps/lisk/get_public_key.py index 3baa21f46..ad7a0864e 100644 --- a/src/apps/lisk/get_public_key.py +++ b/src/apps/lisk/get_public_key.py @@ -1,14 +1,14 @@ from trezor.messages.LiskPublicKey import LiskPublicKey -from .helpers import LISK_CURVE, validate_full_path - from apps.common import layout, paths +from apps.lisk import CURVE +from apps.lisk.helpers import validate_full_path async def get_public_key(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) - node = keychain.derive(msg.address_n, LISK_CURVE) + node = keychain.derive(msg.address_n, CURVE) pubkey = node.public_key() pubkey = pubkey[1:] # skip ed25519 pubkey marker diff --git a/src/apps/lisk/helpers.py b/src/apps/lisk/helpers.py index cc32a1f50..ee801f942 100644 --- a/src/apps/lisk/helpers.py +++ b/src/apps/lisk/helpers.py @@ -2,8 +2,6 @@ from trezor.crypto.hashlib import sha256 from apps.common import HARDENED -LISK_CURVE = "ed25519" - def get_address_from_public_key(pubkey): pubkeyhash = sha256(pubkey).digest() diff --git a/src/apps/lisk/sign_message.py b/src/apps/lisk/sign_message.py index 9650f888f..cbd1f4e07 100644 --- a/src/apps/lisk/sign_message.py +++ b/src/apps/lisk/sign_message.py @@ -4,11 +4,11 @@ from trezor.messages.LiskMessageSignature import LiskMessageSignature from trezor.ui.text import Text from trezor.utils import HashWriter -from .helpers import LISK_CURVE, validate_full_path - from apps.common import paths from apps.common.confirm import require_confirm from apps.common.signverify import split_message +from apps.lisk import CURVE +from apps.lisk.helpers import validate_full_path from apps.wallet.sign_tx.writers import write_varint @@ -23,10 +23,10 @@ def message_digest(message): async def sign_message(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) await require_confirm_sign_message(ctx, msg.message) - node = keychain.derive(msg.address_n, LISK_CURVE) + node = keychain.derive(msg.address_n, CURVE) seckey = node.private_key() pubkey = node.public_key() pubkey = pubkey[1:] # skip ed25519 pubkey marker diff --git a/src/apps/lisk/sign_tx.py b/src/apps/lisk/sign_tx.py index 8c7232273..05e729276 100644 --- a/src/apps/lisk/sign_tx.py +++ b/src/apps/lisk/sign_tx.py @@ -8,16 +8,12 @@ from trezor.messages.LiskSignedTx import LiskSignedTx from trezor.utils import HashWriter from apps.common import paths -from apps.lisk import layout -from apps.lisk.helpers import ( - LISK_CURVE, - get_address_from_public_key, - validate_full_path, -) +from apps.lisk import CURVE, layout +from apps.lisk.helpers import get_address_from_public_key, validate_full_path async def sign_tx(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n, CURVE) pubkey, seckey = _get_keys(keychain, msg) transaction = _update_raw_tx(msg.transaction, pubkey) @@ -41,7 +37,7 @@ async def sign_tx(ctx, msg, keychain): def _get_keys(keychain, msg): - node = keychain.derive(msg.address_n, LISK_CURVE) + node = keychain.derive(msg.address_n, CURVE) seckey = node.private_key() pubkey = node.public_key() diff --git a/src/apps/monero/get_address.py b/src/apps/monero/get_address.py index 83c069157..5448f8a4c 100644 --- a/src/apps/monero/get_address.py +++ b/src/apps/monero/get_address.py @@ -2,11 +2,13 @@ from trezor.messages.MoneroAddress import MoneroAddress from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr -from apps.monero import misc +from apps.monero import CURVE, misc async def get_address(ctx, msg, keychain): - await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, misc.validate_full_path, keychain, msg.address_n, CURVE + ) creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/src/apps/monero/get_tx_keys.py b/src/apps/monero/get_tx_keys.py index 40c7d6c4d..7d690495b 100644 --- a/src/apps/monero/get_tx_keys.py +++ b/src/apps/monero/get_tx_keys.py @@ -20,7 +20,7 @@ from trezor.messages.MoneroGetTxKeyAck import MoneroGetTxKeyAck from trezor.messages.MoneroGetTxKeyRequest import MoneroGetTxKeyRequest from apps.common import paths -from apps.monero import misc +from apps.monero import CURVE, misc from apps.monero.layout import confirms from apps.monero.xmr import crypto from apps.monero.xmr.crypto import chacha_poly @@ -30,7 +30,9 @@ _GET_TX_KEY_REASON_TX_DERIVATION = 1 async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain): - await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, misc.validate_full_path, keychain, msg.address_n, CURVE + ) do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv) diff --git a/src/apps/monero/get_watch_only.py b/src/apps/monero/get_watch_only.py index fe44f0b1f..5739fca24 100644 --- a/src/apps/monero/get_watch_only.py +++ b/src/apps/monero/get_watch_only.py @@ -2,13 +2,15 @@ from trezor.messages.MoneroGetWatchKey import MoneroGetWatchKey from trezor.messages.MoneroWatchKey import MoneroWatchKey from apps.common import paths -from apps.monero import misc +from apps.monero import CURVE, misc from apps.monero.layout import confirms from apps.monero.xmr import crypto async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain): - await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, misc.validate_full_path, keychain, msg.address_n, CURVE + ) await confirms.require_confirm_watchkey(ctx) diff --git a/src/apps/monero/key_image_sync.py b/src/apps/monero/key_image_sync.py index 55421632d..f4c8c70fd 100644 --- a/src/apps/monero/key_image_sync.py +++ b/src/apps/monero/key_image_sync.py @@ -8,7 +8,7 @@ from trezor.messages.MoneroKeyImageSyncFinalAck import MoneroKeyImageSyncFinalAc from trezor.messages.MoneroKeyImageSyncStepAck import MoneroKeyImageSyncStepAck from apps.common import paths -from apps.monero import misc +from apps.monero import CURVE, misc from apps.monero.layout import confirms from apps.monero.xmr import crypto, key_image, monero from apps.monero.xmr.crypto import chacha_poly @@ -47,7 +47,9 @@ class KeyImageSync: async def _init_step(s, ctx, msg, keychain): - await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, misc.validate_full_path, keychain, msg.address_n, CURVE + ) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/src/apps/monero/live_refresh.py b/src/apps/monero/live_refresh.py index ea20dbcbc..8315be970 100644 --- a/src/apps/monero/live_refresh.py +++ b/src/apps/monero/live_refresh.py @@ -9,7 +9,7 @@ from trezor.messages.MoneroLiveRefreshStepAck import MoneroLiveRefreshStepAck from trezor.messages.MoneroLiveRefreshStepRequest import MoneroLiveRefreshStepRequest from apps.common import paths -from apps.monero import misc +from apps.monero import CURVE, misc from apps.monero.layout import confirms from apps.monero.xmr import crypto, key_image, monero from apps.monero.xmr.crypto import chacha_poly @@ -44,7 +44,9 @@ class LiveRefreshState: async def _init_step( s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain ): - await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, misc.validate_full_path, keychain, msg.address_n, CURVE + ) await confirms.require_confirm_live_refresh(ctx) diff --git a/src/apps/monero/signing/step_01_init_transaction.py b/src/apps/monero/signing/step_01_init_transaction.py index 3c12bd793..9778e53c2 100644 --- a/src/apps/monero/signing/step_01_init_transaction.py +++ b/src/apps/monero/signing/step_01_init_transaction.py @@ -4,7 +4,7 @@ Initializes a new transaction. import gc -from apps.monero import misc, signing +from apps.monero import CURVE, misc, signing from apps.monero.layout import confirms from apps.monero.signing.state import State from apps.monero.xmr import crypto, monero @@ -24,7 +24,9 @@ async def init_transaction( from apps.monero.signing import offloading_keys from apps.common import paths - await paths.validate_path(state.ctx, misc.validate_full_path, keychain, address_n) + await paths.validate_path( + state.ctx, misc.validate_full_path, keychain, address_n, CURVE + ) state.creds = misc.get_creds(keychain, address_n, network_type) state.client_version = tsx_data.client_version or 0 diff --git a/src/apps/nem/__init__.py b/src/apps/nem/__init__.py index 8433f2448..fb3d13a11 100644 --- a/src/apps/nem/__init__.py +++ b/src/apps/nem/__init__.py @@ -3,11 +3,10 @@ from trezor.messages import MessageType from apps.common import HARDENED +CURVE = "ed25519-keccak" + def boot(): - ns = [ - ["ed25519-keccak", HARDENED | 44, HARDENED | 43], - ["ed25519-keccak", HARDENED | 44, HARDENED | 1], - ] + ns = [[CURVE, HARDENED | 44, HARDENED | 43], [CURVE, HARDENED | 44, HARDENED | 1]] wire.add(MessageType.NEMGetAddress, __name__, "get_address", ns) wire.add(MessageType.NEMSignTx, __name__, "sign_tx", ns) diff --git a/src/apps/nem/get_address.py b/src/apps/nem/get_address.py index 3d67fd532..4bdb1c4f3 100644 --- a/src/apps/nem/get_address.py +++ b/src/apps/nem/get_address.py @@ -1,17 +1,19 @@ from trezor.messages.NEMAddress import NEMAddress -from .helpers import NEM_CURVE, check_path, get_network_str -from .validators import validate_network - from apps.common.layout import address_n_to_str, show_address, show_qr from apps.common.paths import validate_path +from apps.nem import CURVE +from apps.nem.helpers import check_path, get_network_str +from apps.nem.validators import validate_network async def get_address(ctx, msg, keychain): network = validate_network(msg.network) - await validate_path(ctx, check_path, keychain, msg.address_n, network=network) + await validate_path( + ctx, check_path, keychain, msg.address_n, CURVE, network=network + ) - node = keychain.derive(msg.address_n, NEM_CURVE) + node = keychain.derive(msg.address_n, CURVE) address = node.nem_address(network) if msg.show_display: diff --git a/src/apps/nem/helpers.py b/src/apps/nem/helpers.py index 917a7b160..27327dbee 100644 --- a/src/apps/nem/helpers.py +++ b/src/apps/nem/helpers.py @@ -5,7 +5,6 @@ from apps.common import HARDENED NEM_NETWORK_MAINNET = const(0x68) NEM_NETWORK_TESTNET = const(0x98) NEM_NETWORK_MIJIN = const(0x60) -NEM_CURVE = "ed25519-keccak" NEM_TRANSACTION_TYPE_TRANSFER = const(0x0101) NEM_TRANSACTION_TYPE_IMPORTANCE_TRANSFER = const(0x0801) diff --git a/src/apps/nem/sign_tx.py b/src/apps/nem/sign_tx.py index 60d133c98..13e1169a0 100644 --- a/src/apps/nem/sign_tx.py +++ b/src/apps/nem/sign_tx.py @@ -4,8 +4,8 @@ from trezor.messages.NEMSignTx import NEMSignTx from apps.common import seed from apps.common.paths import validate_path -from apps.nem import mosaic, multisig, namespace, transfer -from apps.nem.helpers import NEM_CURVE, NEM_HASH_ALG, check_path +from apps.nem import CURVE, mosaic, multisig, namespace, transfer +from apps.nem.helpers import NEM_HASH_ALG, check_path from apps.nem.validators import validate @@ -17,10 +17,11 @@ async def sign_tx(ctx, msg: NEMSignTx, keychain): check_path, keychain, msg.transaction.address_n, + CURVE, network=msg.transaction.network, ) - node = keychain.derive(msg.transaction.address_n, NEM_CURVE) + node = keychain.derive(msg.transaction.address_n, CURVE) if msg.multisig: public_key = msg.multisig.signer diff --git a/src/apps/ripple/__init__.py b/src/apps/ripple/__init__.py index 2e24cf581..ac116bf7e 100644 --- a/src/apps/ripple/__init__.py +++ b/src/apps/ripple/__init__.py @@ -3,8 +3,10 @@ from trezor.messages import MessageType from apps.common import HARDENED +CURVE = "secp256k1" + def boot(): - ns = [["secp256k1", HARDENED | 44, HARDENED | 144]] + ns = [[CURVE, HARDENED | 44, HARDENED | 144]] wire.add(MessageType.RippleGetAddress, __name__, "get_address", ns) wire.add(MessageType.RippleSignTx, __name__, "sign_tx", ns) diff --git a/src/apps/ripple/get_address.py b/src/apps/ripple/get_address.py index 1b0bc69f4..5ae24abd1 100644 --- a/src/apps/ripple/get_address.py +++ b/src/apps/ripple/get_address.py @@ -3,11 +3,13 @@ from trezor.messages.RippleGetAddress import RippleGetAddress from apps.common import paths from apps.common.layout import address_n_to_str, show_address, show_qr -from apps.ripple import helpers +from apps.ripple import CURVE, helpers async def get_address(ctx, msg: RippleGetAddress, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE + ) node = keychain.derive(msg.address_n) pubkey = node.public_key() diff --git a/src/apps/ripple/sign_tx.py b/src/apps/ripple/sign_tx.py index 60b68a91d..6ee3bafec 100644 --- a/src/apps/ripple/sign_tx.py +++ b/src/apps/ripple/sign_tx.py @@ -6,14 +6,16 @@ from trezor.messages.RippleSignTx import RippleSignTx from trezor.wire import ProcessError from apps.common import paths -from apps.ripple import helpers, layout +from apps.ripple import CURVE, helpers, layout from apps.ripple.serialize import serialize async def sign_tx(ctx, msg: RippleSignTx, keychain): validate(msg) - await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE + ) node = keychain.derive(msg.address_n) source_address = helpers.address_from_public_key(node.public_key()) diff --git a/src/apps/stellar/__init__.py b/src/apps/stellar/__init__.py index ae5e40c88..f868d8054 100644 --- a/src/apps/stellar/__init__.py +++ b/src/apps/stellar/__init__.py @@ -3,8 +3,10 @@ from trezor.messages import MessageType from apps.common import HARDENED +CURVE = "ed25519" + def boot(): - ns = [["ed25519", HARDENED | 44, HARDENED | 148]] + ns = [[CURVE, HARDENED | 44, HARDENED | 148]] wire.add(MessageType.StellarGetAddress, __name__, "get_address", ns) wire.add(MessageType.StellarSignTx, __name__, "sign_tx", ns) diff --git a/src/apps/stellar/consts.py b/src/apps/stellar/consts.py index e179dac09..e24ddbb4d 100644 --- a/src/apps/stellar/consts.py +++ b/src/apps/stellar/consts.py @@ -2,7 +2,6 @@ from micropython import const from trezor.messages import MessageType -STELLAR_CURVE = "ed25519" TX_TYPE = bytearray("\x00\x00\x00\x02") # source: https://github.com/stellar/go/blob/3d2c1defe73dbfed00146ebe0e8d7e07ce4bb1b6/xdr/Stellar-transaction.x#L16 diff --git a/src/apps/stellar/get_address.py b/src/apps/stellar/get_address.py index 648f433ec..f91ac941f 100644 --- a/src/apps/stellar/get_address.py +++ b/src/apps/stellar/get_address.py @@ -3,13 +3,15 @@ from trezor.messages.StellarGetAddress import StellarGetAddress from apps.common import paths, seed from apps.common.layout import address_n_to_str, show_address, show_qr -from apps.stellar import helpers +from apps.stellar import CURVE, helpers async def get_address(ctx, msg: StellarGetAddress, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE + ) - node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE) + node = keychain.derive(msg.address_n, CURVE) pubkey = seed.remove_ed25519_prefix(node.public_key()) address = helpers.address_from_public_key(pubkey) diff --git a/src/apps/stellar/helpers.py b/src/apps/stellar/helpers.py index 0062b7bce..f14544690 100644 --- a/src/apps/stellar/helpers.py +++ b/src/apps/stellar/helpers.py @@ -5,8 +5,6 @@ from trezor.wire import ProcessError from apps.common import HARDENED -STELLAR_CURVE = "ed25519" - def public_key_from_address(address: str) -> bytes: """Extracts public key from an address diff --git a/src/apps/stellar/sign_tx.py b/src/apps/stellar/sign_tx.py index 516f6cfc7..ad32a57f5 100644 --- a/src/apps/stellar/sign_tx.py +++ b/src/apps/stellar/sign_tx.py @@ -8,14 +8,16 @@ from trezor.messages.StellarTxOpRequest import StellarTxOpRequest from trezor.wire import ProcessError from apps.common import paths, seed -from apps.stellar import consts, helpers, layout, writers +from apps.stellar import CURVE, consts, helpers, layout, writers from apps.stellar.operations import process_operation async def sign_tx(ctx, msg: StellarSignTx, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE + ) - node = keychain.derive(msg.address_n, consts.STELLAR_CURVE) + node = keychain.derive(msg.address_n, CURVE) pubkey = seed.remove_ed25519_prefix(node.public_key()) if msg.num_operations == 0: diff --git a/src/apps/tezos/__init__.py b/src/apps/tezos/__init__.py index 1c5aa3699..0022075be 100644 --- a/src/apps/tezos/__init__.py +++ b/src/apps/tezos/__init__.py @@ -3,9 +3,11 @@ from trezor.messages import MessageType from apps.common import HARDENED +CURVE = "ed25519" + def boot(): - ns = [["ed25519", HARDENED | 44, HARDENED | 1729]] + ns = [[CURVE, HARDENED | 44, HARDENED | 1729]] wire.add(MessageType.TezosGetAddress, __name__, "get_address", ns) wire.add(MessageType.TezosSignTx, __name__, "sign_tx", ns) wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key", ns) diff --git a/src/apps/tezos/get_address.py b/src/apps/tezos/get_address.py index c311bdbec..0f1e8bd80 100644 --- a/src/apps/tezos/get_address.py +++ b/src/apps/tezos/get_address.py @@ -3,13 +3,15 @@ from trezor.messages.TezosAddress import TezosAddress from apps.common import paths, seed from apps.common.layout import address_n_to_str, show_address, show_qr -from apps.tezos import helpers +from apps.tezos import CURVE, helpers async def get_address(ctx, msg, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE + ) - node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) + node = keychain.derive(msg.address_n, CURVE) pk = seed.remove_ed25519_prefix(node.public_key()) pkh = hashlib.blake2b(pk, outlen=20).digest() diff --git a/src/apps/tezos/get_public_key.py b/src/apps/tezos/get_public_key.py index f4777d746..ecb06bc0e 100644 --- a/src/apps/tezos/get_public_key.py +++ b/src/apps/tezos/get_public_key.py @@ -6,13 +6,15 @@ from trezor.utils import chunks from apps.common import paths, seed from apps.common.confirm import require_confirm -from apps.tezos import helpers +from apps.tezos import CURVE, helpers async def get_public_key(ctx, msg, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE + ) - node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) + node = keychain.derive(msg.address_n, CURVE) pk = seed.remove_ed25519_prefix(node.public_key()) pk_prefixed = helpers.base58_encode_check(pk, prefix=helpers.TEZOS_PUBLICKEY_PREFIX) diff --git a/src/apps/tezos/helpers.py b/src/apps/tezos/helpers.py index 0050b93fd..4bba52008 100644 --- a/src/apps/tezos/helpers.py +++ b/src/apps/tezos/helpers.py @@ -4,7 +4,6 @@ from trezor.crypto import base58 from apps.common import HARDENED -TEZOS_CURVE = "ed25519" TEZOS_AMOUNT_DIVISIBILITY = const(6) TEZOS_ED25519_ADDRESS_PREFIX = "tz1" TEZOS_ORIGINATED_ADDRESS_PREFIX = "KT1" diff --git a/src/apps/tezos/sign_tx.py b/src/apps/tezos/sign_tx.py index 4f857578e..10227e2bd 100644 --- a/src/apps/tezos/sign_tx.py +++ b/src/apps/tezos/sign_tx.py @@ -6,13 +6,15 @@ from trezor.messages.TezosSignedTx import TezosSignedTx from apps.common import paths from apps.common.writers import write_bytes, write_uint8 -from apps.tezos import helpers, layout +from apps.tezos import CURVE, helpers, layout async def sign_tx(ctx, msg, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) + await paths.validate_path( + ctx, helpers.validate_full_path, keychain, msg.address_n, CURVE + ) - node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) + node = keychain.derive(msg.address_n, CURVE) if msg.transaction is not None: to = _get_address_from_contract(msg.transaction.destination) diff --git a/src/apps/wallet/get_address.py b/src/apps/wallet/get_address.py index 2a40c05e9..e8fedcd3b 100644 --- a/src/apps/wallet/get_address.py +++ b/src/apps/wallet/get_address.py @@ -16,6 +16,7 @@ async def get_address(ctx, msg, keychain): addresses.validate_full_path, keychain, msg.address_n, + coin.curve_name, coin=coin, script_type=msg.script_type, ) diff --git a/src/apps/wallet/sign_message.py b/src/apps/wallet/sign_message.py index 79fedba78..376d16396 100644 --- a/src/apps/wallet/sign_message.py +++ b/src/apps/wallet/sign_message.py @@ -24,6 +24,7 @@ async def sign_message(ctx, msg, keychain): validate_full_path, keychain, msg.address_n, + coin.curve_name, coin=coin, script_type=msg.script_type, validate_script_type=False, diff --git a/tests/test_apps.common.seed.py b/tests/test_apps.common.seed.py new file mode 100644 index 000000000..0a3eeee22 --- /dev/null +++ b/tests/test_apps.common.seed.py @@ -0,0 +1,60 @@ +from common import * +from apps.common import HARDENED +from apps.common.seed import Keychain, _path_hardened +from trezor import wire + + +class TestKeychain(unittest.TestCase): + + def test_validate_path(self): + n = [ + ["ed25519", 44 | HARDENED, 134 | HARDENED], + ["secp256k1", 44 | HARDENED, 11 | HARDENED], + ] + k = Keychain(b"", n) + + correct = ( + ([44 | HARDENED, 134 | HARDENED], "ed25519"), + ([44 | HARDENED, 11 | HARDENED], "secp256k1"), + ([44 | HARDENED, 11 | HARDENED, 12], "secp256k1"), + ) + for c in correct: + self.assertEqual(None, k.validate_path(*c)) + + fails = [ + ([44 | HARDENED, 134], "ed25519"), # path does not match + ([44 | HARDENED, 134], "secp256k1"), # curve and path does not match + ([44 | HARDENED, 134 | HARDENED], "nist256p"), # curve not included + ([44, 134], "ed25519"), # path does not match (non-hardened items) + ([44 | HARDENED, 134 | HARDENED, 123], "ed25519"), # non-hardened item in ed25519 + ([44 | HARDENED, 13 | HARDENED], "secp256k1"), # invalid second item + ] + for f in fails: + with self.assertRaises(wire.DataError): + k.validate_path(*f) + + def test_validate_path_empty_namespace(self): + k = Keychain(b"", [["secp256k1"]]) + correct = ( + ([], "secp256k1"), + ([1, 2, 3, 4], "secp256k1"), + ([44 | HARDENED, 11 | HARDENED], "secp256k1"), + ([44 | HARDENED, 11 | HARDENED, 12], "secp256k1"), + ) + for c in correct: + self.assertEqual(None, k.validate_path(*c)) + + with self.assertRaises(wire.DataError): + k.validate_path([1, 2, 3, 4], "ed25519") + k.validate_path([], "ed25519") + + def test_path_hardened(self): + self.assertTrue(_path_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED])) + self.assertTrue(_path_hardened([0 | HARDENED, ])) + + self.assertFalse(_path_hardened([44, 44 | HARDENED, 0 | HARDENED])) + self.assertFalse(_path_hardened([0, ])) + self.assertFalse(_path_hardened([44 | HARDENED, 1 | HARDENED, 0 | HARDENED, 0 | HARDENED, 0])) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_apps.nem.hdnode.py b/tests/test_apps.nem.hdnode.py index 69db1a6c6..a2082d775 100644 --- a/tests/test_apps.nem.hdnode.py +++ b/tests/test_apps.nem.hdnode.py @@ -1,7 +1,8 @@ from common import * from ubinascii import unhexlify from trezor.crypto import bip32 -from apps.nem.helpers import NEM_NETWORK_MAINNET, NEM_CURVE +from apps.nem import CURVE +from apps.nem.helpers import NEM_NETWORK_MAINNET class TestNemHDNode(unittest.TestCase): @@ -81,7 +82,7 @@ class TestNemHDNode(unittest.TestCase): child_num=0, chain_code=bytearray(32), private_key=private_key, - curve_name=NEM_CURVE + curve_name=CURVE ) self.assertEqual(node.nem_address(NEM_NETWORK_MAINNET), test[2]) @@ -222,7 +223,7 @@ class TestNemHDNode(unittest.TestCase): child_num=0, chain_code=bytearray(32), private_key=private_key, - curve_name=NEM_CURVE + curve_name=CURVE ) encrypted = node.nem_encrypt(unhexlify(test['public']),