1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-15 09:50:57 +00:00

Merge pull request #532 from trezor/tsusanka/path-validation

Pass Keychain to path validation function to throw error before warning
This commit is contained in:
Tomas Susanka 2019-04-04 13:40:59 +02:00 committed by GitHub
commit 396f5f1937
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 53 additions and 31 deletions

View File

@ -10,7 +10,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg):
keychain = await seed.get_keychain(ctx)
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
try:
address, _ = derive_address_and_node(keychain, msg.address_n)

View File

@ -14,7 +14,11 @@ async def get_public_key(ctx, msg):
keychain = await seed.get_keychain(ctx)
await paths.validate_path(
ctx, paths.validate_path_for_get_public_key, path=msg.address_n, slip44_id=1815
ctx,
paths.validate_path_for_get_public_key,
keychain,
msg.address_n,
slip44_id=1815,
)
try:

View File

@ -11,6 +11,10 @@ class Keychain:
self.path = path
self.root = root
def validate_path(self, checked_path: list):
if checked_path[:2] != SEED_NAMESPACE[0]:
raise wire.DataError("Forbidden key path")
def derive(self, node_path: list) -> bip32.HDNode:
# check we are in the cardano namespace
prefix = node_path[: len(self.path)]

View File

@ -85,7 +85,7 @@ async def sign_tx(ctx, msg):
display_homescreen()
for i in msg.inputs:
await validate_path(ctx, validate_full_path, path=i.address_n)
await validate_path(ctx, validate_full_path, keychain, i.address_n)
# sign the transaction bundle and prepare the result
transaction = Transaction(

View File

@ -8,9 +8,10 @@ from apps.common import HARDENED
from apps.common.confirm import require_confirm
async def validate_path(ctx, validate_func, **kwargs):
if not validate_func(**kwargs):
await show_path_warning(ctx, kwargs["path"])
async def validate_path(ctx, validate_func, keychain, path, **kwargs):
keychain.validate_path(path)
if not validate_func(path, **kwargs):
await show_path_warning(ctx, path)
async def show_path_warning(ctx, path: list):

View File

@ -25,6 +25,12 @@ class Keychain:
del self.roots
del self.seed
def validate_path(self, checked_path: list):
for curve, *path in self.namespaces:
if path == checked_path[: len(path)]: # TODO: check curve_name
return
raise wire.DataError("Forbidden key path")
def derive(self, node_path: list, curve_name: str = "secp256k1") -> bip32.HDNode:
# find the root node index
root_index = 0
@ -44,6 +50,7 @@ class Keychain:
root.derive_path(path)
self.roots[root_index] = root
# TODO check for ed25519?
# derive child node from the root
node = root.clone()
node.derive_path(suffix)

View File

@ -9,7 +9,7 @@ from apps.ethereum.address import address_from_bytes, validate_full_path
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n)
seckey = node.private_key()

View File

@ -7,7 +7,7 @@ from apps.ethereum import address
async def get_public_key(ctx, msg, keychain):
await paths.validate_path(
ctx, address.validate_path_for_get_public_key, path=msg.address_n
ctx, address.validate_path_for_get_public_key, keychain, msg.address_n
)
node = keychain.derive(msg.address_n)

View File

@ -20,7 +20,7 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, address.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, address.validate_full_path, keychain, msg.address_n)
await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n)

View File

@ -23,7 +23,7 @@ MAX_CHAIN_ID = 2147483629
async def sign_tx(ctx, msg, keychain):
msg = sanitize(msg)
check(msg)
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
data_total = msg.data_length

View File

@ -7,7 +7,7 @@ from apps.common.layout import address_n_to_str, show_address, show_qr
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, LISK_CURVE)
pubkey = node.public_key()

View File

@ -6,7 +6,7 @@ from apps.common import layout, paths
async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, LISK_CURVE)
pubkey = node.public_key()

View File

