1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-11 17:08:15 +00:00

test(core): update tests to reflect cache refactor

[no changelog]
This commit is contained in:
M1nd3r 2024-11-15 21:46:22 +01:00 committed by Petr Sedláček
parent e77477cb46
commit a3c275f19a
7 changed files with 160 additions and 107 deletions

View File

@ -1,6 +1,6 @@
from common import H_, await_result, unittest # isort:skip from common import H_, await_result, unittest # isort:skip
import storage.cache import storage.cache_codec
from trezor import wire from trezor import wire
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.enums import InputScriptType, OutputScriptType from trezor.enums import InputScriptType, OutputScriptType
@ -11,6 +11,8 @@ from trezor.messages import (
TxInput, TxInput,
TxOutput, TxOutput,
) )
from trezor.wire import context
from trezor.wire.codec.codec_context import CodecContext
from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization from apps.bitcoin.authorization import FEE_RATE_DECIMALS, CoinJoinAuthorization
from apps.bitcoin.sign_tx.approvers import CoinJoinApprover from apps.bitcoin.sign_tx.approvers import CoinJoinApprover
@ -20,6 +22,11 @@ from apps.common import coins
class TestApprover(unittest.TestCase): class TestApprover(unittest.TestCase):
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self): def setUp(self):
self.coin = coins.by_name("Bitcoin") self.coin = coins.by_name("Bitcoin")
self.fee_rate_percent = 0.3 self.fee_rate_percent = 0.3
@ -47,7 +54,7 @@ class TestApprover(unittest.TestCase):
coin_name=self.coin.coin_name, coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDTAPROOT, script_type=InputScriptType.SPENDTAPROOT,
) )
storage.cache.start_session() storage.cache_codec.start_session()
def make_coinjoin_request(self, inputs): def make_coinjoin_request(self, inputs):
return CoinJoinRequest( return CoinJoinRequest(

View File

@ -1,8 +1,10 @@
from common import H_, unittest # isort:skip from common import H_, unittest # isort:skip
import storage.cache import storage.cache_codec
from trezor.enums import InputScriptType from trezor.enums import InputScriptType
from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx from trezor.messages import AuthorizeCoinJoin, GetOwnershipProof, SignTx
from trezor.wire import context
from trezor.wire.codec.codec_context import CodecContext
from apps.bitcoin.authorization import CoinJoinAuthorization from apps.bitcoin.authorization import CoinJoinAuthorization
from apps.common import coins from apps.common import coins
@ -12,6 +14,10 @@ _ROUND_ID_LEN = 32
class TestAuthorization(unittest.TestCase): class TestAuthorization(unittest.TestCase):
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
coin = coins.by_name("Bitcoin") coin = coins.by_name("Bitcoin")
def setUp(self): def setUp(self):
@ -26,7 +32,7 @@ class TestAuthorization(unittest.TestCase):
) )
self.authorization = CoinJoinAuthorization(self.msg_auth) self.authorization = CoinJoinAuthorization(self.msg_auth)
storage.cache.start_session() storage.cache_codec.start_session()
def test_ownership_proof_account_depth_mismatch(self): def test_ownership_proof_account_depth_mismatch(self):
# Account depth mismatch. # Account depth mismatch.

View File

@ -1,17 +1,25 @@
from common import * # isort:skip from common import * # isort:skip
from storage import cache from storage import cache_common
from trezor import wire from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.wire import context
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
class TestBitcoinKeychain(unittest.TestCase): class TestBitcoinKeychain(unittest.TestCase):
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self): def setUp(self):
cache.start_session() cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "") seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed) cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bitcoin(self): def test_bitcoin(self):
coin = _get_coin_by_name("Bitcoin") coin = _get_coin_by_name("Bitcoin")
@ -88,10 +96,19 @@ class TestBitcoinKeychain(unittest.TestCase):
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestAltcoinKeychains(unittest.TestCase): class TestAltcoinKeychains(unittest.TestCase):
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self): def setUp(self):
cache.start_session() cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "") seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed) cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bcash(self): def test_bcash(self):
coin = _get_coin_by_name("Bcash") coin = _get_coin_by_name("Bcash")

View File

