Fix loading of sessions from cache, improve logging

M1nd3r/thp5
M1nd3r 1 month ago
parent 27765bfb78
commit 3bd33de778

@ -10,6 +10,9 @@ if TYPE_CHECKING:
T = TypeVar("T")
if __debug__:
from trezor import log
# THP specific constants
_MAX_CHANNELS_COUNT = 10
_MAX_SESSIONS_COUNT = const(20)
@ -170,6 +173,12 @@ def get_all_allocated_sessions() -> list[SessionThpCache]:
for session in _SESSIONS:
if _get_session_state(session) != _UNALLOCATED_STATE:
_list.append(session)
if __debug__:
log.debug(
__name__, "session %s is not in UNALLOCATED state", str(session)
)
elif __debug__:
log.debug(__name__, "session %s is in UNALLOCATED state", str(session))
return _list

@ -125,7 +125,6 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)
print("Exception raised:", exc)
async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None:

@ -20,6 +20,7 @@ class ChannelState(IntEnum):
class SessionState(IntEnum):
UNALLOCATED = 0
ALLOCATED = 1
class WireInterfaceType(IntEnum):

@ -1,31 +1,40 @@
from storage.cache_thp import SessionThpCache
from trezor import log
from storage.cache_thp import ChannelCache, SessionThpCache
from trezor import log, loop
from . import thp_session as THP
def handle_received_ACK(cache: SessionThpCache, sync_bit: int) -> None:
def handle_received_ACK(
cache: SessionThpCache | ChannelCache,
sync_bit: int,
waiting_for_ack_timeout: loop.spawn | None = None,
) -> None:
if _ack_is_not_expected(cache):
_conditionally_log_debug(__name__, "Received unexpected ACK message")
_conditionally_log_debug("Received unexpected ACK message")
return
if _ack_has_incorrect_sync_bit(cache, sync_bit):
_conditionally_log_debug(__name__, "Received ACK message with wrong sync bit")
_conditionally_log_debug("Received ACK message with wrong sync bit")
return
# ACK is expected and it has correct sync bit
_conditionally_log_debug(__name__, "Received ACK message with correct sync bit")
_conditionally_log_debug("Received ACK message with correct sync bit")
if waiting_for_ack_timeout is not None:
waiting_for_ack_timeout.close()
_conditionally_log_debug('Closed "waiting for ack" task')
THP.sync_set_can_send_message(cache, True)
def _ack_is_not_expected(cache: SessionThpCache) -> bool:
def _ack_is_not_expected(cache: SessionThpCache | ChannelCache) -> bool:
return THP.sync_can_send_message(cache)
def _ack_has_incorrect_sync_bit(cache: SessionThpCache, sync_bit: int) -> bool:
def _ack_has_incorrect_sync_bit(
cache: SessionThpCache | ChannelCache, sync_bit: int
) -> bool:
return THP.sync_get_send_bit(cache) != sync_bit
def _conditionally_log_debug(name, message):
def _conditionally_log_debug(message):
if __debug__:
log.debug(name, message)
log.debug(__name__, message)

