all: use a specific error code for "invalid session"

pull/1236/head
matejcik 4 years ago committed by matejcik
parent e96a9e8d39
commit e0583dd5cb

@ -34,6 +34,7 @@ message Failure {
Failure_NotInitialized = 11;
Failure_PinMismatch = 12;
Failure_WipeCodeMismatch = 13;
Failure_InvalidSession = 14;
Failure_FirmwareError = 99;
}
}

@ -75,7 +75,7 @@ def set(key: int, value: Any) -> None:
_sessionless_cache[key] = value
return
if _active_session_id is None:
raise wire.ProcessError("Invalid session")
raise wire.InvalidSession
_caches[_active_session_id][key] = value
@ -83,7 +83,7 @@ def get(key: int) -> Any:
if key & _SESSIONLESS_FLAG:
return _sessionless_cache.get(key)
if _active_session_id is None:
raise wire.ProcessError("Invalid session")
raise wire.InvalidSession
return _caches[_active_session_id].get(key)
@ -93,7 +93,7 @@ def delete(key: int) -> None:
del _sessionless_cache[key]
return
if _active_session_id is None:
raise wire.ProcessError("Invalid session")
raise wire.InvalidSession
if key in _caches[_active_session_id]:
del _caches[_active_session_id][key]

@ -6,7 +6,7 @@ if __debug__:
try:
from typing import Dict, List # noqa: F401
from typing_extensions import Literal # noqa: F401
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99]
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99]
except ImportError:
pass
@ -25,6 +25,6 @@ class Failure(p.MessageType):
@classmethod
def get_fields(cls) -> Dict:
return {
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99)), 0),
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99)), 0),
2: ('message', p.UnicodeType, 0),
}

@ -16,4 +16,5 @@ NotEnoughFunds = 10 # type: Literal[10]
NotInitialized = 11 # type: Literal[11]
PinMismatch = 12 # type: Literal[12]
WipeCodeMismatch = 13 # type: Literal[13]
InvalidSession = 14 # type: Literal[14]
FirmwareError = 99 # type: Literal[99]

@ -76,6 +76,11 @@ class WipeCodeMismatch(Error):
super().__init__(FailureType.WipeCodeMismatch, message)
class InvalidSession(Error):
def __init__(self, message: str = "Invalid session") -> None:
super().__init__(FailureType.InvalidSession, message)
class FirmwareError(Error):
def __init__(self, message: str) -> None:
super().__init__(FailureType.FirmwareError, message)

@ -4,7 +4,7 @@ from mock_storage import mock_storage
from storage import cache
from trezor.messages.Initialize import Initialize
from trezor.messages.EndSession import EndSession
from trezor.wire import DUMMY_CONTEXT, ProcessError
from trezor.wire import DUMMY_CONTEXT, InvalidSession
from apps.base import handle_Initialize, handle_EndSession
@ -22,9 +22,9 @@ class TestStorageCache(unittest.TestCase):
self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all()
with self.assertRaises(ProcessError):
with self.assertRaises(InvalidSession):
cache.set(KEY, "something")
with self.assertRaises(ProcessError):
with self.assertRaises(InvalidSession):
cache.get(KEY)
def test_end_session(self):
@ -33,7 +33,7 @@ class TestStorageCache(unittest.TestCase):
cache.set(KEY, "A")
cache.end_current_session()
self.assertFalse(cache.is_session_started())
self.assertRaises(ProcessError, cache.get, KEY)
self.assertRaises(InvalidSession, cache.get, KEY)
# ending an ended session should be a no-op
cache.end_current_session()
@ -80,7 +80,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), "hello")
cache.clear_all()
with self.assertRaises(ProcessError):
with self.assertRaises(InvalidSession):
cache.get(KEY)
def test_decorator_mismatch(self):
@ -162,13 +162,13 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(cache.get(KEY), "hello")
def test_EndSession(self):
self.assertRaises(ProcessError, cache.get, KEY)
self.assertRaises(InvalidSession, cache.get, KEY)
session_id = cache.start_session()
self.assertTrue(cache.is_session_started())
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(DUMMY_CONTEXT, EndSession()))
self.assertFalse(cache.is_session_started())
self.assertRaises(ProcessError, cache.get, KEY)
self.assertRaises(InvalidSession, cache.get, KEY)
if __name__ == "__main__":

@ -583,7 +583,7 @@ static void get_root_node_callback(uint32_t iter, uint32_t total) {
const uint8_t *config_getSeed(void) {
if (activeSessionCache == NULL) {
fsm_sendFailure(FailureType_Failure_ProcessError, "Invalid session");
fsm_sendFailure(FailureType_Failure_InvalidSession, "Invalid session");
return NULL;
}

@ -170,6 +170,9 @@ void fsm_sendFailure(FailureType code, const char *text)
case FailureType_Failure_WipeCodeMismatch:
text = _("Wipe code mismatch");
break;
case FailureType_Failure_InvalidSession:
text = _("Invalid session");
break;
case FailureType_Failure_FirmwareError:
text = _("Firmware error");
break;

@ -6,7 +6,7 @@ if __debug__:
try:
from typing import Dict, List # noqa: F401
from typing_extensions import Literal # noqa: F401
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99]
EnumTypeFailureType = Literal[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99]
except ImportError:
pass
@ -25,6 +25,6 @@ class Failure(p.MessageType):
@classmethod
def get_fields(cls) -> Dict:
return {
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 99)), 0),
1: ('code', p.EnumType("FailureType", (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 99)), 0),
2: ('message', p.UnicodeType, 0),
}

@ -16,4 +16,5 @@ NotEnoughFunds = 10 # type: Literal[10]
NotInitialized = 11 # type: Literal[11]
PinMismatch = 12 # type: Literal[12]
WipeCodeMismatch = 13 # type: Literal[13]
InvalidSession = 14 # type: Literal[14]
FirmwareError = 99 # type: Literal[99]

@ -77,8 +77,10 @@ def test_end_session(client):
client.end_session()
assert client.session_id is None
with pytest.raises(TrezorFailure, match="Invalid session"):
with pytest.raises(TrezorFailure) as exc:
get_test_address(client)
assert exc.value.code == messages.FailureType.InvalidSession
assert exc.value.message.endswith("Invalid session")
client.init_device()
assert client.session_id is not None

Loading…
Cancel
Save