mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-26 17:38:39 +00:00
core: introduce caching decorators
This commit is contained in:
parent
63dfcb17a7
commit
8c4cb58098
@ -21,7 +21,9 @@ _session_ids = [] # type: List[bytes]
|
|||||||
_sessionless_cache = {} # type: Dict[int, Any]
|
_sessionless_cache = {} # type: Dict[int, Any]
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Any
|
from typing import Any, Callable, TypeVar
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
|
||||||
def _move_session_ids_queue(session_id: bytes) -> None:
|
def _move_session_ids_queue(session_id: bytes) -> None:
|
||||||
@ -70,6 +72,43 @@ def get(key: int) -> Any:
|
|||||||
return _caches[_active_session_id].get(key)
|
return _caches[_active_session_id].get(key)
|
||||||
|
|
||||||
|
|
||||||
|
def stored(key: int) -> Callable[[F], F]:
|
||||||
|
def decorator(func: F) -> F:
|
||||||
|
# if we didn't check this, it would be easy to store an Awaitable[something]
|
||||||
|
# in cache, which might prove hard to debug
|
||||||
|
assert not isinstance(func, type(lambda: (yield))), "use stored_async instead"
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs): # type: ignore
|
||||||
|
value = get(key)
|
||||||
|
if value is None:
|
||||||
|
value = func(*args, **kwargs)
|
||||||
|
set(key, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def stored_async(key: int) -> Callable[[F], F]:
|
||||||
|
def decorator(func: F) -> F:
|
||||||
|
# assert isinstance(func, type(lambda: (yield))), "do not use stored_async"
|
||||||
|
# XXX the test above fails for closures
|
||||||
|
# We shouldn't need this test here anyway: the 'await func()' should fail
|
||||||
|
# with functions that do not return an awaitable so the problem is more visible.
|
||||||
|
|
||||||
|
async def wrapper(*args, **kwargs): # type: ignore
|
||||||
|
value = get(key)
|
||||||
|
if value is None:
|
||||||
|
value = await func(*args, **kwargs)
|
||||||
|
set(key, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
return wrapper # type: ignore
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def clear_all() -> None:
|
def clear_all() -> None:
|
||||||
global _active_session_id
|
global _active_session_id
|
||||||
global _caches
|
global _caches
|
||||||
|
@ -50,6 +50,51 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
|
def test_decorator_mismatch(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
|
||||||
|
@cache.stored(KEY)
|
||||||
|
async def async_fun():
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_decorators(self):
|
||||||
|
run_count = 0
|
||||||
|
cache.start_session()
|
||||||
|
|
||||||
|
@cache.stored(KEY)
|
||||||
|
def func():
|
||||||
|
nonlocal run_count
|
||||||
|
run_count += 1
|
||||||
|
return "foo"
|
||||||
|
|
||||||
|
# cache is empty
|
||||||
|
self.assertIsNone(cache.get(KEY))
|
||||||
|
self.assertEqual(run_count, 0)
|
||||||
|
self.assertEqual(func(), "foo")
|
||||||
|
# function was run
|
||||||
|
self.assertEqual(run_count, 1)
|
||||||
|
self.assertEqual(cache.get(KEY), "foo")
|
||||||
|
# function does not run again but returns cached value
|
||||||
|
self.assertEqual(func(), "foo")
|
||||||
|
self.assertEqual(run_count, 1)
|
||||||
|
|
||||||
|
@cache.stored_async(KEY)
|
||||||
|
async def async_func():
|
||||||
|
nonlocal run_count
|
||||||
|
run_count += 1
|
||||||
|
return "bar"
|
||||||
|
|
||||||
|
# cache is still full
|
||||||
|
self.assertEqual(await_result(async_func()), "foo")
|
||||||
|
self.assertEqual(run_count, 1)
|
||||||
|
|
||||||
|
cache.start_session()
|
||||||
|
self.assertEqual(await_result(async_func()), "bar")
|
||||||
|
self.assertEqual(run_count, 2)
|
||||||
|
# awaitable is also run only once
|
||||||
|
self.assertEqual(await_result(async_func()), "bar")
|
||||||
|
self.assertEqual(run_count, 2)
|
||||||
|
|
||||||
@mock_storage
|
@mock_storage
|
||||||
def test_Initialize(self):
|
def test_Initialize(self):
|
||||||
def call_Initialize(**kwargs):
|
def call_Initialize(**kwargs):
|
||||||
|
Loading…
Reference in New Issue
Block a user