diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 4fe865436..37c5a3abc 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -16,6 +16,7 @@ if False: import protobuf from typing import Iterable, NoReturn, Optional, Protocol from trezor.messages.Initialize import Initialize + from trezor.messages.EndSession import EndSession from trezor.messages.GetFeatures import GetFeatures from trezor.messages.Cancel import Cancel from trezor.messages.LockDevice import LockDevice @@ -116,6 +117,11 @@ async def handle_LockDevice(ctx: wire.Context, msg: LockDevice) -> Success: return Success() +async def handle_EndSession(ctx: wire.Context, msg: EndSession) -> Success: + cache.end_current_session() + return Success() + + async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success: if msg.button_protection: from apps.common.confirm import require_confirm @@ -172,6 +178,7 @@ async def handle_CancelAuthorization( ALLOW_WHILE_LOCKED = ( MessageType.Initialize, + MessageType.EndSession, MessageType.GetFeatures, MessageType.Cancel, MessageType.LockDevice, @@ -249,6 +256,7 @@ def boot() -> None: wire.register(MessageType.GetFeatures, handle_GetFeatures) wire.register(MessageType.Cancel, handle_Cancel) wire.register(MessageType.LockDevice, handle_LockDevice) + wire.register(MessageType.EndSession, handle_EndSession) wire.register(MessageType.Ping, handle_Ping) wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized) wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization) diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 7726983d5..de752cb3d 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,3 +1,4 @@ +from trezor import wire from trezor.crypto import random if False: @@ -52,6 +53,19 @@ def start_session(received_session_id: bytes = None) -> bytes: return _active_session_id +def end_current_session() -> None: + global _active_session_id + + if _active_session_id is None: + return + + current_session_id = _active_session_id + _active_session_id = None + + _session_ids.remove(current_session_id) + del _caches[current_session_id] + + def is_session_started() -> bool: return _active_session_id is not None @@ -61,7 +75,7 @@ def set(key: int, value: Any) -> None: _sessionless_cache[key] = value return if _active_session_id is None: - raise RuntimeError # no session active + raise wire.ProcessError("Invalid session") _caches[_active_session_id][key] = value @@ -69,7 +83,7 @@ def get(key: int) -> Any: if key & _SESSIONLESS_FLAG: return _sessionless_cache.get(key) if _active_session_id is None: - raise RuntimeError # no session active + raise wire.ProcessError("Invalid session") return _caches[_active_session_id].get(key) @@ -79,7 +93,7 @@ def delete(key: int) -> None: del _sessionless_cache[key] return if _active_session_id is None: - raise RuntimeError # no session active + raise wire.ProcessError("Invalid session") if key in _caches[_active_session_id]: del _caches[_active_session_id][key] diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 457917409..3cb0f9fbd 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -3,14 +3,18 @@ from mock_storage import mock_storage from storage import cache from trezor.messages.Initialize import Initialize -from trezor.wire import DUMMY_CONTEXT +from trezor.messages.EndSession import EndSession +from trezor.wire import DUMMY_CONTEXT, ProcessError -from apps.base import handle_Initialize +from apps.base import handle_Initialize, handle_EndSession KEY = 99 class TestStorageCache(unittest.TestCase): + def setUp(self): + cache.clear_all() + def test_start_session(self): session_id_a = cache.start_session() self.assertIsNotNone(session_id_a) @@ -18,11 +22,40 @@ class TestStorageCache(unittest.TestCase): self.assertNotEqual(session_id_a, session_id_b) cache.clear_all() - with self.assertRaises(RuntimeError): + with self.assertRaises(ProcessError): cache.set(KEY, "something") - with self.assertRaises(RuntimeError): + with self.assertRaises(ProcessError): cache.get(KEY) + def test_end_session(self): + session_id = cache.start_session() + self.assertTrue(cache.is_session_started()) + cache.set(KEY, "A") + cache.end_current_session() + self.assertFalse(cache.is_session_started()) + self.assertRaises(ProcessError, cache.get, KEY) + + # ending an ended session should be a no-op + cache.end_current_session() + self.assertFalse(cache.is_session_started()) + + session_id_a = cache.start_session(session_id) + # original session no longer exists + self.assertNotEqual(session_id_a, session_id) + # original session data no longer exists + self.assertIsNone(cache.get(KEY)) + + # create a new session + session_id_b = cache.start_session() + # switch back to original session + session_id = cache.start_session(session_id_a) + self.assertEqual(session_id, session_id_a) + # end original session + cache.end_current_session() + # switch back to B + session_id = cache.start_session(session_id_b) + self.assertEqual(session_id, session_id_b) + def test_session_queue(self): session_id = cache.start_session() self.assertEqual(cache.start_session(session_id), session_id) @@ -47,7 +80,7 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(cache.get(KEY), "hello") cache.clear_all() - with self.assertRaises(RuntimeError): + with self.assertRaises(ProcessError): cache.get(KEY) def test_decorator_mismatch(self): @@ -128,6 +161,15 @@ class TestStorageCache(unittest.TestCase): call_Initialize(session_id=session_id) self.assertEqual(cache.get(KEY), "hello") + def test_EndSession(self): + self.assertRaises(ProcessError, cache.get, KEY) + session_id = cache.start_session() + self.assertTrue(cache.is_session_started()) + self.assertIsNone(cache.get(KEY)) + await_result(handle_EndSession(DUMMY_CONTEXT, EndSession())) + self.assertFalse(cache.is_session_started()) + self.assertRaises(ProcessError, cache.get, KEY) + if __name__ == "__main__": unittest.main()