1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-18 12:28:09 +00:00

tests(core): Implement CoinJoin requests in unit tests.

This commit is contained in:
Andrew Kozlik 2022-10-19 18:26:02 +02:00 committed by Andrew Kozlik
parent 1df65d1a0c
commit 7a02be077f

View File

@ -2,13 +2,14 @@ from common import unittest, await_result, H_
import storage.cache
from trezor import wire
from trezor.crypto.curve import secp256k1
from trezor.crypto import bip32
from trezor.crypto.curve import bip340, secp256k1
from trezor.crypto.hashlib import sha256
from trezor.messages import AuthorizeCoinJoin
from trezor.messages import TxInput
from trezor.messages import TxOutput
from trezor.messages import SignTx
from trezor.messages import TxAckPaymentRequest
from trezor.messages import CoinJoinRequest
from trezor.enums import InputScriptType, OutputScriptType
from trezor.utils import HashWriter
@ -24,29 +25,82 @@ class TestApprover(unittest.TestCase):
def setUp(self):
self.coin = coins.by_name('Bitcoin')
self.max_fee_rate_percent = 0.3
self.fee_rate_percent = 0.3
self.no_fee_threshold=1000000
self.min_registrable_amount=5000
self.coordinator_name = "www.example.com"
# Private key for signing and masking CoinJoin requests.
# m/0h for "all all ... all" seed.
self.private_key = b'?S\ti\x8b\xc5o{,\xab\x03\x194\xea\xa8[_:\xeb\xdf\xce\xef\xe50\xf17D\x98`\xb9dj'
self.node = bip32.HDNode(
depth=0,
fingerprint=0,
child_num=0,
chain_code=bytearray(32),
private_key=b"\x01" * 32,
curve_name="secp256k1",
)
self.tweaked_node_pubkey = b"\x02" + bip340.tweak_public_key(self.node.public_key()[1:])
self.msg_auth = AuthorizeCoinJoin(
coordinator=self.coordinator_name,
max_rounds=10,
max_coordinator_fee_rate=int(self.max_fee_rate_percent * 10**8),
max_coordinator_fee_rate=int(self.fee_rate_percent * 10**8),
max_fee_per_kvbyte=7000,
address_n=[H_(84), H_(0), H_(0)],
address_n=[H_(10025), H_(0), H_(0), H_(1)],
coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDWITNESS,
script_type=InputScriptType.SPENDTAPROOT,
)
storage.cache.start_session()
def make_coinjoin_request(self, inputs):
mask_public_key = secp256k1.publickey(self.private_key)
coinjoin_flags = bytearray()
for txi in inputs:
shared_secret = secp256k1.multiply(self.private_key, self.tweaked_node_pubkey)[1:33]
h_mask = HashWriter(sha256())
writers.write_bytes_fixed(h_mask, shared_secret, 32)
writers.write_bytes_reversed(h_mask, txi.prev_hash, writers.TX_HASH_SIZE)
writers.write_uint32(h_mask, txi.prev_index)
mask = h_mask.get_digest()[0] & 1
signable = txi.script_type == InputScriptType.SPENDTAPROOT
txi.coinjoin_flags = signable ^ mask
coinjoin_flags.append(txi.coinjoin_flags)
# Compute CoinJoin request signature.
h_request = HashWriter(sha256(b"CJR1"))
writers.write_bytes_prefixed(
h_request, self.coordinator_name.encode()
)
writers.write_uint32(h_request, self.coin.slip44)
writers.write_uint32(h_request, int(self.fee_rate_percent * 10**8))
writers.write_uint64(h_request, self.no_fee_threshold)
writers.write_uint64(h_request, self.min_registrable_amount)
writers.write_bytes_fixed(h_request, mask_public_key, 33)
writers.write_bytes_prefixed(h_request, coinjoin_flags)
writers.write_bytes_fixed(h_request, sha256().digest(), 32)
writers.write_bytes_fixed(h_request, sha256().digest(), 32)
signature = secp256k1.sign(self.private_key, h_request.get_digest())
return CoinJoinRequest(
fee_rate=int(self.fee_rate_percent * 10**8),
no_fee_threshold=self.no_fee_threshold,
min_registrable_amount=self.min_registrable_amount,
mask_public_key=mask_public_key,
signature=signature,
)
def test_coinjoin_lots_of_inputs(self):
denomination = 10000000
coordinator_fee = int(self.max_fee_rate_percent / 100 * denomination)
denomination = 10_000_000
coordinator_fee = int(self.fee_rate_percent / 100 * denomination)
fees = coordinator_fee + 500
# Other's inputs.
inputs = [
TxInput(
prev_hash=b"",
prev_hash=bytes(32),
prev_index=0,
amount=denomination,
script_pubkey=bytes(22),
@ -59,11 +113,11 @@ class TestApprover(unittest.TestCase):
# Our input.
inputs.insert(30,
TxInput(
prev_hash=b"",
prev_hash=bytes(32),
prev_index=0,
address_n=[H_(84), H_(0), H_(0), 0, 1],
address_n=[H_(10025), H_(0), H_(0), H_(1), 0, 1],
amount=denomination,
script_type=InputScriptType.SPENDWITNESS,
script_type=InputScriptType.SPENDTAPROOT,
sequence=0xffffffff,
)
)
@ -83,7 +137,7 @@ class TestApprover(unittest.TestCase):
40,
TxOutput(
address="",
address_n=[H_(84), H_(0), H_(0), 0, 2],
address_n=[H_(10025), H_(0), H_(0), H_(1), 0, 2],
amount=denomination-fees,
script_type=OutputScriptType.PAYTOWITNESS,
payment_req_index=0,
@ -100,39 +154,18 @@ class TestApprover(unittest.TestCase):
)
)
coinjoin_req = self.make_coinjoin_request(inputs)
tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req)
authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=len(outputs), inputs_count=len(inputs), coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization)
signer = Bitcoin(tx, None, self.coin, approver)
# Compute payment request signature.
# Private key of m/0h for "all all ... all" seed.
private_key = b'?S\ti\x8b\xc5o{,\xab\x03\x194\xea\xa8[_:\xeb\xdf\xce\xef\xe50\xf17D\x98`\xb9dj'
h_pr = HashWriter(sha256())
writers.write_bytes_fixed(h_pr, b"SL\x00\x24", 4)
writers.write_bytes_prefixed(h_pr, b"") # Empty nonce.
writers.write_bytes_prefixed(h_pr, self.coordinator_name.encode())
writers.write_compact_size(h_pr, 0) # No memos.
writers.write_uint32(h_pr, self.coin.slip44)
h_outputs = HashWriter(sha256())
for txo in outputs:
writers.write_uint64(h_outputs, txo.amount)
writers.write_bytes_prefixed(h_outputs, txo.address.encode())
writers.write_bytes_fixed(h_pr, h_outputs.get_digest(), 32)
signature = secp256k1.sign(private_key, h_pr.get_digest())
tx_ack_payment_req = TxAckPaymentRequest(
recipient_name=self.coordinator_name,
signature=signature,
)
for txi in inputs:
if txi.script_type == InputScriptType.EXTERNAL:
approver.add_external_input(txi)
else:
await_result(approver.add_internal_input(txi))
await_result(approver.add_internal_input(txi, self.node))
await_result(approver.add_payment_request(tx_ack_payment_req, None))
for txo in outputs:
if txo.address_n:
approver.add_change_output(txo, script_pubkey=bytes(22))
@ -142,36 +175,38 @@ class TestApprover(unittest.TestCase):
await_result(approver.approve_tx(TxInfo(signer, tx), []))
def test_coinjoin_input_account_depth_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization)
txi = TxInput(
prev_hash=b"",
prev_hash=bytes(32),
prev_index=0,
address_n=[H_(49), H_(0), H_(0), 0],
address_n=[H_(10025), H_(0), H_(0), H_(1), 0],
amount=10000000,
script_type=InputScriptType.SPENDWITNESS
script_type=InputScriptType.SPENDTAPROOT
)
coinjoin_req = self.make_coinjoin_request([txi])
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req)
authorization = CoinJoinAuthorization(self.msg_auth)
approver = CoinJoinApprover(tx, self.coin, authorization)
with self.assertRaises(wire.ProcessError):
await_result(approver.add_internal_input(txi))
await_result(approver.add_internal_input(txi, self.node))
def test_coinjoin_input_account_path_mismatch(self):
authorization = CoinJoinAuthorization(self.msg_auth)
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0)
approver = CoinJoinApprover(tx, self.coin, authorization)
txi = TxInput(
prev_hash=b"",
prev_hash=bytes(32),
prev_index=0,
address_n=[H_(49), H_(0), H_(0), 0, 2],
address_n=[H_(10025), H_(0), H_(1), H_(1), 0, 0],
amount=10000000,
script_type=InputScriptType.SPENDWITNESS
script_type=InputScriptType.SPENDTAPROOT
)
coinjoin_req = self.make_coinjoin_request([txi])
tx = SignTx(outputs_count=201, inputs_count=100, coin_name=self.coin.coin_name, lock_time=0, coinjoin_request=coinjoin_req)
authorization = CoinJoinAuthorization(self.msg_auth)
approver = CoinJoinApprover(tx, self.coin, authorization)
with self.assertRaises(wire.ProcessError):
await_result(approver.add_internal_input(txi))
await_result(approver.add_internal_input(txi, self.node))
if __name__ == '__main__':