mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-22 22:38:08 +00:00
wallet/signing: add change output restrictions
This commit is contained in:
parent
1d97077343
commit
b1164077e9
@ -56,12 +56,14 @@ def input_script_native_p2wpkh_or_p2wsh() -> bytearray:
|
||||
return bytearray(0)
|
||||
|
||||
|
||||
# output script consists of 00 14 <20-byte-key-hash>
|
||||
def output_script_native_p2wpkh_or_p2wsh(pubkeyhash: bytes) -> bytearray:
|
||||
w = bytearray_with_cap(3 + len(pubkeyhash))
|
||||
# output script is either:
|
||||
# 00 14 <20-byte-key-hash>
|
||||
# 00 20 <32-byte-script-hash>
|
||||
def output_script_native_p2wpkh_or_p2wsh(witprog: bytes) -> bytearray:
|
||||
w = bytearray_with_cap(3 + len(witprog))
|
||||
w.append(0x00) # witness version byte
|
||||
w.append(len(pubkeyhash)) # pub key hash length is 20 (P2WPKH) or 32 (P2WSH) bytes
|
||||
write_bytes(w, pubkeyhash) # pub key hash
|
||||
w.append(len(witprog)) # pub key hash length is 20 (P2WPKH) or 32 (P2WSH) bytes
|
||||
write_bytes(w, witprog) # pub key hash
|
||||
return w
|
||||
|
||||
|
||||
|
@ -13,6 +13,16 @@ from apps.wallet.sign_tx.segwit_bip143 import *
|
||||
from apps.wallet.sign_tx.helpers import *
|
||||
from apps.wallet.sign_tx.scripts import *
|
||||
|
||||
# the number of bip32 levels used in a wallet (chain and address)
|
||||
_BIP32_WALLET_DEPTH = const(2)
|
||||
|
||||
# the chain id used for change
|
||||
_BIP32_CHANGE_CHAIN = const(1)
|
||||
|
||||
# the maximum allowed change address. this should be large enough for normal
|
||||
# use and still allow to quickly brute-force the correct bip32 path
|
||||
_BIP32_MAX_LAST_ELEMENT = const(1000000)
|
||||
|
||||
|
||||
class SigningError(ValueError):
|
||||
pass
|
||||
@ -24,6 +34,7 @@ class SigningError(ValueError):
|
||||
# for pseudo code overview
|
||||
# ===
|
||||
|
||||
|
||||
# Phase 1
|
||||
# - check inputs, previous transactions, and outputs
|
||||
# - ask for confirmations
|
||||
@ -45,37 +56,47 @@ async def check_tx_fee(tx: SignTx, root):
|
||||
total_in = 0 # sum of input amounts
|
||||
total_out = 0 # sum of output amounts
|
||||
change_out = 0 # change output amount
|
||||
wallet_path = [] # common prefix of input paths
|
||||
segwit = {} # dict of booleans stating if input is segwit
|
||||
|
||||
for i in range(tx.inputs_count):
|
||||
# STAGE_REQUEST_1_INPUT
|
||||
txi = await request_tx_input(tx_req, i)
|
||||
wallet_path = input_extract_wallet_path(txi, wallet_path)
|
||||
write_tx_input_check(h_first, txi)
|
||||
if txi.script_type in (InputScriptType.SPENDP2SHWITNESS, InputScriptType.SPENDWITNESS):
|
||||
if (txi.script_type == InputScriptType.SPENDWITNESS or
|
||||
txi.script_type == InputScriptType.SPENDP2SHWITNESS):
|
||||
if not coin.segwit:
|
||||
raise SigningError(FailureType.DataError,
|
||||
'Segwit not enabled on this coin')
|
||||
if not txi.amount:
|
||||
raise SigningError(FailureType.DataError,
|
||||
'Segwit input without amount')
|
||||
segwit[i] = True
|
||||
# Add I to segwit hash_prevouts, hash_sequence
|
||||
bip143.add_prevouts(txi)
|
||||
bip143.add_sequence(txi)
|
||||
total_in += txi.amount
|
||||
else:
|
||||
elif txi.script_type == InputScriptType.SPENDADDRESS:
|
||||
segwit[i] = False
|
||||
total_in += await get_prevtx_output_value(
|
||||
tx_req, txi.prev_hash, txi.prev_index)
|
||||
else:
|
||||
raise SigningError(FailureType.DataError,
|
||||
'Wrong input script type')
|
||||
|
||||
for o in range(tx.outputs_count):
|
||||
# STAGE_REQUEST_3_OUTPUT
|
||||
txo = await request_tx_output(tx_req, o)
|
||||
if output_is_change(txo):
|
||||
txo_bin.amount = txo.amount
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
||||
if output_is_change(txo, wallet_path):
|
||||
if change_out != 0:
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
'Only one change output is valid')
|
||||
change_out = txo.amount
|
||||
else:
|
||||
if not await confirm_output(txo, coin):
|
||||
raise SigningError(FailureType.ActionCancelled,
|
||||
'Output cancelled')
|
||||
txo_bin.amount = txo.amount
|
||||
txo_bin.script_pubkey = output_derive_script(txo, coin, root)
|
||||
elif not await confirm_output(txo, coin):
|
||||
raise SigningError(FailureType.ActionCancelled,
|
||||
'Output cancelled')
|
||||
write_tx_output(h_first, txo_bin)
|
||||
bip143.add_output(txo_bin)
|
||||
total_out += txo_bin.amount
|
||||
@ -85,7 +106,8 @@ async def check_tx_fee(tx: SignTx, root):
|
||||
raise SigningError(FailureType.NotEnoughFunds,
|
||||
'Not enough funds')
|
||||
|
||||
if fee > coin.maxfee_kb * ((estimate_tx_size(tx.inputs_count, tx.outputs_count) + 999) // 1000):
|
||||
tx_size_b = estimate_tx_size(tx.inputs_count, tx.outputs_count)
|
||||
if fee > coin.maxfee_kb * ((tx_size_b + 999) // 1000):
|
||||
if not await confirm_feeoverthreshold(fee, coin):
|
||||
raise SigningError(FailureType.ActionCancelled,
|
||||
'Signing cancelled')
|
||||
@ -94,7 +116,7 @@ async def check_tx_fee(tx: SignTx, root):
|
||||
raise SigningError(FailureType.ActionCancelled,
|
||||
'Total cancelled')
|
||||
|
||||
return h_first, tx_req, txo_bin, bip143, segwit, total_in
|
||||
return h_first, tx_req, txo_bin, bip143, segwit, total_in, wallet_path
|
||||
|
||||
|
||||
async def sign_tx(tx: SignTx, root):
|
||||
@ -103,7 +125,8 @@ async def sign_tx(tx: SignTx, root):
|
||||
|
||||
# Phase 1
|
||||
|
||||
h_first, tx_req, txo_bin, bip143, segwit, authorized_in = await check_tx_fee(tx, root)
|
||||
h_first, tx_req, txo_bin, bip143, segwit, authorized_in, wallet_path = \
|
||||
await check_tx_fee(tx, root)
|
||||
|
||||
# Phase 2
|
||||
# - sign inputs
|
||||
@ -129,32 +152,37 @@ async def sign_tx(tx: SignTx, root):
|
||||
if segwit[i_sign]:
|
||||
# STAGE_REQUEST_SEGWIT_INPUT
|
||||
txi_sign = await request_tx_input(tx_req, i_sign)
|
||||
write_tx_input_check(h_second, txi_sign)
|
||||
if txi_sign.script_type in (InputScriptType.SPENDP2SHWITNESS, InputScriptType.SPENDWITNESS):
|
||||
key_sign = node_derive(root, txi_sign.address_n)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub)
|
||||
w_txi = bytearray_with_cap(
|
||||
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
|
||||
if i_sign == 0: # serializing first input => prepend headers
|
||||
write_bytes(w_txi, get_tx_header(tx, True))
|
||||
write_tx_input(w_txi, txi_sign)
|
||||
tx_ser.serialized_tx = w_txi
|
||||
|
||||
if (txi_sign.script_type != InputScriptType.SPENDWITNESS and
|
||||
txi_sign.script_type != InputScriptType.SPENDP2SHWITNESS):
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
'Transaction has changed during signing')
|
||||
input_check_wallet_path(txi_sign, wallet_path)
|
||||
write_tx_input_check(h_second, txi_sign)
|
||||
|
||||
key_sign = node_derive(root, txi_sign.address_n)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
txi_sign.script_sig = input_derive_script(txi_sign, key_sign_pub)
|
||||
w_txi = bytearray_with_cap(
|
||||
7 + len(txi_sign.prev_hash) + 4 + len(txi_sign.script_sig) + 4)
|
||||
if i_sign == 0: # serializing first input => prepend headers
|
||||
write_bytes(w_txi, get_tx_header(tx, True))
|
||||
write_tx_input(w_txi, txi_sign)
|
||||
tx_ser.serialized_tx = w_txi
|
||||
tx_req.serialized = tx_ser
|
||||
|
||||
else:
|
||||
for i in range(tx.inputs_count):
|
||||
# STAGE_REQUEST_4_INPUT
|
||||
txi = await request_tx_input(tx_req, i)
|
||||
input_check_wallet_path(txi, wallet_path)
|
||||
write_tx_input_check(h_second, txi)
|
||||
if i == i_sign:
|
||||
txi_sign = txi
|
||||
key_sign = node_derive(root, txi.address_n)
|
||||
key_sign_pub = key_sign.public_key()
|
||||
# the signature has to be also over the output script to prevent modification
|
||||
# todo this should fail for p2sh
|
||||
txi_sign.script_sig = output_script_p2pkh(ecdsa_hash_pubkey(key_sign_pub))
|
||||
txi_sign.script_sig = output_script_p2pkh(
|
||||
ecdsa_hash_pubkey(key_sign_pub))
|
||||
else:
|
||||
txi.script_sig = bytes()
|
||||
write_tx_input(h_sign, txi)
|
||||
@ -219,9 +247,9 @@ async def sign_tx(tx: SignTx, root):
|
||||
if segwit[i]:
|
||||
# STAGE_REQUEST_SEGWIT_WITNESS
|
||||
txi = await request_tx_input(tx_req, i)
|
||||
input_check_wallet_path(txi, wallet_path)
|
||||
|
||||
# Check amount and the control digests
|
||||
if txi.amount > authorized_in or (get_tx_hash(h_first, False) != get_tx_hash(h_second, False)):
|
||||
if txi.amount > authorized_in:
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
'Transaction has changed during signing')
|
||||
authorized_in -= txi.amount
|
||||
@ -237,6 +265,8 @@ async def sign_tx(tx: SignTx, root):
|
||||
tx_ser.signature = signature
|
||||
tx_ser.serialized_tx = witness
|
||||
tx_req.serialized = tx_ser
|
||||
else:
|
||||
pass # TODO: empty witness
|
||||
|
||||
write_uint32(tx_ser.serialized_tx, tx.lock_time)
|
||||
|
||||
@ -310,13 +340,13 @@ def get_address(script_type: InputScriptType, coin: CoinType, node) -> bytes:
|
||||
elif script_type == InputScriptType.SPENDWITNESS: # native p2wpkh
|
||||
if not coin.segwit or not coin.bech32_prefix:
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
'Coin does not support segwit')
|
||||
'Segwit not enabled on this coin')
|
||||
return address_p2wpkh(node.public_key(), coin.bech32_prefix)
|
||||
|
||||
elif script_type == InputScriptType.SPENDP2SHWITNESS: # p2wpkh using p2sh
|
||||
if not coin.segwit or not coin.address_type_p2sh:
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
'Coin does not support segwit')
|
||||
'Segwit not enabled on this coin')
|
||||
return address_p2wpkh_in_p2sh(node.public_key(), coin.address_type_p2sh)
|
||||
|
||||
else:
|
||||
@ -340,9 +370,12 @@ def address_p2wpkh_in_p2sh_raw(pubkey: bytes) -> bytes:
|
||||
return h
|
||||
|
||||
|
||||
_BECH32_WITVER = const(0x00)
|
||||
|
||||
|
||||
def address_p2wpkh(pubkey: bytes, hrp: str) -> str:
|
||||
pubkeyhash = ecdsa_hash_pubkey(pubkey)
|
||||
address = bech32.encode(hrp, 0, pubkeyhash) # TODO: constant?
|
||||
address = bech32.encode(hrp, _BECH32_WITVER, pubkeyhash)
|
||||
if address is None:
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
'Invalid address')
|
||||
@ -351,8 +384,8 @@ def address_p2wpkh(pubkey: bytes, hrp: str) -> str:
|
||||
|
||||
def decode_bech32_address(prefix: str, address: str) -> bytes:
|
||||
witver, raw = bech32.decode(prefix, address)
|
||||
if witver != 0: # TODO: constant?
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
if witver != _BECH32_WITVER:
|
||||
raise SigningError(FailureType.DataError,
|
||||
'Invalid address witness program')
|
||||
return bytes(raw)
|
||||
|
||||
@ -365,26 +398,23 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
|
||||
|
||||
if o.script_type == OutputScriptType.PAYTOOPRETURN:
|
||||
if o.amount != 0:
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
raise SigningError(FailureType.DataError,
|
||||
'OP_RETURN output with non-zero amount')
|
||||
return output_script_paytoopreturn(o.op_return_data)
|
||||
|
||||
if o.address_n: # change output
|
||||
if o.address:
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
'Both address_n and address provided')
|
||||
address = get_address_for_change(o, coin, root)
|
||||
raise SigningError(FailureType.DataError, 'Address in change output')
|
||||
o.address = get_address_for_change(o, coin, root)
|
||||
else:
|
||||
if not o.address:
|
||||
raise SigningError(FailureType.ProcessError, 'Missing address')
|
||||
address = o.address
|
||||
raise SigningError(FailureType.DataError, 'Missing address')
|
||||
|
||||
if coin.bech32_prefix and address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh
|
||||
# todo check if p2wsh works
|
||||
pubkeyhash = decode_bech32_address(coin.bech32_prefix, address)
|
||||
return output_script_native_p2wpkh_or_p2wsh(pubkeyhash)
|
||||
if coin.bech32_prefix and o.address.startswith(coin.bech32_prefix): # p2wpkh or p2wsh
|
||||
witprog = decode_bech32_address(coin.bech32_prefix, o.address)
|
||||
return output_script_native_p2wpkh_or_p2wsh(witprog)
|
||||
|
||||
raw_address = base58.decode_check(address)
|
||||
raw_address = base58.decode_check(o.address)
|
||||
|
||||
if address_type.check(coin.address_type, raw_address): # p2pkh
|
||||
pubkeyhash = address_type.strip(coin.address_type, raw_address)
|
||||
@ -394,11 +424,10 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
|
||||
scripthash = address_type.strip(coin.address_type_p2sh, raw_address)
|
||||
return output_script_p2sh(scripthash)
|
||||
|
||||
raise SigningError(FailureType.ProcessError, 'Invalid address type')
|
||||
raise SigningError(FailureType.DataError, 'Invalid address type')
|
||||
|
||||
|
||||
def get_address_for_change(o: TxOutputType, coin: CoinType, root):
|
||||
|
||||
if o.script_type == OutputScriptType.PAYTOADDRESS:
|
||||
input_script_type = InputScriptType.SPENDADDRESS
|
||||
elif o.script_type == OutputScriptType.PAYTOMULTISIG:
|
||||
@ -408,12 +437,16 @@ def get_address_for_change(o: TxOutputType, coin: CoinType, root):
|
||||
elif o.script_type == OutputScriptType.PAYTOP2SHWITNESS:
|
||||
input_script_type = InputScriptType.SPENDP2SHWITNESS
|
||||
else:
|
||||
raise SigningError(FailureType.ProcessError, 'Invalid script type')
|
||||
raise SigningError(FailureType.DataError, 'Invalid script type')
|
||||
return get_address(input_script_type, coin, node_derive(root, o.address_n))
|
||||
|
||||
|
||||
def output_is_change(o: TxOutputType) -> bool:
|
||||
return bool(o.address_n)
|
||||
def output_is_change(o: TxOutputType, wallet_path: list) -> bool:
|
||||
address_n = o.address_n
|
||||
return (address_n is not None and wallet_path is not None
|
||||
and wallet_path == address_n[:-_BIP32_WALLET_DEPTH]
|
||||
and address_n[-2] == _BIP32_CHANGE_CHAIN
|
||||
and address_n[-1] <= _BIP32_MAX_LAST_ELEMENT)
|
||||
|
||||
|
||||
# Tx Inputs
|
||||
@ -434,6 +467,28 @@ def input_derive_script(i: TxInputType, pubkey: bytes, signature: bytes=None) ->
|
||||
raise SigningError(FailureType.ProcessError, 'Invalid script type')
|
||||
|
||||
|
||||
def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list:
|
||||
if wallet_path is None:
|
||||
return None # there was a mismatch in previous inputs
|
||||
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
|
||||
if not address_n:
|
||||
return None # input path is too short
|
||||
if not wallet_path:
|
||||
return address_n # this is the first input
|
||||
if wallet_path == address_n:
|
||||
return address_n # paths match
|
||||
return None # paths don't match
|
||||
|
||||
|
||||
def input_check_wallet_path(txi: TxInputType, wallet_path: list) -> list:
|
||||
if wallet_path is None:
|
||||
return # there was a mismatch in Phase 1, ignore it now
|
||||
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
|
||||
if wallet_path != address_n:
|
||||
raise SigningError(FailureType.ProcessError,
|
||||
'Transaction has changed during signing')
|
||||
|
||||
|
||||
def node_derive(root, address_n: list):
|
||||
node = root.clone()
|
||||
node.derive_path(address_n)
|
||||
|
Loading…
Reference in New Issue
Block a user