1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-16 03:18:09 +00:00

common: pass Keychain to path validation function

closes #519
This commit is contained in:
Tomas Susanka 2019-04-02 16:00:51 +02:00
parent e89699817f
commit 7cadefcdd0
31 changed files with 63 additions and 31 deletions

View File

@ -10,7 +10,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg): async def get_address(ctx, msg):
keychain = await seed.get_keychain(ctx) keychain = await seed.get_keychain(ctx)
await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
try: try:
address, _ = derive_address_and_node(keychain, msg.address_n) address, _ = derive_address_and_node(keychain, msg.address_n)

View File

@ -14,7 +14,11 @@ async def get_public_key(ctx, msg):
keychain = await seed.get_keychain(ctx) keychain = await seed.get_keychain(ctx)
await paths.validate_path( await paths.validate_path(
ctx, paths.validate_path_for_get_public_key, path=msg.address_n, slip44_id=1815 ctx,
paths.validate_path_for_get_public_key,
keychain,
msg.address_n,
slip44_id=1815,
) )
try: try:

View File

@ -11,6 +11,10 @@ class Keychain:
self.path = path self.path = path
self.root = root self.root = root
def validate_path(self, checked_path: list):
if checked_path[:2] != SEED_NAMESPACE[0]:
raise wire.DataError("Forbidden key path")
def derive(self, node_path: list) -> bip32.HDNode: def derive(self, node_path: list) -> bip32.HDNode:
# check we are in the cardano namespace # check we are in the cardano namespace
prefix = node_path[: len(self.path)] prefix = node_path[: len(self.path)]

View File

@ -85,7 +85,7 @@ async def sign_tx(ctx, msg):
display_homescreen() display_homescreen()
for i in msg.inputs: for i in msg.inputs:
await validate_path(ctx, validate_full_path, path=i.address_n) await validate_path(ctx, validate_full_path, keychain, i.address_n)
# sign the transaction bundle and prepare the result # sign the transaction bundle and prepare the result
transaction = Transaction( transaction = Transaction(

View File

@ -8,9 +8,10 @@ from apps.common import HARDENED
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
async def validate_path(ctx, validate_func, **kwargs): async def validate_path(ctx, validate_func, keychain, path, **kwargs):
if not validate_func(**kwargs): keychain.validate_path(path)
await show_path_warning(ctx, kwargs["path"]) if not validate_func(path, **kwargs):
await show_path_warning(ctx, path)
async def show_path_warning(ctx, path: list): async def show_path_warning(ctx, path: list):

View File

@ -25,6 +25,22 @@ class Keychain:
del self.roots del self.roots
del self.seed del self.seed
def validate_path(self, checked_path: list):
empty = True
for _, *namespace_path in self.namespaces:
if empty and len(namespace_path):
empty = False
for i, p in enumerate(namespace_path):
if p != checked_path[i]:
# item did not match, move on to the next allowed path in namespace
break
if i == len(namespace_path) - 1:
# all items match in some namespace path -> success
return
if not empty:
# checked_path was not among the allowed ones
raise wire.DataError("Forbidden key path")
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode: def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
# find the root node index # find the root node index
root_index = 0 root_index = 0
@ -44,6 +60,7 @@ class Keychain:
root.derive_path(path) root.derive_path(path)
self.roots[root_index] = root self.roots[root_index] = root
# TODO check for ed25519?
# derive child node from the root # derive child node from the root
node = root.clone() node = root.clone()
node.derive_path(suffix) node.derive_path(suffix)

View File

@ -9,7 +9,7 @@ from apps.ethereum.address import address_from_bytes, validate_full_path
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
seckey = node.private_key() seckey = node.private_key()

View File

@ -7,7 +7,7 @@ from apps.ethereum import address
async def get_public_key(ctx, msg, keychain): async def get_public_key(ctx, msg, keychain):
await paths.validate_path( await paths.validate_path(
ctx, address.validate_path_for_get_public_key, path=msg.address_n ctx, address.validate_path_for_get_public_key, keychain, msg.address_n
) )
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

View File

@ -20,7 +20,7 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain): async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, address.validate_full_path, keychain, msg.address_n)
await require_confirm_sign_message(ctx, msg.message) await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)

View File

@ -23,7 +23,7 @@ MAX_CHAIN_ID = 2147483629
async def sign_tx(ctx, msg, keychain): async def sign_tx(ctx, msg, keychain):
msg = sanitize(msg) msg = sanitize(msg)
check(msg) check(msg)
await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
data_total = msg.data_length data_total = msg.data_length

View File

@ -7,7 +7,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, LISK_CURVE) node = keychain.derive(msg.address_n, LISK_CURVE)
pubkey = node.public_key() pubkey = node.public_key()

View File

@ -6,7 +6,7 @@ from apps.common import layout, paths
async def get_public_key(ctx, msg, keychain): async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, LISK_CURVE) node = keychain.derive(msg.address_n, LISK_CURVE)
pubkey = node.public_key() pubkey = node.public_key()

View File

@ -23,7 +23,7 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain): async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
await require_confirm_sign_message(ctx, msg.message) await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n, LISK_CURVE) node = keychain.derive(msg.address_n, LISK_CURVE)

View File

