1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-07 14:00:57 +00:00

feat(core): implement thp context and cache

This commit is contained in:
M1nd3r 2024-11-15 17:31:36 +01:00
parent aaaeb3abca
commit a6b4d735f9
10 changed files with 338 additions and 159 deletions

View File

@ -204,33 +204,39 @@ def get_features() -> Features:
return f return f
async def handle_Initialize(msg: Initialize) -> Features: if not utils.USE_THP:
import storage.cache_codec as cache_codec
session_id = cache_codec.start_session(msg.session_id) async def handle_Initialize(
msg: Initialize,
) -> Features:
import storage.cache_codec as cache_codec
if not utils.BITCOIN_ONLY: session_id = cache_codec.start_session(msg.session_id)
from storage.cache_common import APP_COMMON_DERIVE_CARDANO
derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) if not utils.BITCOIN_ONLY:
have_seed = context.cache_is_set(APP_COMMON_SEED) from storage.cache_common import APP_COMMON_DERIVE_CARDANO
if (
have_seed
and msg.derive_cardano is not None
and msg.derive_cardano != bool(derive_cardano)
):
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
cache_codec.end_current_session()
session_id = cache_codec.start_session()
have_seed = False
if not have_seed: derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) have_seed = context.cache_is_set(APP_COMMON_SEED)
if (
have_seed
and msg.derive_cardano is not None
and msg.derive_cardano != bool(derive_cardano)
):
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
cache_codec.end_current_session()
session_id = cache_codec.start_session()
have_seed = False
features = get_features() if not have_seed:
features.session_id = session_id context.cache_set_bool(
return features APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
)
features = get_features()
features.session_id = session_id
return features
async def handle_GetFeatures(msg: GetFeatures) -> Features: async def handle_GetFeatures(msg: GetFeatures) -> Features:
@ -464,8 +470,9 @@ def boot() -> None:
MT = MessageType # local_cache_global MT = MessageType # local_cache_global
# Register workflow handlers # Register workflow handlers
if not utils.USE_THP:
workflow_handlers.register(MT.Initialize, handle_Initialize)
for msg_type, handler in [ for msg_type, handler in [
(MT.Initialize, handle_Initialize),
(MT.GetFeatures, handle_GetFeatures), (MT.GetFeatures, handle_GetFeatures),
(MT.Cancel, handle_Cancel), (MT.Cancel, handle_Cancel),
(MT.LockDevice, handle_LockDevice), (MT.LockDevice, handle_LockDevice),

View File

@ -8,7 +8,6 @@ from storage.cache_common import (
) )
from trezor import wire from trezor import wire
from trezor.crypto import cardano from trezor.crypto import cardano
from trezor.wire import context
from apps.common import mnemonic from apps.common import mnemonic
from apps.common.seed import get_seed from apps.common.seed import get_seed
@ -21,6 +20,7 @@ if TYPE_CHECKING:
from trezor import messages from trezor import messages
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.enums import CardanoDerivationType from trezor.enums import CardanoDerivationType
from trezor.wire.protocol_common import Context
from apps.common.keychain import Handler, MsgOut from apps.common.keychain import Handler, MsgOut
from apps.common.paths import Bip32Path from apps.common.paths import Bip32Path
@ -116,9 +116,9 @@ def is_minting_path(path: Bip32Path) -> bool:
return path[: len(MINTING_ROOT)] == MINTING_ROOT return path[: len(MINTING_ROOT)] == MINTING_ROOT
def derive_and_store_secrets(passphrase: str) -> None: def derive_and_store_secrets(ctx: Context, passphrase: str) -> None:
assert device.is_initialized() assert device.is_initialized()
assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) assert ctx.cache.get_bool(APP_COMMON_DERIVE_CARDANO)
if not mnemonic.is_bip39(): if not mnemonic.is_bip39():
# nothing to do for SLIP-39, where we can derive the root from the main seed # nothing to do for SLIP-39, where we can derive the root from the main seed
@ -138,18 +138,19 @@ def derive_and_store_secrets(passphrase: str) -> None:
else: else:
icarus_trezor_secret = icarus_secret icarus_trezor_secret = icarus_secret
context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret) ctx.cache.set(APP_CARDANO_ICARUS_SECRET, icarus_secret)
context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) ctx.cache.set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
from trezor.enums import CardanoDerivationType from trezor.enums import CardanoDerivationType
from trezor.wire import context
if not device.is_initialized(): if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
if derivation_type == CardanoDerivationType.LEDGER: if derivation_type == CardanoDerivationType.LEDGER:
seed = await get_seed() seed = get_seed()
return Keychain(cardano.from_seed_ledger(seed)) return Keychain(cardano.from_seed_ledger(seed))
if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO): if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO):
@ -163,6 +164,11 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
# _get_secret # _get_secret
secret = context.cache_get(cache_entry) secret = context.cache_get(cache_entry)
assert secret is not None assert secret is not None
# TODO solve for non THP
# if secret is None:
# await derive_and_store_roots_legacy()
# secret = context.cache_get(cache_entry)
# assert secret is not None
root = cardano.from_secret(secret) root = cardano.from_secret(secret)
return Keychain(root) return Keychain(root)
@ -173,7 +179,7 @@ async def _get_keychain(derivation_type: CardanoDerivationType) -> Keychain:
return await _get_keychain_bip39(derivation_type) return await _get_keychain_bip39(derivation_type)
else: else:
# derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0023.md # derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0023.md
seed = await get_seed() seed = get_seed()
return Keychain(cardano.from_seed_slip23(seed)) return Keychain(cardano.from_seed_slip23(seed))

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
from trezor import wire from trezor import utils, wire
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.wire import context from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter from trezor.wire.message_handler import filters, remove_filter
@ -24,14 +24,23 @@ def deactivate_repeated_backup() -> None:
remove_filter(_repeated_backup_filter) remove_filter(_repeated_backup_filter)
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( if utils.USE_THP:
MessageType.Initialize, _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.EndSession, MessageType.EndSession,
MessageType.BackupDevice, MessageType.BackupDevice,
MessageType.WipeDevice, MessageType.WipeDevice,
MessageType.Cancel, MessageType.Cancel,
) )
else:
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
MessageType.Initialize,
MessageType.GetFeatures,
MessageType.EndSession,
MessageType.BackupDevice,
MessageType.WipeDevice,
MessageType.Cancel,
)
def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:

