From 9b579094c096f4f95c2e61085ff2060f0bf740fc Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Sun, 31 Oct 2021 09:29:54 +0100 Subject: [PATCH] fix(core): Fix TxWeightCalculator. - Refactor TxWeightCalculator to count inputs and outputs itself. - Fix witness data weight by adding the weight of the witness stack item count for each input in segwit transactions and removing the weight of the nonsensical extra inputs count. - Get multisig pubkey count from multisig.nodes or multisig.pubkeys like in multisig_get_pubkeys(). - Fix size of multisig script length encoding in segwit (varint vs. OP_PUSH). - Improve comments. --- core/src/apps/bitcoin/sign_tx/approvers.py | 6 +- core/src/apps/bitcoin/sign_tx/tx_weight.py | 65 +++++++++++----------- core/tests/test_apps.bitcoin.txweight.py | 10 ++-- 3 files changed, 41 insertions(+), 40 deletions(-) diff --git a/core/src/apps/bitcoin/sign_tx/approvers.py b/core/src/apps/bitcoin/sign_tx/approvers.py index 56b67293d..7187c59cd 100644 --- a/core/src/apps/bitcoin/sign_tx/approvers.py +++ b/core/src/apps/bitcoin/sign_tx/approvers.py @@ -27,7 +27,7 @@ if False: class Approver: def __init__(self, tx: SignTx, coin: CoinInfo) -> None: self.coin = coin - self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count) + self.weight = tx_weight.TxWeightCalculator() # amounts in the current transaction self.total_in = 0 # sum of input amounts @@ -267,9 +267,7 @@ class CoinJoinApprover(Approver): raise wire.DataError("Coin name does not match authorization.") # Upper bound on the user's contribution to the weight of the transaction. - self.our_weight = tx_weight.TxWeightCalculator( - tx.inputs_count, tx.outputs_count - ) + self.our_weight = tx_weight.TxWeightCalculator() # base for coordinator fee to be multiplied by fee_per_anonymity self.coordinator_fee_base = 0 diff --git a/core/src/apps/bitcoin/sign_tx/tx_weight.py b/core/src/apps/bitcoin/sign_tx/tx_weight.py index 5401d9f05..a80a67c91 100644 --- a/core/src/apps/bitcoin/sign_tx/tx_weight.py +++ b/core/src/apps/bitcoin/sign_tx/tx_weight.py @@ -27,7 +27,7 @@ _TXSIZE_INPUT = const(40) _TXSIZE_OUTPUT = const(8) # size of a pubkey _TXSIZE_PUBKEY = const(33) -# size of a DER signature (3 type bytes, 3 len bytes, 33 R, 32 S, 1 sighash +# maximum size of a DER signature (3 type bytes, 3 len bytes, 33 R, 32 S, 1 sighash) _TXSIZE_SIGNATURE = const(72) # size of a multiscript without pubkey (1 M, 1 N, 1 checksig) _TXSIZE_MULTISIGSCRIPT = const(3) @@ -38,32 +38,26 @@ _TXSIZE_WITNESSSCRIPT = const(34) class TxWeightCalculator: - def __init__(self, inputs_count: int, outputs_count: int): - self.inputs_count = inputs_count - self.counter = 4 * ( - _TXSIZE_HEADER - + _TXSIZE_FOOTER - + self.ser_length_size(inputs_count) - + self.ser_length_size(outputs_count) - ) - self.segwit = False - - def add_witness_header(self) -> None: - if not self.segwit: - self.counter += _TXSIZE_SEGWIT_OVERHEAD - self.counter += self.ser_length_size(self.inputs_count) - self.segwit = True + def __init__(self) -> None: + self.inputs_count = 0 + self.outputs_count = 0 + self.counter = 4 * (_TXSIZE_HEADER + _TXSIZE_FOOTER) + self.segwit_inputs_count = 0 def add_input(self, i: TxInput) -> None: + self.inputs_count += 1 if i.multisig: - multisig_script_size = _TXSIZE_MULTISIGSCRIPT + len(i.multisig.pubkeys) * ( - 1 + _TXSIZE_PUBKEY - ) + n = len(i.multisig.nodes) if i.multisig.nodes else len(i.multisig.pubkeys) + multisig_script_size = _TXSIZE_MULTISIGSCRIPT + n * (1 + _TXSIZE_PUBKEY) + if i.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES: + multisig_script_size += self.varint_size(multisig_script_size) + else: + multisig_script_size += self.op_push_size(multisig_script_size) + input_script_size = ( - 1 - + i.multisig.m * (1 + _TXSIZE_SIGNATURE) # the OP_FALSE bug in multisig - + self.op_push_size(multisig_script_size) + 1 # the OP_FALSE bug in multisig + + i.multisig.m * (1 + _TXSIZE_SIGNATURE) + multisig_script_size ) else: @@ -72,29 +66,38 @@ class TxWeightCalculator: self.counter += 4 * _TXSIZE_INPUT if i.script_type in common.NONSEGWIT_INPUT_SCRIPT_TYPES: - input_script_size += self.ser_length_size(input_script_size) + input_script_size += self.varint_size(input_script_size) self.counter += 4 * input_script_size - elif i.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES: - self.add_witness_header() + self.segwit_inputs_count += 1 if i.script_type == InputScriptType.SPENDP2SHWITNESS: + # add script_sig size if i.multisig: self.counter += 4 * (2 + _TXSIZE_WITNESSSCRIPT) else: self.counter += 4 * (2 + _TXSIZE_WITNESSPKHASH) else: - self.counter += 4 # empty - self.counter += input_script_size # discounted witness + self.counter += 4 # empty script_sig (1 byte) + self.counter += 1 + input_script_size # discounted witness def add_output(self, script: bytes) -> None: - size = len(script) + self.ser_length_size(len(script)) - self.counter += 4 * (_TXSIZE_OUTPUT + size) + self.outputs_count += 1 + script_size = self.varint_size(len(script)) + len(script) + self.counter += 4 * (_TXSIZE_OUTPUT + script_size) def get_total(self) -> int: - return self.counter + total = self.counter + total += 4 * self.varint_size(self.inputs_count) + total += 4 * self.varint_size(self.outputs_count) + if self.segwit_inputs_count: + total += _TXSIZE_SEGWIT_OVERHEAD + # add one byte of witness stack item count per non-segwit input + total += self.inputs_count - self.segwit_inputs_count + + return total @staticmethod - def ser_length_size(length: int) -> int: + def varint_size(length: int) -> int: if length < 253: return 1 if length < 0x1_0000: diff --git a/core/tests/test_apps.bitcoin.txweight.py b/core/tests/test_apps.bitcoin.txweight.py index ad46b0b79..e864dc98d 100644 --- a/core/tests/test_apps.bitcoin.txweight.py +++ b/core/tests/test_apps.bitcoin.txweight.py @@ -32,7 +32,7 @@ class TestCalculateTxWeight(unittest.TestCase): address_n=[], multisig=None) - calculator = TxWeightCalculator(1, 1) + calculator = TxWeightCalculator() calculator.add_input(inp1) calculator.add_output(output_derive_script(out1.address, coin)) @@ -71,14 +71,14 @@ class TestCalculateTxWeight(unittest.TestCase): multisig=None, ) - calculator = TxWeightCalculator(1, 2) + calculator = TxWeightCalculator() calculator.add_input(inp1) calculator.add_output(output_derive_script(out1.address, coin)) calculator.add_output(output_derive_script(out2.address, coin)) self.assertEqual(calculator.get_total(), 670) # non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560 - # segwit: segwit header, witness count, 2x witness 1*(2+1+107) = 110 + # segwit: segwit header, witness stack item count, witness 1*(2+1+107) = 110 # total 670 def test_native_p2wpkh_txweight(self): @@ -111,14 +111,14 @@ class TestCalculateTxWeight(unittest.TestCase): multisig=None, ) - calculator = TxWeightCalculator(1, 2) + calculator = TxWeightCalculator() calculator.add_input(inp1) calculator.add_output(output_derive_script(out1.address, coin)) calculator.add_output(output_derive_script(out2.address, coin)) self.assertEqual(calculator.get_total(), 566) # non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456 - # segwit: segwit header, witness count, 2x witness 1*(2+1+107) = 110 + # segwit: segwit header, witness stack item count, witness 1*(2+1+107) = 110 # total 566