mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-25 14:50:57 +00:00
refactor(core): modify cache to preallocate all its data
also get rid of expensive "wire" import
This commit is contained in:
parent
ea505b592c
commit
3cdb09c294
@ -173,7 +173,7 @@ async def handle_CancelAuthorization(
|
|||||||
raise wire.ProcessError("No preauthorized operation")
|
raise wire.ProcessError("No preauthorized operation")
|
||||||
|
|
||||||
authorization.__del__()
|
authorization.__del__()
|
||||||
storage.cache.delete(storage.cache.APP_BASE_AUTHORIZATION)
|
storage.cache.set(storage.cache.APP_BASE_AUTHORIZATION, b"")
|
||||||
|
|
||||||
return Success(message="Authorization cancelled")
|
return Success(message="Authorization cancelled")
|
||||||
|
|
||||||
|
@ -12,11 +12,9 @@ def read_setting() -> EnumTypeSafetyCheckLevel:
|
|||||||
"""
|
"""
|
||||||
Returns the effective safety check level.
|
Returns the effective safety check level.
|
||||||
"""
|
"""
|
||||||
temporary_safety_check_level: EnumTypeSafetyCheckLevel | None = storage.cache.get(
|
temporary_safety_check_level = storage.cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
if temporary_safety_check_level:
|
||||||
)
|
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore
|
||||||
if temporary_safety_check_level is not None:
|
|
||||||
return temporary_safety_check_level
|
|
||||||
else:
|
else:
|
||||||
stored = storage.device.safety_check_level()
|
stored = storage.device.safety_check_level()
|
||||||
if stored == SAFETY_CHECK_LEVEL_STRICT:
|
if stored == SAFETY_CHECK_LEVEL_STRICT:
|
||||||
@ -32,14 +30,14 @@ def apply_setting(level: EnumTypeSafetyCheckLevel) -> None:
|
|||||||
Changes the safety level settings.
|
Changes the safety level settings.
|
||||||
"""
|
"""
|
||||||
if level == SafetyCheckLevel.Strict:
|
if level == SafetyCheckLevel.Strict:
|
||||||
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
|
||||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||||
elif level == SafetyCheckLevel.PromptAlways:
|
elif level == SafetyCheckLevel.PromptAlways:
|
||||||
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
|
||||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
|
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
|
||||||
elif level == SafetyCheckLevel.PromptTemporarily:
|
elif level == SafetyCheckLevel.PromptTemporarily:
|
||||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||||
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level)
|
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown SafetyCheckLevel")
|
raise ValueError("Unknown SafetyCheckLevel")
|
||||||
|
|
||||||
|
@ -51,7 +51,7 @@ async def _init_step(
|
|||||||
|
|
||||||
if not storage.cache.get(storage.cache.APP_MONERO_LIVE_REFRESH):
|
if not storage.cache.get(storage.cache.APP_MONERO_LIVE_REFRESH):
|
||||||
await confirms.require_confirm_live_refresh(ctx)
|
await confirms.require_confirm_live_refresh(ctx)
|
||||||
storage.cache.set(storage.cache.APP_MONERO_LIVE_REFRESH, True)
|
storage.cache.set(storage.cache.APP_MONERO_LIVE_REFRESH, b"\x01")
|
||||||
|
|
||||||
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)
|
||||||
|
|
||||||
|
@ -1,11 +1,14 @@
|
|||||||
from trezor import wire
|
import gc
|
||||||
|
|
||||||
|
from trezor import utils
|
||||||
from trezor.crypto import random
|
from trezor.crypto import random
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Any
|
from typing import Sequence
|
||||||
|
|
||||||
_MAX_SESSIONS_COUNT = 10
|
_MAX_SESSIONS_COUNT = 10
|
||||||
_SESSIONLESS_FLAG = 128
|
_SESSIONLESS_FLAG = 128
|
||||||
|
_SESSION_ID_LENGTH = 32
|
||||||
|
|
||||||
# Traditional cache keys
|
# Traditional cache keys
|
||||||
APP_COMMON_SEED = 0
|
APP_COMMON_SEED = 0
|
||||||
@ -14,100 +17,181 @@ APP_MONERO_LIVE_REFRESH = 2
|
|||||||
APP_BASE_AUTHORIZATION = 3
|
APP_BASE_AUTHORIZATION = 3
|
||||||
|
|
||||||
# Keys that are valid across sessions
|
# Keys that are valid across sessions
|
||||||
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 1 | _SESSIONLESS_FLAG
|
APP_COMMON_SEED_WITHOUT_PASSPHRASE = 0 | _SESSIONLESS_FLAG
|
||||||
APP_COMMON_SAFETY_CHECKS_TEMPORARY = 2 | _SESSIONLESS_FLAG
|
APP_COMMON_SAFETY_CHECKS_TEMPORARY = 1 | _SESSIONLESS_FLAG
|
||||||
|
|
||||||
|
|
||||||
_active_session_id: bytes | None = None
|
class InvalidSessionError(Exception):
|
||||||
_caches: dict[bytes, dict[int, Any]] = {}
|
pass
|
||||||
_session_ids: list[bytes] = []
|
|
||||||
_sessionless_cache: dict[int, Any] = {}
|
|
||||||
|
|
||||||
if False:
|
|
||||||
from typing import Any, Callable, TypeVar
|
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable[..., Any])
|
|
||||||
|
|
||||||
|
|
||||||
def _move_session_ids_queue(session_id: bytes) -> None:
|
class DataCache:
|
||||||
# Move the LRU session ids queue.
|
fields: Sequence[int]
|
||||||
if session_id in _session_ids:
|
|
||||||
_session_ids.remove(session_id)
|
|
||||||
|
|
||||||
while len(_session_ids) >= _MAX_SESSIONS_COUNT:
|
def __init__(self) -> None:
|
||||||
remove_session_id = _session_ids.pop()
|
self.data = [bytearray(f) for f in self.fields]
|
||||||
del _caches[remove_session_id]
|
|
||||||
|
|
||||||
_session_ids.insert(0, session_id)
|
def set(self, key: int, value: bytes) -> None:
|
||||||
|
utils.ensure(key < len(self.fields))
|
||||||
|
utils.ensure(len(value) <= self.fields[key])
|
||||||
|
self.data[key][:] = value
|
||||||
|
|
||||||
|
def get(self, key: int) -> bytes:
|
||||||
|
utils.ensure(key < len(self.fields), "failed to load key %d" % key)
|
||||||
|
return bytes(self.data[key])
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
for i in range(len(self.fields)):
|
||||||
|
self.set(i, b"")
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCache(DataCache):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.session_id = bytearray(_SESSION_ID_LENGTH)
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED
|
||||||
|
128, # APP_CARDANO_ROOT
|
||||||
|
1, # APP_MONERO_LIVE_REFRESH
|
||||||
|
128, # APP_BASE_AUTHORIZATION
|
||||||
|
)
|
||||||
|
self.last_usage = 0
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def export_session_id(self) -> bytes:
|
||||||
|
# generate a new session id if we don't have it yet
|
||||||
|
if not self.session_id:
|
||||||
|
self.session_id[:] = random.bytes(_SESSION_ID_LENGTH)
|
||||||
|
# export it as immutable bytes
|
||||||
|
return bytes(self.session_id)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
super().clear()
|
||||||
|
self.last_usage = 0
|
||||||
|
self.session_id[:] = b""
|
||||||
|
|
||||||
|
|
||||||
|
class SessionlessCache(DataCache):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.fields = (
|
||||||
|
64, # APP_COMMON_SEED_WITHOUT_PASSPHRASE
|
||||||
|
1, # APP_COMMON_SAFETY_CHECKS_TEMPORARY
|
||||||
|
)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
# XXX
|
||||||
|
# Allocation notes:
|
||||||
|
# Instantiation of a DataCache subclass should make as little garbage as possible, so
|
||||||
|
# that the preallocated bytearrays are compact in memory.
|
||||||
|
# That is why the initialization is two-step: first create appropriately sized
|
||||||
|
# bytearrays, then later call `clear()` on all the existing objects, which resets them
|
||||||
|
# to zero length. This is producing some trash - `b[:]` allocates a slice.
|
||||||
|
|
||||||
|
_SESSIONS: list[SessionCache] = []
|
||||||
|
for _ in range(_MAX_SESSIONS_COUNT):
|
||||||
|
_SESSIONS.append(SessionCache())
|
||||||
|
|
||||||
|
_SESSIONLESS_CACHE = SessionlessCache()
|
||||||
|
|
||||||
|
for session in _SESSIONS:
|
||||||
|
session.clear()
|
||||||
|
_SESSIONLESS_CACHE.clear()
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
|
_active_session_idx: int | None = None
|
||||||
|
_session_usage_counter = 0
|
||||||
|
|
||||||
|
|
||||||
def start_session(received_session_id: bytes | None = None) -> bytes:
|
def start_session(received_session_id: bytes | None = None) -> bytes:
|
||||||
if received_session_id and received_session_id in _session_ids:
|
global _active_session_idx
|
||||||
session_id = received_session_id
|
global _session_usage_counter
|
||||||
else:
|
|
||||||
session_id = random.bytes(32)
|
|
||||||
_caches[session_id] = {}
|
|
||||||
|
|
||||||
global _active_session_id
|
if (
|
||||||
_active_session_id = session_id
|
received_session_id is not None
|
||||||
_move_session_ids_queue(session_id)
|
and len(received_session_id) != _SESSION_ID_LENGTH
|
||||||
return _active_session_id
|
):
|
||||||
|
# Prevent the caller from setting received_session_id=b"" and finding a cleared
|
||||||
|
# session. More generally, short-circuit the session id search, because we know
|
||||||
|
# that wrong-length session ids should not be in cache.
|
||||||
|
# Reduce to "session id not provided" case because that's what we do when
|
||||||
|
# caller supplies an id that is not found.
|
||||||
|
received_session_id = None
|
||||||
|
|
||||||
|
_session_usage_counter += 1
|
||||||
|
|
||||||
|
# attempt to find specified session id
|
||||||
|
if received_session_id:
|
||||||
|
for i in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[i].session_id == received_session_id:
|
||||||
|
_active_session_idx = i
|
||||||
|
_SESSIONS[i].last_usage = _session_usage_counter
|
||||||
|
return received_session_id
|
||||||
|
|
||||||
|
# allocate least recently used session
|
||||||
|
lru_counter = _session_usage_counter
|
||||||
|
lru_session_idx = 0
|
||||||
|
for i in range(_MAX_SESSIONS_COUNT):
|
||||||
|
if _SESSIONS[i].last_usage < lru_counter:
|
||||||
|
lru_counter = _SESSIONS[i].last_usage
|
||||||
|
lru_session_idx = i
|
||||||
|
|
||||||
|
_active_session_idx = lru_session_idx
|
||||||
|
selected_session = _SESSIONS[lru_session_idx]
|
||||||
|
selected_session.clear()
|
||||||
|
selected_session.last_usage = _session_usage_counter
|
||||||
|
return selected_session.export_session_id()
|
||||||
|
|
||||||
|
|
||||||
def end_current_session() -> None:
|
def end_current_session() -> None:
|
||||||
global _active_session_id
|
global _active_session_idx
|
||||||
|
|
||||||
if _active_session_id is None:
|
if _active_session_idx is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
current_session_id = _active_session_id
|
_SESSIONS[_active_session_idx].clear()
|
||||||
_active_session_id = None
|
_active_session_idx = None
|
||||||
|
|
||||||
_session_ids.remove(current_session_id)
|
|
||||||
del _caches[current_session_id]
|
|
||||||
|
|
||||||
|
|
||||||
def is_session_started() -> bool:
|
def is_session_started() -> bool:
|
||||||
return _active_session_id is not None
|
return _active_session_idx is not None
|
||||||
|
|
||||||
|
|
||||||
def set(key: int, value: Any) -> None:
|
def set(key: int, value: bytes) -> None:
|
||||||
if key & _SESSIONLESS_FLAG:
|
if key & _SESSIONLESS_FLAG:
|
||||||
_sessionless_cache[key] = value
|
_SESSIONLESS_CACHE.set(key ^ _SESSIONLESS_FLAG, value)
|
||||||
return
|
return
|
||||||
if _active_session_id is None:
|
if _active_session_idx is None:
|
||||||
raise wire.InvalidSession
|
raise InvalidSessionError
|
||||||
_caches[_active_session_id][key] = value
|
_SESSIONS[_active_session_idx].set(key, value)
|
||||||
|
|
||||||
|
|
||||||
def get(key: int) -> Any:
|
def get(key: int) -> bytes:
|
||||||
if key & _SESSIONLESS_FLAG:
|
if key & _SESSIONLESS_FLAG:
|
||||||
return _sessionless_cache.get(key)
|
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG)
|
||||||
if _active_session_id is None:
|
if _active_session_idx is None:
|
||||||
raise wire.InvalidSession
|
raise InvalidSessionError
|
||||||
return _caches[_active_session_id].get(key)
|
return _SESSIONS[_active_session_idx].get(key)
|
||||||
|
|
||||||
|
|
||||||
def delete(key: int) -> None:
|
if False:
|
||||||
if key & _SESSIONLESS_FLAG:
|
from typing import Awaitable, Callable, TypeVar
|
||||||
if key in _sessionless_cache:
|
|
||||||
del _sessionless_cache[key]
|
ByteFunc = TypeVar("ByteFunc", bound=Callable[..., bytes])
|
||||||
return
|
AsyncByteFunc = TypeVar("AsyncByteFunc", bound=Callable[..., Awaitable[bytes]])
|
||||||
if _active_session_id is None:
|
|
||||||
raise wire.InvalidSession
|
|
||||||
if key in _caches[_active_session_id]:
|
|
||||||
del _caches[_active_session_id][key]
|
|
||||||
|
|
||||||
|
|
||||||
def stored(key: int) -> Callable[[F], F]:
|
def stored(key: int) -> Callable[[ByteFunc], ByteFunc]:
|
||||||
def decorator(func: F) -> F:
|
def decorator(func: ByteFunc) -> ByteFunc:
|
||||||
# if we didn't check this, it would be easy to store an Awaitable[something]
|
# if we didn't check this, it would be easy to store an Awaitable[something]
|
||||||
# in cache, which might prove hard to debug
|
# in cache, which might prove hard to debug
|
||||||
|
# XXX mypy should be checking this now, but we don't have full coverage yet
|
||||||
assert not isinstance(func, type(lambda: (yield))), "use stored_async instead"
|
assert not isinstance(func, type(lambda: (yield))), "use stored_async instead"
|
||||||
|
|
||||||
def wrapper(*args, **kwargs): # type: ignore
|
def wrapper(*args, **kwargs): # type: ignore
|
||||||
value = get(key)
|
value = get(key)
|
||||||
if value is None:
|
if not value:
|
||||||
value = func(*args, **kwargs)
|
value = func(*args, **kwargs)
|
||||||
set(key, value)
|
set(key, value)
|
||||||
return value
|
return value
|
||||||
@ -117,8 +201,8 @@ def stored(key: int) -> Callable[[F], F]:
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def stored_async(key: int) -> Callable[[F], F]:
|
def stored_async(key: int) -> Callable[[AsyncByteFunc], AsyncByteFunc]:
|
||||||
def decorator(func: F) -> F:
|
def decorator(func: AsyncByteFunc) -> AsyncByteFunc:
|
||||||
# assert isinstance(func, type(lambda: (yield))), "do not use stored_async"
|
# assert isinstance(func, type(lambda: (yield))), "do not use stored_async"
|
||||||
# XXX the test above fails for closures
|
# XXX the test above fails for closures
|
||||||
# We shouldn't need this test here anyway: the 'await func()' should fail
|
# We shouldn't need this test here anyway: the 'await func()' should fail
|
||||||
@ -126,7 +210,7 @@ def stored_async(key: int) -> Callable[[F], F]:
|
|||||||
|
|
||||||
async def wrapper(*args, **kwargs): # type: ignore
|
async def wrapper(*args, **kwargs): # type: ignore
|
||||||
value = get(key)
|
value = get(key)
|
||||||
if value is None:
|
if not value:
|
||||||
value = await func(*args, **kwargs)
|
value = await func(*args, **kwargs)
|
||||||
set(key, value)
|
set(key, value)
|
||||||
return value
|
return value
|
||||||
@ -137,12 +221,9 @@ def stored_async(key: int) -> Callable[[F], F]:
|
|||||||
|
|
||||||
|
|
||||||
def clear_all() -> None:
|
def clear_all() -> None:
|
||||||
global _active_session_id
|
global _active_session_idx
|
||||||
global _caches
|
|
||||||
global _session_ids
|
|
||||||
global _sessionless_cache
|
|
||||||
|
|
||||||
_active_session_id = None
|
_active_session_idx = None
|
||||||
_caches.clear()
|
_SESSIONLESS_CACHE.clear()
|
||||||
_session_ids.clear()
|
for session in _SESSIONS:
|
||||||
_sessionless_cache.clear()
|
session.clear()
|
||||||
|
@ -36,6 +36,7 @@ reads the message's header. When the message type is known the first handler is
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import protobuf
|
import protobuf
|
||||||
|
from storage.cache import InvalidSessionError
|
||||||
from trezor import log, loop, messages, ui, utils, workflow
|
from trezor import log, loop, messages, ui, utils, workflow
|
||||||
from trezor.messages import FailureType
|
from trezor.messages import FailureType
|
||||||
from trezor.messages.Failure import Failure
|
from trezor.messages.Failure import Failure
|
||||||
@ -492,6 +493,8 @@ def failure(exc: BaseException) -> Failure:
|
|||||||
return Failure(code=exc.code, message=exc.message)
|
return Failure(code=exc.code, message=exc.message)
|
||||||
elif isinstance(exc, loop.TaskClosed):
|
elif isinstance(exc, loop.TaskClosed):
|
||||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||||
|
elif isinstance(exc, InvalidSessionError):
|
||||||
|
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||||
else:
|
else:
|
||||||
return Failure(code=FailureType.FirmwareError, message="Firmware error")
|
return Failure(code=FailureType.FirmwareError, message="Firmware error")
|
||||||
|
|
||||||
|
@ -4,11 +4,11 @@ from mock_storage import mock_storage
|
|||||||
from storage import cache
|
from storage import cache
|
||||||
from trezor.messages.Initialize import Initialize
|
from trezor.messages.Initialize import Initialize
|
||||||
from trezor.messages.EndSession import EndSession
|
from trezor.messages.EndSession import EndSession
|
||||||
from trezor.wire import DUMMY_CONTEXT, InvalidSession
|
from trezor.wire import DUMMY_CONTEXT
|
||||||
|
|
||||||
from apps.base import handle_Initialize, handle_EndSession
|
from apps.base import handle_Initialize, handle_EndSession
|
||||||
|
|
||||||
KEY = 99
|
KEY = 0
|
||||||
|
|
||||||
|
|
||||||
class TestStorageCache(unittest.TestCase):
|
class TestStorageCache(unittest.TestCase):
|
||||||
@ -22,18 +22,18 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertNotEqual(session_id_a, session_id_b)
|
self.assertNotEqual(session_id_a, session_id_b)
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(InvalidSession):
|
with self.assertRaises(cache.InvalidSessionError):
|
||||||
cache.set(KEY, "something")
|
cache.set(KEY, "something")
|
||||||
with self.assertRaises(InvalidSession):
|
with self.assertRaises(cache.InvalidSessionError):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
def test_end_session(self):
|
def test_end_session(self):
|
||||||
session_id = cache.start_session()
|
session_id = cache.start_session()
|
||||||
self.assertTrue(cache.is_session_started())
|
self.assertTrue(cache.is_session_started())
|
||||||
cache.set(KEY, "A")
|
cache.set(KEY, b"A")
|
||||||
cache.end_current_session()
|
cache.end_current_session()
|
||||||
self.assertFalse(cache.is_session_started())
|
self.assertFalse(cache.is_session_started())
|
||||||
self.assertRaises(InvalidSession, cache.get, KEY)
|
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||||
|
|
||||||
# ending an ended session should be a no-op
|
# ending an ended session should be a no-op
|
||||||
cache.end_current_session()
|
cache.end_current_session()
|
||||||
@ -43,7 +43,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
# original session no longer exists
|
# original session no longer exists
|
||||||
self.assertNotEqual(session_id_a, session_id)
|
self.assertNotEqual(session_id_a, session_id)
|
||||||
# original session data no longer exists
|
# original session data no longer exists
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertEqual(cache.get(KEY), b"")
|
||||||
|
|
||||||
# create a new session
|
# create a new session
|
||||||
session_id_b = cache.start_session()
|
session_id_b = cache.start_session()
|
||||||
@ -59,28 +59,28 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
def test_session_queue(self):
|
def test_session_queue(self):
|
||||||
session_id = cache.start_session()
|
session_id = cache.start_session()
|
||||||
self.assertEqual(cache.start_session(session_id), session_id)
|
self.assertEqual(cache.start_session(session_id), session_id)
|
||||||
cache.set(KEY, "A")
|
cache.set(KEY, b"A")
|
||||||
for i in range(cache._MAX_SESSIONS_COUNT):
|
for i in range(cache._MAX_SESSIONS_COUNT):
|
||||||
cache.start_session()
|
cache.start_session()
|
||||||
self.assertNotEqual(cache.start_session(session_id), session_id)
|
self.assertNotEqual(cache.start_session(session_id), session_id)
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertEqual(cache.get(KEY), b"")
|
||||||
|
|
||||||
def test_get_set(self):
|
def test_get_set(self):
|
||||||
session_id1 = cache.start_session()
|
session_id1 = cache.start_session()
|
||||||
cache.set(KEY, "hello")
|
cache.set(KEY, b"hello")
|
||||||
self.assertEqual(cache.get(KEY), "hello")
|
self.assertEqual(cache.get(KEY), b"hello")
|
||||||
|
|
||||||
session_id2 = cache.start_session()
|
session_id2 = cache.start_session()
|
||||||
cache.set(KEY, "world")
|
cache.set(KEY, b"world")
|
||||||
self.assertEqual(cache.get(KEY), "world")
|
self.assertEqual(cache.get(KEY), b"world")
|
||||||
|
|
||||||
cache.start_session(session_id2)
|
cache.start_session(session_id2)
|
||||||
self.assertEqual(cache.get(KEY), "world")
|
self.assertEqual(cache.get(KEY), b"world")
|
||||||
cache.start_session(session_id1)
|
cache.start_session(session_id1)
|
||||||
self.assertEqual(cache.get(KEY), "hello")
|
self.assertEqual(cache.get(KEY), b"hello")
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(InvalidSession):
|
with self.assertRaises(cache.InvalidSessionError):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
def test_decorator_mismatch(self):
|
def test_decorator_mismatch(self):
|
||||||
@ -98,34 +98,34 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
def func():
|
def func():
|
||||||
nonlocal run_count
|
nonlocal run_count
|
||||||
run_count += 1
|
run_count += 1
|
||||||
return "foo"
|
return b"foo"
|
||||||
|
|
||||||
# cache is empty
|
# cache is empty
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertEqual(cache.get(KEY), b"")
|
||||||
self.assertEqual(run_count, 0)
|
self.assertEqual(run_count, 0)
|
||||||
self.assertEqual(func(), "foo")
|
self.assertEqual(func(), b"foo")
|
||||||
# function was run
|
# function was run
|
||||||
self.assertEqual(run_count, 1)
|
self.assertEqual(run_count, 1)
|
||||||
self.assertEqual(cache.get(KEY), "foo")
|
self.assertEqual(cache.get(KEY), b"foo")
|
||||||
# function does not run again but returns cached value
|
# function does not run again but returns cached value
|
||||||
self.assertEqual(func(), "foo")
|
self.assertEqual(func(), b"foo")
|
||||||
self.assertEqual(run_count, 1)
|
self.assertEqual(run_count, 1)
|
||||||
|
|
||||||
@cache.stored_async(KEY)
|
@cache.stored_async(KEY)
|
||||||
async def async_func():
|
async def async_func():
|
||||||
nonlocal run_count
|
nonlocal run_count
|
||||||
run_count += 1
|
run_count += 1
|
||||||
return "bar"
|
return b"bar"
|
||||||
|
|
||||||
# cache is still full
|
# cache is still full
|
||||||
self.assertEqual(await_result(async_func()), "foo")
|
self.assertEqual(await_result(async_func()), b"foo")
|
||||||
self.assertEqual(run_count, 1)
|
self.assertEqual(run_count, 1)
|
||||||
|
|
||||||
cache.start_session()
|
cache.start_session()
|
||||||
self.assertEqual(await_result(async_func()), "bar")
|
self.assertEqual(await_result(async_func()), b"bar")
|
||||||
self.assertEqual(run_count, 2)
|
self.assertEqual(run_count, 2)
|
||||||
# awaitable is also run only once
|
# awaitable is also run only once
|
||||||
self.assertEqual(await_result(async_func()), "bar")
|
self.assertEqual(await_result(async_func()), b"bar")
|
||||||
self.assertEqual(run_count, 2)
|
self.assertEqual(run_count, 2)
|
||||||
|
|
||||||
@mock_storage
|
@mock_storage
|
||||||
@ -144,31 +144,31 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(session_id, features.session_id)
|
self.assertEqual(session_id, features.session_id)
|
||||||
|
|
||||||
# store "hello"
|
# store "hello"
|
||||||
cache.set(KEY, "hello")
|
cache.set(KEY, b"hello")
|
||||||
# check that it is cleared
|
# check that it is cleared
|
||||||
features = call_Initialize()
|
features = call_Initialize()
|
||||||
session_id = features.session_id
|
session_id = features.session_id
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertEqual(cache.get(KEY), b"")
|
||||||
# store "hello" again
|
# store "hello" again
|
||||||
cache.set(KEY, "hello")
|
cache.set(KEY, b"hello")
|
||||||
self.assertEqual(cache.get(KEY), "hello")
|
self.assertEqual(cache.get(KEY), b"hello")
|
||||||
|
|
||||||
# supplying a different session ID starts a new cache
|
# supplying a different session ID starts a new cache
|
||||||
call_Initialize(session_id=b"A")
|
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertEqual(cache.get(KEY), b"")
|
||||||
|
|
||||||
# but resuming a session loads the previous one
|
# but resuming a session loads the previous one
|
||||||
call_Initialize(session_id=session_id)
|
call_Initialize(session_id=session_id)
|
||||||
self.assertEqual(cache.get(KEY), "hello")
|
self.assertEqual(cache.get(KEY), b"hello")
|
||||||
|
|
||||||
def test_EndSession(self):
|
def test_EndSession(self):
|
||||||
self.assertRaises(InvalidSession, cache.get, KEY)
|
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||||
session_id = cache.start_session()
|
session_id = cache.start_session()
|
||||||
self.assertTrue(cache.is_session_started())
|
self.assertTrue(cache.is_session_started())
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertEqual(cache.get(KEY), b"")
|
||||||
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
|
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
|
||||||
self.assertFalse(cache.is_session_started())
|
self.assertFalse(cache.is_session_started())
|
||||||
self.assertRaises(InvalidSession, cache.get, KEY)
|
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user