1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-08 14:31:06 +00:00

feat(core): implement thp context and cache

This commit is contained in:
M1nd3r 2024-11-15 17:31:36 +01:00
parent aaaeb3abca
commit a6b4d735f9
10 changed files with 338 additions and 159 deletions

View File

@ -204,7 +204,11 @@ def get_features() -> Features:
return f return f
async def handle_Initialize(msg: Initialize) -> Features: if not utils.USE_THP:
async def handle_Initialize(
msg: Initialize,
) -> Features:
import storage.cache_codec as cache_codec import storage.cache_codec as cache_codec
session_id = cache_codec.start_session(msg.session_id) session_id = cache_codec.start_session(msg.session_id)
@ -226,7 +230,9 @@ async def handle_Initialize(msg: Initialize) -> Features:
have_seed = False have_seed = False
if not have_seed: if not have_seed:
context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)) context.cache_set_bool(
APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
)
features = get_features() features = get_features()
features.session_id = session_id features.session_id = session_id
@ -464,8 +470,9 @@ def boot() -> None:
MT = MessageType # local_cache_global MT = MessageType # local_cache_global
# Register workflow handlers # Register workflow handlers
if not utils.USE_THP:
workflow_handlers.register(MT.Initialize, handle_Initialize)
for msg_type, handler in [ for msg_type, handler in [
(MT.Initialize, handle_Initialize),
(MT.GetFeatures, handle_GetFeatures), (MT.GetFeatures, handle_GetFeatures),
(MT.Cancel, handle_Cancel), (MT.Cancel, handle_Cancel),
(MT.LockDevice, handle_LockDevice), (MT.LockDevice, handle_LockDevice),

View File

@ -8,7 +8,6 @@ from storage.cache_common import (
) )
from trezor import wire from trezor import wire
from trezor.crypto import cardano from trezor.crypto import cardano
from trezor.wire import context
from apps.common import mnemonic from apps.common import mnemonic
from apps.common.seed import get_seed from apps.common.seed import get_seed
@ -21,6 +20,7 @@ if TYPE_CHECKING:
from trezor import messages from trezor import messages
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.enums import CardanoDerivationType from trezor.enums import CardanoDerivationType
from trezor.wire.protocol_common import Context
from apps.common.keychain import Handler, MsgOut from apps.common.keychain import Handler, MsgOut
from apps.common.paths import Bip32Path from apps.common.paths import Bip32Path
@ -116,9 +116,9 @@ def is_minting_path(path: Bip32Path) -> bool:
return path[: len(MINTING_ROOT)] == MINTING_ROOT return path[: len(MINTING_ROOT)] == MINTING_ROOT
def derive_and_store_secrets(passphrase: str) -> None: def derive_and_store_secrets(ctx: Context, passphrase: str) -> None:
assert device.is_initialized() assert device.is_initialized()
assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO) assert ctx.cache.get_bool(APP_COMMON_DERIVE_CARDANO)
if not mnemonic.is_bip39(): if not mnemonic.is_bip39():
# nothing to do for SLIP-39, where we can derive the root from the main seed # nothing to do for SLIP-39, where we can derive the root from the main seed
@ -138,18 +138,19 @@ def derive_and_store_secrets(passphrase: str) -> None:
else: else:
icarus_trezor_secret = icarus_secret icarus_trezor_secret = icarus_secret
context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret) ctx.cache.set(APP_CARDANO_ICARUS_SECRET, icarus_secret)
context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) ctx.cache.set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
from trezor.enums import CardanoDerivationType from trezor.enums import CardanoDerivationType
from trezor.wire import context
if not device.is_initialized(): if not device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
if derivation_type == CardanoDerivationType.LEDGER: if derivation_type == CardanoDerivationType.LEDGER:
seed = await get_seed() seed = get_seed()
return Keychain(cardano.from_seed_ledger(seed)) return Keychain(cardano.from_seed_ledger(seed))
if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO): if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO):
@ -163,6 +164,11 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
# _get_secret # _get_secret
secret = context.cache_get(cache_entry) secret = context.cache_get(cache_entry)
assert secret is not None assert secret is not None
# TODO solve for non THP
# if secret is None:
# await derive_and_store_roots_legacy()
# secret = context.cache_get(cache_entry)
# assert secret is not None
root = cardano.from_secret(secret) root = cardano.from_secret(secret)
return Keychain(root) return Keychain(root)
@ -173,7 +179,7 @@ async def _get_keychain(derivation_type: CardanoDerivationType) -> Keychain:
return await _get_keychain_bip39(derivation_type) return await _get_keychain_bip39(derivation_type)
else: else:
# derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0023.md # derive the root node via SLIP-0023 https://github.com/satoshilabs/slips/blob/master/slip-0023.md
seed = await get_seed() seed = get_seed()
return Keychain(cardano.from_seed_slip23(seed)) return Keychain(cardano.from_seed_slip23(seed))

