core: implement EndSession

pull/1236/head
matejcik 4 years ago committed by matejcik
parent c7934116ec
commit 4909821f35

@ -16,6 +16,7 @@ if False:
import protobuf import protobuf
from typing import Iterable, NoReturn, Optional, Protocol from typing import Iterable, NoReturn, Optional, Protocol
from trezor.messages.Initialize import Initialize from trezor.messages.Initialize import Initialize
from trezor.messages.EndSession import EndSession
from trezor.messages.GetFeatures import GetFeatures from trezor.messages.GetFeatures import GetFeatures
from trezor.messages.Cancel import Cancel from trezor.messages.Cancel import Cancel
from trezor.messages.LockDevice import LockDevice from trezor.messages.LockDevice import LockDevice
@ -116,6 +117,11 @@ async def handle_LockDevice(ctx: wire.Context, msg: LockDevice) -> Success:
return Success() return Success()
async def handle_EndSession(ctx: wire.Context, msg: EndSession) -> Success:
cache.end_current_session()
return Success()
async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success: async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
if msg.button_protection: if msg.button_protection:
from apps.common.confirm import require_confirm from apps.common.confirm import require_confirm
@ -172,6 +178,7 @@ async def handle_CancelAuthorization(
ALLOW_WHILE_LOCKED = ( ALLOW_WHILE_LOCKED = (
MessageType.Initialize, MessageType.Initialize,
MessageType.EndSession,
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.Cancel, MessageType.Cancel,
MessageType.LockDevice, MessageType.LockDevice,
@ -249,6 +256,7 @@ def boot() -> None:
wire.register(MessageType.GetFeatures, handle_GetFeatures) wire.register(MessageType.GetFeatures, handle_GetFeatures)
wire.register(MessageType.Cancel, handle_Cancel) wire.register(MessageType.Cancel, handle_Cancel)
wire.register(MessageType.LockDevice, handle_LockDevice) wire.register(MessageType.LockDevice, handle_LockDevice)
wire.register(MessageType.EndSession, handle_EndSession)
wire.register(MessageType.Ping, handle_Ping) wire.register(MessageType.Ping, handle_Ping)
wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized) wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization) wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization)

@ -1,3 +1,4 @@
from trezor import wire
from trezor.crypto import random from trezor.crypto import random
if False: if False:
@ -52,6 +53,19 @@ def start_session(received_session_id: bytes = None) -> bytes:
return _active_session_id return _active_session_id
def end_current_session() -> None:
global _active_session_id
if _active_session_id is None:
return
current_session_id = _active_session_id
_active_session_id = None
_session_ids.remove(current_session_id)
del _caches[current_session_id]
def is_session_started() -> bool: def is_session_started() -> bool:
return _active_session_id is not None return _active_session_id is not None
@ -61,7 +75,7 @@ def set(key: int, value: Any) -> None:
_sessionless_cache[key] = value _sessionless_cache[key] = value
return return
if _active_session_id is None: if _active_session_id is None:
raise RuntimeError # no session active raise wire.ProcessError("Invalid session")
_caches[_active_session_id][key] = value _caches[_active_session_id][key] = value
@ -69,7 +83,7 @@ def get(key: int) -> Any:
if key & _SESSIONLESS_FLAG: if key & _SESSIONLESS_FLAG:
return _sessionless_cache.get(key) return _sessionless_cache.get(key)
if _active_session_id is None: if _active_session_id is None:
raise RuntimeError # no session active raise wire.ProcessError("Invalid session")
return _caches[_active_session_id].get(key) return _caches[_active_session_id].get(key)
@ -79,7 +93,7 @@ def delete(key: int) -> None:
del _sessionless_cache[key] del _sessionless_cache[key]
return return
if _active_session_id is None: if _active_session_id is None:
raise RuntimeError # no session active raise wire.ProcessError("Invalid session")
if key in _caches[_active_session_id]: if key in _caches[_active_session_id]:
del _caches[_active_session_id][key] del _caches[_active_session_id][key]

@ -3,14 +3,18 @@ from mock_storage import mock_storage
from storage import cache from storage import cache
from trezor.messages.Initialize import Initialize from trezor.messages.Initialize import Initialize
from trezor.wire import DUMMY_CONTEXT from trezor.messages.EndSession import EndSession
from trezor.wire import DUMMY_CONTEXT, ProcessError
from apps.base import handle_Initialize from apps.base import handle_Initialize, handle_EndSession
KEY = 99 KEY = 99
class TestStorageCache(unittest.TestCase): class TestStorageCache(unittest.TestCase):
def setUp(self):
cache.clear_all()
def test_start_session(self): def test_start_session(self):
session_id_a = cache.start_session() session_id_a = cache.start_session()
self.assertIsNotNone(session_id_a) self.assertIsNotNone(session_id_a)
@ -18,11 +22,40 @@ class TestStorageCache(unittest.TestCase):
self.assertNotEqual(session_id_a, session_id_b) self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all() cache.clear_all()
with self.assertRaises(RuntimeError): with self.assertRaises(ProcessError):
cache.set(KEY, "something") cache.set(KEY, "something")
with self.assertRaises(RuntimeError): with self.assertRaises(ProcessError):
cache.get(KEY) cache.get(KEY)
def test_end_session(self):
session_id = cache.start_session()
self.assertTrue(cache.is_session_started())
cache.set(KEY, "A")
cache.end_current_session()
self.assertFalse(cache.is_session_started())
self.assertRaises(ProcessError, cache.get, KEY)
# ending an ended session should be a no-op
cache.end_current_session()
self.assertFalse(cache.is_session_started())
session_id_a = cache.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(cache.get(KEY))
# create a new session
session_id_b = cache.start_session()
# switch back to original session
session_id = cache.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache.end_current_session()
# switch back to B
session_id = cache.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
def test_session_queue(self): def test_session_queue(self):
session_id = cache.start_session() session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id) self.assertEqual(cache.start_session(session_id), session_id)
@ -47,7 +80,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), "hello") self.assertEqual(cache.get(KEY), "hello")
cache.clear_all() cache.clear_all()
with self.assertRaises(RuntimeError): with self.assertRaises(ProcessError):
cache.get(KEY) cache.get(KEY)
def test_decorator_mismatch(self): def test_decorator_mismatch(self):
@ -128,6 +161,15 @@ class TestStorageCache(unittest.TestCase):
call_Initialize(session_id=session_id) call_Initialize(session_id=session_id)
self.assertEqual(cache.get(KEY), "hello") self.assertEqual(cache.get(KEY), "hello")
def test_EndSession(self):
self.assertRaises(ProcessError, cache.get, KEY)
session_id = cache.start_session()
self.assertTrue(cache.is_session_started())
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
self.assertFalse(cache.is_session_started())
self.assertRaises(ProcessError, cache.get, KEY)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Loading…
Cancel
Save