1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

Cache changes, part 1

This commit is contained in:
M1nd3r 2024-03-21 11:57:36 +01:00
parent 3371d8177e
commit 51b3cd5626

View File

@ -11,15 +11,48 @@ if TYPE_CHECKING:
T = TypeVar("T")
# THP specific constants
_MAX_UNAUTHENTICATED_CHANNELS_COUNT = const(5)
_MAX_CHANNELS_COUNT = 10
_MAX_SESSIONS_COUNT = const(20)
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5)
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove
_THP_CHANNEL_STATE_LENGTH = const(1)
_THP_SESSION_STATE_LENGTH = const(1)
_CHANNEL_ID_LENGTH = const(4)
_SESSION_ID_LENGTH = const(4)
BROADCAST_CHANNEL_ID = const(65535)
class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache
class UnauthenticatedChannelCache(DataCache):
def __init__(self) -> None:
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.fields = ()
super().__init__()
def clear(self) -> None:
self.channel_id[:] = b""
super().clear()
class ChannelCache(UnauthenticatedChannelCache):
def __init__(self) -> None:
self.enc_key = 0 # TODO change
self.dec_key = 1 # TODO change
self.state = bytearray(_THP_CHANNEL_STATE_LENGTH)
self.last_usage = 0
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
super().__init__()
def clear(self) -> None:
self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.last_usage = 0
super().clear()
class SessionThpCache(DataCache):
def __init__(self) -> None:
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.state = bytearray(_THP_SESSION_STATE_LENGTH)
if utils.BITCOIN_ONLY:
@ -44,37 +77,43 @@ class SessionThpCache(DataCache): # TODO implement, this is just copied Session
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
from trezorcrypto import random # avoid pulling in trezor.crypto
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
self.last_usage = 0
self.session_id[:] = b""
self.channel_id[:] = b""
super().clear()
_UNAUTHENTICATED_CHANNELS: list[UnauthenticatedChannelCache] = []
_CHANNELS: list[ChannelCache] = []
_SESSIONS: list[SessionThpCache] = []
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = []
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace
def initialize() -> None:
global _UNAUTHENTICATED_CHANNELS
global _CHANNELS
global _SESSIONS
global _UNAUTHENTICATED_SESSIONS
for _ in range(_MAX_UNAUTHENTICATED_CHANNELS_COUNT):
_UNAUTHENTICATED_CHANNELS.append(UnauthenticatedChannelCache())
for _ in range(_MAX_CHANNELS_COUNT):
_CHANNELS.append(ChannelCache())
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionThpCache())
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
for unauth_channel in _UNAUTHENTICATED_CHANNELS:
unauth_channel.clear()
for channel in _CHANNELS:
channel.clear()
for session in _SESSIONS:
session.clear()
for session in _UNAUTHENTICATED_SESSIONS:
session.clear()
@ -90,7 +129,7 @@ _session_usage_counter = 0
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
cid_counter: int = 4659
cid_counter: int = 4659 # TODO change to random value on start
def get_active_session_id() -> bytearray | None: