mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-24 22:31:35 +00:00
97b2e6c751
[no changelog]
243 lines
8.4 KiB
Python
243 lines
8.4 KiB
Python
# flake8: noqa: F403,F405
|
|
from common import * # isort:skip
|
|
|
|
from mock_storage import mock_storage
|
|
from storage import cache, cache_codec, cache_common
|
|
from trezor.messages import EndSession, Initialize
|
|
from trezor.wire import context
|
|
from trezor.wire.codec.codec_context import CodecContext
|
|
|
|
from apps.base import handle_EndSession, handle_Initialize
|
|
from apps.common.cache import stored, stored_async
|
|
|
|
KEY = 0
|
|
|
|
|
|
# Function moved from cache.py, as it was not used there
|
|
def is_session_started() -> bool:
|
|
return cache_codec._active_session_idx is not None
|
|
|
|
|
|
class TestStorageCache(unittest.TestCase):
|
|
|
|
def setUpClass(self):
|
|
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
|
|
|
def tearDownClass(self):
|
|
context.CURRENT_CONTEXT = None
|
|
|
|
def setUp(self):
|
|
cache.clear_all()
|
|
|
|
def test_start_session(self):
|
|
session_id_a = cache_codec.start_session()
|
|
self.assertIsNotNone(session_id_a)
|
|
session_id_b = cache_codec.start_session()
|
|
self.assertNotEqual(session_id_a, session_id_b)
|
|
|
|
cache.clear_all()
|
|
with self.assertRaises(cache_common.InvalidSessionError):
|
|
context.cache_set(KEY, "something")
|
|
with self.assertRaises(cache_common.InvalidSessionError):
|
|
context.cache_get(KEY)
|
|
|
|
def test_end_session(self):
|
|
session_id = cache_codec.start_session()
|
|
self.assertTrue(is_session_started())
|
|
context.cache_set(KEY, b"A")
|
|
cache_codec.end_current_session()
|
|
self.assertFalse(is_session_started())
|
|
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
|
|
|
|
# ending an ended session should be a no-op
|
|
cache_codec.end_current_session()
|
|
self.assertFalse(is_session_started())
|
|
|
|
session_id_a = cache_codec.start_session(session_id)
|
|
# original session no longer exists
|
|
self.assertNotEqual(session_id_a, session_id)
|
|
# original session data no longer exists
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
|
|
# create a new session
|
|
session_id_b = cache_codec.start_session()
|
|
# switch back to original session
|
|
session_id = cache_codec.start_session(session_id_a)
|
|
self.assertEqual(session_id, session_id_a)
|
|
# end original session
|
|
cache_codec.end_current_session()
|
|
# switch back to B
|
|
session_id = cache_codec.start_session(session_id_b)
|
|
self.assertEqual(session_id, session_id_b)
|
|
|
|
def test_session_queue(self):
|
|
session_id = cache_codec.start_session()
|
|
self.assertEqual(cache_codec.start_session(session_id), session_id)
|
|
context.cache_set(KEY, b"A")
|
|
for _ in range(cache_codec._MAX_SESSIONS_COUNT):
|
|
cache_codec.start_session()
|
|
self.assertNotEqual(cache_codec.start_session(session_id), session_id)
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
|
|
def test_get_set(self):
|
|
session_id1 = cache_codec.start_session()
|
|
context.cache_set(KEY, b"hello")
|
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
|
|
|
session_id2 = cache_codec.start_session()
|
|
context.cache_set(KEY, b"world")
|
|
self.assertEqual(context.cache_get(KEY), b"world")
|
|
|
|
cache_codec.start_session(session_id2)
|
|
self.assertEqual(context.cache_get(KEY), b"world")
|
|
cache_codec.start_session(session_id1)
|
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
|
|
|
cache.clear_all()
|
|
with self.assertRaises(cache_common.InvalidSessionError):
|
|
context.cache_get(KEY)
|
|
|
|
def test_get_set_int(self):
|
|
session_id1 = cache_codec.start_session()
|
|
context.cache_set_int(KEY, 1234)
|
|
self.assertEqual(context.cache_get_int(KEY), 1234)
|
|
|
|
session_id2 = cache_codec.start_session()
|
|
context.cache_set_int(KEY, 5678)
|
|
self.assertEqual(context.cache_get_int(KEY), 5678)
|
|
|
|
cache_codec.start_session(session_id2)
|
|
self.assertEqual(context.cache_get_int(KEY), 5678)
|
|
cache_codec.start_session(session_id1)
|
|
self.assertEqual(context.cache_get_int(KEY), 1234)
|
|
|
|
cache.clear_all()
|
|
with self.assertRaises(cache_common.InvalidSessionError):
|
|
context.cache_get_int(KEY)
|
|
|
|
def test_delete(self):
|
|
session_id1 = cache_codec.start_session()
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
context.cache_set(KEY, b"hello")
|
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
|
context.cache_delete(KEY)
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
|
|
context.cache_set(KEY, b"hello")
|
|
cache_codec.start_session()
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
context.cache_set(KEY, b"hello")
|
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
|
context.cache_delete(KEY)
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
|
|
cache_codec.start_session(session_id1)
|
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
|
|
|
def test_decorators(self):
|
|
run_count = 0
|
|
cache_codec.start_session()
|
|
|
|
@stored(KEY)
|
|
def func():
|
|
nonlocal run_count
|
|
run_count += 1
|
|
return b"foo"
|
|
|
|
# cache is empty
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
self.assertEqual(run_count, 0)
|
|
self.assertEqual(func(), b"foo")
|
|
# function was run
|
|
self.assertEqual(run_count, 1)
|
|
self.assertEqual(context.cache_get(KEY), b"foo")
|
|
# function does not run again but returns cached value
|
|
self.assertEqual(func(), b"foo")
|
|
self.assertEqual(run_count, 1)
|
|
|
|
@stored_async(KEY)
|
|
async def async_func():
|
|
nonlocal run_count
|
|
run_count += 1
|
|
return b"bar"
|
|
|
|
# cache is still full
|
|
self.assertEqual(await_result(async_func()), b"foo")
|
|
self.assertEqual(run_count, 1)
|
|
|
|
cache_codec.start_session()
|
|
self.assertEqual(await_result(async_func()), b"bar")
|
|
self.assertEqual(run_count, 2)
|
|
# awaitable is also run only once
|
|
self.assertEqual(await_result(async_func()), b"bar")
|
|
self.assertEqual(run_count, 2)
|
|
|
|
def test_empty_value(self):
|
|
cache_codec.start_session()
|
|
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
context.cache_set(KEY, b"")
|
|
self.assertEqual(context.cache_get(KEY), b"")
|
|
|
|
context.cache_delete(KEY)
|
|
run_count = 0
|
|
|
|
@stored(KEY)
|
|
def func():
|
|
nonlocal run_count
|
|
run_count += 1
|
|
return b""
|
|
|
|
self.assertEqual(func(), b"")
|
|
# function gets called once
|
|
self.assertEqual(run_count, 1)
|
|
self.assertEqual(func(), b"")
|
|
# function is not called for a second time
|
|
self.assertEqual(run_count, 1)
|
|
|
|
@mock_storage
|
|
def test_Initialize(self):
|
|
def call_Initialize(**kwargs):
|
|
msg = Initialize(**kwargs)
|
|
return await_result(handle_Initialize(msg))
|
|
|
|
# calling Initialize without an ID allocates a new one
|
|
session_id = cache_codec.start_session()
|
|
features = call_Initialize()
|
|
self.assertNotEqual(session_id, features.session_id)
|
|
|
|
# calling Initialize with the current ID does not allocate a new one
|
|
features = call_Initialize(session_id=session_id)
|
|
self.assertEqual(session_id, features.session_id)
|
|
|
|
# store "hello"
|
|
context.cache_set(KEY, b"hello")
|
|
# check that it is cleared
|
|
features = call_Initialize()
|
|
session_id = features.session_id
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
# store "hello" again
|
|
context.cache_set(KEY, b"hello")
|
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
|
|
|
# supplying a different session ID starts a new cache
|
|
call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH)
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
|
|
# but resuming a session loads the previous one
|
|
call_Initialize(session_id=session_id)
|
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
|
|
|
def test_EndSession(self):
|
|
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
|
|
cache_codec.start_session()
|
|
self.assertTrue(is_session_started())
|
|
self.assertIsNone(context.cache_get(KEY))
|
|
await_result(handle_EndSession(EndSession()))
|
|
self.assertFalse(is_session_started())
|
|
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|