diff --git a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-aesgcm.h b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-aesgcm.h index c29de6d734..b7edf784b8 100644 --- a/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-aesgcm.h +++ b/core/embed/extmod/modtrezorcrypto/modtrezorcrypto-aesgcm.h @@ -110,6 +110,28 @@ STATIC mp_obj_t mod_trezorcrypto_AesGcm_encrypt(mp_obj_t self, mp_obj_t data) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_encrypt_obj, mod_trezorcrypto_AesGcm_encrypt); +/// def encrypt_in_place(self, data: bytearray | memoryview) -> int: +/// """ +/// Encrypt data chunk in place. Returns the length of the encrypted data. +/// """ +STATIC mp_obj_t mod_trezorcrypto_AesGcm_encrypt_in_place(mp_obj_t self, + mp_obj_t data) { + mp_obj_AesGcm_t *o = MP_OBJ_TO_PTR(self); + if (o->state != STATE_INIT && o->state != STATE_ENCRYPTING) { + mp_raise_msg(&mp_type_RuntimeError, "Invalid state."); + } + o->state = STATE_ENCRYPTING; + mp_buffer_info_t in = {0}; + mp_get_buffer_raise(data, &in, MP_BUFFER_READ | MP_BUFFER_WRITE); + if (gcm_encrypt((unsigned char *)in.buf, in.len, &(o->ctx)) != RETURN_GOOD) { + o->state = STATE_FAILED; + mp_raise_type(&mp_type_RuntimeError); + } + return mp_obj_new_int(in.len); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_encrypt_in_place_obj, + mod_trezorcrypto_AesGcm_encrypt_in_place); + /// def decrypt(self, data: bytes) -> bytes: /// """ /// Decrypt data chunk. @@ -135,6 +157,28 @@ STATIC mp_obj_t mod_trezorcrypto_AesGcm_decrypt(mp_obj_t self, mp_obj_t data) { STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_decrypt_obj, mod_trezorcrypto_AesGcm_decrypt); +/// def decrypt_in_place(self, data: bytearray | memoryview) -> int: +/// """ +/// Decrypt data chunk in place. Returns the length of the decrypted data. +/// """ +STATIC mp_obj_t mod_trezorcrypto_AesGcm_decrypt_in_place(mp_obj_t self, + mp_obj_t data) { + mp_obj_AesGcm_t *o = MP_OBJ_TO_PTR(self); + if (o->state != STATE_INIT && o->state != STATE_DECRYPTING) { + mp_raise_msg(&mp_type_RuntimeError, "Invalid state."); + } + o->state = STATE_DECRYPTING; + mp_buffer_info_t in = {0}; + mp_get_buffer_raise(data, &in, MP_BUFFER_READ | MP_BUFFER_WRITE); + if (gcm_decrypt((unsigned char *)in.buf, in.len, &(o->ctx)) != RETURN_GOOD) { + o->state = STATE_FAILED; + mp_raise_type(&mp_type_RuntimeError); + } + return mp_obj_new_int(in.len); +} +STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_trezorcrypto_AesGcm_decrypt_in_place_obj, + mod_trezorcrypto_AesGcm_decrypt_in_place); + /// def auth(self, data: bytes) -> None: /// """ /// Include authenticated data chunk in the GCM authentication tag. This can @@ -194,8 +238,12 @@ STATIC const mp_rom_map_elem_t mod_trezorcrypto_AesGcm_locals_dict_table[] = { MP_ROM_PTR(&mod_trezorcrypto_AesGcm_reset_obj)}, {MP_ROM_QSTR(MP_QSTR_encrypt), MP_ROM_PTR(&mod_trezorcrypto_AesGcm_encrypt_obj)}, + {MP_ROM_QSTR(MP_QSTR_encrypt_in_place), + MP_ROM_PTR(&mod_trezorcrypto_AesGcm_encrypt_in_place_obj)}, {MP_ROM_QSTR(MP_QSTR_decrypt), MP_ROM_PTR(&mod_trezorcrypto_AesGcm_decrypt_obj)}, + {MP_ROM_QSTR(MP_QSTR_decrypt_in_place), + MP_ROM_PTR(&mod_trezorcrypto_AesGcm_decrypt_in_place_obj)}, {MP_ROM_QSTR(MP_QSTR_auth), MP_ROM_PTR(&mod_trezorcrypto_AesGcm_auth_obj)}, {MP_ROM_QSTR(MP_QSTR_finish), MP_ROM_PTR(&mod_trezorcrypto_AesGcm_finish_obj)}, diff --git a/core/mocks/generated/trezorcrypto/__init__.pyi b/core/mocks/generated/trezorcrypto/__init__.pyi index 1a0861dacc..a7e0d95f3d 100644 --- a/core/mocks/generated/trezorcrypto/__init__.pyi +++ b/core/mocks/generated/trezorcrypto/__init__.pyi @@ -54,11 +54,21 @@ class aesgcm: Encrypt data chunk. """ + def encrypt_in_place(self, data: bytearray | memoryview) -> int: + """ + Encrypt data chunk in place. Returns the length of the encrypted data. + """ + def decrypt(self, data: bytes) -> bytes: """ Decrypt data chunk. """ + def decrypt_in_place(self, data: bytearray | memoryview) -> int: + """ + Decrypt data chunk in place. Returns the length of the decrypted data. + """ + def auth(self, data: bytes) -> None: """ Include authenticated data chunk in the GCM authentication tag. This can diff --git a/core/tests/test_trezor.crypto.aesgcm.py b/core/tests/test_trezor.crypto.aesgcm.py index fee6a9b2ee..78e834bb97 100644 --- a/core/tests/test_trezor.crypto.aesgcm.py +++ b/core/tests/test_trezor.crypto.aesgcm.py @@ -61,6 +61,29 @@ class TestCryptoAes(unittest.TestCase): self.assertEqual(ctx.decrypt(ct), pt) self.assertEqual(ctx.finish(), tag) + def test_gcm_in_place(self): + for vector in self.vectors: + key, iv, pt, aad, ct, tag = map(unhexlify, vector) + buffer = bytearray(pt) + + # Test encryption. + ctx = aesgcm(key, iv) + if aad: + ctx.auth(aad) + returned = ctx.encrypt_in_place(buffer) + self.assertEqual(buffer, ct) + self.assertEqual(returned, len(buffer)) + self.assertEqual(ctx.finish(), tag) + + # Test decryption. + ctx.reset(iv) + if aad: + ctx.auth(aad) + returned = ctx.decrypt_in_place(buffer) + self.assertEqual(buffer, pt) + self.assertEqual(returned, len(buffer)) + self.assertEqual(ctx.finish(), tag) + def test_gcm_chunks(self): for vector in self.vectors: key, iv, pt, aad, ct, tag = map(unhexlify, vector) @@ -83,6 +106,35 @@ class TestCryptoAes(unittest.TestCase): self.assertEqual(ctx.encrypt(pt[chunk1:]), ct[chunk1:]) self.assertEqual(ctx.finish(), tag) + def test_gcm_chunks_in_place(self): + for vector in self.vectors: + key, iv, pt, aad, ct, tag = map(unhexlify, vector) + buffer = bytearray(ct) + chunk1_length = len(pt) // 3 + chunk2_length = len(pt) - chunk1_length + + # Decrypt by chunks and add authenticated data by chunks. + ctx = aesgcm(key, iv) + returned = ctx.decrypt_in_place(memoryview(buffer)[: chunk1_length]) + self.assertEqual(returned, chunk1_length) + ctx.auth(aad[:17]) + returned = ctx.decrypt_in_place(memoryview(buffer)[chunk1_length:]) + ctx.auth(aad[17:]) + self.assertEqual(returned, chunk2_length) + self.assertEqual(buffer, pt) + self.assertEqual(ctx.finish(), tag) + + # Encrypt by chunks and add authenticated data by chunks. + ctx.reset(iv) + ctx.auth(aad[:7]) + returned = ctx.encrypt_in_place(memoryview(buffer)[: chunk1_length]) + self.assertEqual(returned, chunk1_length) + ctx.auth(aad[7:]) + returned = ctx.encrypt_in_place(memoryview(buffer)[chunk1_length:]) + self.assertEqual(returned, chunk2_length) + self.assertEqual(buffer, ct) + self.assertEqual(ctx.finish(), tag) + if __name__ == "__main__": unittest.main()