mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-16 03:18:09 +00:00
Merge pull request #532 from trezor/tsusanka/path-validation
Pass Keychain to path validation function to throw error before warning
This commit is contained in:
commit
396f5f1937
@ -10,7 +10,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
|
||||
async def get_address(ctx, msg):
|
||||
keychain = await seed.get_keychain(ctx)
|
||||
|
||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||
|
||||
try:
|
||||
address, _ = derive_address_and_node(keychain, msg.address_n)
|
||||
|
@ -14,7 +14,11 @@ async def get_public_key(ctx, msg):
|
||||
keychain = await seed.get_keychain(ctx)
|
||||
|
||||
await paths.validate_path(
|
||||
ctx, paths.validate_path_for_get_public_key, path=msg.address_n, slip44_id=1815
|
||||
ctx,
|
||||
paths.validate_path_for_get_public_key,
|
||||
keychain,
|
||||
msg.address_n,
|
||||
slip44_id=1815,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -11,6 +11,10 @@ class Keychain:
|
||||
self.path = path
|
||||
self.root = root
|
||||
|
||||
def validate_path(self, checked_path: list):
|
||||
if checked_path[:2] != SEED_NAMESPACE[0]:
|
||||
raise wire.DataError("Forbidden key path")
|
||||
|
||||
def derive(self, node_path: list) -> bip32.HDNode:
|
||||
# check we are in the cardano namespace
|
||||
prefix = node_path[: len(self.path)]
|
||||
|
@ -85,7 +85,7 @@ async def sign_tx(ctx, msg):
|
||||
display_homescreen()
|
||||
|
||||
for i in msg.inputs:
|
||||
await validate_path(ctx, validate_full_path, path=i.address_n)
|
||||
await validate_path(ctx, validate_full_path, keychain, i.address_n)
|
||||
|
||||
# sign the transaction bundle and prepare the result
|
||||
transaction = Transaction(
|
||||
|
@ -8,9 +8,10 @@ from apps.common import HARDENED
|
||||
from apps.common.confirm import require_confirm
|
||||
|
||||
|
||||
async def validate_path(ctx, validate_func, **kwargs):
|
||||
if not validate_func(**kwargs):
|
||||
await show_path_warning(ctx, kwargs["path"])
|
||||
async def validate_path(ctx, validate_func, keychain, path, **kwargs):
|
||||
keychain.validate_path(path)
|
||||
if not validate_func(path, **kwargs):
|
||||
await show_path_warning(ctx, path)
|
||||
|
||||
|
||||
async def show_path_warning(ctx, path: list):
|
||||
|
@ -25,6 +25,12 @@ class Keychain:
|
||||
del self.roots
|
||||
del self.seed
|
||||
|
||||
def validate_path(self, checked_path: list):
|
||||
for curve, *path in self.namespaces:
|
||||
if path == checked_path[: len(path)]: # TODO: check curve_name
|
||||
return
|
||||
raise wire.DataError("Forbidden key path")
|
||||
|
||||
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
|
||||
# find the root node index
|
||||
root_index = 0
|
||||
@ -44,6 +50,7 @@ class Keychain:
|
||||
root.derive_path(path)
|
||||
self.roots[root_index] = root
|
||||
|
||||
# TODO check for ed25519?
|
||||
# derive child node from the root
|
||||
node = root.clone()
|
||||
node.derive_path(suffix)
|
||||
|
@ -9,7 +9,7 @@ from apps.ethereum.address import address_from_bytes, validate_full_path
|
||||
|
||||
|
||||
async def get_address(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
seckey = node.private_key()
|
||||
|
@ -7,7 +7,7 @@ from apps.ethereum import address
|
||||
|
||||
async def get_public_key(ctx, msg, keychain):
|
||||
await paths.validate_path(
|
||||
ctx, address.validate_path_for_get_public_key, path=msg.address_n
|
||||
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n
|
||||
)
|
||||
node = keychain.derive(msg.address_n)
|
||||
|
||||
|
@ -20,7 +20,7 @@ def message_digest(message):
|
||||
|
||||
|
||||
async def sign_message(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, address.validate_full_path, keychain, msg.address_n)
|
||||
await require_confirm_sign_message(ctx, msg.message)
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
|
@ -23,7 +23,7 @@ MAX_CHAIN_ID = 2147483629
|
||||
async def sign_tx(ctx, msg, keychain):
|
||||
msg = sanitize(msg)
|
||||
check(msg)
|
||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||
|
||||
data_total = msg.data_length
|
||||
|
||||
|
@ -7,7 +7,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
|
||||
|
||||
|
||||
async def get_address(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n, LISK_CURVE)
|
||||
pubkey = node.public_key()
|
||||
|
@ -6,7 +6,7 @@ from apps.common import layout, paths
|
||||
|
||||
|
||||
async def get_public_key(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n, LISK_CURVE)
|
||||
pubkey = node.public_key()
|
||||
|
@ -23,7 +23,7 @@ def message_digest(message):
|
||||
|
||||
|
||||
async def sign_message(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||
await require_confirm_sign_message(ctx, msg.message)
|
||||
|
||||
node = keychain.derive(msg.address_n, LISK_CURVE)
|
||||
|
@ -17,7 +17,7 @@ from apps.lisk.helpers import (
|
||||
|
||||
|
||||
async def sign_tx(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||
|
||||
pubkey, seckey = _get_keys(keychain, msg)
|
||||
transaction = _update_raw_tx(msg.transaction, pubkey)
|
||||
|
@ -6,7 +6,7 @@ from apps.monero import misc
|
||||
|
||||
|
||||
async def get_address(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
||||
|
||||
|
@ -30,7 +30,7 @@ _GET_TX_KEY_REASON_TX_DERIVATION = 1
|
||||
|
||||
|
||||
async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
|
||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
|
||||
await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv)
|
||||
|
@ -8,7 +8,7 @@ from apps.monero.xmr import crypto
|
||||
|
||||
|
||||
async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain):
|
||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
await confirms.require_confirm_watchkey(ctx)
|
||||
|
||||
|
@ -47,7 +47,7 @@ class KeyImageSync:
|
||||
|
||||
|
||||
async def _init_step(s, ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
||||
|
||||
|
@ -44,7 +44,7 @@ class LiveRefreshState:
|
||||
async def _init_step(
|
||||
s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain
|
||||
):
|
||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
await confirms.require_confirm_live_refresh(ctx)
|
||||
|
||||
|
@ -24,7 +24,7 @@ async def init_transaction(
|
||||
from apps.monero.signing import offloading_keys
|
||||
from apps.common import paths
|
||||
|
||||
await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n)
|
||||
await paths.validate_path(state.ctx, misc.validate_full_path, keychain, address_n)
|
||||
|
||||
state.creds = misc.get_creds(keychain, address_n, network_type)
|
||||
state.client_version = tsx_data.client_version or 0
|
||||
|
@ -9,7 +9,7 @@ from apps.common.paths import validate_path
|
||||
|
||||
async def get_address(ctx, msg, keychain):
|
||||
network = validate_network(msg.network)
|
||||
await validate_path(ctx, check_path, path=msg.address_n, network=network)
|
||||
await validate_path(ctx, check_path, keychain, msg.address_n, network=network)
|
||||
|
||||
node = keychain.derive(msg.address_n, NEM_CURVE)
|
||||
address = node.nem_address(network)
|
||||
|
@ -13,7 +13,11 @@ async def sign_tx(ctx, msg: NEMSignTx, keychain):
|
||||
validate(msg)
|
||||
|
||||
await validate_path(
|
||||
ctx, check_path, path=msg.transaction.address_n, network=msg.transaction.network
|
||||
ctx,
|
||||
check_path,
|
||||
keychain,
|
||||
msg.transaction.address_n,
|
||||
network=msg.transaction.network,
|
||||
)
|
||||
|
||||
node = keychain.derive(msg.transaction.address_n, NEM_CURVE)
|
||||
|
@ -7,7 +7,7 @@ from apps.ripple import helpers
|
||||
|
||||
|
||||
async def get_address(ctx, msg: RippleGetAddress, keychain):
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
pubkey = node.public_key()
|
||||
|
@ -13,7 +13,7 @@ from apps.ripple.serialize import serialize
|
||||
async def sign_tx(ctx, msg: RippleSignTx, keychain):
|
||||
validate(msg)
|
||||
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n)
|
||||
source_address = helpers.address_from_public_key(node.public_key())
|
||||
|
@ -7,7 +7,7 @@ from apps.stellar import helpers
|
||||
|
||||
|
||||
async def get_address(ctx, msg: StellarGetAddress, keychain):
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE)
|
||||
pubkey = seed.remove_ed25519_prefix(node.public_key())
|
||||
|
@ -13,7 +13,7 @@ from apps.stellar.operations import process_operation
|
||||
|
||||
|
||||
async def sign_tx(ctx, msg: StellarSignTx, keychain):
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n, consts.STELLAR_CURVE)
|
||||
pubkey = seed.remove_ed25519_prefix(node.public_key())
|
||||
|
@ -7,7 +7,7 @@ from apps.tezos import helpers
|
||||
|
||||
|
||||
async def get_address(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
||||
|
||||
|
@ -10,7 +10,7 @@ from apps.tezos import helpers
|
||||
|
||||
|
||||
async def get_public_key(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
||||
pk = seed.remove_ed25519_prefix(node.public_key())
|
||||
|
@ -10,7 +10,7 @@ from apps.tezos import helpers, layout
|
||||
|
||||
|
||||
async def sign_tx(ctx, msg, keychain):
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
||||
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||
|
||||
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
||||
|
||||
|
@ -14,7 +14,8 @@ async def get_address(ctx, msg, keychain):
|
||||
await validate_path(
|
||||
ctx,
|
||||
addresses.validate_full_path,
|
||||
path=msg.address_n,
|
||||
keychain,
|
||||
msg.address_n,
|
||||
coin=coin,
|
||||
script_type=msg.script_type,
|
||||
)
|
||||
|
@ -22,7 +22,8 @@ async def sign_message(ctx, msg, keychain):
|
||||
await validate_path(
|
||||
ctx,
|
||||
validate_full_path,
|
||||
path=msg.address_n,
|
||||
keychain,
|
||||
msg.address_n,
|
||||
coin=coin,
|
||||
script_type=msg.script_type,
|
||||
validate_script_type=False,
|
||||
|
Loading…
Reference in New Issue
Block a user