1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-02-22 12:32:02 +00:00

test(core): fix test for non-THP builds, remove unnecessary imports

[no changelog]
This commit is contained in:
M1nd3r 2024-07-31 11:20:44 +02:00
parent e10753a187
commit ea0d143a76
12 changed files with 274 additions and 211 deletions

View File

@ -1,4 +1,4 @@
from common import H_, await_result, unittest # isort:skip
from common import * # isort:skip
import storage.cache
from trezor import wire
@ -18,8 +18,20 @@ from apps.bitcoin.sign_tx.bitcoin import Bitcoin
from apps.bitcoin.sign_tx.tx_info import TxInfo
from apps.common import coins
if not utils.USE_THP:
import storage.cache_codec
class TestApprover(unittest.TestCase):
if not utils.USE_THP:
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
self.coin = coins.by_name("Bitcoin")
self.fee_rate_percent = 0.3
@ -47,7 +59,8 @@ class TestApprover(unittest.TestCase):
coin_name=self.coin.coin_name,
script_type=InputScriptType.SPENDTAPROOT,
)
storage.cache.start_session()
if not utils.USE_THP:
storage.cache_codec.start_session()
def make_coinjoin_request(self, inputs):
return CoinJoinRequest(

View File

@ -1,4 +1,4 @@
from common import H_, unittest # isort:skip
from common import * # isort:skip
import storage.cache
from trezor.enums import InputScriptType
@ -9,8 +9,20 @@ from apps.common import coins
_ROUND_ID_LEN = 32
if not utils.USE_THP:
import storage.cache_codec
class TestAuthorization(unittest.TestCase):
if not utils.USE_THP:
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
coin = coins.by_name("Bitcoin")
@ -26,7 +38,8 @@ class TestAuthorization(unittest.TestCase):
)
self.authorization = CoinJoinAuthorization(self.msg_auth)
storage.cache.start_session()
if not utils.USE_THP:
storage.cache_codec.start_session()
def test_ownership_proof_account_depth_mismatch(self):
# Account depth mismatch.

View File

@ -1,17 +1,30 @@
from common import * # isort:skip
from storage import cache
from storage import cache_common
from trezor import wire
from trezor.crypto import bip39
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
if not utils.USE_THP:
from storage import cache_codec
class TestBitcoinKeychain(unittest.TestCase):
def setUp(self):
cache.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
if not utils.USE_THP:
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bitcoin(self):
coin = _get_coin_by_name("Bitcoin")
@ -88,10 +101,20 @@ class TestBitcoinKeychain(unittest.TestCase):
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestAltcoinKeychains(unittest.TestCase):
def setUp(self):
cache.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
if not utils.USE_THP:
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bcash(self):
coin = _get_coin_by_name("Bcash")

View File