@ -1,19 +1,27 @@
from common import * # isort:skip from common import * # isort:skip
from mock_storage import mock_storage from mock_storage import mock_storage
from storage import cache from storage import cache, cache_common
from trezor import wire from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.enums import SafetyCheckLevel from trezor.enums import SafetyCheckLevel
from trezor.wire import context
from apps.common import safety_checks from apps.common import safety_checks
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
from apps.common.paths import PATTERN_SEP5, PathSchema from apps.common.paths import PATTERN_SEP5, PathSchema
from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
class TestKeychain(unittest.TestCase): class TestKeychain(unittest.TestCase):
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self): def setUp(self):
cache.start_session() cache_codec.start_session()
def tearDown(self): def tearDown(self):
cache.clear_all() cache.clear_all()
@ -71,7 +79,7 @@ class TestKeychain(unittest.TestCase):
def test_get_keychain(self): def test_get_keychain(self):
seed = bip39.seed(" ".join(["all"] * 12), "") seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed) context.cache_set(cache_common.APP_COMMON_SEED, seed)
schema = PathSchema.parse("m/44'/1'", 0) schema = PathSchema.parse("m/44'/1'", 0)
keychain = await_result(get_keychain("secp256k1", [schema])) keychain = await_result(get_keychain("secp256k1", [schema]))
@ -85,7 +93,7 @@ class TestKeychain(unittest.TestCase):
def test_with_slip44(self): def test_with_slip44(self):
seed = bip39.seed(" ".join(["all"] * 12), "") seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed) context.cache_set(cache_common.APP_COMMON_SEED, seed)
slip44_id = 42 slip44_id = 42
valid_path = [H_(44), H_(slip44_id), H_(0)] valid_path = [H_(44), H_(slip44_id), H_(0)]

View File

@ -2,12 +2,15 @@ from common import * # isort:skip
import unittest import unittest
from storage import cache from storage import cache_common
from trezor import utils, wire from trezor import wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.wire import context
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from ethereum_common import encode_network, make_network from ethereum_common import encode_network, make_network
@ -71,10 +74,14 @@ class TestEthereumKeychain(unittest.TestCase):
addr, addr,
) )
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self): def setUp(self):
cache.start_session() cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "") seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed) cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def from_address_n(self, address_n): def from_address_n(self, address_n):
slip44 = _slip44_from_address_n(address_n) slip44 = _slip44_from_address_n(address_n)

View File