View File

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
from trezor import wire from trezor import utils, wire
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.wire import context from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter from trezor.wire.message_handler import filters, remove_filter
@ -24,14 +24,23 @@ def deactivate_repeated_backup() -> None:
remove_filter(_repeated_backup_filter) remove_filter(_repeated_backup_filter)
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( if utils.USE_THP:
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
MessageType.GetFeatures,
MessageType.EndSession,
MessageType.BackupDevice,
MessageType.WipeDevice,
MessageType.Cancel,
)
else:
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (
MessageType.Initialize, MessageType.Initialize,
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.EndSession, MessageType.EndSession,
MessageType.BackupDevice, MessageType.BackupDevice,
MessageType.WipeDevice, MessageType.WipeDevice,
MessageType.Cancel, MessageType.Cancel,
) )
def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: def _repeated_backup_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:

View File

@ -5,14 +5,18 @@ from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPH
from trezor import utils from trezor import utils
from trezor.crypto import hmac from trezor.crypto import hmac
from trezor.wire import context from trezor.wire import context
from trezor.wire.context import get_context
from trezor.wire.errors import DataError
from apps.common import cache from apps.common import cache
from . import mnemonic from . import mnemonic
from .passphrase import get as get_passphrase from .passphrase import get_passphrase as get_passphrase
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.crypto import bip32 from trezor.crypto import bip32
from trezor.messages import ThpCreateNewSession
from trezor.wire.protocol_common import Context
from .paths import Bip32Path, Slip21Path from .paths import Bip32Path, Slip21Path
@ -22,6 +26,10 @@ if not utils.BITCOIN_ONLY:
APP_COMMON_DERIVE_CARDANO, APP_COMMON_DERIVE_CARDANO,
) )
if not utils.USE_THP:
from .passphrase import get as get_passphrase_legacy
class Slip21Node: class Slip21Node:
""" """
This class implements the SLIP-0021 hierarchical derivation of symmetric keys, see This class implements the SLIP-0021 hierarchical derivation of symmetric keys, see
@ -53,51 +61,88 @@ class Slip21Node:
return Slip21Node(data=self.data) return Slip21Node(data=self.data)
if not utils.BITCOIN_ONLY: def get_seed() -> bytes:
# === Cardano variant === common_seed = context.cache_get(APP_COMMON_SEED)
# We want to derive both the normal seed and the Cardano seed together, AND assert common_seed is not None
# expose a method for Cardano to do the same return common_seed
if utils.BITCOIN_ONLY:
# === Bitcoin_only variant ===
# We want to derive the normal seed ONLY
async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None:
if msg.passphrase is not None and msg.on_device:
raise DataError("Passphrase provided when it shouldn't be!")
if ctx.cache.is_set(APP_COMMON_SEED):
raise Exception("Seed is already set!")
async def derive_and_store_roots() -> None:
from trezor import wire from trezor import wire
if not storage_device.is_initialized(): if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
need_seed = not context.cache_is_set(APP_COMMON_SEED) passphrase = await get_passphrase(msg)
need_cardano_secret = context.cache_get_bool( common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
else:
# === Cardano variant ===
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
async def derive_and_store_roots(ctx: Context, msg: ThpCreateNewSession) -> None:
if msg.passphrase is not None and msg.on_device:
raise DataError("Passphrase provided when it shouldn't be!")
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
if ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET):
raise Exception("Cardano icarus secret is already set!")
passphrase = await get_passphrase(msg)
common_seed = mnemonic.get_seed(passphrase)
ctx.cache.set(APP_COMMON_SEED, common_seed)
if msg.derive_cardano:
from apps.cardano.seed import derive_and_store_secrets
ctx.cache.set_bool(APP_COMMON_DERIVE_CARDANO, True)
derive_and_store_secrets(ctx, passphrase)
if not utils.USE_THP:
async def derive_and_store_roots_legacy() -> None:
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
ctx = get_context()
need_seed = not ctx.cache.is_set(APP_COMMON_SEED)
need_cardano_secret = ctx.cache.get_bool(
APP_COMMON_DERIVE_CARDANO APP_COMMON_DERIVE_CARDANO
) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET) ) and not ctx.cache.is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret: if not need_seed and not need_cardano_secret:
return return
passphrase = await get_passphrase() passphrase = await get_passphrase_legacy()
if need_seed: if need_seed:
common_seed = mnemonic.get_seed(passphrase) common_seed = mnemonic.get_seed(passphrase)
context.cache_set(APP_COMMON_SEED, common_seed) ctx.cache.set(APP_COMMON_SEED, common_seed)
if need_cardano_secret: if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase) derive_and_store_secrets(ctx, passphrase)
@cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes:
await derive_and_store_roots()
common_seed = context.cache_get(APP_COMMON_SEED)
assert common_seed is not None
return common_seed
else:
# === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else.
@cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes:
passphrase = await get_passphrase()
return mnemonic.get_seed(passphrase)
@cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE) @cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)

