1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-18 10:32:02 +00:00

signing/multisig: minor refactoring

This commit is contained in:
Jan Pochyla 2018-02-23 19:55:28 +01:00
parent 1b64088957
commit a46934459a
2 changed files with 42 additions and 23 deletions

View File

@ -8,22 +8,42 @@ from apps.wallet.sign_tx.writers import *
from apps.common.hash_writer import * from apps.common.hash_writer import *
class MultisigFingerprint:
def __init__(self):
self.fingerprint = None # multisig fingerprint bytes
self.mismatch = False # flag if multisig input fingerprints are equal
def add(self, multisig: MultisigRedeemScriptType):
fp = multisig_fingerprint(multisig)
assert fp is not None
if self.fingerprint is None:
self.fingerprint = fp
elif self.fingerprint != fp:
self.mismatch = True
def matches(self, multisig: MultisigRedeemScriptType):
fp = multisig_fingerprint(multisig)
assert fp is not None
if self.mismatch is False and self.fingerprint == fp:
return True
else:
return False
def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes: def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
pubkeys = multisig.pubkeys pubkeys = multisig.pubkeys
m = multisig.m m = multisig.m
n = len(pubkeys) n = len(pubkeys)
if n < 1 or n > 15: if n < 1 or n > 15 or m < 1 or m > 15:
return None raise SigningError(FailureType.DataError,
if m < 1 or m > 15: 'Invalid multisig parameters')
return None
for hd in pubkeys: for hd in pubkeys:
d = hd.node d = hd.node
if len(d.public_key) != 33: if len(d.public_key) != 33 or len(d.chain_code) != 32:
return None raise SigningError(FailureType.DataError,
if len(d.chain_code) != 32: 'Invalid multisig parameters')
return None
# casting to bytes(), sorting on bytearray() is not supported in MicroPython # casting to bytes(), sorting on bytearray() is not supported in MicroPython
pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key)) pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key))

View File

@ -51,7 +51,9 @@ async def check_tx_fee(tx: SignTx, root):
# are the same as in Phase 2. it is thus not required to fully hash the # are the same as in Phase 2. it is thus not required to fully hash the
# tx, as the SignTx info is streamed only once # tx, as the SignTx info is streamed only once
h_first = HashWriter(sha256) # not a real tx hash h_first = HashWriter(sha256) # not a real tx hash
bip143 = Bip143()
bip143 = Bip143() # bip143 transaction hashing
multifp = MultisigFingerprint() # control fp of multisig inputs
weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count) weight = TxWeightCalculator(tx.inputs_count, tx.outputs_count)
txo_bin = TxOutputBinType() txo_bin = TxOutputBinType()
@ -64,8 +66,6 @@ async def check_tx_fee(tx: SignTx, root):
change_out = 0 # change output amount change_out = 0 # change output amount
wallet_path = [] # common prefix of input paths wallet_path = [] # common prefix of input paths
segwit = {} # dict of booleans stating if input is segwit segwit = {} # dict of booleans stating if input is segwit
multisig_fp = bytes() # multisig fingerprint
multisig_fp_mismatch = False # flag if multisig input fingerprints are equal
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT # STAGE_REQUEST_1_INPUT
@ -77,11 +77,7 @@ async def check_tx_fee(tx: SignTx, root):
bip143.add_sequence(txi) bip143.add_sequence(txi)
if txi.multisig: if txi.multisig:
fp = multisig_fingerprint(txi.multisig) multifp.add(txi.multisig)
if not len(multisig_fp):
multisig_fp = fp
elif multisig_fp != fp:
multisig_fp_mismatch = True
if coin.force_bip143: if coin.force_bip143:
is_bip143 = (txi.script_type == InputScriptType.SPENDADDRESS) is_bip143 = (txi.script_type == InputScriptType.SPENDADDRESS)
@ -95,7 +91,8 @@ async def check_tx_fee(tx: SignTx, root):
segwit_in += txi.amount segwit_in += txi.amount
total_in += txi.amount total_in += txi.amount
elif txi.script_type in [InputScriptType.SPENDWITNESS, InputScriptType.SPENDP2SHWITNESS]: elif txi.script_type in (InputScriptType.SPENDWITNESS,
InputScriptType.SPENDP2SHWITNESS):
if not coin.segwit: if not coin.segwit:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError,
'Segwit not enabled on this coin') 'Segwit not enabled on this coin')
@ -106,7 +103,8 @@ async def check_tx_fee(tx: SignTx, root):
segwit_in += txi.amount segwit_in += txi.amount
total_in += txi.amount total_in += txi.amount
elif txi.script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG]: elif txi.script_type in (InputScriptType.SPENDADDRESS,
InputScriptType.SPENDMULTISIG):
segwit[i] = False segwit[i] = False
total_in += await get_prevtx_output_value( total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index) tx_req, txi.prev_hash, txi.prev_index)
@ -121,7 +119,7 @@ async def check_tx_fee(tx: SignTx, root):
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
weight.add_output(txo_bin.script_pubkey) weight.add_output(txo_bin.script_pubkey)
if (change_out == 0) and is_change(txo, wallet_path, segwit_in, multisig_fp, multisig_fp_mismatch): if (change_out == 0) and is_change(txo, wallet_path, segwit_in, multifp):
change_out = txo.amount change_out = txo.amount
elif not await confirm_output(txo, coin): elif not await confirm_output(txo, coin):
raise SigningError(FailureType.ActionCancelled, raise SigningError(FailureType.ActionCancelled,
@ -248,8 +246,9 @@ async def sign_tx(tx: SignTx, root):
# for the signing process the script_sig is equal # for the signing process the script_sig is equal
# to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH) # to the previous tx's scriptPubKey (P2PKH) or a redeem script (P2SH)
if txi_sign.script_type == InputScriptType.SPENDMULTISIG: if txi_sign.script_type == InputScriptType.SPENDMULTISIG:
txi_sign.script_sig = script_multisig(multisig_get_pubkeys(txi_sign.multisig), txi_sign.script_sig = script_multisig(
txi_sign.multisig.m) multisig_get_pubkeys(txi_sign.multisig),
txi_sign.multisig.m)
elif txi_sign.script_type == InputScriptType.SPENDADDRESS: elif txi_sign.script_type == InputScriptType.SPENDADDRESS:
txi_sign.script_sig = output_script_p2pkh( txi_sign.script_sig = output_script_p2pkh(
ecdsa_hash_pubkey(key_sign_pub)) ecdsa_hash_pubkey(key_sign_pub))
@ -554,8 +553,8 @@ def ecdsa_sign(node, digest: bytes) -> bytes:
return sigder return sigder
def is_change(txo: TxOutputType, wallet_path, segwit_in: int, multisig_fp: bytes, multisig_fp_mismatch: bool) -> bool: def is_change(txo: TxOutputType, wallet_path, segwit_in: int, multifp: MultisigFingerprint) -> bool:
if txo.multisig: if txo.multisig:
if multisig_fp_mismatch or (multisig_fp != multisig_fingerprint(txo.multisig)): if not multifp.matches(txo.multisig):
return False return False
return output_is_change(txo, wallet_path, segwit_in) return output_is_change(txo, wallet_path, segwit_in)