Fix loading of sessions from cache, improve logging

M1nd3r/thp5
M1nd3r 2 months ago
parent 27765bfb78
commit 3bd33de778

@ -10,6 +10,9 @@ if TYPE_CHECKING:
T = TypeVar("T") T = TypeVar("T")
if __debug__:
from trezor import log
# THP specific constants # THP specific constants
_MAX_CHANNELS_COUNT = 10 _MAX_CHANNELS_COUNT = 10
_MAX_SESSIONS_COUNT = const(20) _MAX_SESSIONS_COUNT = const(20)
@ -170,6 +173,12 @@ def get_all_allocated_sessions() -> list[SessionThpCache]:
for session in _SESSIONS: for session in _SESSIONS:
if _get_session_state(session) != _UNALLOCATED_STATE: if _get_session_state(session) != _UNALLOCATED_STATE:
_list.append(session) _list.append(session)
if __debug__:
log.debug(
__name__, "session %s is not in UNALLOCATED state", str(session)
)
elif __debug__:
log.debug(__name__, "session %s is in UNALLOCATED state", str(session))
return _list return _list

@ -125,7 +125,6 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
# loop.clear() above. # loop.clear() above.
if __debug__: if __debug__:
log.exception(__name__, exc) log.exception(__name__, exc)
print("Exception raised:", exc)
async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None: async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None:

@ -20,6 +20,7 @@ class ChannelState(IntEnum):
class SessionState(IntEnum): class SessionState(IntEnum):
UNALLOCATED = 0 UNALLOCATED = 0
ALLOCATED = 1
class WireInterfaceType(IntEnum): class WireInterfaceType(IntEnum):

@ -1,31 +1,40 @@
from storage.cache_thp import SessionThpCache from storage.cache_thp import ChannelCache, SessionThpCache
from trezor import log from trezor import log, loop
from . import thp_session as THP from . import thp_session as THP
def handle_received_ACK(cache: SessionThpCache, sync_bit: int) -> None: def handle_received_ACK(
cache: SessionThpCache | ChannelCache,
sync_bit: int,
waiting_for_ack_timeout: loop.spawn | None = None,
) -> None:
if _ack_is_not_expected(cache): if _ack_is_not_expected(cache):
_conditionally_log_debug(__name__, "Received unexpected ACK message") _conditionally_log_debug("Received unexpected ACK message")
return return
if _ack_has_incorrect_sync_bit(cache, sync_bit): if _ack_has_incorrect_sync_bit(cache, sync_bit):
_conditionally_log_debug(__name__, "Received ACK message with wrong sync bit") _conditionally_log_debug("Received ACK message with wrong sync bit")
return return
# ACK is expected and it has correct sync bit # ACK is expected and it has correct sync bit
_conditionally_log_debug(__name__, "Received ACK message with correct sync bit") _conditionally_log_debug("Received ACK message with correct sync bit")
if waiting_for_ack_timeout is not None:
waiting_for_ack_timeout.close()
_conditionally_log_debug('Closed "waiting for ack" task')
THP.sync_set_can_send_message(cache, True) THP.sync_set_can_send_message(cache, True)
def _ack_is_not_expected(cache: SessionThpCache) -> bool: def _ack_is_not_expected(cache: SessionThpCache | ChannelCache) -> bool:
return THP.sync_can_send_message(cache) return THP.sync_can_send_message(cache)
def _ack_has_incorrect_sync_bit(cache: SessionThpCache, sync_bit: int) -> bool: def _ack_has_incorrect_sync_bit(
cache: SessionThpCache | ChannelCache, sync_bit: int
) -> bool:
return THP.sync_get_send_bit(cache) != sync_bit return THP.sync_get_send_bit(cache) != sync_bit
def _conditionally_log_debug(name, message): def _conditionally_log_debug(message):
if __debug__: if __debug__:
log.debug(name, message) log.debug(__name__, message)

