diff --git a/core/src/apps/homescreen/__init__.py b/core/src/apps/homescreen/__init__.py index 91d11448c..2c5b13b6c 100644 --- a/core/src/apps/homescreen/__init__.py +++ b/core/src/apps/homescreen/__init__.py @@ -72,15 +72,16 @@ def get_features() -> Features: f.sd_card_present = sdcard.is_present() f.sd_protection = storage.sd_salt.is_enabled() f.wipe_code_protection = config.has_wipe_code() - f.session_id = cache.get_session_id() f.passphrase_always_on_device = storage.device.get_passphrase_always_on_device() return f async def handle_Initialize(ctx: wire.Context, msg: Initialize) -> Features: - if msg.session_id is None or msg.session_id != cache.get_session_id(): - cache.clear() - return get_features() + features = get_features() + if msg.session_id: + msg.session_id = bytes(msg.session_id) + features.session_id = cache.start_session(msg.session_id) + return features async def handle_GetFeatures(ctx: wire.Context, msg: GetFeatures) -> Features: diff --git a/core/src/storage/__init__.py b/core/src/storage/__init__.py index a0bca1945..df4888bd4 100644 --- a/core/src/storage/__init__.py +++ b/core/src/storage/__init__.py @@ -12,7 +12,7 @@ def is_initialized() -> bool: def wipe() -> None: config.wipe() - cache.clear() + cache.clear_all() def init_unlocked() -> None: diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 9411e8f23..580acc599 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -1,36 +1,82 @@ from trezor.crypto import random if False: - from typing import Optional + from typing import Optional, Dict, List +_MAX_SESSIONS_COUNT = 10 +_SESSIONLESS_FLAG = 128 + +# Traditional cache keys APP_COMMON_SEED = 0 -APP_COMMON_SEED_WITHOUT_PASSPHRASE = 1 -APP_CARDANO_ROOT = 2 -APP_MONERO_LIVE_REFRESH = 3 +APP_CARDANO_ROOT = 1 +APP_MONERO_LIVE_REFRESH = 2 + +# Keys that are valid across sessions +APP_COMMON_SEED_WITHOUT_PASSPHRASE = 1 | _SESSIONLESS_FLAG + -_cache_session_id = None # type: Optional[bytes] -_cache = {} +_active_session_id = None # type: Optional[bytes] +_caches = {} # type: Dict[bytes, Dict[int, Any]] +_session_ids = [] # type: List[bytes] +_sessionless_cache = {} # type: Dict[int, Any] if False: from typing import Any -def get_session_id() -> bytes: - global _cache_session_id - if not _cache_session_id: - _cache_session_id = random.bytes(32) - return _cache_session_id +def _move_session_ids_queue(session_id: bytes) -> None: + # Move the LRU session ids queue. + if session_id in _session_ids: + _session_ids.remove(session_id) + + while len(_session_ids) >= _MAX_SESSIONS_COUNT: + remove_session_id = _session_ids.pop() + del _caches[remove_session_id] + + _session_ids.insert(0, session_id) + + +def start_session(received_session_id: bytes = None) -> bytes: + if received_session_id and received_session_id in _session_ids: + session_id = received_session_id + else: + session_id = random.bytes(32) + _caches[session_id] = {} + + global _active_session_id + _active_session_id = session_id + _move_session_ids_queue(session_id) + return _active_session_id + + +def is_session_started() -> bool: + return _active_session_id is not None def set(key: int, value: Any) -> None: - _cache[key] = value + if key & _SESSIONLESS_FLAG: + _sessionless_cache[key] = value + return + if _active_session_id is None: + raise RuntimeError # no session active + _caches[_active_session_id][key] = value def get(key: int) -> Any: - return _cache.get(key) + if key & _SESSIONLESS_FLAG: + return _sessionless_cache.get(key) + if _active_session_id is None: + raise RuntimeError # no session active + return _caches[_active_session_id].get(key) + +def clear_all() -> None: + global _active_session_id + global _caches + global _session_ids + global _sessionless_cache -def clear() -> None: - global _cache_session_id - _cache_session_id = None - _cache.clear() + _active_session_id = None + _caches.clear() + _session_ids.clear() + _sessionless_cache.clear() diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 6a290c1c9..9f392c94f 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -1,41 +1,54 @@ from common import * -from mock import patch from mock_storage import mock_storage -import storage from storage import cache from trezor.messages.Initialize import Initialize -from trezor.messages.ClearSession import ClearSession from trezor.wire import DUMMY_CONTEXT -from apps.homescreen import handle_Initialize, handle_ClearSession +from apps.homescreen import handle_Initialize KEY = 99 class TestStorageCache(unittest.TestCase): - def test_session_id(self): - session_id_a = cache.get_session_id() + def test_start_session(self): + session_id_a = cache.start_session() self.assertIsNotNone(session_id_a) - session_id_b = cache.get_session_id() - self.assertEqual(session_id_a, session_id_b) - - cache.clear() - session_id_c = cache.get_session_id() - self.assertIsNotNone(session_id_c) - self.assertNotEqual(session_id_a, session_id_c) + session_id_b = cache.start_session() + self.assertNotEqual(session_id_a, session_id_b) + + cache.clear_all() + with self.assertRaises(RuntimeError): + cache.set(KEY, "something") + with self.assertRaises(RuntimeError): + cache.get(KEY) + + def test_session_queue(self): + session_id = cache.start_session() + self.assertEqual(cache.start_session(session_id), session_id) + cache.set(KEY, "A") + for i in range(cache._MAX_SESSIONS_COUNT): + cache.start_session() + self.assertNotEqual(cache.start_session(session_id), session_id) + self.assertIsNone(cache.get(KEY)) def test_get_set(self): - value = cache.get(KEY) - self.assertIsNone(value) - + session_id1 = cache.start_session() cache.set(KEY, "hello") - value = cache.get(KEY) - self.assertEqual(value, "hello") + self.assertEqual(cache.get(KEY), "hello") + + session_id2 = cache.start_session() + cache.set(KEY, "world") + self.assertEqual(cache.get(KEY), "world") - cache.clear() - value = cache.get(KEY) - self.assertIsNone(value) + cache.start_session(session_id2) + self.assertEqual(cache.get(KEY), "world") + cache.start_session(session_id1) + self.assertEqual(cache.get(KEY), "hello") + + cache.clear_all() + with self.assertRaises(RuntimeError): + cache.get(KEY) @mock_storage def test_Initialize(self): @@ -44,38 +57,32 @@ class TestStorageCache(unittest.TestCase): return await_result(handle_Initialize(DUMMY_CONTEXT, msg)) # calling Initialize without an ID allocates a new one - session_id = cache.get_session_id() + session_id = cache.start_session() features = call_Initialize() - new_session_id = cache.get_session_id() - self.assertNotEqual(session_id, new_session_id) - self.assertEqual(new_session_id, features.session_id) + self.assertNotEqual(session_id, features.session_id) # calling Initialize with the current ID does not allocate a new one - features = call_Initialize(session_id=new_session_id) - same_session_id = cache.get_session_id() - self.assertEqual(new_session_id, same_session_id) - self.assertEqual(same_session_id, features.session_id) - - call_Initialize() - # calling Initialize with a non-current ID returns a different one - features = call_Initialize(session_id=new_session_id) - self.assertNotEqual(new_session_id, features.session_id) + features = call_Initialize(session_id=session_id) + self.assertEqual(session_id, features.session_id) - # allocating a new session ID clears the cache + # store "hello" cache.set(KEY, "hello") + # check that it is cleared features = call_Initialize() + session_id = features.session_id self.assertIsNone(cache.get(KEY)) - - # resuming a session does not clear the cache + # store "hello" again cache.set(KEY, "hello") - call_Initialize(session_id=features.session_id) self.assertEqual(cache.get(KEY), "hello") - # supplying a different session ID clears the cache - self.assertNotEqual(new_session_id, features.session_id) - call_Initialize(session_id=new_session_id) + # supplying a different session ID starts a new cache + call_Initialize(session_id=b"A") self.assertIsNone(cache.get(KEY)) + # but resuming a session loads the previous one + call_Initialize(session_id=session_id) + self.assertEqual(cache.get(KEY), "hello") + if __name__ == "__main__": unittest.main() diff --git a/tests/device_tests/test_session_id_and_passphrase.py b/tests/device_tests/test_session_id_and_passphrase.py index 3f401d6fc..040b1add8 100644 --- a/tests/device_tests/test_session_id_and_passphrase.py +++ b/tests/device_tests/test_session_id_and_passphrase.py @@ -20,13 +20,26 @@ from trezorlib import messages from trezorlib.messages import FailureType from trezorlib.tools import parse_path -XPUB_PASSPHRASE_A = "xpub6CekxGcnqnJ6osfY4Rrq7W5ogFtR54KUvz4H16XzaQuukMFZCGebEpVznfq4yFcKEmYyShwj2UKjL7CazuNSuhdkofF4mHabHkLxCMVvsqG" +XPUB_PASSPHRASES = { + "A": "xpub6CekxGcnqnJ6osfY4Rrq7W5ogFtR54KUvz4H16XzaQuukMFZCGebEpVznfq4yFcKEmYyShwj2UKjL7CazuNSuhdkofF4mHabHkLxCMVvsqG", + "B": "xpub6CFxuyQpgryoR64QC38w42dLgDv5P4qWXhn1fbaN62UYzu1wJXZyrYqGnkq5d8xPUK68RXtXFBiqp3rfLGpeQ57zLtx675ZZn5ezKMAWQfu", + "C": "xpub6BhJMNFwCjGKyRb9RUcnuHhJ2TgcnurfUrQszrmZ1rg8aadsMXLySF6LY3qf4pR7bY4vwpd1VwLPQvuCRr7BPTs8wvqrv2gexxViwj96czT", + "D": "xpub6DK1vnTBe9EkhLACJRvovv8RSUC3MSiEV64opM7XUqrowxQ8J5C2WpA6n4vt5LS3bs618aKzi7k5w7VzNCv3SfqEeSepvvHaPhRoTvRqR5u", + "E": "xpub6CqbQjHN7r68GHh7RsiAyrdAmyiZQgWvDxQtba2NxZHumvfMK31U6emVQSexYrTAHWQeLygRD1yXZQLsCs1LLJtaeSxMAnh2YUmP3ov6EQz", + "F": "xpub6CRDxB1aHVNHfqjPeYhnPBhBfkQb4b4K581uYKxwv4KnkiVsRttBCXSkZM5jtP1Vv2v3wr5FxfzqWWDApLCbutBLnfwYpkWpZUmZSp6hqg5", + "G": "xpub6DGKmAKYDF44KQEaqXY3bbJNufEDi6QPnahV4JdBxFbFCN9Vg7ZfUHxPv3uhjeeJEtPe2PjFKWRsUrEF3RDttnXf9wXq3BfYBZemwKipJ24", + "H": "xpub6Bg8zbY94d1cBbAGT2crZL7C1UM8JWCP5CCtiHMnV4tB1pE9oCfjvZxRRFLi6EiamBDyCs3ARaHwU2FLx76YYCPFRVc1YyJi6depNtWRnoJ", + "I": "xpub6DMpHuTZTTN64eEHcNpyeQwehXgWTrY668ZkRWnRfkFEGKpNv2uPR3js1dJgcFRksSmrdtpHqFDPTzFsR1HqvzNdgZwXmk9vCLt1ypwUzA3", + "J": "xpub6CVeYPTG57D4tm9BvwCcakppwGJstbXyK8Yd611agusZuHmx7og3dNvr6pjMN6e4BoaNc5MZA4TjMLjMT2h2vJRU8rYLvHFUwrEL9zDbuqe", +} XPUB_PASSPHRASE_NONE = "xpub6BiVtCpG9fQPxnPmHXG8PhtzQdWC2Su4qWu6XW9tpWFYhxydCLJGrWBJZ5H6qTAHdPQ7pQhtpjiYZVZARo14qHiay2fvrX996oEP42u8wZy" XPUB_CARDANO_PASSPHRASE_B = "d80e770f6dfc3edb58eaab68aa091b2c27b08a47583471e93437ac5f8baa61880c7af4938a941c084c19731e6e57a5710e6ad1196263291aea297ce0eec0f177" ADDRESS_N = parse_path("44h/0h/0h") XPUB_REQUEST = messages.GetPublicKey(address_n=ADDRESS_N, coin_name="Bitcoin") +SESSIONS_STORED = 10 + def _init_session(client, session_id=None): """Call Initialize, check and return the session ID.""" @@ -54,24 +67,82 @@ def test_session_with_passphrase(client): # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] # Call Initialize again, this time with the received session id and then call # GetPublicKey. The passphrase should be cached now so Trezor must # not ask for it again, whilst returning the same xpub. new_session_id = _init_session(client, session_id=session_id) assert new_session_id == session_id - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] # If we set session id in Initialize to None, the cache will be cleared # and Trezor will ask for the passphrase again. new_session_id = _init_session(client) assert new_session_id != session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] # Unknown session id is the same as setting it to None. _init_session(client, session_id=b"X" * 32) - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + + +@pytest.mark.skip_ui +@pytest.mark.setup_client(passphrase=True) +def test_multiple_sessions(client): + session_ids = [] + + # start a session + session_ids.append(_init_session(client)) + assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] + # start it again wit the same session id + new_session_id = _init_session(client, session_id=session_ids[0]) + # session is the same + assert new_session_id == session_ids[0] + # passphrase is not prompted + assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + + # start a second session + session_ids.append(_init_session(client)) + # new session -> new session id and passphrase prompt + assert _get_xpub(client, passphrase="B") == XPUB_PASSPHRASES["B"] + + # provide the same session id -> must not ask for passphrase again. + new_session_id = _init_session(client, session_id=session_ids[1]) + assert new_session_id == session_ids[1] + assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["B"] + + # provide the first session id -> must not ask for passphrase again and return the same result. + new_session_id = _init_session(client, session_id=session_ids[0]) + assert new_session_id == session_ids[0] + assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] + + passphrases = list(XPUB_PASSPHRASES.keys()) + xpubs = list(XPUB_PASSPHRASES.values()) + # start as many sessions as the limit is (2 were started already) + for i in range(2, SESSIONS_STORED): + new_session_id = _init_session(client) + assert new_session_id not in session_ids + session_ids.append(new_session_id) + assert _get_xpub(client, passphrase=passphrases[i]) == xpubs[i] + + # passphrase is not prompted for the started the sessions, regardless the order + for i in reversed(range(0, SESSIONS_STORED)): + _init_session(client, session_id=session_ids[i]) + assert _get_xpub(client, passphrase=None) == xpubs[i] + + # creating one more will exceed the limit, the LRU item is at the moment the last one (see above) + # it should have been removed -> must ask for passphrase + _init_session(client) + _get_xpub(client, passphrase="XX") # create one more + _init_session(client, session_id=session_ids[SESSIONS_STORED - 1]) + _get_xpub(client, passphrase="whatever") # the session is gone + + # now the second to last is the next LRU + _init_session(client) + _get_xpub(client, passphrase="XXXX") + _init_session(client, session_id=session_ids[SESSIONS_STORED - 2]) + _get_xpub(client, passphrase="whatever") @pytest.mark.skip_ui @@ -95,24 +166,7 @@ def test_session_enable_passphrase(client): # We clear the session id now, so the passphrase should be asked. new_session_id = _init_session(client) assert session_id != new_session_id - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASE_A - - -@pytest.mark.skip_ui -@pytest.mark.setup_client(passphrase=True) -def test_clear_session_passphrase(client): - # at first attempt, we are prompted for passphrase - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASE_A - - # now the passphrase is cached - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_A - - # Erase the cached passphrase - response = client.call(messages.Initialize()) - assert isinstance(response, messages.Features) - - # we have to enter passphrase again - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] @pytest.mark.skip_ui @@ -126,12 +180,12 @@ def test_passphrase_on_device(client): assert isinstance(response, messages.PassphraseRequest) response = client.call_raw(messages.PassphraseAck(passphrase="A", on_device=False)) assert isinstance(response, messages.PublicKey) - assert response.xpub == XPUB_PASSPHRASE_A + assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached response = client.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) - assert response.xpub == XPUB_PASSPHRASE_A + assert response.xpub == XPUB_PASSPHRASES["A"] # make a new session _init_session(client) @@ -144,12 +198,12 @@ def test_passphrase_on_device(client): client.debug.input("A") response = client.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) - assert response.xpub == XPUB_PASSPHRASE_A + assert response.xpub == XPUB_PASSPHRASES["A"] # try to get xpub again, passphrase should be cached response = client.call_raw(XPUB_REQUEST) assert isinstance(response, messages.PublicKey) - assert response.xpub == XPUB_PASSPHRASE_A + assert response.xpub == XPUB_PASSPHRASES["A"] @pytest.mark.skip_ui @@ -184,7 +238,7 @@ def test_passphrase_always_on_device(client): client.debug.input("A") # Input empty passphrase. response = client.call_raw(messages.ButtonAck()) assert isinstance(response, messages.PublicKey) - assert response.xpub == XPUB_PASSPHRASE_A + assert response.xpub == XPUB_PASSPHRASES["A"] @pytest.mark.skip_ui @@ -275,10 +329,10 @@ def test_cardano_passphrase(client): # GetPublicKey requires passphrase and since it is not cached, # Trezor will prompt for it. - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] # The passphrase is now cached for non-Cardano coins. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] # Cardano will prompt for it again. assert _get_xpub_cardano(client, passphrase="B") == XPUB_CARDANO_PASSPHRASE_B @@ -287,18 +341,18 @@ def test_cardano_passphrase(client): assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B # And others behaviour did not change. - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] # Initialize with the session id does not destroy the state _init_session(client, session_id=session_id) - assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase=None) == XPUB_PASSPHRASES["A"] assert _get_xpub_cardano(client, passphrase=None) == XPUB_CARDANO_PASSPHRASE_B # New session will destroy the state _init_session(client) # GetPublicKey must ask for passphrase again - assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASE_A + assert _get_xpub(client, passphrase="A") == XPUB_PASSPHRASES["A"] # Cardano must also ask for passphrase again assert _get_xpub_cardano(client, passphrase="B") == XPUB_CARDANO_PASSPHRASE_B