From f75ee29ffa532f2fa0f638ad63938842bbc5edda Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Thu, 21 Mar 2024 11:57:36 +0100 Subject: [PATCH] feat(core): make changes to thp cache, part 1 --- core/src/storage/cache_thp.py | 167 +++++++++++++++++++++++++++------- 1 file changed, 132 insertions(+), 35 deletions(-) diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 2700fcee8..0962d0890 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -11,17 +11,53 @@ if TYPE_CHECKING: T = TypeVar("T") # THP specific constants +_MAX_CHANNELS_COUNT = 10 _MAX_SESSIONS_COUNT = const(20) -_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) -_THP_SESSION_STATE_LENGTH = const(1) +_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove + + +_CHANNEL_STATE_LENGTH = const(1) +_WIRE_INTERFACE_LENGTH = const(1) +_SESSION_STATE_LENGTH = const(1) +_CHANNEL_ID_LENGTH = const(4) _SESSION_ID_LENGTH = const(4) BROADCAST_CHANNEL_ID = const(65535) +_UNALLOCATED_STATE = const(0) + + +class ConnectionCache(DataCache): + def __init__(self) -> None: + self.channel_id = bytearray(_CHANNEL_ID_LENGTH) + self.last_usage = 0 + super().__init__() + + def clear(self) -> None: + self.channel_id[:] = b"" + self.last_usage = 0 + super().clear() -class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache + +class ChannelCache(ConnectionCache): + def __init__(self) -> None: + self.enc_key = 0 # TODO change + self.dec_key = 1 # TODO change + self.state = bytearray(_CHANNEL_STATE_LENGTH) + self.iface = bytearray(1) # TODO add decoding + super().__init__() + + def clear(self) -> None: + self.state[:] = bytearray( + int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big") + ) # Set state to UNALLOCATED + # TODO clear all sessions that are under this channel + super().clear() + + +class SessionThpCache(ConnectionCache): def __init__(self) -> None: self.session_id = bytearray(_SESSION_ID_LENGTH) - self.state = bytearray(_THP_SESSION_STATE_LENGTH) + self.state = bytearray(_SESSION_STATE_LENGTH) if utils.BITCOIN_ONLY: self.fields = ( 64, # APP_COMMON_SEED @@ -44,37 +80,35 @@ class SessionThpCache(DataCache): # TODO implement, this is just copied Session self.last_usage = 0 super().__init__() - def export_session_id(self) -> bytes: - from trezorcrypto import random # avoid pulling in trezor.crypto - - # generate a new session id if we don't have it yet - if not self.session_id: - self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) - # export it as immutable bytes - return bytes(self.session_id) - def clear(self) -> None: - super().clear() - self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED - self.last_usage = 0 + self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED self.session_id[:] = b"" + super().clear() +_CHANNELS: list[ChannelCache] = [] _SESSIONS: list[SessionThpCache] = [] -_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] +_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace def initialize() -> None: + global _CHANNELS global _SESSIONS global _UNAUTHENTICATED_SESSIONS + for _ in range(_MAX_CHANNELS_COUNT): + _CHANNELS.append(ChannelCache()) for _ in range(_MAX_SESSIONS_COUNT): _SESSIONS.append(SessionThpCache()) + for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): _UNAUTHENTICATED_SESSIONS.append(SessionThpCache()) + for channel in _CHANNELS: + channel.clear() for session in _SESSIONS: session.clear() + for session in _UNAUTHENTICATED_SESSIONS: session.clear() @@ -83,14 +117,71 @@ initialize() # THP vars -_next_unauthenicated_session_index: int = 0 +_next_unauthenicated_session_index: int = 0 # TODO remove + +# First unauthenticated channel will have index 0 _is_active_session_authenticated: bool _active_session_idx: int | None = None -_session_usage_counter = 0 - +_usage_counter = 0 # with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex) -cid_counter: int = 4659 +cid_counter: int = 4659 # TODO change to random value on start + + +def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache: + if len(iface) != _WIRE_INTERFACE_LENGTH: + raise Exception("Invalid WireInterface (encoded) length") + + new_cid = get_next_channel_id() + index = _get_next_unauthenticated_channel_index() + + _CHANNELS[index] = ChannelCache() + _CHANNELS[index].channel_id[:] = new_cid + _CHANNELS[index].last_usage = _get_usage_counter_and_increment() + _CHANNELS[index].state = bytearray( + _UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big") + ) + _CHANNELS[index].iface = bytearray(iface) + return _CHANNELS[index] + + +def get_all_allocated_channels() -> list[ChannelCache]: + _list: list[ChannelCache] = [] + for channel in _CHANNELS: + if _get_channel_state(channel) != _UNALLOCATED_STATE: + _list.append(channel) + return _list + + +def _get_usage_counter() -> int: + global _usage_counter + return _usage_counter + + +def _get_usage_counter_and_increment() -> int: + global _usage_counter + _usage_counter += 1 + return _usage_counter + + +def _get_next_unauthenticated_channel_index() -> int: + idx = _get_unallocated_channel_index() + if idx is not None: + return idx + return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT) + + +def _get_unallocated_channel_index() -> int | None: + for i in range(_MAX_CHANNELS_COUNT): + if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE: + return i + return None + + +def _get_channel_state(channel: ChannelCache) -> int: + if channel is None: + return _UNALLOCATED_STATE + return int.from_bytes(channel.state, "big") def get_active_session_id() -> bytearray | None: @@ -109,7 +200,10 @@ def get_active_session() -> SessionThpCache | None: return _UNAUTHENTICATED_SESSIONS[_active_session_idx] -def get_next_channel_id() -> int: +_session_usage_counter = 0 + + +def get_next_channel_id() -> bytes: global cid_counter while True: cid_counter += 1 @@ -117,7 +211,7 @@ def get_next_channel_id() -> int: cid_counter = 1 if _is_cid_unique(): break - return cid_counter + return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big") def _is_cid_unique() -> bool: @@ -160,8 +254,6 @@ def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None: def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache: - global _session_usage_counter - unauth_session_idx = get_unauth_session_index(unauth_session) if unauth_session_idx is None: raise InvalidSessionError @@ -172,19 +264,24 @@ def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache: _SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx] _UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear() - _session_usage_counter += 1 - _SESSIONS[new_auth_session_index].last_usage = _session_usage_counter + _SESSIONS[new_auth_session_index].last_usage = _get_usage_counter_and_increment() return _SESSIONS[new_auth_session_index] def get_least_recently_used_authetnicated_session_index() -> int: - lru_counter = _session_usage_counter - lru_session_idx = 0 - for i in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[i].last_usage < lru_counter: - lru_counter = _SESSIONS[i].last_usage - lru_session_idx = i - return lru_session_idx + return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT) + + +def get_least_recently_used_item( + list: list[ChannelCache] | list[SessionThpCache], max_count: int +): + lru_counter = _get_usage_counter() + lru_item_index = 0 + for i in range(max_count): + if list[i].last_usage < lru_counter: + lru_counter = list[i].last_usage + lru_item_index = i + return lru_item_index # The function start_session should not be used in production code. It is present only to assure compatibility with old tests. @@ -205,7 +302,7 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete _active_session_idx = index _is_active_session_authenticated = False return session_id - new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big") + new_session_id = b"\x00\x00" + get_next_channel_id() new_session = create_new_unauthenticated_session(new_session_id)