1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-22 23:48:12 +00:00

core: introduce caching decorators

This commit is contained in:
matejcik 2020-04-20 11:36:28 +02:00 committed by matejcik
parent 63dfcb17a7
commit 8c4cb58098
2 changed files with 85 additions and 1 deletions

View File

@ -21,7 +21,9 @@ _session_ids = [] # type: List[bytes]
_sessionless_cache = {} # type: Dict[int, Any]
if False:
from typing import Any
from typing import Any, Callable, TypeVar
F = TypeVar("F", bound=Callable[..., Any])
def _move_session_ids_queue(session_id: bytes) -> None:
@ -70,6 +72,43 @@ def get(key: int) -> Any:
return _caches[_active_session_id].get(key)
def stored(key: int) -> Callable[[F], F]:
def decorator(func: F) -> F:
# if we didn't check this, it would be easy to store an Awaitable[something]
# in cache, which might prove hard to debug
assert not isinstance(func, type(lambda: (yield))), "use stored_async instead"
def wrapper(*args, **kwargs): # type: ignore
value = get(key)
if value is None:
value = func(*args, **kwargs)
set(key, value)
return value
return wrapper # type: ignore
return decorator
def stored_async(key: int) -> Callable[[F], F]:
def decorator(func: F) -> F:
# assert isinstance(func, type(lambda: (yield))), "do not use stored_async"
# XXX the test above fails for closures
# We shouldn't need this test here anyway: the 'await func()' should fail
# with functions that do not return an awaitable so the problem is more visible.
async def wrapper(*args, **kwargs): # type: ignore
value = get(key)
if value is None:
value = await func(*args, **kwargs)
set(key, value)
return value
return wrapper # type: ignore
return decorator
def clear_all() -> None:
global _active_session_id
global _caches

View File

@ -50,6 +50,51 @@ class TestStorageCache(unittest.TestCase):
with self.assertRaises(RuntimeError):
cache.get(KEY)
def test_decorator_mismatch(self):
with self.assertRaises(AssertionError):
@cache.stored(KEY)
async def async_fun():
pass
def test_decorators(self):
run_count = 0
cache.start_session()
@cache.stored(KEY)
def func():
nonlocal run_count
run_count += 1
return "foo"
# cache is empty
self.assertIsNone(cache.get(KEY))
self.assertEqual(run_count, 0)
self.assertEqual(func(), "foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(cache.get(KEY), "foo")
# function does not run again but returns cached value
self.assertEqual(func(), "foo")
self.assertEqual(run_count, 1)
@cache.stored_async(KEY)
async def async_func():
nonlocal run_count
run_count += 1
return "bar"
# cache is still full
self.assertEqual(await_result(async_func()), "foo")
self.assertEqual(run_count, 1)
cache.start_session()
self.assertEqual(await_result(async_func()), "bar")
self.assertEqual(run_count, 2)
# awaitable is also run only once
self.assertEqual(await_result(async_func()), "bar")
self.assertEqual(run_count, 2)
@mock_storage
def test_Initialize(self):
def call_Initialize(**kwargs):