mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-21 21:00:58 +00:00
chore(core): improve loading of sessions from cache, add docstrings
[no changelog]
This commit is contained in:
parent
e50521fc54
commit
a2e6204db7
@ -9,10 +9,6 @@ if TYPE_CHECKING:
|
||||
|
||||
pass
|
||||
|
||||
if __debug__:
|
||||
from trezor import log
|
||||
|
||||
pass
|
||||
|
||||
# THP specific constants
|
||||
_MAX_CHANNELS_COUNT = const(10)
|
||||
@ -175,25 +171,31 @@ def get_all_allocated_channels() -> list[ChannelCache]:
|
||||
return _list
|
||||
|
||||
|
||||
def get_allocated_sessions(channel_id: bytes) -> list[SessionThpCache]:
|
||||
if __debug__:
|
||||
from trezor.utils import get_bytes_as_str
|
||||
_list: list[SessionThpCache] = []
|
||||
def get_allocated_session(
|
||||
channel_id: bytes, session_id: bytes
|
||||
) -> SessionThpCache | None:
|
||||
"""
|
||||
Finds and returns the first allocated session matching the given `channel_id` and `session_id`,
|
||||
or `None` if no match is found.
|
||||
|
||||
Raises `Exception` if either channel_id or session_id has an invalid length.
|
||||
"""
|
||||
if len(channel_id) != _CHANNEL_ID_LENGTH or len(session_id) != SESSION_ID_LENGTH:
|
||||
raise Exception("At least one of arguments has invalid length")
|
||||
|
||||
for session in _SESSIONS:
|
||||
if _get_session_state(session) == _UNALLOCATED_STATE:
|
||||
continue
|
||||
if session.channel_id != channel_id:
|
||||
continue
|
||||
_list.append(session)
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"session with channel_id: %s and session_id: %s is in ALLOCATED state",
|
||||
get_bytes_as_str(session.channel_id),
|
||||
get_bytes_as_str(session.session_id),
|
||||
)
|
||||
if session.session_id != session_id:
|
||||
continue
|
||||
return session
|
||||
return None
|
||||
|
||||
return _list
|
||||
|
||||
def is_management_session(session_cache: SessionThpCache) -> bool:
|
||||
return _get_session_state(session_cache) == _MANAGEMENT_STATE
|
||||
|
||||
|
||||
def set_channel_host_ephemeral_key(channel: ChannelCache, key: bytearray) -> None:
|
||||
|
@ -13,6 +13,9 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def create_new_session(channel_ctx: Channel) -> SessionContext:
|
||||
"""
|
||||
Creates new `SessionContext` backed by cache.
|
||||
"""
|
||||
session_cache = cache_thp.get_new_session(channel_ctx.channel_cache)
|
||||
return SessionContext(channel_ctx, session_cache)
|
||||
|
||||
@ -20,18 +23,26 @@ def create_new_session(channel_ctx: Channel) -> SessionContext:
|
||||
def create_new_management_session(
|
||||
channel_ctx: Channel, session_id: int = cache_thp.MANAGEMENT_SESSION_ID
|
||||
) -> ManagementSessionContext:
|
||||
"""
|
||||
Creates new `ManagementSessionContext` that is not backed by cache entry.
|
||||
|
||||
Seed cannot be derived with this type of session.
|
||||
"""
|
||||
return ManagementSessionContext(channel_ctx, session_id)
|
||||
|
||||
|
||||
def get_session_from_cache(
|
||||
channel_ctx: Channel, session_id: int
|
||||
) -> GenericSessionContext | None:
|
||||
cached_sessions = cache_thp.get_allocated_sessions(channel_ctx.channel_id)
|
||||
for s in cached_sessions:
|
||||
print(s, s.channel_id, int.from_bytes(s.session_id, "big"))
|
||||
if (
|
||||
s.channel_id == channel_ctx.channel_id
|
||||
and int.from_bytes(s.session_id, "big") == session_id
|
||||
):
|
||||
return SessionContext(channel_ctx, s)
|
||||
"""
|
||||
Returns a `SessionContext` (or `ManagementSessionContext`) reconstructed from a cache or `None` if backing cache is not found.
|
||||
"""
|
||||
session_id_bytes = session_id.to_bytes(1, "big")
|
||||
session_cache = cache_thp.get_allocated_session(
|
||||
channel_ctx.channel_id, session_id_bytes
|
||||
)
|
||||
if session_cache is None:
|
||||
return None
|
||||
elif cache_thp.is_management_session(session_cache):
|
||||
return ManagementSessionContext(channel_ctx, session_id)
|
||||
return SessionContext(channel_ctx, session_cache)
|
||||
|
Loading…
Reference in New Issue
Block a user