diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 1b530f09bf..6ed41b8415 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -9,10 +9,6 @@ if TYPE_CHECKING: pass -if __debug__: - from trezor import log - - pass # THP specific constants _MAX_CHANNELS_COUNT = const(10) @@ -175,25 +171,31 @@ def get_all_allocated_channels() -> list[ChannelCache]: return _list -def get_allocated_sessions(channel_id: bytes) -> list[SessionThpCache]: - if __debug__: - from trezor.utils import get_bytes_as_str - _list: list[SessionThpCache] = [] +def get_allocated_session( + channel_id: bytes, session_id: bytes +) -> SessionThpCache | None: + """ + Finds and returns the first allocated session matching the given `channel_id` and `session_id`, + or `None` if no match is found. + + Raises `Exception` if either channel_id or session_id has an invalid length. + """ + if len(channel_id) != _CHANNEL_ID_LENGTH or len(session_id) != SESSION_ID_LENGTH: + raise Exception("At least one of arguments has invalid length") + for session in _SESSIONS: if _get_session_state(session) == _UNALLOCATED_STATE: continue if session.channel_id != channel_id: continue - _list.append(session) - if __debug__: - log.debug( - __name__, - "session with channel_id: %s and session_id: %s is in ALLOCATED state", - get_bytes_as_str(session.channel_id), - get_bytes_as_str(session.session_id), - ) + if session.session_id != session_id: + continue + return session + return None - return _list + +def is_management_session(session_cache: SessionThpCache) -> bool: + return _get_session_state(session_cache) == _MANAGEMENT_STATE def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None: diff --git a/core/src/trezor/wire/thp/session_manager.py b/core/src/trezor/wire/thp/session_manager.py index d7ab1762d6..3377ce437f 100644 --- a/core/src/trezor/wire/thp/session_manager.py +++ b/core/src/trezor/wire/thp/session_manager.py @@ -13,6 +13,9 @@ if TYPE_CHECKING: def create_new_session(channel_ctx: Channel) -> SessionContext: + """ + Creates new `SessionContext` backed by cache. + """ session_cache = cache_thp.get_new_session(channel_ctx.channel_cache) return SessionContext(channel_ctx, session_cache) @@ -20,18 +23,26 @@ def create_new_session(channel_ctx: Channel) -> SessionContext: def create_new_management_session( channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID ) -> ManagementSessionContext: + """ + Creates new `ManagementSessionContext` that is not backed by cache entry. + + Seed cannot be derived with this type of session. + """ return ManagementSessionContext(channel_ctx, session_id) def get_session_from_cache( channel_ctx: Channel, session_id: int ) -> GenericSessionContext | None: - cached_sessions = cache_thp.get_allocated_sessions(channel_ctx.channel_id) - for s in cached_sessions: - print(s, s.channel_id, int.from_bytes(s.session_id, "big")) - if ( - s.channel_id == channel_ctx.channel_id - and int.from_bytes(s.session_id, "big") == session_id - ): - return SessionContext(channel_ctx, s) - return None + """ + Returns a `SessionContext` (or `ManagementSessionContext`) reconstructed from a cache or `None` if backing cache is not found. + """ + session_id_bytes = session_id.to_bytes(1, "big") + session_cache = cache_thp.get_allocated_session( + channel_ctx.channel_id, session_id_bytes + ) + if session_cache is None: + return None + elif cache_thp.is_management_session(session_cache): + return ManagementSessionContext(channel_ctx, session_id) + return SessionContext(channel_ctx, session_cache)