From 5bbfd40df642ccfd616184797332938115d30f64 Mon Sep 17 00:00:00 2001 From: Andrew Kozlik Date: Tue, 4 Oct 2022 15:39:51 +0200 Subject: [PATCH] feat(core): Add set_int() and get_int() to storage cache. [no changelog] --- core/src/apps/base.py | 9 +++------ core/src/apps/common/request_pin.py | 8 ++------ core/src/storage/cache.py | 25 +++++++++++++++++++++++++ core/tests/test_storage.cache.py | 18 ++++++++++++++++++ 4 files changed, 48 insertions(+), 12 deletions(-) diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 56da6f80e7..77ad0e6ef8 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -29,13 +29,12 @@ def busy_expiry_ms() -> int: Returns the time left until the busy state expires or 0 if the device is not in the busy state. """ - busy_deadline_bytes = storage.cache.get(storage.cache.APP_COMMON_BUSY_DEADLINE_MS) - if busy_deadline_bytes is None: + busy_deadline_ms = storage.cache.get_int(storage.cache.APP_COMMON_BUSY_DEADLINE_MS) + if busy_deadline_ms is None: return 0 import utime - busy_deadline_ms = int.from_bytes(busy_deadline_bytes, "big") expiry_ms = utime.ticks_diff(busy_deadline_ms, utime.ticks_ms()) return expiry_ms if expiry_ms > 0 else 0 @@ -171,9 +170,7 @@ async def handle_SetBusy(ctx: wire.Context, msg: SetBusy) -> Success: import utime deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms) - storage.cache.set( - storage.cache.APP_COMMON_BUSY_DEADLINE_MS, deadline.to_bytes(4, "big") - ) + storage.cache.set_int(storage.cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) else: storage.cache.delete(storage.cache.APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() diff --git a/core/src/apps/common/request_pin.py b/core/src/apps/common/request_pin.py index a4bd233aa1..3788eb3a70 100644 --- a/core/src/apps/common/request_pin.py +++ b/core/src/apps/common/request_pin.py @@ -58,15 +58,11 @@ async def request_pin_and_sd_salt( def _set_last_unlock_time() -> None: now = utime.ticks_ms() - storage.cache.set( - storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now.to_bytes(4, "big") - ) + storage.cache.set_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) def _get_last_unlock_time() -> int: - return int.from_bytes( - storage.cache.get(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" - ) + return storage.cache.get_int(storage.cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK) or 0 async def verify_user_pin( diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index a12e802fea..80509cc66a 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -227,6 +227,23 @@ def set(key: int, value: bytes) -> None: _SESSIONS[_active_session_idx].set(key, value) +def set_int(key: int, value: int) -> None: + if key & _SESSIONLESS_FLAG: + length = _SESSIONLESS_CACHE.fields[key ^ _SESSIONLESS_FLAG] + elif _active_session_idx is None: + raise InvalidSessionError + else: + length = _SESSIONS[_active_session_idx].fields[key] + + encoded = value.to_bytes(length, "big") + + # Ensure that the value fits within the length. Micropython's int.to_bytes() + # doesn't raise OverflowError. + assert int.from_bytes(encoded, "big") == value + + set(key, encoded) + + if TYPE_CHECKING: @overload @@ -246,6 +263,14 @@ def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 return _SESSIONS[_active_session_idx].get(key, default) +def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 + encoded = get(key) + if encoded is None: + return default + else: + return int.from_bytes(encoded, "big") + + def is_set(key: int) -> bool: if key & _SESSIONLESS_FLAG: return _SESSIONLESS_CACHE.is_set(key ^ _SESSIONLESS_FLAG) diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 95aa8d6824..a4e15b1feb 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -83,6 +83,24 @@ class TestStorageCache(unittest.TestCase): with self.assertRaises(cache.InvalidSessionError): cache.get(KEY) + def test_get_set_int(self): + session_id1 = cache.start_session() + cache.set_int(KEY, 1234) + self.assertEqual(cache.get_int(KEY), 1234) + + session_id2 = cache.start_session() + cache.set_int(KEY, 5678) + self.assertEqual(cache.get_int(KEY), 5678) + + cache.start_session(session_id2) + self.assertEqual(cache.get_int(KEY), 5678) + cache.start_session(session_id1) + self.assertEqual(cache.get_int(KEY), 1234) + + cache.clear_all() + with self.assertRaises(cache.InvalidSessionError): + cache.get_int(KEY) + def test_delete(self): session_id1 = cache.start_session() self.assertIsNone(cache.get(KEY))