diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 720d6bb23e..cd01c95de9 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -19,6 +19,7 @@ from trezor import log, loop, protobuf, utils, workflow from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError from . import alternating_bit_protocol as ABP from . import ( + checksum, control_byte, crypto, interface_manager, @@ -75,9 +76,13 @@ class Channel: self.transmission_loop: TransmissionLoop | None = None self.write_task_spawn: loop.spawn | None = None - # Temporary objects for handshake and pairing + # Temporary objects self.handshake: crypto.Handshake | None = None self.connection_context: PairingContext | None = None + self.busy_decoder: crypto.BusyDecoder | None = None + self.temp_crc: int | None = None + self.temp_crc_compare: bytes | None = None + self.temp_tag: bytes | None = None def clear(self) -> None: clear_sessions_with_channel_id(self.channel_id) @@ -134,14 +139,18 @@ class Channel: def _handle_received_packet(self, packet: utils.BufferType) -> None: ctrl_byte = packet[0] if control_byte.is_continuation(ctrl_byte): - return self._handle_cont_packet(packet) - return self._handle_init_packet(packet) + self._handle_cont_packet(packet) + return + self._handle_init_packet(packet) def _handle_init_packet(self, packet: utils.BufferType) -> None: + self.fallback_decrypt = False + self.bytes_read = 0 + self.expected_payload_length = 0 + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("handle_init_packet") - # ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption _, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) self.expected_payload_length = payload_length @@ -162,11 +171,66 @@ class Channel: try: buffer = memory_manager.get_new_read_buffer(cid, length) except BufferError: - self.fallback_decrypt = True - # TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess + # TODO handle not encrypted/(short??), eg. ACK - # if buffer is BufferError: - # pass # TODO handle deviceBUSY + self.fallback_decrypt = True + self._prepare_busy_decoder() + + to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length) + buf = memoryview(self.buffer)[:to_read_len] + self.temp_crc = checksum.compute_int(data=packet[:INIT_HEADER_LENGTH]) + self.temp_crc_compare = bytearray(4) + self.temp_tag = bytearray(16) + utils.memcpy(buf, 0, packet, INIT_HEADER_LENGTH) + + # TODO handle: CRC in init packet, CRC partially in init packet, CRC in some cont packet + # instead of whole buf use only part without CRC + # + # bytes_read=0, buffer_len, payload_len + # crc: + # 1) payload_len >= buffer_len + CHKSUM_LEN -> return buffer_len + # 2) payload_len == buffer_len -> return payload_len - CHKSUM_LEN + # 3) payload_len > buffer_len -> return payload_len - CHKSUM_LEN + # + # noise tag: + # 1) payload_len >= buffer_len + TAG_LEN + CHKSUM_LEN -> return buffer_len + # 2) payload_len == buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN + # 3) payload_len > buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN + # + + # CRC CHECK + crc_copy_len: int = 0 + if payload_length > len(buf) + CHECKSUM_LENGTH: + crc_copy_len = len(buf) + elif payload_length == len(buf): + crc_copy_len = payload_length - CHECKSUM_LENGTH + crc_checksum_last_part = buf[-CHECKSUM_LENGTH:] + offset = CHECKSUM_LENGTH - len(crc_checksum_last_part) + utils.memcpy(self.temp_crc_compare, offset, crc_checksum_last_part, 0) + elif payload_length > len(buf): + crc_copy_len = payload_length - CHECKSUM_LENGTH + crc_checksum_first_part = buf[ + -CHECKSUM_LENGTH + payload_length - len(buf) + ] + utils.memcpy(self.temp_crc_compare, 0, crc_checksum_first_part, 0) + else: + raise Exception("Buffer should not be bigger than payload") + self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc) + + # TAG CHECK + assert self.busy_decoder is not None + + if payload_length > len(buf) + TAG_LENGTH + CHECKSUM_LENGTH: + self.busy_decoder.decrypt_part(buf) + elif payload_length > len(buf): + self.busy_decoder.decrypt_part( + buf[: payload_length - TAG_LENGTH - CHECKSUM_LENGTH] + ) + # TODO add part of the "tag from message" to compare + else: + raise Exception("Buffer should not be bigger than payload") + + # TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("handle_init_packet - payload len: ", str(payload_length)) @@ -174,18 +238,65 @@ class Channel: self._buffer_packet_data(buffer, packet, 0) + def _handle_fallback_crc(self, payload_length: int, buf: memoryview): + if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH: + # The CRC checksum is not in this packet, compute crc over whole buffer + self.temp_crc = checksum.compute_int(buf, self.temp_crc) + elif payload_length >= len(buf) + self.bytes_read: + # At least a part of the CRC checksum is in this packet, compute CRC over + # first (max(0, crc_copy_len)) bytes and add the rest of the bytes + # as the checksum from message into temp_crc_compare + crc_copy_len = payload_length - self.bytes_read - CHECKSUM_LENGTH + self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc) + + crc_checksum = buf[ + payload_length - CHECKSUM_LENGTH - len(buf) - self.bytes_read : + ] + offset = CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH:]) + utils.memcpy(self.temp_crc_compare, offset, crc_checksum, 0) + else: + raise Exception("Buffer (+bytes_read) should not be bigger than payload") + + def _handle_fallback_decryption(self, payload_length: int, buf: memoryview): + if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH: + # The noise tag is not in this packet, decrypt the whole buffer + self.busy_decoder.decrypt_part(buf) + elif payload_length >= len(buf) + self.bytes_read: + # At least a part of the CRC checksum is in this packet, compute CRC over + # first (max(0, crc_copy_len)) bytes and add the rest of the bytes + # as the checksum from message into temp_crc_compare + dec_len = payload_length - self.bytes_read - TAG_LENGTH - CHECKSUM_LENGTH + self.busy_decoder.decrypt_part(buf[:dec_len]) + + noise_tag = buf[ + payload_length + - CHECKSUM_LENGTH + - TAG_LENGTH + - len(buf) + - self.bytes_read : + ] + offset = ( + TAG_LENGTH + CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH - TAG_LENGTH :]) + ) + utils.memcpy(self.temp_tag, offset, noise_tag, 0) + else: + raise Exception("Buffer (+bytes_read) should not be bigger than payload") + def _handle_cont_packet(self, packet: utils.BufferType) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("handle_cont_packet") if not self.is_cont_packet_expected: raise ThpError("Continuation packet is not expected, ignoring") - + if self.fallback_decrypt: + pass # TODO + return try: buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) except BufferError: + self.set_channel_state(ChannelState.INVALIDATED) pass # TODO handle device busy, channel kaput - return self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH) + self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH) def _buffer_packet_data( self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int @@ -193,17 +304,28 @@ class Channel: self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) def _finish_message(self) -> None: - self.fallback_decrypt = False self.bytes_read = 0 self.expected_payload_length = 0 self.is_cont_packet_expected = False + self.fallback_decrypt = False + self.busy_decoder = None + def _decrypt_single_packet_payload( self, payload: utils.BufferType ) -> utils.BufferType: # crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) return payload + def _prepare_busy_decoder(self) -> None: + key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE) + nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE) + + assert key_receive is not None + assert nonce_receive is not None + + self.busy_decoder = crypto.BusyDecoder(key_receive, nonce_receive) + def decrypt_buffer( self, message_length: int, offset: int = INIT_HEADER_LENGTH ) -> None: @@ -232,9 +354,7 @@ class Channel: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("Buffer before decryption: ", get_bytes_as_str(noise_buffer)) - is_tag_valid = crypto.dec( - noise_buffer, tag, key_receive, nonce_receive, b"" - ) + is_tag_valid = crypto.dec(noise_buffer, tag, key_receive, nonce_receive) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("Buffer after decryption: ", get_bytes_as_str(noise_buffer)) @@ -362,7 +482,7 @@ class Channel: assert key_send is not None assert nonce_send is not None - tag = crypto.enc(noise_buffer, key_send, nonce_send, b"") + tag = crypto.enc(noise_buffer, key_send, nonce_send) self.channel_cache.set_int(CHANNEL_NONCE_SEND, nonce_send + 1) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: diff --git a/core/src/trezor/wire/thp/checksum.py b/core/src/trezor/wire/thp/checksum.py index 9c28f2e78d..44aab46630 100644 --- a/core/src/trezor/wire/thp/checksum.py +++ b/core/src/trezor/wire/thp/checksum.py @@ -6,11 +6,22 @@ from trezor.crypto import crc CHECKSUM_LENGTH = const(4) -def compute(data: bytes | utils.BufferType) -> bytes: +def compute(data: bytes | utils.BufferType, crc_chain: int = 0) -> bytes: """ - Returns a CRC-32 checksum of the provided `data`. + Returns a CRC-32 checksum of the provided `data`. Allows for for chaining + computations over multiple data segments using `crc_chain` (optional). """ - return crc.crc32(data).to_bytes(CHECKSUM_LENGTH, "big") + return crc.crc32(data, crc_chain).to_bytes(CHECKSUM_LENGTH, "big") + + +def compute_int(data: bytes | utils.BufferType, crc_chain: int = 0) -> int: + """ + Returns a CRC-32 checksum of the provided `data`. Allows for for chaining + computations over multiple data segments using `crc_chain` (optional). + + Returns checksum in the form of `int`. + """ + return crc.crc32(data, crc_chain) def is_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool: diff --git a/core/src/trezor/wire/thp/crypto.py b/core/src/trezor/wire/thp/crypto.py index aa7d9c146e..4ba7fc71c9 100644 --- a/core/src/trezor/wire/thp/crypto.py +++ b/core/src/trezor/wire/thp/crypto.py @@ -14,16 +14,18 @@ if utils.DISABLE_ENCRYPTION: DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5" if __debug__: - from ubinascii import hexlify + from trezor.utils import get_bytes_as_str -def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> bytes: +def enc( + buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes = b"" +) -> bytes: """ Encrypts the provided `buffer` with AES-GCM (in place). Returns a 16-byte long encryption tag. """ if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "enc (key: %s, nonce: %d)", hexlify(key), nonce) + log.debug(__name__, "enc (key: %s, nonce: %d)", get_bytes_as_str(key), nonce) iv = _get_iv_from_nonce(nonce) aes_ctx = aesgcm(key, iv) aes_ctx.auth(auth_data) @@ -32,7 +34,7 @@ def enc(buffer: utils.BufferType, key: bytes, nonce: int, auth_data: bytes) -> b def dec( - buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes + buffer: utils.BufferType, tag: bytes, key: bytes, nonce: int, auth_data: bytes = b"" ) -> bool: """ Decrypts the provided buffer (in place). Returns `True` if the provided authentication `tag` is the same as @@ -40,7 +42,7 @@ def dec( """ iv = _get_iv_from_nonce(nonce) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "dec (key: %s, nonce: %d)", hexlify(key), nonce) + log.debug(__name__, "dec (key: %s, nonce: %d)", get_bytes_as_str(key), nonce) aes_ctx = aesgcm(key, iv) aes_ctx.auth(auth_data) aes_ctx.decrypt_in_place(buffer) @@ -49,7 +51,8 @@ def dec( class BusyDecoder: - def __init__(self, key: bytes, nonce: int, auth_data: bytes) -> None: + + def __init__(self, key: bytes, nonce: int, auth_data: bytes = b"") -> None: iv = _get_iv_from_nonce(nonce) self.aes_ctx = aesgcm(key, iv) self.aes_ctx.auth(auth_data) @@ -105,7 +108,9 @@ class Handshake: aes_ctx = aesgcm(self.k, IV_1) encrypted_trezor_static_pubkey = aes_ctx.encrypt(trezor_masked_static_pubkey) if __debug__: - log.debug(__name__, "th1 - enc (key: %s, nonce: %d)", hexlify(self.k), 0) + log.debug( + __name__, "th1 - enc (key: %s, nonce: %d)", get_bytes_as_str(self.k), 0 + ) aes_ctx.auth(self.h) tag_to_encrypted_key = aes_ctx.finish() encrypted_trezor_static_pubkey = ( @@ -137,7 +142,9 @@ class Handshake: memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH] ) if __debug__: - log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 1) + log.debug( + __name__, "th2 - dec (key: %s, nonce: %d)", get_bytes_as_str(self.k), 1 + ) host_static_pubkey = memoryview(encrypted_host_static_pubkey)[:PUBKEY_LENGTH] tag = aes_ctx.finish() if tag != encrypted_host_static_pubkey[-16:]: @@ -151,7 +158,9 @@ class Handshake: aes_ctx.auth(self.h) aes_ctx.decrypt_in_place(memoryview(encrypted_payload)[:-16]) if __debug__: - log.debug(__name__, "th2 - dec (key: %s, nonce: %d)", hexlify(self.k), 0) + log.debug( + __name__, "th2 - dec (key: %s, nonce: %d)", get_bytes_as_str(self.k), 0 + ) tag = aes_ctx.finish() if tag != encrypted_payload[-16:]: raise ThpDecryptionError() @@ -162,8 +171,8 @@ class Handshake: log.debug( __name__, "(key_receive: %s, key_send: %s)", - hexlify(self.key_receive), - hexlify(self.key_send), + get_bytes_as_str(self.key_receive), + get_bytes_as_str(self.key_send), ) def get_handshake_completion_response(self, trezor_state: bytes) -> bytes: diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py index cf453f8992..ea5d5477b5 100644 --- a/core/src/trezor/wire/thp/memory_manager.py +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -11,6 +11,7 @@ from .writer import MAX_PAYLOAD_LEN, MESSAGE_TYPE_LENGTH _PROTOBUF_BUFFER_SIZE = 8192 READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) +LOCK_TIMEOUT = 200 # miliseconds lock_owner_cid: int | None = None @@ -106,7 +107,7 @@ def is_locked() -> bool: global lock_time time_diff = utime.ticks_diff(utime.ticks_ms(), lock_time) - return lock_owner_cid is not None and time_diff < 200 + return lock_owner_cid is not None and time_diff < LOCK_TIMEOUT def is_owner(channel_id: int) -> bool: