From 665179b9795dbc7886f46158a8adfa00f34acbf7 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 3 Apr 2024 13:56:26 +0200 Subject: [PATCH] Fix debug log, crashing and mock noise tags --- core/src/all_modules.py | 2 + core/src/storage/cache_thp.py | 4 +- core/src/trezor/wire/thp/channel.py | 69 +++++++++++++++------ core/src/trezor/wire/thp/crypto.py | 34 ++++++++++ core/src/trezor/wire/thp/session_context.py | 8 +++ core/src/trezor/wire/thp_v1.py | 57 +++++++++-------- 6 files changed, 129 insertions(+), 45 deletions(-) create mode 100644 core/src/trezor/wire/thp/crypto.py diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 59e47a0d2..f5b2df94e 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -217,6 +217,8 @@ trezor.wire.thp.channel import trezor.wire.thp.channel trezor.wire.thp.checksum import trezor.wire.thp.checksum +trezor.wire.thp.crypto +import trezor.wire.thp.crypto trezor.wire.thp.session_context import trezor.wire.thp.session_context trezor.wire.thp.thp_messages diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 05acfc446..cf228b6f0 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -180,7 +180,7 @@ def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> Non def get_new_session(channel: ChannelCache): - + print("---------------get new session") new_sid = get_next_session_id(channel) index = _get_next_session_index() @@ -194,6 +194,8 @@ def get_new_session(channel: ChannelCache): _SESSIONS[index].state[:] = bytearray( _UNALLOCATED_STATE.to_bytes(_SESSION_STATE_LENGTH, "big") ) + for s in _SESSIONS: + print(s) return _SESSIONS[index] diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 6eba93347..663a724ad 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -12,7 +12,7 @@ from trezor.wire import message_handler from trezor.wire.thp import thp_messages from ..protocol_common import Context, MessageWithType -from . import ChannelState, SessionState, checksum +from . import ChannelState, SessionState, checksum, crypto from . import thp_session as THP from .checksum import CHECKSUM_LENGTH from .thp_messages import ( @@ -90,7 +90,8 @@ class Channel(Context): await self._handle_cont_packet(packet) else: await self._handle_init_packet(packet) - + print("receive packet", self.expected_payload_length, self.bytes_read) + printBytes(self.buffer) if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: self._finish_message() await self._handle_completed_message() @@ -103,7 +104,7 @@ class Channel(Context): # If the channel does not "own" the buffer lock, decrypt first packet # TODO do it only when needed! if _is_ctrl_byte_encrypted_transport(ctrl_byte): - packet_payload = self._decrypt(packet_payload) + packet_payload = self._decrypt_single_packet_payload(packet_payload) state = self.get_channel_state() @@ -254,7 +255,7 @@ class Channel(Context): self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) def _handle_state_ENCRYPTED_TRANSPORT(self, message_length: int) -> None: - self._decrypt_buffer() + self._decrypt_buffer(message_length) session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) if session_id == 0: self._handle_channel_comms(message_length, message_type) @@ -293,18 +294,43 @@ class Channel(Context): bufferrone = bytearray(2) message_size: int = thp_messages.get_new_session_message(bufferrone) print(message_size) # TODO adjust - loop.schedule(self._write_encrypted_payload_loop(bufferrone)) + loop.schedule(self.write_and_encrypt(bufferrone)) except Exception as e: print("Proč??") print(e) # TODO not finished - def _decrypt(self, payload) -> bytes: - return payload # TODO add decryption process + def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: + payload_buffer = bytearray(payload) + crypto.decrypt(b"\x00", b"\x00", payload_buffer, INIT_DATA_OFFSET, len(payload)) + return payload_buffer + + def _decrypt_buffer(self, message_length: int) -> None: + if not isinstance(self.buffer, bytearray): + self.buffer = bytearray(self.buffer) + crypto.decrypt( + b"\x00", + b"\x00", + self.buffer, + INIT_DATA_OFFSET, + message_length - INIT_DATA_OFFSET - CHECKSUM_LENGTH, + ) - def _decrypt_buffer(self) -> None: - pass - # TODO decode buffer in place + def _encrypt(self, buffer: bytearray, noise_payload_len: int) -> None: + print("\n Encrypting ") + min_required_length = noise_payload_len + TAG_LENGTH + CHECKSUM_LENGTH + if len(buffer) < min_required_length or not isinstance(buffer, bytearray): + new_buffer = bytearray(min_required_length) + utils.memcpy(new_buffer, 0, buffer, 0) + buffer = new_buffer + tag = crypto.encrypt( + b"\x00", + b"\x00", + buffer, + 0, + noise_payload_len, + ) + buffer[noise_payload_len : noise_payload_len + TAG_LENGTH] = tag async def _buffer_packet_data( self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int @@ -327,7 +353,7 @@ class Channel(Context): log.debug( __name__, "Writing ACK message to a channel with id: %d, sync bit: %d", - self.channel_id, + int.from_bytes(self.channel_id, "big"), ack_bit, ) await self._write_payload_to_wire(header, chksum, CHECKSUM_LENGTH) @@ -343,15 +369,18 @@ class Channel(Context): async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: print("write") - noise_payload_len = self._encode_into_buffer(msg, session_id) + await self.write_and_encrypt(self.buffer[:noise_payload_len]) - # trezor.crypto.noise.encode(key, payload=self.buffer) + async def write_and_encrypt(self, payload: bytes) -> None: + payload_length = len(payload) - # TODO payload_len should be output from trezor.crypto.noise.encode, I guess - payload_len = noise_payload_len # + TAG_LENGTH # TODO + if not isinstance(self.buffer, bytearray): + self.buffer = bytearray(self.buffer) + self._encrypt(self.buffer, payload_length) + payload_length = payload_length + TAG_LENGTH - loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_len])) + loop.schedule(self._write_encrypted_payload_loop(self.buffer[:payload_length])) async def _write_encrypted_payload_loop(self, payload: bytes) -> None: print("write loop before while") @@ -419,10 +448,13 @@ class Channel(Context): msg_size = protobuf.encoded_length(msg) offset = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH payload_size = offset + msg_size + required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH - if payload_size > len(self.buffer) or not isinstance(self.buffer, bytearray): + if required_min_size > len(self.buffer) or not isinstance( + self.buffer, bytearray + ): # message is too big or buffer is not bytearray, we need to allocate a new buffer - self.buffer = bytearray(payload_size) + self.buffer = bytearray(required_min_size) buffer = self.buffer session_id_bytes = int.to_bytes(session_id, SESSION_ID_LENGTH, "big") @@ -445,6 +477,7 @@ class Channel(Context): self.sessions[session.session_id] = session loop.schedule(session.handle()) print("new session created. Session id:", session.session_id) + print(self.sessions) def _todo_clear_buffer(self): # TODO Buffer clearing not implemented diff --git a/core/src/trezor/wire/thp/crypto.py b/core/src/trezor/wire/thp/crypto.py new file mode 100644 index 000000000..bb8c6632e --- /dev/null +++ b/core/src/trezor/wire/thp/crypto.py @@ -0,0 +1,34 @@ +DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5" + + +# TODO implement + + +def encrypt( + key: bytes, + nonce: bytes, + buffer: bytearray, + init_offset: int = 0, + payload_length: int = 0, +) -> bytes: + """ + Returns a 16-byte long encryption tag, the encryption itself is performed on the buffer provided. + """ + return DUMMY_TAG + + +def decrypt( + key: bytes, + nonce: bytes, + buffer: bytearray, + init_offset: int = 0, + payload_length: int = 0, +) -> None: + """ + Decryption in place. + """ + pass + + +def is_tag_valid(key: bytes, nonce: bytes, payload: bytes, noise_tag: bytes) -> bool: + return True diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 5360ff762..ef17f5d90 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -139,10 +139,18 @@ class SessionContext(Context): def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO + print("start loading sessions from cache") sessions: dict[int, SessionContext] = {} cached_sessions = cache_thp.get_all_allocated_sessions() + print( + "loaded a total of ", + len(cached_sessions), + "sessions from cache", + ) for session in cached_sessions: if session.channel_id == channel.channel_id: sid = int.from_bytes(session.session_id, "big") sessions[sid] = SessionContext(channel, session) + for i in sessions: + print("session", i) return sessions diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index d2069f930..8ff11a8b0 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -44,34 +44,39 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): read = loop.wait(iface.iface_num() | io.POLL_READ) while True: - print("main loop") - packet = await read - ctrl_byte, cid = ustruct.unpack(">BH", packet) - - if ctrl_byte == CODEC_V1: - pass - # TODO add handling of (unsupported) codec_v1 packets - # possibly ignore continuation packets, i.e. if the - # following bytes are not "##"", do not respond - - if cid == BROADCAST_CHANNEL_ID: - # TODO handle exceptions, try-catch? - await _handle_broadcast(iface, ctrl_byte, packet) - continue - - if cid in _CHANNEL_CONTEXTS: - channel = _CHANNEL_CONTEXTS[cid] - if channel is None: - raise ThpError("Invalid state of a channel") - if channel.iface is not iface: - raise ThpError("Channel has different WireInterface") - - if channel.get_channel_state() != ChannelState.UNALLOCATED: - print("packet type in loop:", type(packet)) - await channel.receive_packet(packet) + try: + print("main loop") + packet = await read + ctrl_byte, cid = ustruct.unpack(">BH", packet) + + if ctrl_byte == CODEC_V1: + pass + # TODO add handling of (unsupported) codec_v1 packets + # possibly ignore continuation packets, i.e. if the + # following bytes are not "##"", do not respond + + if cid == BROADCAST_CHANNEL_ID: + # TODO handle exceptions, try-catch? + await _handle_broadcast(iface, ctrl_byte, packet) continue - await _handle_unallocated(iface, cid) + if cid in _CHANNEL_CONTEXTS: + channel = _CHANNEL_CONTEXTS[cid] + if channel is None: + raise ThpError("Invalid state of a channel") + if channel.iface is not iface: + raise ThpError("Channel has different WireInterface") + + if channel.get_channel_state() != ChannelState.UNALLOCATED: + print("packet type in loop:", type(packet)) + await channel.receive_packet(packet) + continue + await _handle_unallocated(iface, cid) + + except ThpError as e: + if __debug__: + log.exception(__name__, e) + # TODO add cleaning sequence if no workflow/channel is active (or some condition like that)