Solve race condition

M1nd3r/thp2
M1nd3r 2 weeks ago
parent 9869b42ce5
commit 50fe43646a

@ -60,7 +60,7 @@ EXPERIMENTAL_ENABLED = False
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
"""Initialize the wire stack on passed WireInterface."""
if utils.USE_THP:
if utils.USE_THP and not is_debug_session:
loop.schedule(handle_thp_session(iface, is_debug_session))
else:
loop.schedule(handle_session(iface, is_debug_session))

@ -1,40 +1,30 @@
from storage.cache_thp import ChannelCache, SessionThpCache
from trezor import log, loop
from trezor import log
from . import thp_session as THP
def handle_received_ACK(
cache: SessionThpCache | ChannelCache,
sync_bit: int,
waiting_for_ack_timeout: loop.spawn | None = None,
) -> None:
def is_ack_valid(cache: SessionThpCache | ChannelCache, sync_bit: int) -> bool:
if not _is_ack_expected(cache):
return False
if _ack_is_not_expected(cache):
_conditionally_log_debug("Received unexpected ACK message")
return
if _ack_has_incorrect_sync_bit(cache, sync_bit):
_conditionally_log_debug("Received ACK message with wrong sync bit")
return
if not _has_ack_correct_sync_bit(cache, sync_bit):
return False
# ACK is expected and it has correct sync bit
_conditionally_log_debug("Received ACK message with correct sync bit")
if waiting_for_ack_timeout is not None:
waiting_for_ack_timeout.close()
_conditionally_log_debug('Closed "waiting for ack" task')
THP.sync_set_can_send_message(cache, True)
return True
def _ack_is_not_expected(cache: SessionThpCache | ChannelCache) -> bool:
return THP.sync_can_send_message(cache)
def _is_ack_expected(cache: SessionThpCache | ChannelCache) -> bool:
is_expected: bool = not THP.sync_can_send_message(cache)
if not is_expected and __debug__:
log.debug(__name__, "Received unexpected ACK message")
return is_expected
def _ack_has_incorrect_sync_bit(
def _has_ack_correct_sync_bit(
cache: SessionThpCache | ChannelCache, sync_bit: int
) -> bool:
return THP.sync_get_send_bit(cache) != sync_bit
def _conditionally_log_debug(message):
if __debug__:
log.debug(__name__, message)
is_correct: bool = THP.sync_get_send_bit(cache) == sync_bit
if __debug__ and not is_correct:
log.debug(__name__, "Received ACK message with wrong sync bit")
return is_correct

@ -61,6 +61,7 @@ class Channel(Context):
self.channel_cache = channel_cache
self.buffer: utils.BufferType
self.waiting_for_ack_timeout: loop.spawn | None = None
self.write_task_spawn: loop.spawn | None = None
self.is_cont_packet_expected: bool = False
self.expected_payload_length: int = 0
self.bytes_read: int = 0
@ -199,10 +200,7 @@ class Channel(Context):
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
ack_handler.handle_received_ACK(
self.channel_cache, sync_bit, self.waiting_for_ack_timeout
)
self._todo_clear_buffer()
await self._handle_ack(sync_bit)
return
if (
@ -232,6 +230,27 @@ class Channel(Context):
if __debug__:
log.debug(__name__, "handle_completed_message - end")
async def _handle_ack(self, sync_bit: int):
if not ack_handler.is_ack_valid(self.channel_cache, sync_bit):
return
# ACK is expected and it has correct sync bit
if __debug__:
log.debug(__name__, "Received ACK message with correct sync bit")
if self.waiting_for_ack_timeout is not None:
self.waiting_for_ack_timeout.close()
if __debug__:
log.debug(__name__, 'Closed "waiting for ack" task')
THP.sync_set_can_send_message(self.channel_cache, True)
if self.write_task_spawn is not None:
if __debug__:
log.debug(__name__, 'Control to "write_encrypted_payload_loop" task')
await self.write_task_spawn
# Note that no the write_task_spawn could result in loop.clear(),
# which will result in terminations of this function - any code after
# this await might not be executed
def _check_checksum(self, message_length: int):
if __debug__:
log.debug(__name__, "check_checksum")
@ -286,7 +305,8 @@ class Channel(Context):
cache_thp.set_channel_host_ephemeral_key(self.channel_cache, host_ephemeral_key)
# send handshake init response message
loop.schedule(
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
HANDSHAKE_INIT_RES, thp_messages.get_handshake_init_response()
)
@ -338,7 +358,8 @@ class Channel(Context):
paired: bool = False # TODO should be output from credential check
# send hanshake completion response
loop.schedule(
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
HANDSHAKE_COMP_RES,
thp_messages.get_handshake_completion_response(paired=paired),
@ -543,13 +564,19 @@ class Channel(Context):
self._encrypt(self.buffer, payload_length)
payload_length = payload_length + TAG_LENGTH
loop.schedule(
if self.write_task_spawn is not None:
self.write_task_spawn.close() # UPS TODO migh break something
print("\nCLOSED\n")
self._prepare_write()
self.write_task_spawn = loop.spawn(
self._write_encrypted_payload_loop(
ENCRYPTED_TRANSPORT, memoryview(self.buffer[:payload_length])
)
)
if __debug__:
log.debug(__name__, "Scheduled _write_encrypted_payload_loop")
def _prepare_write(self) -> None:
# TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False)
async def _write_encrypted_payload_loop(
self, ctrl_byte: int, payload: bytes
@ -563,8 +590,6 @@ class Channel(Context):
chksum = checksum.compute(header.to_bytes() + payload)
payload = payload + chksum
# TODO add condition that disallows to write when can_send_message is false
THP.sync_set_can_send_message(self.channel_cache, False)
while True:
if __debug__:
log.debug(
@ -576,7 +601,13 @@ class Channel(Context):
await write_payload_to_wire(self.iface, header, payload)
self.waiting_for_ack_timeout = loop.spawn(self._wait_for_ack())
try:
await self.waiting_for_ack_timeout
if THP.sync_can_send_message(self.channel_cache):
# TODO This can happen when ack is received before the message was sent,
# but after it was scheduled to be sent (i.e. ACK was already expected)
# This case should be removed or improved upon before production.
break
else:
await self.waiting_for_ack_timeout
except loop.TaskClosed:
break
@ -584,14 +615,16 @@ class Channel(Context):
# Let the main loop be restarted and clear loop, if there is no other
# workflow and the state is ENCRYPTED_TRANSPORT
if (
not workflow.tasks
and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
):
if self._can_clear_loop():
if __debug__:
log.debug(__name__, "Clearing loop from channel")
log.debug(__name__, "clearing loop from channel")
loop.clear()
def _can_clear_loop(self) -> bool:
return (
not workflow.tasks
) and self.get_channel_state() is ChannelState.ENCRYPTED_TRANSPORT
async def _wait_for_ack(self) -> None:
await loop.sleep(1000)

Loading…
Cancel
Save