mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-23 23:08:14 +00:00
core/sign_tx: Check script_pubkeys of inputs.
This commit is contained in:
parent
cc655575c8
commit
99f01cd316
@ -1,10 +1,9 @@
|
||||
from trezor import utils, wire
|
||||
from trezor.crypto import base58, cashaddr
|
||||
from trezor.crypto.hashlib import sha256
|
||||
from trezor.messages import InputScriptType, OutputScriptType
|
||||
from trezor.messages import InputScriptType
|
||||
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
|
||||
from trezor.messages.TxInputType import TxInputType
|
||||
from trezor.messages.TxOutputType import TxOutputType
|
||||
|
||||
from apps.common import address_type
|
||||
from apps.common.coininfo import CoinInfo
|
||||
@ -65,21 +64,18 @@ def input_derive_script(
|
||||
raise wire.ProcessError("Invalid script type")
|
||||
|
||||
|
||||
def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes:
|
||||
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
|
||||
return output_script_paytoopreturn(txo.op_return_data)
|
||||
|
||||
if coin.bech32_prefix and txo.address.startswith(coin.bech32_prefix):
|
||||
def output_derive_script(address: str, coin: CoinInfo) -> bytes:
|
||||
if coin.bech32_prefix and address.startswith(coin.bech32_prefix):
|
||||
# p2wpkh or p2wsh
|
||||
witprog = common.decode_bech32_address(coin.bech32_prefix, txo.address)
|
||||
witprog = common.decode_bech32_address(coin.bech32_prefix, address)
|
||||
return output_script_native_p2wpkh_or_p2wsh(witprog)
|
||||
|
||||
if (
|
||||
not utils.BITCOIN_ONLY
|
||||
and coin.cashaddr_prefix is not None
|
||||
and txo.address.startswith(coin.cashaddr_prefix + ":")
|
||||
and address.startswith(coin.cashaddr_prefix + ":")
|
||||
):
|
||||
prefix, addr = txo.address.split(":")
|
||||
prefix, addr = address.split(":")
|
||||
version, data = cashaddr.decode(prefix, addr)
|
||||
if version == cashaddr.ADDRESS_TYPE_P2KH:
|
||||
version = coin.address_type
|
||||
@ -90,7 +86,7 @@ def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes:
|
||||
raw_address = bytes([version]) + data
|
||||
else:
|
||||
try:
|
||||
raw_address = base58.decode_check(txo.address, coin.b58_hash)
|
||||
raw_address = base58.decode_check(address, coin.b58_hash)
|
||||
except ValueError:
|
||||
raise wire.DataError("Invalid address")
|
||||
|
||||
|
@ -180,7 +180,9 @@ class Bitcoin:
|
||||
if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES:
|
||||
raise wire.DataError("Wrong input script type")
|
||||
|
||||
prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index)
|
||||
prev_amount, script_pubkey = await self.get_prevtx_output(
|
||||
txi.prev_hash, txi.prev_index
|
||||
)
|
||||
|
||||
if txi.amount is not None and prev_amount != txi.amount:
|
||||
raise wire.DataError("Invalid amount specified")
|
||||
@ -328,7 +330,9 @@ class Bitcoin:
|
||||
script_pubkey = self.output_derive_script(txo)
|
||||
self.write_tx_output(self.serialized_tx, txo, script_pubkey)
|
||||
|
||||
async def get_prevtx_output_value(self, prev_hash: bytes, prev_index: int) -> int:
|
||||
async def get_prevtx_output(
|
||||
self, prev_hash: bytes, prev_index: int
|
||||
) -> Tuple[int, bytes]:
|
||||
amount_out = 0 # output amount
|
||||
|
||||
# STAGE_REQUEST_2_PREV_META in legacy
|
||||
@ -358,6 +362,7 @@ class Bitcoin:
|
||||
self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey)
|
||||
if i == prev_index:
|
||||
amount_out = txo_bin.amount
|
||||
script_pubkey = txo_bin.script_pubkey
|
||||
self.check_prevtx_output(txo_bin)
|
||||
|
||||
await self.write_prev_tx_footer(txh, tx, prev_hash)
|
||||
@ -368,7 +373,7 @@ class Bitcoin:
|
||||
):
|
||||
raise wire.ProcessError("Encountered invalid prev_hash")
|
||||
|
||||
return amount_out
|
||||
return amount_out, script_pubkey
|
||||
|
||||
def check_prevtx_output(self, txo_bin: TxOutputBinType) -> None:
|
||||
# Validations to perform on the UTXO when checking the previous transaction output amount.
|
||||
@ -425,6 +430,9 @@ class Bitcoin:
|
||||
# ===
|
||||
|
||||
def output_derive_script(self, txo: TxOutputType) -> bytes:
|
||||
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
|
||||
return scripts.output_script_paytoopreturn(txo.op_return_data)
|
||||
|
||||
if txo.address_n:
|
||||
# change output
|
||||
try:
|
||||
@ -438,7 +446,7 @@ class Bitcoin:
|
||||
input_script_type, self.coin, node, txo.multisig
|
||||
)
|
||||
|
||||
return scripts.output_derive_script(txo, self.coin)
|
||||
return scripts.output_derive_script(txo.address, self.coin)
|
||||
|
||||
def output_is_change(self, txo: TxOutputType) -> bool:
|
||||
if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES:
|
||||
|
@ -69,7 +69,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
|
||||
for txo in [self.out1, self.out2]:
|
||||
txo_bin = TxOutputBinType()
|
||||
txo_bin.amount = txo.amount
|
||||
script_pubkey = output_derive_script(txo, coin)
|
||||
script_pubkey = output_derive_script(txo.address, coin)
|
||||
bip143.hash143_add_output(txo_bin, script_pubkey)
|
||||
|
||||
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
|
||||
@ -86,7 +86,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
|
||||
for txo in [self.out1, self.out2]:
|
||||
txo_bin = TxOutputBinType()
|
||||
txo_bin.amount = txo.amount
|
||||
script_pubkey = output_derive_script(txo, coin)
|
||||
script_pubkey = output_derive_script(txo.address, coin)
|
||||
bip143.hash143_add_output(txo_bin, script_pubkey)
|
||||
|
||||
# test data public key hash
|
||||
|
@ -58,7 +58,7 @@ class TestSegwitBip143(unittest.TestCase):
|
||||
for txo in [self.out1, self.out2]:
|
||||
txo_bin = TxOutputBinType()
|
||||
txo_bin.amount = txo.amount
|
||||
script_pubkey = output_derive_script(txo, coin)
|
||||
script_pubkey = output_derive_script(txo.address, coin)
|
||||
bip143.hash143_add_output(txo_bin, script_pubkey)
|
||||
|
||||
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
|
||||
@ -72,7 +72,7 @@ class TestSegwitBip143(unittest.TestCase):
|
||||
for txo in [self.out1, self.out2]:
|
||||
txo_bin = TxOutputBinType()
|
||||
txo_bin.amount = txo.amount
|
||||
script_pubkey = output_derive_script(txo, coin)
|
||||
script_pubkey = output_derive_script(txo.address, coin)
|
||||
bip143.hash143_add_output(txo_bin, script_pubkey)
|
||||
|
||||
# test data public key hash
|
||||
|
@ -34,7 +34,7 @@ class TestCalculateTxWeight(unittest.TestCase):
|
||||
|
||||
calculator = TxWeightCalculator(1, 1)
|
||||
calculator.add_input(inp1)
|
||||
calculator.add_output(output_derive_script(out1, coin))
|
||||
calculator.add_output(output_derive_script(out1.address, coin))
|
||||
|
||||
serialized_tx = '010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'
|
||||
tx_weight = len(serialized_tx) / 2 * 4 # non-segwit tx's weight is simple length*4
|
||||
@ -73,8 +73,8 @@ class TestCalculateTxWeight(unittest.TestCase):
|
||||
|
||||
calculator = TxWeightCalculator(1, 2)
|
||||
calculator.add_input(inp1)
|
||||
calculator.add_output(output_derive_script(out1, coin))
|
||||
calculator.add_output(output_derive_script(out2, coin))
|
||||
calculator.add_output(output_derive_script(out1.address, coin))
|
||||
calculator.add_output(output_derive_script(out2.address, coin))
|
||||
|
||||
self.assertEqual(calculator.get_total(), 670)
|
||||
# non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560
|
||||
@ -113,8 +113,8 @@ class TestCalculateTxWeight(unittest.TestCase):
|
||||
|
||||
calculator = TxWeightCalculator(1, 2)
|
||||
calculator.add_input(inp1)
|
||||
calculator.add_output(output_derive_script(out1, coin))
|
||||
calculator.add_output(output_derive_script(out2, coin))
|
||||
calculator.add_output(output_derive_script(out1.address, coin))
|
||||
calculator.add_output(output_derive_script(out2.address, coin))
|
||||
|
||||
self.assertEqual(calculator.get_total(), 566)
|
||||
# non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456
|
||||
|
Loading…
Reference in New Issue
Block a user