1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-11 07:50:57 +00:00

refactor(core): modify cache to preallocate all its data

also get rid of expensive "wire" import
This commit is contained in:
matejcik 2021-02-25 15:37:02 +01:00 committed by matejcik
parent ea505b592c
commit 3cdb09c294
6 changed files with 198 additions and 116 deletions

View File

@ -173,7 +173,7 @@ async def handle_CancelAuthorization(
raise wire.ProcessError("No preauthorized operation") raise wire.ProcessError("No preauthorized operation")
authorization.__del__() authorization.__del__()
storage.cache.delete(storage.cache.APP_BASE_AUTHORIZATION) storage.cache.set(storage.cache.APP_BASE_AUTHORIZATION, b"")
return Success(message="Authorization cancelled") return Success(message="Authorization cancelled")

View File

@ -12,11 +12,9 @@ def read_setting() -> EnumTypeSafetyCheckLevel:
""" """
Returns the effective safety check level. Returns the effective safety check level.
""" """
temporary_safety_check_level: EnumTypeSafetyCheckLevel | None = storage.cache.get( temporary_safety_check_level = storage.cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
APP_COMMON_SAFETY_CHECKS_TEMPORARY if temporary_safety_check_level:
) return int.from_bytes(temporary_safety_check_level, "big") # type: ignore
if temporary_safety_check_level is not None:
return temporary_safety_check_level
else: else:
stored = storage.device.safety_check_level() stored = storage.device.safety_check_level()
if stored == SAFETY_CHECK_LEVEL_STRICT: if stored == SAFETY_CHECK_LEVEL_STRICT:
@ -32,14 +30,14 @@ def apply_setting(level: EnumTypeSafetyCheckLevel) -> None:
Changes the safety level settings. Changes the safety level settings.
""" """
if level == SafetyCheckLevel.Strict: if level == SafetyCheckLevel.Strict:
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways: elif level == SafetyCheckLevel.PromptAlways:
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily: elif level == SafetyCheckLevel.PromptTemporarily:
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level) storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else: else:
raise ValueError("Unknown SafetyCheckLevel") raise ValueError("Unknown SafetyCheckLevel")

View File

