mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-24 15:28:10 +00:00
core/sign_tx: Check script_pubkeys of inputs.
This commit is contained in:
parent
cc655575c8
commit
99f01cd316
@ -1,10 +1,9 @@
|
|||||||
from trezor import utils, wire
|
from trezor import utils, wire
|
||||||
from trezor.crypto import base58, cashaddr
|
from trezor.crypto import base58, cashaddr
|
||||||
from trezor.crypto.hashlib import sha256
|
from trezor.crypto.hashlib import sha256
|
||||||
from trezor.messages import InputScriptType, OutputScriptType
|
from trezor.messages import InputScriptType
|
||||||
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
|
from trezor.messages.MultisigRedeemScriptType import MultisigRedeemScriptType
|
||||||
from trezor.messages.TxInputType import TxInputType
|
from trezor.messages.TxInputType import TxInputType
|
||||||
from trezor.messages.TxOutputType import TxOutputType
|
|
||||||
|
|
||||||
from apps.common import address_type
|
from apps.common import address_type
|
||||||
from apps.common.coininfo import CoinInfo
|
from apps.common.coininfo import CoinInfo
|
||||||
@ -65,21 +64,18 @@ def input_derive_script(
|
|||||||
raise wire.ProcessError("Invalid script type")
|
raise wire.ProcessError("Invalid script type")
|
||||||
|
|
||||||
|
|
||||||
def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes:
|
def output_derive_script(address: str, coin: CoinInfo) -> bytes:
|
||||||
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
|
if coin.bech32_prefix and address.startswith(coin.bech32_prefix):
|
||||||
return output_script_paytoopreturn(txo.op_return_data)
|
|
||||||
|
|
||||||
if coin.bech32_prefix and txo.address.startswith(coin.bech32_prefix):
|
|
||||||
# p2wpkh or p2wsh
|
# p2wpkh or p2wsh
|
||||||
witprog = common.decode_bech32_address(coin.bech32_prefix, txo.address)
|
witprog = common.decode_bech32_address(coin.bech32_prefix, address)
|
||||||
return output_script_native_p2wpkh_or_p2wsh(witprog)
|
return output_script_native_p2wpkh_or_p2wsh(witprog)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not utils.BITCOIN_ONLY
|
not utils.BITCOIN_ONLY
|
||||||
and coin.cashaddr_prefix is not None
|
and coin.cashaddr_prefix is not None
|
||||||
and txo.address.startswith(coin.cashaddr_prefix + ":")
|
and address.startswith(coin.cashaddr_prefix + ":")
|
||||||
):
|
):
|
||||||
prefix, addr = txo.address.split(":")
|
prefix, addr = address.split(":")
|
||||||
version, data = cashaddr.decode(prefix, addr)
|
version, data = cashaddr.decode(prefix, addr)
|
||||||
if version == cashaddr.ADDRESS_TYPE_P2KH:
|
if version == cashaddr.ADDRESS_TYPE_P2KH:
|
||||||
version = coin.address_type
|
version = coin.address_type
|
||||||
@ -90,7 +86,7 @@ def output_derive_script(txo: TxOutputType, coin: CoinInfo) -> bytes:
|
|||||||
raw_address = bytes([version]) + data
|
raw_address = bytes([version]) + data
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
raw_address = base58.decode_check(txo.address, coin.b58_hash)
|
raw_address = base58.decode_check(address, coin.b58_hash)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise wire.DataError("Invalid address")
|
raise wire.DataError("Invalid address")
|
||||||
|
|
||||||
|
@ -180,7 +180,9 @@ class Bitcoin:
|
|||||||
if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES:
|
if txi.script_type not in helpers.INTERNAL_INPUT_SCRIPT_TYPES:
|
||||||
raise wire.DataError("Wrong input script type")
|
raise wire.DataError("Wrong input script type")
|
||||||
|
|
||||||
prev_amount = await self.get_prevtx_output_value(txi.prev_hash, txi.prev_index)
|
prev_amount, script_pubkey = await self.get_prevtx_output(
|
||||||
|
txi.prev_hash, txi.prev_index
|
||||||
|
)
|
||||||
|
|
||||||
if txi.amount is not None and prev_amount != txi.amount:
|
if txi.amount is not None and prev_amount != txi.amount:
|
||||||
raise wire.DataError("Invalid amount specified")
|
raise wire.DataError("Invalid amount specified")
|
||||||
@ -328,7 +330,9 @@ class Bitcoin:
|
|||||||
script_pubkey = self.output_derive_script(txo)
|
script_pubkey = self.output_derive_script(txo)
|
||||||
self.write_tx_output(self.serialized_tx, txo, script_pubkey)
|
self.write_tx_output(self.serialized_tx, txo, script_pubkey)
|
||||||
|
|
||||||
async def get_prevtx_output_value(self, prev_hash: bytes, prev_index: int) -> int:
|
async def get_prevtx_output(
|
||||||
|
self, prev_hash: bytes, prev_index: int
|
||||||
|
) -> Tuple[int, bytes]:
|
||||||
amount_out = 0 # output amount
|
amount_out = 0 # output amount
|
||||||
|
|
||||||
# STAGE_REQUEST_2_PREV_META in legacy
|
# STAGE_REQUEST_2_PREV_META in legacy
|
||||||
@ -358,6 +362,7 @@ class Bitcoin:
|
|||||||
self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey)
|
self.write_tx_output(txh, txo_bin, txo_bin.script_pubkey)
|
||||||
if i == prev_index:
|
if i == prev_index:
|
||||||
amount_out = txo_bin.amount
|
amount_out = txo_bin.amount
|
||||||
|
script_pubkey = txo_bin.script_pubkey
|
||||||
self.check_prevtx_output(txo_bin)
|
self.check_prevtx_output(txo_bin)
|
||||||
|
|
||||||
await self.write_prev_tx_footer(txh, tx, prev_hash)
|
await self.write_prev_tx_footer(txh, tx, prev_hash)
|
||||||
@ -368,7 +373,7 @@ class Bitcoin:
|
|||||||
):
|
):
|
||||||
raise wire.ProcessError("Encountered invalid prev_hash")
|
raise wire.ProcessError("Encountered invalid prev_hash")
|
||||||
|
|
||||||
return amount_out
|
return amount_out, script_pubkey
|
||||||
|
|
||||||
def check_prevtx_output(self, txo_bin: TxOutputBinType) -> None:
|
def check_prevtx_output(self, txo_bin: TxOutputBinType) -> None:
|
||||||
# Validations to perform on the UTXO when checking the previous transaction output amount.
|
# Validations to perform on the UTXO when checking the previous transaction output amount.
|
||||||
@ -425,6 +430,9 @@ class Bitcoin:
|
|||||||
# ===
|
# ===
|
||||||
|
|
||||||
def output_derive_script(self, txo: TxOutputType) -> bytes:
|
def output_derive_script(self, txo: TxOutputType) -> bytes:
|
||||||
|
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
|
||||||
|
return scripts.output_script_paytoopreturn(txo.op_return_data)
|
||||||
|
|
||||||
if txo.address_n:
|
if txo.address_n:
|
||||||
# change output
|
# change output
|
||||||
try:
|
try:
|
||||||
@ -438,7 +446,7 @@ class Bitcoin:
|
|||||||
input_script_type, self.coin, node, txo.multisig
|
input_script_type, self.coin, node, txo.multisig
|
||||||
)
|
)
|
||||||
|
|
||||||
return scripts.output_derive_script(txo, self.coin)
|
return scripts.output_derive_script(txo.address, self.coin)
|
||||||
|
|
||||||
def output_is_change(self, txo: TxOutputType) -> bool:
|
def output_is_change(self, txo: TxOutputType) -> bool:
|
||||||
if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES:
|
if txo.script_type not in helpers.CHANGE_OUTPUT_SCRIPT_TYPES:
|
||||||
|
@ -69,7 +69,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
|
|||||||
for txo in [self.out1, self.out2]:
|
for txo in [self.out1, self.out2]:
|
||||||
txo_bin = TxOutputBinType()
|
txo_bin = TxOutputBinType()
|
||||||
txo_bin.amount = txo.amount
|
txo_bin.amount = txo.amount
|
||||||
script_pubkey = output_derive_script(txo, coin)
|
script_pubkey = output_derive_script(txo.address, coin)
|
||||||
bip143.hash143_add_output(txo_bin, script_pubkey)
|
bip143.hash143_add_output(txo_bin, script_pubkey)
|
||||||
|
|
||||||
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
|
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
|
||||||
@ -86,7 +86,7 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
|
|||||||
for txo in [self.out1, self.out2]:
|
for txo in [self.out1, self.out2]:
|
||||||
txo_bin = TxOutputBinType()
|
txo_bin = TxOutputBinType()
|
||||||
txo_bin.amount = txo.amount
|
txo_bin.amount = txo.amount
|
||||||
script_pubkey = output_derive_script(txo, coin)
|
script_pubkey = output_derive_script(txo.address, coin)
|
||||||
bip143.hash143_add_output(txo_bin, script_pubkey)
|
bip143.hash143_add_output(txo_bin, script_pubkey)
|
||||||
|
|
||||||
# test data public key hash
|
# test data public key hash
|
||||||
|
@ -58,7 +58,7 @@ class TestSegwitBip143(unittest.TestCase):
|
|||||||
for txo in [self.out1, self.out2]:
|
for txo in [self.out1, self.out2]:
|
||||||
txo_bin = TxOutputBinType()
|
txo_bin = TxOutputBinType()
|
||||||
txo_bin.amount = txo.amount
|
txo_bin.amount = txo.amount
|
||||||
script_pubkey = output_derive_script(txo, coin)
|
script_pubkey = output_derive_script(txo.address, coin)
|
||||||
bip143.hash143_add_output(txo_bin, script_pubkey)
|
bip143.hash143_add_output(txo_bin, script_pubkey)
|
||||||
|
|
||||||
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
|
outputs_hash = get_tx_hash(bip143.h_outputs, double=coin.sign_hash_double)
|
||||||
@ -72,7 +72,7 @@ class TestSegwitBip143(unittest.TestCase):
|
|||||||
for txo in [self.out1, self.out2]:
|
for txo in [self.out1, self.out2]:
|
||||||
txo_bin = TxOutputBinType()
|
txo_bin = TxOutputBinType()
|
||||||
txo_bin.amount = txo.amount
|
txo_bin.amount = txo.amount
|
||||||
script_pubkey = output_derive_script(txo, coin)
|
script_pubkey = output_derive_script(txo.address, coin)
|
||||||
bip143.hash143_add_output(txo_bin, script_pubkey)
|
bip143.hash143_add_output(txo_bin, script_pubkey)
|
||||||
|
|
||||||
# test data public key hash
|
# test data public key hash
|
||||||
|
@ -34,7 +34,7 @@ class TestCalculateTxWeight(unittest.TestCase):
|
|||||||
|
|
||||||
calculator = TxWeightCalculator(1, 1)
|
calculator = TxWeightCalculator(1, 1)
|
||||||
calculator.add_input(inp1)
|
calculator.add_input(inp1)
|
||||||
calculator.add_output(output_derive_script(out1, coin))
|
calculator.add_output(output_derive_script(out1.address, coin))
|
||||||
|
|
||||||
serialized_tx = '010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'
|
serialized_tx = '010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000'
|
||||||
tx_weight = len(serialized_tx) / 2 * 4 # non-segwit tx's weight is simple length*4
|
tx_weight = len(serialized_tx) / 2 * 4 # non-segwit tx's weight is simple length*4
|
||||||
@ -73,8 +73,8 @@ class TestCalculateTxWeight(unittest.TestCase):
|
|||||||
|
|
||||||
calculator = TxWeightCalculator(1, 2)
|
calculator = TxWeightCalculator(1, 2)
|
||||||
calculator.add_input(inp1)
|
calculator.add_input(inp1)
|
||||||
calculator.add_output(output_derive_script(out1, coin))
|
calculator.add_output(output_derive_script(out1.address, coin))
|
||||||
calculator.add_output(output_derive_script(out2, coin))
|
calculator.add_output(output_derive_script(out2.address, coin))
|
||||||
|
|
||||||
self.assertEqual(calculator.get_total(), 670)
|
self.assertEqual(calculator.get_total(), 670)
|
||||||
# non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560
|
# non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560
|
||||||
@ -113,8 +113,8 @@ class TestCalculateTxWeight(unittest.TestCase):
|
|||||||
|
|
||||||
calculator = TxWeightCalculator(1, 2)
|
calculator = TxWeightCalculator(1, 2)
|
||||||
calculator.add_input(inp1)
|
calculator.add_input(inp1)
|
||||||
calculator.add_output(output_derive_script(out1, coin))
|
calculator.add_output(output_derive_script(out1.address, coin))
|
||||||
calculator.add_output(output_derive_script(out2, coin))
|
calculator.add_output(output_derive_script(out2.address, coin))
|
||||||
|
|
||||||
self.assertEqual(calculator.get_total(), 566)
|
self.assertEqual(calculator.get_total(), 566)
|
||||||
# non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456
|
# non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456
|
||||||
|
Loading…
Reference in New Issue
Block a user