View File

@ -5,14 +5,18 @@ from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPH
from trezor import utils from trezor import utils
from trezor.crypto import hmac from trezor.crypto import hmac
from trezor.wire import context from trezor.wire import context
from trezor.wire.context import get_context
from trezor.wire.errors import DataError
from apps.common import cache from apps.common import cache
from . import mnemonic from . import mnemonic
from .passphrase import get as get_passphrase from .passphrase import get_passphrase as get_passphrase
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.messages import ThpCreateNewSession
from trezor.wire.protocol_common import Context
from .paths import Bip32Path, Slip21Path from .paths import Bip32Path, Slip21Path
@ -22,6 +26,10 @@ if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO, APP_COMMON_DERIVE_CARDANO,
) )
if not utils.USE_THP:
from .passphrase import get as get_passphrase_legacy
class Slip21Node: class Slip21Node:
""" """
This class implements the SLIP-0021 hierarchical derivation of symmetric keys, see This class implements the SLIP-0021 hierarchical derivation of symmetric keys, see
@ -53,51 +61,88 @@ class Slip21Node:
return Slip21Node(data=self.data) return Slip21Node(data=self.data)
if not utils.BITCOIN_ONLY: def get_seed() -> bytes:
# === Cardano variant === common_seed = context.cache_get(APP_COMMON_SEED)
# We want to derive both the normal seed and the Cardano seed together, AND assert common_seed is not None
# expose a method for Cardano to do the same return common_seed
if utils.BITCOIN_ONLY:
# === Bitcoin_only variant ===
# We want to derive the normal seed ONLY
async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None:
if msg.passphrase is not None and msg.on_device:
raise DataError("Passphrase provided when it shouldn't be!")
if ctx.cache.is_set(APP_COMMON_SEED):
raise Exception("Seed is already set!")
async def derive_and_store_roots() -> None:
from trezor import wire from trezor import wire
if not storage_device.is_initialized(): if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
need_seed = not context.cache_is_set(APP_COMMON_SEED) passphrase = await get_passphrase(msg)
need_cardano_secret = context.cache_get_bool( common_seed = mnemonic.get_seed(passphrase)
APP_COMMON_DERIVE_CARDANO ctx.cache.set(APP_COMMON_SEED, common_seed)
) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
passphrase = await get_passphrase()
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
context.cache_set(APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase)
@cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes:
await derive_and_store_roots()
common_seed = context.cache_get(APP_COMMON_SEED)
assert common_seed is not None
return common_seed
else: else:
# === Bitcoin-only variant === # === Cardano variant ===
# We use the simple version of `get_seed` that never needs to derive anything else. # We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
@cache.stored_async(APP_COMMON_SEED) async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None:
async def get_seed() -> bytes:
passphrase = await get_passphrase() if msg.passphrase is not None and msg.on_device:
return mnemonic.get_seed(passphrase) raise DataError("Passphrase provided when it shouldn't be!")
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET):
raise Exception("Cardano icarus secret is already set!")
passphrase = await get_passphrase(msg)
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if msg.derive_cardano:
from apps.cardano.seed import derive_and_store_secrets
ctx.cache.set_bool(APP_COMMON_DERIVE_CARDANO, True)
derive_and_store_secrets(ctx, passphrase)
if not utils.USE_THP:
async def derive_and_store_roots_legacy() -> None:
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
ctx = get_context()
need_seed = not ctx.cache.is_set(APP_COMMON_SEED)
need_cardano_secret = ctx.cache.get_bool(
APP_COMMON_DERIVE_CARDANO
) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
passphrase = await get_passphrase_legacy()
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(ctx, passphrase)
@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) @cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)

