1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-26 09:28:13 +00:00

feat(tests): Implement CoinJoin requests in device tests.

This commit is contained in:
Andrew Kozlik 2022-10-12 17:05:59 +02:00 committed by Andrew Kozlik
parent 7a02be077f
commit c41ccdca76
2 changed files with 183 additions and 47 deletions

View File

@ -1,10 +1,12 @@
from collections import namedtuple from collections import namedtuple
from hashlib import sha256 from hashlib import sha256
from ecdsa import SECP256k1, SigningKey from ecdsa import ECDH, SECP256k1, SigningKey
from trezorlib import btc, messages from trezorlib import btc, messages
SLIP44 = 1 # Testnet
TextMemo = namedtuple("TextMemo", "text") TextMemo = namedtuple("TextMemo", "text")
RefundMemo = namedtuple("RefundMemo", "address_n") RefundMemo = namedtuple("RefundMemo", "address_n")
CoinPurchaseMemo = namedtuple( CoinPurchaseMemo = namedtuple(
@ -18,15 +20,19 @@ payment_req_signer = SigningKey.from_string(
def hash_bytes_prefixed(hasher, data): def hash_bytes_prefixed(hasher, data):
hasher.update(len(data).to_bytes(1, "little")) n = len(data)
if n < 253:
hasher.update(n.to_bytes(1, "little"))
elif n < 0x1_0000:
hasher.update(bytes([253]))
hasher.update(n.to_bytes(2, "little"))
hasher.update(data) hasher.update(data)
def make_payment_request( def make_payment_request(
client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None client, recipient_name, outputs, change_addresses=None, memos=None, nonce=None
): ):
slip44 = 1 # Testnet
h_pr = sha256(b"SL\x00\x24") h_pr = sha256(b"SL\x00\x24")
if nonce: if nonce:
@ -79,7 +85,7 @@ def make_payment_request(
else: else:
raise ValueError raise ValueError
h_pr.update(slip44.to_bytes(4, "little")) h_pr.update(SLIP44.to_bytes(4, "little"))
change_address = iter(change_addresses or []) change_address = iter(change_addresses or [])
h_outputs = sha256() h_outputs = sha256()
@ -98,3 +104,73 @@ def make_payment_request(
nonce=nonce, nonce=nonce,
signature=payment_req_signer.sign_digest_deterministic(h_pr.digest()), signature=payment_req_signer.sign_digest_deterministic(h_pr.digest()),
) )
def make_coinjoin_request(
coordinator_name,
inputs,
input_script_pubkeys,
outputs,
output_script_pubkeys,
no_fee_indices,
fee_rate=50_000_000, # 0.5 %
no_fee_threshold=1_000_000,
min_registrable_amount=5_000,
):
# Reuse the signing key as the masking key to ensure deterministic behavior.
# Note that in production the masking key should be generated randomly.
ecdh = ECDH(curve=SECP256k1)
ecdh.load_private_key(payment_req_signer)
mask_public_key = ecdh.get_public_key().to_string("compressed")
# Process inputs.
h_prevouts = sha256()
coinjoin_flags = bytearray()
for i, (txi, script_pubkey) in enumerate(zip(inputs, input_script_pubkeys)):
# Add input to prevouts hash.
h_prevouts.update(bytes(reversed(txi.prev_hash)))
h_prevouts.update(txi.prev_index.to_bytes(4, "little"))
# Set signable flag in coinjoin_flags.
if len(script_pubkey) == 34 and script_pubkey.startswith(b"\x51\x20"):
ecdh.load_received_public_key_bytes(b"\x02" + script_pubkey[2:])
shared_secret = ecdh.generate_sharedsecret_bytes()
h_mask = sha256(shared_secret)
h_mask.update(bytes(reversed(txi.prev_hash)))
h_mask.update(txi.prev_index.to_bytes(4, "little"))
mask = h_mask.digest()[0] & 1
signable = bool(txi.address_n)
txi.coinjoin_flags = signable ^ mask
else:
txi.coinjoin_flags = 0
# Set no_fee flag in coinjoin_flags.
txi.coinjoin_flags |= (i in no_fee_indices) << 1
coinjoin_flags.append(txi.coinjoin_flags)
# Process outputs.
h_outputs = sha256()
for txo, script_pubkey in zip(outputs, output_script_pubkeys):
h_outputs.update(txo.amount.to_bytes(8, "little"))
hash_bytes_prefixed(h_outputs, script_pubkey)
# Hash the CoinJoin request.
h_request = sha256(b"CJR1")
hash_bytes_prefixed(h_request, coordinator_name.encode())
h_request.update(SLIP44.to_bytes(4, "little"))
h_request.update(fee_rate.to_bytes(4, "little"))
h_request.update(no_fee_threshold.to_bytes(8, "little"))
h_request.update(min_registrable_amount.to_bytes(8, "little"))
h_request.update(mask_public_key)
hash_bytes_prefixed(h_request, coinjoin_flags)
h_request.update(h_prevouts.digest())
h_request.update(h_outputs.digest())
return messages.CoinJoinRequest(
fee_rate=fee_rate,
no_fee_threshold=no_fee_threshold,
min_registrable_amount=min_registrable_amount,
mask_public_key=mask_public_key,
signature=payment_req_signer.sign_digest_deterministic(h_request.digest()),
)

View File

@ -24,8 +24,8 @@ from trezorlib.exceptions import TrezorFailure
from trezorlib.tools import parse_path from trezorlib.tools import parse_path
from ...tx_cache import TxCache from ...tx_cache import TxCache
from .payment_req import make_payment_request from .payment_req import make_coinjoin_request
from .signtx import request_finished, request_input, request_output, request_payment_req from .signtx import request_finished, request_input, request_output
B = messages.ButtonRequestType B = messages.ButtonRequestType
@ -121,6 +121,15 @@ def test_sign_tx(client: Client):
), ),
] ]
input_script_pubkeys = [
bytes.fromhex(
"5120b3a2750e21facec36b2a56d76cca6019bf517a5c45e2ea8e5b4ed191090f3003"
),
bytes.fromhex(
"51202f436892d90fb2665519efa3d9f0f5182859124f179486862c2cd7a78ea9ac19"
),
]
outputs = [ outputs = [
# Other's coinjoined output. # Other's coinjoined output.
messages.TxOutputType( messages.TxOutputType(
@ -129,7 +138,6 @@ def test_sign_tx(client: Client):
address="tb1pupzczx9cpgyqgtvycncr2mvxscl790luqd8g88qkdt2w3kn7ymhsrdueu2", address="tb1pupzczx9cpgyqgtvycncr2mvxscl790luqd8g88qkdt2w3kn7ymhsrdueu2",
amount=50_000, amount=50_000,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
payment_req_index=0,
), ),
# Our coinjoined output. # Our coinjoined output.
messages.TxOutputType( messages.TxOutputType(
@ -137,7 +145,6 @@ def test_sign_tx(client: Client):
address_n=parse_path("m/10025h/1h/0h/1h/1/1"), address_n=parse_path("m/10025h/1h/0h/1h/1/1"),
amount=50_000, amount=50_000,
script_type=messages.OutputScriptType.PAYTOTAPROOT, script_type=messages.OutputScriptType.PAYTOTAPROOT,
payment_req_index=0,
), ),
# Our change output. # Our change output.
messages.TxOutputType( messages.TxOutputType(
@ -145,7 +152,6 @@ def test_sign_tx(client: Client):
address_n=parse_path("m/10025h/1h/0h/1h/1/2"), address_n=parse_path("m/10025h/1h/0h/1h/1/2"),
amount=7_289_000 - 50_000 - 36_445 - 490, amount=7_289_000 - 50_000 - 36_445 - 490,
script_type=messages.OutputScriptType.PAYTOTAPROOT, script_type=messages.OutputScriptType.PAYTOTAPROOT,
payment_req_index=0,
), ),
# Other's change output. # Other's change output.
messages.TxOutputType( messages.TxOutputType(
@ -154,27 +160,39 @@ def test_sign_tx(client: Client):
address="tb1pvt7lzserh8xd5m6mq0zu9s5wxkpe5wgf5ts56v44jhrr6578hz8saxup5m", address="tb1pvt7lzserh8xd5m6mq0zu9s5wxkpe5wgf5ts56v44jhrr6578hz8saxup5m",
amount=100_000 - 50_000 - 500 - 490, amount=100_000 - 50_000 - 500 - 490,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
payment_req_index=0,
), ),
# Coordinator's output. # Coordinator's output.
messages.TxOutputType( messages.TxOutputType(
address="mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q", address="mvbu1Gdy8SUjTenqerxUaZyYjmveZvt33q",
amount=36_945, amount=36_945,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
payment_req_index=0,
), ),
] ]
payment_req = make_payment_request( output_script_pubkeys = [
client, bytes.fromhex(
recipient_name="www.example.com", "5120e0458118b80a08042d84c4f0356d86863fe2bffc034e839c166ad4e8da7e26ef"
outputs=outputs, ),
change_addresses=[ bytes.fromhex(
"tb1phkcspf88hge86djxgtwx2wu7ddghsw77d6sd7txtcxncu0xpx22shcydyf", "5120bdb100a4e7ba327d364642dc653b9e6b51783bde6ea0df2ccbc1a78e3cc13295"
"tb1pchruvduckkwuzm5hmytqz85emften5dnmkqu9uhfxwfywaqhuu0qjggqyp", ),
], bytes.fromhex(
"5120c5c7c63798b59dc16e97d916011e99da5799d1b3dd81c2f2e93392477417e71e"
),
bytes.fromhex(
"512062fdf14323b9ccda6f5b03c5c2c28e35839a3909a2e14d32b595c63d53c7b88f"
),
bytes.fromhex("76a914a579388225827d9f2fe9014add644487808c695d88ac"),
]
coinjoin_req = make_coinjoin_request(
"www.example.com",
inputs,
input_script_pubkeys,
outputs,
output_script_pubkeys,
no_fee_indices=[],
) )
payment_req.amount = None
with client: with client:
client.set_expected_responses( client.set_expected_responses(
@ -183,7 +201,6 @@ def test_sign_tx(client: Client):
request_input(0), request_input(0),
request_input(1), request_input(1),
request_output(0), request_output(0),
request_payment_req(0),
request_output(1), request_output(1),
request_output(2), request_output(2),
request_output(3), request_output(3),
@ -198,7 +215,7 @@ def test_sign_tx(client: Client):
inputs, inputs,
outputs, outputs,
prev_txes=TX_CACHE_TESTNET, prev_txes=TX_CACHE_TESTNET,
payment_reqs=[payment_req], coinjoin_request=coinjoin_req,
preauthorized=True, preauthorized=True,
serialize=False, serialize=False,
) )
@ -218,7 +235,7 @@ def test_sign_tx(client: Client):
inputs, inputs,
outputs, outputs,
prev_txes=TX_CACHE_TESTNET, prev_txes=TX_CACHE_TESTNET,
payment_reqs=[payment_req], coinjoin_request=coinjoin_req,
preauthorized=True, preauthorized=True,
) )
@ -230,7 +247,7 @@ def test_sign_tx(client: Client):
inputs, inputs,
outputs, outputs,
prev_txes=TX_CACHE_TESTNET, prev_txes=TX_CACHE_TESTNET,
payment_reqs=[payment_req], coinjoin_request=coinjoin_req,
preauthorized=True, preauthorized=True,
) )
@ -277,17 +294,54 @@ def test_sign_tx_large(client: Client):
commitment_data=commitment_data, commitment_data=commitment_data,
) )
internal_input = messages.TxInputType( internal_inputs = [
address_n=parse_path("m/10025h/1h/0h/1h/1/0"), messages.TxInputType(
address_n=parse_path(f"m/10025h/1h/0h/1h/1/{i}"),
amount=output_denom * own_output_count // own_input_count, amount=output_denom * own_output_count // own_input_count,
prev_hash=FAKE_TXHASH_f982c0, prev_hash=FAKE_TXHASH_f982c0,
prev_index=1, prev_index=1,
script_type=messages.InputScriptType.SPENDTAPROOT, script_type=messages.InputScriptType.SPENDTAPROOT,
) )
for i in range(own_input_count)
]
internal_input_script_pubkeys = [
bytes.fromhex(
"51202f436892d90fb2665519efa3d9f0f5182859124f179486862c2cd7a78ea9ac19"
),
bytes.fromhex(
"5120bdb100a4e7ba327d364642dc653b9e6b51783bde6ea0df2ccbc1a78e3cc13295"
),
bytes.fromhex(
"5120c5c7c63798b59dc16e97d916011e99da5799d1b3dd81c2f2e93392477417e71e"
),
bytes.fromhex(
"5120148db939506345b047d945fff64691508c90da036ea3313b38b386ba3ec64ec5"
),
bytes.fromhex(
"51202cf0ba67bc759b413c0a36e33f5223aee574a979cfc1bc6e59b136cc43a8da8d"
),
bytes.fromhex(
"51202ad44db2df5b2a4d46e3655b1ab2402229676e35a3a43c4f7cae73e862c10775"
),
bytes.fromhex(
"51209e101215e14de4bece6cabd552f11e5931cb53119f43e52c10f9c1de0fd03390"
),
bytes.fromhex(
"5120f799c40379196e8507b8adf72c78b6cc12bb9fbae38f3ad744dfcd19a5777253"
),
bytes.fromhex(
"5120db0563942a92fb8c89ced9325c2660607605cd645027d64a9f641e6bc1694020"
),
bytes.fromhex(
"51208f1bbec30c355ec71f7a87c5ea06547c9b9b8a51c7834cd726e13cbb83226d16"
),
]
inputs = [internal_input] * own_input_count + [external_input] * ( inputs = internal_inputs + [external_input] * (total_input_count - own_input_count)
total_input_count - own_input_count
) input_script_pubkeys = internal_input_script_pubkeys + [
external_input.script_pubkey
] * (total_input_count - own_input_count)
# OUTPUTS. # OUTPUTS.
@ -297,7 +351,9 @@ def test_sign_tx_large(client: Client):
address="tb1pupzczx9cpgyqgtvycncr2mvxscl790luqd8g88qkdt2w3kn7ymhsrdueu2", address="tb1pupzczx9cpgyqgtvycncr2mvxscl790luqd8g88qkdt2w3kn7ymhsrdueu2",
amount=output_denom, amount=output_denom,
script_type=messages.OutputScriptType.PAYTOADDRESS, script_type=messages.OutputScriptType.PAYTOADDRESS,
payment_req_index=0, )
external_output_script_pubkey = bytes.fromhex(
"5120e0458118b80a08042d84c4f0356d86863fe2bffc034e839c166ad4e8da7e26ef"
) )
internal_output = messages.TxOutputType( internal_output = messages.TxOutputType(
@ -305,23 +361,27 @@ def test_sign_tx_large(client: Client):
address_n=parse_path("m/10025h/1h/0h/1h/1/1"), address_n=parse_path("m/10025h/1h/0h/1h/1/1"),
amount=output_denom, amount=output_denom,
script_type=messages.OutputScriptType.PAYTOTAPROOT, script_type=messages.OutputScriptType.PAYTOTAPROOT,
payment_req_index=0, )
internal_output_script_pubkey = bytes.fromhex(
"5120bdb100a4e7ba327d364642dc653b9e6b51783bde6ea0df2ccbc1a78e3cc13295"
) )
outputs = [internal_output] * own_output_count + [external_output] * ( outputs = [internal_output] * own_output_count + [external_output] * (
total_output_count - own_output_count total_output_count - own_output_count
) )
payment_req = make_payment_request( output_script_pubkeys = [internal_output_script_pubkey] * own_output_count + [
client, external_output_script_pubkey
recipient_name="www.example.com", ] * (total_output_count - own_output_count)
outputs=outputs,
change_addresses=[ coinjoin_req = make_coinjoin_request(
"tb1phkcspf88hge86djxgtwx2wu7ddghsw77d6sd7txtcxncu0xpx22shcydyf" "www.example.com",
] inputs,
* own_output_count, input_script_pubkeys,
outputs,
output_script_pubkeys,
no_fee_indices=[],
) )
payment_req.amount = None
start = time.time() start = time.time()
with client: with client:
@ -331,7 +391,7 @@ def test_sign_tx_large(client: Client):
inputs, inputs,
outputs, outputs,
prev_txes=TX_CACHE_TESTNET, prev_txes=TX_CACHE_TESTNET,
payment_reqs=[payment_req], coinjoin_request=coinjoin_req,
preauthorized=True, preauthorized=True,
serialize=False, serialize=False,
) )