mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-29 19:08:12 +00:00
core: implement EndSession
This commit is contained in:
parent
c7934116ec
commit
4909821f35
@ -16,6 +16,7 @@ if False:
|
||||
import protobuf
|
||||
from typing import Iterable, NoReturn, Optional, Protocol
|
||||
from trezor.messages.Initialize import Initialize
|
||||
from trezor.messages.EndSession import EndSession
|
||||
from trezor.messages.GetFeatures import GetFeatures
|
||||
from trezor.messages.Cancel import Cancel
|
||||
from trezor.messages.LockDevice import LockDevice
|
||||
@ -116,6 +117,11 @@ async def handle_LockDevice(ctx: wire.Context, msg: LockDevice) -> Success:
|
||||
return Success()
|
||||
|
||||
|
||||
async def handle_EndSession(ctx: wire.Context, msg: EndSession) -> Success:
|
||||
cache.end_current_session()
|
||||
return Success()
|
||||
|
||||
|
||||
async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
|
||||
if msg.button_protection:
|
||||
from apps.common.confirm import require_confirm
|
||||
@ -172,6 +178,7 @@ async def handle_CancelAuthorization(
|
||||
|
||||
ALLOW_WHILE_LOCKED = (
|
||||
MessageType.Initialize,
|
||||
MessageType.EndSession,
|
||||
MessageType.GetFeatures,
|
||||
MessageType.Cancel,
|
||||
MessageType.LockDevice,
|
||||
@ -249,6 +256,7 @@ def boot() -> None:
|
||||
wire.register(MessageType.GetFeatures, handle_GetFeatures)
|
||||
wire.register(MessageType.Cancel, handle_Cancel)
|
||||
wire.register(MessageType.LockDevice, handle_LockDevice)
|
||||
wire.register(MessageType.EndSession, handle_EndSession)
|
||||
wire.register(MessageType.Ping, handle_Ping)
|
||||
wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
|
||||
wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from trezor import wire
|
||||
from trezor.crypto import random
|
||||
|
||||
if False:
|
||||
@ -52,6 +53,19 @@ def start_session(received_session_id: bytes = None) -> bytes:
|
||||
return _active_session_id
|
||||
|
||||
|
||||
def end_current_session() -> None:
|
||||
global _active_session_id
|
||||
|
||||
if _active_session_id is None:
|
||||
return
|
||||
|
||||
current_session_id = _active_session_id
|
||||
_active_session_id = None
|
||||
|
||||
_session_ids.remove(current_session_id)
|
||||
del _caches[current_session_id]
|
||||
|
||||
|
||||
def is_session_started() -> bool:
|
||||
return _active_session_id is not None
|
||||
|
||||
@ -61,7 +75,7 @@ def set(key: int, value: Any) -> None:
|
||||
_sessionless_cache[key] = value
|
||||
return
|
||||
if _active_session_id is None:
|
||||
raise RuntimeError # no session active
|
||||
raise wire.ProcessError("Invalid session")
|
||||
_caches[_active_session_id][key] = value
|
||||
|
||||
|
||||
@ -69,7 +83,7 @@ def get(key: int) -> Any:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _sessionless_cache.get(key)
|
||||
if _active_session_id is None:
|
||||
raise RuntimeError # no session active
|
||||
raise wire.ProcessError("Invalid session")
|
||||
return _caches[_active_session_id].get(key)
|
||||
|
||||
|
||||
@ -79,7 +93,7 @@ def delete(key: int) -> None:
|
||||
del _sessionless_cache[key]
|
||||
return
|
||||
if _active_session_id is None:
|
||||
raise RuntimeError # no session active
|
||||
raise wire.ProcessError("Invalid session")
|
||||
if key in _caches[_active_session_id]:
|
||||
del _caches[_active_session_id][key]
|
||||
|
||||
|
@ -3,14 +3,18 @@ from mock_storage import mock_storage
|
||||
|
||||
from storage import cache
|
||||
from trezor.messages.Initialize import Initialize
|
||||
from trezor.wire import DUMMY_CONTEXT
|
||||
from trezor.messages.EndSession import EndSession
|
||||
from trezor.wire import DUMMY_CONTEXT, ProcessError
|
||||
|
||||
from apps.base import handle_Initialize
|
||||
from apps.base import handle_Initialize, handle_EndSession
|
||||
|
||||
KEY = 99
|
||||
|
||||
|
||||
class TestStorageCache(unittest.TestCase):
|
||||
def setUp(self):
|
||||
cache.clear_all()
|
||||
|
||||
def test_start_session(self):
|
||||
session_id_a = cache.start_session()
|
||||
self.assertIsNotNone(session_id_a)
|
||||
@ -18,11 +22,40 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertNotEqual(session_id_a, session_id_b)
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises(ProcessError):
|
||||
cache.set(KEY, "something")
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises(ProcessError):
|
||||
cache.get(KEY)
|
||||
|
||||
def test_end_session(self):
|
||||
session_id = cache.start_session()
|
||||
self.assertTrue(cache.is_session_started())
|
||||
cache.set(KEY, "A")
|
||||
cache.end_current_session()
|
||||
self.assertFalse(cache.is_session_started())
|
||||
self.assertRaises(ProcessError, cache.get, KEY)
|
||||
|
||||
# ending an ended session should be a no-op
|
||||
cache.end_current_session()
|
||||
self.assertFalse(cache.is_session_started())
|
||||
|
||||
session_id_a = cache.start_session(session_id)
|
||||
# original session no longer exists
|
||||
self.assertNotEqual(session_id_a, session_id)
|
||||
# original session data no longer exists
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
|
||||
# create a new session
|
||||
session_id_b = cache.start_session()
|
||||
# switch back to original session
|
||||
session_id = cache.start_session(session_id_a)
|
||||
self.assertEqual(session_id, session_id_a)
|
||||
# end original session
|
||||
cache.end_current_session()
|
||||
# switch back to B
|
||||
session_id = cache.start_session(session_id_b)
|
||||
self.assertEqual(session_id, session_id_b)
|
||||
|
||||
def test_session_queue(self):
|
||||
session_id = cache.start_session()
|
||||
self.assertEqual(cache.start_session(session_id), session_id)
|
||||
@ -47,7 +80,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(cache.get(KEY), "hello")
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises(ProcessError):
|
||||
cache.get(KEY)
|
||||
|
||||
def test_decorator_mismatch(self):
|
||||
@ -128,6 +161,15 @@ class TestStorageCache(unittest.TestCase):
|
||||
call_Initialize(session_id=session_id)
|
||||
self.assertEqual(cache.get(KEY), "hello")
|
||||
|
||||
def test_EndSession(self):
|
||||
self.assertRaises(ProcessError, cache.get, KEY)
|
||||
session_id = cache.start_session()
|
||||
self.assertTrue(cache.is_session_started())
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
|
||||
self.assertFalse(cache.is_session_started())
|
||||
self.assertRaises(ProcessError, cache.get, KEY)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user