feat(core): In apps.bitcoin implement replacement transaction flow.

pull/1331/head
Andrew Kozlik 4 years ago committed by Andrew Kozlik
parent bd3fe1d789
commit 4a0c5c371a

@ -1,15 +1,17 @@
from micropython import const
from trezor import wire
from trezor.messages import OutputScriptType
from apps.common import safety_checks
from .. import addresses
from ..authorization import FEE_PER_ANONYMITY_DECIMALS
from . import helpers, tx_weight
from .tx_info import TxInfo
from .tx_info import OriginalTxInfo, TxInfo
if False:
from typing import List, Optional
from trezor.messages.SignTx import SignTx
from trezor.messages.TxInput import TxInput
from trezor.messages.TxOutput import TxOutput
@ -34,25 +36,48 @@ class Approver:
self.total_out = 0 # sum of output amounts
self.change_out = 0 # sum of change output amounts
# amounts in original transactions when this is a replacement transaction
self.orig_total_in = 0 # sum of original input amounts
self.orig_external_in = 0 # sum of original external input amounts
self.orig_total_out = 0 # sum of original output amounts
self.orig_change_out = 0 # sum of original change output amounts
async def add_internal_input(self, txi: TxInput) -> None:
self.weight.add_input(txi)
self.total_in += txi.amount
if txi.orig_hash:
self.orig_total_in += txi.amount
def add_external_input(self, txi: TxInput) -> None:
self.weight.add_input(txi)
self.total_in += txi.amount
self.external_in += txi.amount
if txi.orig_hash:
self.orig_total_in += txi.amount
self.orig_external_in += txi.amount
def add_change_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
self.weight.add_output(script_pubkey)
self.total_out += txo.amount
self.change_out += txo.amount
async def add_external_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
def add_orig_change_output(self, txo: TxOutput) -> None:
self.orig_total_out += txo.amount
self.orig_change_out += txo.amount
async def add_external_output(
self,
txo: TxOutput,
script_pubkey: bytes,
orig_txo: Optional[TxOutput] = None,
) -> None:
self.weight.add_output(script_pubkey)
self.total_out += txo.amount
async def approve_tx(self, tx_info: TxInfo) -> None:
def add_orig_external_output(self, txo: TxOutput) -> None:
self.orig_total_out += txo.amount
async def approve_tx(self, tx_info: TxInfo, orig_txs: List[OriginalTxInfo]) -> None:
raise NotImplementedError
@ -74,11 +99,31 @@ class BasicApprover(Approver):
super().add_change_output(txo, script_pubkey)
self.change_count += 1
async def add_external_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
await super().add_external_output(txo, script_pubkey)
await helpers.confirm_output(txo, self.coin)
async def add_external_output(
self,
txo: TxOutput,
script_pubkey: bytes,
orig_txo: Optional[TxOutput] = None,
) -> None:
await super().add_external_output(txo, script_pubkey, orig_txo)
# Replacement transactions must not decrease the value of any external outputs.
if orig_txo and txo.amount < orig_txo.amount:
raise wire.ProcessError(
"Reducing original output amounts is not supported."
)
if self.orig_total_in:
# Skip output confirmation for replacement transactions,
# but don't allow adding new OP_RETURN outputs.
if txo.script_type == OutputScriptType.PAYTOOPRETURN and not orig_txo:
raise wire.ProcessError(
"Adding new OP_RETURN outputs in replacement transactions is not supported."
)
else:
await helpers.confirm_output(txo, self.coin)
async def approve_tx(self, tx_info: TxInfo) -> None:
async def approve_tx(self, tx_info: TxInfo, orig_txs: List[OriginalTxInfo]) -> None:
fee = self.total_in - self.total_out
# some coins require negative fees for reward TX
@ -95,16 +140,56 @@ class BasicApprover(Approver):
if fee > 10 * fee_threshold and safety_checks.is_strict():
raise wire.DataError("The fee is unexpectedly large")
await helpers.confirm_feeoverthreshold(fee, self.coin)
if self.change_count > self.MAX_SILENT_CHANGE_COUNT:
await helpers.confirm_change_count_over_threshold(self.change_count)
if tx_info.tx.lock_time > 0:
await helpers.confirm_nondefault_locktime(
tx_info.tx.lock_time, tx_info.lock_time_disabled()
if orig_txs:
# Replacement transaction.
orig_spending = (
self.orig_total_in - self.orig_change_out - self.orig_external_in
)
if not self.external_in:
await helpers.confirm_total(total, fee, self.coin)
orig_fee = self.orig_total_in - self.orig_total_out
# Replacement transactions are only allowed to make amendments which
# do not increase the amount that we are spending on external outputs.
# In other words, the total amount being sent out of the wallet must
# not increase by more than the fee difference (so additional funds
# can only go towards the fee, which is confirmed by the user).
if spending - orig_spending > fee - orig_fee:
raise wire.ProcessError("Invalid replacement transaction.")
# Replacement transactions must not change the effective nLockTime.
lock_time = 0 if tx_info.lock_time_disabled() else tx_info.tx.lock_time
for orig in orig_txs:
orig_lock_time = 0 if orig.lock_time_disabled() else orig.tx.lock_time
if lock_time != orig_lock_time:
raise wire.ProcessError(
"Original transactions must have same effective nLockTime as replacement transaction."
)
if self.external_in > self.orig_external_in:
description = "PayJoin"
elif len(orig_txs) > 1:
description = "Transaction meld"
else:
description = "Fee modification"
for orig in orig_txs:
await helpers.confirm_replacement(description, orig.orig_hash)
await helpers.confirm_modify_fee(spending - orig_spending, fee, self.coin)
else:
await helpers.confirm_joint_total(spending, total, self.coin)
# Standard transaction.
if tx_info.tx.lock_time > 0:
await helpers.confirm_nondefault_locktime(
tx_info.tx.lock_time, tx_info.lock_time_disabled()
)
if not self.external_in:
await helpers.confirm_total(total, fee, self.coin)
else:
await helpers.confirm_joint_total(spending, total, self.coin)
class CoinJoinApprover(Approver):
@ -147,11 +232,16 @@ class CoinJoinApprover(Approver):
self.our_weight.add_output(script_pubkey)
self.group_our_count += 1
async def add_external_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
await super().add_external_output(txo, script_pubkey)
async def add_external_output(
self,
txo: TxOutput,
script_pubkey: bytes,
orig_txo: Optional[TxOutput] = None,
) -> None:
await super().add_external_output(txo, script_pubkey, orig_txo)
self._add_output(txo, script_pubkey)
async def approve_tx(self, tx_info: TxInfo) -> None:
async def approve_tx(self, tx_info: TxInfo, orig_txs: List[OriginalTxInfo]) -> None:
# The mining fee of the transaction as a whole.
mining_fee = self.total_in - self.total_out

@ -16,7 +16,7 @@ from ..ownership import verify_nonownership
from ..verification import SignatureVerifier
from . import approvers, helpers, progress
from .hash143 import Hash143
from .tx_info import TxInfo
from .tx_info import OriginalTxInfo, TxInfo
if False:
from typing import List, Optional, Set, Tuple, Union
@ -47,7 +47,7 @@ class Bitcoin:
await self.step2_approve_outputs()
# Check fee, approve lock_time and total.
await self.approver.approve_tx(self.tx_info)
await self.approver.approve_tx(self.tx_info, self.orig_txs)
# Verify the transaction input amounts by requesting each previous transaction
# and checking its output amount. Verify external inputs which have already
@ -91,6 +91,12 @@ class Bitcoin:
self.tx_req.serialized = TxRequestSerializedType()
self.tx_req.serialized.serialized_tx = self.serialized_tx
# List of original transactions which are being replaced by the current transaction.
# Note: A List is better than a Dict of TXID -> OriginalTxInfo. Dict ordering is
# undefined so we would need to convert to a sorted list in several places to ensure
# stable device tests.
self.orig_txs = [] # type: List[OriginalTxInfo]
progress.init(tx.inputs_count, tx.outputs_count)
def create_hash_writer(self) -> HashWriter:
@ -115,14 +121,35 @@ class Bitcoin:
else:
await self.process_internal_input(txi)
if txi.orig_hash:
await self.process_original_input(txi)
self.tx_info.h_inputs = self.tx_info.h_tx_check.get_digest()
# Finalize original inputs.
for orig in self.orig_txs:
if orig.index != orig.tx.inputs_count:
raise wire.ProcessError("Removal of original inputs is not supported.")
orig.index = 0 # Reset counter for outputs.
async def step2_approve_outputs(self) -> None:
for i in range(self.tx_info.tx.outputs_count):
# STAGE_REQUEST_2_OUTPUT in legacy
txo = await helpers.request_tx_output(self.tx_req, i, self.coin)
script_pubkey = self.output_derive_script(txo)
await self.approve_output(txo, script_pubkey)
orig_txo = None # type: Optional[TxOutput]
if txo.orig_hash:
orig_txo = await self.get_original_output(txo, script_pubkey)
await self.approve_output(txo, script_pubkey, orig_txo)
# Finalize original outputs.
for orig in self.orig_txs:
# Fetch remaining removed original outputs.
await self.fetch_removed_original_outputs(
orig, orig.orig_hash, orig.tx.outputs_count
)
await orig.finalize_tx_hash()
async def step3_verify_inputs(self) -> None:
# should come out the same as h_inputs, checked before continuing
@ -146,6 +173,9 @@ class Bitcoin:
if h_check.get_digest() != self.tx_info.h_inputs:
raise wire.ProcessError("Transaction has changed during signing")
# verify the signature of one SIGHASH_ALL input in each original transaction
await self.verify_original_txs()
async def step4_serialize_inputs(self) -> None:
self.write_tx_header(self.serialized_tx, self.tx_info.tx, bool(self.segwit))
write_bitcoin_varint(self.serialized_tx, self.tx_info.tx.inputs_count)
@ -194,12 +224,153 @@ class Bitcoin:
async def process_external_input(self, txi: TxInput) -> None:
self.approver.add_external_input(txi)
async def approve_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
async def process_original_input(self, txi: TxInput) -> None:
assert txi.orig_hash is not None
assert txi.orig_index is not None
for orig in self.orig_txs:
if orig.orig_hash == txi.orig_hash:
break
else:
orig_meta = await helpers.request_tx_meta(
self.tx_req, self.coin, txi.orig_hash
)
orig = OriginalTxInfo(self, orig_meta, txi.orig_hash)
self.orig_txs.append(orig)
if txi.orig_index >= orig.tx.inputs_count:
raise wire.ProcessError("Not enough inputs in original transaction.")
if orig.index != txi.orig_index:
raise wire.ProcessError(
"Rearranging or removal of original inputs is not supported."
)
orig_txi = await helpers.request_tx_input(
self.tx_req, txi.orig_index, self.coin, txi.orig_hash
)
# Verify that the original input matches:
# An input is characterized by its prev_hash and prev_index. We also check that the
# amounts match, so that we don't have to call get_prevtx_output() twice for the same
# prevtx output. Verifying that script_type matches is just a sanity check, because
# because we count both inputs as internal or external based only on txi.script_type.
if (
orig_txi.prev_hash != txi.prev_hash
or orig_txi.prev_index != txi.prev_index
or orig_txi.amount != txi.amount
or orig_txi.script_type != txi.script_type
):
raise wire.ProcessError("Original input does not match current input.")
orig.add_input(orig_txi)
orig.index += 1
async def fetch_removed_original_outputs(
self, orig: OriginalTxInfo, orig_hash: bytes, last_index: int
) -> None:
while orig.index < last_index:
txo = await helpers.request_tx_output(
self.tx_req, orig.index, self.coin, orig_hash
)
orig.add_output(txo, self.output_derive_script(txo))
if orig.output_is_change(txo):
# Removal of change-outputs is allowed.
self.approver.add_orig_change_output(txo)
else:
# Removal of external outputs requires prompting the user. Not implemented.
raise wire.ProcessError(
"Removal of original external outputs is not supported."
)
orig.index += 1
async def get_original_output(
self, txo: TxOutput, script_pubkey: bytes
) -> TxOutput:
assert txo.orig_hash is not None
assert txo.orig_index is not None
for orig in self.orig_txs:
if orig.orig_hash == txo.orig_hash:
break
else:
raise wire.ProcessError("Unknown original transaction.")
if txo.orig_index >= orig.tx.outputs_count:
raise wire.ProcessError("Not enough outputs in original transaction.")
if orig.index > txo.orig_index:
raise wire.ProcessError("Rearranging of original outputs is not supported.")
# First fetch any removed original outputs which precede the one we want.
await self.fetch_removed_original_outputs(orig, txo.orig_hash, txo.orig_index)
orig_txo = await helpers.request_tx_output(
self.tx_req, orig.index, self.coin, txo.orig_hash
)
if script_pubkey != self.output_derive_script(orig_txo):
raise wire.ProcessError("Not an original output.")
if self.tx_info.output_is_change(txo) and not orig.output_is_change(orig_txo):
raise wire.ProcessError(
"Original output is missing change-output parameters."
)
orig.add_output(orig_txo, script_pubkey)
if orig.output_is_change(orig_txo):
self.approver.add_orig_change_output(orig_txo)
else:
self.approver.add_orig_external_output(orig_txo)
orig.index += 1
return orig_txo
async def verify_original_txs(self) -> None:
for orig in self.orig_txs:
if orig.verification_input is None:
raise wire.ProcessError(
"Each original transaction must specify address_n for at least one input."
)
assert orig.verification_index is not None
txi = orig.verification_input
node = self.keychain.derive(txi.address_n)
address = addresses.get_address(
txi.script_type, self.coin, node, txi.multisig
)
script_pubkey = scripts.output_derive_script(address, self.coin)
verifier = SignatureVerifier(
script_pubkey, txi.script_sig, txi.witness, self.coin
)
verifier.ensure_hash_type(SIGHASH_ALL)
tx_digest = await self.get_tx_digest(
orig.verification_index,
txi,
orig,
verifier.public_keys,
verifier.threshold,
script_pubkey,
)
verifier.verify(tx_digest)
async def approve_output(
self,
txo: TxOutput,
script_pubkey: bytes,
orig_txo: Optional[TxOutput],
) -> None:
if self.tx_info.output_is_change(txo):
# Output is change and does not need approval.
self.approver.add_change_output(txo, script_pubkey)
else:
await self.approver.add_external_output(txo, script_pubkey)
await self.approver.add_external_output(txo, script_pubkey, orig_txo)
self.tx_info.add_output(txo, script_pubkey)
@ -207,11 +378,10 @@ class Bitcoin:
self,
i: int,
txi: TxInput,
tx_info: TxInfo,
tx_info: Union[TxInfo, OriginalTxInfo],
public_keys: List[bytes],
threshold: int,
script_pubkey: bytes,
tx_hash: Optional[bytes] = None,
) -> bytes:
if txi.witness:
return tx_info.hash143.preimage_hash(
@ -223,9 +393,7 @@ class Bitcoin:
self.get_sighash_type(txi),
)
else:
digest, _, _ = await self.get_legacy_tx_digest(
i, tx, h_approved, script_pubkey, tx_hash
)
digest, _, _ = await self.get_legacy_tx_digest(i, tx_info, script_pubkey)
return digest
async def verify_external_input(
@ -330,8 +498,8 @@ class Bitcoin:
index: int,
tx_info: Union[TxInfo, OriginalTxInfo],
script_pubkey: Optional[bytes] = None,
tx_hash: Optional[bytes] = None,
) -> Tuple[bytes, TxInput, Optional[bip32.HDNode]]:
tx_hash = tx_info.orig_hash if isinstance(tx_info, OriginalTxInfo) else None
# the transaction digest which gets signed for this input
h_sign = self.create_hash_writer()

