core/sign_tx: validate prevout amount in all cases

release/2020-06
Andrew Kozlik 4 years ago committed by Tomas Susanka
parent 7db3e930d4
commit 42eddf8e04

@ -84,7 +84,6 @@ class Bitcoin:
# amounts
self.total_in = 0 # sum of input amounts
self.bip143_in = 0 # sum of segwit input amounts
self.total_out = 0 # sum of output amounts
self.change_out = 0 # change output amount
self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count)
@ -178,26 +177,15 @@ class Bitcoin:
if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type):
await helpers.confirm_foreign_address(txi.address_n)
if input_is_segwit(txi):
await self.process_segwit_input(txi)
elif input_is_nonsegwit(txi):
await self.process_nonsegwit_input(txi)
else:
if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Wrong input script type")
async def process_segwit_input(self, txi: TxInputType) -> None:
await self.process_bip143_input(txi)
prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index)
async def process_nonsegwit_input(self, txi: TxInputType) -> None:
self.total_in += await self.get_prevtx_output_value(
txi.prev_hash, txi.prev_index
)
if txi.amount is not None and prev_amount != txi.amount:
raise wire.DataError("Invalid amount specified")
async def process_bip143_input(self, txi: TxInputType) -> None:
if not txi.amount:
raise wire.DataError("Expected input with amount")
self.bip143_in += txi.amount
self.total_in += txi.amount
self.total_in += prev_amount
async def confirm_output(self, txo: TxOutputType, script_pubkey: bytes) -> None:
if self.change_out == 0 and self.output_is_change(txo):
@ -229,13 +217,12 @@ class Bitcoin:
self.write_tx_input(self.serialized_tx, txi, script_sig)
def sign_bip143_input(self, txi: TxInputType) -> Tuple[bytes, bytes]:
if txi.amount is None:
raise wire.DataError("Expected input with amount")
self.wallet_path.check_input(txi)
self.multisig_fingerprint.check_input(txi)
if txi.amount > self.bip143_in:
raise wire.ProcessError("Transaction has changed during signing")
self.bip143_in -= txi.amount
node = self.keychain.derive(txi.address_n)
public_key = node.public_key()
hash143_hash = self.hash143_preimage_hash(

@ -4,7 +4,6 @@ from micropython import const
from trezor import wire
from trezor.messages.SignTx import SignTx
from trezor.messages.TransactionType import TransactionType
from trezor.messages.TxInputType import TxInputType
from apps.common.writers import write_bitcoin_varint
@ -19,17 +18,6 @@ _SIGHASH_FORKID = const(0x40)
class Bitcoinlike(Bitcoin):
async def process_segwit_input(self, txi: TxInputType) -> None:
if not self.coin.segwit:
raise wire.DataError("Segwit not enabled on this coin")
await super().process_segwit_input(txi)
async def process_nonsegwit_input(self, txi: TxInputType) -> None:
if self.coin.force_bip143:
await self.process_bip143_input(txi)
else:
await super().process_nonsegwit_input(txi)
async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None:
txi = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
@ -83,11 +71,11 @@ class Bitcoinlike(Bitcoin):
await super().write_prev_tx_footer(w, tx, prev_hash)
if self.coin.extra_data:
ofs = 0
while ofs < tx.extra_data_len:
size = min(1024, tx.extra_data_len - ofs)
offset = 0
while offset < tx.extra_data_len:
size = min(1024, tx.extra_data_len - offset)
data = await helpers.request_tx_extra_data(
self.tx_req, ofs, size, prev_hash
self.tx_req, offset, size, prev_hash
)
writers.write_bytes_unchecked(w, data)
ofs += len(data)
offset += len(data)

@ -52,10 +52,8 @@ class Overwintered(Bitcoinlike):
self.write_tx_footer(self.serialized_tx, self.tx)
if self.tx.version == 3:
write_uint32(self.serialized_tx, self.tx.expiry) # expiryHeight
write_bitcoin_varint(self.serialized_tx, 0) # nJoinSplit
elif self.tx.version == 4:
write_uint32(self.serialized_tx, self.tx.expiry) # expiryHeight
write_uint64(self.serialized_tx, 0) # valueBalance
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedSpend
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedOutput
@ -65,9 +63,6 @@ class Overwintered(Bitcoinlike):
await helpers.request_tx_finish(self.tx_req)
async def process_nonsegwit_input(self, txi: TxInputType) -> None:
await self.process_bip143_input(txi)
async def sign_nonsegwit_input(self, i_sign: int) -> None:
await self.sign_nonsegwit_bip143_input(i_sign)
@ -78,6 +73,10 @@ class Overwintered(Bitcoinlike):
write_uint32(w, tx.version | OVERWINTERED)
write_uint32(w, tx.version_group_id) # nVersionGroupId
def write_tx_footer(self, w: Writer, tx: Union[SignTx, TransactionType]) -> None:
write_uint32(w, tx.lock_time)
write_uint32(w, tx.expiry) # expiryHeight
# ZIP-0143 / ZIP-0243
# ===

@ -343,7 +343,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
messages_count = int(len(messages) / 2)
for request, response in chunks(messages, 2):
if i == messages_count - 1: # last message should throw wire.Error
self.assertRaises(wire.ProcessError, signer.send, request)
self.assertRaises(wire.DataError, signer.send, request)
else:
self.assertEqual(signer.send(request), response)
i += 1

Loading…
Cancel
Save