mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 07:28:10 +00:00
all: use a specific error code for "invalid session"
This commit is contained in:
parent
e96a9e8d39
commit
e0583dd5cb
@ -34,6 +34,7 @@ message Failure {
|
||||
Failure_NotInitialized = 11;
|
||||
Failure_PinMismatch = 12;
|
||||
Failure_WipeCodeMismatch = 13;
|
||||
Failure_InvalidSession = 14;
|
||||
Failure_FirmwareError = 99;
|
||||
}
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ def set(key: int, value: Any) -> None:
|
||||
_sessionless_cache[key] = value
|
||||
return
|
||||
if _active_session_id is None:
|
||||
raise wire.ProcessError("Invalid session")
|
||||
raise wire.InvalidSession
|
||||
_caches[_active_session_id][key] = value
|
||||
|
||||
|
||||
@ -83,7 +83,7 @@ def get(key: int) -> Any:
|
||||
if key & _SESSIONLESS_FLAG:
|
||||
return _sessionless_cache.get(key)
|
||||
if _active_session_id is None:
|
||||
raise wire.ProcessError("Invalid session")
|
||||
raise wire.InvalidSession
|
||||
return _caches[_active_session_id].get(key)
|
||||
|
||||
|
||||
@ -93,7 +93,7 @@ def delete(key: int) -> None:
|
||||
del _sessionless_cache[key]
|
||||
return
|
||||
if _active_session_id is None:
|
||||
raise wire.ProcessError("Invalid session")
|
||||
raise wire.InvalidSession
|
||||
if key in _caches[_active_session_id]:
|
||||
del _caches[_active_session_id][key]
|
||||
|
||||
|
@ -6,7 +6,7 @@ if __debug__:
|
||||
try:
|
||||
from typing import Dict, List # noqa: F401
|
||||
from typing_extensions import Literal # noqa: F401
|
||||
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99]
|
||||
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -25,6 +25,6 @@ class Failure(p.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls) -> Dict:
|
||||
return {
|
||||
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99)), 0),
|
||||
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99)), 0),
|
||||
2: ('message', p.UnicodeType, 0),
|
||||
}
|
||||
|
@ -16,4 +16,5 @@ NotEnoughFunds = 10 # type: Literal[10]
|
||||
NotInitialized = 11 # type: Literal[11]
|
||||
PinMismatch = 12 # type: Literal[12]
|
||||
WipeCodeMismatch = 13 # type: Literal[13]
|
||||
InvalidSession = 14 # type: Literal[14]
|
||||
FirmwareError = 99 # type: Literal[99]
|
||||
|
@ -76,6 +76,11 @@ class WipeCodeMismatch(Error):
|
||||
super().__init__(FailureType.WipeCodeMismatch, message)
|
||||
|
||||
|
||||
class InvalidSession(Error):
|
||||
def __init__(self, message: str = "Invalid session") -> None:
|
||||
super().__init__(FailureType.InvalidSession, message)
|
||||
|
||||
|
||||
class FirmwareError(Error):
|
||||
def __init__(self, message: str) -> None:
|
||||
super().__init__(FailureType.FirmwareError, message)
|
||||
|
@ -4,7 +4,7 @@ from mock_storage import mock_storage
|
||||
from storage import cache
|
||||
from trezor.messages.Initialize import Initialize
|
||||
from trezor.messages.EndSession import EndSession
|
||||
from trezor.wire import DUMMY_CONTEXT, ProcessError
|
||||
from trezor.wire import DUMMY_CONTEXT, InvalidSession
|
||||
|
||||
from apps.base import handle_Initialize, handle_EndSession
|
||||
|
||||
@ -22,9 +22,9 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertNotEqual(session_id_a, session_id_b)
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(ProcessError):
|
||||
with self.assertRaises(InvalidSession):
|
||||
cache.set(KEY, "something")
|
||||
with self.assertRaises(ProcessError):
|
||||
with self.assertRaises(InvalidSession):
|
||||
cache.get(KEY)
|
||||
|
||||
def test_end_session(self):
|
||||
@ -33,7 +33,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
cache.set(KEY, "A")
|
||||
cache.end_current_session()
|
||||
self.assertFalse(cache.is_session_started())
|
||||
self.assertRaises(ProcessError, cache.get, KEY)
|
||||
self.assertRaises(InvalidSession, cache.get, KEY)
|
||||
|
||||
# ending an ended session should be a no-op
|
||||
cache.end_current_session()
|
||||
@ -80,7 +80,7 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(cache.get(KEY), "hello")
|
||||
|
||||
cache.clear_all()
|
||||
with self.assertRaises(ProcessError):
|
||||
with self.assertRaises(InvalidSession):
|
||||
cache.get(KEY)
|
||||
|
||||
def test_decorator_mismatch(self):
|
||||
@ -162,13 +162,13 @@ class TestStorageCache(unittest.TestCase):
|
||||
self.assertEqual(cache.get(KEY), "hello")
|
||||
|
||||
def test_EndSession(self):
|
||||
self.assertRaises(ProcessError, cache.get, KEY)
|
||||
self.assertRaises(InvalidSession, cache.get, KEY)
|
||||
session_id = cache.start_session()
|
||||
self.assertTrue(cache.is_session_started())
|
||||
self.assertIsNone(cache.get(KEY))
|
||||
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
|
||||
self.assertFalse(cache.is_session_started())
|
||||
self.assertRaises(ProcessError, cache.get, KEY)
|
||||
self.assertRaises(InvalidSession, cache.get, KEY)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -583,7 +583,7 @@ static void get_root_node_callback(uint32_t iter, uint32_t total) {
|
||||
|
||||
const uint8_t *config_getSeed(void) {
|
||||
if (activeSessionCache == NULL) {
|
||||
fsm_sendFailure(FailureType_Failure_ProcessError, "Invalid session");
|
||||
fsm_sendFailure(FailureType_Failure_InvalidSession, "Invalid session");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
|
@ -170,6 +170,9 @@ void fsm_sendFailure(FailureType code, const char *text)
|
||||
case FailureType_Failure_WipeCodeMismatch:
|
||||
text = _("Wipe code mismatch");
|
||||
break;
|
||||
case FailureType_Failure_InvalidSession:
|
||||
text = _("Invalid session");
|
||||
break;
|
||||
case FailureType_Failure_FirmwareError:
|
||||
text = _("Firmware error");
|
||||
break;
|
||||
|
@ -6,7 +6,7 @@ if __debug__:
|
||||
try:
|
||||
from typing import Dict, List # noqa: F401
|
||||
from typing_extensions import Literal # noqa: F401
|
||||
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99]
|
||||
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@ -25,6 +25,6 @@ class Failure(p.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls) -> Dict:
|
||||
return {
|
||||
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99)), 0),
|
||||
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99)), 0),
|
||||
2: ('message', p.UnicodeType, 0),
|
||||
}
|
||||
|
@ -16,4 +16,5 @@ NotEnoughFunds = 10 # type: Literal[10]
|
||||
NotInitialized = 11 # type: Literal[11]
|
||||
PinMismatch = 12 # type: Literal[12]
|
||||
WipeCodeMismatch = 13 # type: Literal[13]
|
||||
InvalidSession = 14 # type: Literal[14]
|
||||
FirmwareError = 99 # type: Literal[99]
|
||||
|
@ -77,8 +77,10 @@ def test_end_session(client):
|
||||
|
||||
client.end_session()
|
||||
assert client.session_id is None
|
||||
with pytest.raises(TrezorFailure, match="Invalid session"):
|
||||
with pytest.raises(TrezorFailure) as exc:
|
||||
get_test_address(client)
|
||||
assert exc.value.code == messages.FailureType.InvalidSession
|
||||
assert exc.value.message.endswith("Invalid session")
|
||||
|
||||
client.init_device()
|
||||
assert client.session_id is not None
|
||||
|
Loading…
Reference in New Issue
Block a user