diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 5f85ba15da..f9fe4f9518 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -173,7 +173,7 @@ async def handle_CancelAuthorization( raise wire.ProcessError("No preauthorized operation") authorization.__del__() - storage.cache.delete(storage.cache.APP_BASE_AUTHORIZATION) + storage.cache.set(storage.cache.APP_BASE_AUTHORIZATION, b"") return Success(message="Authorization cancelled") diff --git a/core/src/apps/common/safety_checks.py b/core/src/apps/common/safety_checks.py index c770fd6ab9..2c19d9771c 100644 --- a/core/src/apps/common/safety_checks.py +++ b/core/src/apps/common/safety_checks.py @@ -12,11 +12,9 @@ def read_setting() -> EnumTypeSafetyCheckLevel: """ Returns the effective safety check level. """ - temporary_safety_check_level: EnumTypeSafetyCheckLevel | None = storage.cache.get( - APP_COMMON_SAFETY_CHECKS_TEMPORARY - ) - if temporary_safety_check_level is not None: - return temporary_safety_check_level + temporary_safety_check_level = storage.cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + if temporary_safety_check_level: + return int.from_bytes(temporary_safety_check_level, "big") # type: ignore else: stored = storage.device.safety_check_level() if stored == SAFETY_CHECK_LEVEL_STRICT: @@ -32,14 +30,14 @@ def apply_setting(level: EnumTypeSafetyCheckLevel) -> None: Changes the safety level settings. """ if level == SafetyCheckLevel.Strict: - storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"") storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) elif level == SafetyCheckLevel.PromptAlways: - storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"") storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) elif level == SafetyCheckLevel.PromptTemporarily: storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) - storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level) + storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) else: raise ValueError("Unknown SafetyCheckLevel") diff --git a/core/src/apps/monero/live_refresh.py b/core/src/apps/monero/live_refresh.py index 152a9374f1..d0a3342cf0 100644 --- a/core/src/apps/monero/live_refresh.py +++ b/core/src/apps/monero/live_refresh.py @@ -51,7 +51,7 @@ async def _init_step( if not storage.cache.get(storage.cache.APP_MONERO_LIVE_REFRESH): await confirms.require_confirm_live_refresh(ctx) - storage.cache.set(storage.cache.APP_MONERO_LIVE_REFRESH, True) + storage.cache.set(storage.cache.APP_MONERO_LIVE_REFRESH, b"\x01") s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index c1f39f70cc..b907625afa 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,11 +1,14 @@ -from trezor import wire +import gc + +from trezor import utils from trezor.crypto import random if False: - from typing import Any + from typing import Sequence _MAX_SESSIONS_COUNT = 10 _SESSIONLESS_FLAG = 128 +_SESSION_ID_LENGTH = 32 # Traditional cache keys APP_COMMON_SEED = 0 @@ -14,100 +17,181 @@ APP_MONERO_LIVE_REFRESH = 2 APP_BASE_AUTHORIZATION = 3 # Keys that are valid across sessions -APP_COMMON_SEED_WITHOUT_PASSPHRASE = 1 | _SESSIONLESS_FLAG -APP_COMMON_SAFETY_CHECKS_TEMPORARY = 2 | _SESSIONLESS_FLAG +APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG +APP_COMMON_SAFETY_CHECKS_TEMPORARY = 1 | _SESSIONLESS_FLAG -_active_session_id: bytes | None = None -_caches: dict[bytes, dict[int, Any]] = {} -_session_ids: list[bytes] = [] -_sessionless_cache: dict[int, Any] = {} - -if False: - from typing import Any, Callable, TypeVar - - F = TypeVar("F", bound=Callable[..., Any]) +class InvalidSessionError(Exception): + pass -def _move_session_ids_queue(session_id: bytes) -> None: - # Move the LRU session ids queue. - if session_id in _session_ids: - _session_ids.remove(session_id) +class DataCache: + fields: Sequence[int] - while len(_session_ids) >= _MAX_SESSIONS_COUNT: - remove_session_id = _session_ids.pop() - del _caches[remove_session_id] + def __init__(self) -> None: + self.data = [bytearray(f) for f in self.fields] - _session_ids.insert(0, session_id) + def set(self, key: int, value: bytes) -> None: + utils.ensure(key < len(self.fields)) + utils.ensure(len(value) <= self.fields[key]) + self.data[key][:] = value + + def get(self, key: int) -> bytes: + utils.ensure(key < len(self.fields), "failed to load key %d" % key) + return bytes(self.data[key]) + + def clear(self) -> None: + for i in range(len(self.fields)): + self.set(i, b"") + + +class SessionCache(DataCache): + def __init__(self) -> None: + self.session_id = bytearray(_SESSION_ID_LENGTH) + self.fields = ( + 64, # APP_COMMON_SEED + 128, # APP_CARDANO_ROOT + 1, # APP_MONERO_LIVE_REFRESH + 128, # APP_BASE_AUTHORIZATION + ) + self.last_usage = 0 + super().__init__() + + def export_session_id(self) -> bytes: + # generate a new session id if we don't have it yet + if not self.session_id: + self.session_id[:] = random.bytes(_SESSION_ID_LENGTH) + # export it as immutable bytes + return bytes(self.session_id) + + def clear(self) -> None: + super().clear() + self.last_usage = 0 + self.session_id[:] = b"" + + +class SessionlessCache(DataCache): + def __init__(self) -> None: + self.fields = ( + 64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE + 1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY + ) + super().__init__() + + +# XXX +# Allocation notes: +# Instantiation of a DataCache subclass should make as little garbage as possible, so +# that the preallocated bytearrays are compact in memory. +# That is why the initialization is two-step: first create appropriately sized +# bytearrays, then later call `clear()` on all the existing objects, which resets them +# to zero length. This is producing some trash - `b[:]` allocates a slice. + +_SESSIONS: list[SessionCache] = [] +for _ in range(_MAX_SESSIONS_COUNT): + _SESSIONS.append(SessionCache()) + +_SESSIONLESS_CACHE = SessionlessCache() + +for session in _SESSIONS: + session.clear() +_SESSIONLESS_CACHE.clear() + +gc.collect() + + +_active_session_idx: int | None = None +_session_usage_counter = 0 def start_session(received_session_id: bytes | None = None) -> bytes: - if received_session_id and received_session_id in _session_ids: - session_id = received_session_id - else: - session_id = random.bytes(32) - _caches[session_id] = {} + global _active_session_idx + global _session_usage_counter - global _active_session_id - _active_session_id = session_id - _move_session_ids_queue(session_id) - return _active_session_id + if ( + received_session_id is not None + and len(received_session_id) != _SESSION_ID_LENGTH + ): + # Prevent the caller from setting received_session_id=b"" and finding a cleared + # session. More generally, short-circuit the session id search, because we know + # that wrong-length session ids should not be in cache. + # Reduce to "session id not provided" case because that's what we do when + # caller supplies an id that is not found. + received_session_id = None + + _session_usage_counter += 1 + + # attempt to find specified session id + if received_session_id: + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].session_id == received_session_id: + _active_session_idx = i + _SESSIONS[i].last_usage = _session_usage_counter + return received_session_id + + # allocate least recently used session + lru_counter = _session_usage_counter + lru_session_idx = 0 + for i in range(_MAX_SESSIONS_COUNT): + if _SESSIONS[i].last_usage < lru_counter: + lru_counter = _SESSIONS[i].last_usage + lru_session_idx = i + + _active_session_idx = lru_session_idx + selected_session = _SESSIONS[lru_session_idx] + selected_session.clear() + selected_session.last_usage = _session_usage_counter + return selected_session.export_session_id() def end_current_session() -> None: - global _active_session_id + global _active_session_idx - if _active_session_id is None: + if _active_session_idx is None: return - current_session_id = _active_session_id - _active_session_id = None - - _session_ids.remove(current_session_id) - del _caches[current_session_id] + _SESSIONS[_active_session_idx].clear() + _active_session_idx = None def is_session_started() -> bool: - return _active_session_id is not None + return _active_session_idx is not None -def set(key: int, value: Any) -> None: +def set(key: int, value: bytes) -> None: if key & _SESSIONLESS_FLAG: - _sessionless_cache[key] = value + _SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value) return - if _active_session_id is None: - raise wire.InvalidSession - _caches[_active_session_id][key] = value + if _active_session_idx is None: + raise InvalidSessionError + _SESSIONS[_active_session_idx].set(key, value) -def get(key: int) -> Any: +def get(key: int) -> bytes: if key & _SESSIONLESS_FLAG: - return _sessionless_cache.get(key) - if _active_session_id is None: - raise wire.InvalidSession - return _caches[_active_session_id].get(key) + return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG) + if _active_session_idx is None: + raise InvalidSessionError + return _SESSIONS[_active_session_idx].get(key) -def delete(key: int) -> None: - if key & _SESSIONLESS_FLAG: - if key in _sessionless_cache: - del _sessionless_cache[key] - return - if _active_session_id is None: - raise wire.InvalidSession - if key in _caches[_active_session_id]: - del _caches[_active_session_id][key] +if False: + from typing import Awaitable, Callable, TypeVar + + ByteFunc = TypeVar("ByteFunc", bound=Callable[..., bytes]) + AsyncByteFunc = TypeVar("AsyncByteFunc", bound=Callable[..., Awaitable[bytes]]) -def stored(key: int) -> Callable[[F], F]: - def decorator(func: F) -> F: +def stored(key: int) -> Callable[[ByteFunc], ByteFunc]: + def decorator(func: ByteFunc) -> ByteFunc: # if we didn't check this, it would be easy to store an Awaitable[something] # in cache, which might prove hard to debug + # XXX mypy should be checking this now, but we don't have full coverage yet assert not isinstance(func, type(lambda: (yield))), "use stored_async instead" def wrapper(*args, **kwargs): # type: ignore value = get(key) - if value is None: + if not value: value = func(*args, **kwargs) set(key, value) return value @@ -117,8 +201,8 @@ def stored(key: int) -> Callable[[F], F]: return decorator -def stored_async(key: int) -> Callable[[F], F]: - def decorator(func: F) -> F: +def stored_async(key: int) -> Callable[[AsyncByteFunc], AsyncByteFunc]: + def decorator(func: AsyncByteFunc) -> AsyncByteFunc: # assert isinstance(func, type(lambda: (yield))), "do not use stored_async" # XXX the test above fails for closures # We shouldn't need this test here anyway: the 'await func()' should fail @@ -126,7 +210,7 @@ def stored_async(key: int) -> Callable[[F], F]: async def wrapper(*args, **kwargs): # type: ignore value = get(key) - if value is None: + if not value: value = await func(*args, **kwargs) set(key, value) return value @@ -137,12 +221,9 @@ def stored_async(key: int) -> Callable[[F], F]: def clear_all() -> None: - global _active_session_id - global _caches - global _session_ids - global _sessionless_cache + global _active_session_idx - _active_session_id = None - _caches.clear() - _session_ids.clear() - _sessionless_cache.clear() + _active_session_idx = None + _SESSIONLESS_CACHE.clear() + for session in _SESSIONS: + session.clear() diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index c0de662edb..90de32ba92 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -36,6 +36,7 @@ reads the message's header. When the message type is known the first handler is """ import protobuf +from storage.cache import InvalidSessionError from trezor import log, loop, messages, ui, utils, workflow from trezor.messages import FailureType from trezor.messages.Failure import Failure @@ -492,6 +493,8 @@ def failure(exc: BaseException) -> Failure: return Failure(code=exc.code, message=exc.message) elif isinstance(exc, loop.TaskClosed): return Failure(code=FailureType.ActionCancelled, message="Cancelled") + elif isinstance(exc, InvalidSessionError): + return Failure(code=FailureType.InvalidSession, message="Invalid session") else: return Failure(code=FailureType.FirmwareError, message="Firmware error") diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index a076f8b6f7..220db6b6ef 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -4,11 +4,11 @@ from mock_storage import mock_storage from storage import cache from trezor.messages.Initialize import Initialize from trezor.messages.EndSession import EndSession -from trezor.wire import DUMMY_CONTEXT, InvalidSession +from trezor.wire import DUMMY_CONTEXT from apps.base import handle_Initialize, handle_EndSession -KEY = 99 +KEY = 0 class TestStorageCache(unittest.TestCase): @@ -22,18 +22,18 @@ class TestStorageCache(unittest.TestCase): self.assertNotEqual(session_id_a, session_id_b) cache.clear_all() - with self.assertRaises(InvalidSession): + with self.assertRaises(cache.InvalidSessionError): cache.set(KEY, "something") - with self.assertRaises(InvalidSession): + with self.assertRaises(cache.InvalidSessionError): cache.get(KEY) def test_end_session(self): session_id = cache.start_session() self.assertTrue(cache.is_session_started()) - cache.set(KEY, "A") + cache.set(KEY, b"A") cache.end_current_session() self.assertFalse(cache.is_session_started()) - self.assertRaises(InvalidSession, cache.get, KEY) + self.assertRaises(cache.InvalidSessionError, cache.get, KEY) # ending an ended session should be a no-op cache.end_current_session() @@ -43,7 +43,7 @@ class TestStorageCache(unittest.TestCase): # original session no longer exists self.assertNotEqual(session_id_a, session_id) # original session data no longer exists - self.assertIsNone(cache.get(KEY)) + self.assertEqual(cache.get(KEY), b"") # create a new session session_id_b = cache.start_session() @@ -59,28 +59,28 @@ class TestStorageCache(unittest.TestCase): def test_session_queue(self): session_id = cache.start_session() self.assertEqual(cache.start_session(session_id), session_id) - cache.set(KEY, "A") + cache.set(KEY, b"A") for i in range(cache._MAX_SESSIONS_COUNT): cache.start_session() self.assertNotEqual(cache.start_session(session_id), session_id) - self.assertIsNone(cache.get(KEY)) + self.assertEqual(cache.get(KEY), b"") def test_get_set(self): session_id1 = cache.start_session() - cache.set(KEY, "hello") - self.assertEqual(cache.get(KEY), "hello") + cache.set(KEY, b"hello") + self.assertEqual(cache.get(KEY), b"hello") session_id2 = cache.start_session() - cache.set(KEY, "world") - self.assertEqual(cache.get(KEY), "world") + cache.set(KEY, b"world") + self.assertEqual(cache.get(KEY), b"world") cache.start_session(session_id2) - self.assertEqual(cache.get(KEY), "world") + self.assertEqual(cache.get(KEY), b"world") cache.start_session(session_id1) - self.assertEqual(cache.get(KEY), "hello") + self.assertEqual(cache.get(KEY), b"hello") cache.clear_all() - with self.assertRaises(InvalidSession): + with self.assertRaises(cache.InvalidSessionError): cache.get(KEY) def test_decorator_mismatch(self): @@ -98,34 +98,34 @@ class TestStorageCache(unittest.TestCase): def func(): nonlocal run_count run_count += 1 - return "foo" + return b"foo" # cache is empty - self.assertIsNone(cache.get(KEY)) + self.assertEqual(cache.get(KEY), b"") self.assertEqual(run_count, 0) - self.assertEqual(func(), "foo") + self.assertEqual(func(), b"foo") # function was run self.assertEqual(run_count, 1) - self.assertEqual(cache.get(KEY), "foo") + self.assertEqual(cache.get(KEY), b"foo") # function does not run again but returns cached value - self.assertEqual(func(), "foo") + self.assertEqual(func(), b"foo") self.assertEqual(run_count, 1) @cache.stored_async(KEY) async def async_func(): nonlocal run_count run_count += 1 - return "bar" + return b"bar" # cache is still full - self.assertEqual(await_result(async_func()), "foo") + self.assertEqual(await_result(async_func()), b"foo") self.assertEqual(run_count, 1) cache.start_session() - self.assertEqual(await_result(async_func()), "bar") + self.assertEqual(await_result(async_func()), b"bar") self.assertEqual(run_count, 2) # awaitable is also run only once - self.assertEqual(await_result(async_func()), "bar") + self.assertEqual(await_result(async_func()), b"bar") self.assertEqual(run_count, 2) @mock_storage @@ -144,31 +144,31 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(session_id, features.session_id) # store "hello" - cache.set(KEY, "hello") + cache.set(KEY, b"hello") # check that it is cleared features = call_Initialize() session_id = features.session_id - self.assertIsNone(cache.get(KEY)) + self.assertEqual(cache.get(KEY), b"") # store "hello" again - cache.set(KEY, "hello") - self.assertEqual(cache.get(KEY), "hello") + cache.set(KEY, b"hello") + self.assertEqual(cache.get(KEY), b"hello") # supplying a different session ID starts a new cache - call_Initialize(session_id=b"A") - self.assertIsNone(cache.get(KEY)) + call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH) + self.assertEqual(cache.get(KEY), b"") # but resuming a session loads the previous one call_Initialize(session_id=session_id) - self.assertEqual(cache.get(KEY), "hello") + self.assertEqual(cache.get(KEY), b"hello") def test_EndSession(self): - self.assertRaises(InvalidSession, cache.get, KEY) + self.assertRaises(cache.InvalidSessionError, cache.get, KEY) session_id = cache.start_session() self.assertTrue(cache.is_session_started()) - self.assertIsNone(cache.get(KEY)) + self.assertEqual(cache.get(KEY), b"") await_result(handle_EndSession(DUMMY_CONTEXT, EndSession())) self.assertFalse(cache.is_session_started()) - self.assertRaises(InvalidSession, cache.get, KEY) + self.assertRaises(cache.InvalidSessionError, cache.get, KEY) if __name__ == "__main__":