|
|
|
@ -2,7 +2,7 @@ import builtins
|
|
|
|
|
from micropython import const # pyright: ignore[reportMissingModuleSource]
|
|
|
|
|
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
|
|
|
|
|
|
|
|
|
from storage.cache_common import DataCache, InvalidSessionError
|
|
|
|
|
from storage.cache_common import DataCache
|
|
|
|
|
from trezor import utils
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@ -96,30 +96,22 @@ class SessionThpCache(ConnectionCache):
|
|
|
|
|
|
|
|
|
|
_CHANNELS: list[ChannelCache] = []
|
|
|
|
|
_SESSIONS: list[SessionThpCache] = []
|
|
|
|
|
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize() -> None:
|
|
|
|
|
global _CHANNELS
|
|
|
|
|
global _SESSIONS
|
|
|
|
|
global _UNAUTHENTICATED_SESSIONS
|
|
|
|
|
|
|
|
|
|
for _ in range(_MAX_CHANNELS_COUNT):
|
|
|
|
|
_CHANNELS.append(ChannelCache())
|
|
|
|
|
for _ in range(_MAX_SESSIONS_COUNT):
|
|
|
|
|
_SESSIONS.append(SessionThpCache())
|
|
|
|
|
|
|
|
|
|
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
|
|
|
|
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
|
|
|
|
|
|
|
|
|
|
for channel in _CHANNELS:
|
|
|
|
|
channel.clear()
|
|
|
|
|
for session in _SESSIONS:
|
|
|
|
|
session.clear()
|
|
|
|
|
|
|
|
|
|
for session in _UNAUTHENTICATED_SESSIONS:
|
|
|
|
|
session.clear()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
initialize()
|
|
|
|
|
|
|
|
|
@ -128,8 +120,6 @@ initialize()
|
|
|
|
|
_next_unauthenicated_session_index: int = 0 # TODO remove
|
|
|
|
|
|
|
|
|
|
# First unauthenticated channel will have index 0
|
|
|
|
|
_is_active_session_authenticated: bool
|
|
|
|
|
_active_session_idx: int | None = None
|
|
|
|
|
_usage_counter = 0
|
|
|
|
|
|
|
|
|
|
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
|
|
|
|
@ -256,22 +246,6 @@ def _get_session_state(session: SessionThpCache) -> int:
|
|
|
|
|
return int.from_bytes(session.state, "big")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_active_session_id() -> bytearray | None:
|
|
|
|
|
active_session = get_active_session()
|
|
|
|
|
|
|
|
|
|
if active_session is None:
|
|
|
|
|
return None
|
|
|
|
|
return active_session.session_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_active_session() -> SessionThpCache | None:
|
|
|
|
|
if _active_session_idx is None:
|
|
|
|
|
return None
|
|
|
|
|
if _is_active_session_authenticated:
|
|
|
|
|
return _SESSIONS[_active_session_idx]
|
|
|
|
|
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_next_channel_id() -> bytes:
|
|
|
|
|
global cid_counter
|
|
|
|
|
while True:
|
|
|
|
@ -304,7 +278,7 @@ def _is_session_id_unique(channel: ChannelCache) -> bool:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_cid_unique() -> bool:
|
|
|
|
|
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
|
|
|
|
for session in _SESSIONS:
|
|
|
|
|
if cid_counter == _get_cid(session):
|
|
|
|
|
return False
|
|
|
|
|
return True
|
|
|
|
@ -314,53 +288,6 @@ def _get_cid(session: SessionThpCache) -> int:
|
|
|
|
|
return int.from_bytes(session.session_id[2:], "big")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache:
|
|
|
|
|
if len(session_id) != SESSION_ID_LENGTH:
|
|
|
|
|
raise ValueError("session_id must be X bytes long, where X=", SESSION_ID_LENGTH)
|
|
|
|
|
global _active_session_idx
|
|
|
|
|
global _is_active_session_authenticated
|
|
|
|
|
global _next_unauthenicated_session_index
|
|
|
|
|
|
|
|
|
|
i = _next_unauthenicated_session_index
|
|
|
|
|
_UNAUTHENTICATED_SESSIONS[i] = SessionThpCache()
|
|
|
|
|
_UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id)
|
|
|
|
|
_next_unauthenicated_session_index += 1
|
|
|
|
|
if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT:
|
|
|
|
|
_next_unauthenicated_session_index = 0
|
|
|
|
|
|
|
|
|
|
# Set session as active if and only if there is no active session
|
|
|
|
|
if _active_session_idx is None:
|
|
|
|
|
_active_session_idx = i
|
|
|
|
|
_is_active_session_authenticated = False
|
|
|
|
|
return _UNAUTHENTICATED_SESSIONS[i]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
|
|
|
|
|
for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
|
|
|
|
if unauth_session == _UNAUTHENTICATED_SESSIONS[i]:
|
|
|
|
|
return i
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
|
|
|
|
|
unauth_session_idx = get_unauth_session_index(unauth_session)
|
|
|
|
|
if unauth_session_idx is None:
|
|
|
|
|
raise InvalidSessionError
|
|
|
|
|
|
|
|
|
|
# replace least recently used authenticated session by the new session
|
|
|
|
|
new_auth_session_index = get_least_recently_used_authetnicated_session_index()
|
|
|
|
|
|
|
|
|
|
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
|
|
|
|
|
_UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear()
|
|
|
|
|
|
|
|
|
|
_SESSIONS[new_auth_session_index].last_usage = _get_usage_counter_and_increment()
|
|
|
|
|
return _SESSIONS[new_auth_session_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_least_recently_used_authetnicated_session_index() -> int:
|
|
|
|
|
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_least_recently_used_item(
|
|
|
|
|
list: list[ChannelCache] | list[SessionThpCache], max_count: int
|
|
|
|
|
):
|
|
|
|
@ -373,71 +300,9 @@ def get_least_recently_used_item(
|
|
|
|
|
return lru_item_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
|
|
|
|
|
def start_session(session_id: bytes | None) -> bytes: # TODO incomplete
|
|
|
|
|
global _active_session_idx
|
|
|
|
|
global _is_active_session_authenticated
|
|
|
|
|
|
|
|
|
|
if session_id is not None:
|
|
|
|
|
if get_active_session_id() == session_id:
|
|
|
|
|
return session_id
|
|
|
|
|
for index in range(_MAX_SESSIONS_COUNT):
|
|
|
|
|
if _SESSIONS[index].session_id == session_id:
|
|
|
|
|
_active_session_idx = index
|
|
|
|
|
_is_active_session_authenticated = True
|
|
|
|
|
return session_id
|
|
|
|
|
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
|
|
|
|
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
|
|
|
|
_active_session_idx = index
|
|
|
|
|
_is_active_session_authenticated = False
|
|
|
|
|
return session_id
|
|
|
|
|
|
|
|
|
|
channel = get_new_unauthenticated_channel(b"\x00")
|
|
|
|
|
|
|
|
|
|
new_session_id = get_next_session_id(channel)
|
|
|
|
|
|
|
|
|
|
new_session = create_new_unauthenticated_session(new_session_id)
|
|
|
|
|
|
|
|
|
|
index = get_unauth_session_index(new_session)
|
|
|
|
|
_active_session_idx = index
|
|
|
|
|
_is_active_session_authenticated = False
|
|
|
|
|
|
|
|
|
|
return new_session_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def start_existing_session(session_id: bytes) -> bytes:
|
|
|
|
|
global _active_session_idx
|
|
|
|
|
global _is_active_session_authenticated
|
|
|
|
|
|
|
|
|
|
if session_id is None:
|
|
|
|
|
raise ValueError("session_id cannot be None")
|
|
|
|
|
if get_active_session_id() == session_id:
|
|
|
|
|
return session_id
|
|
|
|
|
for index in range(_MAX_SESSIONS_COUNT):
|
|
|
|
|
if _SESSIONS[index].session_id == session_id:
|
|
|
|
|
_active_session_idx = index
|
|
|
|
|
_is_active_session_authenticated = True
|
|
|
|
|
return session_id
|
|
|
|
|
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
|
|
|
|
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
|
|
|
|
|
_active_session_idx = index
|
|
|
|
|
_is_active_session_authenticated = False
|
|
|
|
|
return session_id
|
|
|
|
|
raise ValueError("There is no active session with provided session_id")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def end_current_session() -> None:
|
|
|
|
|
global _active_session_idx
|
|
|
|
|
active_session = get_active_session()
|
|
|
|
|
if active_session is None:
|
|
|
|
|
return
|
|
|
|
|
active_session.clear()
|
|
|
|
|
_active_session_idx = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_int_all_sessions(key: int) -> builtins.set[int]:
|
|
|
|
|
values = builtins.set()
|
|
|
|
|
for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ?
|
|
|
|
|
for session in _SESSIONS:
|
|
|
|
|
encoded = session.get(key)
|
|
|
|
|
if encoded is not None:
|
|
|
|
|
values.add(int.from_bytes(encoded, "big"))
|
|
|
|
@ -445,7 +310,5 @@ def get_int_all_sessions(key: int) -> builtins.set[int]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_all() -> None:
|
|
|
|
|
global _active_session_idx
|
|
|
|
|
_active_session_idx = None
|
|
|
|
|
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
|
|
|
|
|
for session in _SESSIONS:
|
|
|
|
|
session.clear()
|
|
|
|
|