1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-10 23:40:58 +00:00

fix(core): allow caching empty values (fixes #1659)

This commit is contained in:
matejcik 2021-06-10 16:13:24 +02:00 committed by matejcik
parent 74cf309a93
commit 52c34c7364
6 changed files with 108 additions and 24 deletions

View File

@ -0,0 +1 @@
Empty passphrase is properly cached in Cardano functions

View File

@ -35,13 +35,13 @@ def get() -> protobuf.MessageType | None:
return None
msg_wire_type = int.from_bytes(stored_auth_type, "big")
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")
return protobuf.load_message_buffer(buffer, msg_wire_type)
def get_wire_types() -> Iterable[int]:
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
if stored_auth_type is None:
return ()
msg_wire_type = int.from_bytes(stored_auth_type, "big")
@ -49,5 +49,5 @@ def get_wire_types() -> Iterable[int]:
def clear() -> None:
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_TYPE, b"")
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_DATA)

View File

@ -88,7 +88,7 @@ def _set_last_unlock_time() -> None:
def _get_last_unlock_time() -> int:
return int.from_bytes(
storage.cache.get(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK), "big"
storage.cache.get(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
)

View File

@ -27,10 +27,10 @@ def apply_setting(level: SafetyCheckLevel) -> None:
Changes the safety level settings.
"""
if level == SafetyCheckLevel.Strict:
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways:
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily:
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)

View File

@ -4,7 +4,15 @@ from trezorcrypto import random # avoid pulling in trezor.crypto
from trezor import utils
if False:
from typing import Sequence
from typing import Sequence, TypeVar, overload
T = TypeVar("T")
else:
def overload(f) -> None: # type: ignore
pass
_MAX_SESSIONS_COUNT = 10
_SESSIONLESS_FLAG = 128
@ -43,20 +51,35 @@ class DataCache:
fields: Sequence[int]
def __init__(self) -> None:
self.data = [bytearray(f) for f in self.fields]
self.data = [bytearray(f + 1) for f in self.fields]
def set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.fields))
utils.ensure(len(value) <= self.fields[key])
self.data[key][:] = value
self.data[key][0] = 1
self.data[key][1:] = value
def get(self, key: int) -> bytes:
@overload
def get(self, key: int) -> bytes | None:
...
@overload
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.fields), "failed to load key %d" % key)
return bytes(self.data[key])
if self.data[key][0] != 1:
return default
return bytes(self.data[key][1:])
def delete(self, key: int) -> None:
utils.ensure(key < len(self.fields))
self.data[key][:] = b"\x00"
def clear(self) -> None:
for i in range(len(self.fields)):
self.set(i, b"")
self.delete(i)
class SessionCache(DataCache):
@ -184,12 +207,30 @@ def set(key: int, value: bytes) -> None:
_SESSIONS[_active_session_idx].set(key, value)
def get(key: int) -> bytes:
@overload
def get(key: int) -> bytes | None:
...
@overload
def get(key: int, default: T) -> bytes | T: # noqa: F811
...
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG)
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].get(key)
return _SESSIONS[_active_session_idx].get(key, default)
def delete(key: int) -> None:
if key & _SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
if _active_session_idx is None:
raise InvalidSessionError
return _SESSIONS[_active_session_idx].delete(key)
if False:
@ -208,7 +249,7 @@ def stored(key: int) -> Callable[[ByteFunc], ByteFunc]:
def wrapper(*args, **kwargs): # type: ignore
value = get(key)
if not value:
if value is None:
value = func(*args, **kwargs)
set(key, value)
return value
@ -227,7 +268,7 @@ def stored_async(key: int) -> Callable[[AsyncByteFunc], AsyncByteFunc]:
async def wrapper(*args, **kwargs): # type: ignore
value = get(key)
if not value:
if value is None:
value = await func(*args, **kwargs)
set(key, value)
return value

View File

@ -43,7 +43,7 @@ class TestStorageCache(unittest.TestCase):
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertEqual(cache.get(KEY), b"")
self.assertIsNone(cache.get(KEY))
# create a new session
session_id_b = cache.start_session()
@ -63,7 +63,7 @@ class TestStorageCache(unittest.TestCase):
for i in range(cache._MAX_SESSIONS_COUNT):
cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertEqual(cache.get(KEY), b"")
self.assertIsNone(cache.get(KEY))
def test_get_set(self):
session_id1 = cache.start_session()
@ -83,6 +83,25 @@ class TestStorageCache(unittest.TestCase):
with self.assertRaises(cache.InvalidSessionError):
cache.get(KEY)
def test_delete(self):
session_id1 = cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.delete(KEY)
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
session_id2 = cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.delete(KEY)
self.assertIsNone(cache.get(KEY))
cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello")
def test_decorator_mismatch(self):
with self.assertRaises(AssertionError):
@ -101,7 +120,7 @@ class TestStorageCache(unittest.TestCase):
return b"foo"
# cache is empty
self.assertEqual(cache.get(KEY), b"")
self.assertIsNone(cache.get(KEY))
self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo")
# function was run
@ -128,6 +147,29 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
def test_empty_value(self):
cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"")
self.assertEqual(cache.get(KEY), b"")
cache.delete(KEY)
run_count = 0
@cache.stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b""
self.assertEqual(func(), b"")
# function gets called once
self.assertEqual(run_count, 1)
self.assertEqual(func(), b"")
# function is not called for a second time
self.assertEqual(run_count, 1)
@mock_storage
def test_Initialize(self):
def call_Initialize(**kwargs):
@ -148,14 +190,14 @@ class TestStorageCache(unittest.TestCase):
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertEqual(cache.get(KEY), b"")
self.assertIsNone(cache.get(KEY))
# store "hello" again
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
self.assertEqual(cache.get(KEY), b"")
self.assertIsNone(cache.get(KEY))
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
@ -165,7 +207,7 @@ class TestStorageCache(unittest.TestCase):
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
session_id = cache.start_session()
self.assertTrue(cache.is_session_started())
self.assertEqual(cache.get(KEY), b"")
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
self.assertFalse(cache.is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)