@ -1,150 +1,158 @@
from common import * # isort:skip from common import * # isort:skip
from mock_storage import mock_storage from mock_storage import mock_storage
from storage import cache from storage import cache, cache_codec, cache_common
from trezor.messages import EndSession, Initialize from trezor.messages import EndSession, Initialize
from trezor.wire import context
from trezor.wire.codec.codec_context import CodecContext
from apps.base import handle_EndSession, handle_Initialize from apps.base import handle_EndSession, handle_Initialize
from apps.common.cache import stored, stored_async
KEY = 0 KEY = 0
# Function moved from cache.py, as it was not used there # Function moved from cache.py, as it was not used there
def is_session_started() -> bool: def is_session_started() -> bool:
return cache._active_session_idx is not None return cache_codec._active_session_idx is not None
class TestStorageCache(unittest.TestCase): class TestStorageCache(unittest.TestCase):
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self): def setUp(self):
cache.clear_all() cache.clear_all()
def test_start_session(self): def test_start_session(self):
session_id_a = cache.start_session() session_id_a = cache_codec.start_session()
self.assertIsNotNone(session_id_a) self.assertIsNotNone(session_id_a)
session_id_b = cache.start_session() session_id_b = cache_codec.start_session()
self.assertNotEqual(session_id_a, session_id_b) self.assertNotEqual(session_id_a, session_id_b)
cache.clear_all() cache.clear_all()
with self.assertRaises(cache.InvalidSessionError): with self.assertRaises(cache_common.InvalidSessionError):
cache.set(KEY, "something") context.cache_set(KEY, "something")
with self.assertRaises(cache.InvalidSessionError): with self.assertRaises(cache_common.InvalidSessionError):
cache.get(KEY) context.cache_get(KEY)
def test_end_session(self): def test_end_session(self):
session_id = cache.start_session() session_id = cache_codec.start_session()
self.assertTrue(is_session_started()) self.assertTrue(is_session_started())
cache.set(KEY, b"A") context.cache_set(KEY, b"A")
cache.end_current_session() cache_codec.end_current_session()
self.assertFalse(is_session_started()) self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY) self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
# ending an ended session should be a no-op # ending an ended session should be a no-op
cache.end_current_session() cache_codec.end_current_session()
self.assertFalse(is_session_started()) self.assertFalse(is_session_started())
session_id_a = cache.start_session(session_id) session_id_a = cache_codec.start_session(session_id)
# original session no longer exists # original session no longer exists
self.assertNotEqual(session_id_a, session_id) self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists # original session data no longer exists
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
# create a new session # create a new session
session_id_b = cache.start_session() session_id_b = cache_codec.start_session()
# switch back to original session # switch back to original session
session_id = cache.start_session(session_id_a) session_id = cache_codec.start_session(session_id_a)
self.assertEqual(session_id, session_id_a) self.assertEqual(session_id, session_id_a)
# end original session # end original session
cache.end_current_session() cache_codec.end_current_session()
# switch back to B # switch back to B
session_id = cache.start_session(session_id_b) session_id = cache_codec.start_session(session_id_b)
self.assertEqual(session_id, session_id_b) self.assertEqual(session_id, session_id_b)
def test_session_queue(self): def test_session_queue(self):
session_id = cache.start_session() session_id = cache_codec.start_session()
self.assertEqual(cache.start_session(session_id), session_id) self.assertEqual(cache_codec.start_session(session_id), session_id)
cache.set(KEY, b"A") context.cache_set(KEY, b"A")
for i in range(cache._MAX_SESSIONS_COUNT): for i in range(cache_codec._MAX_SESSIONS_COUNT):
cache.start_session() cache_codec.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id) self.assertNotEqual(cache_codec.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
def test_get_set(self): def test_get_set(self):
session_id1 = cache.start_session() session_id1 = cache_codec.start_session()
cache.set(KEY, b"hello") context.cache_set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(context.cache_get(KEY), b"hello")
session_id2 = cache.start_session() session_id2 = cache_codec.start_session()
cache.set(KEY, b"world") context.cache_set(KEY, b"world")
self.assertEqual(cache.get(KEY), b"world") self.assertEqual(context.cache_get(KEY), b"world")
cache.start_session(session_id2) cache_codec.start_session(session_id2)
self.assertEqual(cache.get(KEY), b"world") self.assertEqual(context.cache_get(KEY), b"world")
cache.start_session(session_id1) cache_codec.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(context.cache_get(KEY), b"hello")
cache.clear_all() cache.clear_all()
with self.assertRaises(cache.InvalidSessionError): with self.assertRaises(cache_common.InvalidSessionError):
cache.get(KEY) context.cache_get(KEY)
def test_get_set_int(self): def test_get_set_int(self):
session_id1 = cache.start_session() session_id1 = cache_codec.start_session()
cache.set_int(KEY, 1234) context.cache_set_int(KEY, 1234)
self.assertEqual(cache.get_int(KEY), 1234) self.assertEqual(context.cache_get_int(KEY), 1234)
session_id2 = cache.start_session() session_id2 = cache_codec.start_session()
cache.set_int(KEY, 5678) context.cache_set_int(KEY, 5678)
self.assertEqual(cache.get_int(KEY), 5678) self.assertEqual(context.cache_get_int(KEY), 5678)
cache.start_session(session_id2) cache_codec.start_session(session_id2)
self.assertEqual(cache.get_int(KEY), 5678) self.assertEqual(context.cache_get_int(KEY), 5678)
cache.start_session(session_id1) cache_codec.start_session(session_id1)
self.assertEqual(cache.get_int(KEY), 1234) self.assertEqual(context.cache_get_int(KEY), 1234)
cache.clear_all() cache.clear_all()
with self.assertRaises(cache.InvalidSessionError): with self.assertRaises(cache_common.InvalidSessionError):
cache.get_int(KEY) context.cache_get_int(KEY)
def test_delete(self): def test_delete(self):
session_id1 = cache.start_session() session_id1 = cache_codec.start_session()
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
cache.set(KEY, b"hello") context.cache_set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(context.cache_get(KEY), b"hello")
cache.delete(KEY) context.cache_delete(KEY)
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
cache.set(KEY, b"hello") context.cache_set(KEY, b"hello")
cache.start_session() cache_codec.start_session()
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
cache.set(KEY, b"hello") context.cache_set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(context.cache_get(KEY), b"hello")
cache.delete(KEY) context.cache_delete(KEY)
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
cache.start_session(session_id1) cache_codec.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(context.cache_get(KEY), b"hello")
def test_decorators(self): def test_decorators(self):
run_count = 0 run_count = 0
cache.start_session() cache_codec.start_session()
@cache.stored(KEY) @stored(KEY)
def func(): def func():
nonlocal run_count nonlocal run_count
run_count += 1 run_count += 1
return b"foo" return b"foo"
# cache is empty # cache is empty
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
self.assertEqual(run_count, 0) self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo") self.assertEqual(func(), b"foo")
# function was run # function was run
self.assertEqual(run_count, 1) self.assertEqual(run_count, 1)
self.assertEqual(cache.get(KEY), b"foo") self.assertEqual(context.cache_get(KEY), b"foo")
# function does not run again but returns cached value # function does not run again but returns cached value
self.assertEqual(func(), b"foo") self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1) self.assertEqual(run_count, 1)
@cache.stored_async(KEY) @stored_async(KEY)
async def async_func(): async def async_func():
nonlocal run_count nonlocal run_count
run_count += 1 run_count += 1
@ -154,7 +162,7 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(await_result(async_func()), b"foo") self.assertEqual(await_result(async_func()), b"foo")
self.assertEqual(run_count, 1) self.assertEqual(run_count, 1)
cache.start_session() cache_codec.start_session()
self.assertEqual(await_result(async_func()), b"bar") self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2) self.assertEqual(run_count, 2)
# awaitable is also run only once # awaitable is also run only once
@ -162,16 +170,16 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(run_count, 2) self.assertEqual(run_count, 2)
def test_empty_value(self): def test_empty_value(self):
cache.start_session() cache_codec.start_session()
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
cache.set(KEY, b"") context.cache_set(KEY, b"")
self.assertEqual(cache.get(KEY), b"") self.assertEqual(context.cache_get(KEY), b"")
cache.delete(KEY) context.cache_delete(KEY)
run_count = 0 run_count = 0
@cache.stored(KEY) @stored(KEY)
def func(): def func():
nonlocal run_count nonlocal run_count
run_count += 1 run_count += 1
@ -191,7 +199,7 @@ class TestStorageCache(unittest.TestCase):
return await_result(handle_Initialize(msg)) return await_result(handle_Initialize(msg))
# calling Initialize without an ID allocates a new one # calling Initialize without an ID allocates a new one
session_id = cache.start_session() session_id = cache_codec.start_session()
features = call_Initialize() features = call_Initialize()
self.assertNotEqual(session_id, features.session_id) self.assertNotEqual(session_id, features.session_id)
@ -200,31 +208,31 @@ class TestStorageCache(unittest.TestCase):
self.assertEqual(session_id, features.session_id) self.assertEqual(session_id, features.session_id)
# store "hello" # store "hello"
cache.set(KEY, b"hello") context.cache_set(KEY, b"hello")
# check that it is cleared # check that it is cleared
features = call_Initialize() features = call_Initialize()
session_id = features.session_id session_id = features.session_id
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
# store "hello" again # store "hello" again
cache.set(KEY, b"hello") context.cache_set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(context.cache_get(KEY), b"hello")
# supplying a different session ID starts a new cache # supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * cache._SESSION_ID_LENGTH) call_Initialize(session_id=b"A" * cache_codec.SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
# but resuming a session loads the previous one # but resuming a session loads the previous one
call_Initialize(session_id=session_id) call_Initialize(session_id=session_id)
self.assertEqual(cache.get(KEY), b"hello") self.assertEqual(context.cache_get(KEY), b"hello")
def test_EndSession(self): def test_EndSession(self):
self.assertRaises(cache.InvalidSessionError, cache.get, KEY) self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
cache.start_session() cache_codec.start_session()
self.assertTrue(is_session_started()) self.assertTrue(is_session_started())
self.assertIsNone(cache.get(KEY)) self.assertIsNone(context.cache_get(KEY))
await_result(handle_EndSession(EndSession())) await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started()) self.assertFalse(is_session_started())
self.assertRaises(cache.InvalidSessionError, cache.get, KEY) self.assertRaises(cache_common.InvalidSessionError, context.cache_get, KEY)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -5,7 +5,7 @@ import ustruct
from trezor import io from trezor import io
from trezor.loop import wait from trezor.loop import wait
from trezor.utils import chunks from trezor.utils import chunks
from trezor.wire import codec_v1 from trezor.wire.codec import codec_v1
class MockHID: class MockHID: