1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-21 21:00:58 +00:00

chore(core): improve loading of sessions from cache, add docstrings

[no changelog]
This commit is contained in:
M1nd3r 2024-11-29 18:18:13 +01:00
parent e50521fc54
commit a2e6204db7
2 changed files with 39 additions and 26 deletions

View File

@ -9,10 +9,6 @@ if TYPE_CHECKING:
pass
if __debug__:
from trezor import log
pass
# THP specific constants
_MAX_CHANNELS_COUNT = const(10)
@ -175,25 +171,31 @@ def get_all_allocated_channels() -> list[ChannelCache]:
return _list
def get_allocated_sessions(channel_id: bytes) -> list[SessionThpCache]:
if __debug__:
from trezor.utils import get_bytes_as_str
_list: list[SessionThpCache] = []
def get_allocated_session(
channel_id: bytes, session_id: bytes
) -> SessionThpCache | None:
"""
Finds and returns the first allocated session matching the given `channel_id` and `session_id`,
or `None` if no match is found.
Raises `Exception` if either channel_id or session_id has an invalid length.
"""
if len(channel_id) != _CHANNEL_ID_LENGTH or len(session_id) != SESSION_ID_LENGTH:
raise Exception("At least one of arguments has invalid length")
for session in _SESSIONS:
if _get_session_state(session) == _UNALLOCATED_STATE:
continue
if session.channel_id != channel_id:
continue
_list.append(session)
if __debug__:
log.debug(
__name__,
"session with channel_id: %s and session_id: %s is in ALLOCATED state",
get_bytes_as_str(session.channel_id),
get_bytes_as_str(session.session_id),
)
if session.session_id != session_id:
continue
return session
return None
return _list
def is_management_session(session_cache: SessionThpCache) -> bool:
return _get_session_state(session_cache) == _MANAGEMENT_STATE
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:

View File

@ -13,6 +13,9 @@ if TYPE_CHECKING:
def create_new_session(channel_ctx: Channel) -> SessionContext:
"""
Creates new `SessionContext` backed by cache.
"""
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
return SessionContext(channel_ctx, session_cache)
@ -20,18 +23,26 @@ def create_new_session(channel_ctx: Channel) -> SessionContext:
def create_new_management_session(
channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID
) -> ManagementSessionContext:
"""
Creates new `ManagementSessionContext` that is not backed by cache entry.
Seed cannot be derived with this type of session.
"""
return ManagementSessionContext(channel_ctx, session_id)
def get_session_from_cache(
channel_ctx: Channel, session_id: int
) -> GenericSessionContext | None:
cached_sessions = cache_thp.get_allocated_sessions(channel_ctx.channel_id)
for s in cached_sessions:
print(s, s.channel_id, int.from_bytes(s.session_id, "big"))
if (
s.channel_id == channel_ctx.channel_id
and int.from_bytes(s.session_id, "big") == session_id
):
return SessionContext(channel_ctx, s)
"""
Returns a `SessionContext` (or `ManagementSessionContext`) reconstructed from a cache or `None` if backing cache is not found.
"""
session_id_bytes = session_id.to_bytes(1, "big")
session_cache = cache_thp.get_allocated_session(
channel_ctx.channel_id, session_id_bytes
)
if session_cache is None:
return None
elif cache_thp.is_management_session(session_cache):
return ManagementSessionContext(channel_ctx, session_id)
return SessionContext(channel_ctx, session_cache)