chore(core): Support script_pubkey parameter for Bitcoin inputs.

pull/1918/head
Andrew Kozlik 3 years ago committed by Andrew Kozlik
parent 630c06e782
commit 99e4ed6f42

@ -128,8 +128,8 @@ class Bitcoin:
for i in range(self.tx_info.tx.inputs_count):
# STAGE_REQUEST_1_INPUT in legacy
txi = await helpers.request_tx_input(self.tx_req, i, self.coin)
self.tx_info.add_input(txi)
script_pubkey = self.input_derive_script(txi)
self.tx_info.add_input(txi, script_pubkey)
if input_is_segwit(txi):
self.segwit.add(i)
@ -141,7 +141,7 @@ class Bitcoin:
await self.process_internal_input(txi)
if txi.orig_hash:
await self.process_original_input(txi)
await self.process_original_input(txi, script_pubkey)
self.h_inputs = self.tx_info.get_tx_check_digest()
@ -243,7 +243,7 @@ class Bitcoin:
async def process_external_input(self, txi: TxInput) -> None:
self.approver.add_external_input(txi)
async def process_original_input(self, txi: TxInput) -> None:
async def process_original_input(self, txi: TxInput, script_pubkey: bytes) -> None:
assert txi.orig_hash is not None
assert txi.orig_index is not None
@ -270,19 +270,24 @@ class Bitcoin:
)
# Verify that the original input matches:
#
# An input is characterized by its prev_hash and prev_index. We also check that the
# amounts match, so that we don't have to call get_prevtx_output() twice for the same
# prevtx output. Verifying that script_type matches is just a sanity check, because
# because we count both inputs as internal or external based only on txi.script_type.
# prevtx output. Verifying that script_type matches is just a sanity check, because we
# count both inputs as internal or external based only on txi.script_type.
#
# When all inputs are taproot, we don't check the prevtxs, so we have to ensure that the
# claims about the script_pubkey values and amounts remain consistent throughout.
if (
orig_txi.prev_hash != txi.prev_hash
or orig_txi.prev_index != txi.prev_index
or orig_txi.amount != txi.amount
or orig_txi.script_type != txi.script_type
or self.input_derive_script(orig_txi) != script_pubkey
):
raise wire.ProcessError("Original input does not match current input.")
orig.add_input(orig_txi)
orig.add_input(orig_txi, script_pubkey)
orig.index += 1
async def fetch_removed_original_outputs(
@ -358,13 +363,7 @@ class Bitcoin:
assert orig.verification_index is not None
txi = orig.verification_input
node = self.keychain.derive(txi.address_n)
address = addresses.get_address(
txi.script_type, self.coin, node, txi.multisig
)
script_pubkey = scripts.output_derive_script(address, self.coin)
script_pubkey = self.input_derive_script(txi)
verifier = SignatureVerifier(
script_pubkey, txi.script_sig, txi.witness, self.coin
)
@ -725,9 +724,18 @@ class Bitcoin:
self.tx_req.serialized.signature_index = index
self.tx_req.serialized.signature = signature
# Tx Outputs
# scriptPubKey derivation
# ===
def input_derive_script(self, txi: TxInput) -> bytes:
if input_is_external(txi):
assert txi.script_pubkey is not None # checked in sanitize_tx_input
return txi.script_pubkey
node = self.keychain.derive(txi.address_n)
address = addresses.get_address(txi.script_type, self.coin, node, txi.multisig)
return scripts.output_derive_script(address, self.coin)
def output_derive_script(self, txo: TxOutput) -> bytes:
if txo.script_type == OutputScriptType.PAYTOOPRETURN:
assert txo.op_return_data is not None # checked in sanitize_tx_output

@ -54,7 +54,7 @@ class DecredHash:
def __init__(self, h_prefix: HashWriter) -> None:
self.h_prefix = h_prefix
def add_input(self, txi: TxInput) -> None:
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
Decred.write_tx_input(self.h_prefix, txi, bytes())
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
@ -122,7 +122,7 @@ class Decred(Bitcoin):
async def process_external_input(self, txi: TxInput) -> None:
raise wire.DataError("External inputs not supported")
async def process_original_input(self, txi: TxInput) -> None:
async def process_original_input(self, txi: TxInput, script_pubkey: bytes) -> None:
raise wire.DataError("Replacement transactions not supported")
async def approve_output(

@ -5,12 +5,13 @@ from trezor.utils import HashWriter
from apps.common import coininfo
from .. import scripts, writers
from ..common import input_is_taproot, tagged_hashwriter
if False:
from typing import Protocol, Sequence
class Hash143(Protocol):
def add_input(self, txi: TxInput) -> None:
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
...
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
@ -35,7 +36,7 @@ class Bip143Hash:
self.h_sequence = HashWriter(sha256())
self.h_outputs = HashWriter(sha256())
def add_input(self, txi: TxInput) -> None:
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
writers.write_bytes_reversed(
self.h_prevouts, txi.prev_hash, writers.TX_HASH_SIZE
)

@ -392,8 +392,18 @@ def sanitize_tx_input(txi: TxInput, coin: CoinInfo) -> TxInput:
if not txi.multisig and txi.script_type == InputScriptType.SPENDMULTISIG:
raise wire.DataError("Multisig details required.")
if txi.address_n and txi.script_type not in common.INTERNAL_INPUT_SCRIPT_TYPES:
raise wire.DataError("Input's address_n provided but not expected.")
if txi.script_type in common.INTERNAL_INPUT_SCRIPT_TYPES:
if not txi.address_n:
raise wire.DataError("Missing address_n field.")
if txi.script_pubkey:
raise wire.DataError("Input's script_pubkey provided but not expected.")
else:
if txi.address_n:
raise wire.DataError("Input's address_n provided but not expected.")
if not txi.script_pubkey:
raise wire.DataError("Missing script_pubkey field.")
if not coin.decred and txi.decred_tree is not None:
raise wire.DataError("Decred details provided but Decred coin not specified.")

@ -77,8 +77,10 @@ class TxInfoBase:
# The minimum nSequence of all inputs.
self.min_sequence = _SEQUENCE_FINAL
def add_input(self, txi: TxInput) -> None:
self.hash143.add_input(txi) # all inputs are included (non-segwit as well)
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
self.hash143.add_input(
txi, script_pubkey
) # all inputs are included (non-segwit as well)
writers.write_tx_input_check(self.h_tx_check, txi)
self.min_sequence = min(self.min_sequence, txi.sequence)
@ -147,8 +149,8 @@ class OriginalTxInfo(TxInfoBase):
self.verification_input: TxInput | None = None
self.verification_index: int | None = None
def add_input(self, txi: TxInput) -> None:
super().add_input(txi)
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
super().add_input(txi, script_pubkey)
writers.write_tx_input(self.h_tx, txi, txi.script_sig or bytes())
# For verification use the first original input that specifies address_n.

@ -39,7 +39,7 @@ class Zip243Hash:
self.h_sequence = HashWriter(blake2b(outlen=32, personal=b"ZcashSequencHash"))
self.h_outputs = HashWriter(blake2b(outlen=32, personal=b"ZcashOutputsHash"))
def add_input(self, txi: TxInput) -> None:
def add_input(self, txi: TxInput, script_pubkey: bytes) -> None:
write_bytes_reversed(self.h_prevouts, txi.prev_hash, TX_HASH_SIZE)
write_uint32(self.h_prevouts, txi.prev_index)
write_uint32(self.h_sequence, txi.sequence)

@ -53,6 +53,7 @@ def write_tx_input_check(w: Writer, i: TxInput) -> None:
write_uint32(w, n)
write_uint32(w, i.sequence)
write_uint64(w, i.amount or 0)
write_bytes_prefixed(w, i.script_pubkey or b"")
def write_tx_output(w: Writer, o: TxOutput | PrevOutput, script_pubkey: bytes) -> None:

@ -40,6 +40,7 @@ class TestApprover(unittest.TestCase):
prev_hash=b"",
prev_index=0,
amount=denomination + 1000000 * (i + 1),
script_pubkey=bytes(22),
script_type=InputScriptType.EXTERNAL,
sequence=0xffffffff,
) for i in range(99)

@ -50,16 +50,16 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
def test_prevouts(self):
coin = coins.by_name(self.tx.coin_name)
bip143 = Bip143Hash()
bip143.add_input(self.inp1)
bip143.add_input(self.inp2)
bip143.add_input(self.inp1, b"")
bip143.add_input(self.inp2, b"")
prevouts_hash = get_tx_hash(bip143.h_prevouts, double=coin.sign_hash_double)
self.assertEqual(hexlify(prevouts_hash), b'96b827c8483d4e9b96712b6713a7b68d6e8003a781feba36c31143470b4efd37')
def test_sequence(self):
coin = coins.by_name(self.tx.coin_name)
bip143 = Bip143Hash()
bip143.add_input(self.inp1)
bip143.add_input(self.inp2)
bip143.add_input(self.inp1, b"")
bip143.add_input(self.inp2, b"")
sequence_hash = get_tx_hash(bip143.h_sequence, double=coin.sign_hash_double)
self.assertEqual(hexlify(sequence_hash), b'52b0a642eea2fb7ae638c36f6252b6750293dbe574a806984b8e4d8548339a3b')
@ -82,8 +82,8 @@ class TestSegwitBip143NativeP2WPKH(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
coin = coins.by_name(self.tx.coin_name)
bip143 = Bip143Hash()
bip143.add_input(self.inp1)
bip143.add_input(self.inp2)
bip143.add_input(self.inp1, b"")
bip143.add_input(self.inp2, b"")
for txo in [self.out1, self.out2]:
script_pubkey = output_derive_script(txo.address, coin)

@ -42,14 +42,14 @@ class TestSegwitBip143(unittest.TestCase):
def test_bip143_prevouts(self):
coin = coins.by_name(self.tx.coin_name)
bip143 = Bip143Hash()
bip143.add_input(self.inp1)
bip143.add_input(self.inp1, b"")
prevouts_hash = get_tx_hash(bip143.h_prevouts, double=coin.sign_hash_double)
self.assertEqual(hexlify(prevouts_hash), b'b0287b4a252ac05af83d2dcef00ba313af78a3e9c329afa216eb3aa2a7b4613a')
def test_bip143_sequence(self):
coin = coins.by_name(self.tx.coin_name)
bip143 = Bip143Hash()
bip143.add_input(self.inp1)
bip143.add_input(self.inp1, b"")
sequence_hash = get_tx_hash(bip143.h_sequence, double=coin.sign_hash_double)
self.assertEqual(hexlify(sequence_hash), b'18606b350cd8bf565266bc352f0caddcf01e8fa789dd8a15386327cf8cabe198')
@ -70,7 +70,7 @@ class TestSegwitBip143(unittest.TestCase):
seed = bip39.seed('alcohol woman abuse must during monitor noble actual mixed trade anger aisle', '')
coin = coins.by_name(self.tx.coin_name)
bip143 = Bip143Hash()
bip143.add_input(self.inp1)
bip143.add_input(self.inp1, b"")
for txo in [self.out1, self.out2]:
script_pubkey = output_derive_script(txo.address, coin)
txo_bin = PrevOutput(amount=txo.amount, script_pubkey=script_pubkey)

@ -37,12 +37,13 @@ class TestWriters(unittest.TestCase):
prev_index=0,
script_type=InputScriptType.SPENDWITNESS,
sequence=0xffffffff,
script_pubkey=unhexlify("76a91424a56db43cf6f2b02e838ea493f95d8d6047423188ac"),
script_sig=b"0123456789",
)
b = bytearray()
writers.write_tx_input_check(b, inp)
self.assertEqual(len(b), 32 + 4 + 4 + 4 + 4 + 4 + 8)
self.assertEqual(len(b), 32 + 4 + 4 + 4 + 4 + 4 + 8 + 26)
for bad_prevhash in (b"", b"x", b"hello", b"x" * 33):
inp.prev_hash = bad_prevhash

@ -201,7 +201,7 @@ class TestZcashZip243(unittest.TestCase):
script_type = i["script_type"],
sequence = i["sequence"],
)
zip243.add_input(txi)
zip243.add_input(txi, b"")
for o in v["outputs"]:
txo = PrevOutput(

Loading…
Cancel
Save