@ -9,12 +9,13 @@ from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType, MessageType from trezor.enums import FailureType, MessageType
from trezor.messages import Failure, ThpCreateNewSession from trezor.messages import Failure, ThpCreateNewSession
from trezor.wire import message_handler from trezor.wire import message_handler
from trezor.wire.thp import thp_messages from trezor.wire.thp import ack_handler, thp_messages
from ..protocol_common import Context, MessageWithType from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum, crypto from . import ChannelState, SessionState, checksum, crypto
from . import thp_session as THP from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH from .checksum import CHECKSUM_LENGTH
from .crypto import PUBKEY_LENGTH
from .thp_messages import ( from .thp_messages import (
ACK_MESSAGE, ACK_MESSAGE,
CONTINUATION_PACKET, CONTINUATION_PACKET,
@ -38,8 +39,6 @@ if TYPE_CHECKING:
_WIRE_INTERFACE_USB = b"\x01" _WIRE_INTERFACE_USB = b"\x01"
_MOCK_INTERFACE_HID = b"\x00" _MOCK_INTERFACE_HID = b"\x00"
_PUBKEY_LENGTH = const(32)
MESSAGE_TYPE_LENGTH = const(2) MESSAGE_TYPE_LENGTH = const(2)
@ -183,7 +182,9 @@ class Channel(Context):
# 1: Handle ACKs # 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte): if _is_ctrl_byte_ack(ctrl_byte):
self._handle_received_ACK(sync_bit) ack_handler.handle_received_ACK(
self.channel_cache, sync_bit, self.waiting_for_ack_timeout
)
self._todo_clear_buffer() self._todo_clear_buffer()
return return
@ -251,7 +252,7 @@ class Channel(Context):
) -> None: ) -> None:
if not _is_ctrl_byte_handshake_init: if not _is_ctrl_byte_handshake_init:
raise ThpError("Message received is not a handshake init request!") raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH: if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!") raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray( host_ephemeral_key = bytearray(
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH] self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
@ -301,7 +302,7 @@ class Channel(Context):
self._decrypt_buffer(message_length) self._decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:]) session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0: if session_id == 0:
self._handle_channel_message(message_length, message_type) await self._handle_channel_message(message_length, message_type)
return return
if session_id not in self.sessions: if session_id not in self.sessions:
@ -329,7 +330,9 @@ class Channel(Context):
async def _handle_pairing(self, message_length: int) -> None: async def _handle_pairing(self, message_length: int) -> None:
pass pass
def _handle_channel_message(self, message_length: int, message_type: int) -> None: async def _handle_channel_message(
self, message_length: int, message_type: int
) -> None:
buf = self.buffer[ buf = self.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
] ]
@ -361,7 +364,7 @@ class Channel(Context):
log.debug( log.debug(
__name__, "handle_channel_message - message size: %d", message_size __name__, "handle_channel_message - message size: %d", message_size
) )
loop.schedule(self.write_and_encrypt(bufferrone)) await self.write_and_encrypt(bufferrone)
# TODO not finished # TODO not finished
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
@ -537,6 +540,7 @@ class Channel(Context):
from trezor.wire.thp.session_context import SessionContext from trezor.wire.thp.session_context import SessionContext
session = SessionContext.create_new_session(self) session = SessionContext.create_new_session(self)
session.set_session_state(SessionState.ALLOCATED)
self.sessions[session.session_id] = session self.sessions[session.session_id] = session
loop.schedule(session.handle()) loop.schedule(session.handle())
if __debug__: if __debug__:
@ -553,33 +557,6 @@ class Channel(Context):
# TODO Buffer clearing not implemented # TODO Buffer clearing not implemented
pass pass
# TODO add debug logging to ACK handling
def _handle_received_ACK(self, sync_bit: int) -> None:
if self._ack_is_not_expected():
if __debug__:
log.debug(__name__, "handle_received_ACK - ack is not expected")
return
if self._ack_has_incorrect_sync_bit(sync_bit):
if __debug__:
log.debug(
__name__,
"handle_received_ACK - ack has incorrect sync bit",
)
return
if self.waiting_for_ack_timeout is not None:
self.waiting_for_ack_timeout.close()
if __debug__:
log.debug(__name__, "handle_received_ACK - closed waiting for ack")
THP.sync_set_can_send_message(self.channel_cache, True)
def _ack_is_not_expected(self) -> bool:
return THP.sync_can_send_message(self.channel_cache)
def _ack_has_incorrect_sync_bit(self, sync_bit: int) -> bool:
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
channels: dict[int, Channel] = {} channels: dict[int, Channel] = {}

@ -1,4 +1,7 @@
from micropython import const # pyright: ignore[reportMissingModuleSource]
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5" DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
PUBKEY_LENGTH = const(32)
# TODO implement # TODO implement

@ -45,10 +45,12 @@ class SessionContext(Context):
return cls(channel_context, session_cache) return cls(channel_context, session_cache)
async def handle(self, is_debug_session: bool = False) -> None: async def handle(self, is_debug_session: bool = False) -> None:
if __debug__ and is_debug_session: if __debug__:
import apps.debug log.debug(__name__, "handle - start")
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take() take = self.incoming_message.take()
next_message: MessageWithType | None = None next_message: MessageWithType | None = None
@ -110,8 +112,16 @@ class SessionContext(Context):
expected_types: Container[int], expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None, expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType: ) -> protobuf.MessageType:
if __debug__:
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
str(expected_type),
)
message: MessageWithType = await self.incoming_message.take() message: MessageWithType = await self.incoming_message.take()
if __debug__:
log.debug(__name__, "I'm here")
if message.type not in expected_types: if message.type not in expected_types:
raise UnexpectedMessageWithType(message) raise UnexpectedMessageWithType(message)
@ -130,27 +140,22 @@ class SessionContext(Context):
return SessionState(state) return SessionState(state)
def set_session_state(self, state: SessionState) -> None: def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.value.to_bytes(1, "big")) self.session_cache.state = bytearray(state.to_bytes(1, "big"))
# Called by channel context
async def receive_message(self, message_type, encoded_protobuf_message):
pass # TODO implement
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
print("session_context.load_cached_sessions") if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, SessionContext] = {} sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions() cached_sessions = cache_thp.get_all_allocated_sessions()
print( if __debug__:
"session_context.load_cached_sessions - loaded a total of ", log.debug(
len(cached_sessions), __name__,
"sessions from cache", "load_cached_sessions - loaded a total of %d sessions from cache",
) len(cached_sessions),
)
for session in cached_sessions: for session in cached_sessions:
if session.channel_id == channel.channel_id: if session.channel_id == channel.channel_id:
sid = int.from_bytes(session.session_id, "big") sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel, session) sessions[sid] = SessionContext(channel, session)
for i in sessions:
print("session", i)
return sessions return sessions

Loading…
Cancel
Save