diff --git a/src/apps/wallet/sign_tx/multisig.py b/src/apps/wallet/sign_tx/multisig.py index 108b98a72..fb80c2434 100644 --- a/src/apps/wallet/sign_tx/multisig.py +++ b/src/apps/wallet/sign_tx/multisig.py @@ -8,22 +8,42 @@ from apps.wallet.sign_tx.writers import * from apps.common.hash_writer import * +class MultisigFingerprint: + def __init__(self): + self.fingerprint = None # multisig fingerprint bytes + self.mismatch = False # flag if multisig input fingerprints are equal + + def add(self, multisig: MultisigRedeemScriptType): + fp = multisig_fingerprint(multisig) + assert fp is not None + if self.fingerprint is None: + self.fingerprint = fp + elif self.fingerprint != fp: + self.mismatch = True + + def matches(self, multisig: MultisigRedeemScriptType): + fp = multisig_fingerprint(multisig) + assert fp is not None + if self.mismatch is False and self.fingerprint == fp: + return True + else: + return False + + def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes: pubkeys = multisig.pubkeys m = multisig.m n = len(pubkeys) - if n < 1 or n > 15: - return None - if m < 1 or m > 15: - return None + if n < 1 or n > 15 or m < 1 or m > 15: + raise SigningError(FailureType.DataError, + 'Invalid multisig parameters') for hd in pubkeys: d = hd.node - if len(d.public_key) != 33: - return None - if len(d.chain_code) != 32: - return None + if len(d.public_key) != 33 or len(d.chain_code) != 32: + raise SigningError(FailureType.DataError, + 'Invalid multisig parameters') # casting to bytes(), sorting on bytearray() is not supported in MicroPython pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key)) diff --git a/src/apps/wallet/sign_tx/signing.py b/src/apps/wallet/sign_tx/signing.py index 81dd6763e..e5a0cd48e 100644 --- a/src/apps/wallet/sign_tx/signing.py +++ b/src/apps/wallet/sign_tx/signing.py @@ -51,7 +51,9 @@ async def check_tx_fee(tx: SignTx, root): # are the same as in Phase 2. it is thus not required to fully hash the # tx, as the SignTx info is streamed only once h_first = HashWriter(sha256) # not a real tx hash - bip143 = Bip143() + + bip143 = Bip143() # bip143 transaction hashing + multifp = MultisigFingerprint() # control fp of multisig inputs weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count) txo_bin = TxOutputBinType() @@ -64,8 +66,6 @@ async def check_tx_fee(tx: SignTx, root): change_out = 0 # change output amount wallet_path = [] # common prefix of input paths segwit = {} # dict of booleans stating if input is segwit - multisig_fp = bytes() # multisig fingerprint - multisig_fp_mismatch = False # flag if multisig input fingerprints are equal for i in range(tx.inputs_count): # STAGE_REQUEST_1_INPUT @@ -77,11 +77,7 @@ async def check_tx_fee(tx: SignTx, root): bip143.add_sequence(txi) if txi.multisig: - fp = multisig_fingerprint(txi.multisig) - if not len(multisig_fp): - multisig_fp = fp - elif multisig_fp != fp: - multisig_fp_mismatch = True + multifp.add(txi.multisig) if coin.force_bip143: is_bip143 = (txi.script_type == InputScriptType.SPENDADDRESS) @@ -95,7 +91,8 @@ async def check_tx_fee(tx: SignTx, root): segwit_in += txi.amount total_in += txi.amount - elif txi.script_type in [InputScriptType.SPENDWITNESS, InputScriptType.SPENDP2SHWITNESS]: + elif txi.script_type in (InputScriptType.SPENDWITNESS, + InputScriptType.SPENDP2SHWITNESS): if not coin.segwit: raise SigningError(FailureType.DataError, 'Segwit not enabled on this coin') @@ -106,7 +103,8 @@ async def check_tx_fee(tx: SignTx, root): segwit_in += txi.amount total_in += txi.amount - elif txi.script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG]: + elif txi.script_type in (InputScriptType.SPENDADDRESS, + InputScriptType.SPENDMULTISIG): segwit[i] = False total_in += await get_prevtx_output_value( tx_req, txi.prev_hash, txi.prev_index) @@ -121,7 +119,7 @@ async def check_tx_fee(tx: SignTx, root): txo_bin.amount = txo.amount txo_bin.script_pubkey = output_derive_script(txo, coin, root) weight.add_output(txo_bin.script_pubkey) - if (change_out == 0) and is_change(txo, wallet_path, segwit_in, multisig_fp, multisig_fp_mismatch): + if (change_out == 0) and is_change(txo, wallet_path, segwit_in, multifp): change_out = txo.amount elif not await confirm_output(txo, coin): raise SigningError(FailureType.ActionCancelled, @@ -248,8 +246,9 @@ async def sign_tx(tx: SignTx, root): # for the signing process the script_sig is equal # to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH) if txi_sign.script_type == InputScriptType.SPENDMULTISIG: - txi_sign.script_sig = script_multisig(multisig_get_pubkeys(txi_sign.multisig), - txi_sign.multisig.m) + txi_sign.script_sig = script_multisig( + multisig_get_pubkeys(txi_sign.multisig), + txi_sign.multisig.m) elif txi_sign.script_type == InputScriptType.SPENDADDRESS: txi_sign.script_sig = output_script_p2pkh( ecdsa_hash_pubkey(key_sign_pub)) @@ -554,8 +553,8 @@ def ecdsa_sign(node, digest: bytes) -> bytes: return sigder -def is_change(txo: TxOutputType, wallet_path, segwit_in: int, multisig_fp: bytes, multisig_fp_mismatch: bool) -> bool: +def is_change(txo: TxOutputType, wallet_path, segwit_in: int, multifp: MultisigFingerprint) -> bool: if txo.multisig: - if multisig_fp_mismatch or (multisig_fp != multisig_fingerprint(txo.multisig)): + if not multifp.matches(txo.multisig): return False return output_is_change(txo, wallet_path, segwit_in)