1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-08 22:40:59 +00:00

feat(core): add AES-GCM in-place encryption and decryption

[no changelog]
This commit is contained in:
Ondřej Vejpustek 2024-05-29 13:25:22 +02:00
parent 662f13136f
commit 67ac4078f7
3 changed files with 110 additions and 0 deletions

View File

@ -110,6 +110,28 @@ STATIC mp_obj_t mod_trezorcrypto_AesGcm_encrypt(mp_obj_t self, mp_obj_t data) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_encrypt_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_encrypt_obj,
mod_trezorcrypto_AesGcm_encrypt); mod_trezorcrypto_AesGcm_encrypt);
/// def encrypt_in_place(self, data: bytearray | memoryview) -> int:
/// """
/// Encrypt data chunk in place. Returns the length of the encrypted data.
/// """
STATIC mp_obj_t mod_trezorcrypto_AesGcm_encrypt_in_place(mp_obj_t self,
mp_obj_t data) {
mp_obj_AesGcm_t *o = MP_OBJ_TO_PTR(self);
if (o->state != STATE_INIT && o->state != STATE_ENCRYPTING) {
mp_raise_msg(&mp_type_RuntimeError, "Invalid state.");
}
o->state = STATE_ENCRYPTING;
mp_buffer_info_t in = {0};
mp_get_buffer_raise(data, &in, MP_BUFFER_READ | MP_BUFFER_WRITE);
if (gcm_encrypt((unsigned char *)in.buf, in.len, &(o->ctx)) != RETURN_GOOD) {
o->state = STATE_FAILED;
mp_raise_type(&mp_type_RuntimeError);
}
return mp_obj_new_int(in.len);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_encrypt_in_place_obj,
mod_trezorcrypto_AesGcm_encrypt_in_place);
/// def decrypt(self, data: bytes) -> bytes: /// def decrypt(self, data: bytes) -> bytes:
/// """ /// """
/// Decrypt data chunk. /// Decrypt data chunk.
@ -135,6 +157,28 @@ STATIC mp_obj_t mod_trezorcrypto_AesGcm_decrypt(mp_obj_t self, mp_obj_t data) {
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_decrypt_obj, STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_decrypt_obj,
mod_trezorcrypto_AesGcm_decrypt); mod_trezorcrypto_AesGcm_decrypt);
/// def decrypt_in_place(self, data: bytearray | memoryview) -> int:
/// """
/// Decrypt data chunk in place. Returns the length of the decrypted data.
/// """
STATIC mp_obj_t mod_trezorcrypto_AesGcm_decrypt_in_place(mp_obj_t self,
mp_obj_t data) {
mp_obj_AesGcm_t *o = MP_OBJ_TO_PTR(self);
if (o->state != STATE_INIT && o->state != STATE_DECRYPTING) {
mp_raise_msg(&mp_type_RuntimeError, "Invalid state.");
}
o->state = STATE_DECRYPTING;
mp_buffer_info_t in = {0};
mp_get_buffer_raise(data, &in, MP_BUFFER_READ | MP_BUFFER_WRITE);
if (gcm_decrypt((unsigned char *)in.buf, in.len, &(o->ctx)) != RETURN_GOOD) {
o->state = STATE_FAILED;
mp_raise_type(&mp_type_RuntimeError);
}
return mp_obj_new_int(in.len);
}
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_decrypt_in_place_obj,
mod_trezorcrypto_AesGcm_decrypt_in_place);
/// def auth(self, data: bytes) -> None: /// def auth(self, data: bytes) -> None:
/// """ /// """
/// Include authenticated data chunk in the GCM authentication tag. This can /// Include authenticated data chunk in the GCM authentication tag. This can
@ -194,8 +238,12 @@ STATIC const mp_rom_map_elem_t mod_trezorcrypto_AesGcm_locals_dict_table[] = {
MP_ROM_PTR(&mod_trezorcrypto_AesGcm_reset_obj)}, MP_ROM_PTR(&mod_trezorcrypto_AesGcm_reset_obj)},
{MP_ROM_QSTR(MP_QSTR_encrypt), {MP_ROM_QSTR(MP_QSTR_encrypt),
MP_ROM_PTR(&mod_trezorcrypto_AesGcm_encrypt_obj)}, MP_ROM_PTR(&mod_trezorcrypto_AesGcm_encrypt_obj)},
{MP_ROM_QSTR(MP_QSTR_encrypt_in_place),
MP_ROM_PTR(&mod_trezorcrypto_AesGcm_encrypt_in_place_obj)},
{MP_ROM_QSTR(MP_QSTR_decrypt), {MP_ROM_QSTR(MP_QSTR_decrypt),
MP_ROM_PTR(&mod_trezorcrypto_AesGcm_decrypt_obj)}, MP_ROM_PTR(&mod_trezorcrypto_AesGcm_decrypt_obj)},
{MP_ROM_QSTR(MP_QSTR_decrypt_in_place),
MP_ROM_PTR(&mod_trezorcrypto_AesGcm_decrypt_in_place_obj)},
{MP_ROM_QSTR(MP_QSTR_auth), MP_ROM_PTR(&mod_trezorcrypto_AesGcm_auth_obj)}, {MP_ROM_QSTR(MP_QSTR_auth), MP_ROM_PTR(&mod_trezorcrypto_AesGcm_auth_obj)},
{MP_ROM_QSTR(MP_QSTR_finish), {MP_ROM_QSTR(MP_QSTR_finish),
MP_ROM_PTR(&mod_trezorcrypto_AesGcm_finish_obj)}, MP_ROM_PTR(&mod_trezorcrypto_AesGcm_finish_obj)},

