You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/storage/cache.py

304 lines
9.1 KiB

import gc
from trezorcrypto import random # avoid pulling in trezor.crypto
from trezor import utils
if False:
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
else:
def overload(f) -> None: # type: ignore
pass
_MAX_SESSIONS_COUNT = 10
_SESSIONLESS_FLAG = 128
_SESSION_ID_LENGTH = 32
# Traditional cache keys
APP_COMMON_SEED = 0
APP_COMMON_AUTHORIZATION_TYPE = 1
APP_COMMON_AUTHORIZATION_DATA = 2
APP_COMMON_DERIVE_CARDANO = 3
APP_CARDANO_ICARUS_SECRET = 4
APP_CARDANO_ICARUS_TREZOR_SECRET = 5
APP_MONERO_LIVE_REFRESH = 6
# Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG
APP_COMMON_SAFETY_CHECKS_TEMPORARY = 1 | _SESSIONLESS_FLAG
STORAGE_DEVICE_EXPERIMENTAL_FEATURES = 2 | _SESSIONLESS_FLAG
APP_COMMON_REQUEST_PIN_LAST_UNLOCK = 3 | _SESSIONLESS_FLAG
# === Homescreen storage ===
# This does not logically belong to the "cache" functionality, but the cache module is
# a convenient place to put this.
# When a Homescreen layout is instantiated, it checks the value of `homescreen_shown`
# to know whether it should render itself or whether the result of a previous instance
# is still on. This way we can avoid unnecessary fadeins/fadeouts when a workflow ends.
HOMESCREEN_ON = object()
LOCKSCREEN_ON = object()
homescreen_shown: object | None = None
class InvalidSessionError(Exception):
pass
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][0] = 1
self.data[key][1:] = value
@overload
def get(self, key: int) -> bytes | None:
...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields))
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def is_set(self, key: int) -> bool:
utils.ensure(key < len(self.fields))
return self.data[key][0] == 1
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.delete(i)
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.fields = (
64, # APP_COMMON_SEED
2, # APP_COMMON_AUTHORIZATION_TYPE
128, # APP_COMMON_AUTHORIZATION_DATA
1, # APP_COMMON_DERIVE_CARDANO
96, # APP_CARDANO_ICARUS_SECRET
96, # APP_CARDANO_ICARUS_TREZOR_SECRET
1, # APP_MONERO_LIVE_REFRESH
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
1, # STORAGE_DEVICE_EXPERIMENTAL_FEATURES
4, # APP_COMMON_REQUEST_PIN_LAST_UNLOCK
)
super().__init__()
# XXX
# Allocation notes:
# Instantiation of a DataCache subclass should make as little garbage as possible, so
# that the preallocated bytearrays are compact in memory.
# That is why the initialization is two-step: first create appropriately sized
# bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
_SESSIONLESS_CACHE.clear()
gc.collect()
_active_session_idx: int | None = None
_session_usage_counter = 0
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_idx
if _active_session_idx is None:
return
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def is_session_started() -> bool:
return _active_session_idx is not None
def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return
if _active_session_idx is None:
raise InvalidSessionError
_SESSIONS[_active_session_idx].set(key, value)
@overload
def get(key: int) -> bytes | None:
...
@overload
def get(key: int, default: T) -> bytes | T: # noqa: F811
...
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key, default)
def is_set(key: int) -> bool:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].is_set(key)
def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].delete(key)
if False:
from typing import Awaitable, Callable, TypeVar
ByteFunc = TypeVar("ByteFunc", bound=Callable[..., bytes])
AsyncByteFunc = TypeVar("AsyncByteFunc", bound=Callable[..., Awaitable[bytes]])
def stored(key: int) -> Callable[[ByteFunc], ByteFunc]:
def decorator(func: ByteFunc) -> ByteFunc:
# if we didn't check this, it would be easy to store an Awaitable[something]
# in cache, which might prove hard to debug
# XXX mypy should be checking this now, but we don't have full coverage yet
assert not isinstance(func, type(lambda: (yield))), "use stored_async instead"
def wrapper(*args, **kwargs): # type: ignore
value = get(key)
if value is None:
value = func(*args, **kwargs)
set(key, value)
return value
return wrapper # type: ignore
return decorator
def stored_async(key: int) -> Callable[[AsyncByteFunc], AsyncByteFunc]:
def decorator(func: AsyncByteFunc) -> AsyncByteFunc:
# assert isinstance(func, type(lambda: (yield))), "do not use stored_async"
# XXX the test above fails for closures
# We shouldn't need this test here anyway: the 'await func()' should fail
# with functions that do not return an awaitable so the problem is more visible.
async def wrapper(*args, **kwargs): # type: ignore
value = get(key)
if value is None:
value = await func(*args, **kwargs)
set(key, value)
return value
return wrapper # type: ignore
return decorator
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
_SESSIONLESS_CACHE.clear()
for session in _SESSIONS:
session.clear()