mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-10 23:40:58 +00:00
fix(core): allow caching empty values (fixes #1659)
This commit is contained in:
parent
74cf309a93
commit
52c34c7364
1
core/.changelog.d/1659.fixed
Normal file
1
core/.changelog.d/1659.fixed
Normal file
@ -0,0 +1 @@
|
||||
Empty passphrase is properly cached in Cardano functions
|
@ -35,13 +35,13 @@ def get() -> protobuf.MessageType | None:
|
||||
return None
|
||||
|
||||
msg_wire_type = int.from_bytes(stored_auth_type, "big")
|
||||
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
|
||||
buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")
|
||||
return protobuf.load_message_buffer(buffer, msg_wire_type)
|
||||
|
||||
|
||||
def get_wire_types() -> Iterable[int]:
|
||||
stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
|
||||
if not stored_auth_type:
|
||||
if stored_auth_type is None:
|
||||
return ()
|
||||
|
||||
msg_wire_type = int.from_bytes(stored_auth_type, "big")
|
||||
@ -49,5 +49,5 @@ def get_wire_types() -> Iterable[int]:
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_TYPE, b"")
|
||||
storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"")
|
||||
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_TYPE)
|
||||
storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_DATA)
|
||||
|
@ -88,7 +88,7 @@ def _set_last_unlock_time() -> None:
|
||||
|
||||
def _get_last_unlock_time() -> int:
|
||||
return int.from_bytes(
|
||||
storage.cache.get(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK), "big"
|
||||
storage.cache.get(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
|
||||
)
|
||||
|
||||
|
||||
|
@ -27,10 +27,10 @@ def apply_setting(level: SafetyCheckLevel) -> None:
|
||||
Changes the safety level settings.
|
||||
"""
|
||||
if level == SafetyCheckLevel.Strict:
|
||||
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
|
||||
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||
elif level == SafetyCheckLevel.PromptAlways:
|
||||
storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"")
|
||||
storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
|
||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
|
||||
elif level == SafetyCheckLevel.PromptTemporarily:
|
||||
storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
|
||||
|
@ -4,7 +4,15 @@ from trezorcrypto import random # avoid pulling in trezor.crypto
|
||||
from trezor import utils
|
||||
|
||||
if False:
|
||||
from typing import Sequence
|
||||
from typing import Sequence, TypeVar, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
else:
|
||||
|
||||
def overload(f) -> None: # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
_MAX_SESSIONS_COUNT = 10
|
||||
_SESSIONLESS_FLAG = 128
|
||||
@ -43,20 +51,35 @@ class DataCache:
|
||||
fields: Sequence[int]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.data = [bytearray(f) for f in self.fields]
|
||||
self.data = [bytearray(f + 1) for f in self.fields]
|
||||
|
||||
def set(self, key: int, value: bytes) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
utils.ensure(len(value) <= self.fields[key])
|
||||
self.data[key][:] = value
|
||||
self.data[key][0] = 1
|
||||
self.data[key][1:] = value
|
||||
|
||||
def get(self, key: int) -> bytes:
|
||||
@overload
|
||||
def get(self, key: int) -> bytes | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def get(self, key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
utils.ensure(key < len(self.fields), "failed to load key %d" % key)
|
||||
return bytes(self.data[key])
|
||||
if self.data[key][0] != 1:
|
||||
return default
|
||||
return bytes(self.data[key][1:])
|
||||
|
||||
def delete(self, key: int) -> None:
|
||||
utils.ensure(key < len(self.fields))
|
||||
self.data[key][:] = b"\x00"
|
||||
|
||||
def clear(self) -> None:
|
||||
for i in range(len(self.fields)):
|
||||
self.set(i, b"")
|
||||
self.delete(i)
|
||||
|
||||
|
||||
class SessionCache(DataCache):
|
||||
@ -184,12 +207,30 @@ def set(key: int, value: bytes) -> None:
|
||||
_SESSIONS[_active_session_idx].set(key, value)
|
||||
|
||||
|
||||
def get(key: int) -> bytes:
|
||||
@overload
|
||||
def get(key: int) -> bytes | None:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def get(key: int, default: T) -> bytes | T: # noqa: F811
|
||||
...
|
||||
|
||||
|
||||
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG)
|
||||
return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default)
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].get(key)
|
||||
return _SESSIONS[_active_session_idx].get(key, default)
|
||||
|
||||
|
||||
def delete(key: int) -> None:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG)
|
||||
if _active_session_idx is None:
|
||||
raise InvalidSessionError
|
||||
return _SESSIONS[_active_session_idx].delete(key)
|
||||
|
||||
|
||||
if False:
|
||||
@ -208,7 +249,7 @@ def stored(key: int) -> Callable[[ByteFunc], ByteFunc]:
|
||||
|
||||
def wrapper(*args, **kwargs): # type: ignore
|
||||
value = get(key)
|
||||
if not value:
|
||||
if value is None:
|
||||
value = func(*args, **kwargs)
|
||||
set(key, value)
|
||||
return value
|
||||
@ -227,7 +268,7 @@ def stored_async(key: int) -> Callable[[AsyncByteFunc], AsyncByteFunc]:
|
||||
|
||||
async def wrapper(*args, **kwargs): # type: ignore
|
||||
value = get(key)
|
||||
if not value:
|
||||
if value is None:
|
||||
value = await func(*args, **kwargs)
|
||||
set(key, value)
|
||||
return value
|
||||
|
@ -43,7 +43,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
# original session no longer exists
|
||||
self.assertNotEqual(session_id_a, session_id)
|
||||
# original session data no longer exists
|
||||
self.assertEqual(cache.get(KEY), b"")
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
|
||||
# create a new session
|
||||
session_id_b = cache.start_session()
|
||||
@ -63,7 +63,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
for i in range(cache._MAX_SESSIONS_COUNT):
|
||||
cache.start_session()
|
||||
self.assertNotEqual(cache.start_session(session_id), session_id)
|
||||
self.assertEqual(cache.get(KEY), b"")
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
|
||||
def test_get_set(self):
|
||||
session_id1 = cache.start_session()
|
||||
@ -83,6 +83,25 @@ class TestStorageCache(unittest.TestCase):
|
||||
with self.assertRaises(cache.InvalidSessionError):
|
||||
cache.get(KEY)
|
||||
|
||||
def test_delete(self):
|
||||
session_id1 = cache.start_session()
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
cache.set(KEY, b"hello")
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
cache.delete(KEY)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
|
||||
cache.set(KEY, b"hello")
|
||||
session_id2 = cache.start_session()
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
cache.set(KEY, b"hello")
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
cache.delete(KEY)
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
|
||||
cache.start_session(session_id1)
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
|
||||
def test_decorator_mismatch(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
|
||||
@ -101,7 +120,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
return b"foo"
|
||||
|
||||
# cache is empty
|
||||
self.assertEqual(cache.get(KEY), b"")
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
self.assertEqual(run_count, 0)
|
||||
self.assertEqual(func(), b"foo")
|
||||
# function was run
|
||||
@ -128,6 +147,29 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(await_result(async_func()), b"bar")
|
||||
self.assertEqual(run_count, 2)
|
||||
|
||||
def test_empty_value(self):
|
||||
cache.start_session()
|
||||
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
cache.set(KEY, b"")
|
||||
self.assertEqual(cache.get(KEY), b"")
|
||||
|
||||
cache.delete(KEY)
|
||||
run_count = 0
|
||||
|
||||
@cache.stored(KEY)
|
||||
def func():
|
||||
nonlocal run_count
|
||||
run_count += 1
|
||||
return b""
|
||||
|
||||
self.assertEqual(func(), b"")
|
||||
# function gets called once
|
||||
self.assertEqual(run_count, 1)
|
||||
self.assertEqual(func(), b"")
|
||||
# function is not called for a second time
|
||||
self.assertEqual(run_count, 1)
|
||||
|
||||
@mock_storage
|
||||
def test_Initialize(self):
|
||||
def call_Initialize(**kwargs):
|
||||
@ -148,14 +190,14 @@ class TestStorageCache(unittest.TestCase):
|
||||
# check that it is cleared
|
||||
features = call_Initialize()
|
||||
session_id = features.session_id
|
||||
self.assertEqual(cache.get(KEY), b"")
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
# store "hello" again
|
||||
cache.set(KEY, b"hello")
|
||||
self.assertEqual(cache.get(KEY), b"hello")
|
||||
|
||||
# supplying a different session ID starts a new cache
|
||||
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
|
||||
self.assertEqual(cache.get(KEY), b"")
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
|
||||
# but resuming a session loads the previous one
|
||||
call_Initialize(session_id=session_id)
|
||||
@ -165,7 +207,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||
session_id = cache.start_session()
|
||||
self.assertTrue(cache.is_session_started())
|
||||
self.assertEqual(cache.get(KEY), b"")
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
|
||||
self.assertFalse(cache.is_session_started())
|
||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
||||
|
Loading…
Reference in New Issue
Block a user