from common import * # isort:skip from trezorcrypto import aesgcm, curve25519 import storage if utils.USE_THP: import thp_common from trezor.wire.thp import crypto from trezor.wire.thp.crypto import IV_1, IV_2, Handshake def get_dummy_device_secret(): return b"\x01\x02\x03\x04\x05\x06\x07\x08\x01\x02\x03\x04\x05\x06\x07\x08" @unittest.skipUnless(utils.USE_THP, "only needed for THP") class TestTrezorHostProtocolCrypto(unittest.TestCase): if utils.USE_THP: handshake = Handshake() key_1 = b"\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07\x00\x01\x02\x03\x04\x05\x06\x07" # 0:key, 1:nonce, 2:auth_data, 3:plaintext, 4:expected_ciphertext, 5:expected_tag vectors_enc = [ ( key_1, 0, b"\x55\x64", b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09", b"e2c9dd152fbee5821ea7", b"10625812de81b14a46b9f1e5100a6d0c", ), ( key_1, 1, b"\x55\x64", b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09", b"79811619ddb07c2b99f8", b"71c6b872cdc499a7e9a3c7441f053214", ), ( key_1, 369, b"\x55\x64", b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", b"03bd030390f2dfe815a61c2b157a064f", b"c1200f8a7ae9a6d32cef0fff878d55c2", ), ( key_1, 369, b"\x55\x64\x73\x82\x91", b"\x00\x01\x02\x03\x04\05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f", b"03bd030390f2dfe815a61c2b157a064f", b"693ac160cd93a20f7fc255f049d808d0", ), ] # 0:chaining key, 1:input, 2:output_1, 3:output:2 vectors_hkdf = [ ( crypto.PROTOCOL_NAME, b"\x01\x02", b"c784373a217d6be057cddc6068e6748f255fc8beb6f99b7b90cbc64aad947514", b"12695451e29bf08ffe5e4e6ab734b0c3d7cdd99b16cd409f57bd4eaa874944ba", ), ( b"\xc7\x84\x37\x3a\x21\x7d\x6b\xe0\x57\xcd\xdc\x60\x68\xe6\x74\x8f\x25\x5f\xc8\xbe\xb6\xf9\x9b\x7b\x90\xcb\xc6\x4a\xad\x94\x75\x14", b"\x31\x41\x59\x26\x52\x12\x34\x56\x78\x89\x04\xaa", b"f88c1e08d5c3bae8f6e4a3d3324c8cbc60a805603e399e69c4bf4eacb27c2f48", b"5f0216bdb7110ee05372286974da8c9c8b96e2efa15b4af430755f462bd79a76", ), ] vectors_iv = [ (0, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"), (1, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"), (7, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x07"), (1025, b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x01"), (4294967295, b"\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\xff\xff"), (0xFFFFFFFFFFFFFFFF, b"\x00\x00\x00\x00\xff\xff\xff\xff\xff\xff\xff\xff"), ] def __init__(self): if __debug__: thp_common.suppres_debug_log() super().__init__() def setUp(self): utils.DISABLE_ENCRYPTION = False def test_encryption(self): for v in self.vectors_enc: buffer = bytearray(v[3]) tag = crypto.enc(buffer, v[0], v[1], v[2]) self.assertEqual(hexlify(buffer), v[4]) self.assertEqual(hexlify(tag), v[5]) self.assertTrue(crypto.dec(buffer, tag, v[0], v[1], v[2])) self.assertEqual(buffer, v[3]) def test_hkdf(self): for v in self.vectors_hkdf: ck, k = crypto._hkdf(v[0], v[1]) self.assertEqual(hexlify(ck), v[2]) self.assertEqual(hexlify(k), v[3]) def test_iv_from_nonce(self): for v in self.vectors_iv: x = v[0] y = x.to_bytes(8, "big") iv = crypto._get_iv_from_nonce(v[0]) self.assertEqual(iv, v[1]) with self.assertRaises(AssertionError) as e: iv = crypto._get_iv_from_nonce(0xFFFFFFFFFFFFFFFF + 1) self.assertEqual(e.value.value, "Nonce overflow, terminate the channel") def test_incorrect_vectors(self): pass def test_th1_crypto(self): storage.device.get_device_secret = get_dummy_device_secret handshake = self.handshake host_ephemeral_privkey = curve25519.generate_secret() host_ephemeral_pubkey = curve25519.publickey(host_ephemeral_privkey) handshake.handle_th1_crypto(b"", host_ephemeral_pubkey) def test_th2_crypto(self): handshake = self.handshake host_static_privkey = curve25519.generate_secret() host_static_pubkey = curve25519.publickey(host_static_privkey) aes_ctx = aesgcm(handshake.k, IV_2) aes_ctx.auth(handshake.h) encrypted_host_static_pubkey = bytearray( aes_ctx.encrypt(host_static_pubkey) + aes_ctx.finish() ) # Code to encrypt Host's noise encrypted payload correctly: protomsg = bytearray(b"\x10\x02\x10\x03") temp_k = handshake.k temp_h = handshake.h temp_h = crypto._hash_of_two(temp_h, encrypted_host_static_pubkey) _, temp_k = crypto._hkdf( handshake.ck, curve25519.multiply(handshake.trezor_ephemeral_privkey, host_static_pubkey), ) aes_ctx = aesgcm(temp_k, IV_1) aes_ctx.encrypt_in_place(protomsg) aes_ctx.auth(temp_h) tag = aes_ctx.finish() encrypted_payload = bytearray(protomsg + tag) # end of encrypted payload generation handshake.handle_th2_crypto(encrypted_host_static_pubkey, encrypted_payload) self.assertEqual(encrypted_payload[:4], b"\x10\x02\x10\x03") if __name__ == "__main__": unittest.main()