diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 2700fcee8..88ce04063 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -11,15 +11,48 @@ if TYPE_CHECKING: T = TypeVar("T") # THP specific constants +_MAX_UNAUTHENTICATED_CHANNELS_COUNT = const(5) +_MAX_CHANNELS_COUNT = 10 _MAX_SESSIONS_COUNT = const(20) -_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) +_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove + + +_THP_CHANNEL_STATE_LENGTH = const(1) _THP_SESSION_STATE_LENGTH = const(1) +_CHANNEL_ID_LENGTH = const(4) _SESSION_ID_LENGTH = const(4) BROADCAST_CHANNEL_ID = const(65535) -class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache +class UnauthenticatedChannelCache(DataCache): def __init__(self) -> None: + self.channel_id = bytearray(_CHANNEL_ID_LENGTH) + self.fields = () + super().__init__() + + def clear(self) -> None: + self.channel_id[:] = b"" + super().clear() + + +class ChannelCache(UnauthenticatedChannelCache): + def __init__(self) -> None: + self.enc_key = 0 # TODO change + self.dec_key = 1 # TODO change + self.state = bytearray(_THP_CHANNEL_STATE_LENGTH) + self.last_usage = 0 + self.channel_id = bytearray(_CHANNEL_ID_LENGTH) + super().__init__() + + def clear(self) -> None: + self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED + self.last_usage = 0 + super().clear() + + +class SessionThpCache(DataCache): + def __init__(self) -> None: + self.channel_id = bytearray(_CHANNEL_ID_LENGTH) self.session_id = bytearray(_SESSION_ID_LENGTH) self.state = bytearray(_THP_SESSION_STATE_LENGTH) if utils.BITCOIN_ONLY: @@ -44,37 +77,43 @@ class SessionThpCache(DataCache): # TODO implement, this is just copied Session self.last_usage = 0 super().__init__() - def export_session_id(self) -> bytes: - from trezorcrypto import random # avoid pulling in trezor.crypto - - # generate a new session id if we don't have it yet - if not self.session_id: - self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) - # export it as immutable bytes - return bytes(self.session_id) - def clear(self) -> None: - super().clear() self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED self.last_usage = 0 self.session_id[:] = b"" + self.channel_id[:] = b"" + super().clear() +_UNAUTHENTICATED_CHANNELS: list[UnauthenticatedChannelCache] = [] +_CHANNELS: list[ChannelCache] = [] _SESSIONS: list[SessionThpCache] = [] -_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] +_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace def initialize() -> None: + global _UNAUTHENTICATED_CHANNELS + global _CHANNELS global _SESSIONS global _UNAUTHENTICATED_SESSIONS + for _ in range(_MAX_UNAUTHENTICATED_CHANNELS_COUNT): + _UNAUTHENTICATED_CHANNELS.append(UnauthenticatedChannelCache()) + for _ in range(_MAX_CHANNELS_COUNT): + _CHANNELS.append(ChannelCache()) for _ in range(_MAX_SESSIONS_COUNT): _SESSIONS.append(SessionThpCache()) + for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): _UNAUTHENTICATED_SESSIONS.append(SessionThpCache()) + for unauth_channel in _UNAUTHENTICATED_CHANNELS: + unauth_channel.clear() + for channel in _CHANNELS: + channel.clear() for session in _SESSIONS: session.clear() + for session in _UNAUTHENTICATED_SESSIONS: session.clear() @@ -90,7 +129,7 @@ _session_usage_counter = 0 # with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex) -cid_counter: int = 4659 +cid_counter: int = 4659 # TODO change to random value on start def get_active_session_id() -> bytearray | None: