1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-02 12:38:43 +00:00
trezor-firmware/core/tests/test_trezor.wire.thp.crypto.py

157 lines
5.8 KiB
Python
Raw Normal View History

2024-11-15 16:31:22 +00:00
from common import * # isort:skip
from trezorcrypto import aesgcm, curve25519
import storage
if utils.USE_THP:
import thp_common
from trezor.wire.thp import crypto
from trezor.wire.thp.crypto import IV_1, IV_2, Handshake
def get_dummy_device_secret():
return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08"
@unittest.skipUnless(utils.USE_THP, "only needed for THP")
class TestTrezorHostProtocolCrypto(unittest.TestCase):
if utils.USE_THP:
handshake = Handshake()
key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07"
# 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag
vectors_enc = [
(
key_1,
0,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"e2c9dd152fbee5821ea7",
b"10625812de81b14a46b9f1e5100a6d0c",
),
(
key_1,
1,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09",
b"79811619ddb07c2b99f8",
b"71c6b872cdc499a7e9a3c7441f053214",
),
(
key_1,
369,
b"\x55\x64",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"c1200f8a7ae9a6d32cef0fff878d55c2",
),
(
key_1,
369,
b"\x55\x64\x73\x82\x91",
b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f",
b"03bd030390f2dfe815a61c2b157a064f",
b"693ac160cd93a20f7fc255f049d808d0",
),
]
# 0:chaining key, 1:input, 2:output_1, 3:output:2
vectors_hkdf = [
(
crypto.PROTOCOL_NAME,
b"\x01\x02",
b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514",
b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba",
),
(
b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14",
b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa",
b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48",
b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76",
),
]
vectors_iv = [
(0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"),
(1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"),
(7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"),
(1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"),
(4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"),
(0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"),
]
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
super().__init__()
def setUp(self):
utils.DISABLE_ENCRYPTION = False
def test_encryption(self):
for v in self.vectors_enc:
buffer = bytearray(v[3])
tag = crypto.enc(buffer, v[0], v[1], v[2])
self.assertEqual(hexlify(buffer), v[4])
self.assertEqual(hexlify(tag), v[5])
self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2]))
self.assertEqual(buffer, v[3])
def test_hkdf(self):
for v in self.vectors_hkdf:
ck, k = crypto._hkdf(v[0], v[1])
self.assertEqual(hexlify(ck), v[2])
self.assertEqual(hexlify(k), v[3])
def test_iv_from_nonce(self):
for v in self.vectors_iv:
x = v[0]
y = x.to_bytes(8, "big")
iv = crypto._get_iv_from_nonce(v[0])
self.assertEqual(iv, v[1])
with self.assertRaises(AssertionError) as e:
iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1)
self.assertEqual(e.value.value, "Nonce overflow, terminate the channel")
def test_incorrect_vectors(self):
pass
def test_th1_crypto(self):
storage.device.get_device_secret = get_dummy_device_secret
handshake = self.handshake
host_ephemeral_privkey = curve25519.generate_secret()
host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey)
handshake.handle_th1_crypto(b"", host_ephemeral_pubkey)
def test_th2_crypto(self):
handshake = self.handshake
host_static_privkey = curve25519.generate_secret()
host_static_pubkey = curve25519.publickey(host_static_privkey)
aes_ctx = aesgcm(handshake.k, IV_2)
aes_ctx.auth(handshake.h)
encrypted_host_static_pubkey = bytearray(
aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish()
)
# Code to encrypt Host's noise encrypted payload correctly:
protomsg = bytearray(b"\x10\x02\x10\x03")
temp_k = handshake.k
temp_h = handshake.h
temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey)
_, temp_k = crypto._hkdf(
handshake.ck,
curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey),
)
aes_ctx = aesgcm(temp_k, IV_1)
aes_ctx.encrypt_in_place(protomsg)
aes_ctx.auth(temp_h)
tag = aes_ctx.finish()
encrypted_payload = bytearray(protomsg + tag)
# end of encrypted payload generation
handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload)
self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03")
if __name__ == "__main__":
unittest.main()