diff --git a/src/apps/wallet/sign_tx/tx_weight_calculator.py b/src/apps/wallet/sign_tx/tx_weight_calculator.py new file mode 100644 index 000000000..77ca66375 --- /dev/null +++ b/src/apps/wallet/sign_tx/tx_weight_calculator.py @@ -0,0 +1,109 @@ +# It assumes largest possible signature size for all inputs. For segwit +# multisig it can be .25 bytes off due to difference between segwit +# encoding (varint) vs. non-segwit encoding (op_push) of the multisig script. +# +# Heavily inspired by: +# https://github.com/trezor/trezor-mcu/commit/e1fa7af1da79e86ccaae5f3cd2a6c4644f546f8a + +from micropython import const + +from trezor.messages import InputScriptType +from trezor.messages.TxInputType import TxInputType + +# transaction header size: 4 byte version +_TXSIZE_HEADER = const(4) +# transaction footer size: 4 byte lock time +_TXSIZE_FOOTER = const(4) +# transaction segwit overhead 2 (marker, flag) +_TXSIZE_SEGWIT_OVERHEAD = const(2) + +# transaction input size (without script): 32 prevhash, 4 idx, 4 sequence +_TXSIZE_INPUT = const(40) +# transaction output size (without script): 8 amount +_TXSIZE_OUTPUT = const(8) +# size of a pubkey +_TXSIZE_PUBKEY = const(33) +# size of a DER signature (3 type bytes, 3 len bytes, 33 R, 32 S, 1 sighash +_TXSIZE_SIGNATURE = const(72) +# size of a multiscript without pubkey (1 M, 1 N, 1 checksig) +_TXSIZE_MULTISIGSCRIPT = const(3) +# size of a p2wpkh script (1 version, 1 push, 20 hash) +_TXSIZE_WITNESSPKHASH = const(22) +# size of a p2wsh script (1 version, 1 push, 32 hash) +_TXSIZE_WITNESSSCRIPT = const(34) + + +class TxWeightCalculator: + + def __init__(self, inputs_count: int, outputs_count: int): + self.inputs_count = inputs_count + self.counter = 4 * ( + _TXSIZE_HEADER + + _TXSIZE_FOOTER + + self.ser_length_size(inputs_count) + + self.ser_length_size(outputs_count)) + self.segwit = False + + def add_witness_header(self): + if not self.segwit: + self.counter += _TXSIZE_SEGWIT_OVERHEAD + self.counter += self.ser_length_size(self.inputs_count) + self.segwit = True + + def add_input(self, i: TxInputType): + + if i.multisig: + multisig_script_size = ( + _TXSIZE_MULTISIGSCRIPT + + i.multisig.pubkeys_count * (1 + _TXSIZE_PUBKEY)) + input_script_size = ( + 1 # the OP_FALSE bug in multisig + + i.multisig.m * (1 + _TXSIZE_SIGNATURE) + + self.op_push_size(multisig_script_size) + + multisig_script_size) + else: + input_script_size = 1 + _TXSIZE_SIGNATURE + 1 + _TXSIZE_PUBKEY + + self.counter += 4 * _TXSIZE_INPUT + + if (i.script_type == InputScriptType.SPENDADDRESS + or i.script_type == InputScriptType.SPENDMULTISIG): + input_script_size += self.ser_length_size(input_script_size) + self.counter += 4 * input_script_size + + elif (i.script_type == InputScriptType.SPENDWITNESS + or i.script_type == InputScriptType.SPENDP2SHWITNESS): + self.add_witness_header() + if i.script_type == InputScriptType.SPENDP2SHWITNESS: + if i.multisig: + self.counter += 4 * (2 + _TXSIZE_WITNESSSCRIPT) + else: + self.counter += 4 * (2 + _TXSIZE_WITNESSPKHASH) + else: + self.counter += 4 # empty + self.counter += input_script_size # discounted witness + + def add_output(self, script: bytes): + size = len(script) + self.ser_length_size(len(script)) + self.counter += 4 * (_TXSIZE_OUTPUT + size) + + def get_total(self) -> int: + return self.counter + + @staticmethod + def ser_length_size(length: int): + if length < 253: + return 1 + if length < 0x10000: + return 3 + return 5 + + @staticmethod + def op_push_size(length: int): + if length < 0x4c: + return 1 + if length < 0x100: + return 2 + if length < 0x10000: + return 3 + return 5 diff --git a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py index d9370b43a..a833844d8 100644 --- a/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py +++ b/tests/test_apps.wallet.segwit.signtx.native_p2wpkh.py @@ -41,7 +41,7 @@ class TestSignSegwitTxNativeP2WPKH(unittest.TestCase): address='2N4Q5FhU2497BryFfUgbqkAJE87aKHUhXMp', amount=5000000, script_type=OutputScriptType.PAYTOADDRESS, - address_n=None, # @todo ask honza about sanitizing + address_n=None, ) out2 = TxOutputType( address='tb1q694ccp5qcc0udmfwgp692u2s2hjpq5h407urtu', diff --git a/tests/test_apps.wallet.txweight.py b/tests/test_apps.wallet.txweight.py new file mode 100644 index 000000000..a99d5cdc8 --- /dev/null +++ b/tests/test_apps.wallet.txweight.py @@ -0,0 +1,131 @@ +from common import * + +from trezor.messages.TxOutputType import TxOutputType +from trezor.messages import OutputScriptType +from trezor.crypto import bip32, bip39 + +from apps.common import coins +from apps.wallet.sign_tx.tx_weight_calculator import * +from apps.wallet.sign_tx import signing + + +class TestCalculateTxWeight(unittest.TestCase): + # pylint: disable=C0301 + + def test_p2pkh_txweight(self): + + coin = coins.by_name('Bitcoin') + + seed = bip39.seed(' '.join(['all'] * 12), '') + root = bip32.from_seed(seed, 'secp256k1') + + inp1 = TxInputType(address_n=[0], # 14LmW5k4ssUrtbAB4255zdqv3b4w1TuX9e + # amount=390000, + prev_hash=unhexlify('d5f65ee80147b4bcc70b75e4bbf2d7382021b871bd8867ef8fa525ef50864882'), + prev_index=0, + amount=None, + script_type=InputScriptType.SPENDADDRESS, + sequence=None, + multisig=None) + out1 = TxOutputType(address='1MJ2tj2ThBE62zXbBYA5ZaN3fdve5CPAz1', + amount=390000 - 10000, + script_type=OutputScriptType.PAYTOADDRESS, + address_n=None, + multisig=None) + + calculator = TxWeightCalculator(1, 1) + calculator.add_input(inp1) + calculator.add_output(signing.output_derive_script(out1, coin, root)) + + serialized_tx = '010000000182488650ef25a58fef6788bd71b8212038d7f2bbe4750bc7bcb44701e85ef6d5000000006b4830450221009a0b7be0d4ed3146ee262b42202841834698bb3ee39c24e7437df208b8b7077102202b79ab1e7736219387dffe8d615bbdba87e11477104b867ef47afed1a5ede7810121023230848585885f63803a0a8aecdd6538792d5c539215c91698e315bf0253b43dffffffff0160cc0500000000001976a914de9b2a8da088824e8fe51debea566617d851537888ac00000000' + tx_weight = len(serialized_tx) / 2 * 4 # non-segwit tx's weight is simple length*4 + + self.assertEqual(calculator.get_total(), tx_weight) + + def test_p2wpkh_in_p2sh_txweight(self): + + coin = coins.by_name('Testnet') + + seed = bip39.seed(' '.join(['all'] * 12), '') + root = bip32.from_seed(seed, 'secp256k1') + + inp1 = TxInputType( + # 49'/1'/0'/1/0" - 2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX + address_n=[49 | 0x80000000, 1 | 0x80000000, 0 | 0x80000000, 1, 0], + amount=123456789, + prev_hash=unhexlify('20912f98ea3ed849042efed0fdac8cb4fc301961c5988cba56902d8ffb61c337'), + prev_index=0, + script_type=InputScriptType.SPENDP2SHWITNESS, + sequence=0xffffffff, + multisig=None, + ) + out1 = TxOutputType( + address='mhRx1CeVfaayqRwq5zgRQmD7W5aWBfD5mC', + amount=12300000, + script_type=OutputScriptType.PAYTOADDRESS, + address_n=None, + multisig=None, + ) + out2 = TxOutputType( + address='2N1LGaGg836mqSQqiuUBLfcyGBhyZbremDX', + script_type=OutputScriptType.PAYTOADDRESS, + amount=123456789 - 11000 - 12300000, + address_n=None, + multisig=None, + ) + + calculator = TxWeightCalculator(1, 2) + calculator.add_input(inp1) + calculator.add_output(signing.output_derive_script(out1, coin, root)) + calculator.add_output(signing.output_derive_script(out2, coin, root)) + + self.assertEqual(calculator.get_total(), 670) + # non-segwit: header, inputs, outputs, locktime 4*(4+65+67+4) = 560 + # segwit: segwit header, witness count, 2x witness 1*(2+1+107) = 110 + # total 670 + + def test_native_p2wpkh_txweight(self): + + coin = coins.by_name('Testnet') + + seed = bip39.seed(' '.join(['all'] * 12), '') + root = bip32.from_seed(seed, 'secp256k1') + + inp1 = TxInputType( + # 49'/1'/0'/0/0" - tb1qqzv60m9ajw8drqulta4ld4gfx0rdh82un5s65s + address_n=[49 | 0x80000000, 1 | 0x80000000, 0 | 0x80000000, 0, 0], + amount=12300000, + prev_hash=unhexlify('09144602765ce3dd8f4329445b20e3684e948709c5cdcaf12da3bb079c99448a'), + prev_index=0, + script_type=InputScriptType.SPENDWITNESS, + sequence=0xffffffff, + multisig=None, + ) + out1 = TxOutputType( + address='2N4Q5FhU2497BryFfUgbqkAJE87aKHUhXMp', + amount=5000000, + script_type=OutputScriptType.PAYTOADDRESS, + address_n=None, + multisig=None, + ) + out2 = TxOutputType( + address='tb1q694ccp5qcc0udmfwgp692u2s2hjpq5h407urtu', + script_type=OutputScriptType.PAYTOADDRESS, + amount=12300000 - 11000 - 5000000, + address_n=None, + multisig=None, + ) + + calculator = TxWeightCalculator(1, 2) + calculator.add_input(inp1) + calculator.add_output(signing.output_derive_script(out1, coin, root)) + calculator.add_output(signing.output_derive_script(out2, coin, root)) + + self.assertEqual(calculator.get_total(), 566) + # non-segwit: header, inputs, outputs, locktime 4*(4+42+64+4) = 456 + # segwit: segwit header, witness count, 2x witness 1*(2+1+107) = 110 + # total 566 + + +if __name__ == '__main__': + unittest.main()