mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-27 06:42:02 +00:00
fix(core): Fix TxWeightCalculator.
- Refactor TxWeightCalculator to count inputs and outputs itself. - Fix witness data weight by adding the weight of the witness stack item count for each input in segwit transactions and removing the weight of the nonsensical extra inputs count. - Get multisig pubkey count from multisig.nodes or multisig.pubkeys like in multisig_get_pubkeys(). - Fix size of multisig script length encoding in segwit (varint vs. OP_PUSH). - Improve comments.
This commit is contained in:
parent
a5bd1643fc
commit
9b579094c0
@ -27,7 +27,7 @@ if False:
|
||||
class Approver:
|
||||
def __init__(self, tx: SignTx, coin: CoinInfo) -> None:
|
||||
self.coin = coin
|
||||
self.weight = tx_weight.TxWeightCalculator(tx.inputs_count, tx.outputs_count)
|
||||
self.weight = tx_weight.TxWeightCalculator()
|
||||
|
||||
# amounts in the current transaction
|
||||
self.total_in = 0 # sum of input amounts
|
||||
@ -267,9 +267,7 @@ class CoinJoinApprover(Approver):
|
||||
raise wire.DataError("Coin name does not match authorization.")
|
||||
|
||||
# Upper bound on the user's contribution to the weight of the transaction.
|
||||
self.our_weight = tx_weight.TxWeightCalculator(
|
||||
tx.inputs_count, tx.outputs_count
|
||||
)
|
||||
self.our_weight = tx_weight.TxWeightCalculator()
|
||||
|
||||
# base for coordinator fee to be multiplied by fee_per_anonymity
|
||||
self.coordinator_fee_base = 0
|
||||
|
@ -27,7 +27,7 @@ _TXSIZE_INPUT = const(40)
|
||||
_TXSIZE_OUTPUT = const(8)
|
||||
# size of a pubkey
|
||||
_TXSIZE_PUBKEY = const(33)
|
||||
# size of a DER signature (3 type bytes, 3 len bytes, 33 R, 32 S, 1 sighash
|
||||
# maximum size of a DER signature (3 type bytes, 3 len bytes, 33 R, 32 S, 1 sighash)
|
||||
_TXSIZE_SIGNATURE = const(72)
|
||||
# size of a multiscript without pubkey (1 M, 1 N, 1 checksig)
|
||||
_TXSIZE_MULTISIGSCRIPT = const(3)
|
||||
@ -38,32 +38,26 @@ _TXSIZE_WITNESSSCRIPT = const(34)
|
||||
|
||||
|
||||
class TxWeightCalculator:
|
||||
def __init__(self, inputs_count: int, outputs_count: int):
|
||||
self.inputs_count = inputs_count
|
||||
self.counter = 4 * (
|
||||
_TXSIZE_HEADER
|
||||
+ _TXSIZE_FOOTER
|
||||
+ self.ser_length_size(inputs_count)
|
||||
+ self.ser_length_size(outputs_count)
|
||||
)
|
||||
self.segwit = False
|
||||
|
||||
def add_witness_header(self) -> None:
|
||||
if not self.segwit:
|
||||
self.counter += _TXSIZE_SEGWIT_OVERHEAD
|
||||
self.counter += self.ser_length_size(self.inputs_count)
|
||||
self.segwit = True
|
||||
def __init__(self) -> None:
|
||||
self.inputs_count = 0
|
||||
self.outputs_count = 0
|
||||
self.counter = 4 * (_TXSIZE_HEADER + _TXSIZE_FOOTER)
|
||||
self.segwit_inputs_count = 0
|
||||
|
||||
def add_input(self, i: TxInput) -> None:
|
||||
self.inputs_count += 1
|
||||
|
||||
if i.multisig:
|
||||
multisig_script_size = _TXSIZE_MULTISIGSCRIPT + len(i.multisig.pubkeys) * (
|
||||
1 + _TXSIZE_PUBKEY
|
||||
)
|
||||
n = len(i.multisig.nodes) if i.multisig.nodes else len(i.multisig.pubkeys)
|
||||
multisig_script_size = _TXSIZE_MULTISIGSCRIPT + n * (1 + _TXSIZE_PUBKEY)
|
||||
if i.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES:
|
||||
multisig_script_size += self.varint_size(multisig_script_size)
|
||||
else:
|
||||
multisig_script_size += self.op_push_size(multisig_script_size)
|
||||
|
||||
input_script_size = (
|
||||
1
|
||||
+ i.multisig.m * (1 + _TXSIZE_SIGNATURE) # the OP_FALSE bug in multisig
|
||||
+ self.op_push_size(multisig_script_size)
|
||||
1 # the OP_FALSE bug in multisig
|
||||
+ i.multisig.m * (1 + _TXSIZE_SIGNATURE)
|
||||
+ multisig_script_size
|
||||
)
|
||||
else:
|
||||
@ -72,29 +66,38 @@ class TxWeightCalculator:
|
||||
self.counter += 4 * _TXSIZE_INPUT
|
||||
|
||||
if i.script_type in common.NONSEGWIT_INPUT_SCRIPT_TYPES:
|
||||
input_script_size += self.ser_length_size(input_script_size)
|
||||
input_script_size += self.varint_size(input_script_size)
|
||||
self.counter += 4 * input_script_size
|
||||
|
||||
elif i.script_type in common.SEGWIT_INPUT_SCRIPT_TYPES:
|
||||
self.add_witness_header()
|
||||
self.segwit_inputs_count += 1
|
||||
if i.script_type == InputScriptType.SPENDP2SHWITNESS:
|
||||
# add script_sig size
|
||||
if i.multisig:
|
||||
self.counter += 4 * (2 + _TXSIZE_WITNESSSCRIPT)
|
||||
else:
|
||||
self.counter += 4 * (2 + _TXSIZE_WITNESSPKHASH)
|
||||
else:
|
||||
self.counter += 4 # empty
|
||||
self.counter += input_script_size # discounted witness
|
||||
self.counter += 4 # empty script_sig (1 byte)
|
||||
self.counter += 1 + input_script_size # discounted witness
|
||||
|
||||
def add_output(self, script: bytes) -> None:
|
||||
size = len(script) + self.ser_length_size(len(script))
|
||||
self.counter += 4 * (_TXSIZE_OUTPUT + size)
|
||||
self.outputs_count += 1
|
||||
script_size = self.varint_size(len(script)) + len(script)
|
||||
self.counter += 4 * (_TXSIZE_OUTPUT + script_size)
|
||||
|
||||
def get_total(self) -> int:
|
||||
return self.counter
|
||||
total = self.counter
|
||||
total += 4 * self.varint_size(self.inputs_count)
|
||||
total += 4 * self.varint_size(self.outputs_count)
|
||||
if self.segwit_inputs_count:
|
||||
total += _TXSIZE_SEGWIT_OVERHEAD
|
||||
# add one byte of witness stack item count per non-segwit input
|
||||
total += self.inputs_count - self.segwit_inputs_count
|
||||
|
||||
return total
|
||||
|
||||
@staticmethod
|
||||
def ser_length_size(length: int) -> int:
|
||||
def varint_size(length: int) -> int:
|
||||
if length < 253:
|
||||
return 1
|
||||
if length < 0x1_0000:
|
||||
|
@ -32,7 +32,7 @@ class TestCalculateTxWeight(unittest.TestCase):
|
||||
address_n=[],
|
||||
multisig=None)
|
||||
|
||||
calculator = TxWeightCalculator(1, 1)
|
||||
calculator = TxWeightCalculator()
|
||||
calculator.add_input(inp1)
|
||||
calculator.add_output(output_derive_script(out1.address, coin))
|
||||
|
||||
@ -71,14 +71,14 @@ class TestCalculateTxWeight(unittest.TestCase):
|
||||
multisig=None,
|
||||
)
|
||||
|
||||
calculator = TxWeightCalculator(1, 2)
|
||||
calculator = TxWeightCalculator()
|
||||
calculator.add_input(inp1)
|
||||
calculator.add_output(output_derive_script(out1.address, coin))
|
||||
calculator.add_output(output_derive_script(out2.address, coin))
|
||||
|
||||
self.assertEqual(calculator.get_total(), 670)
|
||||
# non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560
|
||||
# segwit: segwit header, witness count, 2x witness 1*(2+1+107) = 110
|
||||
# segwit: segwit header, witness stack item count, witness 1*(2+1+107) = 110
|
||||
# total 670
|
||||
|
||||
def test_native_p2wpkh_txweight(self):
|
||||
@ -111,14 +111,14 @@ class TestCalculateTxWeight(unittest.TestCase):
|
||||
multisig=None,
|
||||
)
|
||||
|
||||
calculator = TxWeightCalculator(1, 2)
|
||||
calculator = TxWeightCalculator()
|
||||
calculator.add_input(inp1)
|
||||
calculator.add_output(output_derive_script(out1.address, coin))
|
||||
calculator.add_output(output_derive_script(out2.address, coin))
|
||||
|
||||
self.assertEqual(calculator.get_total(), 566)
|
||||
# non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456
|
||||
# segwit: segwit header, witness count, 2x witness 1*(2+1+107) = 110
|
||||
# segwit: segwit header, witness stack item count, witness 1*(2+1+107) = 110
|
||||
# total 566
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user