@ -51,7 +51,7 @@ async def _init_step(
if not storage.cache.get(storage.cache.APP_MONERO_LIVE_REFRESH): if not storage.cache.get(storage.cache.APP_MONERO_LIVE_REFRESH):
await confirms.require_confirm_live_refresh(ctx) await confirms.require_confirm_live_refresh(ctx)
storage.cache.set(storage.cache.APP_MONERO_LIVE_REFRESH, True) storage.cache.set(storage.cache.APP_MONERO_LIVE_REFRESH, b"\x01")
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -1,11 +1,14 @@
from trezor import wire import gc
from trezor import utils
from trezor.crypto import random from trezor.crypto import random
if False: if False:
from typing import Any from typing import Sequence
_MAX_SESSIONS_COUNT = 10 _MAX_SESSIONS_COUNT = 10
_SESSIONLESS_FLAG = 128 _SESSIONLESS_FLAG = 128
_SESSION_ID_LENGTH = 32
# Traditional cache keys # Traditional cache keys
APP_COMMON_SEED = 0 APP_COMMON_SEED = 0
@ -14,100 +17,181 @@ APP_MONERO_LIVE_REFRESH = 2
APP_BASE_AUTHORIZATION = 3 APP_BASE_AUTHORIZATION = 3
# Keys that are valid across sessions # Keys that are valid across sessions
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 1 | _SESSIONLESS_FLAG APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG
APP_COMMON_SAFETY_CHECKS_TEMPORARY = 2 | _SESSIONLESS_FLAG APP_COMMON_SAFETY_CHECKS_TEMPORARY = 1 | _SESSIONLESS_FLAG
_active_session_id: bytes | None = None class InvalidSessionError(Exception):
_caches: dict[bytes, dict[int, Any]] = {} pass
_session_ids: list[bytes] = []
_sessionless_cache: dict[int, Any] = {}
if False:
from typing import Any, Callable, TypeVar
F = TypeVar("F", bound=Callable[..., Any])
def _move_session_ids_queue(session_id: bytes) -> None: class DataCache:
# Move the LRU session ids queue. fields: Sequence[int]
if session_id in _session_ids:
_session_ids.remove(session_id)
while len(_session_ids) >= _MAX_SESSIONS_COUNT: def __init__(self) -> None:
remove_session_id = _session_ids.pop() self.data = [bytearray(f) for f in self.fields]
del _caches[remove_session_id]
_session_ids.insert(0, session_id) def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][:] = value
def get(self, key: int) -> bytes:
utils.ensure(key < len(self.fields), "failed to load key %d" % key)
return bytes(self.data[key])
def clear(self) -> None:
for i in range(len(self.fields)):
self.set(i, b"")
class SessionCache(DataCache):
def __init__(self) -> None:
self.session_id = bytearray(_SESSION_ID_LENGTH)
self.fields = (
64, # APP_COMMON_SEED
128, # APP_CARDANO_ROOT
1, # APP_MONERO_LIVE_REFRESH
128, # APP_BASE_AUTHORIZATION
)
self.last_usage = 0
super().__init__()
def export_session_id(self) -> bytes:
# generate a new session id if we don't have it yet
if not self.session_id:
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
# export it as immutable bytes
return bytes(self.session_id)
def clear(self) -> None:
super().clear()
self.last_usage = 0
self.session_id[:] = b""
class SessionlessCache(DataCache):
def __init__(self) -> None:
self.fields = (
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
)
super().__init__()
# XXX
# Allocation notes:
# Instantiation of a DataCache subclass should make as little garbage as possible, so
# that the preallocated bytearrays are compact in memory.
# That is why the initialization is two-step: first create appropriately sized
# bytearrays, then later call `clear()` on all the existing objects, which resets them
# to zero length. This is producing some trash - `b[:]` allocates a slice.
_SESSIONS: list[SessionCache] = []
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionCache())
_SESSIONLESS_CACHE = SessionlessCache()
for session in _SESSIONS:
session.clear()
_SESSIONLESS_CACHE.clear()
gc.collect()
_active_session_idx: int | None = None
_session_usage_counter = 0
def start_session(received_session_id: bytes | None = None) -> bytes: def start_session(received_session_id: bytes | None = None) -> bytes:
if received_session_id and received_session_id in _session_ids: global _active_session_idx
session_id = received_session_id global _session_usage_counter
else:
session_id = random.bytes(32)
_caches[session_id] = {}
global _active_session_id if (
_active_session_id = session_id received_session_id is not None
_move_session_ids_queue(session_id) and len(received_session_id) != _SESSION_ID_LENGTH
return _active_session_id ):
# Prevent the caller from setting received_session_id=b"" and finding a cleared
# session. More generally, short-circuit the session id search, because we know
# that wrong-length session ids should not be in cache.
# Reduce to "session id not provided" case because that's what we do when
# caller supplies an id that is not found.
received_session_id = None
_session_usage_counter += 1
# attempt to find specified session id
if received_session_id:
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].session_id == received_session_id:
_active_session_idx = i
_SESSIONS[i].last_usage = _session_usage_counter
return received_session_id
# allocate least recently used session
lru_counter = _session_usage_counter
lru_session_idx = 0
for i in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[i].last_usage < lru_counter:
lru_counter = _SESSIONS[i].last_usage
lru_session_idx = i
_active_session_idx = lru_session_idx
selected_session = _SESSIONS[lru_session_idx]
selected_session.clear()
selected_session.last_usage = _session_usage_counter
return selected_session.export_session_id()
def end_current_session() -> None: def end_current_session() -> None:
global _active_session_id global _active_session_idx
if _active_session_id is None: if _active_session_idx is None:
return return
current_session_id = _active_session_id _SESSIONS[_active_session_idx].clear()
_active_session_id = None _active_session_idx = None
_session_ids.remove(current_session_id)
del _caches[current_session_id]
def is_session_started() -> bool: def is_session_started() -> bool:
return _active_session_id is not None return _active_session_idx is not None
def set(key: int, value: Any) -> None: def set(key: int, value: bytes) -> None:
if key & _SESSIONLESS_FLAG: if key & _SESSIONLESS_FLAG:
_sessionless_cache[key] = value _SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
return return
if _active_session_id is None: if _active_session_idx is None:
raise wire.InvalidSession raise InvalidSessionError
_caches[_active_session_id][key] = value _SESSIONS[_active_session_idx].set(key, value)
def get(key: int) -> Any: def get(key: int) -> bytes:
if key & _SESSIONLESS_FLAG: if key & _SESSIONLESS_FLAG:
return _sessionless_cache.get(key) return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG)
if _active_session_id is None: if _active_session_idx is None:
raise wire.InvalidSession raise InvalidSessionError
return _caches[_active_session_id].get(key) return _SESSIONS[_active_session_idx].get(key)
def delete(key: int) -> None: if False:
if key & _SESSIONLESS_FLAG: from typing import Awaitable, Callable, TypeVar
if key in _sessionless_cache:
del _sessionless_cache[key] ByteFunc = TypeVar("ByteFunc", bound=Callable[..., bytes])
return AsyncByteFunc = TypeVar("AsyncByteFunc", bound=Callable[..., Awaitable[bytes]])
if _active_session_id is None:
raise wire.InvalidSession
if key in _caches[_active_session_id]:
del _caches[_active_session_id][key]
def stored(key: int) -> Callable[[F], F]: def stored(key: int) -> Callable[[ByteFunc], ByteFunc]:
def decorator(func: F) -> F: def decorator(func: ByteFunc) -> ByteFunc:
# if we didn't check this, it would be easy to store an Awaitable[something] # if we didn't check this, it would be easy to store an Awaitable[something]
# in cache, which might prove hard to debug # in cache, which might prove hard to debug
# XXX mypy should be checking this now, but we don't have full coverage yet
assert not isinstance(func, type(lambda: (yield))), "use stored_async instead" assert not isinstance(func, type(lambda: (yield))), "use stored_async instead"
def wrapper(*args, **kwargs): # type: ignore def wrapper(*args, **kwargs): # type: ignore
value = get(key) value = get(key)
if value is None: if not value:
value = func(*args, **kwargs) value = func(*args, **kwargs)
set(key, value) set(key, value)
return value return value
@ -117,8 +201,8 @@ def stored(key: int) -> Callable[[F], F]:
return decorator return decorator
def stored_async(key: int) -> Callable[[F], F]: def stored_async(key: int) -> Callable[[AsyncByteFunc], AsyncByteFunc]:
def decorator(func: F) -> F: def decorator(func: AsyncByteFunc) -> AsyncByteFunc:
# assert isinstance(func, type(lambda: (yield))), "do not use stored_async" # assert isinstance(func, type(lambda: (yield))), "do not use stored_async"
# XXX the test above fails for closures # XXX the test above fails for closures
# We shouldn't need this test here anyway: the 'await func()' should fail # We shouldn't need this test here anyway: the 'await func()' should fail
@ -126,7 +210,7 @@ def stored_async(key: int) -> Callable[[F], F]:
async def wrapper(*args, **kwargs): # type: ignore async def wrapper(*args, **kwargs): # type: ignore
value = get(key) value = get(key)
if value is None: if not value:
value = await func(*args, **kwargs) value = await func(*args, **kwargs)
set(key, value) set(key, value)
return value return value
@ -137,12 +221,9 @@ def stored_async(key: int) -> Callable[[F], F]:
def clear_all() -> None: def clear_all() -> None:
global _active_session_id global _active_session_idx
global _caches
global _session_ids
global _sessionless_cache
_active_session_id = None _active_session_idx = None
_caches.clear() _SESSIONLESS_CACHE.clear()
_session_ids.clear() for session in _SESSIONS:
_sessionless_cache.clear() session.clear()

