diff --git a/src/apps/cardano/get_address.py b/src/apps/cardano/get_address.py index 5415386541..cb772b4a70 100644 --- a/src/apps/cardano/get_address.py +++ b/src/apps/cardano/get_address.py @@ -10,7 +10,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr async def get_address(ctx, msg): keychain = await seed.get_keychain(ctx) - await paths.validate_path(ctx, validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) try: address, _ = derive_address_and_node(keychain, msg.address_n) diff --git a/src/apps/cardano/get_public_key.py b/src/apps/cardano/get_public_key.py index 8ebe9e3ef7..f648406f69 100644 --- a/src/apps/cardano/get_public_key.py +++ b/src/apps/cardano/get_public_key.py @@ -14,7 +14,11 @@ async def get_public_key(ctx, msg): keychain = await seed.get_keychain(ctx) await paths.validate_path( - ctx, paths.validate_path_for_get_public_key, path=msg.address_n, slip44_id=1815 + ctx, + paths.validate_path_for_get_public_key, + keychain, + msg.address_n, + slip44_id=1815, ) try: diff --git a/src/apps/cardano/seed.py b/src/apps/cardano/seed.py index 26b32ada70..60cc994ac3 100644 --- a/src/apps/cardano/seed.py +++ b/src/apps/cardano/seed.py @@ -11,6 +11,10 @@ class Keychain: self.path = path self.root = root + def validate_path(self, checked_path: list): + if checked_path[:2] != SEED_NAMESPACE[0]: + raise wire.DataError("Forbidden key path") + def derive(self, node_path: list) -> bip32.HDNode: # check we are in the cardano namespace prefix = node_path[: len(self.path)] diff --git a/src/apps/cardano/sign_tx.py b/src/apps/cardano/sign_tx.py index 12cfa5b39e..99c4243a59 100644 --- a/src/apps/cardano/sign_tx.py +++ b/src/apps/cardano/sign_tx.py @@ -85,7 +85,7 @@ async def sign_tx(ctx, msg): display_homescreen() for i in msg.inputs: - await validate_path(ctx, validate_full_path, path=i.address_n) + await validate_path(ctx, validate_full_path, keychain, i.address_n) # sign the transaction bundle and prepare the result transaction = Transaction( diff --git a/src/apps/common/paths.py b/src/apps/common/paths.py index 566171aeb2..0f3cae73b3 100644 --- a/src/apps/common/paths.py +++ b/src/apps/common/paths.py @@ -8,9 +8,10 @@ from apps.common import HARDENED from apps.common.confirm import require_confirm -async def validate_path(ctx, validate_func, **kwargs): - if not validate_func(**kwargs): - await show_path_warning(ctx, kwargs["path"]) +async def validate_path(ctx, validate_func, keychain, path, **kwargs): + keychain.validate_path(path) + if not validate_func(path, **kwargs): + await show_path_warning(ctx, path) async def show_path_warning(ctx, path: list): diff --git a/src/apps/common/seed.py b/src/apps/common/seed.py index 208988c899..260bbb5c7d 100644 --- a/src/apps/common/seed.py +++ b/src/apps/common/seed.py @@ -25,6 +25,12 @@ class Keychain: del self.roots del self.seed + def validate_path(self, checked_path: list): + for curve, *path in self.namespaces: + if path == checked_path[: len(path)]: # TODO: check curve_name + return + raise wire.DataError("Forbidden key path") + def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode: # find the root node index root_index = 0 @@ -44,6 +50,7 @@ class Keychain: root.derive_path(path) self.roots[root_index] = root + # TODO check for ed25519? # derive child node from the root node = root.clone() node.derive_path(suffix) diff --git a/src/apps/ethereum/get_address.py b/src/apps/ethereum/get_address.py index f2a3adac6d..b7e58db6fb 100644 --- a/src/apps/ethereum/get_address.py +++ b/src/apps/ethereum/get_address.py @@ -9,7 +9,7 @@ from apps.ethereum.address import address_from_bytes, validate_full_path async def get_address(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n) seckey = node.private_key() diff --git a/src/apps/ethereum/get_public_key.py b/src/apps/ethereum/get_public_key.py index b16a8e40af..2521a1ebf7 100644 --- a/src/apps/ethereum/get_public_key.py +++ b/src/apps/ethereum/get_public_key.py @@ -7,7 +7,7 @@ from apps.ethereum import address async def get_public_key(ctx, msg, keychain): await paths.validate_path( - ctx, address.validate_path_for_get_public_key, path=msg.address_n + ctx, address.validate_path_for_get_public_key, keychain, msg.address_n ) node = keychain.derive(msg.address_n) diff --git a/src/apps/ethereum/sign_message.py b/src/apps/ethereum/sign_message.py index 54b85320a3..abffef5226 100644 --- a/src/apps/ethereum/sign_message.py +++ b/src/apps/ethereum/sign_message.py @@ -20,7 +20,7 @@ def message_digest(message): async def sign_message(ctx, msg, keychain): - await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, address.validate_full_path, keychain, msg.address_n) await require_confirm_sign_message(ctx, msg.message) node = keychain.derive(msg.address_n) diff --git a/src/apps/ethereum/sign_tx.py b/src/apps/ethereum/sign_tx.py index b7ef05b70c..0c5c86ea33 100644 --- a/src/apps/ethereum/sign_tx.py +++ b/src/apps/ethereum/sign_tx.py @@ -23,7 +23,7 @@ MAX_CHAIN_ID = 2147483629 async def sign_tx(ctx, msg, keychain): msg = sanitize(msg) check(msg) - await paths.validate_path(ctx, validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) data_total = msg.data_length diff --git a/src/apps/lisk/get_address.py b/src/apps/lisk/get_address.py index c77d0045b5..c0b8dc4d50 100644 --- a/src/apps/lisk/get_address.py +++ b/src/apps/lisk/get_address.py @@ -7,7 +7,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr async def get_address(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n, LISK_CURVE) pubkey = node.public_key() diff --git a/src/apps/lisk/get_public_key.py b/src/apps/lisk/get_public_key.py index 93e3907cf0..3baa21f465 100644 --- a/src/apps/lisk/get_public_key.py +++ b/src/apps/lisk/get_public_key.py @@ -6,7 +6,7 @@ from apps.common import layout, paths async def get_public_key(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n, LISK_CURVE) pubkey = node.public_key() diff --git a/src/apps/lisk/sign_message.py b/src/apps/lisk/sign_message.py index 9b1c99834d..9650f888f1 100644 --- a/src/apps/lisk/sign_message.py +++ b/src/apps/lisk/sign_message.py @@ -23,7 +23,7 @@ def message_digest(message): async def sign_message(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) await require_confirm_sign_message(ctx, msg.message) node = keychain.derive(msg.address_n, LISK_CURVE) diff --git a/src/apps/lisk/sign_tx.py b/src/apps/lisk/sign_tx.py index 0ae7c0e732..8c72322738 100644 --- a/src/apps/lisk/sign_tx.py +++ b/src/apps/lisk/sign_tx.py @@ -17,7 +17,7 @@ from apps.lisk.helpers import ( async def sign_tx(ctx, msg, keychain): - await paths.validate_path(ctx, validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n) pubkey, seckey = _get_keys(keychain, msg) transaction = _update_raw_tx(msg.transaction, pubkey) diff --git a/src/apps/monero/get_address.py b/src/apps/monero/get_address.py index aced5b7dc4..83c069157c 100644 --- a/src/apps/monero/get_address.py +++ b/src/apps/monero/get_address.py @@ -6,7 +6,7 @@ from apps.monero import misc async def get_address(ctx, msg, keychain): - await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/src/apps/monero/get_tx_keys.py b/src/apps/monero/get_tx_keys.py index 0e4a22a729..40c7d6c4d6 100644 --- a/src/apps/monero/get_tx_keys.py +++ b/src/apps/monero/get_tx_keys.py @@ -30,7 +30,7 @@ _GET_TX_KEY_REASON_TX_DERIVATION = 1 async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain): - await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv) diff --git a/src/apps/monero/get_watch_only.py b/src/apps/monero/get_watch_only.py index 8a636a0a86..fe44f0b1fa 100644 --- a/src/apps/monero/get_watch_only.py +++ b/src/apps/monero/get_watch_only.py @@ -8,7 +8,7 @@ from apps.monero.xmr import crypto async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain): - await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) await confirms.require_confirm_watchkey(ctx) diff --git a/src/apps/monero/key_image_sync.py b/src/apps/monero/key_image_sync.py index e025642310..55421632db 100644 --- a/src/apps/monero/key_image_sync.py +++ b/src/apps/monero/key_image_sync.py @@ -47,7 +47,7 @@ class KeyImageSync: async def _init_step(s, ctx, msg, keychain): - await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/src/apps/monero/live_refresh.py b/src/apps/monero/live_refresh.py index 0deecbafc7..ea20dbcbce 100644 --- a/src/apps/monero/live_refresh.py +++ b/src/apps/monero/live_refresh.py @@ -44,7 +44,7 @@ class LiveRefreshState: async def _init_step( s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain ): - await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n) await confirms.require_confirm_live_refresh(ctx) diff --git a/src/apps/monero/signing/step_01_init_transaction.py b/src/apps/monero/signing/step_01_init_transaction.py index c8c8b8ab9c..3c12bd7936 100644 --- a/src/apps/monero/signing/step_01_init_transaction.py +++ b/src/apps/monero/signing/step_01_init_transaction.py @@ -24,7 +24,7 @@ async def init_transaction( from apps.monero.signing import offloading_keys from apps.common import paths - await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n) + await paths.validate_path(state.ctx, misc.validate_full_path, keychain, address_n) state.creds = misc.get_creds(keychain, address_n, network_type) state.client_version = tsx_data.client_version or 0 diff --git a/src/apps/nem/get_address.py b/src/apps/nem/get_address.py index 676e43657e..3d67fd5329 100644 --- a/src/apps/nem/get_address.py +++ b/src/apps/nem/get_address.py @@ -9,7 +9,7 @@ from apps.common.paths import validate_path async def get_address(ctx, msg, keychain): network = validate_network(msg.network) - await validate_path(ctx, check_path, path=msg.address_n, network=network) + await validate_path(ctx, check_path, keychain, msg.address_n, network=network) node = keychain.derive(msg.address_n, NEM_CURVE) address = node.nem_address(network) diff --git a/src/apps/nem/sign_tx.py b/src/apps/nem/sign_tx.py index 1a76f163b7..60d133c98e 100644 --- a/src/apps/nem/sign_tx.py +++ b/src/apps/nem/sign_tx.py @@ -13,7 +13,11 @@ async def sign_tx(ctx, msg: NEMSignTx, keychain): validate(msg) await validate_path( - ctx, check_path, path=msg.transaction.address_n, network=msg.transaction.network + ctx, + check_path, + keychain, + msg.transaction.address_n, + network=msg.transaction.network, ) node = keychain.derive(msg.transaction.address_n, NEM_CURVE) diff --git a/src/apps/ripple/get_address.py b/src/apps/ripple/get_address.py index a5b26ba762..1b0bc69f4d 100644 --- a/src/apps/ripple/get_address.py +++ b/src/apps/ripple/get_address.py @@ -7,7 +7,7 @@ from apps.ripple import helpers async def get_address(ctx, msg: RippleGetAddress, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n) pubkey = node.public_key() diff --git a/src/apps/ripple/sign_tx.py b/src/apps/ripple/sign_tx.py index 4174dc0777..60b68a91d0 100644 --- a/src/apps/ripple/sign_tx.py +++ b/src/apps/ripple/sign_tx.py @@ -13,7 +13,7 @@ from apps.ripple.serialize import serialize async def sign_tx(ctx, msg: RippleSignTx, keychain): validate(msg) - await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n) source_address = helpers.address_from_public_key(node.public_key()) diff --git a/src/apps/stellar/get_address.py b/src/apps/stellar/get_address.py index 8b87d34f8b..648f433ec1 100644 --- a/src/apps/stellar/get_address.py +++ b/src/apps/stellar/get_address.py @@ -7,7 +7,7 @@ from apps.stellar import helpers async def get_address(ctx, msg: StellarGetAddress, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE) pubkey = seed.remove_ed25519_prefix(node.public_key()) diff --git a/src/apps/stellar/sign_tx.py b/src/apps/stellar/sign_tx.py index cc905b9216..516f6cfc7e 100644 --- a/src/apps/stellar/sign_tx.py +++ b/src/apps/stellar/sign_tx.py @@ -13,7 +13,7 @@ from apps.stellar.operations import process_operation async def sign_tx(ctx, msg: StellarSignTx, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n, consts.STELLAR_CURVE) pubkey = seed.remove_ed25519_prefix(node.public_key()) diff --git a/src/apps/tezos/get_address.py b/src/apps/tezos/get_address.py index 4ff5635129..c311bdbece 100644 --- a/src/apps/tezos/get_address.py +++ b/src/apps/tezos/get_address.py @@ -7,7 +7,7 @@ from apps.tezos import helpers async def get_address(ctx, msg, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) diff --git a/src/apps/tezos/get_public_key.py b/src/apps/tezos/get_public_key.py index c3f75ad97b..f4777d7461 100644 --- a/src/apps/tezos/get_public_key.py +++ b/src/apps/tezos/get_public_key.py @@ -10,7 +10,7 @@ from apps.tezos import helpers async def get_public_key(ctx, msg, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) pk = seed.remove_ed25519_prefix(node.public_key()) diff --git a/src/apps/tezos/sign_tx.py b/src/apps/tezos/sign_tx.py index 5cf0e11d3b..4f857578eb 100644 --- a/src/apps/tezos/sign_tx.py +++ b/src/apps/tezos/sign_tx.py @@ -10,7 +10,7 @@ from apps.tezos import helpers, layout async def sign_tx(ctx, msg, keychain): - await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) + await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n) node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) diff --git a/src/apps/wallet/get_address.py b/src/apps/wallet/get_address.py index 80096d37c5..2a40c05e95 100644 --- a/src/apps/wallet/get_address.py +++ b/src/apps/wallet/get_address.py @@ -14,7 +14,8 @@ async def get_address(ctx, msg, keychain): await validate_path( ctx, addresses.validate_full_path, - path=msg.address_n, + keychain, + msg.address_n, coin=coin, script_type=msg.script_type, ) diff --git a/src/apps/wallet/sign_message.py b/src/apps/wallet/sign_message.py index ab9e58f879..79fedba782 100644 --- a/src/apps/wallet/sign_message.py +++ b/src/apps/wallet/sign_message.py @@ -22,7 +22,8 @@ async def sign_message(ctx, msg, keychain): await validate_path( ctx, validate_full_path, - path=msg.address_n, + keychain, + msg.address_n, coin=coin, script_type=msg.script_type, validate_script_type=False,