diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index ccddf20f7..40e685f02 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -132,6 +132,7 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals # loop.clear() above. if __debug__: log.exception(__name__, exc) + print("Exception raised:", exc) async def handle_session( diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py index 27bddd863..74e93e02c 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel_context.py @@ -1,11 +1,13 @@ import ustruct # pyright: ignore[reportMissingModuleSource] from micropython import const # pyright: ignore[reportMissingModuleSource] from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] +from ubinascii import hexlify import usb from storage import cache_thp from storage.cache_thp import KEY_LENGTH, TAG_LENGTH, ChannelCache from trezor import loop, protobuf, utils +from trezor.wire.thp import thp_messages from ..protocol_common import Context from . import ChannelState, SessionState, checksum @@ -51,39 +53,49 @@ class ChannelContext(Context): self.sessions = load_cached_sessions(self) @classmethod - def create_new_channel(cls, iface: WireInterface) -> "ChannelContext": + def create_new_channel( + cls, iface: WireInterface, buffer: utils.BufferType + ) -> "ChannelContext": channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface)) - return cls(channel_cache) + r = cls(channel_cache) + r.set_buffer(buffer) + r.set_channel_state(ChannelState.TH1) + return r # ACCESS TO CHANNEL_DATA - def get_channel_state(self) -> ChannelState: + def get_channel_state(self) -> int: state = int.from_bytes(self.channel_cache.state, "big") - return ChannelState(state) + print("get_ch_state", state) + return state def set_channel_state(self, state: ChannelState) -> None: - self.channel_cache.state = bytearray(state.value.to_bytes(1, "big")) + print("set_ch_state", int.from_bytes(state.to_bytes(1, "big"), "big")) + self.channel_cache.state = bytearray(state.to_bytes(1, "big")) def set_buffer(self, buffer: utils.BufferType) -> None: self.buffer = buffer + print("set buffer channel", type(self.buffer)) # CALLED BY THP_MAIN_LOOP async def receive_packet(self, packet: utils.BufferType): + print("receive packet") ctrl_byte = packet[0] if _is_ctrl_byte_continuation(ctrl_byte): await self._handle_cont_packet(packet) else: await self._handle_init_packet(packet) - if self.expected_payload_length == self.bytes_read: + if self.expected_payload_length + INIT_DATA_OFFSET == self.bytes_read: self._finish_message() await self._handle_completed_message() - async def _handle_init_packet(self, packet): + async def _handle_init_packet(self, packet: utils.BufferType): + print("handle_init_packet") ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) + self.expected_payload_length = payload_length packet_payload = packet[5:] - # If the channel does not "own" the buffer lock, decrypt first packet # TODO do it only when needed! if _is_ctrl_byte_encrypted_transport(ctrl_byte): @@ -103,20 +115,33 @@ class ChannelContext(Context): else: pass # TODO use small buffer - - # TODO for now, we create a new big buffer every time. It should be changed - self.buffer = _get_buffer_for_payload(payload_length, packet) - + print("self.buffer2") + try: + # TODO for now, we create a new big buffer every time. It should be changed + self.buffer: utils.BufferType = _get_buffer_for_message( + payload_length, self.buffer + ) + except Exception as e: + print(e) + print("payload len", payload_length) + print("self.buffer", self.buffer) + print("self.buuffer.type", type(self.buffer)) + print("len", len(self.buffer)) await self._buffer_packet_data(self.buffer, packet, 0) + print("end init") - async def _handle_cont_packet(self, packet): + async def _handle_cont_packet(self, packet: utils.BufferType): + print("cont") if not self.is_cont_packet_expected: return # Continuation packet is not expected, ignoring await self._buffer_packet_data(self.buffer, packet, CONT_DATA_OFFSET) async def _handle_completed_message(self): + print("handling completed message") ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) msg_len = payload_length + INIT_DATA_OFFSET + print("checksum check") + printBytes(self.buffer) if not checksum.is_valid( checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len], data=self.buffer[: msg_len - CHECKSUM_LENGTH], @@ -124,7 +149,7 @@ class ChannelContext(Context): # checksum is not valid -> ignore message self._todo_clear_buffer() return - + print("sync bit") sync_bit = (ctrl_byte & 0x10) >> 4 if _is_ctrl_byte_ack(ctrl_byte): self._handle_received_ACK(sync_bit) @@ -132,6 +157,7 @@ class ChannelContext(Context): return state = self.get_channel_state() + _print_state(state) if state is ChannelState.TH1: if not _is_ctrl_byte_handshake_init: @@ -152,15 +178,26 @@ class ChannelContext(Context): return if not _is_ctrl_byte_encrypted_transport(ctrl_byte): + print("message is not encrypted. Ignoring") # TODO ignore message self._todo_clear_buffer() return - if state is ChannelState.ENCRYPTED_TRANSPORT: self._decrypt_buffer() session_id, message_type = ustruct.unpack( ">BH", self.buffer[INIT_DATA_OFFSET:] ) + if session_id == 0: + try: + message = thp_messages.decode_message( + self.buffer[INIT_DATA_OFFSET + 3 :], message_type + ) + print(message) + except Exception as e: + print(e) + + # TODO not finished + if session_id not in self.sessions: raise Exception("Unalloacted session") @@ -174,6 +211,7 @@ class ChannelContext(Context): ) if state is ChannelState.TH2: + print("th2 branche") host_encrypted_static_pubkey = self.buffer[ INIT_DATA_OFFSET : INIT_DATA_OFFSET + KEY_LENGTH + TAG_LENGTH ] @@ -187,6 +225,7 @@ class ChannelContext(Context): # TODO send ack in response # TODO send hanshake completion response self.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + print("end completed message") def _decrypt(self, payload) -> bytes: return payload # TODO add decryption process @@ -196,9 +235,10 @@ class ChannelContext(Context): # TODO decode buffer in place async def _buffer_packet_data( - self, payload_buffer, packet: utils.BufferType, offset + self, payload_buffer: utils.BufferType, packet: utils.BufferType, offset: int ): self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) + print("bytes, read:", self.bytes_read) def _finish_message(self): self.bytes_read = 0 @@ -221,7 +261,8 @@ class ChannelContext(Context): # OTHER def _todo_clear_buffer(self): - raise NotImplementedError() + # TODO Buffer clearing not implemented + pass # TODO add debug logging to ACK handling def _handle_received_ACK(self, sync_bit: int) -> None: @@ -273,22 +314,26 @@ def _encode_iface(iface: WireInterface) -> bytes: raise Exception("Unknown WireInterface") -def _get_buffer_for_payload( +def _get_buffer_for_message( payload_length: int, existing_buffer: utils.BufferType, max_length=MAX_PAYLOAD_LEN ) -> utils.BufferType: - if payload_length > max_length: + length = payload_length + INIT_DATA_OFFSET + print("length", length) + print("existing buffer type", type(existing_buffer)) + if length > max_length: raise ThpError("Message too large") - if payload_length > len(existing_buffer): + + if length > len(existing_buffer): # allocate a new buffer to fit the message try: - payload: utils.BufferType = bytearray(payload_length) + payload: utils.BufferType = bytearray(length) except MemoryError: payload = bytearray(REPORT_LENGTH) raise ThpError("Message too large") return payload # reuse a part of the supplied buffer - return memoryview(existing_buffer)[:payload_length] + return memoryview(existing_buffer)[:length] def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool: @@ -305,3 +350,33 @@ def _is_ctrl_byte_handshake_init(ctrl_byte: int) -> bool: def _is_ctrl_byte_ack(ctrl_byte: int) -> bool: return ctrl_byte & 0xEF == ACK_MESSAGE + + +def _print_state(cs: int) -> None: + if cs == ChannelState.ENCRYPTED_TRANSPORT: + print("state: encrypted transport") + elif cs == ChannelState.TH1: + print("state: th1") + elif cs == ChannelState.TH2: + print("state: th2") + elif cs == ChannelState.TP1: + print("state: tp1") + elif cs == ChannelState.TP2: + print("state: tp2") + elif cs == ChannelState.TP3: + print("state: tp3") + elif cs == ChannelState.TP4: + print("state: tp4") + elif cs == ChannelState.TP5: + print("state: tp5") + elif cs == ChannelState.UNALLOCATED: + print("state: unallocated") + elif cs == ChannelState.UNAUTHENTICATED: + print("state: unauthenticated") + else: + print(cs) + print("state: ") + + +def printBytes(a): + print(hexlify(a).decode("utf-8")) diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index 1c258997a..3de4a1434 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -1,9 +1,15 @@ import ustruct # pyright:ignore[reportMissingModuleSource] +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from storage.cache_thp import BROADCAST_CHANNEL_ID +from trezor import protobuf +from .. import message_handler from ..protocol_common import Message +if TYPE_CHECKING: + from typing import TypeVar # pyright: ignore[reportShadowedImports] + CODEC_V1 = 0x3F CONTINUATION_PACKET = 0x80 ENCRYPTED_TRANSPORT = 0x02 @@ -12,6 +18,8 @@ ACK_MESSAGE = 0x20 _ERROR = 0x41 _CHANNEL_ALLOCATION_RES = 0x40 +LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + class InitHeader: format_str = ">BHH" @@ -73,3 +81,12 @@ def get_channel_allocation_response(nonce: bytes, new_cid: bytes) -> bytes: def get_error_unallocated_channel() -> bytes: return _ERROR_UNALLOCATED_SESSION + + +def get_handshake_init_response() -> bytes: + return b"\x00" # TODO implement + + +def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: + expected_type = protobuf.type_for_wire(msg_type) + return message_handler.wrap_protobuf_load(buffer, expected_type) diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index c1f5b1621..714207499 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -54,6 +54,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag def set_buffer(buffer): global _BUFFER _BUFFER = buffer + print("setbuffer,", type(_BUFFER)) async def thp_main_loop(iface: WireInterface, is_debug_session=False): @@ -64,6 +65,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): read = loop.wait(iface.iface_num() | io.POLL_READ) while True: + print("main loop") packet = await read ctrl_byte, cid = ustruct.unpack(">BH", packet) @@ -86,6 +88,7 @@ async def thp_main_loop(iface: WireInterface, is_debug_session=False): raise ThpError("Channel has different WireInterface") if channel.get_channel_state() != ChannelState.UNALLOCATED: + print("packet type in loop:", type(packet)) await channel.receive_packet(packet) continue @@ -330,6 +333,7 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None: async def _handle_broadcast( iface: WireInterface, ctrl_byte, packet ) -> MessageWithId | None: + global _BUFFER if ctrl_byte != _CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in broadcast channel packet") if __debug__: @@ -342,7 +346,7 @@ async def _handle_broadcast( if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): raise ThpError("Checksum is not valid") - new_context: ChannelContext = ChannelContext.create_new_channel(iface) + new_context: ChannelContext = ChannelContext.create_new_channel(iface, _BUFFER) cid = int.from_bytes(new_context.channel_id, "big") _CHANNEL_CONTEXTS[cid] = new_context