mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
feat(core): make changes to thp cache, part 1
This commit is contained in:
parent
fb99d1dbe6
commit
f75ee29ffa
@ -11,17 +11,53 @@ if TYPE_CHECKING:
|
||||
T = TypeVar("T")
|
||||
|
||||
# THP specific constants
|
||||
_MAX_CHANNELS_COUNT = 10
|
||||
_MAX_SESSIONS_COUNT = const(20)
|
||||
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5)
|
||||
_THP_SESSION_STATE_LENGTH = const(1)
|
||||
_MAX_UNAUTHENTICATED_SESSIONS_COUNT = const(5) # TODO remove
|
||||
|
||||
|
||||
_CHANNEL_STATE_LENGTH = const(1)
|
||||
_WIRE_INTERFACE_LENGTH = const(1)
|
||||
_SESSION_STATE_LENGTH = const(1)
|
||||
_CHANNEL_ID_LENGTH = const(4)
|
||||
_SESSION_ID_LENGTH = const(4)
|
||||
BROADCAST_CHANNEL_ID = const(65535)
|
||||
|
||||
_UNALLOCATED_STATE = const(0)
|
||||
|
||||
class SessionThpCache(DataCache): # TODO implement, this is just copied SessionCache
|
||||
|
||||
class ConnectionCache(DataCache):
|
||||
def __init__(self) -> None:
|
||||
self.channel_id = bytearray(_CHANNEL_ID_LENGTH)
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.channel_id[:] = b""
|
||||
self.last_usage = 0
|
||||
super().clear()
|
||||
|
||||
|
||||
class ChannelCache(ConnectionCache):
|
||||
def __init__(self) -> None:
|
||||
self.enc_key = 0 # TODO change
|
||||
self.dec_key = 1 # TODO change
|
||||
self.state = bytearray(_CHANNEL_STATE_LENGTH)
|
||||
self.iface = bytearray(1) # TODO add decoding
|
||||
super().__init__()
|
||||
|
||||
def clear(self) -> None:
|
||||
self.state[:] = bytearray(
|
||||
int.to_bytes(0, _CHANNEL_STATE_LENGTH, "big")
|
||||
) # Set state to UNALLOCATED
|
||||
# TODO clear all sessions that are under this channel
|
||||
super().clear()
|
||||
|
||||
|
||||
class SessionThpCache(ConnectionCache):
|
||||
def __init__(self) -> None:
|
||||
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||
self.state = bytearray(_THP_SESSION_STATE_LENGTH)
|
||||
self.state = bytearray(_SESSION_STATE_LENGTH)
|
||||
if utils.BITCOIN_ONLY:
|
||||
self.fields = (
|
||||
64, # APP_COMMON_SEED
|
||||
@ -44,37 +80,35 @@ class SessionThpCache(DataCache): # TODO implement, this is just copied Session
|
||||
self.last_usage = 0
|
||||
super().__init__()
|
||||
|
||||
def export_session_id(self) -> bytes:
|
||||
from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
|
||||
# generate a new session id if we don't have it yet
|
||||
if not self.session_id:
|
||||
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||
# export it as immutable bytes
|
||||
return bytes(self.session_id)
|
||||
|
||||
def clear(self) -> None:
|
||||
super().clear()
|
||||
self.state = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
|
||||
self.last_usage = 0
|
||||
self.state[:] = bytearray(int.to_bytes(0, 1, "big")) # Set state to UNALLOCATED
|
||||
self.session_id[:] = b""
|
||||
super().clear()
|
||||
|
||||
|
||||
_CHANNELS: list[ChannelCache] = []
|
||||
_SESSIONS: list[SessionThpCache] = []
|
||||
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = []
|
||||
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace
|
||||
|
||||
|
||||
def initialize() -> None:
|
||||
global _CHANNELS
|
||||
global _SESSIONS
|
||||
global _UNAUTHENTICATED_SESSIONS
|
||||
|
||||
for _ in range(_MAX_CHANNELS_COUNT):
|
||||
_CHANNELS.append(ChannelCache())
|
||||
for _ in range(_MAX_SESSIONS_COUNT):
|
||||
_SESSIONS.append(SessionThpCache())
|
||||
|
||||
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
|
||||
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
|
||||
|
||||
for channel in _CHANNELS:
|
||||
channel.clear()
|
||||
for session in _SESSIONS:
|
||||
session.clear()
|
||||
|
||||
for session in _UNAUTHENTICATED_SESSIONS:
|
||||
session.clear()
|
||||
|
||||
@ -83,14 +117,71 @@ initialize()
|
||||
|
||||
|
||||
# THP vars
|
||||
_next_unauthenicated_session_index: int = 0
|
||||
_next_unauthenicated_session_index: int = 0 # TODO remove
|
||||
|
||||
# First unauthenticated channel will have index 0
|
||||
_is_active_session_authenticated: bool
|
||||
_active_session_idx: int | None = None
|
||||
_session_usage_counter = 0
|
||||
|
||||
_usage_counter = 0
|
||||
|
||||
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
|
||||
cid_counter: int = 4659
|
||||
cid_counter: int = 4659 # TODO change to random value on start
|
||||
|
||||
|
||||
def get_new_unauthenticated_channel(iface: bytes) -> ChannelCache:
|
||||
if len(iface) != _WIRE_INTERFACE_LENGTH:
|
||||
raise Exception("Invalid WireInterface (encoded) length")
|
||||
|
||||
new_cid = get_next_channel_id()
|
||||
index = _get_next_unauthenticated_channel_index()
|
||||
|
||||
_CHANNELS[index] = ChannelCache()
|
||||
_CHANNELS[index].channel_id[:] = new_cid
|
||||
_CHANNELS[index].last_usage = _get_usage_counter_and_increment()
|
||||
_CHANNELS[index].state = bytearray(
|
||||
_UNALLOCATED_STATE.to_bytes(_CHANNEL_STATE_LENGTH, "big")
|
||||
)
|
||||
_CHANNELS[index].iface = bytearray(iface)
|
||||
return _CHANNELS[index]
|
||||
|
||||
|
||||
def get_all_allocated_channels() -> list[ChannelCache]:
|
||||
_list: list[ChannelCache] = []
|
||||
for channel in _CHANNELS:
|
||||
if _get_channel_state(channel) != _UNALLOCATED_STATE:
|
||||
_list.append(channel)
|
||||
return _list
|
||||
|
||||
|
||||
def _get_usage_counter() -> int:
|
||||
global _usage_counter
|
||||
return _usage_counter
|
||||
|
||||
|
||||
def _get_usage_counter_and_increment() -> int:
|
||||
global _usage_counter
|
||||
_usage_counter += 1
|
||||
return _usage_counter
|
||||
|
||||
|
||||
def _get_next_unauthenticated_channel_index() -> int:
|
||||
idx = _get_unallocated_channel_index()
|
||||
if idx is not None:
|
||||
return idx
|
||||
return get_least_recently_used_item(_CHANNELS, max_count=_MAX_CHANNELS_COUNT)
|
||||
|
||||
|
||||
def _get_unallocated_channel_index() -> int | None:
|
||||
for i in range(_MAX_CHANNELS_COUNT):
|
||||
if _get_channel_state(_CHANNELS[i]) is _UNALLOCATED_STATE:
|
||||
return i
|
||||
return None
|
||||
|
||||
|
||||
def _get_channel_state(channel: ChannelCache) -> int:
|
||||
if channel is None:
|
||||
return _UNALLOCATED_STATE
|
||||
return int.from_bytes(channel.state, "big")
|
||||
|
||||
|
||||
def get_active_session_id() -> bytearray | None:
|
||||
@ -109,7 +200,10 @@ def get_active_session() -> SessionThpCache | None:
|
||||
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
|
||||
|
||||
|
||||
def get_next_channel_id() -> int:
|
||||
_session_usage_counter = 0
|
||||
|
||||
|
||||
def get_next_channel_id() -> bytes:
|
||||
global cid_counter
|
||||
while True:
|
||||
cid_counter += 1
|
||||
@ -117,7 +211,7 @@ def get_next_channel_id() -> int:
|
||||
cid_counter = 1
|
||||
if _is_cid_unique():
|
||||
break
|
||||
return cid_counter
|
||||
return cid_counter.to_bytes(_CHANNEL_ID_LENGTH, "big")
|
||||
|
||||
|
||||
def _is_cid_unique() -> bool:
|
||||
@ -160,8 +254,6 @@ def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
|
||||
|
||||
|
||||
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
|
||||
global _session_usage_counter
|
||||
|
||||
unauth_session_idx = get_unauth_session_index(unauth_session)
|
||||
if unauth_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
@ -172,19 +264,24 @@ def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
|
||||
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
|
||||
_UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear()
|
||||
|
||||
_session_usage_counter += 1
|
||||
_SESSIONS[new_auth_session_index].last_usage = _session_usage_counter
|
||||
_SESSIONS[new_auth_session_index].last_usage = _get_usage_counter_and_increment()
|
||||
return _SESSIONS[new_auth_session_index]
|
||||
|
||||
|
||||
def get_least_recently_used_authetnicated_session_index() -> int:
|
||||
lru_counter = _session_usage_counter
|
||||
lru_session_idx = 0
|
||||
for i in range(_MAX_SESSIONS_COUNT):
|
||||
if _SESSIONS[i].last_usage < lru_counter:
|
||||
lru_counter = _SESSIONS[i].last_usage
|
||||
lru_session_idx = i
|
||||
return lru_session_idx
|
||||
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
|
||||
|
||||
|
||||
def get_least_recently_used_item(
|
||||
list: list[ChannelCache] | list[SessionThpCache], max_count: int
|
||||
):
|
||||
lru_counter = _get_usage_counter()
|
||||
lru_item_index = 0
|
||||
for i in range(max_count):
|
||||
if list[i].last_usage < lru_counter:
|
||||
lru_counter = list[i].last_usage
|
||||
lru_item_index = i
|
||||
return lru_item_index
|
||||
|
||||
|
||||
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
|
||||
@ -205,7 +302,7 @@ def start_session(session_id: bytes | None) -> bytes: # TODO incomplete
|
||||
_active_session_idx = index
|
||||
_is_active_session_authenticated = False
|
||||
return session_id
|
||||
new_session_id = b"\x00\x00" + get_next_channel_id().to_bytes(2, "big")
|
||||
new_session_id = b"\x00\x00" + get_next_channel_id()
|
||||
|
||||
new_session = create_new_unauthenticated_session(new_session_id)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user