diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index 199e13461..32d829edb 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -161,7 +161,7 @@ async def check_tx_fee(tx: SignTx, root: bip32.HDNode): txo_bin.script_pubkey = output_derive_script(txo, coin, root) weight.add_output(txo_bin.script_pubkey) - if change_out == 0 and is_change(txo, wallet_path, segwit_in, multifp): + if change_out == 0 and output_is_change(txo, wallet_path, segwit_in, multifp): # output is change and does not need confirmation change_out = txo.amount elif not await helpers.confirm_output(txo, coin): @@ -241,11 +241,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): # STAGE_REQUEST_SEGWIT_INPUT txi_sign = await helpers.request_tx_input(tx_req, i_sign) - is_segwit = ( - txi_sign.script_type == InputScriptType.SPENDWITNESS - or txi_sign.script_type == InputScriptType.SPENDP2SHWITNESS - ) - if not is_segwit: + if not input_is_segwit(txi_sign): raise SigningError( FailureType.ProcessError, "Transaction has changed during signing" ) @@ -289,7 +285,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): get_hash_type(coin), ) - # if multisig, check if singing with a key that is included in multisig + # if multisig, check if signing with a key that is included in multisig if txi_sign.multisig: multisig.multisig_pubkey_index(txi_sign.multisig, key_sign_pub) @@ -452,7 +448,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): FailureType.ProcessError, "Transaction has changed during signing" ) - # if multisig, check if singing with a key that is included in multisig + # if multisig, check if signing with a key that is included in multisig if txi_sign.multisig: multisig.multisig_pubkey_index(txi_sign.multisig, key_sign_pub) @@ -508,11 +504,7 @@ async def sign_tx(tx: SignTx, root: bip32.HDNode): txi = await helpers.request_tx_input(tx_req, i) input_check_wallet_path(txi, wallet_path) - is_segwit = ( - txi.script_type == InputScriptType.SPENDWITNESS - or txi.script_type == InputScriptType.SPENDP2SHWITNESS - ) - if not is_segwit or txi.amount > authorized_in: + if not input_is_segwit(txi) or txi.amount > authorized_in: raise SigningError( FailureType.ProcessError, "Transaction has changed during signing" ) @@ -764,12 +756,15 @@ def get_address_for_change( ) -def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool: - is_segwit = ( - o.script_type == OutputScriptType.PAYTOWITNESS - or o.script_type == OutputScriptType.PAYTOP2SHWITNESS - ) - if is_segwit and o.amount > segwit_in: +def output_is_change( + o: TxOutputType, + wallet_path: list, + segwit_in: int, + multifp: multisig.MultisigFingerprint, +) -> bool: + if o.multisig and not multifp.matches(o.multisig): + return False + if output_is_segwit(o) and o.amount > segwit_in: # if the output is segwit, make sure it doesn't spend more than what the # segwit inputs paid. this is to prevent user being tricked into # creating ANYONECANSPEND outputs before full segwit activation. @@ -782,6 +777,13 @@ def output_is_change(o: TxOutputType, wallet_path: list, segwit_in: int) -> bool ) +def output_is_segwit(o: TxOutputType) -> bool: + return ( + o.script_type == OutputScriptType.PAYTOWITNESS + or o.script_type == OutputScriptType.PAYTOP2SHWITNESS + ) + + # Tx Inputs # === @@ -825,6 +827,13 @@ def input_derive_script( raise SigningError(FailureType.ProcessError, "Invalid script type") +def input_is_segwit(i: TxInputType) -> bool: + return ( + i.script_type == InputScriptType.SPENDWITNESS + or i.script_type == InputScriptType.SPENDP2SHWITNESS + ) + + def input_extract_wallet_path(txi: TxInputType, wallet_path: list) -> list: if wallet_path is None: return None # there was a mismatch in previous inputs @@ -858,15 +867,3 @@ def ecdsa_sign(node: bip32.HDNode, digest: bytes) -> bytes: sig = secp256k1.sign(node.private_key(), digest) sigder = der.encode_seq((sig[1:33], sig[33:65])) return sigder - - -def is_change( - txo: TxOutputType, - wallet_path: list, - segwit_in: int, - multifp: multisig.MultisigFingerprint, -) -> bool: - if txo.multisig: - if not multifp.matches(txo.multisig): - return False - return output_is_change(txo, wallet_path, segwit_in)