View File

@ -54,11 +54,21 @@ class aesgcm:
Encrypt data chunk. Encrypt data chunk.
""" """
def encrypt_in_place(self, data: bytearray | memoryview) -> int:
"""
Encrypt data chunk in place. Returns the length of the encrypted data.
"""
def decrypt(self, data: bytes) -> bytes: def decrypt(self, data: bytes) -> bytes:
""" """
Decrypt data chunk. Decrypt data chunk.
""" """
def decrypt_in_place(self, data: bytearray | memoryview) -> int:
"""
Decrypt data chunk in place. Returns the length of the decrypted data.
"""
def auth(self, data: bytes) -> None: def auth(self, data: bytes) -> None:
""" """
Include authenticated data chunk in the GCM authentication tag. This can Include authenticated data chunk in the GCM authentication tag. This can

View File

@ -61,6 +61,29 @@ class TestCryptoAes(unittest.TestCase):
self.assertEqual(ctx.decrypt(ct), pt) self.assertEqual(ctx.decrypt(ct), pt)
self.assertEqual(ctx.finish(), tag) self.assertEqual(ctx.finish(), tag)
def test_gcm_in_place(self):
for vector in self.vectors:
key, iv, pt, aad, ct, tag = map(unhexlify, vector)
buffer = bytearray(pt)
# Test encryption.
ctx = aesgcm(key, iv)
if aad:
ctx.auth(aad)
returned = ctx.encrypt_in_place(buffer)
self.assertEqual(buffer, ct)
self.assertEqual(returned, len(buffer))
self.assertEqual(ctx.finish(), tag)
# Test decryption.
ctx.reset(iv)
if aad:
ctx.auth(aad)
returned = ctx.decrypt_in_place(buffer)
self.assertEqual(buffer, pt)
self.assertEqual(returned, len(buffer))
self.assertEqual(ctx.finish(), tag)
def test_gcm_chunks(self): def test_gcm_chunks(self):
for vector in self.vectors: for vector in self.vectors:
key, iv, pt, aad, ct, tag = map(unhexlify, vector) key, iv, pt, aad, ct, tag = map(unhexlify, vector)
@ -83,6 +106,35 @@ class TestCryptoAes(unittest.TestCase):
self.assertEqual(ctx.encrypt(pt[chunk1:]), ct[chunk1:]) self.assertEqual(ctx.encrypt(pt[chunk1:]), ct[chunk1:])
self.assertEqual(ctx.finish(), tag) self.assertEqual(ctx.finish(), tag)
def test_gcm_chunks_in_place(self):
for vector in self.vectors:
key, iv, pt, aad, ct, tag = map(unhexlify, vector)
buffer = bytearray(ct)
chunk1_length = len(pt) // 3
chunk2_length = len(pt) - chunk1_length
# Decrypt by chunks and add authenticated data by chunks.
ctx = aesgcm(key, iv)
returned = ctx.decrypt_in_place(memoryview(buffer)[: chunk1_length])
self.assertEqual(returned, chunk1_length)
ctx.auth(aad[:17])
returned = ctx.decrypt_in_place(memoryview(buffer)[chunk1_length:])
ctx.auth(aad[17:])
self.assertEqual(returned, chunk2_length)
self.assertEqual(buffer, pt)
self.assertEqual(ctx.finish(), tag)
# Encrypt by chunks and add authenticated data by chunks.
ctx.reset(iv)
ctx.auth(aad[:7])
returned = ctx.encrypt_in_place(memoryview(buffer)[: chunk1_length])
self.assertEqual(returned, chunk1_length)
ctx.auth(aad[7:])
returned = ctx.encrypt_in_place(memoryview(buffer)[chunk1_length:])
self.assertEqual(returned, chunk2_length)
self.assertEqual(buffer, ct)
self.assertEqual(ctx.finish(), tag)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()