diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 2cf66ece4..7e0604283 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -10,6 +10,9 @@ if TYPE_CHECKING: T = TypeVar("T") +if __debug__: + from trezor import log + # THP specific constants _MAX_CHANNELS_COUNT = 10 _MAX_SESSIONS_COUNT = const(20) @@ -170,6 +173,12 @@ def get_all_allocated_sessions() -> list[SessionThpCache]: for session in _SESSIONS: if _get_session_state(session) != _UNALLOCATED_STATE: _list.append(session) + if __debug__: + log.debug( + __name__, "session %s is not in UNALLOCATED state", str(session) + ) + elif __debug__: + log.debug(__name__, "session %s is in UNALLOCATED state", str(session)) return _list diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 729d31254..85b53e26a 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -125,7 +125,6 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals # loop.clear() above. if __debug__: log.exception(__name__, exc) - print("Exception raised:", exc) async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None: diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py index 9538dc0af..0a9a18131 100644 --- a/core/src/trezor/wire/thp/__init__.py +++ b/core/src/trezor/wire/thp/__init__.py @@ -20,6 +20,7 @@ class ChannelState(IntEnum): class SessionState(IntEnum): UNALLOCATED = 0 + ALLOCATED = 1 class WireInterfaceType(IntEnum): diff --git a/core/src/trezor/wire/thp/ack_handler.py b/core/src/trezor/wire/thp/ack_handler.py index 992bda30e..1bf9d5a58 100644 --- a/core/src/trezor/wire/thp/ack_handler.py +++ b/core/src/trezor/wire/thp/ack_handler.py @@ -1,31 +1,40 @@ -from storage.cache_thp import SessionThpCache -from trezor import log +from storage.cache_thp import ChannelCache, SessionThpCache +from trezor import log, loop from . import thp_session as THP -def handle_received_ACK(cache: SessionThpCache, sync_bit: int) -> None: +def handle_received_ACK( + cache: SessionThpCache | ChannelCache, + sync_bit: int, + waiting_for_ack_timeout: loop.spawn | None = None, +) -> None: if _ack_is_not_expected(cache): - _conditionally_log_debug(__name__, "Received unexpected ACK message") + _conditionally_log_debug("Received unexpected ACK message") return if _ack_has_incorrect_sync_bit(cache, sync_bit): - _conditionally_log_debug(__name__, "Received ACK message with wrong sync bit") + _conditionally_log_debug("Received ACK message with wrong sync bit") return # ACK is expected and it has correct sync bit - _conditionally_log_debug(__name__, "Received ACK message with correct sync bit") + _conditionally_log_debug("Received ACK message with correct sync bit") + if waiting_for_ack_timeout is not None: + waiting_for_ack_timeout.close() + _conditionally_log_debug('Closed "waiting for ack" task') THP.sync_set_can_send_message(cache, True) -def _ack_is_not_expected(cache: SessionThpCache) -> bool: +def _ack_is_not_expected(cache: SessionThpCache | ChannelCache) -> bool: return THP.sync_can_send_message(cache) -def _ack_has_incorrect_sync_bit(cache: SessionThpCache, sync_bit: int) -> bool: +def _ack_has_incorrect_sync_bit( + cache: SessionThpCache | ChannelCache, sync_bit: int +) -> bool: return THP.sync_get_send_bit(cache) != sync_bit -def _conditionally_log_debug(name, message): +def _conditionally_log_debug(message): if __debug__: - log.debug(name, message) + log.debug(__name__, message) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index aad76c222..2d414978e 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -9,12 +9,13 @@ from trezor import log, loop, protobuf, utils from trezor.enums import FailureType, MessageType from trezor.messages import Failure, ThpCreateNewSession from trezor.wire import message_handler -from trezor.wire.thp import thp_messages +from trezor.wire.thp import ack_handler, thp_messages from ..protocol_common import Context, MessageWithType from . import ChannelState, SessionState, checksum, crypto from . import thp_session as THP from .checksum import CHECKSUM_LENGTH +from .crypto import PUBKEY_LENGTH from .thp_messages import ( ACK_MESSAGE, CONTINUATION_PACKET, @@ -38,8 +39,6 @@ if TYPE_CHECKING: _WIRE_INTERFACE_USB = b"\x01" _MOCK_INTERFACE_HID = b"\x00" -_PUBKEY_LENGTH = const(32) - MESSAGE_TYPE_LENGTH = const(2) @@ -183,7 +182,9 @@ class Channel(Context): # 1: Handle ACKs if _is_ctrl_byte_ack(ctrl_byte): - self._handle_received_ACK(sync_bit) + ack_handler.handle_received_ACK( + self.channel_cache, sync_bit, self.waiting_for_ack_timeout + ) self._todo_clear_buffer() return @@ -251,7 +252,7 @@ class Channel(Context): ) -> None: if not _is_ctrl_byte_handshake_init: raise ThpError("Message received is not a handshake init request!") - if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH: + if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH: raise ThpError("Message received is not a valid handshake init request!") host_ephemeral_key = bytearray( self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH] @@ -301,7 +302,7 @@ class Channel(Context): self._decrypt_buffer(message_length) session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) if session_id == 0: - self._handle_channel_message(message_length, message_type) + await self._handle_channel_message(message_length, message_type) return if session_id not in self.sessions: @@ -329,7 +330,9 @@ class Channel(Context): async def _handle_pairing(self, message_length: int) -> None: pass - def _handle_channel_message(self, message_length: int, message_type: int) -> None: + async def _handle_channel_message( + self, message_length: int, message_type: int + ) -> None: buf = self.buffer[ INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH ] @@ -361,7 +364,7 @@ class Channel(Context): log.debug( __name__, "handle_channel_message - message size: %d", message_size ) - loop.schedule(self.write_and_encrypt(bufferrone)) + await self.write_and_encrypt(bufferrone) # TODO not finished def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: @@ -537,6 +540,7 @@ class Channel(Context): from trezor.wire.thp.session_context import SessionContext session = SessionContext.create_new_session(self) + session.set_session_state(SessionState.ALLOCATED) self.sessions[session.session_id] = session loop.schedule(session.handle()) if __debug__: @@ -553,33 +557,6 @@ class Channel(Context): # TODO Buffer clearing not implemented pass - # TODO add debug logging to ACK handling - def _handle_received_ACK(self, sync_bit: int) -> None: - if self._ack_is_not_expected(): - if __debug__: - log.debug(__name__, "handle_received_ACK - ack is not expected") - return - if self._ack_has_incorrect_sync_bit(sync_bit): - if __debug__: - log.debug( - __name__, - "handle_received_ACK - ack has incorrect sync bit", - ) - return - - if self.waiting_for_ack_timeout is not None: - self.waiting_for_ack_timeout.close() - if __debug__: - log.debug(__name__, "handle_received_ACK - closed waiting for ack") - - THP.sync_set_can_send_message(self.channel_cache, True) - - def _ack_is_not_expected(self) -> bool: - return THP.sync_can_send_message(self.channel_cache) - - def _ack_has_incorrect_sync_bit(self, sync_bit: int) -> bool: - return THP.sync_get_send_bit(self.channel_cache) != sync_bit - def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO channels: dict[int, Channel] = {} diff --git a/core/src/trezor/wire/thp/crypto.py b/core/src/trezor/wire/thp/crypto.py index bb8c6632e..221c64b92 100644 --- a/core/src/trezor/wire/thp/crypto.py +++ b/core/src/trezor/wire/thp/crypto.py @@ -1,4 +1,7 @@ +from micropython import const # pyright: ignore[reportMissingModuleSource] + DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5" +PUBKEY_LENGTH = const(32) # TODO implement diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index d7cc2c753..b7c3abfa1 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -45,10 +45,12 @@ class SessionContext(Context): return cls(channel_context, session_cache) async def handle(self, is_debug_session: bool = False) -> None: - if __debug__ and is_debug_session: - import apps.debug + if __debug__: + log.debug(__name__, "handle - start") + if is_debug_session: + import apps.debug - apps.debug.DEBUG_CONTEXT = self + apps.debug.DEBUG_CONTEXT = self take = self.incoming_message.take() next_message: MessageWithType | None = None @@ -110,8 +112,19 @@ class SessionContext(Context): expected_types: Container[int], expected_type: type[protobuf.MessageType] | None = None, ) -> protobuf.MessageType: - + if __debug__: + exp_type: str = str(expected_type) + if expected_type is not None: + exp_type = expected_type.MESSAGE_NAME + log.debug( + __name__, + "Read - with expected types %s and expected type %s", + str(expected_types), + exp_type, + ) message: MessageWithType = await self.incoming_message.take() + if __debug__: + log.debug(__name__, "I'm here") if message.type not in expected_types: raise UnexpectedMessageWithType(message) @@ -130,27 +143,22 @@ class SessionContext(Context): return SessionState(state) def set_session_state(self, state: SessionState) -> None: - self.session_cache.state = bytearray(state.value.to_bytes(1, "big")) - - # Called by channel context - - async def receive_message(self, message_type, encoded_protobuf_message): - pass # TODO implement + self.session_cache.state = bytearray(state.to_bytes(1, "big")) def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO - print("session_context.load_cached_sessions") + if __debug__: + log.debug(__name__, "load_cached_sessions") sessions: dict[int, SessionContext] = {} cached_sessions = cache_thp.get_all_allocated_sessions() - print( - "session_context.load_cached_sessions - loaded a total of ", - len(cached_sessions), - "sessions from cache", - ) + if __debug__: + log.debug( + __name__, + "load_cached_sessions - loaded a total of %d sessions from cache", + len(cached_sessions), + ) for session in cached_sessions: if session.channel_id == channel.channel_id: sid = int.from_bytes(session.session_id, "big") sessions[sid] = SessionContext(channel, session) - for i in sessions: - print("session", i) return sessions