@ -1,8 +1,8 @@
from common import * # isort:skip
from mock_storage import mock_storage
from storage import cache
from trezor import wire
from storage import cache, cache_common
from trezor import utils, wire
from trezor.crypto import bip39
from trezor.enums import SafetyCheckLevel
@ -10,10 +10,24 @@ from apps.common import safety_checks
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
from apps.common.paths import PATTERN_SEP5, PathSchema
if not utils.USE_THP:
from storage import cache_codec
class TestKeychain(unittest.TestCase):
def setUp(self):
cache.start_session()
if not utils.USE_THP:
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
def tearDown(self):
cache.clear_all()
@ -71,7 +85,7 @@ class TestKeychain(unittest.TestCase):
def test_get_keychain(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
schema = PathSchema.parse("m/44'/1'", 0)
keychain = await_result(get_keychain("secp256k1", [schema]))
@ -85,7 +99,7 @@ class TestKeychain(unittest.TestCase):
def test_with_slip44(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
slip44_id = 42
valid_path = [H_(44), H_(slip44_id), H_(0)]

View File

@ -2,13 +2,17 @@ from common import * # isort:skip
import unittest
from storage import cache
from trezor import utils, wire
from storage import cache_common
from trezor import wire
from trezor.crypto import bip39
from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED
if not utils.USE_THP:
from storage import cache_codec
if not utils.BITCOIN_ONLY:
from ethereum_common import encode_network, make_network
from trezor.messages import (
@ -70,11 +74,20 @@ class TestEthereumKeychain(unittest.TestCase):
keychain.derive,
addr,
)
if not utils.USE_THP:
def setUp(self):
cache.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache.set(cache.APP_COMMON_SEED, seed)
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def from_address_n(self, address_n):
slip44 = _slip44_from_address_n(address_n)

View File

@ -3,8 +3,6 @@ from common import * # isort:skip # noqa: F403
from mock_storage import mock_storage
from storage import cache, cache_codec, cache_thp
from storage.cache_common import InvalidSessionError
from trezor import utils
from trezor.messages import Initialize
from trezor.messages import EndSession
@ -17,10 +15,11 @@ if utils.USE_THP:
else:
_PROTOCOL_CACHE = cache_codec
def is_session_started() -> bool:
return cache_codec.get_active_session() is not None
# Function moved from cache.py, as it was not used there
def is_session_started() -> bool:
return _PROTOCOL_CACHE.get_active_session() is not None
def get_active_session():
return cache_codec.get_active_session()
class TestStorageCache(
@ -29,217 +28,209 @@ class TestStorageCache(
def setUp(self):
cache.clear_all()
def test_start_session(self):
session_id_a = cache.start_session()
self.assertIsNotNone(session_id_a)
session_id_b = cache.start_session()
self.assertNotEqual(session_id_a, session_id_b)
if not utils.USE_THP:
cache.clear_all()
with self.assertRaises(InvalidSessionError):
cache.set(KEY, "something")
with self.assertRaises(InvalidSessionError):
cache.get(KEY)
def __init__(self):
# Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code
from trezor.wire import context
def test_end_session(self):
session_id = cache.start_session()
self.assertTrue(is_session_started())
cache.set(KEY, b"A")
cache.end_current_session()
self.assertFalse(is_session_started())
self.assertRaises(InvalidSessionError, cache.get, KEY)
context.CURRENT_CONTEXT = context.CodecContext(None, bytearray(64))
super().__init__()
# ending an ended session should be a no-op
cache.end_current_session()
self.assertFalse(is_session_started())
def test_start_session(self):
session_id_a = cache_codec.start_session()
self.assertIsNotNone(session_id_a)
session_id_b = cache_codec.start_session()
self.assertNotEqual(session_id_a, session_id_b)
session_id_a = cache.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(cache.get(KEY))
cache.clear_all()
self.assertIsNone(get_active_session())
for session in cache_codec._SESSIONS:
self.assertEqual(session.session_id, b"")
self.assertEqual(session.last_usage, 0)
# create a new session
session_id_b = cache.start_session()
# switch back to original session
session_id = cache.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache.end_current_session()
# switch back to B
session_id = cache.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
def test_end_session(self):
session_id = cache_codec.start_session()
self.assertTrue(is_session_started())
get_active_session().set(KEY, b"A")
cache_codec.end_current_session()
self.assertFalse(is_session_started())
self.assertIsNone(get_active_session())
def test_session_queue(self):
session_id = cache.start_session()
self.assertEqual(cache.start_session(session_id), session_id)
cache.set(KEY, b"A")
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
cache.start_session()
self.assertNotEqual(cache.start_session(session_id), session_id)
self.assertIsNone(cache.get(KEY))
# ending an ended session should be a no-op
cache_codec.end_current_session()
self.assertFalse(is_session_started())
def test_get_set(self):
session_id1 = cache.start_session()
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
session_id_a = cache_codec.start_session(session_id)
# original session no longer exists
self.assertNotEqual(session_id_a, session_id)
# original session data no longer exists
self.assertIsNone(get_active_session().get(KEY))
session_id2 = cache.start_session()
cache.set(KEY, b"world")
self.assertEqual(cache.get(KEY), b"world")
# create a new session
session_id_b = cache_codec.start_session()
# switch back to original session
session_id = cache_codec.start_session(session_id_a)
self.assertEqual(session_id, session_id_a)
# end original session
cache_codec.end_current_session()
# switch back to B
session_id = cache_codec.start_session(session_id_b)
self.assertEqual(session_id, session_id_b)
cache.start_session(session_id2)
self.assertEqual(cache.get(KEY), b"world")
cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello")
def test_session_queue(self):
session_id = cache_codec.start_session()
self.assertEqual(cache_codec.start_session(session_id), session_id)
get_active_session().set(KEY, b"A")
for i in range(_PROTOCOL_CACHE._MAX_SESSIONS_COUNT):
cache_codec.start_session()
self.assertNotEqual(cache_codec.start_session(session_id), session_id)
self.assertIsNone(get_active_session().get(KEY))
cache.clear_all()
with self.assertRaises(InvalidSessionError):
cache.get(KEY)
def test_get_set(self):
session_id1 = cache_codec.start_session()
cache_codec.get_active_session().set(KEY, b"hello")
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
def test_get_set_int(self):
session_id1 = cache.start_session()
cache.set_int(KEY, 1234)
self.assertEqual(cache.get_int(KEY), 1234)
session_id2 = cache_codec.start_session()
cache_codec.get_active_session().set(KEY, b"world")
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
session_id2 = cache.start_session()
cache.set_int(KEY, 5678)
self.assertEqual(cache.get_int(KEY), 5678)
cache_codec.start_session(session_id2)
self.assertEqual(cache_codec.get_active_session().get(KEY), b"world")
cache_codec.start_session(session_id1)
self.assertEqual(cache_codec.get_active_session().get(KEY), b"hello")
cache.start_session(session_id2)
self.assertEqual(cache.get_int(KEY), 5678)
cache.start_session(session_id1)
self.assertEqual(cache.get_int(KEY), 1234)
cache_codec.clear_all()
self.assertIsNone(cache_codec.get_active_session())
cache.clear_all()
with self.assertRaises(InvalidSessionError):
cache.get_int(KEY)
def test_get_set_int(self):
session_id1 = cache_codec.start_session()
get_active_session().set_int(KEY, 1234)
self.assertEqual(get_active_session().get_int(KEY), 1234)
def test_delete(self):
session_id1 = cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.delete(KEY)
self.assertIsNone(cache.get(KEY))
session_id2 = cache_codec.start_session()
get_active_session().set_int(KEY, 5678)
self.assertEqual(get_active_session().get_int(KEY), 5678)
cache.set(KEY, b"hello")
cache.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
cache.delete(KEY)
self.assertIsNone(cache.get(KEY))
cache_codec.start_session(session_id2)
self.assertEqual(get_active_session().get_int(KEY), 5678)
cache_codec.start_session(session_id1)
self.assertEqual(get_active_session().get_int(KEY), 1234)
cache.start_session(session_id1)
self.assertEqual(cache.get(KEY), b"hello")
cache_codec.clear_all()
self.assertIsNone(get_active_session())
def test_decorators(self):
run_count = 0
cache.start_session()
def test_delete(self):
session_id1 = cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
get_active_session().delete(KEY)
self.assertIsNone(get_active_session().get(KEY))
@cache.stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b"foo"
get_active_session().set(KEY, b"hello")
cache_codec.start_session()
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
get_active_session().delete(KEY)
self.assertIsNone(get_active_session().get(KEY))
# cache is empty
self.assertIsNone(cache.get(KEY))
self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(cache.get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
cache_codec.start_session(session_id1)
self.assertEqual(get_active_session().get(KEY), b"hello")
@cache.stored_async(KEY)
async def async_func():
nonlocal run_count
run_count += 1
return b"bar"
def test_decorators(self):
run_count = 0
cache_codec.start_session()
from apps.common.cache import stored
# cache is still full
self.assertEqual(await_result(async_func()), b"foo")
self.assertEqual(run_count, 1)
@stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b"foo"
cache.start_session()
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
# awaitable is also run only once
self.assertEqual(await_result(async_func()), b"bar")
self.assertEqual(run_count, 2)
# cache is empty
self.assertIsNone(get_active_session().get(KEY))
self.assertEqual(run_count, 0)
self.assertEqual(func(), b"foo")
# function was run
self.assertEqual(run_count, 1)
self.assertEqual(get_active_session().get(KEY), b"foo")
# function does not run again but returns cached value
self.assertEqual(func(), b"foo")
self.assertEqual(run_count, 1)
def test_empty_value(self):
cache.start_session()
def test_empty_value(self):
cache_codec.start_session()
self.assertIsNone(cache.get(KEY))
cache.set(KEY, b"")
self.assertEqual(cache.get(KEY), b"")
self.assertIsNone(get_active_session().get(KEY))
get_active_session().set(KEY, b"")
self.assertEqual(get_active_session().get(KEY), b"")
cache.delete(KEY)
run_count = 0
get_active_session().delete(KEY)
run_count = 0
@cache.stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b""
from apps.common.cache import stored
self.assertEqual(func(), b"")
# function gets called once
self.assertEqual(run_count, 1)
self.assertEqual(func(), b"")
# function is not called for a second time
self.assertEqual(run_count, 1)
@stored(KEY)
def func():
nonlocal run_count
run_count += 1
return b""
@mock_storage
def test_Initialize(self):
if utils.USE_THP: # INITIALIZE SHOULD NOT BE IN THP!!! TODO
return
self.assertEqual(func(), b"")
# function gets called once
self.assertEqual(run_count, 1)
self.assertEqual(func(), b"")
# function is not called for a second time
self.assertEqual(run_count, 1)
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
@mock_storage
def test_Initialize(self):
# calling Initialize without an ID allocates a new one
session_id = cache.start_session()
features = call_Initialize()
self.assertNotEqual(session_id, features.session_id)
def call_Initialize(**kwargs):
msg = Initialize(**kwargs)
return await_result(handle_Initialize(msg))
# calling Initialize with the current ID does not allocate a new one
features = call_Initialize(session_id=session_id)
self.assertEqual(session_id, features.session_id)
# calling Initialize without an ID allocates a new one
session_id = cache_codec.start_session()
features = call_Initialize()
self.assertNotEqual(session_id, features.session_id)
# store "hello"
cache.set(KEY, b"hello")
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertIsNone(cache.get(KEY))
# store "hello" again
cache.set(KEY, b"hello")
self.assertEqual(cache.get(KEY), b"hello")
# calling Initialize with the current ID does not allocate a new one
features = call_Initialize(session_id=session_id)
self.assertEqual(session_id, features.session_id)
# supplying a different session ID starts a new cache
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
self.assertIsNone(cache.get(KEY))
# store "hello"
get_active_session().set(KEY, b"hello")
# check that it is cleared
features = call_Initialize()
session_id = features.session_id
self.assertIsNone(get_active_session().get(KEY))
# store "hello" again
get_active_session().set(KEY, b"hello")
self.assertEqual(get_active_session().get(KEY), b"hello")
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
self.assertEqual(cache.get(KEY), b"hello")
# supplying a different session ID starts a new session
call_Initialize(session_id=b"A" * _PROTOCOL_CACHE.SESSION_ID_LENGTH)
self.assertIsNone(get_active_session().get(KEY))
def test_EndSession(self):
# but resuming a session loads the previous one
call_Initialize(session_id=session_id)
self.assertEqual(get_active_session().get(KEY), b"hello")
self.assertRaises(InvalidSessionError, cache.get, KEY)
session_id = cache.start_session()
self.assertTrue(is_session_started())
self.assertIsNone(cache.get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertRaises(InvalidSessionError, cache.get, KEY)
def test_EndSession(self):
self.assertIsNone(get_active_session())
cache_codec.start_session()
self.assertTrue(is_session_started())
self.assertIsNone(get_active_session().get(KEY))
await_result(handle_EndSession(EndSession()))
self.assertFalse(is_session_started())
self.assertIsNone(cache_codec.get_active_session())
if __name__ == "__main__":

View File

@ -1,5 +1,4 @@
from common import * # isort:skip
from trezor import utils
if utils.USE_THP:
from trezor.wire.thp import checksum

View File

@ -1,5 +1,5 @@
from common import * # isort:skip
from trezor import config, log, utils
from trezor import config, log
if utils.USE_THP:
from trezor.messages import ThpCredentialMetadata

View File

@ -2,7 +2,6 @@ from common import * # isort:skip
from trezorcrypto import aesgcm, curve25519
import storage
from trezor import utils
if utils.USE_THP:
from trezor.wire.thp import crypto

View File

@ -1,6 +1,6 @@
from common import * # isort:skip
from storage import cache_thp
from trezor import config, io, log, protobuf, utils
from trezor import config, io, log, protobuf
from trezor.crypto.curve import curve25519
from trezor.enums import MessageType
from trezor.loop import wait

View File

@ -1,5 +1,4 @@
from common import * # isort:skip
from trezor import utils
if utils.USE_THP:
from trezor.wire.thp import writer

View File

@ -1,10 +1,9 @@
from common import * # isort:skip
import ustruct
from typing import TYPE_CHECKING
from ubinascii import hexlify
from storage.cache_thp import BROADCAST_CHANNEL_ID
from trezor import io, log, utils
from trezor import io, log
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire.protocol_common import Message