View File

@ -3,7 +3,7 @@ import gc
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache from storage.cache_common import SESSIONLESS_FLAG, SessionlessCache
from storage import cache_codec from trezor import utils
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Tuple from typing import Tuple
@ -20,7 +20,15 @@ if TYPE_CHECKING:
_SESSIONLESS_CACHE = SessionlessCache() _SESSIONLESS_CACHE = SessionlessCache()
_PROTOCOL_CACHE = cache_codec
if utils.USE_THP:
from storage import cache_thp
_PROTOCOL_CACHE = cache_thp
else:
from storage import cache_codec
_PROTOCOL_CACHE = cache_codec
_PROTOCOL_CACHE.initialize() _PROTOCOL_CACHE.initialize()
_SESSIONLESS_CACHE.clear() _SESSIONLESS_CACHE.clear()
@ -32,6 +40,11 @@ def clear_all(excluded: Tuple[bytes, bytes] | None = None) -> None:
global autolock_last_touch global autolock_last_touch
autolock_last_touch = None autolock_last_touch = None
_SESSIONLESS_CACHE.clear() _SESSIONLESS_CACHE.clear()
if utils.USE_THP and excluded is not None:
# If we want to keep THP connection alive, we do not clear communication keys
cache_thp.clear_all_except_one_session_keys(excluded)
else:
_PROTOCOL_CACHE.clear_all() _PROTOCOL_CACHE.clear_all()

View File

