1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-24 15:28:10 +00:00

core/sign_tx: Check script_pubkeys of inputs.

This commit is contained in:
Andrew Kozlik 2020-05-25 22:31:13 +02:00 committed by Andrew Kozlik
parent cc655575c8
commit 99f01cd316
5 changed files with 28 additions and 24 deletions

View File

@ -1,10 +1,9 @@
from trezor import utils, wire from trezor import utils, wire
from trezor.crypto import base58, cashaddr from trezor.crypto import base58, cashaddr
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.messages import InputScriptType, OutputScriptType from trezor.messages import InputScriptType
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
from trezor.messages.TxInputType import TxInputType from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
from apps.common import address_type from apps.common import address_type
from apps.common.coininfo import CoinInfo from apps.common.coininfo import CoinInfo
@ -65,21 +64,18 @@ def input_derive_script(
raise wire.ProcessError("Invalid script type") raise wire.ProcessError("Invalid script type")
def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes: def output_derive_script(address: str, coin: CoinInfo) -> bytes:
if txo.script_type == OutputScriptType.PAYTOOPRETURN: if coin.bech32_prefix and address.startswith(coin.bech32_prefix):
return output_script_paytoopreturn(txo.op_return_data)
if coin.bech32_prefix and txo.address.startswith(coin.bech32_prefix):
# p2wpkh or p2wsh # p2wpkh or p2wsh
witprog = common.decode_bech32_address(coin.bech32_prefix, txo.address) witprog = common.decode_bech32_address(coin.bech32_prefix, address)
return output_script_native_p2wpkh_or_p2wsh(witprog) return output_script_native_p2wpkh_or_p2wsh(witprog)
if ( if (
not utils.BITCOIN_ONLY not utils.BITCOIN_ONLY
and coin.cashaddr_prefix is not None and coin.cashaddr_prefix is not None
and txo.address.startswith(coin.cashaddr_prefix + ":") and address.startswith(coin.cashaddr_prefix + ":")
): ):
prefix, addr = txo.address.split(":") prefix, addr = address.split(":")
version, data = cashaddr.decode(prefix, addr) version, data = cashaddr.decode(prefix, addr)
if version == cashaddr.ADDRESS_TYPE_P2KH: if version == cashaddr.ADDRESS_TYPE_P2KH:
version = coin.address_type version = coin.address_type
@ -90,7 +86,7 @@ def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes:
raw_address = bytes([version]) + data raw_address = bytes([version]) + data
else: else:
try: try:
raw_address = base58.decode_check(txo.address, coin.b58_hash) raw_address = base58.decode_check(address, coin.b58_hash)
except ValueError: except ValueError:
raise wire.DataError("Invalid address") raise wire.DataError("Invalid address")

View File

@ -180,7 +180,9 @@ class Bitcoin:
if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES: if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Wrong input script type") raise wire.DataError("Wrong input script type")
prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index) prev_amount, script_pubkey = await self.get_prevtx_output(
txi.prev_hash, txi.prev_index
)
if txi.amount is not None and prev_amount != txi.amount: if txi.amount is not None and prev_amount != txi.amount:
raise wire.DataError("Invalid amount specified") raise wire.DataError("Invalid amount specified")
@ -328,7 +330,9 @@ class Bitcoin:
script_pubkey = self.output_derive_script(txo) script_pubkey = self.output_derive_script(txo)
self.write_tx_output(self.serialized_tx, txo, script_pubkey) self.write_tx_output(self.serialized_tx, txo, script_pubkey)
async def get_prevtx_output_value(self, prev_hash: bytes, prev_index: int) -> int: async def get_prevtx_output(
self, prev_hash: bytes, prev_index: int
) -> Tuple[int, bytes]:
amount_out = 0 # output amount amount_out = 0 # output amount
# STAGE_REQUEST_2_PREV_META in legacy # STAGE_REQUEST_2_PREV_META in legacy
@ -358,6 +362,7 @@ class Bitcoin:
self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey) self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey)
if i == prev_index: if i == prev_index:
amount_out = txo_bin.amount amount_out = txo_bin.amount
script_pubkey = txo_bin.script_pubkey
self.check_prevtx_output(txo_bin) self.check_prevtx_output(txo_bin)
await self.write_prev_tx_footer(txh, tx, prev_hash) await self.write_prev_tx_footer(txh, tx, prev_hash)
@ -368,7 +373,7 @@ class Bitcoin:
): ):
raise wire.ProcessError("Encountered invalid prev_hash") raise wire.ProcessError("Encountered invalid prev_hash")
return amount_out return amount_out, script_pubkey
def check_prevtx_output(self, txo_bin: TxOutputBinType) -> None: def check_prevtx_output(self, txo_bin: TxOutputBinType) -> None:
# Validations to perform on the UTXO when checking the previous transaction output amount. # Validations to perform on the UTXO when checking the previous transaction output amount.
@ -425,6 +430,9 @@ class Bitcoin:
# === # ===
def output_derive_script(self, txo: TxOutputType) -> bytes: def output_derive_script(self, txo: TxOutputType) -> bytes:
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
return scripts.output_script_paytoopreturn(txo.op_return_data)
if txo.address_n: if txo.address_n:
# change output # change output
try: try:
@ -438,7 +446,7 @@ class Bitcoin:
input_script_type, self.coin, node, txo.multisig input_script_type, self.coin, node, txo.multisig
) )
return scripts.output_derive_script(txo, self.coin) return scripts.output_derive_script(txo.address, self.coin)
def output_is_change(self, txo: TxOutputType) -> bool: def output_is_change(self, txo: TxOutputType) -> bool:
if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES: if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES:

