mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-16 17:42:02 +00:00
all: use a specific error code for "invalid session"
This commit is contained in:
parent
e96a9e8d39
commit
e0583dd5cb
@ -34,6 +34,7 @@ message Failure {
|
|||||||
Failure_NotInitialized = 11;
|
Failure_NotInitialized = 11;
|
||||||
Failure_PinMismatch = 12;
|
Failure_PinMismatch = 12;
|
||||||
Failure_WipeCodeMismatch = 13;
|
Failure_WipeCodeMismatch = 13;
|
||||||
|
Failure_InvalidSession = 14;
|
||||||
Failure_FirmwareError = 99;
|
Failure_FirmwareError = 99;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -75,7 +75,7 @@ def set(key: int, value: Any) -> None:
|
|||||||
_sessionless_cache[key] = value
|
_sessionless_cache[key] = value
|
||||||
return
|
return
|
||||||
if _active_session_id is None:
|
if _active_session_id is None:
|
||||||
raise wire.ProcessError("Invalid session")
|
raise wire.InvalidSession
|
||||||
_caches[_active_session_id][key] = value
|
_caches[_active_session_id][key] = value
|
||||||
|
|
||||||
|
|
||||||
@ -83,7 +83,7 @@ def get(key: int) -> Any:
|
|||||||
if key & _SESSIONLESS_FLAG:
|
if key & _SESSIONLESS_FLAG:
|
||||||
return _sessionless_cache.get(key)
|
return _sessionless_cache.get(key)
|
||||||
if _active_session_id is None:
|
if _active_session_id is None:
|
||||||
raise wire.ProcessError("Invalid session")
|
raise wire.InvalidSession
|
||||||
return _caches[_active_session_id].get(key)
|
return _caches[_active_session_id].get(key)
|
||||||
|
|
||||||
|
|
||||||
@ -93,7 +93,7 @@ def delete(key: int) -> None:
|
|||||||
del _sessionless_cache[key]
|
del _sessionless_cache[key]
|
||||||
return
|
return
|
||||||
if _active_session_id is None:
|
if _active_session_id is None:
|
||||||
raise wire.ProcessError("Invalid session")
|
raise wire.InvalidSession
|
||||||
if key in _caches[_active_session_id]:
|
if key in _caches[_active_session_id]:
|
||||||
del _caches[_active_session_id][key]
|
del _caches[_active_session_id][key]
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ if __debug__:
|
|||||||
try:
|
try:
|
||||||
from typing import Dict, List # noqa: F401
|
from typing import Dict, List # noqa: F401
|
||||||
from typing_extensions import Literal # noqa: F401
|
from typing_extensions import Literal # noqa: F401
|
||||||
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99]
|
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -25,6 +25,6 @@ class Failure(p.MessageType):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls) -> Dict:
|
def get_fields(cls) -> Dict:
|
||||||
return {
|
return {
|
||||||
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99)), 0),
|
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99)), 0),
|
||||||
2: ('message', p.UnicodeType, 0),
|
2: ('message', p.UnicodeType, 0),
|
||||||
}
|
}
|
||||||
|
@ -16,4 +16,5 @@ NotEnoughFunds = 10 # type: Literal[10]
|
|||||||
NotInitialized = 11 # type: Literal[11]
|
NotInitialized = 11 # type: Literal[11]
|
||||||
PinMismatch = 12 # type: Literal[12]
|
PinMismatch = 12 # type: Literal[12]
|
||||||
WipeCodeMismatch = 13 # type: Literal[13]
|
WipeCodeMismatch = 13 # type: Literal[13]
|
||||||
|
InvalidSession = 14 # type: Literal[14]
|
||||||
FirmwareError = 99 # type: Literal[99]
|
FirmwareError = 99 # type: Literal[99]
|
||||||
|
@ -76,6 +76,11 @@ class WipeCodeMismatch(Error):
|
|||||||
super().__init__(FailureType.WipeCodeMismatch, message)
|
super().__init__(FailureType.WipeCodeMismatch, message)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidSession(Error):
|
||||||
|
def __init__(self, message: str = "Invalid session") -> None:
|
||||||
|
super().__init__(FailureType.InvalidSession, message)
|
||||||
|
|
||||||
|
|
||||||
class FirmwareError(Error):
|
class FirmwareError(Error):
|
||||||
def __init__(self, message: str) -> None:
|
def __init__(self, message: str) -> None:
|
||||||
super().__init__(FailureType.FirmwareError, message)
|
super().__init__(FailureType.FirmwareError, message)
|
||||||
|
@ -4,7 +4,7 @@ from mock_storage import mock_storage
|
|||||||
from storage import cache
|
from storage import cache
|
||||||
from trezor.messages.Initialize import Initialize
|
from trezor.messages.Initialize import Initialize
|
||||||
from trezor.messages.EndSession import EndSession
|
from trezor.messages.EndSession import EndSession
|
||||||
from trezor.wire import DUMMY_CONTEXT, ProcessError
|
from trezor.wire import DUMMY_CONTEXT, InvalidSession
|
||||||
|
|
||||||
from apps.base import handle_Initialize, handle_EndSession
|
from apps.base import handle_Initialize, handle_EndSession
|
||||||
|
|
||||||
@ -22,9 +22,9 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertNotEqual(session_id_a, session_id_b)
|
self.assertNotEqual(session_id_a, session_id_b)
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(ProcessError):
|
with self.assertRaises(InvalidSession):
|
||||||
cache.set(KEY, "something")
|
cache.set(KEY, "something")
|
||||||
with self.assertRaises(ProcessError):
|
with self.assertRaises(InvalidSession):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
def test_end_session(self):
|
def test_end_session(self):
|
||||||
@ -33,7 +33,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
cache.set(KEY, "A")
|
cache.set(KEY, "A")
|
||||||
cache.end_current_session()
|
cache.end_current_session()
|
||||||
self.assertFalse(cache.is_session_started())
|
self.assertFalse(cache.is_session_started())
|
||||||
self.assertRaises(ProcessError, cache.get, KEY)
|
self.assertRaises(InvalidSession, cache.get, KEY)
|
||||||
|
|
||||||
# ending an ended session should be a no-op
|
# ending an ended session should be a no-op
|
||||||
cache.end_current_session()
|
cache.end_current_session()
|
||||||
@ -80,7 +80,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(cache.get(KEY), "hello")
|
self.assertEqual(cache.get(KEY), "hello")
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(ProcessError):
|
with self.assertRaises(InvalidSession):
|
||||||
cache.get(KEY)
|
cache.get(KEY)
|
||||||
|
|
||||||
def test_decorator_mismatch(self):
|
def test_decorator_mismatch(self):
|
||||||
@ -162,13 +162,13 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(cache.get(KEY), "hello")
|
self.assertEqual(cache.get(KEY), "hello")
|
||||||
|
|
||||||
def test_EndSession(self):
|
def test_EndSession(self):
|
||||||
self.assertRaises(ProcessError, cache.get, KEY)
|
self.assertRaises(InvalidSession, cache.get, KEY)
|
||||||
session_id = cache.start_session()
|
session_id = cache.start_session()
|
||||||
self.assertTrue(cache.is_session_started())
|
self.assertTrue(cache.is_session_started())
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(cache.get(KEY))
|
||||||
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
|
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
|
||||||
self.assertFalse(cache.is_session_started())
|
self.assertFalse(cache.is_session_started())
|
||||||
self.assertRaises(ProcessError, cache.get, KEY)
|
self.assertRaises(InvalidSession, cache.get, KEY)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -583,7 +583,7 @@ static void get_root_node_callback(uint32_t iter, uint32_t total) {
|
|||||||
|
|
||||||
const uint8_t *config_getSeed(void) {
|
const uint8_t *config_getSeed(void) {
|
||||||
if (activeSessionCache == NULL) {
|
if (activeSessionCache == NULL) {
|
||||||
fsm_sendFailure(FailureType_Failure_ProcessError, "Invalid session");
|
fsm_sendFailure(FailureType_Failure_InvalidSession, "Invalid session");
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,6 +170,9 @@ void fsm_sendFailure(FailureType code, const char *text)
|
|||||||
case FailureType_Failure_WipeCodeMismatch:
|
case FailureType_Failure_WipeCodeMismatch:
|
||||||
text = _("Wipe code mismatch");
|
text = _("Wipe code mismatch");
|
||||||
break;
|
break;
|
||||||
|
case FailureType_Failure_InvalidSession:
|
||||||
|
text = _("Invalid session");
|
||||||
|
break;
|
||||||
case FailureType_Failure_FirmwareError:
|
case FailureType_Failure_FirmwareError:
|
||||||
text = _("Firmware error");
|
text = _("Firmware error");
|
||||||
break;
|
break;
|
||||||
|
@ -6,7 +6,7 @@ if __debug__:
|
|||||||
try:
|
try:
|
||||||
from typing import Dict, List # noqa: F401
|
from typing import Dict, List # noqa: F401
|
||||||
from typing_extensions import Literal # noqa: F401
|
from typing_extensions import Literal # noqa: F401
|
||||||
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99]
|
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99]
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -25,6 +25,6 @@ class Failure(p.MessageType):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls) -> Dict:
|
def get_fields(cls) -> Dict:
|
||||||
return {
|
return {
|
||||||
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99)), 0),
|
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99)), 0),
|
||||||
2: ('message', p.UnicodeType, 0),
|
2: ('message', p.UnicodeType, 0),
|
||||||
}
|
}
|
||||||
|
@ -16,4 +16,5 @@ NotEnoughFunds = 10 # type: Literal[10]
|
|||||||
NotInitialized = 11 # type: Literal[11]
|
NotInitialized = 11 # type: Literal[11]
|
||||||
PinMismatch = 12 # type: Literal[12]
|
PinMismatch = 12 # type: Literal[12]
|
||||||
WipeCodeMismatch = 13 # type: Literal[13]
|
WipeCodeMismatch = 13 # type: Literal[13]
|
||||||
|
InvalidSession = 14 # type: Literal[14]
|
||||||
FirmwareError = 99 # type: Literal[99]
|
FirmwareError = 99 # type: Literal[99]
|
||||||
|
@ -77,8 +77,10 @@ def test_end_session(client):
|
|||||||
|
|
||||||
client.end_session()
|
client.end_session()
|
||||||
assert client.session_id is None
|
assert client.session_id is None
|
||||||
with pytest.raises(TrezorFailure, match="Invalid session"):
|
with pytest.raises(TrezorFailure) as exc:
|
||||||
get_test_address(client)
|
get_test_address(client)
|
||||||
|
assert exc.value.code == messages.FailureType.InvalidSession
|
||||||
|
assert exc.value.message.endswith("Invalid session")
|
||||||
|
|
||||||
client.init_device()
|
client.init_device()
|
||||||
assert client.session_id is not None
|
assert client.session_id is not None
|
||||||
|
Loading…
Reference in New Issue
Block a user