mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-16 03:18:09 +00:00
Merge pull request #532 from trezor/tsusanka/path-validation
Pass Keychain to path validation function to throw error before warning
This commit is contained in:
commit
396f5f1937
@ -10,7 +10,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
|
|||||||
async def get_address(ctx, msg):
|
async def get_address(ctx, msg):
|
||||||
keychain = await seed.get_keychain(ctx)
|
keychain = await seed.get_keychain(ctx)
|
||||||
|
|
||||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
address, _ = derive_address_and_node(keychain, msg.address_n)
|
address, _ = derive_address_and_node(keychain, msg.address_n)
|
||||||
|
@ -14,7 +14,11 @@ async def get_public_key(ctx, msg):
|
|||||||
keychain = await seed.get_keychain(ctx)
|
keychain = await seed.get_keychain(ctx)
|
||||||
|
|
||||||
await paths.validate_path(
|
await paths.validate_path(
|
||||||
ctx, paths.validate_path_for_get_public_key, path=msg.address_n, slip44_id=1815
|
ctx,
|
||||||
|
paths.validate_path_for_get_public_key,
|
||||||
|
keychain,
|
||||||
|
msg.address_n,
|
||||||
|
slip44_id=1815,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -11,6 +11,10 @@ class Keychain:
|
|||||||
self.path = path
|
self.path = path
|
||||||
self.root = root
|
self.root = root
|
||||||
|
|
||||||
|
def validate_path(self, checked_path: list):
|
||||||
|
if checked_path[:2] != SEED_NAMESPACE[0]:
|
||||||
|
raise wire.DataError("Forbidden key path")
|
||||||
|
|
||||||
def derive(self, node_path: list) -> bip32.HDNode:
|
def derive(self, node_path: list) -> bip32.HDNode:
|
||||||
# check we are in the cardano namespace
|
# check we are in the cardano namespace
|
||||||
prefix = node_path[: len(self.path)]
|
prefix = node_path[: len(self.path)]
|
||||||
|
@ -85,7 +85,7 @@ async def sign_tx(ctx, msg):
|
|||||||
display_homescreen()
|
display_homescreen()
|
||||||
|
|
||||||
for i in msg.inputs:
|
for i in msg.inputs:
|
||||||
await validate_path(ctx, validate_full_path, path=i.address_n)
|
await validate_path(ctx, validate_full_path, keychain, i.address_n)
|
||||||
|
|
||||||
# sign the transaction bundle and prepare the result
|
# sign the transaction bundle and prepare the result
|
||||||
transaction = Transaction(
|
transaction = Transaction(
|
||||||
|
@ -8,9 +8,10 @@ from apps.common import HARDENED
|
|||||||
from apps.common.confirm import require_confirm
|
from apps.common.confirm import require_confirm
|
||||||
|
|
||||||
|
|
||||||
async def validate_path(ctx, validate_func, **kwargs):
|
async def validate_path(ctx, validate_func, keychain, path, **kwargs):
|
||||||
if not validate_func(**kwargs):
|
keychain.validate_path(path)
|
||||||
await show_path_warning(ctx, kwargs["path"])
|
if not validate_func(path, **kwargs):
|
||||||
|
await show_path_warning(ctx, path)
|
||||||
|
|
||||||
|
|
||||||
async def show_path_warning(ctx, path: list):
|
async def show_path_warning(ctx, path: list):
|
||||||
|
@ -25,6 +25,12 @@ class Keychain:
|
|||||||
del self.roots
|
del self.roots
|
||||||
del self.seed
|
del self.seed
|
||||||
|
|
||||||
|
def validate_path(self, checked_path: list):
|
||||||
|
for curve, *path in self.namespaces:
|
||||||
|
if path == checked_path[: len(path)]: # TODO: check curve_name
|
||||||
|
return
|
||||||
|
raise wire.DataError("Forbidden key path")
|
||||||
|
|
||||||
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
|
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
|
||||||
# find the root node index
|
# find the root node index
|
||||||
root_index = 0
|
root_index = 0
|
||||||
@ -44,6 +50,7 @@ class Keychain:
|
|||||||
root.derive_path(path)
|
root.derive_path(path)
|
||||||
self.roots[root_index] = root
|
self.roots[root_index] = root
|
||||||
|
|
||||||
|
# TODO check for ed25519?
|
||||||
# derive child node from the root
|
# derive child node from the root
|
||||||
node = root.clone()
|
node = root.clone()
|
||||||
node.derive_path(suffix)
|
node.derive_path(suffix)
|
||||||
|
@ -9,7 +9,7 @@ from apps.ethereum.address import address_from_bytes, validate_full_path
|
|||||||
|
|
||||||
|
|
||||||
async def get_address(ctx, msg, keychain):
|
async def get_address(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
seckey = node.private_key()
|
seckey = node.private_key()
|
||||||
|
@ -7,7 +7,7 @@ from apps.ethereum import address
|
|||||||
|
|
||||||
async def get_public_key(ctx, msg, keychain):
|
async def get_public_key(ctx, msg, keychain):
|
||||||
await paths.validate_path(
|
await paths.validate_path(
|
||||||
ctx, address.validate_path_for_get_public_key, path=msg.address_n
|
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n
|
||||||
)
|
)
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ def message_digest(message):
|
|||||||
|
|
||||||
|
|
||||||
async def sign_message(ctx, msg, keychain):
|
async def sign_message(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, address.validate_full_path, keychain, msg.address_n)
|
||||||
await require_confirm_sign_message(ctx, msg.message)
|
await require_confirm_sign_message(ctx, msg.message)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
|
@ -23,7 +23,7 @@ MAX_CHAIN_ID = 2147483629
|
|||||||
async def sign_tx(ctx, msg, keychain):
|
async def sign_tx(ctx, msg, keychain):
|
||||||
msg = sanitize(msg)
|
msg = sanitize(msg)
|
||||||
check(msg)
|
check(msg)
|
||||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
data_total = msg.data_length
|
data_total = msg.data_length
|
||||||
|
|
||||||
|
@ -7,7 +7,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
|
|||||||
|
|
||||||
|
|
||||||
async def get_address(ctx, msg, keychain):
|
async def get_address(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, LISK_CURVE)
|
node = keychain.derive(msg.address_n, LISK_CURVE)
|
||||||
pubkey = node.public_key()
|
pubkey = node.public_key()
|
||||||
|
@ -6,7 +6,7 @@ from apps.common import layout, paths
|
|||||||
|
|
||||||
|
|
||||||
async def get_public_key(ctx, msg, keychain):
|
async def get_public_key(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, LISK_CURVE)
|
node = keychain.derive(msg.address_n, LISK_CURVE)
|
||||||
pubkey = node.public_key()
|
pubkey = node.public_key()
|
||||||
|
@ -23,7 +23,7 @@ def message_digest(message):
|
|||||||
|
|
||||||
|
|
||||||
async def sign_message(ctx, msg, keychain):
|
async def sign_message(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||||
await require_confirm_sign_message(ctx, msg.message)
|
await require_confirm_sign_message(ctx, msg.message)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, LISK_CURVE)
|
node = keychain.derive(msg.address_n, LISK_CURVE)
|
||||||
|
@ -17,7 +17,7 @@ from apps.lisk.helpers import (
|
|||||||
|
|
||||||
|
|
||||||
async def sign_tx(ctx, msg, keychain):
|
async def sign_tx(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
pubkey, seckey = _get_keys(keychain, msg)
|
pubkey, seckey = _get_keys(keychain, msg)
|
||||||
transaction = _update_raw_tx(msg.transaction, pubkey)
|
transaction = _update_raw_tx(msg.transaction, pubkey)
|
||||||
|
@ -6,7 +6,7 @@ from apps.monero import misc
|
|||||||
|
|
||||||
|
|
||||||
async def get_address(ctx, msg, keychain):
|
async def get_address(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ _GET_TX_KEY_REASON_TX_DERIVATION = 1
|
|||||||
|
|
||||||
|
|
||||||
async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
|
async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
|
||||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
|
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
|
||||||
await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv)
|
await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv)
|
||||||
|
@ -8,7 +8,7 @@ from apps.monero.xmr import crypto
|
|||||||
|
|
||||||
|
|
||||||
async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain):
|
async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain):
|
||||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
await confirms.require_confirm_watchkey(ctx)
|
await confirms.require_confirm_watchkey(ctx)
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class KeyImageSync:
|
|||||||
|
|
||||||
|
|
||||||
async def _init_step(s, ctx, msg, keychain):
|
async def _init_step(s, ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ class LiveRefreshState:
|
|||||||
async def _init_step(
|
async def _init_step(
|
||||||
s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain
|
s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain
|
||||||
):
|
):
|
||||||
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
await confirms.require_confirm_live_refresh(ctx)
|
await confirms.require_confirm_live_refresh(ctx)
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ async def init_transaction(
|
|||||||
from apps.monero.signing import offloading_keys
|
from apps.monero.signing import offloading_keys
|
||||||
from apps.common import paths
|
from apps.common import paths
|
||||||
|
|
||||||
await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n)
|
await paths.validate_path(state.ctx, misc.validate_full_path, keychain, address_n)
|
||||||
|
|
||||||
state.creds = misc.get_creds(keychain, address_n, network_type)
|
state.creds = misc.get_creds(keychain, address_n, network_type)
|
||||||
state.client_version = tsx_data.client_version or 0
|
state.client_version = tsx_data.client_version or 0
|
||||||
|
@ -9,7 +9,7 @@ from apps.common.paths import validate_path
|
|||||||
|
|
||||||
async def get_address(ctx, msg, keychain):
|
async def get_address(ctx, msg, keychain):
|
||||||
network = validate_network(msg.network)
|
network = validate_network(msg.network)
|
||||||
await validate_path(ctx, check_path, path=msg.address_n, network=network)
|
await validate_path(ctx, check_path, keychain, msg.address_n, network=network)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, NEM_CURVE)
|
node = keychain.derive(msg.address_n, NEM_CURVE)
|
||||||
address = node.nem_address(network)
|
address = node.nem_address(network)
|
||||||
|
@ -13,7 +13,11 @@ async def sign_tx(ctx, msg: NEMSignTx, keychain):
|
|||||||
validate(msg)
|
validate(msg)
|
||||||
|
|
||||||
await validate_path(
|
await validate_path(
|
||||||
ctx, check_path, path=msg.transaction.address_n, network=msg.transaction.network
|
ctx,
|
||||||
|
check_path,
|
||||||
|
keychain,
|
||||||
|
msg.transaction.address_n,
|
||||||
|
network=msg.transaction.network,
|
||||||
)
|
)
|
||||||
|
|
||||||
node = keychain.derive(msg.transaction.address_n, NEM_CURVE)
|
node = keychain.derive(msg.transaction.address_n, NEM_CURVE)
|
||||||
|
@ -7,7 +7,7 @@ from apps.ripple import helpers
|
|||||||
|
|
||||||
|
|
||||||
async def get_address(ctx, msg: RippleGetAddress, keychain):
|
async def get_address(ctx, msg: RippleGetAddress, keychain):
|
||||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
pubkey = node.public_key()
|
pubkey = node.public_key()
|
||||||
|
@ -13,7 +13,7 @@ from apps.ripple.serialize import serialize
|
|||||||
async def sign_tx(ctx, msg: RippleSignTx, keychain):
|
async def sign_tx(ctx, msg: RippleSignTx, keychain):
|
||||||
validate(msg)
|
validate(msg)
|
||||||
|
|
||||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n)
|
node = keychain.derive(msg.address_n)
|
||||||
source_address = helpers.address_from_public_key(node.public_key())
|
source_address = helpers.address_from_public_key(node.public_key())
|
||||||
|
@ -7,7 +7,7 @@ from apps.stellar import helpers
|
|||||||
|
|
||||||
|
|
||||||
async def get_address(ctx, msg: StellarGetAddress, keychain):
|
async def get_address(ctx, msg: StellarGetAddress, keychain):
|
||||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE)
|
node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE)
|
||||||
pubkey = seed.remove_ed25519_prefix(node.public_key())
|
pubkey = seed.remove_ed25519_prefix(node.public_key())
|
||||||
|
@ -13,7 +13,7 @@ from apps.stellar.operations import process_operation
|
|||||||
|
|
||||||
|
|
||||||
async def sign_tx(ctx, msg: StellarSignTx, keychain):
|
async def sign_tx(ctx, msg: StellarSignTx, keychain):
|
||||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, consts.STELLAR_CURVE)
|
node = keychain.derive(msg.address_n, consts.STELLAR_CURVE)
|
||||||
pubkey = seed.remove_ed25519_prefix(node.public_key())
|
pubkey = seed.remove_ed25519_prefix(node.public_key())
|
||||||
|
@ -7,7 +7,7 @@ from apps.tezos import helpers
|
|||||||
|
|
||||||
|
|
||||||
async def get_address(ctx, msg, keychain):
|
async def get_address(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from apps.tezos import helpers
|
|||||||
|
|
||||||
|
|
||||||
async def get_public_key(ctx, msg, keychain):
|
async def get_public_key(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
||||||
pk = seed.remove_ed25519_prefix(node.public_key())
|
pk = seed.remove_ed25519_prefix(node.public_key())
|
||||||
|
@ -10,7 +10,7 @@ from apps.tezos import helpers, layout
|
|||||||
|
|
||||||
|
|
||||||
async def sign_tx(ctx, msg, keychain):
|
async def sign_tx(ctx, msg, keychain):
|
||||||
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
|
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
|
||||||
|
|
||||||
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
|
||||||
|
|
||||||
|
@ -14,7 +14,8 @@ async def get_address(ctx, msg, keychain):
|
|||||||
await validate_path(
|
await validate_path(
|
||||||
ctx,
|
ctx,
|
||||||
addresses.validate_full_path,
|
addresses.validate_full_path,
|
||||||
path=msg.address_n,
|
keychain,
|
||||||
|
msg.address_n,
|
||||||
coin=coin,
|
coin=coin,
|
||||||
script_type=msg.script_type,
|
script_type=msg.script_type,
|
||||||
)
|
)
|
||||||
|
@ -22,7 +22,8 @@ async def sign_message(ctx, msg, keychain):
|
|||||||
await validate_path(
|
await validate_path(
|
||||||
ctx,
|
ctx,
|
||||||
validate_full_path,
|
validate_full_path,
|
||||||
path=msg.address_n,
|
keychain,
|
||||||
|
msg.address_n,
|
||||||
coin=coin,
|
coin=coin,
|
||||||
script_type=msg.script_type,
|
script_type=msg.script_type,
|
||||||
validate_script_type=False,
|
validate_script_type=False,
|
||||||
|
Loading…
Reference in New Issue
Block a user