diff --git a/common/protob/messages-common.proto b/common/protob/messages-common.proto index 4b1c077ef..be38977cf 100644 --- a/common/protob/messages-common.proto +++ b/common/protob/messages-common.proto @@ -34,6 +34,7 @@ message Failure { Failure_NotInitialized = 11; Failure_PinMismatch = 12; Failure_WipeCodeMismatch = 13; + Failure_InvalidSession = 14; Failure_FirmwareError = 99; } } diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index de752cb3d..30ec0330b 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -75,7 +75,7 @@ def set(key: int, value: Any) -> None: _sessionless_cache[key] = value return if _active_session_id is None: - raise wire.ProcessError("Invalid session") + raise wire.InvalidSession _caches[_active_session_id][key] = value @@ -83,7 +83,7 @@ def get(key: int) -> Any: if key & _SESSIONLESS_FLAG: return _sessionless_cache.get(key) if _active_session_id is None: - raise wire.ProcessError("Invalid session") + raise wire.InvalidSession return _caches[_active_session_id].get(key) @@ -93,7 +93,7 @@ def delete(key: int) -> None: del _sessionless_cache[key] return if _active_session_id is None: - raise wire.ProcessError("Invalid session") + raise wire.InvalidSession if key in _caches[_active_session_id]: del _caches[_active_session_id][key] diff --git a/core/src/trezor/messages/Failure.py b/core/src/trezor/messages/Failure.py index 7c1d37dde..2a085e374 100644 --- a/core/src/trezor/messages/Failure.py +++ b/core/src/trezor/messages/Failure.py @@ -6,7 +6,7 @@ if __debug__: try: from typing import Dict, List # noqa: F401 from typing_extensions import Literal # noqa: F401 - EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99] + EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99] except ImportError: pass @@ -25,6 +25,6 @@ class Failure(p.MessageType): @classmethod def get_fields(cls) -> Dict: return { - 1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99)), 0), + 1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99)), 0), 2: ('message', p.UnicodeType, 0), } diff --git a/core/src/trezor/messages/FailureType.py b/core/src/trezor/messages/FailureType.py index f94652b9e..df77d8db9 100644 --- a/core/src/trezor/messages/FailureType.py +++ b/core/src/trezor/messages/FailureType.py @@ -16,4 +16,5 @@ NotEnoughFunds = 10 # type: Literal[10] NotInitialized = 11 # type: Literal[11] PinMismatch = 12 # type: Literal[12] WipeCodeMismatch = 13 # type: Literal[13] +InvalidSession = 14 # type: Literal[14] FirmwareError = 99 # type: Literal[99] diff --git a/core/src/trezor/wire/errors.py b/core/src/trezor/wire/errors.py index 38ee73902..aeed58cb9 100644 --- a/core/src/trezor/wire/errors.py +++ b/core/src/trezor/wire/errors.py @@ -76,6 +76,11 @@ class WipeCodeMismatch(Error): super().__init__(FailureType.WipeCodeMismatch, message) +class InvalidSession(Error): + def __init__(self, message: str = "Invalid session") -> None: + super().__init__(FailureType.InvalidSession, message) + + class FirmwareError(Error): def __init__(self, message: str) -> None: super().__init__(FailureType.FirmwareError, message) diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 3cb0f9fbd..a076f8b6f 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -4,7 +4,7 @@ from mock_storage import mock_storage from storage import cache from trezor.messages.Initialize import Initialize from trezor.messages.EndSession import EndSession -from trezor.wire import DUMMY_CONTEXT, ProcessError +from trezor.wire import DUMMY_CONTEXT, InvalidSession from apps.base import handle_Initialize, handle_EndSession @@ -22,9 +22,9 @@ class TestStorageCache(unittest.TestCase): self.assertNotEqual(session_id_a, session_id_b) cache.clear_all() - with self.assertRaises(ProcessError): + with self.assertRaises(InvalidSession): cache.set(KEY, "something") - with self.assertRaises(ProcessError): + with self.assertRaises(InvalidSession): cache.get(KEY) def test_end_session(self): @@ -33,7 +33,7 @@ class TestStorageCache(unittest.TestCase): cache.set(KEY, "A") cache.end_current_session() self.assertFalse(cache.is_session_started()) - self.assertRaises(ProcessError, cache.get, KEY) + self.assertRaises(InvalidSession, cache.get, KEY) # ending an ended session should be a no-op cache.end_current_session() @@ -80,7 +80,7 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(cache.get(KEY), "hello") cache.clear_all() - with self.assertRaises(ProcessError): + with self.assertRaises(InvalidSession): cache.get(KEY) def test_decorator_mismatch(self): @@ -162,13 +162,13 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(cache.get(KEY), "hello") def test_EndSession(self): - self.assertRaises(ProcessError, cache.get, KEY) + self.assertRaises(InvalidSession, cache.get, KEY) session_id = cache.start_session() self.assertTrue(cache.is_session_started()) self.assertIsNone(cache.get(KEY)) await_result(handle_EndSession(DUMMY_CONTEXT, EndSession())) self.assertFalse(cache.is_session_started()) - self.assertRaises(ProcessError, cache.get, KEY) + self.assertRaises(InvalidSession, cache.get, KEY) if __name__ == "__main__": diff --git a/legacy/firmware/config.c b/legacy/firmware/config.c index 90443a39d..3d160aae3 100644 --- a/legacy/firmware/config.c +++ b/legacy/firmware/config.c @@ -583,7 +583,7 @@ static void get_root_node_callback(uint32_t iter, uint32_t total) { const uint8_t *config_getSeed(void) { if (activeSessionCache == NULL) { - fsm_sendFailure(FailureType_Failure_ProcessError, "Invalid session"); + fsm_sendFailure(FailureType_Failure_InvalidSession, "Invalid session"); return NULL; } diff --git a/legacy/firmware/fsm.c b/legacy/firmware/fsm.c index a1962d3a7..be95b2f9f 100644 --- a/legacy/firmware/fsm.c +++ b/legacy/firmware/fsm.c @@ -170,6 +170,9 @@ void fsm_sendFailure(FailureType code, const char *text) case FailureType_Failure_WipeCodeMismatch: text = _("Wipe code mismatch"); break; + case FailureType_Failure_InvalidSession: + text = _("Invalid session"); + break; case FailureType_Failure_FirmwareError: text = _("Firmware error"); break; diff --git a/python/src/trezorlib/messages/Failure.py b/python/src/trezorlib/messages/Failure.py index c6c2dd066..bba2a2d52 100644 --- a/python/src/trezorlib/messages/Failure.py +++ b/python/src/trezorlib/messages/Failure.py @@ -6,7 +6,7 @@ if __debug__: try: from typing import Dict, List # noqa: F401 from typing_extensions import Literal # noqa: F401 - EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99] + EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99] except ImportError: pass @@ -25,6 +25,6 @@ class Failure(p.MessageType): @classmethod def get_fields(cls) -> Dict: return { - 1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99)), 0), + 1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99)), 0), 2: ('message', p.UnicodeType, 0), } diff --git a/python/src/trezorlib/messages/FailureType.py b/python/src/trezorlib/messages/FailureType.py index f94652b9e..df77d8db9 100644 --- a/python/src/trezorlib/messages/FailureType.py +++ b/python/src/trezorlib/messages/FailureType.py @@ -16,4 +16,5 @@ NotEnoughFunds = 10 # type: Literal[10] NotInitialized = 11 # type: Literal[11] PinMismatch = 12 # type: Literal[12] WipeCodeMismatch = 13 # type: Literal[13] +InvalidSession = 14 # type: Literal[14] FirmwareError = 99 # type: Literal[99] diff --git a/tests/device_tests/test_session.py b/tests/device_tests/test_session.py index 89a22f6fa..6bfcd8c77 100644 --- a/tests/device_tests/test_session.py +++ b/tests/device_tests/test_session.py @@ -77,8 +77,10 @@ def test_end_session(client): client.end_session() assert client.session_id is None - with pytest.raises(TrezorFailure, match="Invalid session"): + with pytest.raises(TrezorFailure) as exc: get_test_address(client) + assert exc.value.code == messages.FailureType.InvalidSession + assert exc.value.message.endswith("Invalid session") client.init_device() assert client.session_id is not None