mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-18 10:32:02 +00:00
core: implement EndSession
This commit is contained in:
parent
c7934116ec
commit
4909821f35
@ -16,6 +16,7 @@ if False:
|
|||||||
import protobuf
|
import protobuf
|
||||||
from typing import Iterable, NoReturn, Optional, Protocol
|
from typing import Iterable, NoReturn, Optional, Protocol
|
||||||
from trezor.messages.Initialize import Initialize
|
from trezor.messages.Initialize import Initialize
|
||||||
|
from trezor.messages.EndSession import EndSession
|
||||||
from trezor.messages.GetFeatures import GetFeatures
|
from trezor.messages.GetFeatures import GetFeatures
|
||||||
from trezor.messages.Cancel import Cancel
|
from trezor.messages.Cancel import Cancel
|
||||||
from trezor.messages.LockDevice import LockDevice
|
from trezor.messages.LockDevice import LockDevice
|
||||||
@ -116,6 +117,11 @@ async def handle_LockDevice(ctx: wire.Context, msg: LockDevice) -> Success:
|
|||||||
return Success()
|
return Success()
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_EndSession(ctx: wire.Context, msg: EndSession) -> Success:
|
||||||
|
cache.end_current_session()
|
||||||
|
return Success()
|
||||||
|
|
||||||
|
|
||||||
async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
|
async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
|
||||||
if msg.button_protection:
|
if msg.button_protection:
|
||||||
from apps.common.confirm import require_confirm
|
from apps.common.confirm import require_confirm
|
||||||
@ -172,6 +178,7 @@ async def handle_CancelAuthorization(
|
|||||||
|
|
||||||
ALLOW_WHILE_LOCKED = (
|
ALLOW_WHILE_LOCKED = (
|
||||||
MessageType.Initialize,
|
MessageType.Initialize,
|
||||||
|
MessageType.EndSession,
|
||||||
MessageType.GetFeatures,
|
MessageType.GetFeatures,
|
||||||
MessageType.Cancel,
|
MessageType.Cancel,
|
||||||
MessageType.LockDevice,
|
MessageType.LockDevice,
|
||||||
@ -249,6 +256,7 @@ def boot() -> None:
|
|||||||
wire.register(MessageType.GetFeatures, handle_GetFeatures)
|
wire.register(MessageType.GetFeatures, handle_GetFeatures)
|
||||||
wire.register(MessageType.Cancel, handle_Cancel)
|
wire.register(MessageType.Cancel, handle_Cancel)
|
||||||
wire.register(MessageType.LockDevice, handle_LockDevice)
|
wire.register(MessageType.LockDevice, handle_LockDevice)
|
||||||
|
wire.register(MessageType.EndSession, handle_EndSession)
|
||||||
wire.register(MessageType.Ping, handle_Ping)
|
wire.register(MessageType.Ping, handle_Ping)
|
||||||
wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
|
wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
|
||||||
wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization)
|
wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from trezor import wire
|
||||||
from trezor.crypto import random
|
from trezor.crypto import random
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
@ -52,6 +53,19 @@ def start_session(received_session_id: bytes = None) -> bytes:
|
|||||||
return _active_session_id
|
return _active_session_id
|
||||||
|
|
||||||
|
|
||||||
|
def end_current_session() -> None:
|
||||||
|
global _active_session_id
|
||||||
|
|
||||||
|
if _active_session_id is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_session_id = _active_session_id
|
||||||
|
_active_session_id = None
|
||||||
|
|
||||||
|
_session_ids.remove(current_session_id)
|
||||||
|
del _caches[current_session_id]
|
||||||
|
|
||||||
|
|
||||||
def is_session_started() -> bool:
|
def is_session_started() -> bool:
|
||||||
return _active_session_id is not None
|
return _active_session_id is not None
|
||||||
|
|
||||||
@ -61,7 +75,7 @@ def set(key: int, value: Any) -> None:
|
|||||||
_sessionless_cache[key] = value
|
_sessionless_cache[key] = value
|
||||||
return
|
return
|
||||||
if _active_session_id is None:
|
if _active_session_id is None:
|
||||||
raise RuntimeError # no session active
|
raise wire.ProcessError("Invalid session")
|
||||||
_caches[_active_session_id][key] = value
|
_caches[_active_session_id][key] = value
|
||||||
|
|
||||||
|
|
||||||
@ -69,7 +83,7 @@ def get(key: int) -> Any:
|
|||||||
if key & _SESSIONLESS_FLAG:
|
if key & _SESSIONLESS_FLAG:
|
||||||
return _sessionless_cache.get(key)
|
return _sessionless_cache.get(key)
|
||||||
if _active_session_id is None:
|
if _active_session_id is None:
|
||||||
raise RuntimeError # no session active
|
raise wire.ProcessError("Invalid session")
|
||||||
return _caches[_active_session_id].get(key)
|
return _caches[_active_session_id].get(key)
|
||||||
|
|
||||||
|
|
||||||
@ -79,7 +93,7 @@ def delete(key: int) -> None:
|
|||||||
del _sessionless_cache[key]
|
del _sessionless_cache[key]
|
||||||
return
|
return
|
||||||
if _active_session_id is None:
|
if _active_session_id is None:
|
||||||
raise RuntimeError # no session active
|
raise wire.ProcessError("Invalid session")
|
||||||
if key in _caches[_active_session_id]:
|
if key in _caches[_active_session_id]:
|
||||||
del _caches[_active_session_id][key]
|
del _caches[_active_session_id][key]
|
||||||
|
|
||||||
|
@ -3,14 +3,18 @@ from mock_storage import mock_storage
|
|||||||
|
|
||||||
from storage import cache
|
from storage import cache
|
||||||
from trezor.messages.Initialize import Initialize
|
from trezor.messages.Initialize import Initialize
|
||||||
from trezor.wire import DUMMY_CONTEXT
|
from trezor.messages.EndSession import EndSession
|
||||||
|
from trezor.wire import DUMMY_CONTEXT, ProcessError
|
||||||
|
|
||||||
from apps.base import handle_Initialize
|
from apps.base import handle_Initialize, handle_EndSession
|
||||||
|
|
||||||
KEY = 99
|
KEY = 99
|
||||||
|
|
||||||
|
|
||||||
class TestStorageCache(unittest.TestCase):
|
class TestStorageCache(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
cache.clear_all()
|
||||||
|
|
||||||
def test_start_session(self):
|
def test_start_session(self):
|
||||||
session_id_a = cache.start_session()
|
session_id_a = cache.start_session()
|
||||||
self.assertIsNotNone(session_id_a)
|
self.assertIsNotNone(session_id_a)
|
||||||
@ -18,11 +22,40 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertNotEqual(session_id_a, session_id_b)
|
self.assertNotEqual(session_id_a, session_id_b)
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(ProcessError):
|
||||||
cache.set(KEY, "something")
|
cache.set(KEY, "something")
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(ProcessError):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
|
def test_end_session(self):
|
||||||
|
session_id = cache.start_session()
|
||||||
|
self.assertTrue(cache.is_session_started())
|
||||||
|
cache.set(KEY, "A")
|
||||||
|
cache.end_current_session()
|
||||||
|
self.assertFalse(cache.is_session_started())
|
||||||
|
self.assertRaises(ProcessError, cache.get, KEY)
|
||||||
|
|
||||||
|
# ending an ended session should be a no-op
|
||||||
|
cache.end_current_session()
|
||||||
|
self.assertFalse(cache.is_session_started())
|
||||||
|
|
||||||
|
session_id_a = cache.start_session(session_id)
|
||||||
|
# original session no longer exists
|
||||||
|
self.assertNotEqual(session_id_a, session_id)
|
||||||
|
# original session data no longer exists
|
||||||
|
self.assertIsNone(cache.get(KEY))
|
||||||
|
|
||||||
|
# create a new session
|
||||||
|
session_id_b = cache.start_session()
|
||||||
|
# switch back to original session
|
||||||
|
session_id = cache.start_session(session_id_a)
|
||||||
|
self.assertEqual(session_id, session_id_a)
|
||||||
|
# end original session
|
||||||
|
cache.end_current_session()
|
||||||
|
# switch back to B
|
||||||
|
session_id = cache.start_session(session_id_b)
|
||||||
|
self.assertEqual(session_id, session_id_b)
|
||||||
|
|
||||||
def test_session_queue(self):
|
def test_session_queue(self):
|
||||||
session_id = cache.start_session()
|
session_id = cache.start_session()
|
||||||
self.assertEqual(cache.start_session(session_id), session_id)
|
self.assertEqual(cache.start_session(session_id), session_id)
|
||||||
@ -47,7 +80,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(cache.get(KEY), "hello")
|
self.assertEqual(cache.get(KEY), "hello")
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(ProcessError):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
def test_decorator_mismatch(self):
|
def test_decorator_mismatch(self):
|
||||||
@ -128,6 +161,15 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
call_Initialize(session_id=session_id)
|
call_Initialize(session_id=session_id)
|
||||||
self.assertEqual(cache.get(KEY), "hello")
|
self.assertEqual(cache.get(KEY), "hello")
|
||||||
|
|
||||||
|
def test_EndSession(self):
|
||||||
|
self.assertRaises(ProcessError, cache.get, KEY)
|
||||||
|
session_id = cache.start_session()
|
||||||
|
self.assertTrue(cache.is_session_started())
|
||||||
|
self.assertIsNone(cache.get(KEY))
|
||||||
|
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
|
||||||
|
self.assertFalse(cache.is_session_started())
|
||||||
|
self.assertRaises(ProcessError, cache.get, KEY)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user