From 42eddf8e04c059832e0205925fd60f0943f3884b Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Fri, 15 May 2020 15:22:34 +0200 Subject: [PATCH] core/sign_tx: validate prevout amount in all cases --- core/src/apps/bitcoin/sign_tx/bitcoin.py | 29 +++++-------------- core/src/apps/bitcoin/sign_tx/bitcoinlike.py | 22 ++++---------- core/src/apps/bitcoin/sign_tx/zcash.py | 9 +++--- ...ps.bitcoin.segwit.signtx.p2wpkh_in_p2sh.py | 2 +- 4 files changed, 18 insertions(+), 44 deletions(-) diff --git a/core/src/apps/bitcoin/sign_tx/bitcoin.py b/core/src/apps/bitcoin/sign_tx/bitcoin.py index efa6f0c29..240096a1d 100644 --- a/core/src/apps/bitcoin/sign_tx/bitcoin.py +++ b/core/src/apps/bitcoin/sign_tx/bitcoin.py @@ -84,7 +84,6 @@ class Bitcoin: # amounts self.total_in = 0 # sum of input amounts - self.bip143_in = 0 # sum of segwit input amounts self.total_out = 0 # sum of output amounts self.change_out = 0 # change output amount self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count) @@ -178,26 +177,15 @@ class Bitcoin: if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type): await helpers.confirm_foreign_address(txi.address_n) - if input_is_segwit(txi): - await self.process_segwit_input(txi) - elif input_is_nonsegwit(txi): - await self.process_nonsegwit_input(txi) - else: + if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES: raise wire.DataError("Wrong input script type") - async def process_segwit_input(self, txi: TxInputType) -> None: - await self.process_bip143_input(txi) + prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index) - async def process_nonsegwit_input(self, txi: TxInputType) -> None: - self.total_in += await self.get_prevtx_output_value( - txi.prev_hash, txi.prev_index - ) + if txi.amount is not None and prev_amount != txi.amount: + raise wire.DataError("Invalid amount specified") - async def process_bip143_input(self, txi: TxInputType) -> None: - if not txi.amount: - raise wire.DataError("Expected input with amount") - self.bip143_in += txi.amount - self.total_in += txi.amount + self.total_in += prev_amount async def confirm_output(self, txo: TxOutputType, script_pubkey: bytes) -> None: if self.change_out == 0 and self.output_is_change(txo): @@ -229,13 +217,12 @@ class Bitcoin: self.write_tx_input(self.serialized_tx, txi, script_sig) def sign_bip143_input(self, txi: TxInputType) -> Tuple[bytes, bytes]: + if txi.amount is None: + raise wire.DataError("Expected input with amount") + self.wallet_path.check_input(txi) self.multisig_fingerprint.check_input(txi) - if txi.amount > self.bip143_in: - raise wire.ProcessError("Transaction has changed during signing") - self.bip143_in -= txi.amount - node = self.keychain.derive(txi.address_n) public_key = node.public_key() hash143_hash = self.hash143_preimage_hash( diff --git a/core/src/apps/bitcoin/sign_tx/bitcoinlike.py b/core/src/apps/bitcoin/sign_tx/bitcoinlike.py index a37b7509d..7b57c265a 100644 --- a/core/src/apps/bitcoin/sign_tx/bitcoinlike.py +++ b/core/src/apps/bitcoin/sign_tx/bitcoinlike.py @@ -4,7 +4,6 @@ from micropython import const from trezor import wire from trezor.messages.SignTx import SignTx from trezor.messages.TransactionType import TransactionType -from trezor.messages.TxInputType import TxInputType from apps.common.writers import write_bitcoin_varint @@ -19,17 +18,6 @@ _SIGHASH_FORKID = const(0x40) class Bitcoinlike(Bitcoin): - async def process_segwit_input(self, txi: TxInputType) -> None: - if not self.coin.segwit: - raise wire.DataError("Segwit not enabled on this coin") - await super().process_segwit_input(txi) - - async def process_nonsegwit_input(self, txi: TxInputType) -> None: - if self.coin.force_bip143: - await self.process_bip143_input(txi) - else: - await super().process_nonsegwit_input(txi) - async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None: txi = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) @@ -83,11 +71,11 @@ class Bitcoinlike(Bitcoin): await super().write_prev_tx_footer(w, tx, prev_hash) if self.coin.extra_data: - ofs = 0 - while ofs < tx.extra_data_len: - size = min(1024, tx.extra_data_len - ofs) + offset = 0 + while offset < tx.extra_data_len: + size = min(1024, tx.extra_data_len - offset) data = await helpers.request_tx_extra_data( - self.tx_req, ofs, size, prev_hash + self.tx_req, offset, size, prev_hash ) writers.write_bytes_unchecked(w, data) - ofs += len(data) + offset += len(data) diff --git a/core/src/apps/bitcoin/sign_tx/zcash.py b/core/src/apps/bitcoin/sign_tx/zcash.py index 748379128..63eecd7dc 100644 --- a/core/src/apps/bitcoin/sign_tx/zcash.py +++ b/core/src/apps/bitcoin/sign_tx/zcash.py @@ -52,10 +52,8 @@ class Overwintered(Bitcoinlike): self.write_tx_footer(self.serialized_tx, self.tx) if self.tx.version == 3: - write_uint32(self.serialized_tx, self.tx.expiry) # expiryHeight write_bitcoin_varint(self.serialized_tx, 0) # nJoinSplit elif self.tx.version == 4: - write_uint32(self.serialized_tx, self.tx.expiry) # expiryHeight write_uint64(self.serialized_tx, 0) # valueBalance write_bitcoin_varint(self.serialized_tx, 0) # nShieldedSpend write_bitcoin_varint(self.serialized_tx, 0) # nShieldedOutput @@ -65,9 +63,6 @@ class Overwintered(Bitcoinlike): await helpers.request_tx_finish(self.tx_req) - async def process_nonsegwit_input(self, txi: TxInputType) -> None: - await self.process_bip143_input(txi) - async def sign_nonsegwit_input(self, i_sign: int) -> None: await self.sign_nonsegwit_bip143_input(i_sign) @@ -78,6 +73,10 @@ class Overwintered(Bitcoinlike): write_uint32(w, tx.version | OVERWINTERED) write_uint32(w, tx.version_group_id) # nVersionGroupId + def write_tx_footer(self, w: Writer, tx: Union[SignTx, TransactionType]) -> None: + write_uint32(w, tx.lock_time) + write_uint32(w, tx.expiry) # expiryHeight + # ZIP-0143 / ZIP-0243 # === diff --git a/core/tests/test_apps.bitcoin.segwit.signtx.p2wpkh_in_p2sh.py b/core/tests/test_apps.bitcoin.segwit.signtx.p2wpkh_in_p2sh.py index 7f9881672..e3d0de2d1 100644 --- a/core/tests/test_apps.bitcoin.segwit.signtx.p2wpkh_in_p2sh.py +++ b/core/tests/test_apps.bitcoin.segwit.signtx.p2wpkh_in_p2sh.py @@ -343,7 +343,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase): messages_count = int(len(messages) / 2) for request, response in chunks(messages, 2): if i == messages_count - 1: # last message should throw wire.Error - self.assertRaises(wire.ProcessError, signer.send, request) + self.assertRaises(wire.DataError, signer.send, request) else: self.assertEqual(signer.send(request), response) i += 1