View File

@ -3,7 +3,7 @@ import gc
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
from storage import cache_codec from trezor import utils
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Tuple from typing import Tuple
@ -20,7 +20,15 @@ if TYPE_CHECKING:
_SESSIONLESS_CACHE = SessionlessCache() _SESSIONLESS_CACHE = SessionlessCache()
_PROTOCOL_CACHE = cache_codec
if utils.USE_THP:
from storage import cache_thp
_PROTOCOL_CACHE = cache_thp
else:
from storage import cache_codec
_PROTOCOL_CACHE = cache_codec
_PROTOCOL_CACHE.initialize() _PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear() _SESSIONLESS_CACHE.clear()
@ -32,7 +40,12 @@ def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None:
global autolock_last_touch global autolock_last_touch
autolock_last_touch = None autolock_last_touch = None
_SESSIONLESS_CACHE.clear() _SESSIONLESS_CACHE.clear()
_PROTOCOL_CACHE.clear_all()
if utils.USE_THP and excluded is not None:
# If we want to keep THP connection alive, we do not clear communication keys
cache_thp.clear_all_except_one_session_keys(excluded)
else:
_PROTOCOL_CACHE.clear_all()
def get_int_all_sessions(key: int) -> builtins.set[int]: def get_int_all_sessions(key: int) -> builtins.set[int]:

View File

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response. - Request / response.
- Protobuf-encoded, see `protobuf.py`. - Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`. - Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py` or `trezor/wire/thp/thp_main.py`.
- Transferred over USB interface, or UDP in case of Unix emulation. - Transferred over USB interface, or UDP in case of Unix emulation.
This module: This module:
@ -27,7 +27,13 @@ from typing import TYPE_CHECKING
from trezor import log, loop, protobuf, utils from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common from trezor.wire import message_handler, protocol_common
from trezor.wire.codec.codec_context import CodecContext
if utils.USE_THP:
from trezor.wire.message_handler import WIRE_BUFFER_2
from trezor.wire.thp import thp_main
else:
from trezor.wire.codec.codec_context import CodecContext
from trezor.wire.context import UnexpectedMessageException from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import WIRE_BUFFER, failure, find_handler from trezor.wire.message_handler import WIRE_BUFFER, failure, find_handler
@ -52,59 +58,89 @@ def setup(iface: WireInterface) -> None:
loop.schedule(handle_session(iface)) loop.schedule(handle_session(iface))
async def handle_session(iface: WireInterface) -> None: if utils.USE_THP:
ctx = CodecContext(iface, WIRE_BUFFER)
next_msg: protocol_common.Message | None = None
# Take a mark of modules that are imported at this point, so we can async def handle_session(iface: WireInterface) -> None:
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try:
if next_msg is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
continue
else: thp_main.set_read_buffer(WIRE_BUFFER)
# Process the message from previous run. thp_main.set_write_buffer(WIRE_BUFFER_2)
msg = next_msg
next_msg = None
do_not_restart = False # Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try: try:
do_not_restart = await message_handler.handle_single_message( await thp_main.thp_main_loop(iface)
ctx, msg, handler_finder=find_handler
)
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_msg = unexpected.msg
# ...and we must not restart because that would lose the message.
do_not_restart = True
continue
except Exception as exc: except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the # Log and try again.
# following finally block.
if __debug__: if __debug__:
log.exception(__name__, exc) log.exception(__name__, exc)
finally: finally:
# Unload modules imported by the workflow. Should not raise. # Unload modules imported by the workflow. Should not raise.
if __debug__:
log.debug(__name__, "utils.unimport_end(modules) and loop.clear()")
utils.unimport_end(modules) utils.unimport_end(modules)
loop.clear()
return # pylint: disable=lost-exception
if not do_not_restart: else:
# Let the session be restarted from `main`.
loop.clear()
return # pylint: disable=lost-exception
except Exception as exc: async def handle_session(iface: WireInterface) -> None:
# Log and try again. The session handler can only exit explicitly via ctx = CodecContext(iface, WIRE_BUFFER)
# loop.clear() above. next_msg: protocol_common.Message | None = None
if __debug__:
log.exception(__name__, exc) # Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try:
if next_msg is None:
# If the previous run did not keep an unprocessed message for us,
# wait for a new one coming from the wire.
try:
msg = await ctx.read_from_wire()
except protocol_common.WireError as exc:
if __debug__:
log.exception(__name__, exc)
await ctx.write(failure(exc))
continue
else:
# Process the message from previous run.
msg = next_msg
next_msg = None
do_not_restart = False
try:
do_not_restart = await message_handler.handle_single_message(
ctx, msg, handler_finder=find_handler
)
except UnexpectedMessageException as unexpected:
# The workflow was interrupted by an unexpected message. We need to
# process it as if it was a new message...
next_msg = unexpected.msg
# ...and we must not restart because that would lose the message.
do_not_restart = True
continue
except Exception as exc:
# Log and ignore. The session handler can only exit explicitly in the
# following finally block.
if __debug__:
log.exception(__name__, exc)
finally:
# Unload modules imported by the workflow. Should not raise.
utils.unimport_end(modules)
if not do_not_restart:
# Let the session be restarted from `main`.
if __debug__:
log.debug(__name__, "loop.clear()")
loop.clear()
return # pylint: disable=lost-exception
except Exception as exc:
# Log and try again. The session handler can only exit explicitly via
# loop.clear() above.
if __debug__:
log.exception(__name__, exc)

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from storage import cache from storage import cache
from storage.cache_common import SESSIONLESS_FLAG from storage.cache_common import SESSIONLESS_FLAG
from trezor import loop, protobuf from trezor import loop, protobuf, utils
from .protocol_common import Context, Message from .protocol_common import Context, Message
@ -137,6 +137,18 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
else: else:
send_exc = None send_exc = None
def try_get_ctx_ids() -> Tuple[bytes, bytes] | None:
ids = None
if utils.USE_THP:
from trezor.wire.thp.session_context import GenericSessionContext
ctx = get_context()
if isinstance(ctx, GenericSessionContext):
ids = (ctx.channel_id, ctx.session_id.to_bytes(1, "big"))
return ids
# ACCESS TO CACHE # ACCESS TO CACHE
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -7,19 +7,36 @@ from trezor.wire import context
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
from trezor.wire.codec.codec_context import CodecContext from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
if utils.USE_THP:
import thp_common
else:
from storage import cache_codec
class TestBitcoinKeychain(unittest.TestCase): class TestBitcoinKeychain(unittest.TestCase):
if utils.USE_THP:
def __init__(self): def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) if __debug__:
super().__init__() thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
def setUp(self): def setUp(self):
cache_codec.start_session() seed = bip39.seed(" ".join(["all"] * 12), "")
seed = bip39.seed(" ".join(["all"] * 12), "") context.cache_set(cache_common.APP_COMMON_SEED, seed)
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
else:
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bitcoin(self): def test_bitcoin(self):
coin = _get_coin_by_name("Bitcoin") coin = _get_coin_by_name("Bitcoin")
@ -96,19 +113,20 @@ class TestBitcoinKeychain(unittest.TestCase):
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestAltcoinKeychains(unittest.TestCase): class TestAltcoinKeychains(unittest.TestCase):
if not utils.USE_THP:
def __init__(self): def __init__(self):
# Context is needed to test decorators and handleInitialize # Context is needed to test decorators and handleInitialize
# It allows access to codec cache from different parts of the code # It allows access to codec cache from different parts of the code
from trezor.wire import context from trezor.wire import context
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__() super().__init__()
def setUp(self): def setUp(self):
cache_codec.start_session() cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "") seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def test_bcash(self): def test_bcash(self):
coin = _get_coin_by_name("Bcash") coin = _get_coin_by_name("Bcash")

