Fix order of checks

M1nd3r/thp5
M1nd3r 1 month ago
parent 3bd33de778
commit c724990c02

@ -154,7 +154,7 @@ class Channel(Context):
await self._buffer_packet_data(self.buffer, packet, 0)
if __debug__:
log.debug(__name__, "channel._handle_init_packet - end")
log.debug(__name__, "handle_init_packet - end")
async def _handle_cont_packet(self, packet: utils.BufferType):
if __debug__:
@ -188,6 +188,12 @@ class Channel(Context):
self._todo_clear_buffer()
return
if self._should_be_encrypted() and not _is_ctrl_byte_encrypted_transport(
ctrl_byte
):
self._todo_clear_buffer()
raise ThpError("Message is not encrypted. Ignoring")
# 2: Handle message with unexpected synchronization bit
if sync_bit != THP.sync_get_receive_expected_bit(self.channel_cache):
if __debug__:
@ -231,10 +237,6 @@ class Channel(Context):
await self._handle_state_TH1(payload_length, message_length, sync_bit)
return
if not _is_ctrl_byte_encrypted_transport(ctrl_byte):
self._todo_clear_buffer()
raise ThpError("Message is not encrypted. Ignoring")
if state is ChannelState.ENCRYPTED_TRANSPORT:
await self._handle_state_ENCRYPTED_TRANSPORT(message_length)
return
@ -304,7 +306,6 @@ class Channel(Context):
if session_id == 0:
await self._handle_channel_message(message_length, message_type)
return
if session_id not in self.sessions:
await self.write_error(
FailureType.ThpUnallocatedSession, "Unallocated session"
@ -317,7 +318,6 @@ class Channel(Context):
FailureType.ThpUnallocatedSession, "Unallocated session"
)
raise ThpError("Unalloacted session")
self.sessions[session_id].incoming_message.publish(
MessageWithType(
message_type,
@ -330,6 +330,11 @@ class Channel(Context):
async def _handle_pairing(self, message_length: int) -> None:
pass
def _should_be_encrypted(self) -> bool:
if self.get_channel_state() in [ChannelState.UNALLOCATED, ChannelState.TH1]:
return False
return True
async def _handle_channel_message(
self, message_length: int, message_type: int
) -> None:
@ -434,7 +439,7 @@ class Channel(Context):
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
if __debug__:
log.debug(__name__, "channel.write: %s", msg.MESSAGE_NAME)
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
noise_payload_len = self._encode_into_buffer(msg, session_id)
await self.write_and_encrypt(self.buffer[:noise_payload_len])

@ -113,15 +113,14 @@ class SessionContext(Context):
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__:
exp_type: str = str(expected_type)
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
str(expected_type),
exp_type,
)
message: MessageWithType = await self.incoming_message.take()
if __debug__:
log.debug(__name__, "I'm here")
if message.type not in expected_types:
raise UnexpectedMessageWithType(message)
@ -158,4 +157,5 @@ def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
if session.channel_id == channel.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel, session)
loop.schedule(sessions[sid].handle())
return sessions

@ -3,6 +3,7 @@ from typing import TYPE_CHECKING # pyright:ignore[reportShadowedImports]
from storage import cache_thp as storage_thp_cache
from storage.cache_thp import ChannelCache, SessionThpCache
from trezor import log
from trezor.wire.protocol_common import WireError
if TYPE_CHECKING:
@ -85,6 +86,8 @@ def sync_set_can_send_message(
def sync_set_receive_expected_bit(
cache: SessionThpCache | ChannelCache, bit: int
) -> None:
if __debug__:
log.debug(__name__, "Set sync receive expected bit to %d", bit)
if bit not in (0, 1):
raise ThpError("Unexpected receive sync bit")

Loading…
Cancel
Save