diff --git a/core/tests/test_apps.bitcoin.approver.py b/core/tests/test_apps.bitcoin.approver.py index 7354a846b1..1bc48040e7 100644 --- a/core/tests/test_apps.bitcoin.approver.py +++ b/core/tests/test_apps.bitcoin.approver.py @@ -1,6 +1,6 @@ from common import H_, await_result, unittest # isort:skip -import storage.cache +import storage.cache_codec from trezor import wire from trezor.crypto import bip32 from trezor.enums import InputScriptType, OutputScriptType @@ -11,6 +11,8 @@ from trezor.messages import ( TxInput, TxOutput, ) +from trezor.wire import context +from trezor.wire.codec.codec_context import CodecContext from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization from apps.bitcoin.sign_tx.approvers import CoinJoinApprover @@ -20,6 +22,11 @@ from apps.common import coins class TestApprover(unittest.TestCase): + + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + def setUp(self): self.coin = coins.by_name("Bitcoin") self.fee_rate_percent = 0.3 @@ -47,7 +54,7 @@ class TestApprover(unittest.TestCase): coin_name=self.coin.coin_name, script_type=InputScriptType.SPENDTAPROOT, ) - storage.cache.start_session() + storage.cache_codec.start_session() def make_coinjoin_request(self, inputs): return CoinJoinRequest( diff --git a/core/tests/test_apps.bitcoin.authorization.py b/core/tests/test_apps.bitcoin.authorization.py index 503c181569..a2a8747285 100644 --- a/core/tests/test_apps.bitcoin.authorization.py +++ b/core/tests/test_apps.bitcoin.authorization.py @@ -1,8 +1,10 @@ from common import H_, unittest # isort:skip -import storage.cache +import storage.cache_codec from trezor.enums import InputScriptType from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx +from trezor.wire import context +from trezor.wire.codec.codec_context import CodecContext from apps.bitcoin.authorization import CoinJoinAuthorization from apps.common import coins @@ -12,6 +14,10 @@ _ROUND_ID_LEN = 32 class TestAuthorization(unittest.TestCase): + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + coin = coins.by_name("Bitcoin") def setUp(self): @@ -26,7 +32,7 @@ class TestAuthorization(unittest.TestCase): ) self.authorization = CoinJoinAuthorization(self.msg_auth) - storage.cache.start_session() + storage.cache_codec.start_session() def test_ownership_proof_account_depth_mismatch(self): # Account depth mismatch. diff --git a/core/tests/test_apps.bitcoin.keychain.py b/core/tests/test_apps.bitcoin.keychain.py index 3828a3ebbc..e21f88c8c0 100644 --- a/core/tests/test_apps.bitcoin.keychain.py +++ b/core/tests/test_apps.bitcoin.keychain.py @@ -1,17 +1,25 @@ from common import * # isort:skip -from storage import cache +from storage import cache_common from trezor import wire from trezor.crypto import bip39 +from trezor.wire import context from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin +from trezor.wire.codec.codec_context import CodecContext +from storage import cache_codec class TestBitcoinKeychain(unittest.TestCase): + + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + def setUp(self): - cache.start_session() + cache_codec.start_session() seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def test_bitcoin(self): coin = _get_coin_by_name("Bitcoin") @@ -88,10 +96,19 @@ class TestBitcoinKeychain(unittest.TestCase): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestAltcoinKeychains(unittest.TestCase): + + def __init__(self): + # Context is needed to test decorators and handleInitialize + # It allows access to codec cache from different parts of the code + from trezor.wire import context + + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + def setUp(self): - cache.start_session() + cache_codec.start_session() seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def test_bcash(self): coin = _get_coin_by_name("Bcash") diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py index 84681a0b01..ea85b2eb8f 100644 --- a/core/tests/test_apps.common.keychain.py +++ b/core/tests/test_apps.common.keychain.py @@ -1,19 +1,27 @@ from common import * # isort:skip from mock_storage import mock_storage -from storage import cache +from storage import cache, cache_common from trezor import wire from trezor.crypto import bip39 from trezor.enums import SafetyCheckLevel +from trezor.wire import context from apps.common import safety_checks from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.paths import PATTERN_SEP5, PathSchema +from trezor.wire.codec.codec_context import CodecContext +from storage import cache_codec class TestKeychain(unittest.TestCase): + + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + def setUp(self): - cache.start_session() + cache_codec.start_session() def tearDown(self): cache.clear_all() @@ -71,7 +79,7 @@ class TestKeychain(unittest.TestCase): def test_get_keychain(self): seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + context.cache_set(cache_common.APP_COMMON_SEED, seed) schema = PathSchema.parse("m/44'/1'", 0) keychain = await_result(get_keychain("secp256k1", [schema])) @@ -85,7 +93,7 @@ class TestKeychain(unittest.TestCase): def test_with_slip44(self): seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + context.cache_set(cache_common.APP_COMMON_SEED, seed) slip44_id = 42 valid_path = [H_(44), H_(slip44_id), H_(0)] diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py index 53affef1b7..404dc07641 100644 --- a/core/tests/test_apps.ethereum.keychain.py +++ b/core/tests/test_apps.ethereum.keychain.py @@ -2,12 +2,15 @@ from common import * # isort:skip import unittest -from storage import cache -from trezor import utils, wire +from storage import cache_common +from trezor import wire from trezor.crypto import bip39 +from trezor.wire import context from apps.common.keychain import get_keychain from apps.common.paths import HARDENED +from trezor.wire.codec.codec_context import CodecContext +from storage import cache_codec if not utils.BITCOIN_ONLY: from ethereum_common import encode_network, make_network @@ -71,10 +74,14 @@ class TestEthereumKeychain(unittest.TestCase): addr, ) + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + def setUp(self): - cache.start_session() + cache_codec.start_session() seed = bip39.seed(" ".join(["all"] * 12), "") - cache.set(cache.APP_COMMON_SEED, seed) + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def from_address_n(self, address_n): slip44 = _slip44_from_address_n(address_n) diff --git a/core/tests/test_storage.cache.py b/core/tests/test_storage.cache.py index 76fe29655b..60d1168004 100644 --- a/core/tests/test_storage.cache.py +++ b/core/tests/test_storage.cache.py @@ -1,150 +1,158 @@ from common import * # isort:skip from mock_storage import mock_storage -from storage import cache +from storage import cache, cache_codec, cache_common from trezor.messages import EndSession, Initialize +from trezor.wire import context +from trezor.wire.codec.codec_context import CodecContext from apps.base import handle_EndSession, handle_Initialize +from apps.common.cache import stored, stored_async KEY = 0 # Function moved from cache.py, as it was not used there def is_session_started() -> bool: - return cache._active_session_idx is not None + return cache_codec._active_session_idx is not None class TestStorageCache(unittest.TestCase): + + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + def setUp(self): cache.clear_all() def test_start_session(self): - session_id_a = cache.start_session() + session_id_a = cache_codec.start_session() self.assertIsNotNone(session_id_a) - session_id_b = cache.start_session() + session_id_b = cache_codec.start_session() self.assertNotEqual(session_id_a, session_id_b) cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.set(KEY, "something") - with self.assertRaises(cache.InvalidSessionError): - cache.get(KEY) + with self.assertRaises(cache_common.InvalidSessionError): + context.cache_set(KEY, "something") + with self.assertRaises(cache_common.InvalidSessionError): + context.cache_get(KEY) def test_end_session(self): - session_id = cache.start_session() + session_id = cache_codec.start_session() self.assertTrue(is_session_started()) - cache.set(KEY, b"A") - cache.end_current_session() + context.cache_set(KEY, b"A") + cache_codec.end_current_session() self.assertFalse(is_session_started()) - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) + self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) # ending an ended session should be a no-op - cache.end_current_session() + cache_codec.end_current_session() self.assertFalse(is_session_started()) - session_id_a = cache.start_session(session_id) + session_id_a = cache_codec.start_session(session_id) # original session no longer exists self.assertNotEqual(session_id_a, session_id) # original session data no longer exists - self.assertIsNone(cache.get(KEY)) + self.assertIsNone(context.cache_get(KEY)) # create a new session - session_id_b = cache.start_session() + session_id_b = cache_codec.start_session() # switch back to original session - session_id = cache.start_session(session_id_a) + session_id = cache_codec.start_session(session_id_a) self.assertEqual(session_id, session_id_a) # end original session - cache.end_current_session() + cache_codec.end_current_session() # switch back to B - session_id = cache.start_session(session_id_b) + session_id = cache_codec.start_session(session_id_b) self.assertEqual(session_id, session_id_b) def test_session_queue(self): - session_id = cache.start_session() - self.assertEqual(cache.start_session(session_id), session_id) - cache.set(KEY, b"A") - for i in range(cache._MAX_SESSIONS_COUNT): - cache.start_session() - self.assertNotEqual(cache.start_session(session_id), session_id) - self.assertIsNone(cache.get(KEY)) + session_id = cache_codec.start_session() + self.assertEqual(cache_codec.start_session(session_id), session_id) + context.cache_set(KEY, b"A") + for i in range(cache_codec._MAX_SESSIONS_COUNT): + cache_codec.start_session() + self.assertNotEqual(cache_codec.start_session(session_id), session_id) + self.assertIsNone(context.cache_get(KEY)) def test_get_set(self): - session_id1 = cache.start_session() - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") + session_id1 = cache_codec.start_session() + context.cache_set(KEY, b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") - session_id2 = cache.start_session() - cache.set(KEY, b"world") - self.assertEqual(cache.get(KEY), b"world") + session_id2 = cache_codec.start_session() + context.cache_set(KEY, b"world") + self.assertEqual(context.cache_get(KEY), b"world") - cache.start_session(session_id2) - self.assertEqual(cache.get(KEY), b"world") - cache.start_session(session_id1) - self.assertEqual(cache.get(KEY), b"hello") + cache_codec.start_session(session_id2) + self.assertEqual(context.cache_get(KEY), b"world") + cache_codec.start_session(session_id1) + self.assertEqual(context.cache_get(KEY), b"hello") cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.get(KEY) + with self.assertRaises(cache_common.InvalidSessionError): + context.cache_get(KEY) def test_get_set_int(self): - session_id1 = cache.start_session() - cache.set_int(KEY, 1234) - self.assertEqual(cache.get_int(KEY), 1234) + session_id1 = cache_codec.start_session() + context.cache_set_int(KEY, 1234) + self.assertEqual(context.cache_get_int(KEY), 1234) - session_id2 = cache.start_session() - cache.set_int(KEY, 5678) - self.assertEqual(cache.get_int(KEY), 5678) + session_id2 = cache_codec.start_session() + context.cache_set_int(KEY, 5678) + self.assertEqual(context.cache_get_int(KEY), 5678) - cache.start_session(session_id2) - self.assertEqual(cache.get_int(KEY), 5678) - cache.start_session(session_id1) - self.assertEqual(cache.get_int(KEY), 1234) + cache_codec.start_session(session_id2) + self.assertEqual(context.cache_get_int(KEY), 5678) + cache_codec.start_session(session_id1) + self.assertEqual(context.cache_get_int(KEY), 1234) cache.clear_all() - with self.assertRaises(cache.InvalidSessionError): - cache.get_int(KEY) + with self.assertRaises(cache_common.InvalidSessionError): + context.cache_get_int(KEY) def test_delete(self): - session_id1 = cache.start_session() - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") - cache.delete(KEY) - self.assertIsNone(cache.get(KEY)) + session_id1 = cache_codec.start_session() + self.assertIsNone(context.cache_get(KEY)) + context.cache_set(KEY, b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") + context.cache_delete(KEY) + self.assertIsNone(context.cache_get(KEY)) - cache.set(KEY, b"hello") - cache.start_session() - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") - cache.delete(KEY) - self.assertIsNone(cache.get(KEY)) + context.cache_set(KEY, b"hello") + cache_codec.start_session() + self.assertIsNone(context.cache_get(KEY)) + context.cache_set(KEY, b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") + context.cache_delete(KEY) + self.assertIsNone(context.cache_get(KEY)) - cache.start_session(session_id1) - self.assertEqual(cache.get(KEY), b"hello") + cache_codec.start_session(session_id1) + self.assertEqual(context.cache_get(KEY), b"hello") def test_decorators(self): run_count = 0 - cache.start_session() + cache_codec.start_session() - @cache.stored(KEY) + @stored(KEY) def func(): nonlocal run_count run_count += 1 return b"foo" # cache is empty - self.assertIsNone(cache.get(KEY)) + self.assertIsNone(context.cache_get(KEY)) self.assertEqual(run_count, 0) self.assertEqual(func(), b"foo") # function was run self.assertEqual(run_count, 1) - self.assertEqual(cache.get(KEY), b"foo") + self.assertEqual(context.cache_get(KEY), b"foo") # function does not run again but returns cached value self.assertEqual(func(), b"foo") self.assertEqual(run_count, 1) - @cache.stored_async(KEY) + @stored_async(KEY) async def async_func(): nonlocal run_count run_count += 1 @@ -154,7 +162,7 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(await_result(async_func()), b"foo") self.assertEqual(run_count, 1) - cache.start_session() + cache_codec.start_session() self.assertEqual(await_result(async_func()), b"bar") self.assertEqual(run_count, 2) # awaitable is also run only once @@ -162,16 +170,16 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(run_count, 2) def test_empty_value(self): - cache.start_session() + cache_codec.start_session() - self.assertIsNone(cache.get(KEY)) - cache.set(KEY, b"") - self.assertEqual(cache.get(KEY), b"") + self.assertIsNone(context.cache_get(KEY)) + context.cache_set(KEY, b"") + self.assertEqual(context.cache_get(KEY), b"") - cache.delete(KEY) + context.cache_delete(KEY) run_count = 0 - @cache.stored(KEY) + @stored(KEY) def func(): nonlocal run_count run_count += 1 @@ -191,7 +199,7 @@ class TestStorageCache(unittest.TestCase): return await_result(handle_Initialize(msg)) # calling Initialize without an ID allocates a new one - session_id = cache.start_session() + session_id = cache_codec.start_session() features = call_Initialize() self.assertNotEqual(session_id, features.session_id) @@ -200,31 +208,31 @@ class TestStorageCache(unittest.TestCase): self.assertEqual(session_id, features.session_id) # store "hello" - cache.set(KEY, b"hello") + context.cache_set(KEY, b"hello") # check that it is cleared features = call_Initialize() session_id = features.session_id - self.assertIsNone(cache.get(KEY)) + self.assertIsNone(context.cache_get(KEY)) # store "hello" again - cache.set(KEY, b"hello") - self.assertEqual(cache.get(KEY), b"hello") + context.cache_set(KEY, b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") # supplying a different session ID starts a new cache - call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH) - self.assertIsNone(cache.get(KEY)) + call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH) + self.assertIsNone(context.cache_get(KEY)) # but resuming a session loads the previous one call_Initialize(session_id=session_id) - self.assertEqual(cache.get(KEY), b"hello") + self.assertEqual(context.cache_get(KEY), b"hello") def test_EndSession(self): - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) - cache.start_session() + self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) + cache_codec.start_session() self.assertTrue(is_session_started()) - self.assertIsNone(cache.get(KEY)) + self.assertIsNone(context.cache_get(KEY)) await_result(handle_EndSession(EndSession())) self.assertFalse(is_session_started()) - self.assertRaises(cache.InvalidSessionError, cache.get, KEY) + self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY) if __name__ == "__main__": diff --git a/core/tests/test_trezor.wire.codec_v1.py b/core/tests/test_trezor.wire.codec.codec_v1.py similarity index 99% rename from core/tests/test_trezor.wire.codec_v1.py rename to core/tests/test_trezor.wire.codec.codec_v1.py index 1da0ea896b..78675859e2 100644 --- a/core/tests/test_trezor.wire.codec_v1.py +++ b/core/tests/test_trezor.wire.codec.codec_v1.py @@ -5,7 +5,7 @@ import ustruct from trezor import io from trezor.loop import wait from trezor.utils import chunks -from trezor.wire import codec_v1 +from trezor.wire.codec import codec_v1 class MockHID: