mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-05 04:50:57 +00:00
core/sign_tx: validate prevout amount in all cases
This commit is contained in:
parent
7db3e930d4
commit
42eddf8e04
@ -84,7 +84,6 @@ class Bitcoin:
|
|||||||
|
|
||||||
# amounts
|
# amounts
|
||||||
self.total_in = 0 # sum of input amounts
|
self.total_in = 0 # sum of input amounts
|
||||||
self.bip143_in = 0 # sum of segwit input amounts
|
|
||||||
self.total_out = 0 # sum of output amounts
|
self.total_out = 0 # sum of output amounts
|
||||||
self.change_out = 0 # change output amount
|
self.change_out = 0 # change output amount
|
||||||
self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count)
|
self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count)
|
||||||
@ -178,26 +177,15 @@ class Bitcoin:
|
|||||||
if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type):
|
if not addresses.validate_full_path(txi.address_n, self.coin, txi.script_type):
|
||||||
await helpers.confirm_foreign_address(txi.address_n)
|
await helpers.confirm_foreign_address(txi.address_n)
|
||||||
|
|
||||||
if input_is_segwit(txi):
|
if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES:
|
||||||
await self.process_segwit_input(txi)
|
|
||||||
elif input_is_nonsegwit(txi):
|
|
||||||
await self.process_nonsegwit_input(txi)
|
|
||||||
else:
|
|
||||||
raise wire.DataError("Wrong input script type")
|
raise wire.DataError("Wrong input script type")
|
||||||
|
|
||||||
async def process_segwit_input(self, txi: TxInputType) -> None:
|
prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index)
|
||||||
await self.process_bip143_input(txi)
|
|
||||||
|
|
||||||
async def process_nonsegwit_input(self, txi: TxInputType) -> None:
|
if txi.amount is not None and prev_amount != txi.amount:
|
||||||
self.total_in += await self.get_prevtx_output_value(
|
raise wire.DataError("Invalid amount specified")
|
||||||
txi.prev_hash, txi.prev_index
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process_bip143_input(self, txi: TxInputType) -> None:
|
self.total_in += prev_amount
|
||||||
if not txi.amount:
|
|
||||||
raise wire.DataError("Expected input with amount")
|
|
||||||
self.bip143_in += txi.amount
|
|
||||||
self.total_in += txi.amount
|
|
||||||
|
|
||||||
async def confirm_output(self, txo: TxOutputType, script_pubkey: bytes) -> None:
|
async def confirm_output(self, txo: TxOutputType, script_pubkey: bytes) -> None:
|
||||||
if self.change_out == 0 and self.output_is_change(txo):
|
if self.change_out == 0 and self.output_is_change(txo):
|
||||||
@ -229,13 +217,12 @@ class Bitcoin:
|
|||||||
self.write_tx_input(self.serialized_tx, txi, script_sig)
|
self.write_tx_input(self.serialized_tx, txi, script_sig)
|
||||||
|
|
||||||
def sign_bip143_input(self, txi: TxInputType) -> Tuple[bytes, bytes]:
|
def sign_bip143_input(self, txi: TxInputType) -> Tuple[bytes, bytes]:
|
||||||
|
if txi.amount is None:
|
||||||
|
raise wire.DataError("Expected input with amount")
|
||||||
|
|
||||||
self.wallet_path.check_input(txi)
|
self.wallet_path.check_input(txi)
|
||||||
self.multisig_fingerprint.check_input(txi)
|
self.multisig_fingerprint.check_input(txi)
|
||||||
|
|
||||||
if txi.amount > self.bip143_in:
|
|
||||||
raise wire.ProcessError("Transaction has changed during signing")
|
|
||||||
self.bip143_in -= txi.amount
|
|
||||||
|
|
||||||
node = self.keychain.derive(txi.address_n)
|
node = self.keychain.derive(txi.address_n)
|
||||||
public_key = node.public_key()
|
public_key = node.public_key()
|
||||||
hash143_hash = self.hash143_preimage_hash(
|
hash143_hash = self.hash143_preimage_hash(
|
||||||
|
@ -4,7 +4,6 @@ from micropython import const
|
|||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.messages.SignTx import SignTx
|
from trezor.messages.SignTx import SignTx
|
||||||
from trezor.messages.TransactionType import TransactionType
|
from trezor.messages.TransactionType import TransactionType
|
||||||
from trezor.messages.TxInputType import TxInputType
|
|
||||||
|
|
||||||
from apps.common.writers import write_bitcoin_varint
|
from apps.common.writers import write_bitcoin_varint
|
||||||
|
|
||||||
@ -19,17 +18,6 @@ _SIGHASH_FORKID = const(0x40)
|
|||||||
|
|
||||||
|
|
||||||
class Bitcoinlike(Bitcoin):
|
class Bitcoinlike(Bitcoin):
|
||||||
async def process_segwit_input(self, txi: TxInputType) -> None:
|
|
||||||
if not self.coin.segwit:
|
|
||||||
raise wire.DataError("Segwit not enabled on this coin")
|
|
||||||
await super().process_segwit_input(txi)
|
|
||||||
|
|
||||||
async def process_nonsegwit_input(self, txi: TxInputType) -> None:
|
|
||||||
if self.coin.force_bip143:
|
|
||||||
await self.process_bip143_input(txi)
|
|
||||||
else:
|
|
||||||
await super().process_nonsegwit_input(txi)
|
|
||||||
|
|
||||||
async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None:
|
async def sign_nonsegwit_bip143_input(self, i_sign: int) -> None:
|
||||||
txi = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
|
txi = await helpers.request_tx_input(self.tx_req, i_sign, self.coin)
|
||||||
|
|
||||||
@ -83,11 +71,11 @@ class Bitcoinlike(Bitcoin):
|
|||||||
await super().write_prev_tx_footer(w, tx, prev_hash)
|
await super().write_prev_tx_footer(w, tx, prev_hash)
|
||||||
|
|
||||||
if self.coin.extra_data:
|
if self.coin.extra_data:
|
||||||
ofs = 0
|
offset = 0
|
||||||
while ofs < tx.extra_data_len:
|
while offset < tx.extra_data_len:
|
||||||
size = min(1024, tx.extra_data_len - ofs)
|
size = min(1024, tx.extra_data_len - offset)
|
||||||
data = await helpers.request_tx_extra_data(
|
data = await helpers.request_tx_extra_data(
|
||||||
self.tx_req, ofs, size, prev_hash
|
self.tx_req, offset, size, prev_hash
|
||||||
)
|
)
|
||||||
writers.write_bytes_unchecked(w, data)
|
writers.write_bytes_unchecked(w, data)
|
||||||
ofs += len(data)
|
offset += len(data)
|
||||||
|
@ -52,10 +52,8 @@ class Overwintered(Bitcoinlike):
|
|||||||
self.write_tx_footer(self.serialized_tx, self.tx)
|
self.write_tx_footer(self.serialized_tx, self.tx)
|
||||||
|
|
||||||
if self.tx.version == 3:
|
if self.tx.version == 3:
|
||||||
write_uint32(self.serialized_tx, self.tx.expiry) # expiryHeight
|
|
||||||
write_bitcoin_varint(self.serialized_tx, 0) # nJoinSplit
|
write_bitcoin_varint(self.serialized_tx, 0) # nJoinSplit
|
||||||
elif self.tx.version == 4:
|
elif self.tx.version == 4:
|
||||||
write_uint32(self.serialized_tx, self.tx.expiry) # expiryHeight
|
|
||||||
write_uint64(self.serialized_tx, 0) # valueBalance
|
write_uint64(self.serialized_tx, 0) # valueBalance
|
||||||
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedSpend
|
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedSpend
|
||||||
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedOutput
|
write_bitcoin_varint(self.serialized_tx, 0) # nShieldedOutput
|
||||||
@ -65,9 +63,6 @@ class Overwintered(Bitcoinlike):
|
|||||||
|
|
||||||
await helpers.request_tx_finish(self.tx_req)
|
await helpers.request_tx_finish(self.tx_req)
|
||||||
|
|
||||||
async def process_nonsegwit_input(self, txi: TxInputType) -> None:
|
|
||||||
await self.process_bip143_input(txi)
|
|
||||||
|
|
||||||
async def sign_nonsegwit_input(self, i_sign: int) -> None:
|
async def sign_nonsegwit_input(self, i_sign: int) -> None:
|
||||||
await self.sign_nonsegwit_bip143_input(i_sign)
|
await self.sign_nonsegwit_bip143_input(i_sign)
|
||||||
|
|
||||||
@ -78,6 +73,10 @@ class Overwintered(Bitcoinlike):
|
|||||||
write_uint32(w, tx.version | OVERWINTERED)
|
write_uint32(w, tx.version | OVERWINTERED)
|
||||||
write_uint32(w, tx.version_group_id) # nVersionGroupId
|
write_uint32(w, tx.version_group_id) # nVersionGroupId
|
||||||
|
|
||||||
|
def write_tx_footer(self, w: Writer, tx: Union[SignTx, TransactionType]) -> None:
|
||||||
|
write_uint32(w, tx.lock_time)
|
||||||
|
write_uint32(w, tx.expiry) # expiryHeight
|
||||||
|
|
||||||
# ZIP-0143 / ZIP-0243
|
# ZIP-0143 / ZIP-0243
|
||||||
# ===
|
# ===
|
||||||
|
|
||||||
|
@ -343,7 +343,7 @@ class TestSignSegwitTxP2WPKHInP2SH(unittest.TestCase):
|
|||||||
messages_count = int(len(messages) / 2)
|
messages_count = int(len(messages) / 2)
|
||||||
for request, response in chunks(messages, 2):
|
for request, response in chunks(messages, 2):
|
||||||
if i == messages_count - 1: # last message should throw wire.Error
|
if i == messages_count - 1: # last message should throw wire.Error
|
||||||
self.assertRaises(wire.ProcessError, signer.send, request)
|
self.assertRaises(wire.DataError, signer.send, request)
|
||||||
else:
|
else:
|
||||||
self.assertEqual(signer.send(request), response)
|
self.assertEqual(signer.send(request), response)
|
||||||
i += 1
|
i += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user