mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
fix(core): fix loading of sessions from cache, improve logging
Improve logging of expected message type in session context
This commit is contained in:
parent
6c762c50bc
commit
2c97edb183
@ -10,6 +10,9 @@ if TYPE_CHECKING:
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
# THP specific constants
|
||||
_MAX_CHANNELS_COUNT = 10
|
||||
_MAX_SESSIONS_COUNT = const(20)
|
||||
@ -170,6 +173,12 @@ def get_all_allocated_sessions() -> list[SessionThpCache]:
|
||||
for session in _SESSIONS:
|
||||
if _get_session_state(session) != _UNALLOCATED_STATE:
|
||||
_list.append(session)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__, "session %s is not in UNALLOCATED state", str(session)
|
||||
)
|
||||
elif __debug__:
|
||||
log.debug(__name__, "session %s is in UNALLOCATED state", str(session))
|
||||
return _list
|
||||
|
||||
|
||||
|
@ -125,7 +125,6 @@ async def handle_thp_session(iface: WireInterface, is_debug_session: bool = Fals
|
||||
# loop.clear() above.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
print("Exception raised:", exc)
|
||||
|
||||
|
||||
async def handle_session(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||
|
@ -20,6 +20,7 @@ class ChannelState(IntEnum):
|
||||
|
||||
class SessionState(IntEnum):
|
||||
UNALLOCATED = 0
|
||||
ALLOCATED = 1
|
||||
|
||||
|
||||
class WireInterfaceType(IntEnum):
|
||||
|
@ -1,31 +1,40 @@
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from trezor import log
|
||||
from storage.cache_thp import ChannelCache, SessionThpCache
|
||||
from trezor import log, loop
|
||||
|
||||
from . import thp_session as THP
|
||||
|
||||
|
||||
def handle_received_ACK(cache: SessionThpCache, sync_bit: int) -> None:
|
||||
def handle_received_ACK(
|
||||
cache: SessionThpCache | ChannelCache,
|
||||
sync_bit: int,
|
||||
waiting_for_ack_timeout: loop.spawn | None = None,
|
||||
) -> None:
|
||||
|
||||
if _ack_is_not_expected(cache):
|
||||
_conditionally_log_debug(__name__, "Received unexpected ACK message")
|
||||
_conditionally_log_debug("Received unexpected ACK message")
|
||||
return
|
||||
if _ack_has_incorrect_sync_bit(cache, sync_bit):
|
||||
_conditionally_log_debug(__name__, "Received ACK message with wrong sync bit")
|
||||
_conditionally_log_debug("Received ACK message with wrong sync bit")
|
||||
return
|
||||
|
||||
# ACK is expected and it has correct sync bit
|
||||
_conditionally_log_debug(__name__, "Received ACK message with correct sync bit")
|
||||
_conditionally_log_debug("Received ACK message with correct sync bit")
|
||||
if waiting_for_ack_timeout is not None:
|
||||
waiting_for_ack_timeout.close()
|
||||
_conditionally_log_debug('Closed "waiting for ack" task')
|
||||
THP.sync_set_can_send_message(cache, True)
|
||||
|
||||
|
||||
def _ack_is_not_expected(cache: SessionThpCache) -> bool:
|
||||
def _ack_is_not_expected(cache: SessionThpCache | ChannelCache) -> bool:
|
||||
return THP.sync_can_send_message(cache)
|
||||
|
||||
|
||||
def _ack_has_incorrect_sync_bit(cache: SessionThpCache, sync_bit: int) -> bool:
|
||||
def _ack_has_incorrect_sync_bit(
|
||||
cache: SessionThpCache | ChannelCache, sync_bit: int
|
||||
) -> bool:
|
||||
return THP.sync_get_send_bit(cache) != sync_bit
|
||||
|
||||
|
||||
def _conditionally_log_debug(name, message):
|
||||
def _conditionally_log_debug(message):
|
||||
if __debug__:
|
||||
log.debug(name, message)
|
||||
log.debug(__name__, message)
|
||||
|
@ -9,12 +9,13 @@ from trezor import log, loop, protobuf, utils
|
||||
from trezor.enums import FailureType, MessageType
|
||||
from trezor.messages import Failure, ThpCreateNewSession
|
||||
from trezor.wire import message_handler
|
||||
from trezor.wire.thp import thp_messages
|
||||
from trezor.wire.thp import ack_handler, thp_messages
|
||||
|
||||
from ..protocol_common import Context, MessageWithType
|
||||
from . import ChannelState, SessionState, checksum, crypto
|
||||
from . import thp_session as THP
|
||||
from .checksum import CHECKSUM_LENGTH
|
||||
from .crypto import PUBKEY_LENGTH
|
||||
from .thp_messages import (
|
||||
ACK_MESSAGE,
|
||||
CONTINUATION_PACKET,
|
||||
@ -38,8 +39,6 @@ if TYPE_CHECKING:
|
||||
_WIRE_INTERFACE_USB = b"\x01"
|
||||
_MOCK_INTERFACE_HID = b"\x00"
|
||||
|
||||
_PUBKEY_LENGTH = const(32)
|
||||
|
||||
|
||||
MESSAGE_TYPE_LENGTH = const(2)
|
||||
|
||||
@ -183,7 +182,9 @@ class Channel(Context):
|
||||
|
||||
# 1: Handle ACKs
|
||||
if _is_ctrl_byte_ack(ctrl_byte):
|
||||
self._handle_received_ACK(sync_bit)
|
||||
ack_handler.handle_received_ACK(
|
||||
self.channel_cache, sync_bit, self.waiting_for_ack_timeout
|
||||
)
|
||||
self._todo_clear_buffer()
|
||||
return
|
||||
|
||||
@ -251,7 +252,7 @@ class Channel(Context):
|
||||
) -> None:
|
||||
if not _is_ctrl_byte_handshake_init:
|
||||
raise ThpError("Message received is not a handshake init request!")
|
||||
if not payload_length == _PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
||||
if not payload_length == PUBKEY_LENGTH + CHECKSUM_LENGTH:
|
||||
raise ThpError("Message received is not a valid handshake init request!")
|
||||
host_ephemeral_key = bytearray(
|
||||
self.buffer[INIT_DATA_OFFSET : message_length - CHECKSUM_LENGTH]
|
||||
@ -301,7 +302,7 @@ class Channel(Context):
|
||||
self._decrypt_buffer(message_length)
|
||||
session_id, message_type = ustruct.unpack(">BH", self.buffer[INIT_DATA_OFFSET:])
|
||||
if session_id == 0:
|
||||
self._handle_channel_message(message_length, message_type)
|
||||
await self._handle_channel_message(message_length, message_type)
|
||||
return
|
||||
|
||||
if session_id not in self.sessions:
|
||||
@ -329,7 +330,9 @@ class Channel(Context):
|
||||
async def _handle_pairing(self, message_length: int) -> None:
|
||||
pass
|
||||
|
||||
def _handle_channel_message(self, message_length: int, message_type: int) -> None:
|
||||
async def _handle_channel_message(
|
||||
self, message_length: int, message_type: int
|
||||
) -> None:
|
||||
buf = self.buffer[
|
||||
INIT_DATA_OFFSET + 3 : message_length - CHECKSUM_LENGTH - TAG_LENGTH
|
||||
]
|
||||
@ -361,7 +364,7 @@ class Channel(Context):
|
||||
log.debug(
|
||||
__name__, "handle_channel_message - message size: %d", message_size
|
||||
)
|
||||
loop.schedule(self.write_and_encrypt(bufferrone))
|
||||
await self.write_and_encrypt(bufferrone)
|
||||
# TODO not finished
|
||||
|
||||
def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray:
|
||||
@ -537,6 +540,7 @@ class Channel(Context):
|
||||
from trezor.wire.thp.session_context import SessionContext
|
||||
|
||||
session = SessionContext.create_new_session(self)
|
||||
session.set_session_state(SessionState.ALLOCATED)
|
||||
self.sessions[session.session_id] = session
|
||||
loop.schedule(session.handle())
|
||||
if __debug__:
|
||||
@ -553,33 +557,6 @@ class Channel(Context):
|
||||
# TODO Buffer clearing not implemented
|
||||
pass
|
||||
|
||||
# TODO add debug logging to ACK handling
|
||||
def _handle_received_ACK(self, sync_bit: int) -> None:
|
||||
if self._ack_is_not_expected():
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_received_ACK - ack is not expected")
|
||||
return
|
||||
if self._ack_has_incorrect_sync_bit(sync_bit):
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"handle_received_ACK - ack has incorrect sync bit",
|
||||
)
|
||||
return
|
||||
|
||||
if self.waiting_for_ack_timeout is not None:
|
||||
self.waiting_for_ack_timeout.close()
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle_received_ACK - closed waiting for ack")
|
||||
|
||||
THP.sync_set_can_send_message(self.channel_cache, True)
|
||||
|
||||
def _ack_is_not_expected(self) -> bool:
|
||||
return THP.sync_can_send_message(self.channel_cache)
|
||||
|
||||
def _ack_has_incorrect_sync_bit(self, sync_bit: int) -> bool:
|
||||
return THP.sync_get_send_bit(self.channel_cache) != sync_bit
|
||||
|
||||
|
||||
def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
|
||||
channels: dict[int, Channel] = {}
|
||||
|
@ -1,4 +1,7 @@
|
||||
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
||||
|
||||
DUMMY_TAG = b"\xA0\xA1\xA2\xA3\xA4\xA5\xA6\xA7\xA8\xA9\xB0\xB1\xB2\xB3\xB4\xB5"
|
||||
PUBKEY_LENGTH = const(32)
|
||||
|
||||
|
||||
# TODO implement
|
||||
|
@ -45,10 +45,12 @@ class SessionContext(Context):
|
||||
return cls(channel_context, session_cache)
|
||||
|
||||
async def handle(self, is_debug_session: bool = False) -> None:
|
||||
if __debug__ and is_debug_session:
|
||||
import apps.debug
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle - start")
|
||||
if is_debug_session:
|
||||
import apps.debug
|
||||
|
||||
apps.debug.DEBUG_CONTEXT = self
|
||||
apps.debug.DEBUG_CONTEXT = self
|
||||
|
||||
take = self.incoming_message.take()
|
||||
next_message: MessageWithType | None = None
|
||||
@ -110,8 +112,19 @@ class SessionContext(Context):
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
|
||||
if __debug__:
|
||||
exp_type: str = str(expected_type)
|
||||
if expected_type is not None:
|
||||
exp_type = expected_type.MESSAGE_NAME
|
||||
log.debug(
|
||||
__name__,
|
||||
"Read - with expected types %s and expected type %s",
|
||||
str(expected_types),
|
||||
exp_type,
|
||||
)
|
||||
message: MessageWithType = await self.incoming_message.take()
|
||||
if __debug__:
|
||||
log.debug(__name__, "I'm here")
|
||||
if message.type not in expected_types:
|
||||
raise UnexpectedMessageWithType(message)
|
||||
|
||||
@ -130,27 +143,22 @@ class SessionContext(Context):
|
||||
return SessionState(state)
|
||||
|
||||
def set_session_state(self, state: SessionState) -> None:
|
||||
self.session_cache.state = bytearray(state.value.to_bytes(1, "big"))
|
||||
|
||||
# Called by channel context
|
||||
|
||||
async def receive_message(self, message_type, encoded_protobuf_message):
|
||||
pass # TODO implement
|
||||
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
|
||||
|
||||
|
||||
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
|
||||
print("session_context.load_cached_sessions")
|
||||
if __debug__:
|
||||
log.debug(__name__, "load_cached_sessions")
|
||||
sessions: dict[int, SessionContext] = {}
|
||||
cached_sessions = cache_thp.get_all_allocated_sessions()
|
||||
print(
|
||||
"session_context.load_cached_sessions - loaded a total of ",
|
||||
len(cached_sessions),
|
||||
"sessions from cache",
|
||||
)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"load_cached_sessions - loaded a total of %d sessions from cache",
|
||||
len(cached_sessions),
|
||||
)
|
||||
for session in cached_sessions:
|
||||
if session.channel_id == channel.channel_id:
|
||||
sid = int.from_bytes(session.session_id, "big")
|
||||
sessions[sid] = SessionContext(channel, session)
|
||||
for i in sessions:
|
||||
print("session", i)
|
||||
return sessions
|
||||
|
Loading…
Reference in New Issue
Block a user