From 685cadf8c931db2404528f9ac73f055d280fe71d Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Tue, 2 Apr 2024 16:47:47 +0200 Subject: [PATCH] Fix synchronization --- core/src/trezor/wire/thp/channel.py | 108 ++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 28 deletions(-) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 35bacc464..d76037a42 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -6,7 +6,7 @@ from ubinascii import hexlify import usb from storage import cache_thp from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache -from trezor import io, loop, protobuf, utils +from trezor import io, log, loop, protobuf, utils from trezor.messages import ThpCreateNewSession from trezor.wire import message_handler from trezor.wire.thp import thp_messages @@ -41,6 +41,8 @@ MESSAGE_TYPE_LENGTH = const(2) REPORT_LENGTH = const(64) MAX_PAYLOAD_LEN = const(60000) +_ACK_MESSAGE = 0x20 + class Channel(Context): def __init__(self, channel_cache: ChannelCache) -> None: @@ -139,10 +141,13 @@ class Channel(Context): async def _handle_completed_message(self): print("handling completed message") + print("send snyc bit::", THP.sync_get_send_bit(self.channel_cache)) ctrl_byte, _, payload_length = ustruct.unpack(">BHH", self.buffer) msg_len = payload_length + INIT_DATA_OFFSET + print("checksum check") - printBytes(self.buffer) + # printBytes(self.buffer) + if not checksum.is_valid( checksum=self.buffer[msg_len - CHECKSUM_LENGTH : msg_len], data=self.buffer[: msg_len - CHECKSUM_LENGTH], @@ -150,16 +155,40 @@ class Channel(Context): # checksum is not valid -> ignore message self._todo_clear_buffer() return + + # Synchronization process sync_bit = (ctrl_byte & 0x10) >> 4 print("sync bit:", sync_bit) + # 1: Handle ACKs if _is_ctrl_byte_ack(ctrl_byte): self._handle_received_ACK(sync_bit) self._todo_clear_buffer() return + # 2: Handle message with unexpected synchronization bit + if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache): + if __debug__: + log.debug( + __name__, "Received message with an unexpected synchronization bit" + ) + raise ThpError("Received message with an unexpected synchronization bit") + + # 3: Send ACK in response + if __debug__: + log.debug( + __name__, + "Writing ACK message to a channel with id: %d, sync bit: %d", + self.channel_id, + sync_bit, + ) + await self._sendAck(sync_bit) + print("___set receive bit to", 1 - sync_bit) + THP.sync_set_receive_expected_bit(self.channel_cache, 1 - sync_bit) + state = self.get_channel_state() - _print_state(state) + if __debug__: + log.debug(__name__, _state_to_str(state)) if state is ChannelState.TH1: if not _is_ctrl_byte_handshake_init: @@ -270,6 +299,21 @@ class Channel(Context): self.expected_payload_length = 0 self.is_cont_packet_expected = False + async def _sendAck(self, ack_bit: int) -> None: + ctrl_byte = self._add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit) + header = InitHeader( + ctrl_byte, int.from_bytes(self.channel_id, "big"), CHECKSUM_LENGTH + ) + chksum = checksum.compute(header.to_bytes()) + await self._write_encrypted_payload(header, chksum, CHECKSUM_LENGTH) + + def _add_sync_bit_to_ctrl_byte(self, ctrl_byte, sync_bit): + if sync_bit == 0: + return ctrl_byte & 0xEF + if sync_bit == 1: + return ctrl_byte | 0x10 + raise ThpError("Unexpected synchronization bit") + # CALLED BY WORKFLOW / SESSION CONTEXT async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: @@ -287,17 +331,26 @@ class Channel(Context): async def _write_encrypted_payload_loop(self, payload: bytes) -> None: print("write loop before while") payload_len = len(payload) + sync_bit = THP.sync_get_send_bit(self.channel_cache) + ctrl_byte = self._add_sync_bit_to_ctrl_byte(ENCRYPTED_TRANSPORT, sync_bit) header = InitHeader( - ENCRYPTED_TRANSPORT, int.from_bytes(self.channel_id, "big"), payload_len + ctrl_byte, int.from_bytes(self.channel_id, "big"), payload_len ) + # TODO add condition that disallows to write when can_send_message is false THP.sync_set_can_send_message(self.channel_cache, False) while True: - print("write encrypted payload loop - start") + print( + "write encrypted payload loop - start, sync_bit:", + header.ctrl_byte & 0x10, + " send_sync_bit:", + THP.sync_get_send_bit(self.channel_cache), + ) await self._write_encrypted_payload(header, payload, payload_len) self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) try: await self.waiting_for_ack_timeout except loop.TaskClosed: + THP.sync_set_send_bit_to_opposite(self.channel_cache) break async def _write_encrypted_payload( @@ -459,30 +512,29 @@ def _is_ctrl_byte_ack(ctrl_byte: int) -> bool: return ctrl_byte & 0xEF == ACK_MESSAGE -def _print_state(cs: int) -> None: - if cs == ChannelState.ENCRYPTED_TRANSPORT: - print("state: encrypted transport") - elif cs == ChannelState.TH1: - print("state: th1") - elif cs == ChannelState.TH2: - print("state: th2") - elif cs == ChannelState.TP1: - print("state: tp1") - elif cs == ChannelState.TP2: - print("state: tp2") - elif cs == ChannelState.TP3: - print("state: tp3") - elif cs == ChannelState.TP4: - print("state: tp4") - elif cs == ChannelState.TP5: - print("state: tp5") - elif cs == ChannelState.UNALLOCATED: - print("state: unallocated") - elif cs == ChannelState.UNAUTHENTICATED: - print("state: unauthenticated") +def _state_to_str(state: int) -> str: + if state == ChannelState.ENCRYPTED_TRANSPORT: + return "state: encrypted transport" + elif state == ChannelState.TH1: + return "state: th1" + elif state == ChannelState.TH2: + return "state: th2" + elif state == ChannelState.TP1: + return "state: tp1" + elif state == ChannelState.TP2: + return "state: tp2" + elif state == ChannelState.TP3: + return "state: tp3" + elif state == ChannelState.TP4: + return "state: tp4" + elif state == ChannelState.TP5: + return "state: tp5" + elif state == ChannelState.UNALLOCATED: + return "state: unallocated" + elif state == ChannelState.UNAUTHENTICATED: + return "state: unauthenticated" else: - print(cs) - print("state: ") + return "state: " def printBytes(a):