@ -17,7 +17,7 @@ from apps.lisk.helpers import (
async def sign_tx(ctx, msg, keychain): async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n) await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
pubkey, seckey = _get_keys(keychain, msg) pubkey, seckey = _get_keys(keychain, msg)
transaction = _update_raw_tx(msg.transaction, pubkey) transaction = _update_raw_tx(msg.transaction, pubkey)

View File

@ -6,7 +6,7 @@ from apps.monero import misc
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type) creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -30,7 +30,7 @@ _GET_TX_KEY_REASON_TX_DERIVATION = 1
async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain): async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv) await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv)

View File

@ -8,7 +8,7 @@ from apps.monero.xmr import crypto
async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain): async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
await confirms.require_confirm_watchkey(ctx) await confirms.require_confirm_watchkey(ctx)

View File

@ -47,7 +47,7 @@ class KeyImageSync:
async def _init_step(s, ctx, msg, keychain): async def _init_step(s, ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -44,7 +44,7 @@ class LiveRefreshState:
async def _init_step( async def _init_step(
s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain
): ):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
await confirms.require_confirm_live_refresh(ctx) await confirms.require_confirm_live_refresh(ctx)

View File

@ -24,7 +24,7 @@ async def init_transaction(
from apps.monero.signing import offloading_keys from apps.monero.signing import offloading_keys
from apps.common import paths from apps.common import paths
await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n) await paths.validate_path(state.ctx, misc.validate_full_path, keychain, address_n)
state.creds = misc.get_creds(keychain, address_n, network_type) state.creds = misc.get_creds(keychain, address_n, network_type)
state.client_version = tsx_data.client_version or 0 state.client_version = tsx_data.client_version or 0

View File

@ -9,7 +9,7 @@ from apps.common.paths import validate_path
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
network = validate_network(msg.network) network = validate_network(msg.network)
await validate_path(ctx, check_path, path=msg.address_n, network=network) await validate_path(ctx, check_path, keychain, msg.address_n, network=network)
node = keychain.derive(msg.address_n, NEM_CURVE) node = keychain.derive(msg.address_n, NEM_CURVE)
address = node.nem_address(network) address = node.nem_address(network)

View File

@ -13,7 +13,11 @@ async def sign_tx(ctx, msg: NEMSignTx, keychain):
validate(msg) validate(msg)
await validate_path( await validate_path(
ctx, check_path, path=msg.transaction.address_n, network=msg.transaction.network ctx,
check_path,
keychain,
msg.transaction.address_n,
network=msg.transaction.network,
) )
node = keychain.derive(msg.transaction.address_n, NEM_CURVE) node = keychain.derive(msg.transaction.address_n, NEM_CURVE)

View File

@ -7,7 +7,7 @@ from apps.ripple import helpers
async def get_address(ctx, msg: RippleGetAddress, keychain): async def get_address(ctx, msg: RippleGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
pubkey = node.public_key() pubkey = node.public_key()

View File

@ -13,7 +13,7 @@ from apps.ripple.serialize import serialize
async def sign_tx(ctx, msg: RippleSignTx, keychain): async def sign_tx(ctx, msg: RippleSignTx, keychain):
validate(msg) validate(msg)
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n) node = keychain.derive(msg.address_n)
source_address = helpers.address_from_public_key(node.public_key()) source_address = helpers.address_from_public_key(node.public_key())

View File

@ -7,7 +7,7 @@ from apps.stellar import helpers
async def get_address(ctx, msg: StellarGetAddress, keychain): async def get_address(ctx, msg: StellarGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE) node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key()) pubkey = seed.remove_ed25519_prefix(node.public_key())

View File

@ -13,7 +13,7 @@ from apps.stellar.operations import process_operation
async def sign_tx(ctx, msg: StellarSignTx, keychain): async def sign_tx(ctx, msg: StellarSignTx, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, consts.STELLAR_CURVE) node = keychain.derive(msg.address_n, consts.STELLAR_CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key()) pubkey = seed.remove_ed25519_prefix(node.public_key())

View File

@ -7,7 +7,7 @@ from apps.tezos import helpers
async def get_address(ctx, msg, keychain): async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)

View File

@ -10,7 +10,7 @@ from apps.tezos import helpers
async def get_public_key(ctx, msg, keychain): async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
pk = seed.remove_ed25519_prefix(node.public_key()) pk = seed.remove_ed25519_prefix(node.public_key())

View File

@ -10,7 +10,7 @@ from apps.tezos import helpers, layout
async def sign_tx(ctx, msg, keychain): async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n) await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE) node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)

View File

@ -14,7 +14,8 @@ async def get_address(ctx, msg, keychain):
await validate_path( await validate_path(
ctx, ctx,
addresses.validate_full_path, addresses.validate_full_path,
path=msg.address_n, keychain,
msg.address_n,
coin=coin, coin=coin,
script_type=msg.script_type, script_type=msg.script_type,
) )

View File

@ -22,7 +22,8 @@ async def sign_message(ctx, msg, keychain):
await validate_path( await validate_path(
ctx, ctx,
validate_full_path, validate_full_path,
path=msg.address_n, keychain,
msg.address_n,
coin=coin, coin=coin,
script_type=msg.script_type, script_type=msg.script_type,
validate_script_type=False, validate_script_type=False,