mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-02 03:20:59 +00:00
113 lines
3.6 KiB
Python
113 lines
3.6 KiB
Python
|
from cryptography.hazmat.backends import default_backend
|
||
|
from cryptography.hazmat.primitives import hashes, hmac
|
||
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms
|
||
|
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
|
||
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||
|
|
||
|
from . import consts, prng
|
||
|
|
||
|
|
||
|
def derive_kek_keiv(salt: bytes, pin: int) -> (bytes, bytes):
|
||
|
kdf = PBKDF2HMAC(
|
||
|
algorithm=hashes.SHA256(),
|
||
|
length=consts.KEK_SIZE + consts.KEIV_SIZE,
|
||
|
salt=bytes(salt),
|
||
|
iterations=10000,
|
||
|
backend=default_backend(),
|
||
|
)
|
||
|
pbkdf_output = kdf.derive(pin.to_bytes(4, "little"))
|
||
|
# the first 256b is Key Encryption Key
|
||
|
kek = pbkdf_output[: consts.KEK_SIZE]
|
||
|
# following with 96b of Initialization Vector
|
||
|
keiv = pbkdf_output[consts.KEK_SIZE :]
|
||
|
|
||
|
return kek, keiv
|
||
|
|
||
|
|
||
|
def chacha_poly_encrypt(
|
||
|
key: bytes, iv: bytes, data: bytes, additional_data: bytes = None
|
||
|
) -> (bytes, bytes):
|
||
|
chacha = ChaCha20Poly1305(key)
|
||
|
chacha_output = chacha.encrypt(iv, bytes(data), additional_data)
|
||
|
# cipher text and 128b authentication tag
|
||
|
return chacha_output[: len(data)], chacha_output[len(data) :]
|
||
|
|
||
|
|
||
|
def chacha_poly_decrypt(
|
||
|
key: bytes, app_key: int, iv: bytes, data: bytes, additional_data: bytes = None
|
||
|
) -> bytes:
|
||
|
chacha = ChaCha20Poly1305(key)
|
||
|
chacha_output = chacha.decrypt(bytes(iv), bytes(data), additional_data)
|
||
|
return chacha_output
|
||
|
|
||
|
|
||
|
def decrypt_edek_esak(
|
||
|
pin: int, salt: bytes, edek_esak: bytes, pvc: bytes
|
||
|
) -> (bytes, bytes):
|
||
|
"""
|
||
|
Decrypts EDEK, ESAK to DEK, SAK and checks PIN in the process.
|
||
|
Raises:
|
||
|
InvalidPinError: if PIN is invalid
|
||
|
"""
|
||
|
kek, keiv = derive_kek_keiv(salt, pin)
|
||
|
|
||
|
algorithm = algorithms.ChaCha20(kek, (1).to_bytes(4, "little") + keiv)
|
||
|
cipher = Cipher(algorithm, mode=None, backend=default_backend())
|
||
|
decryptor = cipher.decryptor()
|
||
|
dek_sak = decryptor.update(bytes(edek_esak))
|
||
|
dek = dek_sak[: consts.DEK_SIZE]
|
||
|
sak = dek_sak[consts.DEK_SIZE :]
|
||
|
|
||
|
if not validate_pin(kek, keiv, dek_sak, pvc):
|
||
|
raise InvalidPinError("Invalid PIN")
|
||
|
|
||
|
return dek, sak
|
||
|
|
||
|
|
||
|
def validate_pin(kek: bytes, keiv: bytes, dek_sak: bytes, pvc: bytes) -> bool:
|
||
|
"""
|
||
|
This a little bit hackish. We do not store the whole
|
||
|
authentication tag so we can't decrypt using ChaCha20Poly1305
|
||
|
because it obviously checks the tag first and fails.
|
||
|
So we are using the sole ChaCha20 cipher to decipher and then encrypt
|
||
|
again with Chacha20Poly1305 here to get the tag and compare it to PVC.
|
||
|
"""
|
||
|
_, tag = chacha_poly_encrypt(kek, keiv, dek_sak)
|
||
|
prng.random32()
|
||
|
prng.random32()
|
||
|
return tag[: consts.PVC_SIZE] == pvc
|
||
|
|
||
|
|
||
|
def calculate_hmacs(sak: bytes, keys: bytes) -> bytes:
|
||
|
"""
|
||
|
This calculates HMAC-SHA-256(SAK, (XOR_i) HMAC-SHA-256(SAK, KEY_i)).
|
||
|
In other words, it does HMAC for every KEY and XORs it all together.
|
||
|
One more final HMAC is then performed on the result.
|
||
|
"""
|
||
|
hmacs = _hmac(sak, keys[0])
|
||
|
for key in keys[1:]:
|
||
|
hmacs = _xor(hmacs, _hmac(sak, key))
|
||
|
return _final_hmac(sak, hmacs)
|
||
|
|
||
|
|
||
|
def init_hmacs(sak: bytes) -> bytes:
|
||
|
return _final_hmac(sak, b"\x00" * hashes.SHA256.digest_size)
|
||
|
|
||
|
|
||
|
def _final_hmac(sak: bytes, data: bytes) -> bytes:
|
||
|
return _hmac(sak, data)[: consts.SAT_SIZE]
|
||
|
|
||
|
|
||
|
def _hmac(key: bytes, data: bytes) -> bytes:
|
||
|
h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend())
|
||
|
h.update(data)
|
||
|
return h.finalize()
|
||
|
|
||
|
|
||
|
def _xor(first: bytes, second: bytes) -> bytes:
|
||
|
return bytes(a ^ b for a, b in zip(first, second))
|
||
|
|
||
|
|
||
|
class InvalidPinError(ValueError):
|
||
|
pass
|