1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-03 03:50:58 +00:00

core/sign_tx: validate prevout amount in all cases

This commit is contained in:
Andrew Kozlik 2020-05-15 15:22:34 +02:00 committed by Tomas Susanka
parent 7db3e930d4
commit 42eddf8e04
4 changed files with 18 additions and 44 deletions

View File

@ -84,7 +84,6 @@ class Bitcoin:
# amounts # amounts
self.total_in = 0 # sum of input amounts self.total_in = 0 # sum of input amounts
self.bip143_in = 0 # sum of segwit input amounts
self.total_out = 0 # sum of output amounts self.total_out = 0 # sum of output amounts
self.change_out = 0 # change output amount self.change_out = 0 # change output amount
self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count) self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count)
@ -178,26 +177,15 @@ class Bitcoin:
if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type): if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type):
await helpers.confirm_foreign_address(txi.address_n) await helpers.confirm_foreign_address(txi.address_n)
if input_is_segwit(txi): if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES:
await self.process_segwit_input(txi)
elif input_is_nonsegwit(txi):
await self.process_nonsegwit_input(txi)
else:
raise wire.DataError("Wrong input script type") raise wire.DataError("Wrong input script type")
async def process_segwit_input(self, txi: TxInputType) -> None: prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index)
await self.process_bip143_input(txi)
async def process_nonsegwit_input(self, txi: TxInputType) -> None: if txi.amount is not None and prev_amount != txi.amount:
self.total_in += await self.get_prevtx_output_value( raise wire.DataError("Invalid amount specified")
txi.prev_hash, txi.prev_index
)
async def process_bip143_input(self, txi: TxInputType) -> None: self.total_in += prev_amount
if not txi.amount:
raise wire.DataError("Expected input with amount")
self.bip143_in += txi.amount
self.total_in += txi.amount
async def confirm_output(self, txo: TxOutputType, script_pubkey: bytes) -> None: async def confirm_output(self, txo: TxOutputType, script_pubkey: bytes) -> None:
if self.change_out == 0 and self.output_is_change(txo): if self.change_out == 0 and self.output_is_change(txo):
@ -229,13 +217,12 @@ class Bitcoin:
self.write_tx_input(self.serialized_tx, txi, script_sig) self.write_tx_input(self.serialized_tx, txi, script_sig)
def sign_bip143_input(self, txi: TxInputType) -> Tuple[bytes, bytes]: def sign_bip143_input(self, txi: TxInputType) -> Tuple[bytes, bytes]:
if txi.amount is None:
raise wire.DataError("Expected input with amount")
self.wallet_path.check_input(txi) self.wallet_path.check_input(txi)
self.multisig_fingerprint.check_input(txi) self.multisig_fingerprint.check_input(txi)
if txi.amount > self.bip143_in:
raise wire.ProcessError("Transaction has changed during signing")
self.bip143_in -= txi.amount
node = self.keychain.derive(txi.address_n) node = self.keychain.derive(txi.address_n)
public_key = node.public_key() public_key = node.public_key()
hash143_hash = self.hash143_preimage_hash( hash143_hash = self.hash143_preimage_hash(

View File

@ -4,7 +4,6 @@ from micropython import const
from trezor import wire from trezor import wire
from trezor.messages.SignTx import SignTx from trezor.messages.SignTx import SignTx
from trezor.messages.TransactionType import TransactionType from trezor.messages.TransactionType import TransactionType
from trezor.messages.TxInputType import TxInputType
from apps.common.writers import write_bitcoin_varint from apps.common.writers import write_bitcoin_varint
@ -19,17 +18,6 @@ _SIGHASH_FORKID = const(0x40)
class Bitcoinlike(Bitcoin): class Bitcoinlike(Bitcoin):
async def process_segwit_input(self, txi: TxInputType) -> None:
if not self.coin.segwit:
raise wire.DataError("Segwit not enabled on this coin")
await super().process_segwit_input(txi)
async def process_nonsegwit_input(self, txi: TxInputType) -> None:
if self.coin.force_bip143:
await self.process_bip143_input(txi)
else:
await super().process_nonsegwit_input(txi)
async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None: async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None:
txi = await helpers.request_tx_input(self.tx_req, i_sign, self.coin) txi = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
@ -83,11 +71,11 @@ class Bitcoinlike(Bitcoin):
await super().write_prev_tx_footer(w, tx, prev_hash) await super().write_prev_tx_footer(w, tx, prev_hash)
if self.coin.extra_data: if self.coin.extra_data:
ofs = 0 offset = 0
while ofs < tx.extra_data_len: while offset < tx.extra_data_len:
size = min(1024, tx.extra_data_len - ofs) size = min(1024, tx.extra_data_len - offset)
data = await helpers.request_tx_extra_data( data = await helpers.request_tx_extra_data(
self.tx_req, ofs, size, prev_hash self.tx_req, offset, size, prev_hash
) )
writers.write_bytes_unchecked(w, data) writers.write_bytes_unchecked(w, data)
ofs += len(data) offset += len(data)

View File

@ -52,10 +52,8 @@ class Overwintered(Bitcoinlike):
self.write_tx_footer(self.serialized_tx, self.tx) self.write_tx_footer(self.serialized_tx, self.tx)
if self.tx.version == 3: if self.tx.version == 3:
write_uint32(self.serialized_tx, self.tx.expiry) # expiryHeight
write_bitcoin_varint(self.serialized_tx, 0) # nJoinSplit write_bitcoin_varint(self.serialized_tx, 0) # nJoinSplit
elif self.tx.version == 4: elif self.tx.version == 4:
write_uint32(self.serialized_tx, self.tx.expiry) # expiryHeight
write_uint64(self.serialized_tx, 0) # valueBalance write_uint64(self.serialized_tx, 0) # valueBalance
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedSpend write_bitcoin_varint(self.serialized_tx, 0) # nShieldedSpend
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedOutput write_bitcoin_varint(self.serialized_tx, 0) # nShieldedOutput
@ -65,9 +63,6 @@ class Overwintered(Bitcoinlike):
await helpers.request_tx_finish(self.tx_req) await helpers.request_tx_finish(self.tx_req)
async def process_nonsegwit_input(self, txi: TxInputType) -> None:
await self.process_bip143_input(txi)
async def sign_nonsegwit_input(self, i_sign: int) -> None: async def sign_nonsegwit_input(self, i_sign: int) -> None:
await self.sign_nonsegwit_bip143_input(i_sign) await self.sign_nonsegwit_bip143_input(i_sign)
@ -78,6 +73,10 @@ class Overwintered(Bitcoinlike):
write_uint32(w, tx.version | OVERWINTERED) write_uint32(w, tx.version | OVERWINTERED)
write_uint32(w, tx.version_group_id) # nVersionGroupId write_uint32(w, tx.version_group_id) # nVersionGroupId
def write_tx_footer(self, w: Writer, tx: Union[SignTx, TransactionType]) -> None:
write_uint32(w, tx.lock_time)
write_uint32(w, tx.expiry) # expiryHeight
# ZIP-0143 / ZIP-0243 # ZIP-0143 / ZIP-0243
# === # ===

View File

@ -343,7 +343,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
messages_count = int(len(messages) / 2) messages_count = int(len(messages) / 2)
for request, response in chunks(messages, 2): for request, response in chunks(messages, 2):
if i == messages_count - 1: # last message should throw wire.Error if i == messages_count - 1: # last message should throw wire.Error
self.assertRaises(wire.ProcessError, signer.send, request) self.assertRaises(wire.DataError, signer.send, request)
else: else:
self.assertEqual(signer.send(request), response) self.assertEqual(signer.send(request), response)
i += 1 i += 1