View File

@ -69,7 +69,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
for txo in [self.out1, self.out2]: for txo in [self.out1, self.out2]:
txo_bin = TxOutputBinType() txo_bin = TxOutputBinType()
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
script_pubkey = output_derive_script(txo, coin) script_pubkey = output_derive_script(txo.address, coin)
bip143.hash143_add_output(txo_bin, script_pubkey) bip143.hash143_add_output(txo_bin, script_pubkey)
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double) outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
@ -86,7 +86,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
for txo in [self.out1, self.out2]: for txo in [self.out1, self.out2]:
txo_bin = TxOutputBinType() txo_bin = TxOutputBinType()
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
script_pubkey = output_derive_script(txo, coin) script_pubkey = output_derive_script(txo.address, coin)
bip143.hash143_add_output(txo_bin, script_pubkey) bip143.hash143_add_output(txo_bin, script_pubkey)
# test data public key hash # test data public key hash

View File

@ -58,7 +58,7 @@ class TestSegwitBip143(unittest.TestCase):
for txo in [self.out1, self.out2]: for txo in [self.out1, self.out2]:
txo_bin = TxOutputBinType() txo_bin = TxOutputBinType()
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
script_pubkey = output_derive_script(txo, coin) script_pubkey = output_derive_script(txo.address, coin)
bip143.hash143_add_output(txo_bin, script_pubkey) bip143.hash143_add_output(txo_bin, script_pubkey)
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double) outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
@ -72,7 +72,7 @@ class TestSegwitBip143(unittest.TestCase):
for txo in [self.out1, self.out2]: for txo in [self.out1, self.out2]:
txo_bin = TxOutputBinType() txo_bin = TxOutputBinType()
txo_bin.amount = txo.amount txo_bin.amount = txo.amount
script_pubkey = output_derive_script(txo, coin) script_pubkey = output_derive_script(txo.address, coin)
bip143.hash143_add_output(txo_bin, script_pubkey) bip143.hash143_add_output(txo_bin, script_pubkey)
# test data public key hash # test data public key hash

View File

@ -34,7 +34,7 @@ class TestCalculateTxWeight(unittest.TestCase):
calculator = TxWeightCalculator(1, 1) calculator = TxWeightCalculator(1, 1)
calculator.add_input(inp1) calculator.add_input(inp1)
calculator.add_output(output_derive_script(out1, coin)) calculator.add_output(output_derive_script(out1.address, coin))
serialized_tx = '010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000' serialized_tx = '010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'
tx_weight = len(serialized_tx) / 2 * 4 # non-segwit tx's weight is simple length*4 tx_weight = len(serialized_tx) / 2 * 4 # non-segwit tx's weight is simple length*4
@ -73,8 +73,8 @@ class TestCalculateTxWeight(unittest.TestCase):
calculator = TxWeightCalculator(1, 2) calculator = TxWeightCalculator(1, 2)
calculator.add_input(inp1) calculator.add_input(inp1)
calculator.add_output(output_derive_script(out1, coin)) calculator.add_output(output_derive_script(out1.address, coin))
calculator.add_output(output_derive_script(out2, coin)) calculator.add_output(output_derive_script(out2.address, coin))
self.assertEqual(calculator.get_total(), 670) self.assertEqual(calculator.get_total(), 670)
# non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560 # non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560
@ -113,8 +113,8 @@ class TestCalculateTxWeight(unittest.TestCase):
calculator = TxWeightCalculator(1, 2) calculator = TxWeightCalculator(1, 2)
calculator.add_input(inp1) calculator.add_input(inp1)
calculator.add_output(output_derive_script(out1, coin)) calculator.add_output(output_derive_script(out1.address, coin))
calculator.add_output(output_derive_script(out2, coin)) calculator.add_output(output_derive_script(out2.address, coin))
self.assertEqual(calculator.get_total(), 566) self.assertEqual(calculator.get_total(), 566)
# non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456 # non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456