diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py index 8cf3a9207d..76288c1f9b 100644 --- a/core/src/trezor/wire/thp/__init__.py +++ b/core/src/trezor/wire/thp/__init__.py @@ -67,6 +67,7 @@ class ChannelState(IntEnum): TP4 = 6 TC1 = 7 ENCRYPTED_TRANSPORT = 8 + INVALIDATED = 9 class SessionState(IntEnum): diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index d2ddcbf96e..cfe7833730 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -8,7 +8,12 @@ from storage.cache_common import ( CHANNEL_NONCE_RECEIVE, CHANNEL_NONCE_SEND, ) -from storage.cache_thp import TAG_LENGTH, ChannelCache, clear_sessions_with_channel_id +from storage.cache_thp import ( + SESSION_ID_LENGTH, + TAG_LENGTH, + ChannelCache, + clear_sessions_with_channel_id, +) from trezor import log, loop, protobuf, utils, workflow from . import ENCRYPTED, ChannelState, PacketHeader, ThpDecryptionError, ThpError @@ -25,6 +30,7 @@ from .transmission_loop import TransmissionLoop from .writer import ( CONT_HEADER_LENGTH, INIT_HEADER_LENGTH, + MESSAGE_TYPE_LENGTH, PACKET_LENGTH, write_payload_to_wire_and_add_checksum, ) @@ -58,6 +64,7 @@ class Channel: # Shared variables self.buffer: utils.BufferType = bytearray(PACKET_LENGTH) + self.fallback_decrypt: bool = False self.bytes_read: int = 0 self.expected_payload_length: int = 0 self.is_cont_packet_expected: bool = False @@ -97,24 +104,25 @@ class Channel: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("set_channel_state: ", state_to_str(state)) - def set_buffer(self, buffer: utils.BufferType) -> None: - self.buffer = buffer - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - self._log("set_buffer: ", str(type(self.buffer))) - # READ and DECRYPT def receive_packet(self, packet: utils.BufferType) -> Awaitable[None] | None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("receive packet") + self._handle_received_packet(packet) + try: + buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) + except BufferError: + pass # TODO ?? + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - self._log("self.buffer: ", get_bytes_as_str(self.buffer)) + self._log("self.buffer: ", get_bytes_as_str(buffer)) if self.expected_payload_length + INIT_HEADER_LENGTH == self.bytes_read: self._finish_message() - return received_message_handler.handle_received_message(self, self.buffer) + return received_message_handler.handle_received_message(self, buffer) elif self.expected_payload_length + INIT_HEADER_LENGTH > self.bytes_read: self.is_cont_packet_expected = True else: @@ -136,7 +144,9 @@ class Channel: # ctrl_byte, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) # TODO use this with single packet decryption _, _, payload_length = ustruct.unpack(PacketHeader.format_str_init, packet) self.expected_payload_length = payload_length - packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:] + + # packet_payload = memoryview(packet)[INIT_HEADER_LENGTH:] + # The above could be used for single packet decryption # If the channel does not "own" the buffer lock, decrypt first packet # TODO do it only when needed! @@ -147,18 +157,22 @@ class Channel: # if control_byte.is_encrypted_transport(ctrl_byte): # packet_payload = self._decrypt_single_packet_payload(packet_payload) - self.buffer = memory_manager.select_buffer( - self.get_channel_state(), - self.buffer, - packet_payload, - payload_length, - ) + cid = self.get_channel_id_int() + length = payload_length + INIT_HEADER_LENGTH + try: + buffer = memory_manager.get_new_read_buffer(cid, length) + except BufferError: + self.fallback_decrypt = True + # TODO decrypt packet by packet, keep track of length, at the end call _finish_message to clear mess + + # if buffer is BufferError: + # pass # TODO handle deviceBUSY if __debug__ and utils.ALLOW_DEBUG_MESSAGES: self._log("handle_init_packet - payload len: ", str(payload_length)) - self._log("handle_init_packet - buffer len: ", str(len(self.buffer))) + self._log("handle_init_packet - buffer len: ", str(len(buffer))) - return self._buffer_packet_data(self.buffer, packet, 0) + self._buffer_packet_data(buffer, packet, 0) def _handle_cont_packet(self, packet: utils.BufferType) -> None: if __debug__ and utils.ALLOW_DEBUG_MESSAGES: @@ -166,7 +180,12 @@ class Channel: if not self.is_cont_packet_expected: raise ThpError("Continuation packet is not expected, ignoring") - return self._buffer_packet_data(self.buffer, packet, CONT_HEADER_LENGTH) + + try: + buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) + except BufferError: + pass # TODO handle device busy, channel kaput + return self._buffer_packet_data(buffer, packet, CONT_HEADER_LENGTH) def _buffer_packet_data( self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int @@ -174,6 +193,7 @@ class Channel: self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) def _finish_message(self) -> None: + self.fallback_decrypt = False self.bytes_read = 0 self.expected_payload_length = 0 self.is_cont_packet_expected = False @@ -187,15 +207,19 @@ class Channel: def decrypt_buffer( self, message_length: int, offset: int = INIT_HEADER_LENGTH ) -> None: - noise_buffer = memoryview(self.buffer)[ + buffer = memory_manager.get_existing_read_buffer(self.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO handle deviceBUSY + noise_buffer = memoryview(buffer)[ offset : message_length - CHECKSUM_LENGTH - TAG_LENGTH ] - tag = self.buffer[ + tag = buffer[ message_length - CHECKSUM_LENGTH - TAG_LENGTH : message_length - CHECKSUM_LENGTH ] + if utils.DISABLE_ENCRYPTION: is_tag_valid = tag == crypto.DUMMY_TAG else: @@ -235,10 +259,30 @@ class Channel: if __debug__ and utils.EMULATOR: self._log(f"write message: {msg.MESSAGE_NAME}\n", utils.dump_protobuf(msg)) - self.buffer = memory_manager.get_write_buffer(self.buffer, msg) - noise_payload_len = memory_manager.encode_into_buffer( - self.buffer, msg, session_id - ) + cid = self.get_channel_id_int() + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + length = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + INIT_HEADER_LENGTH + try: + buffer = memory_manager.get_new_write_buffer(cid, length) + noise_payload_len = memory_manager.encode_into_buffer( + buffer, msg, session_id + ) + except BufferError: + from trezor.messages import Failure, FailureType + + if __debug__ and utils.ALLOW_DEBUG_MESSAGES: + self._log("Failed to get write buffer, killing channel.") + + noise_payload_len = memory_manager.encode_into_buffer( + self.buffer, + Failure( + code=FailureType.FirmwareError, + message="Failed to obtain write buffer.", + ), + session_id, + ) + self.set_channel_state(ChannelState.INVALIDATED) task = self._write_and_encrypt(self.buffer[:noise_payload_len], force) if task is not None: await task @@ -259,7 +303,11 @@ class Channel: self, payload: bytes, force: bool = False ) -> Awaitable[None] | None: payload_length = len(payload) - self._encrypt(self.buffer, payload_length) + buffer = memory_manager.get_existing_write_buffer(self.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO handle deviceBUSY + + self._encrypt(buffer, payload_length) payload_length = payload_length + TAG_LENGTH if self.write_task_spawn is not None: @@ -275,7 +323,7 @@ class Channel: ) self.write_task_spawn = loop.spawn( self._write_encrypted_payload_loop( - ENCRYPTED, memoryview(self.buffer[:payload_length]) + ENCRYPTED, memoryview(buffer[:payload_length]) ) ) return None diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py index d7fb633134..cf453f8992 100644 --- a/core/src/trezor/wire/thp/memory_manager.py +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -1,70 +1,150 @@ -from storage.cache_thp import SESSION_ID_LENGTH, TAG_LENGTH -from trezor import log, protobuf, utils +import utime +from micropython import const + +from storage.cache_thp import SESSION_ID_LENGTH +from trezor import protobuf, utils from trezor.wire.message_handler import get_msg_type -from . import ChannelState, ThpError -from .checksum import CHECKSUM_LENGTH -from .writer import ( - INIT_HEADER_LENGTH, - MAX_PAYLOAD_LEN, - MESSAGE_TYPE_LENGTH, - PACKET_LENGTH, -) +from . import ThpError +from .writer import MAX_PAYLOAD_LEN, MESSAGE_TYPE_LENGTH _PROTOBUF_BUFFER_SIZE = 8192 READ_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) WRITE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) -def select_buffer( - channel_state: int, - channel_buffer: utils.BufferType, - packet_payload: utils.BufferType, - payload_length: int, -) -> utils.BufferType: +lock_owner_cid: int | None = None +lock_time: int = 0 - if channel_state is ChannelState.ENCRYPTED_TRANSPORT: - session_id = packet_payload[0] - if session_id == 0: - pass - # TODO use small buffer - else: - pass - # TODO use big buffer but only if the channel owns the buffer lock. - # Otherwise send BUSY message and return +READ_BUFFER_SLICE: memoryview | None = None +WRITE_BUFFER_SLICE: memoryview | None = None + +# Buffer types +_READ: int = const(0) +_WRITE: int = const(1) + + +# +# Access to buffer slices + + +def get_new_read_buffer(channel_id: int, length: int) -> memoryview: + return _get_new_buffer(_READ, channel_id, length) + + +def get_new_write_buffer(channel_id: int, length: int) -> memoryview: + return _get_new_buffer(_WRITE, channel_id, length) + + +def get_existing_read_buffer(channel_id: int) -> memoryview: + return _get_existing_buffer(_READ, channel_id) + + +def get_existing_write_buffer(channel_id: int) -> memoryview: + return _get_existing_buffer(_WRITE, channel_id) + + +def _get_new_buffer(buffer_type: int, channel_id: int, length: int) -> memoryview: + if is_locked(): + if not is_owner(channel_id): + raise BufferError + update_lock_time() else: - pass - # TODO use small buffer - try: - # TODO for now, we create a new big buffer every time. It should be changed - buffer: utils.BufferType = _get_buffer_for_read(payload_length, channel_buffer) - return buffer - except Exception as e: - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.exception(__name__, e) - raise Exception("Failed to create a buffer for channel") # TODO handle better + update_lock(channel_id) + + if buffer_type == _READ: + global READ_BUFFER + buffer = READ_BUFFER + elif buffer_type == _WRITE: + global WRITE_BUFFER + buffer = WRITE_BUFFER + else: + raise Exception("Invalid buffer_type") + + if length > MAX_PAYLOAD_LEN or length > len(buffer): + raise ThpError("Message is too large") # TODO reword + + if buffer_type == _READ: + global READ_BUFFER_SLICE + READ_BUFFER_SLICE = memoryview(READ_BUFFER)[:length] + return READ_BUFFER_SLICE + + if buffer_type == _WRITE: + global WRITE_BUFFER_SLICE + WRITE_BUFFER_SLICE = memoryview(WRITE_BUFFER)[:length] + return WRITE_BUFFER_SLICE + + raise Exception("Invalid buffer_type") -def get_write_buffer( - buffer: utils.BufferType, msg: protobuf.MessageType -) -> utils.BufferType: - msg_size = protobuf.encoded_length(msg) - payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size - required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH +def _get_existing_buffer(buffer_type: int, channel_id: int) -> memoryview: + if not is_owner(channel_id): + raise BufferError + update_lock_time() - if required_min_size > len(buffer): - return _get_buffer_for_write(required_min_size, buffer) - return buffer + if buffer_type == _READ: + global READ_BUFFER_SLICE + if READ_BUFFER_SLICE is None: + raise BufferError + return READ_BUFFER_SLICE + + if buffer_type == _WRITE: + global WRITE_BUFFER_SLICE + if WRITE_BUFFER_SLICE is None: + raise BufferError + return WRITE_BUFFER_SLICE + + raise Exception("Invalid buffer_type") + + +# +# Buffer locking + + +def is_locked() -> bool: + global lock_owner_cid + global lock_time + + time_diff = utime.ticks_diff(utime.ticks_ms(), lock_time) + return lock_owner_cid is not None and time_diff < 200 + + +def is_owner(channel_id: int) -> bool: + global lock_owner_cid + return lock_owner_cid is not None and lock_owner_cid == channel_id + + +def update_lock(channel_id: int) -> None: + set_owner(channel_id) + update_lock_time() + + +def set_owner(channel_id: int) -> None: + global lock_owner_cid + lock_owner_cid = channel_id + + +def update_lock_time() -> None: + global lock_time + lock_time = utime.ticks_ms() + + +# +# Helper for encoding messages into buffer def encode_into_buffer( buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int ) -> int: + """Encode protobuf message `msg` into the `buffer`, including session id + an messages's wire type. Will fail if provided message has no wire type.""" + # cannot write message without wire type msg_type = msg.MESSAGE_WIRE_TYPE if msg_type is None: msg_type = get_msg_type(msg.MESSAGE_NAME) - assert msg_type is not None + if msg_type is None: + raise Exception("Message has no wire type.") msg_size = protobuf.encoded_length(msg) payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size @@ -96,84 +176,3 @@ def _encode_message_into_buffer( buffer: memoryview, message: protobuf.MessageType, buffer_offset: int = 0 ) -> None: protobuf.encode(memoryview(buffer[buffer_offset:]), message) - - -def _get_buffer_for_read( - payload_length: int, - existing_buffer: utils.BufferType, - max_length: int = MAX_PAYLOAD_LEN, -) -> utils.BufferType: - length = payload_length + INIT_HEADER_LENGTH - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "get_buffer_for_read - length: %d, %s %s", - length, - "existing buffer type:", - type(existing_buffer), - ) - if length > max_length: - raise ThpError("Message too large") - - if length > len(existing_buffer): - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "Allocating a new buffer") - - if length > len(READ_BUFFER): - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "Required length is %d, where raw buffer has capacity only %d", - length, - len(READ_BUFFER), - ) - raise ThpError("Message is too large") - - try: - payload: utils.BufferType = memoryview(READ_BUFFER)[:length] - except MemoryError: - payload = memoryview(READ_BUFFER)[:PACKET_LENGTH] - raise ThpError("Message is too large") - return payload - - # reuse a part of the supplied buffer - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "Reusing already allocated buffer") - return memoryview(existing_buffer)[:length] - - -def _get_buffer_for_write( - payload_length: int, - existing_buffer: utils.BufferType, - max_length: int = MAX_PAYLOAD_LEN, -) -> utils.BufferType: - length = payload_length + INIT_HEADER_LENGTH - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug( - __name__, - "get_buffer_for_write - length: %d, %s %s", - length, - "existing buffer type:", - type(existing_buffer), - ) - if length > max_length: - raise ThpError("Message too large") - - if length > len(existing_buffer): - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "Creating a new write buffer from raw write buffer") - - if length > len(WRITE_BUFFER): - raise ThpError("Message is too large") - - try: - payload: utils.BufferType = memoryview(WRITE_BUFFER)[:length] - except MemoryError: - payload = memoryview(WRITE_BUFFER)[:PACKET_LENGTH] - raise ThpError("Message is too large") - return payload - - # reuse a part of the supplied buffer - if __debug__ and utils.ALLOW_DEBUG_MESSAGES: - log.debug(__name__, "Reusing already allocated buffer") - return memoryview(existing_buffer)[:length] diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index 2a7fbbdf30..56a2c3bda4 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -18,6 +18,7 @@ from storage.cache_thp import ( from trezor import log, loop, protobuf, utils from trezor.enums import FailureType from trezor.messages import Failure +from trezor.wire.thp import memory_manager from .. import message_handler from ..errors import DataError @@ -227,8 +228,12 @@ async def _handle_state_TH1( ctx.handshake = Handshake() + buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO buffer is gone :/ + host_ephemeral_pubkey = bytearray( - ctx.buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH] + buffer[INIT_HEADER_LENGTH : message_length - CHECKSUM_LENGTH] ) trezor_ephemeral_pubkey, encrypted_trezor_static_pubkey, tag = ( ctx.handshake.handle_th1_crypto( @@ -267,10 +272,13 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) - if ctx.handshake is None: raise Exception("Handshake object is not prepared. Retry handshake.") - host_encrypted_static_pubkey = memoryview(ctx.buffer)[ + buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO handle + host_encrypted_static_pubkey = buffer[ INIT_HEADER_LENGTH : INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH ] - handshake_completion_request_noise_payload = memoryview(ctx.buffer)[ + handshake_completion_request_noise_payload = buffer[ INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length - CHECKSUM_LENGTH ] @@ -285,7 +293,7 @@ async def _handle_state_TH2(ctx: Channel, message_length: int, ctrl_byte: int) - ctx.channel_cache.set_int(CHANNEL_NONCE_SEND, 1) noise_payload = _decode_message( - ctx.buffer[ + buffer[ INIT_HEADER_LENGTH + KEY_LENGTH + TAG_LENGTH : message_length @@ -349,8 +357,12 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) - log.debug(__name__, "handle_state_ENCRYPTED_TRANSPORT") ctx.decrypt_buffer(message_length) + + buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO handle session_id, message_type = ustruct.unpack( - ">BH", memoryview(ctx.buffer)[INIT_HEADER_LENGTH:] + ">BH", memoryview(buffer)[INIT_HEADER_LENGTH:] ) if session_id not in ctx.sessions: @@ -372,7 +384,7 @@ async def _handle_state_ENCRYPTED_TRANSPORT(ctx: Channel, message_length: int) - s.incoming_message.put( Message( message_type, - ctx.buffer[ + buffer[ INIT_HEADER_LENGTH + MESSAGE_TYPE_LENGTH + SESSION_ID_LENGTH : message_length @@ -391,14 +403,17 @@ async def _handle_pairing(ctx: Channel, message_length: int) -> None: loop.schedule(ctx.connection_context.handle()) ctx.decrypt_buffer(message_length) + buffer = memory_manager.get_existing_read_buffer(ctx.get_channel_id_int()) + # if buffer is BufferError: + # pass # TODO handle message_type = ustruct.unpack( - ">H", ctx.buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :] + ">H", buffer[INIT_HEADER_LENGTH + SESSION_ID_LENGTH :] )[0] ctx.connection_context.incoming_message.put( Message( message_type, - ctx.buffer[ + buffer[ INIT_HEADER_LENGTH + MESSAGE_TYPE_LENGTH + SESSION_ID_LENGTH : message_length