mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-26 01:18:28 +00:00
test(core): update tests to reflect cache refactor
[no changelog]
This commit is contained in:
parent
7fc226258e
commit
6bc085c828
@ -1,6 +1,6 @@
|
|||||||
from common import H_, await_result, unittest # isort:skip
|
from common import H_, await_result, unittest # isort:skip
|
||||||
|
|
||||||
import storage.cache
|
import storage.cache_codec
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip32
|
from trezor.crypto import bip32
|
||||||
from trezor.enums import InputScriptType, OutputScriptType
|
from trezor.enums import InputScriptType, OutputScriptType
|
||||||
@ -11,6 +11,8 @@ from trezor.messages import (
|
|||||||
TxInput,
|
TxInput,
|
||||||
TxOutput,
|
TxOutput,
|
||||||
)
|
)
|
||||||
|
from trezor.wire import context
|
||||||
|
from trezor.wire.codec.codec_context import CodecContext
|
||||||
|
|
||||||
from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization
|
from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization
|
||||||
from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
|
from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
|
||||||
@ -20,6 +22,11 @@ from apps.common import coins
|
|||||||
|
|
||||||
|
|
||||||
class TestApprover(unittest.TestCase):
|
class TestApprover(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.coin = coins.by_name("Bitcoin")
|
self.coin = coins.by_name("Bitcoin")
|
||||||
self.fee_rate_percent = 0.3
|
self.fee_rate_percent = 0.3
|
||||||
@ -47,7 +54,7 @@ class TestApprover(unittest.TestCase):
|
|||||||
coin_name=self.coin.coin_name,
|
coin_name=self.coin.coin_name,
|
||||||
script_type=InputScriptType.SPENDTAPROOT,
|
script_type=InputScriptType.SPENDTAPROOT,
|
||||||
)
|
)
|
||||||
storage.cache.start_session()
|
storage.cache_codec.start_session()
|
||||||
|
|
||||||
def make_coinjoin_request(self, inputs):
|
def make_coinjoin_request(self, inputs):
|
||||||
return CoinJoinRequest(
|
return CoinJoinRequest(
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
from common import H_, unittest # isort:skip
|
from common import H_, unittest # isort:skip
|
||||||
|
|
||||||
import storage.cache
|
import storage.cache_codec
|
||||||
from trezor.enums import InputScriptType
|
from trezor.enums import InputScriptType
|
||||||
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx
|
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx
|
||||||
|
from trezor.wire import context
|
||||||
|
from trezor.wire.codec.codec_context import CodecContext
|
||||||
|
|
||||||
from apps.bitcoin.authorization import CoinJoinAuthorization
|
from apps.bitcoin.authorization import CoinJoinAuthorization
|
||||||
from apps.common import coins
|
from apps.common import coins
|
||||||
@ -12,6 +14,10 @@ _ROUND_ID_LEN = 32
|
|||||||
|
|
||||||
class TestAuthorization(unittest.TestCase):
|
class TestAuthorization(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
coin = coins.by_name("Bitcoin")
|
coin = coins.by_name("Bitcoin")
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
@ -26,7 +32,7 @@ class TestAuthorization(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.authorization = CoinJoinAuthorization(self.msg_auth)
|
self.authorization = CoinJoinAuthorization(self.msg_auth)
|
||||||
storage.cache.start_session()
|
storage.cache_codec.start_session()
|
||||||
|
|
||||||
def test_ownership_proof_account_depth_mismatch(self):
|
def test_ownership_proof_account_depth_mismatch(self):
|
||||||
# Account depth mismatch.
|
# Account depth mismatch.
|
||||||
|
@ -1,17 +1,25 @@
|
|||||||
from common import * # isort:skip
|
from common import * # isort:skip
|
||||||
|
|
||||||
from storage import cache
|
from storage import cache_common
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip39
|
from trezor.crypto import bip39
|
||||||
|
from trezor.wire import context
|
||||||
|
|
||||||
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
|
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
|
||||||
|
from trezor.wire.codec.codec_context import CodecContext
|
||||||
|
from storage import cache_codec
|
||||||
|
|
||||||
|
|
||||||
class TestBitcoinKeychain(unittest.TestCase):
|
class TestBitcoinKeychain(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
|
||||||
|
|
||||||
def test_bitcoin(self):
|
def test_bitcoin(self):
|
||||||
coin = _get_coin_by_name("Bitcoin")
|
coin = _get_coin_by_name("Bitcoin")
|
||||||
@ -88,10 +96,19 @@ class TestBitcoinKeychain(unittest.TestCase):
|
|||||||
|
|
||||||
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
|
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
|
||||||
class TestAltcoinKeychains(unittest.TestCase):
|
class TestAltcoinKeychains(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Context is needed to test decorators and handleInitialize
|
||||||
|
# It allows access to codec cache from different parts of the code
|
||||||
|
from trezor.wire import context
|
||||||
|
|
||||||
|
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
|
||||||
|
|
||||||
def test_bcash(self):
|
def test_bcash(self):
|
||||||
coin = _get_coin_by_name("Bcash")
|
coin = _get_coin_by_name("Bcash")
|
||||||
|
@ -1,19 +1,27 @@
|
|||||||
from common import * # isort:skip
|
from common import * # isort:skip
|
||||||
|
|
||||||
from mock_storage import mock_storage
|
from mock_storage import mock_storage
|
||||||
from storage import cache
|
from storage import cache, cache_common
|
||||||
from trezor import wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip39
|
from trezor.crypto import bip39
|
||||||
from trezor.enums import SafetyCheckLevel
|
from trezor.enums import SafetyCheckLevel
|
||||||
|
from trezor.wire import context
|
||||||
|
|
||||||
from apps.common import safety_checks
|
from apps.common import safety_checks
|
||||||
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
|
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
|
||||||
from apps.common.paths import PATTERN_SEP5, PathSchema
|
from apps.common.paths import PATTERN_SEP5, PathSchema
|
||||||
|
from trezor.wire.codec.codec_context import CodecContext
|
||||||
|
from storage import cache_codec
|
||||||
|
|
||||||
|
|
||||||
class TestKeychain(unittest.TestCase):
|
class TestKeychain(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
@ -71,7 +79,7 @@ class TestKeychain(unittest.TestCase):
|
|||||||
|
|
||||||
def test_get_keychain(self):
|
def test_get_keychain(self):
|
||||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
context.cache_set(cache_common.APP_COMMON_SEED, seed)
|
||||||
|
|
||||||
schema = PathSchema.parse("m/44'/1'", 0)
|
schema = PathSchema.parse("m/44'/1'", 0)
|
||||||
keychain = await_result(get_keychain("secp256k1", [schema]))
|
keychain = await_result(get_keychain("secp256k1", [schema]))
|
||||||
@ -85,7 +93,7 @@ class TestKeychain(unittest.TestCase):
|
|||||||
|
|
||||||
def test_with_slip44(self):
|
def test_with_slip44(self):
|
||||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
context.cache_set(cache_common.APP_COMMON_SEED, seed)
|
||||||
|
|
||||||
slip44_id = 42
|
slip44_id = 42
|
||||||
valid_path = [H_(44), H_(slip44_id), H_(0)]
|
valid_path = [H_(44), H_(slip44_id), H_(0)]
|
||||||
|
@ -2,12 +2,15 @@ from common import * # isort:skip
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from storage import cache
|
from storage import cache_common
|
||||||
from trezor import utils, wire
|
from trezor import wire
|
||||||
from trezor.crypto import bip39
|
from trezor.crypto import bip39
|
||||||
|
from trezor.wire import context
|
||||||
|
|
||||||
from apps.common.keychain import get_keychain
|
from apps.common.keychain import get_keychain
|
||||||
from apps.common.paths import HARDENED
|
from apps.common.paths import HARDENED
|
||||||
|
from trezor.wire.codec.codec_context import CodecContext
|
||||||
|
from storage import cache_codec
|
||||||
|
|
||||||
if not utils.BITCOIN_ONLY:
|
if not utils.BITCOIN_ONLY:
|
||||||
from ethereum_common import encode_network, make_network
|
from ethereum_common import encode_network, make_network
|
||||||
@ -71,10 +74,14 @@ class TestEthereumKeychain(unittest.TestCase):
|
|||||||
addr,
|
addr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
seed = bip39.seed(" ".join(["all"] * 12), "")
|
seed = bip39.seed(" ".join(["all"] * 12), "")
|
||||||
cache.set(cache.APP_COMMON_SEED, seed)
|
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
|
||||||
|
|
||||||
def from_address_n(self, address_n):
|
def from_address_n(self, address_n):
|
||||||
slip44 = _slip44_from_address_n(address_n)
|
slip44 = _slip44_from_address_n(address_n)
|
||||||
|
@ -1,150 +1,158 @@
|
|||||||
from common import * # isort:skip
|
from common import * # isort:skip
|
||||||
|
|
||||||
from mock_storage import mock_storage
|
from mock_storage import mock_storage
|
||||||
from storage import cache
|
from storage import cache, cache_codec, cache_common
|
||||||
from trezor.messages import EndSession, Initialize
|
from trezor.messages import EndSession, Initialize
|
||||||
|
from trezor.wire import context
|
||||||
|
from trezor.wire.codec.codec_context import CodecContext
|
||||||
|
|
||||||
from apps.base import handle_EndSession, handle_Initialize
|
from apps.base import handle_EndSession, handle_Initialize
|
||||||
|
from apps.common.cache import stored, stored_async
|
||||||
|
|
||||||
KEY = 0
|
KEY = 0
|
||||||
|
|
||||||
|
|
||||||
# Function moved from cache.py, as it was not used there
|
# Function moved from cache.py, as it was not used there
|
||||||
def is_session_started() -> bool:
|
def is_session_started() -> bool:
|
||||||
return cache._active_session_idx is not None
|
return cache_codec._active_session_idx is not None
|
||||||
|
|
||||||
|
|
||||||
class TestStorageCache(unittest.TestCase):
|
class TestStorageCache(unittest.TestCase):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
|
|
||||||
def test_start_session(self):
|
def test_start_session(self):
|
||||||
session_id_a = cache.start_session()
|
session_id_a = cache_codec.start_session()
|
||||||
self.assertIsNotNone(session_id_a)
|
self.assertIsNotNone(session_id_a)
|
||||||
session_id_b = cache.start_session()
|
session_id_b = cache_codec.start_session()
|
||||||
self.assertNotEqual(session_id_a, session_id_b)
|
self.assertNotEqual(session_id_a, session_id_b)
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(cache.InvalidSessionError):
|
with self.assertRaises(cache_common.InvalidSessionError):
|
||||||
cache.set(KEY, "something")
|
context.cache_set(KEY, "something")
|
||||||
with self.assertRaises(cache.InvalidSessionError):
|
with self.assertRaises(cache_common.InvalidSessionError):
|
||||||
cache.get(KEY)
|
context.cache_get(KEY)
|
||||||
|
|
||||||
def test_end_session(self):
|
def test_end_session(self):
|
||||||
session_id = cache.start_session()
|
session_id = cache_codec.start_session()
|
||||||
self.assertTrue(is_session_started())
|
self.assertTrue(is_session_started())
|
||||||
cache.set(KEY, b"A")
|
context.cache_set(KEY, b"A")
|
||||||
cache.end_current_session()
|
cache_codec.end_current_session()
|
||||||
self.assertFalse(is_session_started())
|
self.assertFalse(is_session_started())
|
||||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
|
||||||
|
|
||||||
# ending an ended session should be a no-op
|
# ending an ended session should be a no-op
|
||||||
cache.end_current_session()
|
cache_codec.end_current_session()
|
||||||
self.assertFalse(is_session_started())
|
self.assertFalse(is_session_started())
|
||||||
|
|
||||||
session_id_a = cache.start_session(session_id)
|
session_id_a = cache_codec.start_session(session_id)
|
||||||
# original session no longer exists
|
# original session no longer exists
|
||||||
self.assertNotEqual(session_id_a, session_id)
|
self.assertNotEqual(session_id_a, session_id)
|
||||||
# original session data no longer exists
|
# original session data no longer exists
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
|
|
||||||
# create a new session
|
# create a new session
|
||||||
session_id_b = cache.start_session()
|
session_id_b = cache_codec.start_session()
|
||||||
# switch back to original session
|
# switch back to original session
|
||||||
session_id = cache.start_session(session_id_a)
|
session_id = cache_codec.start_session(session_id_a)
|
||||||
self.assertEqual(session_id, session_id_a)
|
self.assertEqual(session_id, session_id_a)
|
||||||
# end original session
|
# end original session
|
||||||
cache.end_current_session()
|
cache_codec.end_current_session()
|
||||||
# switch back to B
|
# switch back to B
|
||||||
session_id = cache.start_session(session_id_b)
|
session_id = cache_codec.start_session(session_id_b)
|
||||||
self.assertEqual(session_id, session_id_b)
|
self.assertEqual(session_id, session_id_b)
|
||||||
|
|
||||||
def test_session_queue(self):
|
def test_session_queue(self):
|
||||||
session_id = cache.start_session()
|
session_id = cache_codec.start_session()
|
||||||
self.assertEqual(cache.start_session(session_id), session_id)
|
self.assertEqual(cache_codec.start_session(session_id), session_id)
|
||||||
cache.set(KEY, b"A")
|
context.cache_set(KEY, b"A")
|
||||||
for i in range(cache._MAX_SESSIONS_COUNT):
|
for i in range(cache_codec._MAX_SESSIONS_COUNT):
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
self.assertNotEqual(cache.start_session(session_id), session_id)
|
self.assertNotEqual(cache_codec.start_session(session_id), session_id)
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
|
|
||||||
def test_get_set(self):
|
def test_get_set(self):
|
||||||
session_id1 = cache.start_session()
|
session_id1 = cache_codec.start_session()
|
||||||
cache.set(KEY, b"hello")
|
context.cache_set(KEY, b"hello")
|
||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
||||||
|
|
||||||
session_id2 = cache.start_session()
|
session_id2 = cache_codec.start_session()
|
||||||
cache.set(KEY, b"world")
|
context.cache_set(KEY, b"world")
|
||||||
self.assertEqual(cache.get(KEY), b"world")
|
self.assertEqual(context.cache_get(KEY), b"world")
|
||||||
|
|
||||||
cache.start_session(session_id2)
|
cache_codec.start_session(session_id2)
|
||||||
self.assertEqual(cache.get(KEY), b"world")
|
self.assertEqual(context.cache_get(KEY), b"world")
|
||||||
cache.start_session(session_id1)
|
cache_codec.start_session(session_id1)
|
||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(cache.InvalidSessionError):
|
with self.assertRaises(cache_common.InvalidSessionError):
|
||||||
cache.get(KEY)
|
context.cache_get(KEY)
|
||||||
|
|
||||||
def test_get_set_int(self):
|
def test_get_set_int(self):
|
||||||
session_id1 = cache.start_session()
|
session_id1 = cache_codec.start_session()
|
||||||
cache.set_int(KEY, 1234)
|
context.cache_set_int(KEY, 1234)
|
||||||
self.assertEqual(cache.get_int(KEY), 1234)
|
self.assertEqual(context.cache_get_int(KEY), 1234)
|
||||||
|
|
||||||
session_id2 = cache.start_session()
|
session_id2 = cache_codec.start_session()
|
||||||
cache.set_int(KEY, 5678)
|
context.cache_set_int(KEY, 5678)
|
||||||
self.assertEqual(cache.get_int(KEY), 5678)
|
self.assertEqual(context.cache_get_int(KEY), 5678)
|
||||||
|
|
||||||
cache.start_session(session_id2)
|
cache_codec.start_session(session_id2)
|
||||||
self.assertEqual(cache.get_int(KEY), 5678)
|
self.assertEqual(context.cache_get_int(KEY), 5678)
|
||||||
cache.start_session(session_id1)
|
cache_codec.start_session(session_id1)
|
||||||
self.assertEqual(cache.get_int(KEY), 1234)
|
self.assertEqual(context.cache_get_int(KEY), 1234)
|
||||||
|
|
||||||
cache.clear_all()
|
cache.clear_all()
|
||||||
with self.assertRaises(cache.InvalidSessionError):
|
with self.assertRaises(cache_common.InvalidSessionError):
|
||||||
cache.get_int(KEY)
|
context.cache_get_int(KEY)
|
||||||
|
|
||||||
def test_delete(self):
|
def test_delete(self):
|
||||||
session_id1 = cache.start_session()
|
session_id1 = cache_codec.start_session()
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
cache.set(KEY, b"hello")
|
context.cache_set(KEY, b"hello")
|
||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
||||||
cache.delete(KEY)
|
context.cache_delete(KEY)
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
|
|
||||||
cache.set(KEY, b"hello")
|
context.cache_set(KEY, b"hello")
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
cache.set(KEY, b"hello")
|
context.cache_set(KEY, b"hello")
|
||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
||||||
cache.delete(KEY)
|
context.cache_delete(KEY)
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
|
|
||||||
cache.start_session(session_id1)
|
cache_codec.start_session(session_id1)
|
||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
||||||
|
|
||||||
def test_decorators(self):
|
def test_decorators(self):
|
||||||
run_count = 0
|
run_count = 0
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
|
|
||||||
@cache.stored(KEY)
|
@stored(KEY)
|
||||||
def func():
|
def func():
|
||||||
nonlocal run_count
|
nonlocal run_count
|
||||||
run_count += 1
|
run_count += 1
|
||||||
return b"foo"
|
return b"foo"
|
||||||
|
|
||||||
# cache is empty
|
# cache is empty
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
self.assertEqual(run_count, 0)
|
self.assertEqual(run_count, 0)
|
||||||
self.assertEqual(func(), b"foo")
|
self.assertEqual(func(), b"foo")
|
||||||
# function was run
|
# function was run
|
||||||
self.assertEqual(run_count, 1)
|
self.assertEqual(run_count, 1)
|
||||||
self.assertEqual(cache.get(KEY), b"foo")
|
self.assertEqual(context.cache_get(KEY), b"foo")
|
||||||
# function does not run again but returns cached value
|
# function does not run again but returns cached value
|
||||||
self.assertEqual(func(), b"foo")
|
self.assertEqual(func(), b"foo")
|
||||||
self.assertEqual(run_count, 1)
|
self.assertEqual(run_count, 1)
|
||||||
|
|
||||||
@cache.stored_async(KEY)
|
@stored_async(KEY)
|
||||||
async def async_func():
|
async def async_func():
|
||||||
nonlocal run_count
|
nonlocal run_count
|
||||||
run_count += 1
|
run_count += 1
|
||||||
@ -154,7 +162,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(await_result(async_func()), b"foo")
|
self.assertEqual(await_result(async_func()), b"foo")
|
||||||
self.assertEqual(run_count, 1)
|
self.assertEqual(run_count, 1)
|
||||||
|
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
self.assertEqual(await_result(async_func()), b"bar")
|
self.assertEqual(await_result(async_func()), b"bar")
|
||||||
self.assertEqual(run_count, 2)
|
self.assertEqual(run_count, 2)
|
||||||
# awaitable is also run only once
|
# awaitable is also run only once
|
||||||
@ -162,16 +170,16 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(run_count, 2)
|
self.assertEqual(run_count, 2)
|
||||||
|
|
||||||
def test_empty_value(self):
|
def test_empty_value(self):
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
|
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
cache.set(KEY, b"")
|
context.cache_set(KEY, b"")
|
||||||
self.assertEqual(cache.get(KEY), b"")
|
self.assertEqual(context.cache_get(KEY), b"")
|
||||||
|
|
||||||
cache.delete(KEY)
|
context.cache_delete(KEY)
|
||||||
run_count = 0
|
run_count = 0
|
||||||
|
|
||||||
@cache.stored(KEY)
|
@stored(KEY)
|
||||||
def func():
|
def func():
|
||||||
nonlocal run_count
|
nonlocal run_count
|
||||||
run_count += 1
|
run_count += 1
|
||||||
@ -191,7 +199,7 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
return await_result(handle_Initialize(msg))
|
return await_result(handle_Initialize(msg))
|
||||||
|
|
||||||
# calling Initialize without an ID allocates a new one
|
# calling Initialize without an ID allocates a new one
|
||||||
session_id = cache.start_session()
|
session_id = cache_codec.start_session()
|
||||||
features = call_Initialize()
|
features = call_Initialize()
|
||||||
self.assertNotEqual(session_id, features.session_id)
|
self.assertNotEqual(session_id, features.session_id)
|
||||||
|
|
||||||
@ -200,31 +208,31 @@ class TestStorageCache(unittest.TestCase):
|
|||||||
self.assertEqual(session_id, features.session_id)
|
self.assertEqual(session_id, features.session_id)
|
||||||
|
|
||||||
# store "hello"
|
# store "hello"
|
||||||
cache.set(KEY, b"hello")
|
context.cache_set(KEY, b"hello")
|
||||||
# check that it is cleared
|
# check that it is cleared
|
||||||
features = call_Initialize()
|
features = call_Initialize()
|
||||||
session_id = features.session_id
|
session_id = features.session_id
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
# store "hello" again
|
# store "hello" again
|
||||||
cache.set(KEY, b"hello")
|
context.cache_set(KEY, b"hello")
|
||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
||||||
|
|
||||||
# supplying a different session ID starts a new cache
|
# supplying a different session ID starts a new cache
|
||||||
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH)
|
call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH)
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
|
|
||||||
# but resuming a session loads the previous one
|
# but resuming a session loads the previous one
|
||||||
call_Initialize(session_id=session_id)
|
call_Initialize(session_id=session_id)
|
||||||
self.assertEqual(cache.get(KEY), b"hello")
|
self.assertEqual(context.cache_get(KEY), b"hello")
|
||||||
|
|
||||||
def test_EndSession(self):
|
def test_EndSession(self):
|
||||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
|
||||||
cache.start_session()
|
cache_codec.start_session()
|
||||||
self.assertTrue(is_session_started())
|
self.assertTrue(is_session_started())
|
||||||
self.assertIsNone(cache.get(KEY))
|
self.assertIsNone(context.cache_get(KEY))
|
||||||
await_result(handle_EndSession(EndSession()))
|
await_result(handle_EndSession(EndSession()))
|
||||||
self.assertFalse(is_session_started())
|
self.assertFalse(is_session_started())
|
||||||
self.assertRaises(cache.InvalidSessionError, cache.get, KEY)
|
self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -5,7 +5,7 @@ import ustruct
|
|||||||
from trezor import io
|
from trezor import io
|
||||||
from trezor.loop import wait
|
from trezor.loop import wait
|
||||||
from trezor.utils import chunks
|
from trezor.utils import chunks
|
||||||
from trezor.wire import codec_v1
|
from trezor.wire.codec import codec_v1
|
||||||
|
|
||||||
|
|
||||||
class MockHID:
|
class MockHID:
|
||||||
|
Loading…
Reference in New Issue
Block a user