from common import H_, await_result, unittest # isort:skip import storage.cache_codec from trezor import wire from trezor.crypto import bip32 from trezor.enums import InputScriptType, OutputScriptType from trezor.messages import ( AuthorizeCoinJoin, CoinJoinRequest, SignTx, TxInput, TxOutput, ) from trezor.wire import context from trezor.wire.codec.codec_context import CodecContext from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization from apps.bitcoin.sign_tx.approvers import CoinJoinApprover from apps.bitcoin.sign_tx.bitcoin import Bitcoin from apps.bitcoin.sign_tx.tx_info import TxInfo from apps.common import coins class TestApprover(unittest.TestCase): def setUpClass(self): context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) def tearDownClass(self): context.CURRENT_CONTEXT = None def setUp(self): self.coin = coins.by_name("Bitcoin") self.fee_rate_percent = 0.3 self.no_fee_threshold = 1000000 self.min_registrable_amount = 5000 self.coordinator_name = "www.example.com" self.node = bip32.HDNode( depth=0, fingerprint=0, child_num=0, chain_code=bytearray(32), private_key=b"\x01" * 32, curve_name="secp256k1", ) self.msg_auth = AuthorizeCoinJoin( coordinator=self.coordinator_name, max_rounds=10, max_coordinator_fee_rate=int(self.fee_rate_percent * 10**FEE_RATE_DECIMALS), max_fee_per_kvbyte=7000, address_n=[H_(10025), H_(0), H_(0), H_(1)], coin_name=self.coin.coin_name, script_type=InputScriptType.SPENDTAPROOT, ) storage.cache_codec.start_session() def make_coinjoin_request(self, inputs): return CoinJoinRequest( fee_rate=int(self.fee_rate_percent * 10**FEE_RATE_DECIMALS), no_fee_threshold=self.no_fee_threshold, min_registrable_amount=self.min_registrable_amount, mask_public_key=bytearray(), signature=bytearray(), ) def test_coinjoin_lots_of_inputs(self): denomination = 10_000_000 coordinator_fee = int(self.fee_rate_percent / 100 * denomination) fees = coordinator_fee + 500 # Other's inputs. inputs = [ TxInput( prev_hash=bytes(32), prev_index=0, amount=denomination, script_pubkey=bytes(22), script_type=InputScriptType.EXTERNAL, sequence=0xFFFFFFFF, witness="", ) for i in range(99) ] # Our input. inputs.insert( 30, TxInput( prev_hash=bytes(32), prev_index=0, address_n=[H_(10025), H_(0), H_(0), H_(1), 0, 1], amount=denomination, script_type=InputScriptType.SPENDTAPROOT, sequence=0xFFFFFFFF, ), ) # Other's CoinJoined outputs. outputs = [ TxOutput( address="", amount=denomination - fees, script_type=OutputScriptType.PAYTOTAPROOT, payment_req_index=0, ) for i in range(99) ] # Our CoinJoined output. outputs.insert( 40, TxOutput( address="", address_n=[H_(10025), H_(0), H_(0), H_(1), 0, 2], amount=denomination - fees, script_type=OutputScriptType.PAYTOTAPROOT, payment_req_index=0, ), ) # Coordinator's output. outputs.append( TxOutput( address="", amount=coordinator_fee * len(outputs), script_type=OutputScriptType.PAYTOTAPROOT, payment_req_index=0, ) ) coinjoin_req = self.make_coinjoin_request(inputs) tx = SignTx( outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req, ) authorization = CoinJoinAuthorization(self.msg_auth) approver = CoinJoinApprover(tx, self.coin, authorization) signer = Bitcoin(tx, None, self.coin, approver) tx_info = TxInfo(signer, tx) for txi in inputs: if txi.script_type == InputScriptType.EXTERNAL: approver.add_external_input(txi) else: await_result(approver.add_internal_input(txi, self.node)) for txo in outputs: if txo.address_n: await_result(approver.add_change_output(txo, script_pubkey=bytes(22))) else: await_result( approver.add_external_output( txo, script_pubkey=bytes(22), tx_info=tx_info ) ) await_result(approver.approve_tx(tx_info, [], None)) def test_coinjoin_input_account_depth_mismatch(self): txi = TxInput( prev_hash=bytes(32), prev_index=0, address_n=[H_(10025), H_(0), H_(0), H_(1), 0], amount=10000000, script_type=InputScriptType.SPENDTAPROOT, ) coinjoin_req = self.make_coinjoin_request([txi]) tx = SignTx( outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req, ) authorization = CoinJoinAuthorization(self.msg_auth) approver = CoinJoinApprover(tx, self.coin, authorization) with self.assertRaises(wire.ProcessError): await_result(approver.add_internal_input(txi, self.node)) def test_coinjoin_input_account_path_mismatch(self): txi = TxInput( prev_hash=bytes(32), prev_index=0, address_n=[H_(10025), H_(0), H_(1), H_(1), 0, 0], amount=10000000, script_type=InputScriptType.SPENDTAPROOT, ) coinjoin_req = self.make_coinjoin_request([txi]) tx = SignTx( outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req, ) authorization = CoinJoinAuthorization(self.msg_auth) approver = CoinJoinApprover(tx, self.coin, authorization) with self.assertRaises(wire.ProcessError): await_result(approver.add_internal_input(txi, self.node)) if __name__ == "__main__": unittest.main()