core/sign_tx: Check script_pubkeys of inputs.

pull/1091/head
Andrew Kozlik 4 years ago committed by Andrew Kozlik
parent cc655575c8
commit 99f01cd316

@ -1,10 +1,9 @@
from trezor import utils, wire
from trezor.crypto import base58, cashaddr
from trezor.crypto.hashlib import sha256
from trezor.messages import InputScriptType, OutputScriptType
from trezor.messages import InputScriptType
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
from trezor.messages.TxInputType import TxInputType
from trezor.messages.TxOutputType import TxOutputType
from apps.common import address_type
from apps.common.coininfo import CoinInfo
@ -65,21 +64,18 @@ def input_derive_script(
raise wire.ProcessError("Invalid script type")
def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes:
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
return output_script_paytoopreturn(txo.op_return_data)
if coin.bech32_prefix and txo.address.startswith(coin.bech32_prefix):
def output_derive_script(address: str, coin: CoinInfo) -> bytes:
if coin.bech32_prefix and address.startswith(coin.bech32_prefix):
# p2wpkh or p2wsh
witprog = common.decode_bech32_address(coin.bech32_prefix, txo.address)
witprog = common.decode_bech32_address(coin.bech32_prefix, address)
return output_script_native_p2wpkh_or_p2wsh(witprog)
if (
not utils.BITCOIN_ONLY
and coin.cashaddr_prefix is not None
and txo.address.startswith(coin.cashaddr_prefix + ":")
and address.startswith(coin.cashaddr_prefix + ":")
):
prefix, addr = txo.address.split(":")
prefix, addr = address.split(":")
version, data = cashaddr.decode(prefix, addr)
if version == cashaddr.ADDRESS_TYPE_P2KH:
version = coin.address_type
@ -90,7 +86,7 @@ def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes:
raw_address = bytes([version]) + data
else:
try:
raw_address = base58.decode_check(txo.address, coin.b58_hash)
raw_address = base58.decode_check(address, coin.b58_hash)
except ValueError:
raise wire.DataError("Invalid address")

@ -180,7 +180,9 @@ class Bitcoin:
if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Wrong input script type")
prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index)
prev_amount, script_pubkey = await self.get_prevtx_output(
txi.prev_hash, txi.prev_index
)
if txi.amount is not None and prev_amount != txi.amount:
raise wire.DataError("Invalid amount specified")
@ -328,7 +330,9 @@ class Bitcoin:
script_pubkey = self.output_derive_script(txo)
self.write_tx_output(self.serialized_tx, txo, script_pubkey)
async def get_prevtx_output_value(self, prev_hash: bytes, prev_index: int) -> int:
async def get_prevtx_output(
self, prev_hash: bytes, prev_index: int
) -> Tuple[int, bytes]:
amount_out = 0 # output amount
# STAGE_REQUEST_2_PREV_META in legacy
@ -358,6 +362,7 @@ class Bitcoin:
self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey)
if i == prev_index:
amount_out = txo_bin.amount
script_pubkey = txo_bin.script_pubkey
self.check_prevtx_output(txo_bin)
await self.write_prev_tx_footer(txh, tx, prev_hash)
@ -368,7 +373,7 @@ class Bitcoin:
):
raise wire.ProcessError("Encountered invalid prev_hash")
return amount_out
return amount_out, script_pubkey
def check_prevtx_output(self, txo_bin: TxOutputBinType) -> None:
# Validations to perform on the UTXO when checking the previous transaction output amount.
@ -425,6 +430,9 @@ class Bitcoin:
# ===
def output_derive_script(self, txo: TxOutputType) -> bytes:
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
return scripts.output_script_paytoopreturn(txo.op_return_data)
if txo.address_n:
# change output
try:
@ -438,7 +446,7 @@ class Bitcoin:
input_script_type, self.coin, node, txo.multisig
)
return scripts.output_derive_script(txo, self.coin)
return scripts.output_derive_script(txo.address, self.coin)
def output_is_change(self, txo: TxOutputType) -> bool:
if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES:

@ -69,7 +69,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
for txo in [self.out1, self.out2]:
txo_bin = TxOutputBinType()
txo_bin.amount = txo.amount
script_pubkey = output_derive_script(txo, coin)
script_pubkey = output_derive_script(txo.address, coin)
bip143.hash143_add_output(txo_bin, script_pubkey)
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
@ -86,7 +86,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
for txo in [self.out1, self.out2]:
txo_bin = TxOutputBinType()
txo_bin.amount = txo.amount
script_pubkey = output_derive_script(txo, coin)
script_pubkey = output_derive_script(txo.address, coin)
bip143.hash143_add_output(txo_bin, script_pubkey)
# test data public key hash

@ -58,7 +58,7 @@ class TestSegwitBip143(unittest.TestCase):
for txo in [self.out1, self.out2]:
txo_bin = TxOutputBinType()
txo_bin.amount = txo.amount
script_pubkey = output_derive_script(txo, coin)
script_pubkey = output_derive_script(txo.address, coin)
bip143.hash143_add_output(txo_bin, script_pubkey)
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
@ -72,7 +72,7 @@ class TestSegwitBip143(unittest.TestCase):
for txo in [self.out1, self.out2]:
txo_bin = TxOutputBinType()
txo_bin.amount = txo.amount
script_pubkey = output_derive_script(txo, coin)
script_pubkey = output_derive_script(txo.address, coin)
bip143.hash143_add_output(txo_bin, script_pubkey)
# test data public key hash

@ -34,7 +34,7 @@ class TestCalculateTxWeight(unittest.TestCase):
calculator = TxWeightCalculator(1, 1)
calculator.add_input(inp1)
calculator.add_output(output_derive_script(out1, coin))
calculator.add_output(output_derive_script(out1.address, coin))
serialized_tx = '010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'
tx_weight = len(serialized_tx) / 2 * 4 # non-segwit tx's weight is simple length*4
@ -73,8 +73,8 @@ class TestCalculateTxWeight(unittest.TestCase):
calculator = TxWeightCalculator(1, 2)
calculator.add_input(inp1)
calculator.add_output(output_derive_script(out1, coin))
calculator.add_output(output_derive_script(out2, coin))
calculator.add_output(output_derive_script(out1.address, coin))
calculator.add_output(output_derive_script(out2.address, coin))
self.assertEqual(calculator.get_total(), 670)
# non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560
@ -113,8 +113,8 @@ class TestCalculateTxWeight(unittest.TestCase):
calculator = TxWeightCalculator(1, 2)
calculator.add_input(inp1)
calculator.add_output(output_derive_script(out1, coin))
calculator.add_output(output_derive_script(out2, coin))
calculator.add_output(output_derive_script(out1.address, coin))
calculator.add_output(output_derive_script(out2.address, coin))
self.assertEqual(calculator.get_total(), 566)
# non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456

Loading…
Cancel
Save