View File

@ -2,7 +2,7 @@ from common import * # isort:skip
from mock_storage import mock_storage from mock_storage import mock_storage
from storage import cache, cache_common from storage import cache, cache_common
from trezor import wire from trezor import utils, wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.enums import SafetyCheckLevel from trezor.enums import SafetyCheckLevel
from trezor.wire import context from trezor.wire import context
@ -11,17 +11,31 @@ from apps.common import safety_checks
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
from apps.common.paths import PATTERN_SEP5, PathSchema from apps.common.paths import PATTERN_SEP5, PathSchema
from trezor.wire.codec.codec_context import CodecContext from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
if utils.USE_THP:
import thp_common
if not utils.USE_THP:
from storage import cache_codec
class TestKeychain(unittest.TestCase): class TestKeychain(unittest.TestCase):
def __init__(self): if utils.USE_THP:
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self): def __init__(self):
cache_codec.start_session() if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
else:
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
def tearDown(self): def tearDown(self):
cache.clear_all() cache.clear_all()

View File

@ -10,7 +10,12 @@ from trezor.wire import context
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
from trezor.wire.codec.codec_context import CodecContext from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
if utils.USE_THP:
import thp_common
else:
from storage import cache_codec
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from ethereum_common import encode_network, make_network from ethereum_common import encode_network, make_network
@ -74,14 +79,28 @@ class TestEthereumKeychain(unittest.TestCase):
addr, addr,
) )
def __init__(self): if utils.USE_THP:
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self): def __init__(self):
cache_codec.start_session() if __debug__:
seed = bip39.seed(" ".join(["all"] * 12), "") thp_common.suppres_debug_log()
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed) thp_common.prepare_context()
super().__init__()
def setUp(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
else:
def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__()
def setUp(self):
cache_codec.start_session()
seed = bip39.seed(" ".join(["all"] * 12), "")
cache_codec.get_active_session().set(cache_common.APP_COMMON_SEED, seed)
def from_address_n(self, address_n): def from_address_n(self, address_n):
slip44 = _slip44_from_address_n(address_n) slip44 = _slip44_from_address_n(address_n)