1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-18 11:21:11 +00:00

wallet/signing: fee checking seperated

This commit is contained in:
Tomas Susanka 2017-10-20 13:09:05 +02:00
parent 679d024df0
commit 039f6bad43

View File

@ -137,21 +137,10 @@ def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType:
# Transaction signing
# ===
async def check_tx_fee(tx: SignTx, root, segwit=False):
async def sign_tx(tx: SignTx, root):
tx = sanitize_sign_tx(tx)
coin = coins.by_name(tx.coin_name)
# Phase 1
# - check inputs, previous transactions, and outputs
# - ask for confirmations
# - check fee
total_in = 0 # sum of input amounts
total_out = 0 # sum of output amounts
change_out = 0 # change output amount
# h_first is used to make sure the inputs and outputs streamed in Phase 1
# are the same as in Phase 2. it is thus not required to fully hash the
# tx, as the SignTx info is streamed only once
@ -161,10 +150,17 @@ async def sign_tx(tx: SignTx, root):
tx_req = TxRequest()
tx_req.details = TxRequestDetailsType()
total_in = 0 # sum of input amounts
total_out = 0 # sum of output amounts
change_out = 0 # change output amount
for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT
txi = await request_tx_input(tx_req, i)
write_tx_input_check(h_first, txi)
if segwit:
total_in += txi.amount
else:
total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index)
@ -186,7 +182,6 @@ async def sign_tx(tx: SignTx, root):
total_out += txo_bin.amount
fee = total_in - total_out
if fee < 0:
raise SigningError(FailureType.NotEnoughFunds,
'Not enough funds')
@ -200,10 +195,25 @@ async def sign_tx(tx: SignTx, root):
raise SigningError(FailureType.ActionCancelled,
'Total cancelled')
return h_first, tx_req, txo_bin
async def sign_tx(tx: SignTx, root):
tx = sanitize_sign_tx(tx)
# Phase 1
# - check inputs, previous transactions, and outputs
# - ask for confirmations
# - check fee
h_first, tx_req, txo_bin = await check_tx_fee(tx, root)
# Phase 2
# - sign inputs
# - check that nothing changed
coin = coins.by_name(tx.coin_name)
tx_ser = TxRequestSerializedType()
for i_sign in range(tx.inputs_count):