1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-22 22:38:08 +00:00

feat(core): Check script type match for change-outputs in Bitcoin signing.

This commit is contained in:
Andrew Kozlik 2022-09-21 14:40:29 +02:00 committed by matejcik
parent 9528e2f9eb
commit 102ab3c7d6
3 changed files with 39 additions and 5 deletions

View File

@ -43,8 +43,7 @@ class MatchChecker(Generic[T]):
def attribute_from_tx(self, txio: TxInput | TxOutput) -> T:
# Return the attribute from the txio, which is to be used for matching.
# If the txio is invalid for matching, then return an object which
# evaluates as a boolean False.
# If the txio is invalid for matching, then return None.
raise NotImplementedError
def add_input(self, txi: TxInput) -> None:
@ -56,7 +55,7 @@ class MatchChecker(Generic[T]):
return # There was a mismatch in previous inputs.
added_attribute = self.attribute_from_tx(txi)
if not added_attribute:
if added_attribute is None:
self.attribute = self.MISMATCH # The added input is invalid for matching.
elif self.attribute is self.UNDEFINED:
self.attribute = added_attribute # This is the first input.
@ -87,7 +86,7 @@ class WalletPathChecker(MatchChecker):
def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any:
from ..common import BIP32_WALLET_DEPTH
if len(txio.address_n) < BIP32_WALLET_DEPTH:
if len(txio.address_n) <= BIP32_WALLET_DEPTH:
return None
return txio.address_n[:-BIP32_WALLET_DEPTH]
@ -99,3 +98,24 @@ class MultisigFingerprintChecker(MatchChecker):
if not txio.multisig:
return None
return multisig.multisig_fingerprint(txio.multisig)
class ScriptTypeChecker(MatchChecker):
def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any:
from trezor.enums import InputScriptType
from trezor.messages import TxOutput
from ..common import CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES
if TxOutput.is_type_of(txio):
script_type = CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES[txio.script_type]
else:
script_type = txio.script_type
# SPENDMULTISIG is used only for non-SegWit and is effectively the same as SPENDADDRESS.
# For SegWit inputs and outputs multisig is indicated by the presence of the multisig
# structure. For both SegWit and non-SegWit we can rely on MultisigFingerprintChecker to
# check the multisig structure.
if script_type == InputScriptType.SPENDMULTISIG:
script_type = InputScriptType.SPENDADDRESS
return script_type

View File

@ -58,7 +58,11 @@ class TxInfoBase:
def __init__(self, signer: Signer, tx: SignTx | PrevTx) -> None:
from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter
from .matchcheck import MultisigFingerprintChecker, WalletPathChecker
from .matchcheck import (
MultisigFingerprintChecker,
WalletPathChecker,
ScriptTypeChecker,
)
# Checksum of multisig inputs, used to validate change-output.
self.multisig_fingerprint = MultisigFingerprintChecker()
@ -66,6 +70,9 @@ class TxInfoBase:
# Common prefix of input paths, used to validate change-output.
self.wallet_path = WalletPathChecker()
# Common script type, used to validate change-output.
self.script_type = ScriptTypeChecker()
# h_tx_check is used to make sure that the inputs and outputs streamed in
# different steps are the same every time, e.g. the ones streamed for approval
# in Steps 1 and 2 and the ones streamed for signing legacy inputs in Step 4.
@ -90,6 +97,7 @@ class TxInfoBase:
if not common.input_is_external(txi):
self.wallet_path.add_input(txi)
self.script_type.add_input(txi)
self.multisig_fingerprint.add_input(txi)
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
@ -98,6 +106,7 @@ class TxInfoBase:
def check_input(self, txi: TxInput) -> None:
self.wallet_path.check_input(txi)
self.script_type.check_input(txi)
self.multisig_fingerprint.check_input(txi)
def output_is_change(self, txo: TxOutput) -> bool:
@ -107,6 +116,7 @@ class TxInfoBase:
return False
return (
self.wallet_path.output_matches(txo)
and self.script_type.output_matches(txo)
and len(txo.address_n) >= common.BIP32_WALLET_DEPTH
and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN
and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT

View File

@ -605,6 +605,8 @@ def test_send_multisig_3_change(client: Client):
expected_responses = [
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(tt, messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
(tt, messages.ButtonRequest(code=B.SignTx)),
request_input(0),
@ -690,6 +692,8 @@ def test_send_multisig_4_change(client: Client):
expected_responses = [
request_input(0),
request_output(0),
messages.ButtonRequest(code=B.ConfirmOutput),
(tt, messages.ButtonRequest(code=B.ConfirmOutput)),
messages.ButtonRequest(code=B.SignTx),
(tt, messages.ButtonRequest(code=B.SignTx)),
request_input(0),