From a6b4d735f9f2cf4d4bc05136ed03ed4a189e089a Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 15 Nov 2024 17:31:36 +0100 Subject: [PATCH] feat(core): implement thp context and cache --- core/src/apps/base.py | 53 +++++---- core/src/apps/cardano/seed.py | 20 ++-- core/src/apps/common/backup.py | 27 +++-- core/src/apps/common/seed.py | 119 +++++++++++++------- core/src/storage/cache.py | 19 +++- core/src/trezor/wire/__init__.py | 128 ++++++++++++++-------- core/src/trezor/wire/context.py | 14 ++- core/tests/test_apps.bitcoin.keychain.py | 54 ++++++--- core/tests/test_apps.common.keychain.py | 28 +++-- core/tests/test_apps.ethereum.keychain.py | 35 ++++-- 10 files changed, 338 insertions(+), 159 deletions(-) diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 43bcd5f9e0..6a969f0068 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -204,33 +204,39 @@ def get_features() -> Features: return f -async def handle_Initialize(msg: Initialize) -> Features: - import storage.cache_codec as cache_codec +if not utils.USE_THP: - session_id = cache_codec.start_session(msg.session_id) + async def handle_Initialize( + msg: Initialize, + ) -> Features: + import storage.cache_codec as cache_codec - if not utils.BITCOIN_ONLY: - from storage.cache_common import APP_COMMON_DERIVE_CARDANO + session_id = cache_codec.start_session(msg.session_id) - derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) - have_seed = context.cache_is_set(APP_COMMON_SEED) - if ( - have_seed - and msg.derive_cardano is not None - and msg.derive_cardano != bool(derive_cardano) - ): - # seed is already derived, and host wants to change derive_cardano setting - # => create a new session - cache_codec.end_current_session() - session_id = cache_codec.start_session() - have_seed = False + if not utils.BITCOIN_ONLY: + from storage.cache_common import APP_COMMON_DERIVE_CARDANO - if not have_seed: - context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) + derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) + have_seed = context.cache_is_set(APP_COMMON_SEED) + if ( + have_seed + and msg.derive_cardano is not None + and msg.derive_cardano != bool(derive_cardano) + ): + # seed is already derived, and host wants to change derive_cardano setting + # => create a new session + cache_codec.end_current_session() + session_id = cache_codec.start_session() + have_seed = False - features = get_features() - features.session_id = session_id - return features + if not have_seed: + context.cache_set_bool( + APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano) + ) + + features = get_features() + features.session_id = session_id + return features async def handle_GetFeatures(msg: GetFeatures) -> Features: @@ -464,8 +470,9 @@ def boot() -> None: MT = MessageType # local_cache_global # Register workflow handlers + if not utils.USE_THP: + workflow_handlers.register(MT.Initialize, handle_Initialize) for msg_type, handler in [ - (MT.Initialize, handle_Initialize), (MT.GetFeatures, handle_GetFeatures), (MT.Cancel, handle_Cancel), (MT.LockDevice, handle_LockDevice), diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 6e5309a968..28d93af3d7 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -8,7 +8,6 @@ from storage.cache_common import ( ) from trezor import wire from trezor.crypto import cardano -from trezor.wire import context from apps.common import mnemonic from apps.common.seed import get_seed @@ -21,6 +20,7 @@ if TYPE_CHECKING: from trezor import messages from trezor.crypto import bip32 from trezor.enums import CardanoDerivationType + from trezor.wire.protocol_common import Context from apps.common.keychain import Handler, MsgOut from apps.common.paths import Bip32Path @@ -116,9 +116,9 @@ def is_minting_path(path: Bip32Path) -> bool: return path[: len(MINTING_ROOT)] == MINTING_ROOT -def derive_and_store_secrets(passphrase: str) -> None: +def derive_and_store_secrets(ctx: Context, passphrase: str) -> None: assert device.is_initialized() - assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) + assert ctx.cache.get_bool(APP_COMMON_DERIVE_CARDANO) if not mnemonic.is_bip39(): # nothing to do for SLIP-39, where we can derive the root from the main seed @@ -138,18 +138,19 @@ def derive_and_store_secrets(passphrase: str) -> None: else: icarus_trezor_secret = icarus_secret - context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret) - context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) + ctx.cache.set(APP_CARDANO_ICARUS_SECRET, icarus_secret) + ctx.cache.set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: from trezor.enums import CardanoDerivationType + from trezor.wire import context if not device.is_initialized(): raise wire.NotInitialized("Device is not initialized") if derivation_type == CardanoDerivationType.LEDGER: - seed = await get_seed() + seed = get_seed() return Keychain(cardano.from_seed_ledger(seed)) if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO): @@ -163,6 +164,11 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai # _get_secret secret = context.cache_get(cache_entry) assert secret is not None + # TODO solve for non THP + # if secret is None: + # await derive_and_store_roots_legacy() + # secret = context.cache_get(cache_entry) + # assert secret is not None root = cardano.from_secret(secret) return Keychain(root) @@ -173,7 +179,7 @@ async def _get_keychain(derivation_type: CardanoDerivationType) -> Keychain: return await _get_keychain_bip39(derivation_type) else: # derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0023.md - seed = await get_seed() + seed = get_seed() return Keychain(cardano.from_seed_slip23(seed)) diff --git a/core/src/apps/common/backup.py b/core/src/apps/common/backup.py index fc56f42f9b..8037aba698 100644 --- a/core/src/apps/common/backup.py +++ b/core/src/apps/common/backup.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED -from trezor import wire +from trezor import utils, wire from trezor.enums import MessageType from trezor.wire import context from trezor.wire.message_handler import filters, remove_filter @@ -24,14 +24,23 @@ def deactivate_repeated_backup() -> None: remove_filter(_repeated_backup_filter) -_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( - MessageType.Initialize, - MessageType.GetFeatures, - MessageType.EndSession, - MessageType.BackupDevice, - MessageType.WipeDevice, - MessageType.Cancel, -) +if utils.USE_THP: + _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( + MessageType.GetFeatures, + MessageType.EndSession, + MessageType.BackupDevice, + MessageType.WipeDevice, + MessageType.Cancel, + ) +else: + _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( + MessageType.Initialize, + MessageType.GetFeatures, + MessageType.EndSession, + MessageType.BackupDevice, + MessageType.WipeDevice, + MessageType.Cancel, + ) def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index ee1bf681c6..a060e6e17d 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -5,14 +5,18 @@ from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPH from trezor import utils from trezor.crypto import hmac from trezor.wire import context +from trezor.wire.context import get_context +from trezor.wire.errors import DataError from apps.common import cache from . import mnemonic -from .passphrase import get as get_passphrase +from .passphrase import get_passphrase as get_passphrase if TYPE_CHECKING: from trezor.crypto import bip32 + from trezor.messages import ThpCreateNewSession + from trezor.wire.protocol_common import Context from .paths import Bip32Path, Slip21Path @@ -22,6 +26,10 @@ if not utils.BITCOIN_ONLY: APP_COMMON_DERIVE_CARDANO, ) +if not utils.USE_THP: + from .passphrase import get as get_passphrase_legacy + + class Slip21Node: """ This class implements the SLIP-0021 hierarchical derivation of symmetric keys, see @@ -53,51 +61,88 @@ class Slip21Node: return Slip21Node(data=self.data) -if not utils.BITCOIN_ONLY: - # === Cardano variant === - # We want to derive both the normal seed and the Cardano seed together, AND - # expose a method for Cardano to do the same +def get_seed() -> bytes: + common_seed = context.cache_get(APP_COMMON_SEED) + assert common_seed is not None + return common_seed + + +if utils.BITCOIN_ONLY: + # === Bitcoin_only variant === + # We want to derive the normal seed ONLY + + async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None: + + if msg.passphrase is not None and msg.on_device: + raise DataError("Passphrase provided when it shouldn't be!") + + if ctx.cache.is_set(APP_COMMON_SEED): + raise Exception("Seed is already set!") - async def derive_and_store_roots() -> None: from trezor import wire if not storage_device.is_initialized(): raise wire.NotInitialized("Device is not initialized") - need_seed = not context.cache_is_set(APP_COMMON_SEED) - need_cardano_secret = context.cache_get_bool( - APP_COMMON_DERIVE_CARDANO - ) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET) - - if not need_seed and not need_cardano_secret: - return - - passphrase = await get_passphrase() - - if need_seed: - common_seed = mnemonic.get_seed(passphrase) - context.cache_set(APP_COMMON_SEED, common_seed) - - if need_cardano_secret: - from apps.cardano.seed import derive_and_store_secrets - - derive_and_store_secrets(passphrase) - - @cache.stored_async(APP_COMMON_SEED) - async def get_seed() -> bytes: - await derive_and_store_roots() - common_seed = context.cache_get(APP_COMMON_SEED) - assert common_seed is not None - return common_seed + passphrase = await get_passphrase(msg) + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) else: - # === Bitcoin-only variant === - # We use the simple version of `get_seed` that never needs to derive anything else. + # === Cardano variant === + # We want to derive both the normal seed and the Cardano seed together, AND + # expose a method for Cardano to do the same - @cache.stored_async(APP_COMMON_SEED) - async def get_seed() -> bytes: - passphrase = await get_passphrase() - return mnemonic.get_seed(passphrase) + async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None: + + if msg.passphrase is not None and msg.on_device: + raise DataError("Passphrase provided when it shouldn't be!") + + from trezor import wire + + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") + + if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET): + raise Exception("Cardano icarus secret is already set!") + + passphrase = await get_passphrase(msg) + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) + + if msg.derive_cardano: + from apps.cardano.seed import derive_and_store_secrets + + ctx.cache.set_bool(APP_COMMON_DERIVE_CARDANO, True) + derive_and_store_secrets(ctx, passphrase) + + if not utils.USE_THP: + + async def derive_and_store_roots_legacy() -> None: + from trezor import wire + + if not storage_device.is_initialized(): + raise wire.NotInitialized("Device is not initialized") + + ctx = get_context() + need_seed = not ctx.cache.is_set(APP_COMMON_SEED) + need_cardano_secret = ctx.cache.get_bool( + APP_COMMON_DERIVE_CARDANO + ) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET) + + if not need_seed and not need_cardano_secret: + return + + passphrase = await get_passphrase_legacy() + + if need_seed: + common_seed = mnemonic.get_seed(passphrase) + ctx.cache.set(APP_COMMON_SEED, common_seed) + + if need_cardano_secret: + from apps.cardano.seed import derive_and_store_secrets + + derive_and_store_secrets(ctx, passphrase) @cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index 89f038f706..bf6534a972 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -3,7 +3,7 @@ import gc from typing import TYPE_CHECKING from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache -from storage import cache_codec +from trezor import utils if TYPE_CHECKING: from typing import Tuple @@ -20,7 +20,15 @@ if TYPE_CHECKING: _SESSIONLESS_CACHE = SessionlessCache() -_PROTOCOL_CACHE = cache_codec + +if utils.USE_THP: + from storage import cache_thp + + _PROTOCOL_CACHE = cache_thp +else: + from storage import cache_codec + + _PROTOCOL_CACHE = cache_codec _PROTOCOL_CACHE.initialize() _SESSIONLESS_CACHE.clear() @@ -32,7 +40,12 @@ def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None: global autolock_last_touch autolock_last_touch = None _SESSIONLESS_CACHE.clear() - _PROTOCOL_CACHE.clear_all() + + if utils.USE_THP and excluded is not None: + # If we want to keep THP connection alive, we do not clear communication keys + cache_thp.clear_all_except_one_session_keys(excluded) + else: + _PROTOCOL_CACHE.clear_all() def get_int_all_sessions(key: int) -> builtins.set[int]: diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 26d7309325..68b9e51577 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is: - Request / response. - Protobuf-encoded, see `protobuf.py`. -- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`. +- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py` or `trezor/wire/thp/thp_main.py`. - Transferred over USB interface, or UDP in case of Unix emulation. This module: @@ -27,7 +27,13 @@ from typing import TYPE_CHECKING from trezor import log, loop, protobuf, utils from trezor.wire import message_handler, protocol_common -from trezor.wire.codec.codec_context import CodecContext + +if utils.USE_THP: + from trezor.wire.message_handler import WIRE_BUFFER_2 + from trezor.wire.thp import thp_main +else: + from trezor.wire.codec.codec_context import CodecContext + from trezor.wire.context import UnexpectedMessageException from trezor.wire.message_handler import WIRE_BUFFER, failure, find_handler @@ -52,59 +58,89 @@ def setup(iface: WireInterface) -> None: loop.schedule(handle_session(iface)) -async def handle_session(iface: WireInterface) -> None: - ctx = CodecContext(iface, WIRE_BUFFER) - next_msg: protocol_common.Message | None = None +if utils.USE_THP: - # Take a mark of modules that are imported at this point, so we can - # roll back and un-import any others. - modules = utils.unimport_begin() - while True: - try: - if next_msg is None: - # If the previous run did not keep an unprocessed message for us, - # wait for a new one coming from the wire. - try: - msg = await ctx.read_from_wire() - except protocol_common.WireError as exc: - if __debug__: - log.exception(__name__, exc) - await ctx.write(failure(exc)) - continue + async def handle_session(iface: WireInterface) -> None: - else: - # Process the message from previous run. - msg = next_msg - next_msg = None + thp_main.set_read_buffer(WIRE_BUFFER) + thp_main.set_write_buffer(WIRE_BUFFER_2) - do_not_restart = False + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + modules = utils.unimport_begin() + + while True: try: - do_not_restart = await message_handler.handle_single_message( - ctx, msg, handler_finder=find_handler - ) - except UnexpectedMessageException as unexpected: - # The workflow was interrupted by an unexpected message. We need to - # process it as if it was a new message... - next_msg = unexpected.msg - # ...and we must not restart because that would lose the message. - do_not_restart = True - continue + await thp_main.thp_main_loop(iface) except Exception as exc: - # Log and ignore. The session handler can only exit explicitly in the - # following finally block. + # Log and try again. if __debug__: log.exception(__name__, exc) finally: # Unload modules imported by the workflow. Should not raise. + if __debug__: + log.debug(__name__, "utils.unimport_end(modules) and loop.clear()") utils.unimport_end(modules) + loop.clear() + return # pylint: disable=lost-exception - if not do_not_restart: - # Let the session be restarted from `main`. - loop.clear() - return # pylint: disable=lost-exception +else: - except Exception as exc: - # Log and try again. The session handler can only exit explicitly via - # loop.clear() above. - if __debug__: - log.exception(__name__, exc) + async def handle_session(iface: WireInterface) -> None: + ctx = CodecContext(iface, WIRE_BUFFER) + next_msg: protocol_common.Message | None = None + + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + modules = utils.unimport_begin() + while True: + try: + if next_msg is None: + # If the previous run did not keep an unprocessed message for us, + # wait for a new one coming from the wire. + try: + msg = await ctx.read_from_wire() + except protocol_common.WireError as exc: + if __debug__: + log.exception(__name__, exc) + await ctx.write(failure(exc)) + continue + + else: + # Process the message from previous run. + msg = next_msg + next_msg = None + + do_not_restart = False + try: + do_not_restart = await message_handler.handle_single_message( + ctx, msg, handler_finder=find_handler + ) + except UnexpectedMessageException as unexpected: + # The workflow was interrupted by an unexpected message. We need to + # process it as if it was a new message... + next_msg = unexpected.msg + # ...and we must not restart because that would lose the message. + do_not_restart = True + continue + except Exception as exc: + # Log and ignore. The session handler can only exit explicitly in the + # following finally block. + if __debug__: + log.exception(__name__, exc) + finally: + # Unload modules imported by the workflow. Should not raise. + utils.unimport_end(modules) + + if not do_not_restart: + # Let the session be restarted from `main`. + if __debug__: + log.debug(__name__, "loop.clear()") + loop.clear() + return # pylint: disable=lost-exception + + except Exception as exc: + # Log and try again. The session handler can only exit explicitly via + # loop.clear() above. + if __debug__: + log.exception(__name__, exc) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 124d2ce770..352a3db7bc 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING from storage import cache from storage.cache_common import SESSIONLESS_FLAG -from trezor import loop, protobuf +from trezor import loop, protobuf, utils from .protocol_common import Context, Message @@ -137,6 +137,18 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator: else: send_exc = None + +def try_get_ctx_ids() -> Tuple[bytes, bytes] | None: + ids = None + if utils.USE_THP: + from trezor.wire.thp.session_context import GenericSessionContext + + ctx = get_context() + if isinstance(ctx, GenericSessionContext): + ids = (ctx.channel_id, ctx.session_id.to_bytes(1, "big")) + return ids + + # ACCESS TO CACHE if TYPE_CHECKING: diff --git a/core/tests/test_apps.bitcoin.keychain.py b/core/tests/test_apps.bitcoin.keychain.py index e21f88c8c0..e23fd54b55 100644 --- a/core/tests/test_apps.bitcoin.keychain.py +++ b/core/tests/test_apps.bitcoin.keychain.py @@ -7,19 +7,36 @@ from trezor.wire import context from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin from trezor.wire.codec.codec_context import CodecContext -from storage import cache_codec + +if utils.USE_THP: + import thp_common +else: + from storage import cache_codec class TestBitcoinKeychain(unittest.TestCase): + if utils.USE_THP: - def __init__(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) - super().__init__() + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + super().__init__() - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) + + else: + + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def test_bitcoin(self): coin = _get_coin_by_name("Bitcoin") @@ -96,19 +113,20 @@ class TestBitcoinKeychain(unittest.TestCase): @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") class TestAltcoinKeychains(unittest.TestCase): + if not utils.USE_THP: - def __init__(self): - # Context is needed to test decorators and handleInitialize - # It allows access to codec cache from different parts of the code - from trezor.wire import context + def __init__(self): + # Context is needed to test decorators and handleInitialize + # It allows access to codec cache from different parts of the code + from trezor.wire import context - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) - super().__init__() + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def test_bcash(self): coin = _get_coin_by_name("Bcash") diff --git a/core/tests/test_apps.common.keychain.py b/core/tests/test_apps.common.keychain.py index ea85b2eb8f..6ff532efba 100644 --- a/core/tests/test_apps.common.keychain.py +++ b/core/tests/test_apps.common.keychain.py @@ -2,7 +2,7 @@ from common import * # isort:skip from mock_storage import mock_storage from storage import cache, cache_common -from trezor import wire +from trezor import utils, wire from trezor.crypto import bip39 from trezor.enums import SafetyCheckLevel from trezor.wire import context @@ -11,17 +11,31 @@ from apps.common import safety_checks from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.paths import PATTERN_SEP5, PathSchema from trezor.wire.codec.codec_context import CodecContext -from storage import cache_codec + +if utils.USE_THP: + import thp_common +if not utils.USE_THP: + from storage import cache_codec class TestKeychain(unittest.TestCase): - def __init__(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) - super().__init__() + if utils.USE_THP: - def setUp(self): - cache_codec.start_session() + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + super().__init__() + + else: + + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + + def setUp(self): + cache_codec.start_session() def tearDown(self): cache.clear_all() diff --git a/core/tests/test_apps.ethereum.keychain.py b/core/tests/test_apps.ethereum.keychain.py index 404dc07641..bcc94597f0 100644 --- a/core/tests/test_apps.ethereum.keychain.py +++ b/core/tests/test_apps.ethereum.keychain.py @@ -10,7 +10,12 @@ from trezor.wire import context from apps.common.keychain import get_keychain from apps.common.paths import HARDENED from trezor.wire.codec.codec_context import CodecContext -from storage import cache_codec + +if utils.USE_THP: + import thp_common +else: + from storage import cache_codec + if not utils.BITCOIN_ONLY: from ethereum_common import encode_network, make_network @@ -74,14 +79,28 @@ class TestEthereumKeychain(unittest.TestCase): addr, ) - def __init__(self): - context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) - super().__init__() + if utils.USE_THP: - def setUp(self): - cache_codec.start_session() - seed = bip39.seed(" ".join(["all"] * 12), "") - cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) + def __init__(self): + if __debug__: + thp_common.suppres_debug_log() + thp_common.prepare_context() + super().__init__() + + def setUp(self): + seed = bip39.seed(" ".join(["all"] * 12), "") + context.cache_set(cache_common.APP_COMMON_SEED, seed) + + else: + + def __init__(self): + context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) + super().__init__() + + def setUp(self): + cache_codec.start_session() + seed = bip39.seed(" ".join(["all"] * 12), "") + cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) def from_address_n(self, address_n): slip44 = _slip44_from_address_n(address_n)