1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-22 05:10:56 +00:00

signing/multisig: change check using multisig fingerprint

This commit is contained in:
Tomas Susanka 2018-02-08 13:14:36 +01:00 committed by Jan Pochyla
parent 26a89a8e5f
commit 985c01caf4
2 changed files with 57 additions and 8 deletions

View File

@ -1,9 +1,45 @@
from trezor.crypto.hashlib import sha256
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
from trezor.messages.HDNodePathType import HDNodePathType from trezor.messages.HDNodePathType import HDNodePathType
from apps.wallet.sign_tx.writers import * from apps.wallet.sign_tx.writers import *
from apps.common.hash_writer import *
def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
pubkeys = multisig.pubkeys
m = multisig.m
n = len(pubkeys)
if n < 1 or n > 15:
return None
if m < 1 or m > 15:
return None
for hd in pubkeys:
d = hd.node
if len(d.public_key) != 33:
return None
if len(d.chain_code) != 32:
return None
# casting to bytes(), sorting on bytearray() is not supported in MicroPython
pubkeys = sorted(pubkeys, key=lambda hd: bytes(hd.node.public_key))
h = HashWriter(sha256)
write_uint32(h, m)
write_uint32(h, n)
for hd in pubkeys:
d = hd.node
write_uint32(h, d.depth)
write_uint32(h, d.fingerprint)
write_uint32(h, d.child_num)
write_bytes(h, d.chain_code)
write_bytes(h, d.public_key)
return h.get_digest()
def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int: def multisig_pubkey_index(multisig: MultisigRedeemScriptType, pubkey: bytes) -> int:

View File

@ -64,6 +64,8 @@ async def check_tx_fee(tx: SignTx, root):
change_out = 0 # change output amount change_out = 0 # change output amount
wallet_path = [] # common prefix of input paths wallet_path = [] # common prefix of input paths
segwit = {} # dict of booleans stating if input is segwit segwit = {} # dict of booleans stating if input is segwit
multisig_fp = bytes() # multisig fingerprint
multisig_fp_mismatch = False # flag if multisig input fingerprints are equal
for i in range(tx.inputs_count): for i in range(tx.inputs_count):
# STAGE_REQUEST_1_INPUT # STAGE_REQUEST_1_INPUT
@ -73,8 +75,7 @@ async def check_tx_fee(tx: SignTx, root):
weight.add_input(txi) weight.add_input(txi)
bip143.add_prevouts(txi) # all inputs are included (non-segwit as well) bip143.add_prevouts(txi) # all inputs are included (non-segwit as well)
bip143.add_sequence(txi) bip143.add_sequence(txi)
is_segwit = (txi.script_type == InputScriptType.SPENDWITNESS or
txi.script_type == InputScriptType.SPENDP2SHWITNESS)
if coin.force_bip143: if coin.force_bip143:
is_bip143 = (txi.script_type == InputScriptType.SPENDADDRESS) is_bip143 = (txi.script_type == InputScriptType.SPENDADDRESS)
if not is_bip143: if not is_bip143:
@ -86,7 +87,8 @@ async def check_tx_fee(tx: SignTx, root):
segwit[i] = False segwit[i] = False
segwit_in += txi.amount segwit_in += txi.amount
total_in += txi.amount total_in += txi.amount
elif is_segwit:
elif txi.script_type in [InputScriptType.SPENDWITNESS, InputScriptType.SPENDP2SHWITNESS]:
if not coin.segwit: if not coin.segwit:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError,
'Segwit not enabled on this coin') 'Segwit not enabled on this coin')
@ -96,10 +98,18 @@ async def check_tx_fee(tx: SignTx, root):
segwit[i] = True segwit[i] = True
segwit_in += txi.amount segwit_in += txi.amount
total_in += txi.amount total_in += txi.amount
elif txi.script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG]: elif txi.script_type in [InputScriptType.SPENDADDRESS, InputScriptType.SPENDMULTISIG]:
segwit[i] = False segwit[i] = False
total_in += await get_prevtx_output_value( total_in += await get_prevtx_output_value(
tx_req, txi.prev_hash, txi.prev_index) tx_req, txi.prev_hash, txi.prev_index)
if txi.script_type == InputScriptType.SPENDMULTISIG:
fp = multisig_fingerprint(txi.multisig)
if not len(multisig_fp):
multisig_fp = fp
elif multisig_fp != fp:
multisig_fp_mismatch = True
else: else:
raise SigningError(FailureType.DataError, raise SigningError(FailureType.DataError,
'Wrong input script type') 'Wrong input script type')
@ -110,7 +120,7 @@ async def check_tx_fee(tx: SignTx, root):
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
txo_bin.script_pubkey = output_derive_script(txo, coin, root) txo_bin.script_pubkey = output_derive_script(txo, coin, root)
weight.add_output(txo_bin.script_pubkey) weight.add_output(txo_bin.script_pubkey)
if output_is_change(txo, wallet_path, segwit_in): if is_change(txo, wallet_path, segwit_in, multisig_fp, multisig_fp_mismatch):
if change_out != 0: if change_out != 0:
raise SigningError(FailureType.ProcessError, raise SigningError(FailureType.ProcessError,
'Only one change output is valid') 'Only one change output is valid')
@ -438,10 +448,6 @@ def output_derive_script(o: TxOutputType, coin: CoinType, root) -> bytes:
if o.address_n: # change output if o.address_n: # change output
if o.address: if o.address:
raise SigningError(FailureType.DataError, 'Address in change output') raise SigningError(FailureType.DataError, 'Address in change output')
if o.multisig:
if not check_address_n_against_pubkeys(o.multisig, o.address_n):
raise AddressError(FailureType.ProcessError,
'address_n must match one of the address_n in the MultisigRedeemScriptType pubkeys')
o.address = get_address_for_change(o, coin, root) o.address = get_address_for_change(o, coin, root)
else: else:
if not o.address: if not o.address:
@ -543,3 +549,10 @@ def ecdsa_sign(node, digest: bytes) -> bytes:
sig = secp256k1.sign(node.private_key(), digest) sig = secp256k1.sign(node.private_key(), digest)
sigder = der.encode_seq((sig[1:33], sig[33:65])) sigder = der.encode_seq((sig[1:33], sig[33:65]))
return sigder return sigder
def is_change(txo: TxOutputType, wallet_path, segwit_in: int, multisig_fp: bytes, multisig_fp_mismatch: bool) -> bool:
if txo.multisig:
if multisig_fp_mismatch or (multisig_fp != multisig_fingerprint(txo.multisig)):
return False
return output_is_change(txo, wallet_path, segwit_in)