common: pass Keychain to path validation function

closes #519
pull/25/head
Tomas Susanka 5 years ago
parent e89699817f
commit 7cadefcdd0

@ -10,7 +10,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg):
keychain = await seed.get_keychain(ctx)
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
try:
address, _ = derive_address_and_node(keychain, msg.address_n)

@ -14,7 +14,11 @@ async def get_public_key(ctx, msg):
keychain = await seed.get_keychain(ctx)
await paths.validate_path(
ctx, paths.validate_path_for_get_public_key, path=msg.address_n, slip44_id=1815
ctx,
paths.validate_path_for_get_public_key,
keychain,
msg.address_n,
slip44_id=1815,
)
try:

@ -11,6 +11,10 @@ class Keychain:
self.path = path
self.root = root
def validate_path(self, checked_path: list):
if checked_path[:2] != SEED_NAMESPACE[0]:
raise wire.DataError("Forbidden key path")
def derive(self, node_path: list) -> bip32.HDNode:
# check we are in the cardano namespace
prefix = node_path[: len(self.path)]

@ -85,7 +85,7 @@ async def sign_tx(ctx, msg):
display_homescreen()
for i in msg.inputs:
await validate_path(ctx, validate_full_path, path=i.address_n)
await validate_path(ctx, validate_full_path, keychain, i.address_n)
# sign the transaction bundle and prepare the result
transaction = Transaction(

@ -8,9 +8,10 @@ from apps.common import HARDENED
from apps.common.confirm import require_confirm
async def validate_path(ctx, validate_func, **kwargs):
if not validate_func(**kwargs):
await show_path_warning(ctx, kwargs["path"])
async def validate_path(ctx, validate_func, keychain, path, **kwargs):
keychain.validate_path(path)
if not validate_func(path, **kwargs):
await show_path_warning(ctx, path)
async def show_path_warning(ctx, path: list):

@ -25,6 +25,22 @@ class Keychain:
del self.roots
del self.seed
def validate_path(self, checked_path: list):
empty = True
for _, *namespace_path in self.namespaces:
if empty and len(namespace_path):
empty = False
for i, p in enumerate(namespace_path):
if p != checked_path[i]:
# item did not match, move on to the next allowed path in namespace
break
if i == len(namespace_path) - 1:
# all items match in some namespace path -> success
return
if not empty:
# checked_path was not among the allowed ones
raise wire.DataError("Forbidden key path")
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
# find the root node index
root_index = 0
@ -44,6 +60,7 @@ class Keychain:
root.derive_path(path)
self.roots[root_index] = root
# TODO check for ed25519?
# derive child node from the root
node = root.clone()
node.derive_path(suffix)

@ -9,7 +9,7 @@ from apps.ethereum.address import address_from_bytes, validate_full_path
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n)
seckey = node.private_key()

@ -7,7 +7,7 @@ from apps.ethereum import address
async def get_public_key(ctx, msg, keychain):
await paths.validate_path(
ctx, address.validate_path_for_get_public_key, path=msg.address_n
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n
)
node = keychain.derive(msg.address_n)

@ -20,7 +20,7 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, address.validate_full_path, keychain, msg.address_n)
await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n)

@ -23,7 +23,7 @@ MAX_CHAIN_ID = 2147483629
async def sign_tx(ctx, msg, keychain):
msg = sanitize(msg)
check(msg)
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
data_total = msg.data_length

@ -7,7 +7,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, LISK_CURVE)
pubkey = node.public_key()

@ -6,7 +6,7 @@ from apps.common import layout, paths
async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, LISK_CURVE)
pubkey = node.public_key()

@ -23,7 +23,7 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n, LISK_CURVE)

@ -17,7 +17,7 @@ from apps.lisk.helpers import (
async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
pubkey, seckey = _get_keys(keychain, msg)
transaction = _update_raw_tx(msg.transaction, pubkey)

@ -6,7 +6,7 @@ from apps.monero import misc
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

@ -30,7 +30,7 @@ _GET_TX_KEY_REASON_TX_DERIVATION = 1
async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv)

@ -8,7 +8,7 @@ from apps.monero.xmr import crypto
async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
await confirms.require_confirm_watchkey(ctx)

@ -47,7 +47,7 @@ class KeyImageSync:
async def _init_step(s, ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

@ -44,7 +44,7 @@ class LiveRefreshState:
async def _init_step(
s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain
):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
await confirms.require_confirm_live_refresh(ctx)

@ -24,7 +24,7 @@ async def init_transaction(
from apps.monero.signing import offloading_keys
from apps.common import paths
await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n)
await paths.validate_path(state.ctx, misc.validate_full_path, keychain, address_n)
state.creds = misc.get_creds(keychain, address_n, network_type)
state.client_version = tsx_data.client_version or 0

@ -9,7 +9,7 @@ from apps.common.paths import validate_path
async def get_address(ctx, msg, keychain):
network = validate_network(msg.network)
await validate_path(ctx, check_path, path=msg.address_n, network=network)
await validate_path(ctx, check_path, keychain, msg.address_n, network=network)
node = keychain.derive(msg.address_n, NEM_CURVE)
address = node.nem_address(network)

@ -13,7 +13,11 @@ async def sign_tx(ctx, msg: NEMSignTx, keychain):
validate(msg)
await validate_path(
ctx, check_path, path=msg.transaction.address_n, network=msg.transaction.network
ctx,
check_path,
keychain,
msg.transaction.address_n,
network=msg.transaction.network,
)
node = keychain.derive(msg.transaction.address_n, NEM_CURVE)

@ -7,7 +7,7 @@ from apps.ripple import helpers
async def get_address(ctx, msg: RippleGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n)
pubkey = node.public_key()

@ -13,7 +13,7 @@ from apps.ripple.serialize import serialize
async def sign_tx(ctx, msg: RippleSignTx, keychain):
validate(msg)
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n)
source_address = helpers.address_from_public_key(node.public_key())

@ -7,7 +7,7 @@ from apps.stellar import helpers
async def get_address(ctx, msg: StellarGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key())

@ -13,7 +13,7 @@ from apps.stellar.operations import process_operation
async def sign_tx(ctx, msg: StellarSignTx, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, consts.STELLAR_CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key())

@ -7,7 +7,7 @@ from apps.tezos import helpers
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)

@ -10,7 +10,7 @@ from apps.tezos import helpers
async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
pk = seed.remove_ed25519_prefix(node.public_key())

@ -10,7 +10,7 @@ from apps.tezos import helpers, layout
async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)

@ -14,7 +14,8 @@ async def get_address(ctx, msg, keychain):
await validate_path(
ctx,
addresses.validate_full_path,
path=msg.address_n,
keychain,
msg.address_n,
coin=coin,
script_type=msg.script_type,
)

@ -22,7 +22,8 @@ async def sign_message(ctx, msg, keychain):
await validate_path(
ctx,
validate_full_path,
path=msg.address_n,
keychain,
msg.address_n,
coin=coin,
script_type=msg.script_type,
validate_script_type=False,

Loading…
Cancel
Save