@ -23,7 +23,7 @@ def message_digest(message):
async def sign_message(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
await require_confirm_sign_message(ctx, msg.message)
node = keychain.derive(msg.address_n, LISK_CURVE)

View File

@ -17,7 +17,7 @@ from apps.lisk.helpers import (
async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, validate_full_path, keychain, msg.address_n)
pubkey, seckey = _get_keys(keychain, msg)
transaction = _update_raw_tx(msg.transaction, pubkey)

View File

@ -6,7 +6,7 @@ from apps.monero import misc
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -30,7 +30,7 @@ _GET_TX_KEY_REASON_TX_DERIVATION = 1
async def get_tx_keys(ctx, msg: MoneroGetTxKeyRequest, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
do_deriv = msg.reason == _GET_TX_KEY_REASON_TX_DERIVATION
await confirms.require_confirm_tx_key(ctx, export_key=not do_deriv)

View File

@ -8,7 +8,7 @@ from apps.monero.xmr import crypto
async def get_watch_only(ctx, msg: MoneroGetWatchKey, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
await confirms.require_confirm_watchkey(ctx)

View File

@ -47,7 +47,7 @@ class KeyImageSync:
async def _init_step(s, ctx, msg, keychain):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -44,7 +44,7 @@ class LiveRefreshState:
async def _init_step(
s: LiveRefreshState, ctx, msg: MoneroLiveRefreshStartRequest, keychain
):
await paths.validate_path(ctx, misc.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, misc.validate_full_path, keychain, msg.address_n)
await confirms.require_confirm_live_refresh(ctx)

View File

@ -24,7 +24,7 @@ async def init_transaction(
from apps.monero.signing import offloading_keys
from apps.common import paths
await paths.validate_path(state.ctx, misc.validate_full_path, path=address_n)
await paths.validate_path(state.ctx, misc.validate_full_path, keychain, address_n)
state.creds = misc.get_creds(keychain, address_n, network_type)
state.client_version = tsx_data.client_version or 0

View File

@ -9,7 +9,7 @@ from apps.common.paths import validate_path
async def get_address(ctx, msg, keychain):
network = validate_network(msg.network)
await validate_path(ctx, check_path, path=msg.address_n, network=network)
await validate_path(ctx, check_path, keychain, msg.address_n, network=network)
node = keychain.derive(msg.address_n, NEM_CURVE)
address = node.nem_address(network)

View File

@ -13,7 +13,11 @@ async def sign_tx(ctx, msg: NEMSignTx, keychain):
validate(msg)
await validate_path(
ctx, check_path, path=msg.transaction.address_n, network=msg.transaction.network
ctx,
check_path,
keychain,
msg.transaction.address_n,
network=msg.transaction.network,
)
node = keychain.derive(msg.transaction.address_n, NEM_CURVE)

View File

@ -7,7 +7,7 @@ from apps.ripple import helpers
async def get_address(ctx, msg: RippleGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n)
pubkey = node.public_key()

View File

@ -13,7 +13,7 @@ from apps.ripple.serialize import serialize
async def sign_tx(ctx, msg: RippleSignTx, keychain):
validate(msg)
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n)
source_address = helpers.address_from_public_key(node.public_key())

View File

@ -7,7 +7,7 @@ from apps.stellar import helpers
async def get_address(ctx, msg: StellarGetAddress, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.STELLAR_CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key())

View File

@ -13,7 +13,7 @@ from apps.stellar.operations import process_operation
async def sign_tx(ctx, msg: StellarSignTx, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, consts.STELLAR_CURVE)
pubkey = seed.remove_ed25519_prefix(node.public_key())

View File

@ -7,7 +7,7 @@ from apps.tezos import helpers
async def get_address(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)

View File

@ -10,7 +10,7 @@ from apps.tezos import helpers
async def get_public_key(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)
pk = seed.remove_ed25519_prefix(node.public_key())

View File

@ -10,7 +10,7 @@ from apps.tezos import helpers, layout
async def sign_tx(ctx, msg, keychain):
await paths.validate_path(ctx, helpers.validate_full_path, path=msg.address_n)
await paths.validate_path(ctx, helpers.validate_full_path, keychain, msg.address_n)
node = keychain.derive(msg.address_n, helpers.TEZOS_CURVE)

View File

@ -14,7 +14,8 @@ async def get_address(ctx, msg, keychain):
await validate_path(
ctx,
addresses.validate_full_path,
path=msg.address_n,
keychain,
msg.address_n,
coin=coin,
script_type=msg.script_type,
)

View File

@ -22,7 +22,8 @@ async def sign_message(ctx, msg, keychain):
await validate_path(
ctx,
validate_full_path,
path=msg.address_n,
keychain,
msg.address_n,
coin=coin,
script_type=msg.script_type,
validate_script_type=False,