diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index dfd05a24a..6c955498a 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -137,21 +137,10 @@ def sanitize_tx_binoutput(tx: TransactionType) -> TxOutputBinType: # Transaction signing # === +async def check_tx_fee(tx: SignTx, root, segwit=False): -async def sign_tx(tx: SignTx, root): - - tx = sanitize_sign_tx(tx) coin = coins.by_name(tx.coin_name) - # Phase 1 - # - check inputs, previous transactions, and outputs - # - ask for confirmations - # - check fee - - total_in = 0 # sum of input amounts - total_out = 0 # sum of output amounts - change_out = 0 # change output amount - # h_first is used to make sure the inputs and outputs streamed in Phase 1 # are the same as in Phase 2. it is thus not required to fully hash the # tx, as the SignTx info is streamed only once @@ -161,12 +150,19 @@ async def sign_tx(tx: SignTx, root): tx_req = TxRequest() tx_req.details = TxRequestDetailsType() + total_in = 0 # sum of input amounts + total_out = 0 # sum of output amounts + change_out = 0 # change output amount + for i in range(tx.inputs_count): # STAGE_REQUEST_1_INPUT txi = await request_tx_input(tx_req, i) write_tx_input_check(h_first, txi) - total_in += await get_prevtx_output_value( - tx_req, txi.prev_hash, txi.prev_index) + if segwit: + total_in += txi.amount + else: + total_in += await get_prevtx_output_value( + tx_req, txi.prev_hash, txi.prev_index) for o in range(tx.outputs_count): # STAGE_REQUEST_3_OUTPUT @@ -186,7 +182,6 @@ async def sign_tx(tx: SignTx, root): total_out += txo_bin.amount fee = total_in - total_out - if fee < 0: raise SigningError(FailureType.NotEnoughFunds, 'Not enough funds') @@ -200,10 +195,25 @@ async def sign_tx(tx: SignTx, root): raise SigningError(FailureType.ActionCancelled, 'Total cancelled') + return h_first, tx_req, txo_bin + + +async def sign_tx(tx: SignTx, root): + + tx = sanitize_sign_tx(tx) + + # Phase 1 + # - check inputs, previous transactions, and outputs + # - ask for confirmations + # - check fee + + h_first, tx_req, txo_bin = await check_tx_fee(tx, root) + # Phase 2 # - sign inputs # - check that nothing changed + coin = coins.by_name(tx.coin_name) tx_ser = TxRequestSerializedType() for i_sign in range(tx.inputs_count):