From 102ab3c7d60712400a354804761646c41621e230 Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Wed, 21 Sep 2022 14:40:29 +0200 Subject: [PATCH] feat(core): Check script type match for change-outputs in Bitcoin signing. --- core/src/apps/bitcoin/sign_tx/matchcheck.py | 28 ++++++++++++++++--- core/src/apps/bitcoin/sign_tx/tx_info.py | 12 +++++++- .../bitcoin/test_signtx_segwit_native.py | 4 +++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/core/src/apps/bitcoin/sign_tx/matchcheck.py b/core/src/apps/bitcoin/sign_tx/matchcheck.py index d0b6ebf60c..8e97dc553e 100644 --- a/core/src/apps/bitcoin/sign_tx/matchcheck.py +++ b/core/src/apps/bitcoin/sign_tx/matchcheck.py @@ -43,8 +43,7 @@ class MatchChecker(Generic[T]): def attribute_from_tx(self, txio: TxInput | TxOutput) -> T: # Return the attribute from the txio, which is to be used for matching. - # If the txio is invalid for matching, then return an object which - # evaluates as a boolean False. + # If the txio is invalid for matching, then return None. raise NotImplementedError def add_input(self, txi: TxInput) -> None: @@ -56,7 +55,7 @@ class MatchChecker(Generic[T]): return # There was a mismatch in previous inputs. added_attribute = self.attribute_from_tx(txi) - if not added_attribute: + if added_attribute is None: self.attribute = self.MISMATCH # The added input is invalid for matching. elif self.attribute is self.UNDEFINED: self.attribute = added_attribute # This is the first input. @@ -87,7 +86,7 @@ class WalletPathChecker(MatchChecker): def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any: from ..common import BIP32_WALLET_DEPTH - if len(txio.address_n) < BIP32_WALLET_DEPTH: + if len(txio.address_n) <= BIP32_WALLET_DEPTH: return None return txio.address_n[:-BIP32_WALLET_DEPTH] @@ -99,3 +98,24 @@ class MultisigFingerprintChecker(MatchChecker): if not txio.multisig: return None return multisig.multisig_fingerprint(txio.multisig) + + +class ScriptTypeChecker(MatchChecker): + def attribute_from_tx(self, txio: TxInput | TxOutput) -> Any: + from trezor.enums import InputScriptType + from trezor.messages import TxOutput + from ..common import CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES + + if TxOutput.is_type_of(txio): + script_type = CHANGE_OUTPUT_TO_INPUT_SCRIPT_TYPES[txio.script_type] + else: + script_type = txio.script_type + + # SPENDMULTISIG is used only for non-SegWit and is effectively the same as SPENDADDRESS. + # For SegWit inputs and outputs multisig is indicated by the presence of the multisig + # structure. For both SegWit and non-SegWit we can rely on MultisigFingerprintChecker to + # check the multisig structure. + if script_type == InputScriptType.SPENDMULTISIG: + script_type = InputScriptType.SPENDADDRESS + + return script_type diff --git a/core/src/apps/bitcoin/sign_tx/tx_info.py b/core/src/apps/bitcoin/sign_tx/tx_info.py index 65cf2cd3e2..b13ee715bf 100644 --- a/core/src/apps/bitcoin/sign_tx/tx_info.py +++ b/core/src/apps/bitcoin/sign_tx/tx_info.py @@ -58,7 +58,11 @@ class TxInfoBase: def __init__(self, signer: Signer, tx: SignTx | PrevTx) -> None: from trezor.crypto.hashlib import sha256 from trezor.utils import HashWriter - from .matchcheck import MultisigFingerprintChecker, WalletPathChecker + from .matchcheck import ( + MultisigFingerprintChecker, + WalletPathChecker, + ScriptTypeChecker, + ) # Checksum of multisig inputs, used to validate change-output. self.multisig_fingerprint = MultisigFingerprintChecker() @@ -66,6 +70,9 @@ class TxInfoBase: # Common prefix of input paths, used to validate change-output. self.wallet_path = WalletPathChecker() + # Common script type, used to validate change-output. + self.script_type = ScriptTypeChecker() + # h_tx_check is used to make sure that the inputs and outputs streamed in # different steps are the same every time, e.g. the ones streamed for approval # in Steps 1 and 2 and the ones streamed for signing legacy inputs in Step 4. @@ -90,6 +97,7 @@ class TxInfoBase: if not common.input_is_external(txi): self.wallet_path.add_input(txi) + self.script_type.add_input(txi) self.multisig_fingerprint.add_input(txi) def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None: @@ -98,6 +106,7 @@ class TxInfoBase: def check_input(self, txi: TxInput) -> None: self.wallet_path.check_input(txi) + self.script_type.check_input(txi) self.multisig_fingerprint.check_input(txi) def output_is_change(self, txo: TxOutput) -> bool: @@ -107,6 +116,7 @@ class TxInfoBase: return False return ( self.wallet_path.output_matches(txo) + and self.script_type.output_matches(txo) and len(txo.address_n) >= common.BIP32_WALLET_DEPTH and txo.address_n[-2] <= _BIP32_CHANGE_CHAIN and txo.address_n[-1] <= _BIP32_MAX_LAST_ELEMENT diff --git a/tests/device_tests/bitcoin/test_signtx_segwit_native.py b/tests/device_tests/bitcoin/test_signtx_segwit_native.py index e6404a5e6c..0a3e994e65 100644 --- a/tests/device_tests/bitcoin/test_signtx_segwit_native.py +++ b/tests/device_tests/bitcoin/test_signtx_segwit_native.py @@ -605,6 +605,8 @@ def test_send_multisig_3_change(client: Client): expected_responses = [ request_input(0), request_output(0), + messages.ButtonRequest(code=B.ConfirmOutput), + (tt, messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), (tt, messages.ButtonRequest(code=B.SignTx)), request_input(0), @@ -690,6 +692,8 @@ def test_send_multisig_4_change(client: Client): expected_responses = [ request_input(0), request_output(0), + messages.ButtonRequest(code=B.ConfirmOutput), + (tt, messages.ButtonRequest(code=B.ConfirmOutput)), messages.ButtonRequest(code=B.SignTx), (tt, messages.ButtonRequest(code=B.SignTx)), request_input(0),