1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 12:28:09 +00:00

tests(core): Implement CoinJoin requests in unit tests.

This commit is contained in:
Andrew Kozlik 2022-10-19 18:26:02 +02:00 committed by Andrew Kozlik
parent 1df65d1a0c
commit 7a02be077f

View File

@ -2,13 +2,14 @@ from common import unittest, await_result, H_
import storage.cache import storage.cache
from trezor import wire from trezor import wire
from trezor.crypto.curve import secp256k1 from trezor.crypto import bip32
from trezor.crypto.curve import bip340, secp256k1
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.messages import AuthorizeCoinJoin from trezor.messages import AuthorizeCoinJoin
from trezor.messages import TxInput from trezor.messages import TxInput
from trezor.messages import TxOutput from trezor.messages import TxOutput
from trezor.messages import SignTx from trezor.messages import SignTx
from trezor.messages import TxAckPaymentRequest from trezor.messages import CoinJoinRequest
from trezor.enums import InputScriptType, OutputScriptType from trezor.enums import InputScriptType, OutputScriptType
from trezor.utils import HashWriter from trezor.utils import HashWriter
@ -24,29 +25,82 @@ class TestApprover(unittest.TestCase):
def setUp(self): def setUp(self):
self.coin = coins.by_name('Bitcoin') self.coin = coins.by_name('Bitcoin')
self.max_fee_rate_percent = 0.3 self.fee_rate_percent = 0.3
self.no_fee_threshold=1000000
self.min_registrable_amount=5000
self.coordinator_name = "www.example.com" self.coordinator_name = "www.example.com"
# Private key for signing and masking CoinJoin requests.
# m/0h for "all all ... all" seed.
self.private_key = b'?S\ti\x8b\xc5o{,\xab\x03\x194\xea\xa8[_:\xeb\xdf\xce\xef\xe50\xf17D\x98`\xb9dj'
self.node = bip32.HDNode(
depth=0,
fingerprint=0,
child_num=0,
chain_code=bytearray(32),
private_key=b"\x01" * 32,
curve_name="secp256k1",
)
self.tweaked_node_pubkey = b"\x02" + bip340.tweak_public_key(self.node.public_key()[1:])
self.msg_auth = AuthorizeCoinJoin( self.msg_auth = AuthorizeCoinJoin(
coordinator=self.coordinator_name, coordinator=self.coordinator_name,
max_rounds=10, max_rounds=10,
max_coordinator_fee_rate=int(self.max_fee_rate_percent * 10**8), max_coordinator_fee_rate=int(self.fee_rate_percent * 10**8),
max_fee_per_kvbyte=7000, max_fee_per_kvbyte=7000,
address_n=[H_(84), H_(0), H_(0)], address_n=[H_(10025), H_(0), H_(0), H_(1)],
coin_name=self.coin.coin_name, coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDWITNESS, script_type=InputScriptType.SPENDTAPROOT,
) )
storage.cache.start_session() storage.cache.start_session()
def make_coinjoin_request(self, inputs):
mask_public_key = secp256k1.publickey(self.private_key)
coinjoin_flags = bytearray()
for txi in inputs:
shared_secret = secp256k1.multiply(self.private_key, self.tweaked_node_pubkey)[1:33]
h_mask = HashWriter(sha256())
writers.write_bytes_fixed(h_mask, shared_secret, 32)
writers.write_bytes_reversed(h_mask, txi.prev_hash, writers.TX_HASH_SIZE)
writers.write_uint32(h_mask, txi.prev_index)
mask = h_mask.get_digest()[0] & 1
signable = txi.script_type == InputScriptType.SPENDTAPROOT
txi.coinjoin_flags = signable ^ mask
coinjoin_flags.append(txi.coinjoin_flags)
# Compute CoinJoin request signature.
h_request = HashWriter(sha256(b"CJR1"))
writers.write_bytes_prefixed(
h_request, self.coordinator_name.encode()
)
writers.write_uint32(h_request, self.coin.slip44)
writers.write_uint32(h_request, int(self.fee_rate_percent * 10**8))
writers.write_uint64(h_request, self.no_fee_threshold)
writers.write_uint64(h_request, self.min_registrable_amount)
writers.write_bytes_fixed(h_request, mask_public_key, 33)
writers.write_bytes_prefixed(h_request, coinjoin_flags)
writers.write_bytes_fixed(h_request, sha256().digest(), 32)
writers.write_bytes_fixed(h_request, sha256().digest(), 32)
signature = secp256k1.sign(self.private_key, h_request.get_digest())
return CoinJoinRequest(
fee_rate=int(self.fee_rate_percent * 10**8),
no_fee_threshold=self.no_fee_threshold,
min_registrable_amount=self.min_registrable_amount,
mask_public_key=mask_public_key,
signature=signature,
)
def test_coinjoin_lots_of_inputs(self): def test_coinjoin_lots_of_inputs(self):
denomination = 10000000 denomination = 10_000_000
coordinator_fee = int(self.max_fee_rate_percent / 100 * denomination) coordinator_fee = int(self.fee_rate_percent / 100 * denomination)
fees = coordinator_fee + 500 fees = coordinator_fee + 500
# Other's inputs. # Other's inputs.
inputs = [ inputs = [
TxInput( TxInput(
prev_hash=b"", prev_hash=bytes(32),
prev_index=0, prev_index=0,
amount=denomination, amount=denomination,
script_pubkey=bytes(22), script_pubkey=bytes(22),
@ -59,11 +113,11 @@ class TestApprover(unittest.TestCase):
# Our input. # Our input.
inputs.insert(30, inputs.insert(30,
TxInput( TxInput(
prev_hash=b"", prev_hash=bytes(32),
prev_index=0, prev_index=0,
address_n=[H_(84), H_(0), H_(0), 0, 1], address_n=[H_(10025), H_(0), H_(0), H_(1), 0, 1],
amount=denomination, amount=denomination,
script_type=InputScriptType.SPENDWITNESS, script_type=InputScriptType.SPENDTAPROOT,
sequence=0xffffffff, sequence=0xffffffff,
) )
) )
@ -83,7 +137,7 @@ class TestApprover(unittest.TestCase):
40, 40,
TxOutput( TxOutput(
address="", address="",
address_n=[H_(84), H_(0), H_(0), 0, 2], address_n=[H_(10025), H_(0), H_(0), H_(1), 0, 2],
amount=denomination-fees, amount=denomination-fees,
script_type=OutputScriptType.PAYTOWITNESS, script_type=OutputScriptType.PAYTOWITNESS,
payment_req_index=0, payment_req_index=0,
@ -100,39 +154,18 @@ class TestApprover(unittest.TestCase):
) )
) )
coinjoin_req = self.make_coinjoin_request(inputs)
tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req)
authorization = CoinJoinAuthorization(self.msg_auth) authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization) approver = CoinJoinApprover(tx, self.coin, authorization)
signer = Bitcoin(tx, None, self.coin, approver) signer = Bitcoin(tx, None, self.coin, approver)
# Compute payment request signature.
# Private key of m/0h for "all all ... all" seed.
private_key = b'?S\ti\x8b\xc5o{,\xab\x03\x194\xea\xa8[_:\xeb\xdf\xce\xef\xe50\xf17D\x98`\xb9dj'
h_pr = HashWriter(sha256())
writers.write_bytes_fixed(h_pr, b"SL\x00\x24", 4)
writers.write_bytes_prefixed(h_pr, b"") # Empty nonce.
writers.write_bytes_prefixed(h_pr, self.coordinator_name.encode())
writers.write_compact_size(h_pr, 0) # No memos.
writers.write_uint32(h_pr, self.coin.slip44)
h_outputs = HashWriter(sha256())
for txo in outputs:
writers.write_uint64(h_outputs, txo.amount)
writers.write_bytes_prefixed(h_outputs, txo.address.encode())
writers.write_bytes_fixed(h_pr, h_outputs.get_digest(), 32)
signature = secp256k1.sign(private_key, h_pr.get_digest())
tx_ack_payment_req = TxAckPaymentRequest(
recipient_name=self.coordinator_name,
signature=signature,
)
for txi in inputs: for txi in inputs:
if txi.script_type == InputScriptType.EXTERNAL: if txi.script_type == InputScriptType.EXTERNAL:
approver.add_external_input(txi) approver.add_external_input(txi)
else: else:
await_result(approver.add_internal_input(txi)) await_result(approver.add_internal_input(txi, self.node))
await_result(approver.add_payment_request(tx_ack_payment_req, None))
for txo in outputs: for txo in outputs:
if txo.address_n: if txo.address_n:
approver.add_change_output(txo, script_pubkey=bytes(22)) approver.add_change_output(txo, script_pubkey=bytes(22))
@ -142,36 +175,38 @@ class TestApprover(unittest.TestCase):
await_result(approver.approve_tx(TxInfo(signer, tx), [])) await_result(approver.approve_tx(TxInfo(signer, tx), []))
def test_coinjoin_input_account_depth_mismatch(self): def test_coinjoin_input_account_depth_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization)
txi = TxInput( txi = TxInput(
prev_hash=b"", prev_hash=bytes(32),
prev_index=0, prev_index=0,
address_n=[H_(49), H_(0), H_(0), 0], address_n=[H_(10025), H_(0), H_(0), H_(1), 0],
amount=10000000, amount=10000000,
script_type=InputScriptType.SPENDWITNESS script_type=InputScriptType.SPENDTAPROOT
) )
coinjoin_req = self.make_coinjoin_request([txi])
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req)
authorization = CoinJoinAuthorization(self.msg_auth)
approver = CoinJoinApprover(tx, self.coin, authorization)
with self.assertRaises(wire.ProcessError): with self.assertRaises(wire.ProcessError):
await_result(approver.add_internal_input(txi)) await_result(approver.add_internal_input(txi, self.node))
def test_coinjoin_input_account_path_mismatch(self): def test_coinjoin_input_account_path_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization)
txi = TxInput( txi = TxInput(
prev_hash=b"", prev_hash=bytes(32),
prev_index=0, prev_index=0,
address_n=[H_(49), H_(0), H_(0), 0, 2], address_n=[H_(10025), H_(0), H_(1), H_(1), 0, 0],
amount=10000000, amount=10000000,
script_type=InputScriptType.SPENDWITNESS script_type=InputScriptType.SPENDTAPROOT
) )
coinjoin_req = self.make_coinjoin_request([txi])
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req)
authorization = CoinJoinAuthorization(self.msg_auth)
approver = CoinJoinApprover(tx, self.coin, authorization)
with self.assertRaises(wire.ProcessError): with self.assertRaises(wire.ProcessError):
await_result(approver.add_internal_input(txi)) await_result(approver.add_internal_input(txi, self.node))
if __name__ == '__main__': if __name__ == '__main__':