1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-23 08:28:46 +00:00

core/sign_tx: Consolidate wallet path and multisig fingerprint checking.

This commit is contained in:
Andrew Kozlik 2020-04-15 17:16:20 +02:00 committed by Andrew Kozlik
parent 27e6720f3d
commit a07e125793
4 changed files with 87 additions and 74 deletions

View File

@ -46,8 +46,8 @@ class Bitcoinlike(signing.Bitcoin):
async def sign_bip143_input(self, i_sign: int) -> None: async def sign_bip143_input(self, i_sign: int) -> None:
# STAGE_REQUEST_SEGWIT_INPUT # STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
self.input_check_wallet_path(txi_sign) self.wallet_path.check_input(txi_sign)
self.input_check_multisig_fingerprint(txi_sign) self.multisig_fingerprint.check_input(txi_sign)
is_bip143 = ( is_bip143 = (
txi_sign.script_type == InputScriptType.SPENDADDRESS txi_sign.script_type == InputScriptType.SPENDADDRESS

View File

@ -111,8 +111,8 @@ class Decred(Bitcoin):
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
self.input_check_wallet_path(txi_sign) self.wallet_path.check_input(txi_sign)
self.input_check_multisig_fingerprint(txi_sign) self.multisig_fingerprint.check_input(txi_sign)
key_sign = self.keychain.derive(txi_sign.address_n, self.coin.curve_name) key_sign = self.keychain.derive(txi_sign.address_n, self.coin.curve_name)
key_sign_pub = key_sign.public_key() key_sign_pub = key_sign.public_key()

View File

@ -3,40 +3,18 @@ from trezor.crypto.hashlib import sha256
from trezor.messages import FailureType from trezor.messages import FailureType
from trezor.messages.HDNodeType import HDNodeType from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
from trezor.utils import HashWriter, ensure from trezor.utils import HashWriter
from apps.wallet.sign_tx.writers import write_bytes_fixed, write_uint32 from apps.wallet.sign_tx.writers import write_bytes_fixed, write_uint32
if False: if False:
from typing import List, Optional from typing import List
class MultisigError(ValueError): class MultisigError(ValueError):
pass pass
class MultisigFingerprint:
def __init__(self) -> None:
self.fingerprint = None # type: Optional[bytes] # multisig fingerprint bytes
self.mismatch = False # flag if multisig input fingerprints are equal
def add(self, multisig: MultisigRedeemScriptType) -> None:
fp = multisig_fingerprint(multisig)
ensure(fp is not None)
if self.fingerprint is None:
self.fingerprint = fp
elif self.fingerprint != fp:
self.mismatch = True
def matches(self, multisig: MultisigRedeemScriptType) -> bool:
fp = multisig_fingerprint(multisig)
ensure(fp is not None)
if self.mismatch is False and self.fingerprint == fp:
return True
else:
return False
def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes: def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
if multisig.nodes: if multisig.nodes:
pubnodes = multisig.nodes pubnodes = multisig.nodes

View File

@ -28,7 +28,7 @@ from apps.wallet.sign_tx import (
) )
if False: if False:
from typing import Dict, List, Optional, Union from typing import Dict, Union
# the number of bip32 levels used in a wallet (chain and address) # the number of bip32 levels used in a wallet (chain and address)
_BIP32_WALLET_DEPTH = const(2) _BIP32_WALLET_DEPTH = const(2)
@ -48,6 +48,75 @@ class SigningError(ValueError):
pass pass
class MatchChecker:
"""
MatchCheckers are used to identify the change-output in a transaction. An output is a change-output
if it has certain matching attributes with all inputs.
1. When inputs are first processed, add_input() is called on each one to determine if they all match.
2. Outputs are tested using output_matches() to tell whether they are admissible as a change-output.
3. Before signing each input, check_input() is used to ensure that the attribute has not changed.
"""
MISMATCH = object()
UNDEFINED = object()
def __init__(self) -> None:
self.attribute = self.UNDEFINED # type: object
self.read_only = False # Failsafe to ensure that add_input() is not accidentally called after output_matches().
def attribute_from_tx(self, txio: Union[TxInputType, TxOutputType]) -> object:
# Return the attribute from the txio, which is to be used for matching.
# If the txio is invalid for matching, then return an object which
# evaluates as a boolean False.
raise NotImplementedError
def add_input(self, txi: TxInputType) -> None:
ensure(not self.read_only)
if self.attribute is self.MISMATCH:
return # There was a mismatch in previous inputs.
added_attribute = self.attribute_from_tx(txi)
if not added_attribute:
self.attribute = self.MISMATCH # The added input is invalid for matching.
elif self.attribute is self.UNDEFINED:
self.attribute = added_attribute # This is the first input.
elif self.attribute != added_attribute:
self.attribute = self.MISMATCH
def check_input(self, txi: TxInputType) -> None:
if self.attribute is self.MISMATCH:
return # There was already a mismatch when adding inputs, ignore it now.
# All added inputs had a matching attribute, allowing a change-output.
# Ensure that this input still has the same attribute.
if self.attribute != self.attribute_from_tx(txi):
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
def output_matches(self, txo: TxOutputType) -> bool:
self.read_only = True
if self.attribute is self.MISMATCH:
return False
return self.attribute_from_tx(txo) == self.attribute
class WalletPathChecker(MatchChecker):
def attribute_from_tx(self, txio: Union[TxInputType, TxOutputType]) -> object:
if not txio.address_n:
return None
return txio.address_n[:-_BIP32_WALLET_DEPTH]
class MultisigFingerprintChecker(MatchChecker):
def attribute_from_tx(self, txio: Union[TxInputType, TxOutputType]) -> object:
if not txio.multisig:
return None
return multisig.multisig_fingerprint(txio.multisig)
# Transaction signing # Transaction signing
# === # ===
# see https://github.com/trezor/trezor-mcu/blob/master/firmware/signing.c#L84 # see https://github.com/trezor/trezor-mcu/blob/master/firmware/signing.c#L84
@ -93,10 +162,10 @@ class Bitcoin:
self.keychain = keychain self.keychain = keychain
# checksum of multisig inputs, used to validate change-output # checksum of multisig inputs, used to validate change-output
self.multisig_fp = multisig.MultisigFingerprint() self.multisig_fingerprint = MultisigFingerprintChecker()
# common prefix of input paths, used to validate change-output # common prefix of input paths, used to validate change-output
self.wallet_path = [] # type: Optional[List[int]] self.wallet_path = WalletPathChecker()
# dict of booleans stating if input is segwit # dict of booleans stating if input is segwit
self.segwit = {} # type: Dict[int, bool] self.segwit = {} # type: Dict[int, bool]
@ -196,7 +265,8 @@ class Bitcoin:
await helpers.request_tx_finish(self.tx_req) await helpers.request_tx_finish(self.tx_req)
async def process_input(self, i: int, txi: TxInputType) -> None: async def process_input(self, i: int, txi: TxInputType) -> None:
self.input_extract_wallet_path(txi) self.wallet_path.add_input(txi)
self.multisig_fingerprint.add_input(txi)
writers.write_tx_input_check(self.h_confirmed, txi) writers.write_tx_input_check(self.h_confirmed, txi)
self.hash143.add_prevouts(txi) # all inputs are included (non-segwit as well) self.hash143.add_prevouts(txi) # all inputs are included (non-segwit as well)
self.hash143.add_sequence(txi) self.hash143.add_sequence(txi)
@ -204,11 +274,6 @@ class Bitcoin:
if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type): if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type):
await helpers.confirm_foreign_address(txi.address_n) await helpers.confirm_foreign_address(txi.address_n)
if txi.multisig:
self.multisig_fp.add(txi.multisig)
else:
self.multisig_fp.mismatch = True
if input_is_segwit(txi): if input_is_segwit(txi):
await self.process_segwit_input(i, txi) await self.process_segwit_input(i, txi)
elif input_is_nonsegwit(txi): elif input_is_nonsegwit(txi):
@ -253,7 +318,7 @@ class Bitcoin:
raise SigningError( raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing" FailureType.ProcessError, "Transaction has changed during signing"
) )
self.input_check_wallet_path(txi) self.wallet_path.check_input(txi)
# NOTE: No need to check the multisig fingerprint, because we won't be signing # NOTE: No need to check the multisig fingerprint, because we won't be signing
# the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS. # the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS.
@ -267,8 +332,8 @@ class Bitcoin:
# STAGE_REQUEST_SEGWIT_WITNESS # STAGE_REQUEST_SEGWIT_WITNESS
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
self.input_check_wallet_path(txi) self.wallet_path.check_input(txi)
self.input_check_multisig_fingerprint(txi) self.multisig_fingerprint.check_input(txi)
if not input_is_segwit(txi) or txi.amount > self.bip143_in: if not input_is_segwit(txi) or txi.amount > self.bip143_in:
raise SigningError( raise SigningError(
@ -314,11 +379,11 @@ class Bitcoin:
for i in range(self.tx.inputs_count): for i in range(self.tx.inputs_count):
# STAGE_REQUEST_4_INPUT # STAGE_REQUEST_4_INPUT
txi = await helpers.request_tx_input(self.tx_req, i, self.coin) txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
self.input_check_wallet_path(txi)
writers.write_tx_input_check(h_check, txi) writers.write_tx_input_check(h_check, txi)
if i == i_sign: if i == i_sign:
txi_sign = txi txi_sign = txi
self.input_check_multisig_fingerprint(txi_sign) self.wallet_path.check_input(txi_sign)
self.multisig_fingerprint.check_input(txi_sign)
node = self.keychain.derive(txi.address_n, self.coin.curve_name) node = self.keychain.derive(txi.address_n, self.coin.curve_name)
key_sign_pub = node.public_key() key_sign_pub = node.public_key()
# for the signing process the script_sig is equal # for the signing process the script_sig is equal
@ -519,11 +584,10 @@ class Bitcoin:
def output_is_change(self, txo: TxOutputType) -> bool: def output_is_change(self, txo: TxOutputType) -> bool:
if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES: if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES:
return False return False
if txo.multisig and not self.multisig_fp.matches(txo.multisig): if txo.multisig and not self.multisig_fingerprint.output_matches(txo):
return False return False
return ( return (
self.wallet_path is not None self.wallet_path.output_matches(txo)
and self.wallet_path == txo.address_n[:-_BIP32_WALLET_DEPTH]
and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN
and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT
) )
@ -573,35 +637,6 @@ class Bitcoin:
else: else:
raise SigningError(FailureType.ProcessError, "Invalid script type") raise SigningError(FailureType.ProcessError, "Invalid script type")
def input_extract_wallet_path(self, txi: TxInputType) -> None:
if self.wallet_path is None:
return # there was a mismatch in previous inputs
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
if not address_n:
self.wallet_path = None # input path is too short
elif not self.wallet_path:
self.wallet_path = address_n # this is the first input
elif self.wallet_path != address_n:
self.wallet_path = None # paths don't match
def input_check_wallet_path(self, txi: TxInputType) -> None:
if self.wallet_path is None:
return # there was a mismatch in Step 1, ignore it now
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
if self.wallet_path != address_n:
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
def input_check_multisig_fingerprint(self, txi: TxInputType) -> None:
if self.multisig_fp.mismatch is False:
# All inputs in Step 1 had matching multisig fingerprints, allowing a multisig change-output.
if not txi.multisig or not self.multisig_fp.matches(txi.multisig):
# This input no longer has a matching multisig fingerprint.
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
def input_is_segwit(txi: TxInputType) -> bool: def input_is_segwit(txi: TxInputType) -> bool:
return txi.script_type in helpers.SEGWIT_INPUT_SCRIPT_TYPES return txi.script_type in helpers.SEGWIT_INPUT_SCRIPT_TYPES