View File

@ -36,6 +36,7 @@ reads the message's header. When the message type is known the first handler is
""" """
import protobuf import protobuf
from storage.cache import InvalidSessionError
from trezor import log, loop, messages, ui, utils, workflow from trezor import log, loop, messages, ui, utils, workflow
from trezor.messages import FailureType from trezor.messages import FailureType
from trezor.messages.Failure import Failure from trezor.messages.Failure import Failure
@ -492,6 +493,8 @@ def failure(exc: BaseException) -> Failure:
return Failure(code=exc.code, message=exc.message) return Failure(code=exc.code, message=exc.message)
elif isinstance(exc, loop.TaskClosed): elif isinstance(exc, loop.TaskClosed):
return Failure(code=FailureType.ActionCancelled, message="Cancelled") return Failure(code=FailureType.ActionCancelled, message="Cancelled")
elif isinstance(exc, InvalidSessionError):
return Failure(code=FailureType.InvalidSession, message="Invalid session")
else: else:
return Failure(code=FailureType.FirmwareError, message="Firmware error") return Failure(code=FailureType.FirmwareError, message="Firmware error")

View File

@ -4,11 +4,11 @@ from mock_storage import mock_storage
from storage import cache from storage import cache
from trezor.messages.Initialize import Initialize from trezor.messages.Initialize import Initialize
from trezor.messages.EndSession import EndSession from trezor.messages.EndSession import EndSession
from trezor.wire import DUMMY_CONTEXT, InvalidSession from trezor.wire import DUMMY_CONTEXT
from apps.base import handle_Initialize, handle_EndSession from apps.base import handle_Initialize, handle_EndSession
KEY = 99 KEY = 0
class TestStorageCache(unittest.TestCase): class TestStorageCache(unittest.TestCase):
@ -22,18 +22,18 @@ class TestStorageCache(unittest.TestCase):
self.assertNotEqual(session_id_a, session_id_b) self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all() cache.clear_all()
with self.assertRaises(InvalidSession): with self.assertRaises(cache.InvalidSessionError):
cache.set(KEY, "something") cache.set(KEY, "something")
with self.assertRaises(InvalidSession): with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY) cache.get(KEY)
def test_end_session(self): def test_end_session(self):
session_id = cache.start_session() session_id = cache.start_session()
self.assertTrue(cache.is_session_started()) self.assertTrue(cache.is_session_started())
cache.set(KEY, "A") cache.set(KEY, b"A")
cache.end_current_session() cache.end_current_session()
self.assertFalse(cache.is_session_started()) self.assertFalse(cache.is_session_started())
self.assertRaises(InvalidSession, cache.get, KEY) self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
# ending an ended session should be a no-op # ending an ended session should be a no-op
cache.end_current_session() cache.end_current_session()
@ -43,7 +43,7 @@ class TestStorageCache(unittest.TestCase):
# original session no longer exists # original session no longer exists
self.assertNotEqual(session_id_a, session_id) self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists # original session data no longer exists
self.assertIsNone(cache.get(KEY)) self.assertEqual(cache.get(KEY), b"")
# create a new session # create a new session
session_id_b = cache.start_session() session_id_b = cache.start_session()
@ -59,28 +59,28 @@ class TestStorageCache(unittest.TestCase):
def test_session_queue(self): def test_session_queue(self):
session_id = cache.start_session() session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id) self.assertEqual(cache.start_session(session_id), session_id)
cache.set(KEY, "A") cache.set(KEY, b"A")
for i in range(cache._MAX_SESSIONS_COUNT): for i in range(cache._MAX_SESSIONS_COUNT):
cache.start_session() cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id) self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY)) self.assertEqual(cache.get(KEY), b"")
def test_get_set(self): def test_get_set(self):
session_id1 = cache.start_session() session_id1 = cache.start_session()
cache.set(KEY, "hello") cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), "hello") self.assertEqual(cache.get(KEY), b"hello")
session_id2 = cache.start_session() session_id2 = cache.start_session()
cache.set(KEY, "world") cache.set(KEY, b"world")
self.assertEqual(cache.get(KEY), "world") self.assertEqual(cache.get(KEY), b"world")
cache.start_session(session_id2) cache.start_session(session_id2)
self.assertEqual(cache.get(KEY), "world") self.assertEqual(cache.get(KEY), b"world")
cache.start_session(session_id1) cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), "hello") self.assertEqual(cache.get(KEY), b"hello")
cache.clear_all() cache.clear_all()
with self.assertRaises(InvalidSession): with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY) cache.get(KEY)
def test_decorator_mismatch(self): def test_decorator_mismatch(self):
@ -98,34 +98,34 @@ class TestStorageCache(unittest.TestCase):
def func(): def func():
nonlocal run_count nonlocal run_count
run_count += 1 run_count += 1
return "foo" return b"foo"
# cache is empty # cache is empty
self.assertIsNone(cache.get(KEY)) self.assertEqual(cache.get(KEY), b"")
self.assertEqual(run_count, 0) self.assertEqual(run_count, 0)
self.assertEqual(func(), "foo") self.assertEqual(func(), b"foo")
# function was run # function was run
self.assertEqual(run_count, 1) self.assertEqual(run_count, 1)
self.assertEqual(cache.get(KEY), "foo") self.assertEqual(cache.get(KEY), b"foo")
# function does not run again but returns cached value # function does not run again but returns cached value
self.assertEqual(func(), "foo") self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1) self.assertEqual(run_count, 1)
@cache.stored_async(KEY) @cache.stored_async(KEY)
async def async_func(): async def async_func():
nonlocal run_count nonlocal run_count
run_count += 1 run_count += 1
return "bar" return b"bar"
# cache is still full # cache is still full
self.assertEqual(await_result(async_func()), "foo") self.assertEqual(await_result(async_func()), b"foo")
self.assertEqual(run_count, 1) self.assertEqual(run_count, 1)
cache.start_session() cache.start_session()
self.assertEqual(await_result(async_func()), "bar") self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2) self.assertEqual(run_count, 2)
# awaitable is also run only once # awaitable is also run only once
self.assertEqual(await_result(async_func()), "bar") self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2) self.assertEqual(run_count, 2)
@mock_storage @mock_storage
@ -144,31 +144,31 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(session_id, features.session_id) self.assertEqual(session_id, features.session_id)
# store "hello" # store "hello"
cache.set(KEY, "hello") cache.set(KEY, b"hello")
# check that it is cleared # check that it is cleared
features = call_Initialize() features = call_Initialize()
session_id = features.session_id session_id = features.session_id
self.assertIsNone(cache.get(KEY)) self.assertEqual(cache.get(KEY), b"")
# store "hello" again # store "hello" again
cache.set(KEY, "hello") cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), "hello") self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new cache # supplying a different session ID starts a new cache
call_Initialize(session_id=b"A") call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY)) self.assertEqual(cache.get(KEY), b"")
# but resuming a session loads the previous one # but resuming a session loads the previous one
call_Initialize(session_id=session_id) call_Initialize(session_id=session_id)
self.assertEqual(cache.get(KEY), "hello") self.assertEqual(cache.get(KEY), b"hello")
def test_EndSession(self): def test_EndSession(self):
self.assertRaises(InvalidSession, cache.get, KEY) self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
session_id = cache.start_session() session_id = cache.start_session()
self.assertTrue(cache.is_session_started()) self.assertTrue(cache.is_session_started())
self.assertIsNone(cache.get(KEY)) self.assertEqual(cache.get(KEY), b"")
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession())) await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
self.assertFalse(cache.is_session_started()) self.assertFalse(cache.is_session_started())
self.assertRaises(InvalidSession, cache.get, KEY) self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
if __name__ == "__main__": if __name__ == "__main__":