@ -5,7 +5,7 @@ Handles on-the-wire communication with a host computer. The communication is:
- Request / response. - Request / response.
- Protobuf-encoded, see `protobuf.py`. - Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py`. - Wrapped in a simple envelope format, see `trezor/wire/codec/codec_v1.py` or `trezor/wire/thp/thp_main.py`.
- Transferred over USB interface, or UDP in case of Unix emulation. - Transferred over USB interface, or UDP in case of Unix emulation.
This module: This module:
@ -27,7 +27,13 @@ from typing import TYPE_CHECKING
from trezor import log, loop, protobuf, utils from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common from trezor.wire import message_handler, protocol_common
from trezor.wire.codec.codec_context import CodecContext
if utils.USE_THP:
from trezor.wire.message_handler import WIRE_BUFFER_2
from trezor.wire.thp import thp_main
else:
from trezor.wire.codec.codec_context import CodecContext
from trezor.wire.context import UnexpectedMessageException from trezor.wire.context import UnexpectedMessageException
from trezor.wire.message_handler import WIRE_BUFFER, failure, find_handler from trezor.wire.message_handler import WIRE_BUFFER, failure, find_handler
@ -52,7 +58,35 @@ def setup(iface: WireInterface) -> None:
loop.schedule(handle_session(iface)) loop.schedule(handle_session(iface))
async def handle_session(iface: WireInterface) -> None: if utils.USE_THP:
async def handle_session(iface: WireInterface) -> None:
thp_main.set_read_buffer(WIRE_BUFFER)
thp_main.set_write_buffer(WIRE_BUFFER_2)
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others.
modules = utils.unimport_begin()
while True:
try:
await thp_main.thp_main_loop(iface)
except Exception as exc:
# Log and try again.
if __debug__:
log.exception(__name__, exc)
finally:
# Unload modules imported by the workflow. Should not raise.
if __debug__:
log.debug(__name__, "utils.unimport_end(modules) and loop.clear()")
utils.unimport_end(modules)
loop.clear()
return # pylint: disable=lost-exception
else:
async def handle_session(iface: WireInterface) -> None:
ctx = CodecContext(iface, WIRE_BUFFER) ctx = CodecContext(iface, WIRE_BUFFER)
next_msg: protocol_common.Message | None = None next_msg: protocol_common.Message | None = None
@ -100,6 +134,8 @@ async def handle_session(iface: WireInterface) -> None:
if not do_not_restart: if not do_not_restart:
# Let the session be restarted from `main`. # Let the session be restarted from `main`.
if __debug__:
log.debug(__name__, "loop.clear()")
loop.clear() loop.clear()
return # pylint: disable=lost-exception return # pylint: disable=lost-exception

View File

@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
from storage import cache from storage import cache
from storage.cache_common import SESSIONLESS_FLAG from storage.cache_common import SESSIONLESS_FLAG
from trezor import loop, protobuf from trezor import loop, protobuf, utils
from .protocol_common import Context, Message from .protocol_common import Context, Message
@ -137,6 +137,18 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
else: else:
send_exc = None send_exc = None
def try_get_ctx_ids() -> Tuple[bytes, bytes] | None:
ids = None
if utils.USE_THP:
from trezor.wire.thp.session_context import GenericSessionContext
ctx = get_context()
if isinstance(ctx, GenericSessionContext):
ids = (ctx.channel_id, ctx.session_id.to_bytes(1, "big"))
return ids
# ACCESS TO CACHE # ACCESS TO CACHE
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@ -7,10 +7,27 @@ from trezor.wire import context
from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin from apps.bitcoin.keychain import _get_coin_by_name, _get_keychain_for_coin
from trezor.wire.codec.codec_context import CodecContext from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
if utils.USE_THP:
import thp_common
else:
from storage import cache_codec
class TestBitcoinKeychain(unittest.TestCase): class TestBitcoinKeychain(unittest.TestCase):
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
def setUp(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
else:
def __init__(self): def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
@ -96,6 +113,7 @@ class TestBitcoinKeychain(unittest.TestCase):
@unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin") @unittest.skipUnless(not utils.BITCOIN_ONLY, "altcoin")
class TestAltcoinKeychains(unittest.TestCase): class TestAltcoinKeychains(unittest.TestCase):
if not utils.USE_THP:
def __init__(self): def __init__(self):
# Context is needed to test decorators and handleInitialize # Context is needed to test decorators and handleInitialize

View File

@ -2,7 +2,7 @@ from common import * # isort:skip
from mock_storage import mock_storage from mock_storage import mock_storage
from storage import cache, cache_common from storage import cache, cache_common
from trezor import wire from trezor import utils, wire
from trezor.crypto import bip39 from trezor.crypto import bip39
from trezor.enums import SafetyCheckLevel from trezor.enums import SafetyCheckLevel
from trezor.wire import context from trezor.wire import context
@ -11,11 +11,25 @@ from apps.common import safety_checks
from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain from apps.common.keychain import Keychain, LRUCache, get_keychain, with_slip44_keychain
from apps.common.paths import PATTERN_SEP5, PathSchema from apps.common.paths import PATTERN_SEP5, PathSchema
from trezor.wire.codec.codec_context import CodecContext from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
if utils.USE_THP:
import thp_common
if not utils.USE_THP:
from storage import cache_codec
class TestKeychain(unittest.TestCase): class TestKeychain(unittest.TestCase):
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
else:
def __init__(self): def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__() super().__init__()

View File

@ -10,7 +10,12 @@ from trezor.wire import context
from apps.common.keychain import get_keychain from apps.common.keychain import get_keychain
from apps.common.paths import HARDENED from apps.common.paths import HARDENED
from trezor.wire.codec.codec_context import CodecContext from trezor.wire.codec.codec_context import CodecContext
from storage import cache_codec
if utils.USE_THP:
import thp_common
else:
from storage import cache_codec
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
from ethereum_common import encode_network, make_network from ethereum_common import encode_network, make_network
@ -74,6 +79,20 @@ class TestEthereumKeychain(unittest.TestCase):
addr, addr,
) )
if utils.USE_THP:
def __init__(self):
if __debug__:
thp_common.suppres_debug_log()
thp_common.prepare_context()
super().__init__()
def setUp(self):
seed = bip39.seed(" ".join(["all"] * 12), "")
context.cache_set(cache_common.APP_COMMON_SEED, seed)
else:
def __init__(self): def __init__(self):
context.CURRENT_CONTEXT = CodecContext(None, bytearray(64)) context.CURRENT_CONTEXT = CodecContext(None, bytearray(64))
super().__init__() super().__init__()