From b1164077e915d5b37a3c90e426ef4822f32409db Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Mon, 20 Nov 2017 12:47:39 +0100 Subject: [PATCH] wallet/signing: add change output restrictions --- src/apps/wallet/sign_tx/scripts.py | 12 ++- src/apps/wallet/sign_tx/signing.py | 155 +++++++++++++++++++---------- 2 files changed, 112 insertions(+), 55 deletions(-) diff --git a/src/apps/wallet/sign_tx/scripts.py b/src/apps/wallet/sign_tx/scripts.py index 5c7a794c98..7dc630b0de 100644 --- a/src/apps/wallet/sign_tx/scripts.py +++ b/src/apps/wallet/sign_tx/scripts.py @@ -56,12 +56,14 @@ def input_script_native_p2wpkh_or_p2wsh() -> bytearray: return bytearray(0) -# output script consists of 00 14 <20-byte-key-hash> -def output_script_native_p2wpkh_or_p2wsh(pubkeyhash: bytes) -> bytearray: - w = bytearray_with_cap(3 + len(pubkeyhash)) +# output script is either: +# 00 14 <20-byte-key-hash> +# 00 20 <32-byte-script-hash> +def output_script_native_p2wpkh_or_p2wsh(witprog: bytes) -> bytearray: + w = bytearray_with_cap(3 + len(witprog)) w.append(0x00) # witness version byte - w.append(len(pubkeyhash)) # pub key hash length is 20 (P2WPKH) or 32 (P2WSH) bytes - write_bytes(w, pubkeyhash) # pub key hash + w.append(len(witprog)) # pub key hash length is 20 (P2WPKH) or 32 (P2WSH) bytes + write_bytes(w, witprog) # pub key hash return w diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index 50f0eae544..8ecebd130d 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -13,6 +13,16 @@ from apps.wallet.sign_tx.segwit_bip143 import * from apps.wallet.sign_tx.helpers import * from apps.wallet.sign_tx.scripts import * +# the number of bip32 levels used in a wallet (chain and address) +_BIP32_WALLET_DEPTH = const(2) + +# the chain id used for change +_BIP32_CHANGE_CHAIN = const(1) + +# the maximum allowed change address. this should be large enough for normal +# use and still allow to quickly brute-force the correct bip32 path +_BIP32_MAX_LAST_ELEMENT = const(1000000) + class SigningError(ValueError): pass @@ -24,6 +34,7 @@ class SigningError(ValueError): # for pseudo code overview # === + # Phase 1 # - check inputs, previous transactions, and outputs # - ask for confirmations @@ -45,37 +56,47 @@ async def check_tx_fee(tx: SignTx, root): total_in = 0 # sum of input amounts total_out = 0 # sum of output amounts change_out = 0 # change output amount + wallet_path = [] # common prefix of input paths segwit = {} # dict of booleans stating if input is segwit for i in range(tx.inputs_count): # STAGE_REQUEST_1_INPUT txi = await request_tx_input(tx_req, i) + wallet_path = input_extract_wallet_path(txi, wallet_path) write_tx_input_check(h_first, txi) - if txi.script_type in (InputScriptType.SPENDP2SHWITNESS, InputScriptType.SPENDWITNESS): + if (txi.script_type == InputScriptType.SPENDWITNESS or + txi.script_type == InputScriptType.SPENDP2SHWITNESS): + if not coin.segwit: + raise SigningError(FailureType.DataError, + 'Segwit not enabled on this coin') + if not txi.amount: + raise SigningError(FailureType.DataError, + 'Segwit input without amount') segwit[i] = True - # Add I to segwit hash_prevouts, hash_sequence bip143.add_prevouts(txi) bip143.add_sequence(txi) total_in += txi.amount - else: + elif txi.script_type == InputScriptType.SPENDADDRESS: segwit[i] = False total_in += await get_prevtx_output_value( tx_req, txi.prev_hash, txi.prev_index) + else: + raise SigningError(FailureType.DataError, + 'Wrong input script type') for o in range(tx.outputs_count): # STAGE_REQUEST_3_OUTPUT txo = await request_tx_output(tx_req, o) - if output_is_change(txo): + txo_bin.amount = txo.amount + txo_bin.script_pubkey = output_derive_script(txo, coin, root) + if output_is_change(txo, wallet_path): if change_out != 0: raise SigningError(FailureType.ProcessError, 'Only one change output is valid') change_out = txo.amount - else: - if not await confirm_output(txo, coin): - raise SigningError(FailureType.ActionCancelled, - 'Output cancelled') - txo_bin.amount = txo.amount - txo_bin.script_pubkey = output_derive_script(txo, coin, root) + elif not await confirm_output(txo, coin): + raise SigningError(FailureType.ActionCancelled, + 'Output cancelled') write_tx_output(h_first, txo_bin) bip143.add_output(txo_bin) total_out += txo_bin.amount @@ -85,7 +106,8 @@ async def check_tx_fee(tx: SignTx, root): raise SigningError(FailureType.NotEnoughFunds, 'Not enough funds') - if fee > coin.maxfee_kb * ((estimate_tx_size(tx.inputs_count, tx.outputs_count) + 999) // 1000): + tx_size_b = estimate_tx_size(tx.inputs_count, tx.outputs_count) + if fee > coin.maxfee_kb * ((tx_size_b + 999) // 1000): if not await confirm_feeoverthreshold(fee, coin): raise SigningError(FailureType.ActionCancelled, 'Signing cancelled') @@ -94,7 +116,7 @@ async def check_tx_fee(tx: SignTx, root): raise SigningError(FailureType.ActionCancelled, 'Total cancelled') - return h_first, tx_req, txo_bin, bip143, segwit, total_in + return h_first, tx_req, txo_bin, bip143, segwit, total_in, wallet_path async def sign_tx(tx: SignTx, root): @@ -103,7 +125,8 @@ async def sign_tx(tx: SignTx, root): # Phase 1 - h_first, tx_req, txo_bin, bip143, segwit, authorized_in = await check_tx_fee(tx, root) + h_first, tx_req, txo_bin, bip143, segwit, authorized_in, wallet_path = \ + await check_tx_fee(tx, root) # Phase 2 # - sign inputs @@ -129,32 +152,37 @@ async def sign_tx(tx: SignTx, root): if segwit[i_sign]: # STAGE_REQUEST_SEGWIT_INPUT txi_sign = await request_tx_input(tx_req, i_sign) - write_tx_input_check(h_second, txi_sign) - if txi_sign.script_type in (InputScriptType.SPENDP2SHWITNESS, InputScriptType.SPENDWITNESS): - key_sign = node_derive(root, txi_sign.address_n) - key_sign_pub = key_sign.public_key() - txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub) - w_txi = bytearray_with_cap( - 7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) - if i_sign == 0: # serializing first input => prepend headers - write_bytes(w_txi, get_tx_header(tx, True)) - write_tx_input(w_txi, txi_sign) - tx_ser.serialized_tx = w_txi + if (txi_sign.script_type != InputScriptType.SPENDWITNESS and + txi_sign.script_type != InputScriptType.SPENDP2SHWITNESS): + raise SigningError(FailureType.ProcessError, + 'Transaction has changed during signing') + input_check_wallet_path(txi_sign, wallet_path) + write_tx_input_check(h_second, txi_sign) + + key_sign = node_derive(root, txi_sign.address_n) + key_sign_pub = key_sign.public_key() + txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub) + w_txi = bytearray_with_cap( + 7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4) + if i_sign == 0: # serializing first input => prepend headers + write_bytes(w_txi, get_tx_header(tx, True)) + write_tx_input(w_txi, txi_sign) + tx_ser.serialized_tx = w_txi tx_req.serialized = tx_ser else: for i in range(tx.inputs_count): # STAGE_REQUEST_4_INPUT txi = await request_tx_input(tx_req, i) + input_check_wallet_path(txi, wallet_path) write_tx_input_check(h_second, txi) if i == i_sign: txi_sign = txi key_sign = node_derive(root, txi.address_n) key_sign_pub = key_sign.public_key() - # the signature has to be also over the output script to prevent modification - # todo this should fail for p2sh - txi_sign.script_sig = output_script_p2pkh(ecdsa_hash_pubkey(key_sign_pub)) + txi_sign.script_sig = output_script_p2pkh( + ecdsa_hash_pubkey(key_sign_pub)) else: txi.script_sig = bytes() write_tx_input(h_sign, txi) @@ -219,9 +247,9 @@ async def sign_tx(tx: SignTx, root): if segwit[i]: # STAGE_REQUEST_SEGWIT_WITNESS txi = await request_tx_input(tx_req, i) + input_check_wallet_path(txi, wallet_path) - # Check amount and the control digests - if txi.amount > authorized_in or (get_tx_hash(h_first, False) != get_tx_hash(h_second, False)): + if txi.amount > authorized_in: raise SigningError(FailureType.ProcessError, 'Transaction has changed during signing') authorized_in -= txi.amount @@ -237,6 +265,8 @@ async def sign_tx(tx: SignTx, root): tx_ser.signature = signature tx_ser.serialized_tx = witness tx_req.serialized = tx_ser + else: + pass # TODO: empty witness write_uint32(tx_ser.serialized_tx, tx.lock_time) @@ -310,13 +340,13 @@ def get_address(script_type: InputScriptType, coin: CoinType, node) -> bytes: elif script_type == InputScriptType.SPENDWITNESS: # native p2wpkh if not coin.segwit or not coin.bech32_prefix: raise SigningError(FailureType.ProcessError, - 'Coin does not support segwit') + 'Segwit not enabled on this coin') return address_p2wpkh(node.public_key(), coin.bech32_prefix) elif script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh using p2sh if not coin.segwit or not coin.address_type_p2sh: raise SigningError(FailureType.ProcessError, - 'Coin does not support segwit') + 'Segwit not enabled on this coin') return address_p2wpkh_in_p2sh(node.public_key(), coin.address_type_p2sh) else: @@ -340,9 +370,12 @@ def address_p2wpkh_in_p2sh_raw(pubkey: bytes) -> bytes: return h +_BECH32_WITVER = const(0x00) + + def address_p2wpkh(pubkey: bytes, hrp: str) -> str: pubkeyhash = ecdsa_hash_pubkey(pubkey) - address = bech32.encode(hrp, 0, pubkeyhash) # TODO: constant? + address = bech32.encode(hrp, _BECH32_WITVER, pubkeyhash) if address is None: raise SigningError(FailureType.ProcessError, 'Invalid address') @@ -351,8 +384,8 @@ def address_p2wpkh(pubkey: bytes, hrp: str) -> str: def decode_bech32_address(prefix: str, address: str) -> bytes: witver, raw = bech32.decode(prefix, address) - if witver != 0: # TODO: constant? - raise SigningError(FailureType.ProcessError, + if witver != _BECH32_WITVER: + raise SigningError(FailureType.DataError, 'Invalid address witness program') return bytes(raw) @@ -365,26 +398,23 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: if o.script_type == OutputScriptType.PAYTOOPRETURN: if o.amount != 0: - raise SigningError(FailureType.ProcessError, + raise SigningError(FailureType.DataError, 'OP_RETURN output with non-zero amount') return output_script_paytoopreturn(o.op_return_data) if o.address_n: # change output if o.address: - raise SigningError(FailureType.ProcessError, - 'Both address_n and address provided') - address = get_address_for_change(o, coin, root) + raise SigningError(FailureType.DataError, 'Address in change output') + o.address = get_address_for_change(o, coin, root) else: if not o.address: - raise SigningError(FailureType.ProcessError, 'Missing address') - address = o.address + raise SigningError(FailureType.DataError, 'Missing address') - if coin.bech32_prefix and address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh - # todo check if p2wsh works - pubkeyhash = decode_bech32_address(coin.bech32_prefix, address) - return output_script_native_p2wpkh_or_p2wsh(pubkeyhash) + if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh + witprog = decode_bech32_address(coin.bech32_prefix, o.address) + return output_script_native_p2wpkh_or_p2wsh(witprog) - raw_address = base58.decode_check(address) + raw_address = base58.decode_check(o.address) if address_type.check(coin.address_type, raw_address): # p2pkh pubkeyhash = address_type.strip(coin.address_type, raw_address) @@ -394,11 +424,10 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes: scripthash = address_type.strip(coin.address_type_p2sh, raw_address) return output_script_p2sh(scripthash) - raise SigningError(FailureType.ProcessError, 'Invalid address type') + raise SigningError(FailureType.DataError, 'Invalid address type') def get_address_for_change(o: TxOutputType, coin: CoinType, root): - if o.script_type == OutputScriptType.PAYTOADDRESS: input_script_type = InputScriptType.SPENDADDRESS elif o.script_type == OutputScriptType.PAYTOMULTISIG: @@ -408,12 +437,16 @@ def get_address_for_change(o: TxOutputType, coin: CoinType, root): elif o.script_type == OutputScriptType.PAYTOP2SHWITNESS: input_script_type = InputScriptType.SPENDP2SHWITNESS else: - raise SigningError(FailureType.ProcessError, 'Invalid script type') + raise SigningError(FailureType.DataError, 'Invalid script type') return get_address(input_script_type, coin, node_derive(root, o.address_n)) -def output_is_change(o: TxOutputType) -> bool: - return bool(o.address_n) +def output_is_change(o: TxOutputType, wallet_path: list) -> bool: + address_n = o.address_n + return (address_n is not None and wallet_path is not None + and wallet_path == address_n[:-_BIP32_WALLET_DEPTH] + and address_n[-2] == _BIP32_CHANGE_CHAIN + and address_n[-1] <= _BIP32_MAX_LAST_ELEMENT) # Tx Inputs @@ -434,6 +467,28 @@ def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) -> raise SigningError(FailureType.ProcessError, 'Invalid script type') +def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list: + if wallet_path is None: + return None # there was a mismatch in previous inputs + address_n = txi.address_n[:-_BIP32_WALLET_DEPTH] + if not address_n: + return None # input path is too short + if not wallet_path: + return address_n # this is the first input + if wallet_path == address_n: + return address_n # paths match + return None # paths don't match + + +def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list: + if wallet_path is None: + return # there was a mismatch in Phase 1, ignore it now + address_n = txi.address_n[:-_BIP32_WALLET_DEPTH] + if wallet_path != address_n: + raise SigningError(FailureType.ProcessError, + 'Transaction has changed during signing') + + def node_derive(root, address_n: list): node = root.clone() node.derive_path(address_n)