core/sign_tx: Consolidate wallet path and multisig fingerprint checking.

pull/985/head
Andrew Kozlik 4 years ago committed by Andrew Kozlik
parent 27e6720f3d
commit a07e125793

@ -46,8 +46,8 @@ class Bitcoinlike(signing.Bitcoin):
async def sign_bip143_input(self, i_sign: int) -> None:
# STAGE_REQUEST_SEGWIT_INPUT
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
self.input_check_wallet_path(txi_sign)
self.input_check_multisig_fingerprint(txi_sign)
self.wallet_path.check_input(txi_sign)
self.multisig_fingerprint.check_input(txi_sign)
is_bip143 = (
txi_sign.script_type == InputScriptType.SPENDADDRESS

@ -111,8 +111,8 @@ class Decred(Bitcoin):
txi_sign = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
self.input_check_wallet_path(txi_sign)
self.input_check_multisig_fingerprint(txi_sign)
self.wallet_path.check_input(txi_sign)
self.multisig_fingerprint.check_input(txi_sign)
key_sign = self.keychain.derive(txi_sign.address_n, self.coin.curve_name)
key_sign_pub = key_sign.public_key()

@ -3,40 +3,18 @@ from trezor.crypto.hashlib import sha256
from trezor.messages import FailureType
from trezor.messages.HDNodeType import HDNodeType
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
from trezor.utils import HashWriter, ensure
from trezor.utils import HashWriter
from apps.wallet.sign_tx.writers import write_bytes_fixed, write_uint32
if False:
from typing import List, Optional
from typing import List
class MultisigError(ValueError):
pass
class MultisigFingerprint:
def __init__(self) -> None:
self.fingerprint = None # type: Optional[bytes] # multisig fingerprint bytes
self.mismatch = False # flag if multisig input fingerprints are equal
def add(self, multisig: MultisigRedeemScriptType) -> None:
fp = multisig_fingerprint(multisig)
ensure(fp is not None)
if self.fingerprint is None:
self.fingerprint = fp
elif self.fingerprint != fp:
self.mismatch = True
def matches(self, multisig: MultisigRedeemScriptType) -> bool:
fp = multisig_fingerprint(multisig)
ensure(fp is not None)
if self.mismatch is False and self.fingerprint == fp:
return True
else:
return False
def multisig_fingerprint(multisig: MultisigRedeemScriptType) -> bytes:
if multisig.nodes:
pubnodes = multisig.nodes

@ -28,7 +28,7 @@ from apps.wallet.sign_tx import (
)
if False:
from typing import Dict, List, Optional, Union
from typing import Dict, Union
# the number of bip32 levels used in a wallet (chain and address)
_BIP32_WALLET_DEPTH = const(2)
@ -48,6 +48,75 @@ class SigningError(ValueError):
pass
class MatchChecker:
"""
MatchCheckers are used to identify the change-output in a transaction. An output is a change-output
if it has certain matching attributes with all inputs.
1. When inputs are first processed, add_input() is called on each one to determine if they all match.
2. Outputs are tested using output_matches() to tell whether they are admissible as a change-output.
3. Before signing each input, check_input() is used to ensure that the attribute has not changed.
"""
MISMATCH = object()
UNDEFINED = object()
def __init__(self) -> None:
self.attribute = self.UNDEFINED # type: object
self.read_only = False # Failsafe to ensure that add_input() is not accidentally called after output_matches().
def attribute_from_tx(self, txio: Union[TxInputType, TxOutputType]) -> object:
# Return the attribute from the txio, which is to be used for matching.
# If the txio is invalid for matching, then return an object which
# evaluates as a boolean False.
raise NotImplementedError
def add_input(self, txi: TxInputType) -> None:
ensure(not self.read_only)
if self.attribute is self.MISMATCH:
return # There was a mismatch in previous inputs.
added_attribute = self.attribute_from_tx(txi)
if not added_attribute:
self.attribute = self.MISMATCH # The added input is invalid for matching.
elif self.attribute is self.UNDEFINED:
self.attribute = added_attribute # This is the first input.
elif self.attribute != added_attribute:
self.attribute = self.MISMATCH
def check_input(self, txi: TxInputType) -> None:
if self.attribute is self.MISMATCH:
return # There was already a mismatch when adding inputs, ignore it now.
# All added inputs had a matching attribute, allowing a change-output.
# Ensure that this input still has the same attribute.
if self.attribute != self.attribute_from_tx(txi):
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
def output_matches(self, txo: TxOutputType) -> bool:
self.read_only = True
if self.attribute is self.MISMATCH:
return False
return self.attribute_from_tx(txo) == self.attribute
class WalletPathChecker(MatchChecker):
def attribute_from_tx(self, txio: Union[TxInputType, TxOutputType]) -> object:
if not txio.address_n:
return None
return txio.address_n[:-_BIP32_WALLET_DEPTH]
class MultisigFingerprintChecker(MatchChecker):
def attribute_from_tx(self, txio: Union[TxInputType, TxOutputType]) -> object:
if not txio.multisig:
return None
return multisig.multisig_fingerprint(txio.multisig)
# Transaction signing
# ===
# see https://github.com/trezor/trezor-mcu/blob/master/firmware/signing.c#L84
@ -93,10 +162,10 @@ class Bitcoin:
self.keychain = keychain
# checksum of multisig inputs, used to validate change-output
self.multisig_fp = multisig.MultisigFingerprint()
self.multisig_fingerprint = MultisigFingerprintChecker()
# common prefix of input paths, used to validate change-output
self.wallet_path = [] # type: Optional[List[int]]
self.wallet_path = WalletPathChecker()
# dict of booleans stating if input is segwit
self.segwit = {} # type: Dict[int, bool]
@ -196,7 +265,8 @@ class Bitcoin:
await helpers.request_tx_finish(self.tx_req)
async def process_input(self, i: int, txi: TxInputType) -> None:
self.input_extract_wallet_path(txi)
self.wallet_path.add_input(txi)
self.multisig_fingerprint.add_input(txi)
writers.write_tx_input_check(self.h_confirmed, txi)
self.hash143.add_prevouts(txi) # all inputs are included (non-segwit as well)
self.hash143.add_sequence(txi)
@ -204,11 +274,6 @@ class Bitcoin:
if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type):
await helpers.confirm_foreign_address(txi.address_n)
if txi.multisig:
self.multisig_fp.add(txi.multisig)
else:
self.multisig_fp.mismatch = True
if input_is_segwit(txi):
await self.process_segwit_input(i, txi)
elif input_is_nonsegwit(txi):
@ -253,7 +318,7 @@ class Bitcoin:
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
self.input_check_wallet_path(txi)
self.wallet_path.check_input(txi)
# NOTE: No need to check the multisig fingerprint, because we won't be signing
# the script here. Signatures are produced in STAGE_REQUEST_SEGWIT_WITNESS.
@ -267,8 +332,8 @@ class Bitcoin:
# STAGE_REQUEST_SEGWIT_WITNESS
txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
self.input_check_wallet_path(txi)
self.input_check_multisig_fingerprint(txi)
self.wallet_path.check_input(txi)
self.multisig_fingerprint.check_input(txi)
if not input_is_segwit(txi) or txi.amount > self.bip143_in:
raise SigningError(
@ -314,11 +379,11 @@ class Bitcoin:
for i in range(self.tx.inputs_count):
# STAGE_REQUEST_4_INPUT
txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
self.input_check_wallet_path(txi)
writers.write_tx_input_check(h_check, txi)
if i == i_sign:
txi_sign = txi
self.input_check_multisig_fingerprint(txi_sign)
self.wallet_path.check_input(txi_sign)
self.multisig_fingerprint.check_input(txi_sign)
node = self.keychain.derive(txi.address_n, self.coin.curve_name)
key_sign_pub = node.public_key()
# for the signing process the script_sig is equal
@ -519,11 +584,10 @@ class Bitcoin:
def output_is_change(self, txo: TxOutputType) -> bool:
if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES:
return False
if txo.multisig and not self.multisig_fp.matches(txo.multisig):
if txo.multisig and not self.multisig_fingerprint.output_matches(txo):
return False
return (
self.wallet_path is not None
and self.wallet_path == txo.address_n[:-_BIP32_WALLET_DEPTH]
self.wallet_path.output_matches(txo)
and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN
and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT
)
@ -573,35 +637,6 @@ class Bitcoin:
else:
raise SigningError(FailureType.ProcessError, "Invalid script type")
def input_extract_wallet_path(self, txi: TxInputType) -> None:
if self.wallet_path is None:
return # there was a mismatch in previous inputs
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
if not address_n:
self.wallet_path = None # input path is too short
elif not self.wallet_path:
self.wallet_path = address_n # this is the first input
elif self.wallet_path != address_n:
self.wallet_path = None # paths don't match
def input_check_wallet_path(self, txi: TxInputType) -> None:
if self.wallet_path is None:
return # there was a mismatch in Step 1, ignore it now
address_n = txi.address_n[:-_BIP32_WALLET_DEPTH]
if self.wallet_path != address_n:
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
def input_check_multisig_fingerprint(self, txi: TxInputType) -> None:
if self.multisig_fp.mismatch is False:
# All inputs in Step 1 had matching multisig fingerprints, allowing a multisig change-output.
if not txi.multisig or not self.multisig_fp.matches(txi.multisig):
# This input no longer has a matching multisig fingerprint.
raise SigningError(
FailureType.ProcessError, "Transaction has changed during signing"
)
def input_is_segwit(txi: TxInputType) -> bool:
return txi.script_type in helpers.SEGWIT_INPUT_SCRIPT_TYPES

Loading…
Cancel
Save