From 03d5751f3f1e095a7aeaa2da38c45703d2257b51 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Mon, 3 Mar 2025 10:02:02 +0100 Subject: [PATCH] wip --- core/src/trezor/wire/thp/crypto.py | 7 +- .../trezorlib/transport/thp/protocol_v2.py | 157 +++++------------- tests/device_tests/thp/test_thp.py | 81 ++++----- 3 files changed, 87 insertions(+), 158 deletions(-) diff --git a/core/src/trezor/wire/thp/crypto.py b/core/src/trezor/wire/thp/crypto.py index 05cfa54247..ed23ff8761 100644 --- a/core/src/trezor/wire/thp/crypto.py +++ b/core/src/trezor/wire/thp/crypto.py @@ -110,8 +110,13 @@ class Handshake: encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey) if __debug__: log.debug( - __name__, "th1 - enc (key: %s, nonce: %d)", get_bytes_as_str(self.k), 0 + __name__, + "th1 - enc (key: %s, nonce: %d, handshake_hash %s)", + get_bytes_as_str(self.k), + 0, + get_bytes_as_str(self.h), ) + aes_ctx.auth(self.h) tag_to_encrypted_key = aes_ctx.finish() encrypted_trezor_static_pubkey = ( diff --git a/python/src/trezorlib/transport/thp/protocol_v2.py b/python/src/trezorlib/transport/thp/protocol_v2.py index 814df7c0ee..8d382f5deb 100644 --- a/python/src/trezorlib/transport/thp/protocol_v2.py +++ b/python/src/trezorlib/transport/thp/protocol_v2.py @@ -1,19 +1,17 @@ from __future__ import annotations -import hashlib -import hmac import logging import os import typing as t from binascii import hexlify import click -from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from noise.connection import Keypair, NoiseConnection from ... import exceptions, messages, protobuf from ...mapping import ProtobufMapping from .. import Transport -from ..thp import checksum, curve25519, thp_io +from ..thp import checksum, thp_io from ..thp.channel_data import ChannelData from ..thp.checksum import CHECKSUM_LENGTH from ..thp.message_header import MessageHeader @@ -30,27 +28,6 @@ if t.TYPE_CHECKING: MT = t.TypeVar("MT", bound=protobuf.MessageType) -def _sha256_of_two(val_1: bytes, val_2: bytes) -> bytes: - hash = hashlib.sha256(val_1) - hash.update(val_2) - return hash.digest() - - -def _hkdf(chaining_key: bytes, input: bytes): - temp_key = hmac.new(chaining_key, input, hashlib.sha256).digest() - output_1 = hmac.new(temp_key, b"\x01", hashlib.sha256).digest() - ctx_output_2 = hmac.new(temp_key, output_1, hashlib.sha256) - ctx_output_2.update(b"\x02") - output_2 = ctx_output_2.digest() - return (output_1, output_2) - - -def _get_iv_from_nonce(nonce: int) -> bytes: - if not nonce <= 0xFFFFFFFFFFFFFFFF: - raise ValueError("Nonce overflow, terminate the channel") - return bytes(4) + nonce.to_bytes(8, "big") - - class ProtocolV2Channel(Channel): channel_id: int channel_database: ChannelDatabase @@ -95,8 +72,8 @@ class ProtocolV2Channel(Channel): protocol_version_minor=2, transport_path=self.transport.get_path(), channel_id=self.channel_id, - key_request=self.key_request, - key_response=self.key_response, + key_request=self.noise.noise_protocol.cipher_state_encrypt.k, + key_response=self.noise.noise_protocol.cipher_state_decrypt.k, nonce_request=self.nonce_request, nonce_response=self.nonce_response, sync_bit_receive=self.sync_bit_receive, @@ -188,39 +165,40 @@ class ProtocolV2Channel(Channel): device_properties = payload[10:] return (channel_id, device_properties) + def _init_noise(self, randomness_static: bytes) -> None: + self.noise = NoiseConnection.from_name(b"Noise_XX_25519_AESGCM_SHA256") + self.noise.set_as_initiator() + self.noise.set_keypair_from_private_bytes(Keypair.STATIC, randomness_static) + + prologue = bytes(self.device_properties) + self.noise.set_prologue(prologue) + self.noise.start_handshake() + def _do_handshake( - self, credential: bytes | None = None, host_static_privkey: bytes | None = None + self, + credential: bytes | None = None, + host_static_randomness: bytes | None = None, ): - host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) - host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) - self._send_handshake_init_request(host_ephemeral_pubkey) + randomness_static = host_static_randomness or os.urandom(32) + + self._init_noise(randomness_static) + self._send_handshake_init_request() self._read_ack() - init_response = self._read_handshake_init_response() - - trezor_ephemeral_pubkey = init_response[:32] - encrypted_trezor_static_pubkey = init_response[32:80] - noise_tag = init_response[80:96] - LOG.debug("noise_tag: %s", hexlify(noise_tag).decode()) - - # TODO check noise_tag is valid - - ck = self._send_handshake_completion_request( - host_ephemeral_pubkey, - host_ephemeral_privkey, - trezor_ephemeral_pubkey, - encrypted_trezor_static_pubkey, + self._read_handshake_init_response() + self._send_handshake_completion_request( credential, - host_static_privkey, ) self._read_ack() self._read_handshake_completion_response() - self.key_request, self.key_response = _hkdf(ck, b"") + self.key_request = self.noise.noise_protocol.cipher_state_encrypt.k + self.key_response = self.noise.noise_protocol.cipher_state_decrypt.k self.nonce_request = 0 self.nonce_response = 1 - def _send_handshake_init_request(self, host_ephemeral_pubkey: bytes) -> None: + def _send_handshake_init_request(self) -> None: ha_init_req_header = MessageHeader(0, self.channel_id, 36) + host_ephemeral_pubkey = self.noise.write_message() thp_io.write_payload_to_wire_and_add_checksum( self.transport, ha_init_req_header, host_ephemeral_pubkey @@ -241,90 +219,49 @@ class ProtocolV2Channel(Channel): "Received message is not a valid handshake init response message", err=True, ) + self.noise.read_message(payload) return payload def _send_handshake_completion_request( self, - host_ephemeral_pubkey: bytes, - host_ephemeral_privkey: bytes, - trezor_ephemeral_pubkey: bytes, - encrypted_trezor_static_pubkey: bytes, credential: bytes | None = None, - host_static_privkey: bytes | None = None, - ) -> bytes: - PROTOCOL_NAME = b"Noise_XX_25519_AESGCM_SHA256\x00\x00\x00\x00" - IV_1 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" - IV_2 = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" - h = _sha256_of_two(PROTOCOL_NAME, self.device_properties) - h = _sha256_of_two(h, host_ephemeral_pubkey) - h = _sha256_of_two(h, trezor_ephemeral_pubkey) - ck, k = _hkdf( - PROTOCOL_NAME, - curve25519.multiply(host_ephemeral_privkey, trezor_ephemeral_pubkey), - ) + ) -> None: + # TODO implement key recognition + # print( + # "TREZOR's static pubkey:\n", + # self.noise.noise_protocol.handshake_state.rs.public.public_bytes_raw(), + # ) - aes_ctx = AESGCM(k) - try: - trezor_masked_static_pubkey = aes_ctx.decrypt( - IV_1, encrypted_trezor_static_pubkey, h - ) - except Exception as e: - click.echo( - f"Exception of type{type(e)}", err=True - ) # TODO how to handle potential exceptions? Q for Matejcik - h = _sha256_of_two(h, encrypted_trezor_static_pubkey) - ck, k = _hkdf( - ck, curve25519.multiply(host_ephemeral_privkey, trezor_masked_static_pubkey) - ) - aes_ctx = AESGCM(k) - - tag_of_empty_string = aes_ctx.encrypt(IV_1, b"", h) - h = _sha256_of_two(h, tag_of_empty_string) - - # TODO: search for saved credentials - if host_static_privkey is None: - host_static_privkey = curve25519.get_private_key(os.urandom(32)) - host_static_pubkey = curve25519.get_public_key(host_static_privkey) - - aes_ctx = AESGCM(k) - encrypted_host_static_pubkey = aes_ctx.encrypt(IV_2, host_static_pubkey, h) - h = _sha256_of_two(h, encrypted_host_static_pubkey) - ck, k = _hkdf( - ck, curve25519.multiply(host_static_privkey, trezor_ephemeral_pubkey) - ) msg_data = self.mapping.encode_without_wire_type( messages.ThpHandshakeCompletionReqNoisePayload( host_pairing_credential=credential, ) ) + message2 = self.noise.write_message(payload=msg_data) - aes_ctx = AESGCM(k) - - encrypted_payload = aes_ctx.encrypt(IV_1, msg_data, h) - h = _sha256_of_two(h, encrypted_payload[:-16]) ha_completion_req_header = MessageHeader( 0x12, self.channel_id, - len(encrypted_host_static_pubkey) - + len(encrypted_payload) - + CHECKSUM_LENGTH, + len(message2) + CHECKSUM_LENGTH, ) thp_io.write_payload_to_wire_and_add_checksum( self.transport, ha_completion_req_header, - encrypted_host_static_pubkey + encrypted_payload, + message2, # encrypted_host_static_pubkey + encrypted_payload, ) - self.handshake_hash = h - return ck + self.handshake_hash = self.noise.get_handshake_hash() def _read_handshake_completion_response(self) -> None: # Read handshake completion response, ignore payload as we do not care about the state - header, _ = self._read_until_valid_crc_check() + header, data = self._read_until_valid_crc_check() if not header.is_handshake_comp_response(): click.echo( "Received message is not a valid handshake completion response", err=True, ) + trezor_state = self.noise.decrypt(bytes(data)) + # TODO handle trezor_state + print("trezor state:", trezor_state) self._send_ack_1() def _do_pairing(self, helper_debug: DebugLink | None): @@ -369,7 +306,6 @@ class ProtocolV2Channel(Channel): ctrl_byte: int | None = None, ) -> None: assert self.key_request is not None - aes_ctx = AESGCM(self.key_request) if ctrl_byte is None: ctrl_byte = control_byte.add_seq_bit_to_ctrl_byte(0x04, self.sync_bit_send) @@ -378,9 +314,9 @@ class ProtocolV2Channel(Channel): sid = session_id.to_bytes(1, "big") msg_type = message_type.to_bytes(2, "big") data = sid + msg_type + message_data - nonce = _get_iv_from_nonce(self.nonce_request) - self.nonce_request += 1 - encrypted_message = aes_ctx.encrypt(nonce, data, b"") + + encrypted_message = self.noise.encrypt(data) + header = MessageHeader( ctrl_byte, self.channel_id, len(encrypted_message) + CHECKSUM_LENGTH ) @@ -417,11 +353,8 @@ class ProtocolV2Channel(Channel): self._send_ack_1() else: self._send_ack_0() - aes_ctx = AESGCM(self.key_response) - nonce = _get_iv_from_nonce(self.nonce_response) - self.nonce_response += 1 - message = aes_ctx.decrypt(nonce, raw_payload, b"") + message = self.noise.decrypt(bytes(raw_payload)) session_id = message[0] message_type = message[1:3] message_data = message[3:] diff --git a/tests/device_tests/thp/test_thp.py b/tests/device_tests/thp/test_thp.py index 6347b4e12a..61fe26acb1 100644 --- a/tests/device_tests/thp/test_thp.py +++ b/tests/device_tests/thp/test_thp.py @@ -34,7 +34,6 @@ from trezorlib.messages import ( ) from trezorlib.transport.thp import curve25519 from trezorlib.transport.thp.cpace import Cpace -from trezorlib.transport.thp.protocol_v2 import _hkdf if t.TYPE_CHECKING: P = tx.ParamSpec("P") @@ -53,18 +52,18 @@ def _prepare_protocol(client: Client) -> ProtocolV2Channel: def _prepare_protocol_for_pairing( - client: Client, host_static_privkey: bytes | None = None + client: Client, host_static_randomness: bytes | None = None ) -> ProtocolV2Channel: protocol = _prepare_protocol(client) - protocol._do_handshake(host_static_privkey=host_static_privkey) + protocol._do_handshake(host_static_randomness=host_static_randomness) return protocol def _get_encrypted_transport_protocol( - client: Client, host_static_privkey: bytes | None = None + client: Client, host_static_randomness: bytes | None = None ) -> ProtocolV2Channel: protocol = _prepare_protocol_for_pairing( - client, host_static_privkey=host_static_privkey + client, host_static_randomness=host_static_randomness ) protocol._do_pairing(client.debug) return protocol @@ -105,39 +104,24 @@ def test_allocate_channel(client: Client) -> None: def test_handshake(client: Client) -> None: protocol = _prepare_protocol(client) - host_ephemeral_privkey = curve25519.get_private_key(os.urandom(32)) - host_ephemeral_pubkey = curve25519.get_public_key(host_ephemeral_privkey) + randomness_static = os.urandom(32) protocol._do_channel_allocation() - protocol._send_handshake_init_request(host_ephemeral_pubkey) - protocol._read_ack() - init_response = protocol._read_handshake_init_response() - - trezor_ephemeral_pubkey = init_response[:32] - encrypted_trezor_static_pubkey = init_response[32:80] - noise_tag = init_response[80:96] - - # TODO check noise_tag is valid - - ck = protocol._send_handshake_completion_request( - host_ephemeral_pubkey, - host_ephemeral_privkey, - trezor_ephemeral_pubkey, - encrypted_trezor_static_pubkey, + protocol._init_noise( + randomness_static=randomness_static, ) + protocol._send_handshake_init_request() + protocol._read_ack() + protocol._read_handshake_init_response() + + protocol._send_handshake_completion_request() protocol._read_ack() protocol._read_handshake_completion_response() - protocol.key_request, protocol.key_response = _hkdf(ck, b"") - protocol.nonce_request = 0 - protocol.nonce_response = 1 # TODO - without pairing, the client is damaged and results in fail of the following test # so far no luck in solving it - it should be also tackled in FW, as it causes unexpected FW error protocol._do_pairing(client.debug) - # TODO the following is just to make style checker happy - assert noise_tag is not None - def test_pairing_qr_code(client: Client) -> None: protocol = _prepare_protocol_for_pairing(client) @@ -293,7 +277,8 @@ def test_credential_phase(client: Client) -> None: _nfc_pairing(client, protocol) # Request credential with confirmation after pairing - host_static_privkey = curve25519.get_private_key(os.urandom(32)) + randomness_static = os.urandom(32) + host_static_privkey = curve25519.get_private_key(randomness_static) host_static_pubkey = curve25519.get_public_key(host_static_privkey) protocol._send_message( ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=False) @@ -308,7 +293,7 @@ def test_credential_phase(client: Client) -> None: # Connect using credential with confirmation protocol = _prepare_protocol(client) protocol._do_channel_allocation() - protocol._do_handshake(credential, host_static_privkey) + protocol._do_handshake(credential, randomness_static) protocol._send_message(ThpEndRequest()) button_req = protocol._read_message(ButtonRequest) assert button_req.name == "connection_request" @@ -318,7 +303,8 @@ def test_credential_phase(client: Client) -> None: # Delete channel from the device by sending badly encrypted message # This is done to prevent channel replacement and trigerring of autoconnect false -> true - protocol.nonce_request = 250 + protocol.noise.noise_protocol.cipher_state_encrypt.n = 250 + protocol._send_message(ButtonAck()) with pytest.raises(Exception) as e: protocol.read(1) @@ -327,7 +313,7 @@ def test_credential_phase(client: Client) -> None: # Connect using credential with confirmation and ask for autoconnect credential. protocol = _prepare_protocol(client) protocol._do_channel_allocation() - protocol._do_handshake(credential, host_static_privkey) + protocol._do_handshake(credential, randomness_static) protocol._send_message( ThpCredentialRequest(host_static_pubkey=host_static_pubkey, autoconnect=True) ) @@ -345,7 +331,7 @@ def test_credential_phase(client: Client) -> None: # Connect using credential with confirmation protocol = _prepare_protocol(client) protocol._do_channel_allocation() - protocol._do_handshake(credential, host_static_privkey) + protocol._do_handshake(credential, randomness_static) # Confirmation dialog is not shown as channel in ENCRYPTED TRANSPORT state with the same # host static public key is still available in Trezor's cache. (Channel replacement is triggered.) protocol._send_message(ThpEndRequest()) @@ -354,13 +340,14 @@ def test_credential_phase(client: Client) -> None: # Connect using autoconnect credential protocol = _prepare_protocol(client) protocol._do_channel_allocation() - protocol._do_handshake(credential_auto, host_static_privkey) + protocol._do_handshake(credential_auto, randomness_static) protocol._send_message(ThpEndRequest()) protocol._read_message(ThpEndResponse) # Delete channel from the device by sending badly encrypted message # This is done to prevent channel replacement and trigerring of autoconnect false -> true - protocol.nonce_request = 250 + protocol.noise.noise_protocol.cipher_state_encrypt.n = 100 + protocol._send_message(ButtonAck()) with pytest.raises(Exception) as e: protocol.read(1) @@ -369,7 +356,7 @@ def test_credential_phase(client: Client) -> None: # Connect using autoconnect credential - should work the same as above protocol = _prepare_protocol(client) protocol._do_channel_allocation() - protocol._do_handshake(credential_auto, host_static_privkey) + protocol._do_handshake(credential_auto, randomness_static) protocol._send_message(ThpEndRequest()) protocol._read_message(ThpEndResponse) @@ -378,23 +365,25 @@ def test_credential_phase(client: Client) -> None: def test_channel_replacement(client: Client) -> None: assert client.features.passphrase_protection is True - host_static_privkey = curve25519.get_private_key(os.urandom(32)) - host_static_privkey_2 = curve25519.get_private_key(os.urandom(32)) + host_static_randomness = os.urandom(32) + host_static_randomness_2 = os.urandom(32) + host_static_privkey = curve25519.get_private_key(host_static_randomness) + host_static_privkey_2 = curve25519.get_private_key(host_static_randomness_2) assert host_static_privkey != host_static_privkey_2 - client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey) + client.protocol = _get_encrypted_transport_protocol(client, host_static_randomness) - session = client.get_session(passphrase="TREZOR", session_id=20) + session = client.get_session(passphrase="TREZOR", session_id=b"\x10") address = get_test_address(session) - session_2 = client.get_session(passphrase="ROZERT", session_id=30) + session_2 = client.get_session(passphrase="ROZERT", session_id=b"\x20") address_2 = get_test_address(session_2) assert address != address_2 # create new channel using the same host_static_privkey - client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey) - session_3 = client.get_session(passphrase="OKIDOKI", session_id=40) + client.protocol = _get_encrypted_transport_protocol(client, host_static_randomness) + session_3 = client.get_session(passphrase="OKIDOKI", session_id=b"\x30") address_3 = get_test_address(session_3) assert address_3 != address_2 @@ -405,7 +394,9 @@ def test_channel_replacement(client: Client) -> None: assert address_3 == new_address_3 # create new channel using different host_static_privkey - client.protocol = _get_encrypted_transport_protocol(client, host_static_privkey_2) + client.protocol = _get_encrypted_transport_protocol( + client, host_static_randomness_2 + ) with pytest.raises(exceptions.TrezorFailure) as e_1: _ = get_test_address(session) assert str(e_1.value.message) == "Invalid session" @@ -414,6 +405,6 @@ def test_channel_replacement(client: Client) -> None: _ = get_test_address(session_3) assert str(e_2.value.message) == "Invalid session" - session_4 = client.get_session(passphrase="TREZOR", session_id=80) + session_4 = client.get_session(passphrase="TREZOR", session_id=b"\x40") super_new_address = get_test_address(session_4) assert address == super_new_address