1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

wallet/signing: add change output restrictions

This commit is contained in:
Jan Pochyla 2017-11-20 12:47:39 +01:00 committed by Tomas Susanka
parent 1d97077343
commit b1164077e9
2 changed files with 112 additions and 55 deletions

View File

@ -56,12 +56,14 @@ def input_script_native_p2wpkh_or_p2wsh() -> bytearray:
return bytearray(0)
# output script consists of 00 14 <20-byte-key-hash>
def output_script_native_p2wpkh_or_p2wsh(pubkeyhash: bytes) -> bytearray:
w = bytearray_with_cap(3 + len(pubkeyhash))
# output script is either:
# 00 14 <20-byte-key-hash>
# 00 20 <32-byte-script-hash>
def output_script_native_p2wpkh_or_p2wsh(witprog: bytes) -> bytearray:
w = bytearray_with_cap(3 + len(witprog))
w.append(0x00) # witness version byte
w.append(len(pubkeyhash)) # pub key hash length is 20 (P2WPKH) or 32 (P2WSH) bytes
write_bytes(w, pubkeyhash) # pub key hash
w.append(len(witprog)) # pub key hash length is 20 (P2WPKH) or 32 (P2WSH) bytes
write_bytes(w, witprog) # pub key hash
return w

View File

@ -13,6 +13,16 @@ from apps.wallet.sign_tx.segwit_bip143 import *
from apps.wallet.sign_tx.helpers import *
from apps.wallet.sign_tx.scripts import *
# the number of bip32 levels used in a wallet (chain and address)
_BIP32_WALLET_DEPTH = const(2)
# the chain id used for change
_BIP32_CHANGE_CHAIN = const(1)
# the maximum allowed change address. this should be large enough for normal
# use and still allow to quickly brute-force the correct bip32 path
_BIP32_MAX_LAST_ELEMENT = const(1000000)
class SigningError(ValueError):
pass
@ -24,6 +34,7 @@ class SigningError(ValueError):
# for pseudo code overview
# ===
# Phase 1
# - check inputs, previous transactions, and outputs
# - ask for confirmations
@ -45,37 +56,47 @@ async def check_tx_fee(tx: SignTx, root):
total_in = 0 # sum of input amounts
total_out = 0 # sum of output amounts
change_out = 0 # change output amount
wallet_path = [] # common prefix of input paths
segwit = {} # dict of booleans stating if input is segwit
for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i)
wallet_path = input_extract_wallet_path(txi, wallet_path)
write_tx_input_check(h_first, txi)
if txi.script_type in (InputScriptType.SPENDP2SHWITNESS, InputScriptType.SPENDWITNESS):
if (txi.script_type == InputScriptType.SPENDWITNESS or
txi.script_type == InputScriptType.SPENDP2SHWITNESS):
if not coin.segwit:
raise SigningError(FailureType.DataError,
'Segwit not enabled on this coin')
if not txi.amount:
raise SigningError(FailureType.DataError,
'Segwit input without amount')
segwit[i] = True
# Add I to segwit hash_prevouts, hash_sequence
bip143.add_prevouts(txi)
bip143.add_sequence(txi)
total_in += txi.amount
else:
elif txi.script_type == InputScriptType.SPENDADDRESS:
segwit[i] = False
total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index)
else:
raise SigningError(FailureType.DataError,
'Wrong input script type')
for o in range(tx.outputs_count):
# STAGE_REQUEST_3_OUTPUT
txo = await request_tx_output(tx_req, o)
if output_is_change(txo):
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
if output_is_change(txo, wallet_path):
if change_out != 0:
raise SigningError(FailureType.ProcessError,
'Only one change output is valid')
change_out = txo.amount
else:
if not await confirm_output(txo, coin):
raise SigningError(FailureType.ActionCancelled,
'Output cancelled')
txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
elif not await confirm_output(txo, coin):
raise SigningError(FailureType.ActionCancelled,
'Output cancelled')
write_tx_output(h_first, txo_bin)
bip143.add_output(txo_bin)
total_out += txo_bin.amount
@ -85,7 +106,8 @@ async def check_tx_fee(tx: SignTx, root):
raise SigningError(FailureType.NotEnoughFunds,
'Not enough funds')
if fee > coin.maxfee_kb * ((estimate_tx_size(tx.inputs_count, tx.outputs_count) + 999) // 1000):
tx_size_b = estimate_tx_size(tx.inputs_count, tx.outputs_count)
if fee > coin.maxfee_kb * ((tx_size_b + 999) // 1000):
if not await confirm_feeoverthreshold(fee, coin):
raise SigningError(FailureType.ActionCancelled,
'Signing cancelled')
@ -94,7 +116,7 @@ async def check_tx_fee(tx: SignTx, root):
raise SigningError(FailureType.ActionCancelled,
'Total cancelled')
return h_first, tx_req, txo_bin, bip143, segwit, total_in
return h_first, tx_req, txo_bin, bip143, segwit, total_in, wallet_path
async def sign_tx(tx: SignTx, root):
@ -103,7 +125,8 @@ async def sign_tx(tx: SignTx, root):
# Phase 1
h_first, tx_req, txo_bin, bip143, segwit, authorized_in = await check_tx_fee(tx, root)
h_first, tx_req, txo_bin, bip143, segwit, authorized_in, wallet_path = \
await check_tx_fee(tx, root)
# Phase 2
# - sign inputs
@ -129,32 +152,37 @@ async def sign_tx(tx: SignTx, root):
if segwit[i_sign]:
# STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await request_tx_input(tx_req, i_sign)
write_tx_input_check(h_second, txi_sign)
if txi_sign.script_type in (InputScriptType.SPENDP2SHWITNESS, InputScriptType.SPENDWITNESS):
key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key()
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub)
w_txi = bytearray_with_cap(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi, get_tx_header(tx, True))
write_tx_input(w_txi, txi_sign)
tx_ser.serialized_tx = w_txi
if (txi_sign.script_type != InputScriptType.SPENDWITNESS and
txi_sign.script_type != InputScriptType.SPENDP2SHWITNESS):
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
input_check_wallet_path(txi_sign, wallet_path)
write_tx_input_check(h_second, txi_sign)
key_sign = node_derive(root, txi_sign.address_n)
key_sign_pub = key_sign.public_key()
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub)
w_txi = bytearray_with_cap(
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
if i_sign == 0: # serializing first input => prepend headers
write_bytes(w_txi, get_tx_header(tx, True))
write_tx_input(w_txi, txi_sign)
tx_ser.serialized_tx = w_txi
tx_req.serialized = tx_ser
else:
for i in range(tx.inputs_count):
# STAGE_REQUEST_4_INPUT
txi = await request_tx_input(tx_req, i)
input_check_wallet_path(txi, wallet_path)
write_tx_input_check(h_second, txi)
if i == i_sign:
txi_sign = txi
key_sign = node_derive(root, txi.address_n)
key_sign_pub = key_sign.public_key()
# the signature has to be also over the output script to prevent modification
# todo this should fail for p2sh
txi_sign.script_sig = output_script_p2pkh(ecdsa_hash_pubkey(key_sign_pub))
txi_sign.script_sig = output_script_p2pkh(
ecdsa_hash_pubkey(key_sign_pub))
else:
txi.script_sig = bytes()
write_tx_input(h_sign, txi)
@ -219,9 +247,9 @@ async def sign_tx(tx: SignTx, root):
if segwit[i]:
# STAGE_REQUEST_SEGWIT_WITNESS
txi = await request_tx_input(tx_req, i)
input_check_wallet_path(txi, wallet_path)
# Check amount and the control digests
if txi.amount > authorized_in or (get_tx_hash(h_first, False) != get_tx_hash(h_second, False)):
if txi.amount > authorized_in:
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
authorized_in -= txi.amount
@ -237,6 +265,8 @@ async def sign_tx(tx: SignTx, root):
tx_ser.signature = signature
tx_ser.serialized_tx = witness
tx_req.serialized = tx_ser
else:
pass # TODO: empty witness
write_uint32(tx_ser.serialized_tx, tx.lock_time)
@ -310,13 +340,13 @@ def get_address(script_type: InputScriptType, coin: CoinType, node) -> bytes:
elif script_type == InputScriptType.SPENDWITNESS: # native p2wpkh
if not coin.segwit or not coin.bech32_prefix:
raise SigningError(FailureType.ProcessError,
'Coin does not support segwit')
'Segwit not enabled on this coin')
return address_p2wpkh(node.public_key(), coin.bech32_prefix)
elif script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh using p2sh
if not coin.segwit or not coin.address_type_p2sh:
raise SigningError(FailureType.ProcessError,
'Coin does not support segwit')
'Segwit not enabled on this coin')
return address_p2wpkh_in_p2sh(node.public_key(), coin.address_type_p2sh)
else:
@ -340,9 +370,12 @@ def address_p2wpkh_in_p2sh_raw(pubkey: bytes) -> bytes:
return h
_BECH32_WITVER = const(0x00)
def address_p2wpkh(pubkey: bytes, hrp: str) -> str:
pubkeyhash = ecdsa_hash_pubkey(pubkey)
address = bech32.encode(hrp, 0, pubkeyhash) # TODO: constant?
address = bech32.encode(hrp, _BECH32_WITVER, pubkeyhash)
if address is None:
raise SigningError(FailureType.ProcessError,
'Invalid address')
@ -351,8 +384,8 @@ def address_p2wpkh(pubkey: bytes, hrp: str) -> str:
def decode_bech32_address(prefix: str, address: str) -> bytes:
witver, raw = bech32.decode(prefix, address)
if witver != 0: # TODO: constant?
raise SigningError(FailureType.ProcessError,
if witver != _BECH32_WITVER:
raise SigningError(FailureType.DataError,
'Invalid address witness program')
return bytes(raw)
@ -365,26 +398,23 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
if o.script_type == OutputScriptType.PAYTOOPRETURN:
if o.amount != 0:
raise SigningError(FailureType.ProcessError,
raise SigningError(FailureType.DataError,
'OP_RETURN output with non-zero amount')
return output_script_paytoopreturn(o.op_return_data)
if o.address_n: # change output
if o.address:
raise SigningError(FailureType.ProcessError,
'Both address_n and address provided')
address = get_address_for_change(o, coin, root)
raise SigningError(FailureType.DataError, 'Address in change output')
o.address = get_address_for_change(o, coin, root)
else:
if not o.address:
raise SigningError(FailureType.ProcessError, 'Missing address')
address = o.address
raise SigningError(FailureType.DataError, 'Missing address')
if coin.bech32_prefix and address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh
# todo check if p2wsh works
pubkeyhash = decode_bech32_address(coin.bech32_prefix, address)
return output_script_native_p2wpkh_or_p2wsh(pubkeyhash)
if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh
witprog = decode_bech32_address(coin.bech32_prefix, o.address)
return output_script_native_p2wpkh_or_p2wsh(witprog)
raw_address = base58.decode_check(address)
raw_address = base58.decode_check(o.address)
if address_type.check(coin.address_type, raw_address): # p2pkh
pubkeyhash = address_type.strip(coin.address_type, raw_address)
@ -394,11 +424,10 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
scripthash = address_type.strip(coin.address_type_p2sh, raw_address)
return output_script_p2sh(scripthash)
raise SigningError(FailureType.ProcessError, 'Invalid address type')
raise SigningError(FailureType.DataError, 'Invalid address type')
def get_address_for_change(o: TxOutputType, coin: CoinType, root):
if o.script_type == OutputScriptType.PAYTOADDRESS:
input_script_type = InputScriptType.SPENDADDRESS
elif o.script_type == OutputScriptType.PAYTOMULTISIG:
@ -408,12 +437,16 @@ def get_address_for_change(o: TxOutputType, coin: CoinType, root):
elif o.script_type == OutputScriptType.PAYTOP2SHWITNESS:
input_script_type = InputScriptType.SPENDP2SHWITNESS
else:
raise SigningError(FailureType.ProcessError, 'Invalid script type')
raise SigningError(FailureType.DataError, 'Invalid script type')
return get_address(input_script_type, coin, node_derive(root, o.address_n))
def output_is_change(o: TxOutputType) -> bool:
return bool(o.address_n)
def output_is_change(o: TxOutputType, wallet_path: list) -> bool:
address_n = o.address_n
return (address_n is not None and wallet_path is not None
and wallet_path == address_n[:-_BIP32_WALLET_DEPTH]
and address_n[-2] == _BIP32_CHANGE_CHAIN
and address_n[-1] <= _BIP32_MAX_LAST_ELEMENT)
# Tx Inputs
@ -434,6 +467,28 @@ def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) ->
raise SigningError(FailureType.ProcessError, 'Invalid script type')
def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list:
if wallet_path is None:
return None # there was a mismatch in previous inputs
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
if not address_n:
return None # input path is too short
if not wallet_path:
return address_n # this is the first input
if wallet_path == address_n:
return address_n # paths match
return None # paths don't match
def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
if wallet_path is None:
return # there was a mismatch in Phase 1, ignore it now
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
if wallet_path != address_n:
raise SigningError(FailureType.ProcessError,
'Transaction has changed during signing')
def node_derive(root, address_n: list):
node = root.clone()
node.derive_path(address_n)