diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 2d414978e..15e3cf992 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -154,7 +154,7 @@ class Channel(Context): await self._buffer_packet_data(self.buffer, packet, 0) if __debug__: - log.debug(__name__, "channel._handle_init_packet - end") + log.debug(__name__, "handle_init_packet - end") async def _handle_cont_packet(self, packet: utils.BufferType): if __debug__: @@ -188,6 +188,12 @@ class Channel(Context): self._todo_clear_buffer() return + if self._should_be_encrypted() and not _is_ctrl_byte_encrypted_transport( + ctrl_byte + ): + self._todo_clear_buffer() + raise ThpError("Message is not encrypted. Ignoring") + # 2: Handle message with unexpected synchronization bit if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache): if __debug__: @@ -231,10 +237,6 @@ class Channel(Context): await self._handle_state_TH1(payload_length, message_length, sync_bit) return - if not _is_ctrl_byte_encrypted_transport(ctrl_byte): - self._todo_clear_buffer() - raise ThpError("Message is not encrypted. Ignoring") - if state is ChannelState.ENCRYPTED_TRANSPORT: await self._handle_state_ENCRYPTED_TRANSPORT(message_length) return @@ -304,7 +306,6 @@ class Channel(Context): if session_id == 0: await self._handle_channel_message(message_length, message_type) return - if session_id not in self.sessions: await self.write_error( FailureType.ThpUnallocatedSession, "Unallocated session" @@ -317,7 +318,6 @@ class Channel(Context): FailureType.ThpUnallocatedSession, "Unallocated session" ) raise ThpError("Unalloacted session") - self.sessions[session_id].incoming_message.publish( MessageWithType( message_type, @@ -330,6 +330,11 @@ class Channel(Context): async def _handle_pairing(self, message_length: int) -> None: pass + def _should_be_encrypted(self) -> bool: + if self.get_channel_state() in [ChannelState.UNALLOCATED, ChannelState.TH1]: + return False + return True + async def _handle_channel_message( self, message_length: int, message_type: int ) -> None: @@ -434,7 +439,7 @@ class Channel(Context): async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: if __debug__: - log.debug(__name__, "channel.write: %s", msg.MESSAGE_NAME) + log.debug(__name__, "write message: %s", msg.MESSAGE_NAME) noise_payload_len = self._encode_into_buffer(msg, session_id) await self.write_and_encrypt(self.buffer[:noise_payload_len]) diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index a9adaa93a..b6a67b53b 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -113,15 +113,14 @@ class SessionContext(Context): expected_type: type[protobuf.MessageType] | None = None, ) -> protobuf.MessageType: if __debug__: + exp_type: str = str(expected_type) log.debug( __name__, "Read - with expected types %s and expected type %s", str(expected_types), - str(expected_type), + exp_type, ) message: MessageWithType = await self.incoming_message.take() - if __debug__: - log.debug(__name__, "I'm here") if message.type not in expected_types: raise UnexpectedMessageWithType(message) @@ -158,4 +157,5 @@ def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO if session.channel_id == channel.channel_id: sid = int.from_bytes(session.session_id, "big") sessions[sid] = SessionContext(channel, session) + loop.schedule(sessions[sid].handle()) return sessions diff --git a/core/src/trezor/wire/thp/thp_session.py b/core/src/trezor/wire/thp/thp_session.py index 78f1fa7c9..12778f91f 100644 --- a/core/src/trezor/wire/thp/thp_session.py +++ b/core/src/trezor/wire/thp/thp_session.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports] from storage import cache_thp as storage_thp_cache from storage.cache_thp import ChannelCache, SessionThpCache +from trezor import log from trezor.wire.protocol_common import WireError if TYPE_CHECKING: @@ -85,6 +86,8 @@ def sync_set_can_send_message( def sync_set_receive_expected_bit( cache: SessionThpCache | ChannelCache, bit: int ) -> None: + if __debug__: + log.debug(__name__, "Set sync receive expected bit to %d", bit) if bit not in (0, 1): raise ThpError("Unexpected receive sync bit")