@ -88,6 +88,9 @@ class Decred(Bitcoin):
async def process_external_input(self, txi: TxInput) -> None:
raise wire.DataError("External inputs not supported")
async def process_original_input(self, txi: TxInput) -> None:
raise wire.DataError("Replacement transactions not supported")
async def approve_output(
self,
txo: TxOutput,

@ -334,6 +334,8 @@ def sanitize_tx_input(txi: TxInput, coin: CoinInfo) -> TxInput:
raise wire.DataError("Segwit not enabled on this coin.")
if txi.commitment_data and not txi.ownership_proof:
raise wire.DataError("commitment_data field provided but not expected.")
if txi.orig_hash and txi.orig_index is None:
raise wire.DataError("Missing orig_index field.")
return txi
@ -369,4 +371,6 @@ def sanitize_tx_output(txo: TxOutput, coin: CoinInfo) -> TxOutput:
raise wire.DataError("Both address and address_n provided.")
if not txo.address_n and not txo.address:
raise wire.DataError("Missing address")
if txo.orig_hash and txo.orig_index is None:
raise wire.DataError("Missing orig_index field.")
return txo

@ -133,3 +133,51 @@ class TxInfo(TxInfoBase):
# is used to ensure that the inputs streamed for verification in Step 3 are
# the same as those in Step 1.
self.h_inputs = None # type: Optional[bytes]
# Used to keep track of any original transactions which are being replaced by the current transaction.
class OriginalTxInfo(TxInfoBase):
def __init__(self, signer: Signer, tx: PrevTx, orig_hash: bytes) -> None:
super().__init__(signer)
self.tx = tx
self.signer = signer
self.orig_hash = orig_hash
# Index of the next input or output to be added by add_input or add_output. Signer uses this
# value to check that original transaction inputs and outputs are streamed in order, and to
# check whether any have been skipped. Incrementing and resetting this variable is the
# responsibility of the signer class.
self.index = 0
# Transaction hasher to compute the TXID.
self.h_tx = signer.create_hash_writer()
signer.write_tx_header(self.h_tx, tx, witness_marker=False)
writers.write_bitcoin_varint(self.h_tx, tx.inputs_count)
# The input which will be used for verification and its index in the original transaction.
self.verification_input = None # type: Optional[TxInput]
self.verification_index = None # type: Optional[int]
def add_input(self, txi: TxInput) -> None:
super().add_input(txi)
self.signer.write_tx_input(self.h_tx, txi, txi.script_sig or bytes())
# For verification use the first original input that specifies address_n.
if not self.verification_input and txi.address_n:
self.verification_input = txi
self.verification_index = self.index
def add_output(self, txo: TxOutput, script_pubkey: bytes) -> None:
super().add_output(txo, script_pubkey)
if self.index == 0:
writers.write_bitcoin_varint(self.h_tx, self.tx.outputs_count)
self.signer.write_tx_output(self.h_tx, txo, script_pubkey)
async def finalize_tx_hash(self) -> None:
await self.signer.write_prev_tx_footer(self.h_tx, self.tx, self.orig_hash)
if self.orig_hash != writers.get_tx_hash(
self.h_tx, double=self.signer.coin.sign_hash_double, reverse=True
):
raise wire.ProcessError("Invalid original TXID.")

@ -120,7 +120,7 @@ class TestApprover(unittest.TestCase):
else:
await_result(approver.add_external_output(txo, script_pubkey=bytes(22)))
await_result(approver.approve_tx(TxInfo(signer, tx)))
await_result(approver.approve_tx(TxInfo(signer, tx), []))
def test_coinjoin_input_account_depth_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth, None, self.coin)

Loading…
Cancel
Save