From 52c34c73648072c2fc061f33e2a4d37ca3a6973c Mon Sep 17 00:00:00 2001 From: matejcik Date: Thu, 10 Jun 2021 16:13:24 +0200 Subject: [PATCH] fix(core): allow caching empty values (fixes #1659) --- core/.changelog.d/1659.fixed | 1 + core/src/apps/common/authorization.py | 8 ++-- core/src/apps/common/request_pin.py | 2 +- core/src/apps/common/safety_checks.py | 4 +- core/src/storage/cache.py | 63 ++++++++++++++++++++++----- core/tests/test_storage.cache.py | 54 ++++++++++++++++++++--- 6 files changed, 108 insertions(+), 24 deletions(-) create mode 100644 core/.changelog.d/1659.fixed diff --git a/core/.changelog.d/1659.fixed b/core/.changelog.d/1659.fixed new file mode 100644 index 000000000..38274f474 --- /dev/null +++ b/core/.changelog.d/1659.fixed @@ -0,0 +1 @@ +Empty passphrase is properly cached in Cardano functions diff --git a/core/src/apps/common/authorization.py b/core/src/apps/common/authorization.py index 715e862aa..4bba8edc9 100644 --- a/core/src/apps/common/authorization.py +++ b/core/src/apps/common/authorization.py @@ -35,13 +35,13 @@ def get() -> protobuf.MessageType | None: return None msg_wire_type = int.from_bytes(stored_auth_type, "big") - buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA) + buffer = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"") return protobuf.load_message_buffer(buffer, msg_wire_type) def get_wire_types() -> Iterable[int]: stored_auth_type = storage.cache.get(storage.cache.APP_COMMON_AUTHORIZATION_TYPE) - if not stored_auth_type: + if stored_auth_type is None: return () msg_wire_type = int.from_bytes(stored_auth_type, "big") @@ -49,5 +49,5 @@ def get_wire_types() -> Iterable[int]: def clear() -> None: - storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_TYPE, b"") - storage.cache.set(storage.cache.APP_COMMON_AUTHORIZATION_DATA, b"") + storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_TYPE) + storage.cache.delete(storage.cache.APP_COMMON_AUTHORIZATION_DATA) diff --git a/core/src/apps/common/request_pin.py b/core/src/apps/common/request_pin.py index 0f08c631a..e01b8723d 100644 --- a/core/src/apps/common/request_pin.py +++ b/core/src/apps/common/request_pin.py @@ -88,7 +88,7 @@ def _set_last_unlock_time() -> None: def _get_last_unlock_time() -> int: return int.from_bytes( - storage.cache.get(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK), "big" + storage.cache.get(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" ) diff --git a/core/src/apps/common/safety_checks.py b/core/src/apps/common/safety_checks.py index 6cfae81ac..be8ecef08 100644 --- a/core/src/apps/common/safety_checks.py +++ b/core/src/apps/common/safety_checks.py @@ -27,10 +27,10 @@ def apply_setting(level: SafetyCheckLevel) -> None: Changes the safety level settings. """ if level == SafetyCheckLevel.Strict: - storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"") + storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) elif level == SafetyCheckLevel.PromptAlways: - storage.cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, b"") + storage.cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) elif level == SafetyCheckLevel.PromptTemporarily: storage.device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 3ea73d2d1..3be662743 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -4,7 +4,15 @@ from trezorcrypto import random # avoid pulling in trezor.crypto from trezor import utils if False: - from typing import Sequence + from typing import Sequence, TypeVar, overload + + T = TypeVar("T") + +else: + + def overload(f) -> None: # type: ignore + pass + _MAX_SESSIONS_COUNT = 10 _SESSIONLESS_FLAG = 128 @@ -43,20 +51,35 @@ class DataCache: fields: Sequence[int] def __init__(self) -> None: - self.data = [bytearray(f) for f in self.fields] + self.data = [bytearray(f + 1) for f in self.fields] def set(self, key: int, value: bytes) -> None: utils.ensure(key < len(self.fields)) utils.ensure(len(value) <= self.fields[key]) - self.data[key][:] = value + self.data[key][0] = 1 + self.data[key][1:] = value + + @overload + def get(self, key: int) -> bytes | None: + ... - def get(self, key: int) -> bytes: + @overload + def get(self, key: int, default: T) -> bytes | T: # noqa: F811 + ... + + def get(self, key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 utils.ensure(key < len(self.fields), "failed to load key %d" % key) - return bytes(self.data[key]) + if self.data[key][0] != 1: + return default + return bytes(self.data[key][1:]) + + def delete(self, key: int) -> None: + utils.ensure(key < len(self.fields)) + self.data[key][:] = b"\x00" def clear(self) -> None: for i in range(len(self.fields)): - self.set(i, b"") + self.delete(i) class SessionCache(DataCache): @@ -184,12 +207,30 @@ def set(key: int, value: bytes) -> None: _SESSIONS[_active_session_idx].set(key, value) -def get(key: int) -> bytes: +@overload +def get(key: int) -> bytes | None: + ... + + +@overload +def get(key: int, default: T) -> bytes | T: # noqa: F811 + ... + + +def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + if key & _SESSIONLESS_FLAG: + return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG, default) + if _active_session_idx is None: + raise InvalidSessionError + return _SESSIONS[_active_session_idx].get(key, default) + + +def delete(key: int) -> None: if key & _SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.get(key ^ _SESSIONLESS_FLAG) + return _SESSIONLESS_CACHE.delete(key ^ _SESSIONLESS_FLAG) if _active_session_idx is None: raise InvalidSessionError - return _SESSIONS[_active_session_idx].get(key) + return _SESSIONS[_active_session_idx].delete(key) if False: @@ -208,7 +249,7 @@ def stored(key: int) -> Callable[[ByteFunc], ByteFunc]: def wrapper(*args, **kwargs): # type: ignore value = get(key) - if not value: + if value is None: value = func(*args, **kwargs) set(key, value) return value @@ -227,7 +268,7 @@ def stored_async(key: int) -> Callable[[AsyncByteFunc], AsyncByteFunc]: async def wrapper(*args, **kwargs): # type: ignore value = get(key) - if not value: + if value is None: value = await func(*args, **kwargs) set(key, value) return value diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index cc2bfcf3e..5337685f8 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -43,7 +43,7 @@ class TestStorageCache(unittest.TestCase): # original session no longer exists self.assertNotEqual(session_id_a, session_id) # original session data no longer exists - self.assertEqual(cache.get(KEY), b"") + self.assertIsNone(cache.get(KEY)) # create a new session session_id_b = cache.start_session() @@ -63,7 +63,7 @@ class TestStorageCache(unittest.TestCase): for i in range(cache._MAX_SESSIONS_COUNT): cache.start_session() self.assertNotEqual(cache.start_session(session_id), session_id) - self.assertEqual(cache.get(KEY), b"") + self.assertIsNone(cache.get(KEY)) def test_get_set(self): session_id1 = cache.start_session() @@ -83,6 +83,25 @@ class TestStorageCache(unittest.TestCase): with self.assertRaises(cache.InvalidSessionError): cache.get(KEY) + def test_delete(self): + session_id1 = cache.start_session() + self.assertIsNone(cache.get(KEY)) + cache.set(KEY, b"hello") + self.assertEqual(cache.get(KEY), b"hello") + cache.delete(KEY) + self.assertIsNone(cache.get(KEY)) + + cache.set(KEY, b"hello") + session_id2 = cache.start_session() + self.assertIsNone(cache.get(KEY)) + cache.set(KEY, b"hello") + self.assertEqual(cache.get(KEY), b"hello") + cache.delete(KEY) + self.assertIsNone(cache.get(KEY)) + + cache.start_session(session_id1) + self.assertEqual(cache.get(KEY), b"hello") + def test_decorator_mismatch(self): with self.assertRaises(AssertionError): @@ -101,7 +120,7 @@ class TestStorageCache(unittest.TestCase): return b"foo" # cache is empty - self.assertEqual(cache.get(KEY), b"") + self.assertIsNone(cache.get(KEY)) self.assertEqual(run_count, 0) self.assertEqual(func(), b"foo") # function was run @@ -128,6 +147,29 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(await_result(async_func()), b"bar") self.assertEqual(run_count, 2) + def test_empty_value(self): + cache.start_session() + + self.assertIsNone(cache.get(KEY)) + cache.set(KEY, b"") + self.assertEqual(cache.get(KEY), b"") + + cache.delete(KEY) + run_count = 0 + + @cache.stored(KEY) + def func(): + nonlocal run_count + run_count += 1 + return b"" + + self.assertEqual(func(), b"") + # function gets called once + self.assertEqual(run_count, 1) + self.assertEqual(func(), b"") + # function is not called for a second time + self.assertEqual(run_count, 1) + @mock_storage def test_Initialize(self): def call_Initialize(**kwargs): @@ -148,14 +190,14 @@ class TestStorageCache(unittest.TestCase): # check that it is cleared features = call_Initialize() session_id = features.session_id - self.assertEqual(cache.get(KEY), b"") + self.assertIsNone(cache.get(KEY)) # store "hello" again cache.set(KEY, b"hello") self.assertEqual(cache.get(KEY), b"hello") # supplying a different session ID starts a new cache call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH) - self.assertEqual(cache.get(KEY), b"") + self.assertIsNone(cache.get(KEY)) # but resuming a session loads the previous one call_Initialize(session_id=session_id) @@ -165,7 +207,7 @@ class TestStorageCache(unittest.TestCase): self.assertRaises(cache.InvalidSessionError, cache.get, KEY) session_id = cache.start_session() self.assertTrue(cache.is_session_started()) - self.assertEqual(cache.get(KEY), b"") + self.assertIsNone(cache.get(KEY)) await_result(handle_EndSession(DUMMY_CONTEXT, EndSession())) self.assertFalse(cache.is_session_started()) self.assertRaises(cache.InvalidSessionError, cache.get, KEY)