refactor(core): modify cache to preallocate all its data

also get rid of expensive "wire" import
pull/1610/head
matejcik 3 years ago committed by matejcik
parent ea505b592c
commit 3cdb09c294

@ -173,7 +173,7 @@ async def handle_CancelAuthorization(
raise wire.ProcessError("No preauthorized operation")
authorization.__del__()
storage.cache.delete(storage.cache.APP_BASE_AUTHORIZATION)
storage.cache.set(storage.cache.APP_BASE_AUTHORIZATION, b"")
return Success(message="Authorization cancelled")

@ -12,11 +12,9 @@ def read_setting() -> EnumTypeSafetyCheckLevel:
"""
Returns the effective safety check level.
"""
temporary_safety_check_level: EnumTypeSafetyCheckLevel | None = storage.cache.get(
APP_COMMON_SAFETY_CHECKS_TEMPORARY
)
if temporary_safety_check_level is not None:
return temporary_safety_check_level
temporary_safety_check_level = storage.cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
if temporary_safety_check_level:
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore
else:
stored = storage.device.safety_check_level()
if stored == SAFETY_CHECK_LEVEL_STRICT:
@ -32,14 +30,14 @@ def apply_setting(level: EnumTypeSafetyCheckLevel) -> None:
Changes the safety level settings.
"""
if level == SafetyCheckLevel.Strict:
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways:
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily:
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level)
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else:
raise ValueError("Unknown SafetyCheckLevel")

@ -51,7 +51,7 @@ async def _init_step(
if not storage.cache.get(storage.cache.APP_MONERO_LIVE_REFRESH):
await confirms.require_confirm_live_refresh(ctx)
storage.cache.set(storage.cache.APP_MONERO_LIVE_REFRESH, True)
storage.cache.set(storage.cache.APP_MONERO_LIVE_REFRESH, b"\x01")
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

@ -1,11 +1,14 @@
from trezor import wire
import gc
from trezor import utils
from trezor.crypto import random
if False:
from typing import Any
from typing import Sequence
_MAX_SESSIONS_COUNT = 10
_SESSIONLESS_FLAG = 128
_SESSION_ID_LENGTH = 32
# Traditional cache keys
APP_COMMON_SEED = 0
@ -14,100 +17,181 @@ APP_MONERO_LIVE_REFRESH = 2
APP_BASE_AUTHORIZATION = 3
# Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 1 | _SESSIONLESS_FLAG
APP_COMMON_SAFETY_CHECKS_TEMPORARY = 2 | _SESSIONLESS_FLAG
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG
APP_COMMON_SAFETY_CHECKS_TEMPORARY = 1 | _SESSIONLESS_FLAG
_active_session_id: bytes | None = None
_caches: dict[bytes, dict[int, Any]] = {}
_session_ids: list[bytes] = []
_sessionless_cache: dict[int, Any] = {}
class InvalidSessionError(Exception):
pass
if False:
from typing import Any, Callable, TypeVar
F = TypeVar("F", bound=Callable[..., Any])
class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f) for f in self.fields]
def _move_session_ids_queue(session_id: bytes) -> None:
# Move the LRU session ids queue.
if session_id in _session_ids:
_session_ids.remove(session_id)
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][:] = value
while len(_session_ids) >= _MAX_SESSIONS_COUNT:
remove_session_id = _session_ids.pop()
del _caches[remove_session_id]
def get(self, key: int) -> bytes:
utils.ensure(key < len(self.fields), "failed to load key %d" % key)
return bytes(self.data[key])
_session_ids.insert(0, session_id)
def clear(self) -> None:
for i in range(len(self.fields)):
self.set(i, b"")
def start_session(received_session_id: bytes | None = None) -> bytes:
if received_session_id and received_session_id in _session_ids:
session_id = received_session_id
else:
session_id = random.bytes(32)
_caches[session_id] = {}
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.fields = (
64, # APP_COMMON_SEED
128, # APP_CARDANO_ROOT
1, # APP_MONERO_LIVE_REFRESH
128, # APP_BASE_AUTHORIZATION
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
)
super().__init__()
# XXX
# Allocation notes:
# Instantiation of a DataCache subclass should make as little garbage as possible, so
# that the preallocated bytearrays are compact in memory.
# That is why the initialization is two-step: first create appropriately sized
# bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
_SESSIONLESS_CACHE.clear()
global _active_session_id
_active_session_id = session_id
_move_session_ids_queue(session_id)
return _active_session_id
gc.collect()
_active_session_idx: int | None = None
_session_usage_counter = 0
def start_session(received_session_id: bytes | None = None) -> bytes:
global _active_session_idx
global _session_usage_counter
if (
received_session_id is not None
and len(received_session_id) != _SESSION_ID_LENGTH
):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None:
global _active_session_id
global _active_session_idx
if _active_session_id is None:
if _active_session_idx is None:
return
current_session_id = _active_session_id
_active_session_id = None
_session_ids.remove(current_session_id)
del _caches[current_session_id]
_SESSIONS[_active_session_idx].clear()
_active_session_idx = None
def is_session_started() -> bool:
return _active_session_id is not None
return _active_session_idx is not None
def set(key: int, value: Any) -> None:
def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG:
_sessionless_cache[key] = value
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return
if _active_session_id is None:
raise wire.InvalidSession
_caches[_active_session_id][key] = value
if _active_session_idx is None:
raise InvalidSessionError
_SESSIONS[_active_session_idx].set(key, value)
def get(key: int) -> Any:
def get(key: int) -> bytes:
if key & _SESSIONLESS_FLAG:
return _sessionless_cache.get(key)
if _active_session_id is None:
raise wire.InvalidSession
return _caches[_active_session_id].get(key)
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key)
def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG:
if key in _sessionless_cache:
del _sessionless_cache[key]
return
if _active_session_id is None:
raise wire.InvalidSession
if key in _caches[_active_session_id]:
del _caches[_active_session_id][key]
if False:
from typing import Awaitable, Callable, TypeVar
ByteFunc = TypeVar("ByteFunc", bound=Callable[..., bytes])
AsyncByteFunc = TypeVar("AsyncByteFunc", bound=Callable[..., Awaitable[bytes]])
def stored(key: int) -> Callable[[F], F]:
def decorator(func: F) -> F:
def stored(key: int) -> Callable[[ByteFunc], ByteFunc]:
def decorator(func: ByteFunc) -> ByteFunc:
# if we didn't check this, it would be easy to store an Awaitable[something]
# in cache, which might prove hard to debug
# XXX mypy should be checking this now, but we don't have full coverage yet
assert not isinstance(func, type(lambda: (yield))), "use stored_async instead"
def wrapper(*args, **kwargs): # type: ignore
value = get(key)
if value is None:
if not value:
value = func(*args, **kwargs)
set(key, value)
return value
@ -117,8 +201,8 @@ def stored(key: int) -> Callable[[F], F]:
return decorator
def stored_async(key: int) -> Callable[[F], F]:
def decorator(func: F) -> F:
def stored_async(key: int) -> Callable[[AsyncByteFunc], AsyncByteFunc]:
def decorator(func: AsyncByteFunc) -> AsyncByteFunc:
# assert isinstance(func, type(lambda: (yield))), "do not use stored_async"
# XXX the test above fails for closures
# We shouldn't need this test here anyway: the 'await func()' should fail
@ -126,7 +210,7 @@ def stored_async(key: int) -> Callable[[F], F]:
async def wrapper(*args, **kwargs): # type: ignore
value = get(key)
if value is None:
if not value:
value = await func(*args, **kwargs)
set(key, value)
return value
@ -137,12 +221,9 @@ def stored_async(key: int) -> Callable[[F], F]:
def clear_all() -> None:
global _active_session_id
global _caches
global _session_ids
global _sessionless_cache
_active_session_id = None
_caches.clear()
_session_ids.clear()
_sessionless_cache.clear()
global _active_session_idx
_active_session_idx = None
_SESSIONLESS_CACHE.clear()
for session in _SESSIONS:
session.clear()

@ -36,6 +36,7 @@ reads the message's header. When the message type is known the first handler is
"""
import protobuf
from storage.cache import InvalidSessionError
from trezor import log, loop, messages, ui, utils, workflow
from trezor.messages import FailureType
from trezor.messages.Failure import Failure
@ -492,6 +493,8 @@ def failure(exc: BaseException) -> Failure:
return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else:
return Failure(code=FailureType.FirmwareError, message="Firmware error")

@ -4,11 +4,11 @@ from mock_storage import mock_storage
from storage import cache
from trezor.messages.Initialize import Initialize
from trezor.messages.EndSession import EndSession
from trezor.wire import DUMMY_CONTEXT, InvalidSession
from trezor.wire import DUMMY_CONTEXT
from apps.base import handle_Initialize, handle_EndSession
KEY = 99
KEY = 0
class TestStorageCache(unittest.TestCase):
@ -22,18 +22,18 @@ class TestStorageCache(unittest.TestCase):
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all()
with self.assertRaises(InvalidSession):
with self.assertRaises(cache.InvalidSessionError):
cache.set(KEY, "something")
with self.assertRaises(InvalidSession):
with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY)
def test_end_session(self):
session_id = cache.start_session()
self.assertTrue(cache.is_session_started())
cache.set(KEY, "A")
cache.set(KEY, b"A")
cache.end_current_session()
self.assertFalse(cache.is_session_started())
self.assertRaises(InvalidSession, cache.get, KEY)
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
# ending an ended session should be a no-op
cache.end_current_session()
@ -43,7 +43,7 @@ class TestStorageCache(unittest.TestCase):
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(cache.get(KEY))
self.assertEqual(cache.get(KEY), b"")
# create a new session
session_id_b = cache.start_session()
@ -59,28 +59,28 @@ class TestStorageCache(unittest.TestCase):
def test_session_queue(self):
session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id)
cache.set(KEY, "A")
cache.set(KEY, b"A")
for i in range(cache._MAX_SESSIONS_COUNT):
cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY))
self.assertEqual(cache.get(KEY), b"")
def test_get_set(self):
session_id1 = cache.start_session()
cache.set(KEY, "hello")
self.assertEqual(cache.get(KEY), "hello")
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
session_id2 = cache.start_session()
cache.set(KEY, "world")
self.assertEqual(cache.get(KEY), "world")
cache.set(KEY, b"world")
self.assertEqual(cache.get(KEY), b"world")
cache.start_session(session_id2)
self.assertEqual(cache.get(KEY), "world")
self.assertEqual(cache.get(KEY), b"world")
cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), "hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.clear_all()
with self.assertRaises(InvalidSession):
with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY)
def test_decorator_mismatch(self):
@ -98,34 +98,34 @@ class TestStorageCache(unittest.TestCase):
def func():
nonlocal run_count
run_count += 1
return "foo"
return b"foo"
# cache is empty
self.assertIsNone(cache.get(KEY))
self.assertEqual(cache.get(KEY), b"")
self.assertEqual(run_count, 0)
self.assertEqual(func(), "foo")
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(cache.get(KEY), "foo")
self.assertEqual(cache.get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), "foo")
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
@cache.stored_async(KEY)
async def async_func():
nonlocal run_count
run_count += 1
return "bar"
return b"bar"
# cache is still full
self.assertEqual(await_result(async_func()), "foo")
self.assertEqual(await_result(async_func()), b"foo")
self.assertEqual(run_count, 1)
cache.start_session()
self.assertEqual(await_result(async_func()), "bar")
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
# awaitable is also run only once
self.assertEqual(await_result(async_func()), "bar")
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
@mock_storage
@ -144,31 +144,31 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(session_id, features.session_id)
# store "hello"
cache.set(KEY, "hello")
cache.set(KEY, b"hello")
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertIsNone(cache.get(KEY))
self.assertEqual(cache.get(KEY), b"")
# store "hello" again
cache.set(KEY, "hello")
self.assertEqual(cache.get(KEY), "hello")
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A")
self.assertIsNone(cache.get(KEY))
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
self.assertEqual(cache.get(KEY), b"")
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
self.assertEqual(cache.get(KEY), "hello")
self.assertEqual(cache.get(KEY), b"hello")
def test_EndSession(self):
self.assertRaises(InvalidSession, cache.get, KEY)
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
session_id = cache.start_session()
self.assertTrue(cache.is_session_started())
self.assertIsNone(cache.get(KEY))
self.assertEqual(cache.get(KEY), b"")
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
self.assertFalse(cache.is_session_started())
self.assertRaises(InvalidSession, cache.get, KEY)
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
if __name__ == "__main__":

Loading…
Cancel
Save