diff --git a/core/src/trezor/enums/__init__.py b/core/src/trezor/enums/__init__.py index 62ef86b810..f0ab5e5d55 100644 --- a/core/src/trezor/enums/__init__.py +++ b/core/src/trezor/enums/__init__.py @@ -39,8 +39,10 @@ if TYPE_CHECKING: PinMismatch = 12 WipeCodeMismatch = 13 InvalidSession = 14 - ThpUnallocatedSession = 15 - InvalidProtocol = 16 + DeviceIsBusy = 15 + ThpUnallocatedSession = 16 + InvalidProtocol = 17 + BufferError = 18 FirmwareError = 99 class ButtonRequestType(IntEnum): diff --git a/core/src/trezor/wire/errors.py b/core/src/trezor/wire/errors.py index e8b2d3feb4..8f572fcf0c 100644 --- a/core/src/trezor/wire/errors.py +++ b/core/src/trezor/wire/errors.py @@ -14,6 +14,11 @@ class SilentError(Exception): self.message = message +class WireBufferError(Error): + def __init__(self, message: str = "Buffer error") -> None: + super().__init__(FailureType.BufferError, message) + + class UnexpectedMessage(Error): def __init__(self, message: str) -> None: super().__init__(FailureType.UnexpectedMessage, message) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index cd01c95de9..a5daa08c0b 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -15,6 +15,7 @@ from storage.cache_thp import ( clear_sessions_with_channel_id, ) from trezor import log, loop, protobuf, utils, workflow +from trezor.wire.errors import WireBufferError from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError from . import alternating_bit_protocol as ABP @@ -81,8 +82,8 @@ class Channel: self.connection_context: PairingContext | None = None self.busy_decoder: crypto.BusyDecoder | None = None self.temp_crc: int | None = None - self.temp_crc_compare: bytes | None = None - self.temp_tag: bytes | None = None + self.temp_crc_compare: bytearray | None = None + self.temp_tag: bytearray | None = None def clear(self) -> None: clear_sessions_with_channel_id(self.channel_id) @@ -119,7 +120,7 @@ class Channel: try: buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) - except BufferError: + except WireBufferError: pass # TODO ?? if __debug__ and utils.ALLOW_DEBUG_MESSAGES: @@ -170,67 +171,22 @@ class Channel: length = payload_length + INIT_HEADER_LENGTH try: buffer = memory_manager.get_new_read_buffer(cid, length) - except BufferError: + except WireBufferError: # TODO handle not encrypted/(short??), eg. ACK self.fallback_decrypt = True - self._prepare_busy_decoder() + + self._prepare_fallback() to_read_len = min(len(packet) - INIT_HEADER_LENGTH, payload_length) buf = memoryview(self.buffer)[:to_read_len] - self.temp_crc = checksum.compute_int(data=packet[:INIT_HEADER_LENGTH]) - self.temp_crc_compare = bytearray(4) - self.temp_tag = bytearray(16) utils.memcpy(buf, 0, packet, INIT_HEADER_LENGTH) - # TODO handle: CRC in init packet, CRC partially in init packet, CRC in some cont packet - # instead of whole buf use only part without CRC - # - # bytes_read=0, buffer_len, payload_len - # crc: - # 1) payload_len >= buffer_len + CHKSUM_LEN -> return buffer_len - # 2) payload_len == buffer_len -> return payload_len - CHKSUM_LEN - # 3) payload_len > buffer_len -> return payload_len - CHKSUM_LEN - # - # noise tag: - # 1) payload_len >= buffer_len + TAG_LEN + CHKSUM_LEN -> return buffer_len - # 2) payload_len == buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN - # 3) payload_len > buffer_len -> return payload_len - TAG_LEN - CHKSUM_LEN - # - # CRC CHECK - crc_copy_len: int = 0 - if payload_length > len(buf) + CHECKSUM_LENGTH: - crc_copy_len = len(buf) - elif payload_length == len(buf): - crc_copy_len = payload_length - CHECKSUM_LENGTH - crc_checksum_last_part = buf[-CHECKSUM_LENGTH:] - offset = CHECKSUM_LENGTH - len(crc_checksum_last_part) - utils.memcpy(self.temp_crc_compare, offset, crc_checksum_last_part, 0) - elif payload_length > len(buf): - crc_copy_len = payload_length - CHECKSUM_LENGTH - crc_checksum_first_part = buf[ - -CHECKSUM_LENGTH + payload_length - len(buf) - ] - utils.memcpy(self.temp_crc_compare, 0, crc_checksum_first_part, 0) - else: - raise Exception("Buffer should not be bigger than payload") - self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc) + self._handle_fallback_crc(buf) # TAG CHECK - assert self.busy_decoder is not None - - if payload_length > len(buf) + TAG_LENGTH + CHECKSUM_LENGTH: - self.busy_decoder.decrypt_part(buf) - elif payload_length > len(buf): - self.busy_decoder.decrypt_part( - buf[: payload_length - TAG_LENGTH - CHECKSUM_LENGTH] - ) - # TODO add part of the "tag from message" to compare - else: - raise Exception("Buffer should not be bigger than payload") - - # TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess + self._handle_fallback_decryption(buf) if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("handle_init_packet - payload len: ", str(payload_length)) @@ -238,38 +194,55 @@ class Channel: self._buffer_packet_data(buffer, packet, 0) - def _handle_fallback_crc(self, payload_length: int, buf: memoryview): - if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH: + def _handle_fallback_crc(self, buf: memoryview) -> None: + assert self.temp_crc is not None + assert self.temp_crc_compare is not None + if self.expected_payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH: # The CRC checksum is not in this packet, compute crc over whole buffer self.temp_crc = checksum.compute_int(buf, self.temp_crc) - elif payload_length >= len(buf) + self.bytes_read: + elif self.expected_payload_length >= len(buf) + self.bytes_read: # At least a part of the CRC checksum is in this packet, compute CRC over - # first (max(0, crc_copy_len)) bytes and add the rest of the bytes + # first (max(0, crc_copy_len)) bytes and add the rest of the bytes (max 4) # as the checksum from message into temp_crc_compare - crc_copy_len = payload_length - self.bytes_read - CHECKSUM_LENGTH + crc_copy_len = ( + self.expected_payload_length - self.bytes_read - CHECKSUM_LENGTH + ) self.temp_crc = checksum.compute_int(buf[:crc_copy_len], self.temp_crc) crc_checksum = buf[ - payload_length - CHECKSUM_LENGTH - len(buf) - self.bytes_read : + self.expected_payload_length + - CHECKSUM_LENGTH + - len(buf) + - self.bytes_read : ] offset = CHECKSUM_LENGTH - len(buf[-CHECKSUM_LENGTH:]) utils.memcpy(self.temp_crc_compare, offset, crc_checksum, 0) else: raise Exception("Buffer (+bytes_read) should not be bigger than payload") - def _handle_fallback_decryption(self, payload_length: int, buf: memoryview): - if payload_length > len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH: + def _handle_fallback_decryption(self, buf: memoryview) -> None: + assert self.busy_decoder is not None + assert self.temp_tag is not None + if ( + self.expected_payload_length + > len(buf) + self.bytes_read + CHECKSUM_LENGTH + TAG_LENGTH + ): # The noise tag is not in this packet, decrypt the whole buffer self.busy_decoder.decrypt_part(buf) - elif payload_length >= len(buf) + self.bytes_read: + elif self.expected_payload_length >= len(buf) + self.bytes_read: # At least a part of the CRC checksum is in this packet, compute CRC over # first (max(0, crc_copy_len)) bytes and add the rest of the bytes # as the checksum from message into temp_crc_compare - dec_len = payload_length - self.bytes_read - TAG_LENGTH - CHECKSUM_LENGTH + dec_len = ( + self.expected_payload_length + - self.bytes_read + - TAG_LENGTH + - CHECKSUM_LENGTH + ) self.busy_decoder.decrypt_part(buf[:dec_len]) noise_tag = buf[ - payload_length + self.expected_payload_length - CHECKSUM_LENGTH - TAG_LENGTH - len(buf) @@ -293,7 +266,7 @@ class Channel: return try: buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) - except BufferError: + except WireBufferError: self.set_channel_state(ChannelState.INVALIDATED) pass # TODO handle device busy, channel kaput self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH) @@ -317,7 +290,8 @@ class Channel: # crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) return payload - def _prepare_busy_decoder(self) -> None: + def _prepare_fallback(self) -> None: + # prepare busy decoder key_receive = self.channel_cache.get(CHANNEL_KEY_RECEIVE) nonce_receive = self.channel_cache.get_int(CHANNEL_NONCE_RECEIVE) @@ -326,11 +300,16 @@ class Channel: self.busy_decoder = crypto.BusyDecoder(key_receive, nonce_receive) + # prepare temp channel values + self.temp_crc = 0 + self.temp_crc_compare = bytearray(4) + self.temp_tag = bytearray(16) + def decrypt_buffer( self, message_length: int, offset: int = INIT_HEADER_LENGTH ) -> None: buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) - # if buffer is BufferError: + # if buffer is WireBufferError: # pass # TODO handle deviceBUSY noise_buffer = memoryview(buffer)[ offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH @@ -387,7 +366,7 @@ class Channel: noise_payload_len = memory_manager.encode_into_buffer( buffer, msg, session_id ) - except BufferError: + except WireBufferError: from trezor.messages import Failure, FailureType if __debug__ and utils.ALLOW_DEBUG_MESSAGES: @@ -421,7 +400,7 @@ class Channel: def _write_and_encrypt(self, payload: bytes) -> Awaitable[None]: payload_length = len(payload) buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int()) - # if buffer is BufferError: + # if buffer is WireBufferError: # pass # TODO handle deviceBUSY self._encrypt(buffer, payload_length) diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py index ea5d5477b5..681642d8fc 100644 --- a/core/src/trezor/wire/thp/memory_manager.py +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -3,6 +3,7 @@ from micropython import const from storage.cache_thp import SESSION_ID_LENGTH from trezor import protobuf, utils +from trezor.wire.errors import WireBufferError from trezor.wire.message_handler import get_msg_type from . import ThpError @@ -48,7 +49,7 @@ def get_existing_write_buffer(channel_id: int) -> memoryview: def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview: if is_locked(): if not is_owner(channel_id): - raise BufferError + raise WireBufferError update_lock_time() else: update_lock(channel_id) @@ -80,19 +81,19 @@ def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryvie def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview: if not is_owner(channel_id): - raise BufferError + raise WireBufferError update_lock_time() if buffer_type == _READ: global READ_BUFFER_SLICE if READ_BUFFER_SLICE is None: - raise BufferError + raise WireBufferError return READ_BUFFER_SLICE if buffer_type == _WRITE: global WRITE_BUFFER_SLICE if WRITE_BUFFER_SLICE is None: - raise BufferError + raise WireBufferError return WRITE_BUFFER_SLICE raise Exception("Invalid buffer_type")