diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index de8d9e360..6c952190f 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -60,7 +60,7 @@ EXPERIMENTAL_ENABLED = False def setup(iface: WireInterface, is_debug_session: bool = False) -> None: """Initialize the wire stack on passed WireInterface.""" - if utils.USE_THP: + if utils.USE_THP and not is_debug_session: loop.schedule(handle_thp_session(iface, is_debug_session)) else: loop.schedule(handle_session(iface, is_debug_session)) diff --git a/core/src/trezor/wire/thp/ack_handler.py b/core/src/trezor/wire/thp/ack_handler.py index 1bf9d5a58..10840823d 100644 --- a/core/src/trezor/wire/thp/ack_handler.py +++ b/core/src/trezor/wire/thp/ack_handler.py @@ -1,40 +1,30 @@ from storage.cache_thp import ChannelCache, SessionThpCache -from trezor import log, loop +from trezor import log from . import thp_session as THP -def handle_received_ACK( - cache: SessionThpCache | ChannelCache, - sync_bit: int, - waiting_for_ack_timeout: loop.spawn | None = None, -) -> None: +def is_ack_valid(cache: SessionThpCache | ChannelCache, sync_bit: int) -> bool: + if not _is_ack_expected(cache): + return False - if _ack_is_not_expected(cache): - _conditionally_log_debug("Received unexpected ACK message") - return - if _ack_has_incorrect_sync_bit(cache, sync_bit): - _conditionally_log_debug("Received ACK message with wrong sync bit") - return + if not _has_ack_correct_sync_bit(cache, sync_bit): + return False - # ACK is expected and it has correct sync bit - _conditionally_log_debug("Received ACK message with correct sync bit") - if waiting_for_ack_timeout is not None: - waiting_for_ack_timeout.close() - _conditionally_log_debug('Closed "waiting for ack" task') - THP.sync_set_can_send_message(cache, True) + return True -def _ack_is_not_expected(cache: SessionThpCache | ChannelCache) -> bool: - return THP.sync_can_send_message(cache) +def _is_ack_expected(cache: SessionThpCache | ChannelCache) -> bool: + is_expected: bool = not THP.sync_can_send_message(cache) + if not is_expected and __debug__: + log.debug(__name__, "Received unexpected ACK message") + return is_expected -def _ack_has_incorrect_sync_bit( +def _has_ack_correct_sync_bit( cache: SessionThpCache | ChannelCache, sync_bit: int ) -> bool: - return THP.sync_get_send_bit(cache) != sync_bit - - -def _conditionally_log_debug(message): - if __debug__: - log.debug(__name__, message) + is_correct: bool = THP.sync_get_send_bit(cache) == sync_bit + if __debug__ and not is_correct: + log.debug(__name__, "Received ACK message with wrong sync bit") + return is_correct diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 519f0c782..28969887c 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -61,6 +61,7 @@ class Channel(Context): self.channel_cache = channel_cache self.buffer: utils.BufferType self.waiting_for_ack_timeout: loop.spawn | None = None + self.write_task_spawn: loop.spawn | None = None self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 self.bytes_read: int = 0 @@ -199,10 +200,7 @@ class Channel(Context): # 1: Handle ACKs if _is_ctrl_byte_ack(ctrl_byte): - ack_handler.handle_received_ACK( - self.channel_cache, sync_bit, self.waiting_for_ack_timeout - ) - self._todo_clear_buffer() + await self._handle_ack(sync_bit) return if ( @@ -232,6 +230,27 @@ class Channel(Context): if __debug__: log.debug(__name__, "handle_completed_message - end") + async def _handle_ack(self, sync_bit: int): + if not ack_handler.is_ack_valid(self.channel_cache, sync_bit): + return + # ACK is expected and it has correct sync bit + if __debug__: + log.debug(__name__, "Received ACK message with correct sync bit") + if self.waiting_for_ack_timeout is not None: + self.waiting_for_ack_timeout.close() + if __debug__: + log.debug(__name__, 'Closed "waiting for ack" task') + + THP.sync_set_can_send_message(self.channel_cache, True) + + if self.write_task_spawn is not None: + if __debug__: + log.debug(__name__, 'Control to "write_encrypted_payload_loop" task') + await self.write_task_spawn + # Note that no the write_task_spawn could result in loop.clear(), + # which will result in terminations of this function - any code after + # this await might not be executed + def _check_checksum(self, message_length: int): if __debug__: log.debug(__name__, "check_checksum") @@ -286,7 +305,8 @@ class Channel(Context): cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key) # send handshake init response message - loop.schedule( + self._prepare_write() + self.write_task_spawn = loop.spawn( self._write_encrypted_payload_loop( HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response() ) @@ -338,7 +358,8 @@ class Channel(Context): paired: bool = False # TODO should be output from credential check # send hanshake completion response - loop.schedule( + self._prepare_write() + self.write_task_spawn = loop.spawn( self._write_encrypted_payload_loop( HANDSHAKE_COMP_RES, thp_messages.get_handshake_completion_response(paired=paired), @@ -543,13 +564,19 @@ class Channel(Context): self._encrypt(self.buffer, payload_length) payload_length = payload_length + TAG_LENGTH - loop.schedule( + if self.write_task_spawn is not None: + self.write_task_spawn.close() # UPS TODO migh break something + print("\nCLOSED\n") + self._prepare_write() + self.write_task_spawn = loop.spawn( self._write_encrypted_payload_loop( ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length]) ) ) - if __debug__: - log.debug(__name__, "Scheduled _write_encrypted_payload_loop") + + def _prepare_write(self) -> None: + # TODO add condition that disallows to write when can_send_message is false + THP.sync_set_can_send_message(self.channel_cache, False) async def _write_encrypted_payload_loop( self, ctrl_byte: int, payload: bytes @@ -563,8 +590,6 @@ class Channel(Context): chksum = checksum.compute(header.to_bytes() + payload) payload = payload + chksum - # TODO add condition that disallows to write when can_send_message is false - THP.sync_set_can_send_message(self.channel_cache, False) while True: if __debug__: log.debug( @@ -576,7 +601,13 @@ class Channel(Context): await write_payload_to_wire(self.iface, header, payload) self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack()) try: - await self.waiting_for_ack_timeout + if THP.sync_can_send_message(self.channel_cache): + # TODO This can happen when ack is received before the message was sent, + # but after it was scheduled to be sent (i.e. ACK was already expected) + # This case should be removed or improved upon before production. + break + else: + await self.waiting_for_ack_timeout except loop.TaskClosed: break @@ -584,14 +615,16 @@ class Channel(Context): # Let the main loop be restarted and clear loop, if there is no other # workflow and the state is ENCRYPTED_TRANSPORT - if ( - not workflow.tasks - and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT - ): + if self._can_clear_loop(): if __debug__: - log.debug(__name__, "Clearing loop from channel") + log.debug(__name__, "clearing loop from channel") loop.clear() + def _can_clear_loop(self) -> bool: + return ( + not workflow.tasks + ) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT + async def _wait_for_ack(self) -> None: await loop.sleep(1000)