core: implement EndSession

pull/1236/head
matejcik 4 years ago committed by matejcik
parent c7934116ec
commit 4909821f35

@ -16,6 +16,7 @@ if False:
import protobuf
from typing import Iterable, NoReturn, Optional, Protocol
from trezor.messages.Initialize import Initialize
from trezor.messages.EndSession import EndSession
from trezor.messages.GetFeatures import GetFeatures
from trezor.messages.Cancel import Cancel
from trezor.messages.LockDevice import LockDevice
@ -116,6 +117,11 @@ async def handle_LockDevice(ctx: wire.Context, msg: LockDevice) -> Success:
return Success()
async def handle_EndSession(ctx: wire.Context, msg: EndSession) -> Success:
cache.end_current_session()
return Success()
async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
if msg.button_protection:
from apps.common.confirm import require_confirm
@ -172,6 +178,7 @@ async def handle_CancelAuthorization(
ALLOW_WHILE_LOCKED = (
MessageType.Initialize,
MessageType.EndSession,
MessageType.GetFeatures,
MessageType.Cancel,
MessageType.LockDevice,
@ -249,6 +256,7 @@ def boot() -> None:
wire.register(MessageType.GetFeatures, handle_GetFeatures)
wire.register(MessageType.Cancel, handle_Cancel)
wire.register(MessageType.LockDevice, handle_LockDevice)
wire.register(MessageType.EndSession, handle_EndSession)
wire.register(MessageType.Ping, handle_Ping)
wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization)

@ -1,3 +1,4 @@
from trezor import wire
from trezor.crypto import random
if False:
@ -52,6 +53,19 @@ def start_session(received_session_id: bytes = None) -> bytes:
return _active_session_id
def end_current_session() -> None:
global _active_session_id
if _active_session_id is None:
return
current_session_id = _active_session_id
_active_session_id = None
_session_ids.remove(current_session_id)
del _caches[current_session_id]
def is_session_started() -> bool:
return _active_session_id is not None
@ -61,7 +75,7 @@ def set(key: int, value: Any) -> None:
_sessionless_cache[key] = value
return
if _active_session_id is None:
raise RuntimeError # no session active
raise wire.ProcessError("Invalid session")
_caches[_active_session_id][key] = value
@ -69,7 +83,7 @@ def get(key: int) -> Any:
if key & _SESSIONLESS_FLAG:
return _sessionless_cache.get(key)
if _active_session_id is None:
raise RuntimeError # no session active
raise wire.ProcessError("Invalid session")
return _caches[_active_session_id].get(key)
@ -79,7 +93,7 @@ def delete(key: int) -> None:
del _sessionless_cache[key]
return
if _active_session_id is None:
raise RuntimeError # no session active
raise wire.ProcessError("Invalid session")
if key in _caches[_active_session_id]:
del _caches[_active_session_id][key]

@ -3,14 +3,18 @@ from mock_storage import mock_storage
from storage import cache
from trezor.messages.Initialize import Initialize
from trezor.wire import DUMMY_CONTEXT
from trezor.messages.EndSession import EndSession
from trezor.wire import DUMMY_CONTEXT, ProcessError
from apps.base import handle_Initialize
from apps.base import handle_Initialize, handle_EndSession
KEY = 99
class TestStorageCache(unittest.TestCase):
def setUp(self):
cache.clear_all()
def test_start_session(self):
session_id_a = cache.start_session()
self.assertIsNotNone(session_id_a)
@ -18,11 +22,40 @@ class TestStorageCache(unittest.TestCase):
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all()
with self.assertRaises(RuntimeError):
with self.assertRaises(ProcessError):
cache.set(KEY, "something")
with self.assertRaises(RuntimeError):
with self.assertRaises(ProcessError):
cache.get(KEY)
def test_end_session(self):
session_id = cache.start_session()
self.assertTrue(cache.is_session_started())
cache.set(KEY, "A")
cache.end_current_session()
self.assertFalse(cache.is_session_started())
self.assertRaises(ProcessError, cache.get, KEY)
# ending an ended session should be a no-op
cache.end_current_session()
self.assertFalse(cache.is_session_started())
session_id_a = cache.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(cache.get(KEY))
# create a new session
session_id_b = cache.start_session()
# switch back to original session
session_id = cache.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache.end_current_session()
# switch back to B
session_id = cache.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
def test_session_queue(self):
session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id)
@ -47,7 +80,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), "hello")
cache.clear_all()
with self.assertRaises(RuntimeError):
with self.assertRaises(ProcessError):
cache.get(KEY)
def test_decorator_mismatch(self):
@ -128,6 +161,15 @@ class TestStorageCache(unittest.TestCase):
call_Initialize(session_id=session_id)
self.assertEqual(cache.get(KEY), "hello")
def test_EndSession(self):
self.assertRaises(ProcessError, cache.get, KEY)
session_id = cache.start_session()
self.assertTrue(cache.is_session_started())
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
self.assertFalse(cache.is_session_started())
self.assertRaises(ProcessError, cache.get, KEY)
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save