@ -9,12 +9,13 @@ from trezor import log, loop, protobuf, utils
from trezor.enums import FailureType, MessageType
from trezor.messages import Failure, ThpCreateNewSession
from trezor.wire import message_handler
from trezor.wire.thp import thp_messages
from trezor.wire.thp import ack_handler, thp_messages
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum, crypto
from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH
from .crypto import PUBKEY_LENGTH
from .thp_messages import (
ACK_MESSAGE,
CONTINUATION_PACKET,
@ -38,8 +39,6 @@ if TYPE_CHECKING:
_WIRE_INTERFACE_USB = b"\x01"
_MOCK_INTERFACE_HID = b"\x00"
_PUBKEY_LENGTH = const(32)
MESSAGE_TYPE_LENGTH = const(2)
@ -183,7 +182,9 @@ class Channel(Context):
# 1: Handle ACKs
if _is_ctrl_byte_ack(ctrl_byte):
self._handle_received_ACK(sync_bit)
ack_handler.handle_received_ACK(
self.channel_cache, sync_bit, self.waiting_for_ack_timeout
)
self._todo_clear_buffer()
return
@ -251,7 +252,7 @@ class Channel(Context):
) -> None:
if not _is_ctrl_byte_handshake_init:
raise ThpError("Message received is not a handshake init request!")
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
raise ThpError("Message received is not a valid handshake init request!")
host_ephemeral_key = bytearray(
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
@ -301,7 +302,7 @@ class Channel(Context):
self._decrypt_buffer(message_length)
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
if session_id == 0:
self._handle_channel_message(message_length, message_type)
await self._handle_channel_message(message_length, message_type)
return
if session_id not in self.sessions:
@ -329,7 +330,9 @@ class Channel(Context):
async def _handle_pairing(self, message_length: int) -> None:
pass
def _handle_channel_message(self, message_length: int, message_type: int) -> None:
async def _handle_channel_message(
self, message_length: int, message_type: int
) -> None:
buf = self.buffer[
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
]
@ -361,7 +364,7 @@ class Channel(Context):
log.debug(
__name__, "handle_channel_message - message size: %d", message_size
)
loop.schedule(self.write_and_encrypt(bufferrone))
await self.write_and_encrypt(bufferrone)
# TODO not finished
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
@ -537,6 +540,7 @@ class Channel(Context):
from trezor.wire.thp.session_context import SessionContext
session = SessionContext.create_new_session(self)
session.set_session_state(SessionState.ALLOCATED)
self.sessions[session.session_id] = session
loop.schedule(session.handle())
if __debug__:
@ -553,33 +557,6 @@ class Channel(Context):
# TODO Buffer clearing not implemented
pass
# TODO add debug logging to ACK handling
def _handle_received_ACK(self, sync_bit: int) -> None:
if self._ack_is_not_expected():
if __debug__:
log.debug(__name__, "handle_received_ACK - ack is not expected")
return
if self._ack_has_incorrect_sync_bit(sync_bit):
if __debug__:
log.debug(
__name__,
"handle_received_ACK - ack has incorrect sync bit",
)
return
if self.waiting_for_ack_timeout is not None:
self.waiting_for_ack_timeout.close()
if __debug__:
log.debug(__name__, "handle_received_ACK - closed waiting for ack")
THP.sync_set_can_send_message(self.channel_cache, True)
def _ack_is_not_expected(self) -> bool:
return THP.sync_can_send_message(self.channel_cache)
def _ack_has_incorrect_sync_bit(self, sync_bit: int) -> bool:
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
channels: dict[int, Channel] = {}

@ -1,4 +1,7 @@
from micropython import const # pyright: ignore[reportMissingModuleSource]
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
PUBKEY_LENGTH = const(32)
# TODO implement

@ -45,10 +45,12 @@ class SessionContext(Context):
return cls(channel_context, session_cache)
async def handle(self, is_debug_session: bool = False) -> None:
if __debug__ and is_debug_session:
import apps.debug
if __debug__:
log.debug(__name__, "handle - start")
if is_debug_session:
import apps.debug
apps.debug.DEBUG_CONTEXT = self
apps.debug.DEBUG_CONTEXT = self
take = self.incoming_message.take()
next_message: MessageWithType | None = None
@ -110,8 +112,16 @@ class SessionContext(Context):
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
if __debug__:
log.debug(
__name__,
"Read - with expected types %s and expected type %s",
str(expected_types),
str(expected_type),
)
message: MessageWithType = await self.incoming_message.take()
if __debug__:
log.debug(__name__, "I'm here")
if message.type not in expected_types:
raise UnexpectedMessageWithType(message)
@ -130,27 +140,22 @@ class SessionContext(Context):
return SessionState(state)
def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.value.to_bytes(1, "big"))
# Called by channel context
async def receive_message(self, message_type, encoded_protobuf_message):
pass # TODO implement
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
print("session_context.load_cached_sessions")
if __debug__:
log.debug(__name__, "load_cached_sessions")
sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions()
print(
"session_context.load_cached_sessions - loaded a total of ",
len(cached_sessions),
"sessions from cache",
)
if __debug__:
log.debug(
__name__,
"load_cached_sessions - loaded a total of %d sessions from cache",
len(cached_sessions),
)
for session in cached_sessions:
if session.channel_id == channel.channel_id:
sid = int.from_bytes(session.session_id, "big")
sessions[sid] = SessionContext(channel, session)
for i in sessions:
print("session", i)
return sessions

Loading…
Cancel
Save