From 2f0a7ec740e11f19ecc3346815e5bd53d5fe8b26 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 3 May 2024 16:55:18 +0200 Subject: [PATCH] feat(core): implement cache handling of passphrase, refactor cache --- core/src/all_modules.py | 2 + core/src/apps/base.py | 17 +- .../apps/bitcoin/sign_tx/payment_request.py | 4 +- core/src/apps/cardano/seed.py | 9 +- core/src/apps/common/authorization.py | 13 +- core/src/apps/common/cache.py | 23 +++ core/src/apps/common/request_pin.py | 5 +- core/src/apps/common/safety_checks.py | 5 +- core/src/apps/common/seed.py | 58 +++--- core/src/apps/misc/cosi_commit.py | 11 +- core/src/apps/monero/live_refresh.py | 5 +- core/src/apps/thp/create_session.py | 4 + core/src/storage/cache.py | 172 +++++++++++------- core/src/storage/cache_thp.py | 145 +-------------- core/src/storage/device.py | 7 +- core/src/trezor/wire/context.py | 74 ++++++++ core/src/trezor/wire/protocol_common.py | 19 ++ core/src/trezor/wire/thp/__init__.py | 16 +- core/src/trezor/wire/thp/channel.py | 34 +++- core/src/trezor/wire/thp/handler_provider.py | 20 +- core/src/trezor/wire/thp/memory_manager.py | 18 +- .../wire/thp/received_message_handler.py | 12 +- core/src/trezor/wire/thp/session_context.py | 34 +++- core/src/trezor/wire/thp/thp_session.py | 45 +---- 24 files changed, 437 insertions(+), 315 deletions(-) create mode 100644 core/src/apps/common/cache.py diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 0182f8a06..9337f1fb1 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -325,6 +325,8 @@ apps.common.address_type import apps.common.address_type apps.common.authorization import apps.common.authorization +apps.common.cache +import apps.common.cache apps.common.cbor import apps.common.cbor apps.common.coininfo diff --git a/core/src/apps/base.py b/core/src/apps/base.py index 4f06b25d8..0d9b10a41 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -6,6 +6,7 @@ from trezor import TR, config, utils, wire, workflow from trezor.enums import HomescreenFormat, MessageType from trezor.messages import Success, UnlockPath from trezor.ui.layouts import confirm_action +from trezor.wire import context from . import workflow_handlers @@ -33,7 +34,7 @@ def busy_expiry_ms() -> int: Returns the time left until the busy state expires or 0 if the device is not in the busy state. """ - busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) + busy_deadline_ms = context.cache_get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) if busy_deadline_ms is None: return 0 @@ -175,7 +176,7 @@ def get_features() -> Features: return f -# handle_Initialize should not be used with THP to start a new session +@storage_cache.check_thp_is_not_used async def handle_Initialize(msg: Initialize) -> Features: if utils.USE_THP: raise ValueError("With THP enabled, a session id must be provided in args") @@ -183,8 +184,8 @@ async def handle_Initialize(msg: Initialize) -> Features: session_id = storage_cache.start_session(msg.session_id) if not utils.BITCOIN_ONLY: - derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO) - have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED) + derive_cardano = context.cache_get(storage_cache.APP_COMMON_DERIVE_CARDANO) + have_seed = context.cache_is_set(storage_cache.APP_COMMON_SEED) if ( have_seed @@ -194,11 +195,11 @@ async def handle_Initialize(msg: Initialize) -> Features: # seed is already derived, and host wants to change derive_cardano setting # => create a new session storage_cache.end_current_session() - session_id = storage_cache.start_session() # This should not be used in THP + session_id = storage_cache.start_session() have_seed = False if not have_seed: - storage_cache.set( + context.cache_set( storage_cache.APP_COMMON_DERIVE_CARDANO, b"\x01" if msg.derive_cardano else b"", ) @@ -229,7 +230,7 @@ async def handle_SetBusy(msg: SetBusy) -> Success: import utime deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms) - storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) + context.cache_set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) else: storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) set_homescreen() @@ -338,7 +339,7 @@ def set_homescreen() -> None: set_default = workflow.set_default # local_cache_attribute - if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS): + if context.cache_is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS): from apps.homescreen import busyscreen set_default(busyscreen) diff --git a/core/src/apps/bitcoin/sign_tx/payment_request.py b/core/src/apps/bitcoin/sign_tx/payment_request.py index 8f2f7b88a..31cbe8232 100644 --- a/core/src/apps/bitcoin/sign_tx/payment_request.py +++ b/core/src/apps/bitcoin/sign_tx/payment_request.py @@ -1,7 +1,7 @@ from micropython import const from typing import TYPE_CHECKING -from trezor.wire import DataError +from trezor.wire import DataError, context from .. import writers @@ -42,7 +42,7 @@ class PaymentRequestVerifier: if msg.nonce: nonce = bytes(msg.nonce) - if cache.get(cache.APP_COMMON_NONCE) != nonce: + if context.cache_get(cache.APP_COMMON_NONCE) != nonce: raise DataError("Invalid nonce in payment request.") cache.delete(cache.APP_COMMON_NONCE) else: diff --git a/core/src/apps/cardano/seed.py b/core/src/apps/cardano/seed.py index 4a7f7b267..f6f4ccb16 100644 --- a/core/src/apps/cardano/seed.py +++ b/core/src/apps/cardano/seed.py @@ -15,6 +15,7 @@ if TYPE_CHECKING: from trezor import messages from trezor.crypto import bip32 from trezor.enums import CardanoDerivationType + from trezor.wire.protocol_common import Context from apps.common.keychain import Handler, MsgOut from apps.common.paths import Bip32Path @@ -110,9 +111,9 @@ def is_minting_path(path: Bip32Path) -> bool: return path[: len(MINTING_ROOT)] == MINTING_ROOT -def derive_and_store_secrets(passphrase: str) -> None: +def derive_and_store_secrets(ctx: Context, passphrase: str) -> None: assert device.is_initialized() - assert cache.get(cache.APP_COMMON_DERIVE_CARDANO) + assert ctx.cache_get(cache.APP_COMMON_DERIVE_CARDANO) if not mnemonic.is_bip39(): # nothing to do for SLIP-39, where we can derive the root from the main seed @@ -132,8 +133,8 @@ def derive_and_store_secrets(passphrase: str) -> None: else: icarus_trezor_secret = icarus_secret - cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret) - cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) + ctx.cache_set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret) + ctx.cache_set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: diff --git a/core/src/apps/common/authorization.py b/core/src/apps/common/authorization.py index 4d6e58e4d..e6a160f65 100644 --- a/core/src/apps/common/authorization.py +++ b/core/src/apps/common/authorization.py @@ -3,6 +3,7 @@ from typing import Iterable import storage.cache as storage_cache from trezor import protobuf from trezor.enums import MessageType +from trezor.wire import context WIRE_TYPES: dict[int, tuple[int, ...]] = { MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof), @@ -17,7 +18,7 @@ APP_COMMON_AUTHORIZATION_TYPE = ( def is_set() -> bool: - return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None + return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None def set(auth_message: protobuf.MessageType) -> None: @@ -29,16 +30,16 @@ def set(auth_message: protobuf.MessageType) -> None: # (because only wire-level messages have wire_type, which we use as identifier) ensure(auth_message.MESSAGE_WIRE_TYPE is not None) assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too - storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) - storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer) + context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) + context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer) def get() -> protobuf.MessageType | None: - stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) + stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE) if not stored_auth_type: return None - buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"") + buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"") return protobuf.load_message_buffer(buffer, stored_auth_type) @@ -49,7 +50,7 @@ def is_set_any_session(auth_type: MessageType) -> bool: def get_wire_types() -> Iterable[int]: - stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) + stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE) if stored_auth_type is None: return () diff --git a/core/src/apps/common/cache.py b/core/src/apps/common/cache.py new file mode 100644 index 000000000..af3dd977f --- /dev/null +++ b/core/src/apps/common/cache.py @@ -0,0 +1,23 @@ +from typing import TYPE_CHECKING + +from trezor.wire import context + +if TYPE_CHECKING: + from typing import Callable, ParamSpec + + P = ParamSpec("P") + ByteFunc = Callable[P, bytes] + + +def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: + def decorator(func: ByteFunc[P]) -> ByteFunc[P]: + def wrapper(*args: P.args, **kwargs: P.kwargs): + value = context.cache_get(key) + if value is None: + value = func(*args, **kwargs) + context.cache_set(key, value) + return value + + return wrapper + + return decorator diff --git a/core/src/apps/common/request_pin.py b/core/src/apps/common/request_pin.py index fd5ad0d0a..56fe86423 100644 --- a/core/src/apps/common/request_pin.py +++ b/core/src/apps/common/request_pin.py @@ -4,6 +4,7 @@ from typing import Any, NoReturn import storage.cache as storage_cache from trezor import TR, config, utils, wire from trezor.ui.layouts import show_error_and_raise +from trezor.wire import context async def _request_sd_salt( @@ -77,7 +78,7 @@ async def request_pin_and_sd_salt( def _set_last_unlock_time() -> None: now = utime.ticks_ms() - storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) + context.cache_set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) _DEF_ARG_PIN_ENTER: str = TR.pin__enter @@ -91,7 +92,7 @@ async def verify_user_pin( ) -> None: # _get_last_unlock_time last_unlock = int.from_bytes( - storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" + context.cache_get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" ) if ( diff --git a/core/src/apps/common/safety_checks.py b/core/src/apps/common/safety_checks.py index dbdff4463..31a609239 100644 --- a/core/src/apps/common/safety_checks.py +++ b/core/src/apps/common/safety_checks.py @@ -3,13 +3,14 @@ import storage.device as storage_device from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT from trezor.enums import SafetyCheckLevel +from trezor.wire import context def read_setting() -> SafetyCheckLevel: """ Returns the effective safety check level. """ - temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) + temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) if temporary_safety_check_level: return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum] else: @@ -34,7 +35,7 @@ def apply_setting(level: SafetyCheckLevel) -> None: storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) elif level == SafetyCheckLevel.PromptTemporarily: storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) - storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) + context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) else: raise ValueError("Unknown SafetyCheckLevel") diff --git a/core/src/apps/common/seed.py b/core/src/apps/common/seed.py index d773909b4..f7741b080 100644 --- a/core/src/apps/common/seed.py +++ b/core/src/apps/common/seed.py @@ -2,14 +2,20 @@ from typing import TYPE_CHECKING import storage.cache as storage_cache import storage.device as storage_device -from trezor import utils +from trezor import log, utils from trezor.crypto import hmac +from trezor.wire import context +from trezor.wire.context import get_context + +from apps.common import cache from . import mnemonic from .passphrase import get as get_passphrase if TYPE_CHECKING: from trezor.crypto import bip32 + from trezor.messages import ThpCreateNewSession + from trezor.wire.protocol_common import Context from .paths import Bip32Path, Slip21Path @@ -45,54 +51,56 @@ class Slip21Node: return Slip21Node(data=self.data) +async def get_seed() -> bytes: + common_seed = context.cache_get(storage_cache.APP_COMMON_SEED) + assert common_seed is not None + return common_seed + + if not utils.BITCOIN_ONLY: # === Cardano variant === # We want to derive both the normal seed and the Cardano seed together, AND # expose a method for Cardano to do the same - async def derive_and_store_roots() -> None: + async def derive_and_store_roots( + ctx: Context | None = None, msg: ThpCreateNewSession | None = None + ) -> None: + if __debug__: + log.debug(__name__, "derive_and_store_roots start") + from trezor import wire if not storage_device.is_initialized(): raise wire.NotInitialized("Device is not initialized") - need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED) - need_cardano_secret = storage_cache.get( + # For old codec_v1 implementation, the context is passed using get_context + # This handling is specific. In the rest of the code, a context.cache_* is used instead + if ctx is None: + ctx = get_context() + need_seed = not ctx.cache_is_set(storage_cache.APP_COMMON_SEED) + need_cardano_secret = ctx.cache_get( storage_cache.APP_COMMON_DERIVE_CARDANO - ) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET) + ) and not ctx.cache_is_set(storage_cache.APP_CARDANO_ICARUS_SECRET) if not need_seed and not need_cardano_secret: return - passphrase = await get_passphrase() + if msg is None or msg.on_device: + passphrase = await get_passphrase() + else: + passphrase = msg.passphrase or "" if need_seed: common_seed = mnemonic.get_seed(passphrase) - storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed) + ctx.cache_set(storage_cache.APP_COMMON_SEED, common_seed) if need_cardano_secret: from apps.cardano.seed import derive_and_store_secrets - derive_and_store_secrets(passphrase) - - @storage_cache.stored_async(storage_cache.APP_COMMON_SEED) - async def get_seed() -> bytes: - await derive_and_store_roots() - common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED) - assert common_seed is not None - return common_seed - -else: - # === Bitcoin-only variant === - # We use the simple version of `get_seed` that never needs to derive anything else. - - @storage_cache.stored_async(storage_cache.APP_COMMON_SEED) - async def get_seed() -> bytes: - passphrase = await get_passphrase() - return mnemonic.get_seed(passphrase) + derive_and_store_secrets(ctx, passphrase) -@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) +@cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) def _get_seed_without_passphrase() -> bytes: if not storage_device.is_initialized(): raise Exception("Device is not initialized") diff --git a/core/src/apps/misc/cosi_commit.py b/core/src/apps/misc/cosi_commit.py index b682ff7ab..0b0459fb5 100644 --- a/core/src/apps/misc/cosi_commit.py +++ b/core/src/apps/misc/cosi_commit.py @@ -59,6 +59,7 @@ async def cosi_commit(msg: CosiCommit) -> CosiSignature: from trezor.crypto import cosi from trezor.crypto.curve import ed25519 from trezor.ui.layouts import confirm_blob, confirm_text + from trezor.wire import context from trezor.wire.context import call from apps.common import paths @@ -71,11 +72,11 @@ async def cosi_commit(msg: CosiCommit) -> CosiSignature: seckey = node.private_key() pubkey = ed25519.publickey(seckey) - if not storage_cache.is_set(storage_cache.APP_MISC_COSI_COMMITMENT): + if not context.cache_is_set(storage_cache.APP_MISC_COSI_COMMITMENT): nonce, commitment = cosi.commit() - storage_cache.set(storage_cache.APP_MISC_COSI_NONCE, nonce) - storage_cache.set(storage_cache.APP_MISC_COSI_COMMITMENT, commitment) - commitment = storage_cache.get(storage_cache.APP_MISC_COSI_COMMITMENT) + context.cache_set(storage_cache.APP_MISC_COSI_NONCE, nonce) + context.cache_set(storage_cache.APP_MISC_COSI_COMMITMENT, commitment) + commitment = context.cache_get(storage_cache.APP_MISC_COSI_COMMITMENT) if commitment is None: raise RuntimeError @@ -101,7 +102,7 @@ async def cosi_commit(msg: CosiCommit) -> CosiSignature: ) # clear nonce from cache - nonce = storage_cache.get(storage_cache.APP_MISC_COSI_NONCE) + nonce = context.cache_get(storage_cache.APP_MISC_COSI_NONCE) storage_cache.delete(storage_cache.APP_MISC_COSI_COMMITMENT) storage_cache.delete(storage_cache.APP_MISC_COSI_NONCE) if nonce is None: diff --git a/core/src/apps/monero/live_refresh.py b/core/src/apps/monero/live_refresh.py index 011b2e283..9bf157149 100644 --- a/core/src/apps/monero/live_refresh.py +++ b/core/src/apps/monero/live_refresh.py @@ -59,14 +59,15 @@ async def _init_step( ) -> MoneroLiveRefreshStartAck: import storage.cache as storage_cache from trezor.messages import MoneroLiveRefreshStartAck + from trezor.wire import context from apps.common import paths await paths.validate_path(keychain, msg.address_n) - if not storage_cache.get(storage_cache.APP_MONERO_LIVE_REFRESH): + if not context.cache_get(storage_cache.APP_MONERO_LIVE_REFRESH): await layout.require_confirm_live_refresh() - storage_cache.set(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01") + context.cache_set(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01") s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) diff --git a/core/src/apps/thp/create_session.py b/core/src/apps/thp/create_session.py index c76ca52dc..2035d7a76 100644 --- a/core/src/apps/thp/create_session.py +++ b/core/src/apps/thp/create_session.py @@ -14,7 +14,11 @@ async def create_new_session( # from apps.common.seed import get_seed TODO from trezor.wire.thp.session_manager import create_new_session + from apps.common.seed import derive_and_store_roots + session = create_new_session(channel) + await derive_and_store_roots(session, message) + session.set_session_state(SessionState.ALLOCATED) channel.sessions[session.session_id] = session loop.schedule(session.handle()) diff --git a/core/src/storage/cache.py b/core/src/storage/cache.py index e9bb593ef..9bb68b90e 100644 --- a/core/src/storage/cache.py +++ b/core/src/storage/cache.py @@ -9,9 +9,25 @@ from trezor import utils SESSIONLESS_FLAG = const(128) if TYPE_CHECKING: - from typing import TypeVar, overload + from typing import Callable, ParamSpec, TypeVar, overload T = TypeVar("T") + P = ParamSpec("P") + + +def check_thp_is_not_used(f: Callable[P, T]) -> Callable[P, T]: + """A type-safe decorator to raise an exception when the function is called with THP enabled. + + This decorator should be removed after the caches for Codec_v1 and THP are properly refactored and separated. + """ + + def inner(*args: P.args, **kwargs: P.kwargs) -> T: + if utils.USE_THP: + raise Exception("Cannot call this function with the new THP enabled") + return f(*args, **kwargs) + + return inner + # Traditional cache keys APP_COMMON_SEED = const(0) @@ -74,6 +90,18 @@ _SESSIONLESS_CACHE.clear() gc.collect() +if TYPE_CHECKING: + + @overload + def get(key: int) -> bytes | None: ... + + @overload + def get(key: int, default: T) -> bytes | T: # noqa: F811 + ... + + +# Common functions + def clear_all() -> None: global autolock_last_touch @@ -82,42 +110,99 @@ def clear_all() -> None: _PROTOCOL_CACHE.clear_all() +def get_int_all_sessions(key: int) -> builtins.set[int]: + if key & SESSIONLESS_FLAG: + values = builtins.set() + encoded = _SESSIONLESS_CACHE.get(key) + if encoded is not None: + values.add(int.from_bytes(encoded, "big")) + return values + return _PROTOCOL_CACHE.get_int_all_sessions(key) + + +# Sessionless functions + + +def get_sessionless( + key: int, default: T | None = None +) -> bytes | T | None: # noqa: F811 + if key & SESSIONLESS_FLAG: + return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default) + raise ValueError("Argument 'key' does not have a sessionless flag") + + +def get_int_sessionless( + key: int, default: T | None = None +) -> int | T | None: # noqa: F811 + encoded = get_sessionless(key) + if encoded is None: + return default + else: + return int.from_bytes(encoded, "big") + + +def is_set_sessionless(key: int) -> bool: + if key & SESSIONLESS_FLAG: + return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG) + raise ValueError("Argument 'key' does not have a sessionless flag") + + +def set_sessionless(key: int, value: bytes) -> None: + if key & SESSIONLESS_FLAG: + _SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value) + return + raise ValueError("Argument 'key' does not have a sessionless flag") + + +def set_int_sessionless(key: int, value: int) -> None: + + if not key & SESSIONLESS_FLAG: + raise ValueError("Argument 'key' does not have a sessionless flag") + + length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG] + encoded = value.to_bytes(length, "big") + + # Ensure that the value fits within the length. Micropython's int.to_bytes() + # doesn't raise OverflowError. + assert int.from_bytes(encoded, "big") == value + + set_sessionless(key, encoded) + + +# Codec_v1 specific functions + + +@check_thp_is_not_used def start_session(received_session_id: bytes | None = None) -> bytes: - return _PROTOCOL_CACHE.start_session(received_session_id) + return cache_codec.start_session(received_session_id) +@check_thp_is_not_used def end_current_session() -> None: - _PROTOCOL_CACHE.end_current_session() + cache_codec.end_current_session() +@check_thp_is_not_used def delete(key: int) -> None: if key & SESSIONLESS_FLAG: return _SESSIONLESS_CACHE.delete(key ^ SESSIONLESS_FLAG) - active_session = _PROTOCOL_CACHE.get_active_session() + active_session = cache_codec.get_active_session() if active_session is None: raise InvalidSessionError return active_session.delete(key) -if TYPE_CHECKING: - - @overload - def get(key: int) -> bytes | None: ... - - @overload - def get(key: int, default: T) -> bytes | T: # noqa: F811 - ... - - +@check_thp_is_not_used def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 if key & SESSIONLESS_FLAG: - return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default) - active_session = _PROTOCOL_CACHE.get_active_session() + return get_sessionless(key, default) + active_session = cache_codec.get_active_session() if active_session is None: raise InvalidSessionError return active_session.get(key, default) +@check_thp_is_not_used def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 encoded = get(key) if encoded is None: @@ -126,37 +211,30 @@ def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 return int.from_bytes(encoded, "big") -def get_int_all_sessions(key: int) -> builtins.set[int]: - if key & SESSIONLESS_FLAG: - values = builtins.set() - encoded = _SESSIONLESS_CACHE.get(key) - if encoded is not None: - values.add(int.from_bytes(encoded, "big")) - return values - return _PROTOCOL_CACHE.get_int_all_sessions(key) - - +@check_thp_is_not_used def is_set(key: int) -> bool: if key & SESSIONLESS_FLAG: return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG) - active_session = _PROTOCOL_CACHE.get_active_session() + active_session = cache_codec.get_active_session() if active_session is None: raise InvalidSessionError return active_session.is_set(key) +@check_thp_is_not_used def set(key: int, value: bytes) -> None: if key & SESSIONLESS_FLAG: _SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value) return - active_session = _PROTOCOL_CACHE.get_active_session() + active_session = cache_codec.get_active_session() if active_session is None: raise InvalidSessionError active_session.set(key, value) +@check_thp_is_not_used def set_int(key: int, value: int) -> None: - active_session = _PROTOCOL_CACHE.get_active_session() + active_session = cache_codec.get_active_session() if key & SESSIONLESS_FLAG: length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG] @@ -172,39 +250,3 @@ def set_int(key: int, value: int) -> None: assert int.from_bytes(encoded, "big") == value set(key, encoded) - - -if TYPE_CHECKING: - from typing import Awaitable, Callable, ParamSpec, TypeVar - - P = ParamSpec("P") - ByteFunc = Callable[P, bytes] - AsyncByteFunc = Callable[P, Awaitable[bytes]] - - -def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]: - def decorator(func: ByteFunc[P]) -> ByteFunc[P]: - def wrapper(*args: P.args, **kwargs: P.kwargs): - value = get(key) - if value is None: - value = func(*args, **kwargs) - set(key, value) - return value - - return wrapper - - return decorator - - -def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]: - def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]: - async def wrapper(*args: P.args, **kwargs: P.kwargs): - value = get(key) - if value is None: - value = await func(*args, **kwargs) - set(key, value) - return value - - return wrapper - - return decorator diff --git a/core/src/storage/cache_thp.py b/core/src/storage/cache_thp.py index 7e0604283..01877d20a 100644 --- a/core/src/storage/cache_thp.py +++ b/core/src/storage/cache_thp.py @@ -2,7 +2,7 @@ import builtins from micropython import const # pyright: ignore[reportMissingModuleSource] from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] -from storage.cache_common import DataCache, InvalidSessionError +from storage.cache_common import DataCache from trezor import utils if TYPE_CHECKING: @@ -96,30 +96,22 @@ class SessionThpCache(ConnectionCache): _CHANNELS: list[ChannelCache] = [] _SESSIONS: list[SessionThpCache] = [] -_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace def initialize() -> None: global _CHANNELS global _SESSIONS - global _UNAUTHENTICATED_SESSIONS for _ in range(_MAX_CHANNELS_COUNT): _CHANNELS.append(ChannelCache()) for _ in range(_MAX_SESSIONS_COUNT): _SESSIONS.append(SessionThpCache()) - for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): - _UNAUTHENTICATED_SESSIONS.append(SessionThpCache()) - for channel in _CHANNELS: channel.clear() for session in _SESSIONS: session.clear() - for session in _UNAUTHENTICATED_SESSIONS: - session.clear() - initialize() @@ -128,8 +120,6 @@ initialize() _next_unauthenicated_session_index: int = 0 # TODO remove # First unauthenticated channel will have index 0 -_is_active_session_authenticated: bool -_active_session_idx: int | None = None _usage_counter = 0 # with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex) @@ -256,22 +246,6 @@ def _get_session_state(session: SessionThpCache) -> int: return int.from_bytes(session.state, "big") -def get_active_session_id() -> bytearray | None: - active_session = get_active_session() - - if active_session is None: - return None - return active_session.session_id - - -def get_active_session() -> SessionThpCache | None: - if _active_session_idx is None: - return None - if _is_active_session_authenticated: - return _SESSIONS[_active_session_idx] - return _UNAUTHENTICATED_SESSIONS[_active_session_idx] - - def get_next_channel_id() -> bytes: global cid_counter while True: @@ -304,7 +278,7 @@ def _is_session_id_unique(channel: ChannelCache) -> bool: def _is_cid_unique() -> bool: - for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS: + for session in _SESSIONS: if cid_counter == _get_cid(session): return False return True @@ -314,53 +288,6 @@ def _get_cid(session: SessionThpCache) -> int: return int.from_bytes(session.session_id[2:], "big") -def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache: - if len(session_id) != SESSION_ID_LENGTH: - raise ValueError("session_id must be X bytes long, where X=", SESSION_ID_LENGTH) - global _active_session_idx - global _is_active_session_authenticated - global _next_unauthenicated_session_index - - i = _next_unauthenicated_session_index - _UNAUTHENTICATED_SESSIONS[i] = SessionThpCache() - _UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id) - _next_unauthenicated_session_index += 1 - if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT: - _next_unauthenicated_session_index = 0 - - # Set session as active if and only if there is no active session - if _active_session_idx is None: - _active_session_idx = i - _is_active_session_authenticated = False - return _UNAUTHENTICATED_SESSIONS[i] - - -def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None: - for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): - if unauth_session == _UNAUTHENTICATED_SESSIONS[i]: - return i - return None - - -def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache: - unauth_session_idx = get_unauth_session_index(unauth_session) - if unauth_session_idx is None: - raise InvalidSessionError - - # replace least recently used authenticated session by the new session - new_auth_session_index = get_least_recently_used_authetnicated_session_index() - - _SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx] - _UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear() - - _SESSIONS[new_auth_session_index].last_usage = _get_usage_counter_and_increment() - return _SESSIONS[new_auth_session_index] - - -def get_least_recently_used_authetnicated_session_index() -> int: - return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT) - - def get_least_recently_used_item( list: list[ChannelCache] | list[SessionThpCache], max_count: int ): @@ -373,71 +300,9 @@ def get_least_recently_used_item( return lru_item_index -# The function start_session should not be used in production code. It is present only to assure compatibility with old tests. -def start_session(session_id: bytes | None) -> bytes: # TODO incomplete - global _active_session_idx - global _is_active_session_authenticated - - if session_id is not None: - if get_active_session_id() == session_id: - return session_id - for index in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[index].session_id == session_id: - _active_session_idx = index - _is_active_session_authenticated = True - return session_id - for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): - if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id: - _active_session_idx = index - _is_active_session_authenticated = False - return session_id - - channel = get_new_unauthenticated_channel(b"\x00") - - new_session_id = get_next_session_id(channel) - - new_session = create_new_unauthenticated_session(new_session_id) - - index = get_unauth_session_index(new_session) - _active_session_idx = index - _is_active_session_authenticated = False - - return new_session_id - - -def start_existing_session(session_id: bytes) -> bytes: - global _active_session_idx - global _is_active_session_authenticated - - if session_id is None: - raise ValueError("session_id cannot be None") - if get_active_session_id() == session_id: - return session_id - for index in range(_MAX_SESSIONS_COUNT): - if _SESSIONS[index].session_id == session_id: - _active_session_idx = index - _is_active_session_authenticated = True - return session_id - for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT): - if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id: - _active_session_idx = index - _is_active_session_authenticated = False - return session_id - raise ValueError("There is no active session with provided session_id") - - -def end_current_session() -> None: - global _active_session_idx - active_session = get_active_session() - if active_session is None: - return - active_session.clear() - _active_session_idx = None - - def get_int_all_sessions(key: int) -> builtins.set[int]: values = builtins.set() - for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ? + for session in _SESSIONS: encoded = session.get(key) if encoded is not None: values.add(int.from_bytes(encoded, "big")) @@ -445,7 +310,5 @@ def get_int_all_sessions(key: int) -> builtins.set[int]: def clear_all() -> None: - global _active_session_idx - _active_session_idx = None - for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS: + for session in _SESSIONS: session.clear() diff --git a/core/src/storage/device.py b/core/src/storage/device.py index cf6ba0e92..145e5d9de 100644 --- a/core/src/storage/device.py +++ b/core/src/storage/device.py @@ -3,6 +3,9 @@ from typing import TYPE_CHECKING import storage.cache as storage_cache from storage import common +from trezor.wire import context + +from apps.common import cache if TYPE_CHECKING: from trezor.enums import BackupType @@ -314,7 +317,7 @@ def set_safety_check_level(level: StorageSafetyCheckLevel) -> None: common.set_uint8(_NAMESPACE, _SAFETY_CHECK_LEVEL, level) -@storage_cache.stored(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES) +@cache.stored(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES) def _get_experimental_features() -> bytes: if common.get_bool(_NAMESPACE, _EXPERIMENTAL_FEATURES): return b"\x01" @@ -328,7 +331,7 @@ def get_experimental_features() -> bool: def set_experimental_features(enabled: bool) -> None: cached_bytes = b"\x01" if enabled else b"" - storage_cache.set(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES, cached_bytes) + context.cache_set(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES, cached_bytes) common.set_true_or_delete(_NAMESPACE, _EXPERIMENTAL_FEATURES, enabled) diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index f35d28099..5458dc949 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -15,6 +15,8 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ from typing import TYPE_CHECKING +from storage import cache +from storage.cache import SESSIONLESS_FLAG from trezor import log, loop, protobuf from trezor.wire import codec_v1 @@ -159,6 +161,32 @@ class CodecContext(Context): memoryview(buffer)[:msg_size], ) + # ACCESS TO CACHE + + if TYPE_CHECKING: + T = TypeVar("T") + + @overload + def cache_get(self, key: int) -> bytes | None: ... + + @overload + def cache_get(self, key: int, default: T) -> bytes | T: ... + + def cache_get(self, key: int, default: T | None = None) -> bytes | T | None: + return cache.get(key, default) + + def cache_get_int(self, key: int, default: T | None = None) -> int | T | None: + return cache.get_int(key, default) + + def cache_is_set(self, key: int) -> bool: + return cache.is_set(key) + + def cache_set(self, key: int, value: bytes) -> None: + cache.set(key, value) + + def cache_set_int(self, key: int, value: int) -> None: + cache.set_int(key, value) + CURRENT_CONTEXT: Context | None = None @@ -268,3 +296,49 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator: send_exc = e else: send_exc = None + + +# ACCESS TO CACHE + +if TYPE_CHECKING: + T = TypeVar("T") + + @overload + def cache_get(key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def cache_get(key: int, default: T) -> bytes | T: # noqa: F811 + ... + + +def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811 + if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: + return cache.get_sessionless(key, default) + return CURRENT_CONTEXT.cache_get(key, default) + + +def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811 + if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: + return cache.get_int_sessionless(key, default) + return CURRENT_CONTEXT.cache_get_int(key, default) + + +def cache_is_set(key: int) -> bool: + if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: + return cache.is_set_sessionless(key) + return CURRENT_CONTEXT.cache_is_set(key) + + +def cache_set(key: int, value: bytes) -> None: + if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: + cache.set_sessionless(key, value) + return + CURRENT_CONTEXT.cache_set(key, value) + + +def cache_set_int(key: int, value: int) -> None: + if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG: + cache.set_int_sessionless(key, value) + return + CURRENT_CONTEXT.cache_set_int(key, value) diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index 0c6f642fe..81a60ce0d 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from typing import Container, TypeVar, overload LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + T = TypeVar("T") class Message: @@ -69,6 +70,24 @@ class Context: async def write(self, msg: protobuf.MessageType) -> None: ... + if TYPE_CHECKING: + + @overload + def cache_get(self, key: int) -> bytes | None: ... + + @overload + def cache_get(self, key: int, default: T) -> bytes | T: ... + + def cache_get(self, key: int, default: T | None = None) -> bytes | T | None: ... + + def cache_get_int(self, key: int, default: T | None = None) -> int | T | None: ... + + def cache_is_set(self, key: int) -> bool: ... + + def cache_set(self, key: int, value: bytes) -> None: ... + + def cache_set_int(self, key: int, value: int) -> None: ... + class WireError(Exception): pass diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py index b06fa576d..09e105637 100644 --- a/core/src/trezor/wire/thp/__init__.py +++ b/core/src/trezor/wire/thp/__init__.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] if TYPE_CHECKING: from enum import IntEnum from trezorio import WireInterface - from typing import Protocol + from typing import Protocol, TypeVar, overload from storage.cache_thp import ChannelCache from trezor import loop, protobuf, utils @@ -11,6 +11,8 @@ if TYPE_CHECKING: from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.session_context import SessionContext + T = TypeVar("T") + class ChannelContext(Protocol): buffer: utils.BufferType iface: WireInterface @@ -40,6 +42,18 @@ if TYPE_CHECKING: def get_channel_id_int(self) -> int: ... + @overload + def cache_get(self, key: int) -> bytes | None: ... + + @overload + def cache_get(self, key: int, default: T) -> bytes | T: ... + + def cache_get(self, key: int, default: T | None = None) -> bytes | T | None: ... + + def cache_is_set(self, key: int) -> bool: ... + + def cache_set(self, key: int, value: bytes) -> None: ... + else: IntEnum = object diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 9f33b9047..744978047 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -22,7 +22,8 @@ if __debug__: from . import state_to_str if TYPE_CHECKING: - from trezorio import WireInterface # pyright: ignore[reportMissingImports] + from trezorio import WireInterface + from typing import TypeVar, overload from . import ChannelContext, PairingContext from .session_context import SessionContext @@ -173,6 +174,7 @@ class Channel: async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: if __debug__: log.debug(__name__, "write message: %s", msg.MESSAGE_NAME) + self.buffer = memory_manager.get_write_buffer(self.buffer, msg) noise_payload_len = memory_manager.encode_into_buffer( self.buffer, msg, session_id ) @@ -274,3 +276,33 @@ class Channel: async def _wait_for_ack(self) -> None: await loop.sleep(1000) + + # ACCESS TO CACHE + + if TYPE_CHECKING: + T = TypeVar("T") + + @overload + def cache_get(self, key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def cache_get(self, key: int, default: T) -> bytes | T: # noqa: F811 + ... + + def cache_get( + self, key: int, default: T | None = None + ) -> bytes | T | None: # noqa: F811 + utils.ensure(key < len(self.channel_cache.fields)) + if self.channel_cache.data[key][0] != 1: + return default + return bytes(self.channel_cache.data[key][1:]) + + def cache_is_set(self, key: int) -> bool: + return self.channel_cache.is_set(key) + + def cache_set(self, key: int, value: bytes) -> None: + utils.ensure(key < len(self.channel_cache.fields)) + utils.ensure(len(value) <= self.channel_cache.fields[key]) + self.channel_cache.data[key][0] = 1 + self.channel_cache.data[key][1:] = value diff --git a/core/src/trezor/wire/thp/handler_provider.py b/core/src/trezor/wire/thp/handler_provider.py index 68d170442..59e9a969b 100644 --- a/core/src/trezor/wire/thp/handler_provider.py +++ b/core/src/trezor/wire/thp/handler_provider.py @@ -1,16 +1,34 @@ from typing import TYPE_CHECKING from trezor import protobuf +from trezor.enums import MessageType +from trezor.wire.errors import UnexpectedMessage from apps.thp import create_session if TYPE_CHECKING: from typing import Any, Callable, Coroutine + from trezor.messages import LoadDevice + + from . import ChannelContext + pass def get_handler_for_channel_message( msg: protobuf.MessageType, ) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: - return create_session.create_new_session + if msg.MESSAGE_WIRE_TYPE is MessageType.ThpCreateNewSession: + return create_session.create_new_session + if __debug__: + if msg.MESSAGE_WIRE_TYPE is MessageType.LoadDevice: + from apps.debug.load_device import load_device + + def wrapper( + channel: ChannelContext, msg: LoadDevice + ) -> Coroutine[Any, Any, protobuf.MessageType]: + return load_device(msg) + + return wrapper + raise UnexpectedMessage("There is no handler available for this message") diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py index a230e0ed6..c51e3b78d 100644 --- a/core/src/trezor/wire/thp/memory_manager.py +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -45,6 +45,19 @@ def select_buffer( raise Exception("Failed to create a buffer for channel") # TODO handle better +def get_write_buffer( + buffer: utils.BufferType, msg: protobuf.MessageType +) -> utils.BufferType: + msg_size = protobuf.encoded_length(msg) + payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size + required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH + + if required_min_size > len(buffer): + # message is too big, we need to allocate a new buffer + return bytearray(required_min_size) + return buffer + + def encode_into_buffer( buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int ) -> int: @@ -54,11 +67,6 @@ def encode_into_buffer( msg_size = protobuf.encoded_length(msg) payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size - required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH - - if required_min_size > len(buffer): - # message is too big, we need to allocate a new buffer - buffer = bytearray(required_min_size) _encode_session_into_buffer(memoryview(buffer), session_id) _encode_message_type_into_buffer( diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index f82f22b6a..cefda1256 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -35,6 +35,8 @@ if TYPE_CHECKING: from . import ChannelContext if __debug__: + from trezor.messages import LoadDevice + from . import state_to_str @@ -237,7 +239,7 @@ async def _handle_state_TH2( ) # TODO add credential recognition - paired: bool = True # TODO should be output from credential check + paired: bool = False # TODO should be output from credential check # send hanshake completion response await ctx.write_handshake_message( @@ -334,7 +336,7 @@ async def _handle_channel_message( expected_type = protobuf.type_for_wire(message_type) message = message_handler.wrap_protobuf_load(buf, expected_type) - if not ThpCreateNewSession.is_type_of(message): + if not _is_channel_message(message): raise ThpError( "The received message cannot be handled by channel itself. It must be sent to allocated session." ) @@ -348,3 +350,9 @@ async def _handle_channel_message( await ctx.write(response_message) if __debug__: log.debug(__name__, "_handle_channel_message - end") + + +def _is_channel_message(message) -> bool: + if __debug__: + return ThpCreateNewSession.is_type_of(message) or LoadDevice.is_type_of(message) + return ThpCreateNewSession.is_type_of(message) diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 5a7c54546..34cb68e55 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from storage.cache_thp import SessionThpCache -from trezor import log, loop, protobuf +from trezor import log, loop, protobuf, utils from trezor.wire import message_handler, protocol_common from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure @@ -13,6 +13,8 @@ if TYPE_CHECKING: Any, Awaitable, Container, + TypeVar, + overload, ) from . import ChannelContext @@ -160,3 +162,33 @@ class SessionContext(Context): def set_session_state(self, state: SessionState) -> None: self.session_cache.state = bytearray(state.to_bytes(1, "big")) + + # ACCESS TO CACHE + + if TYPE_CHECKING: + T = TypeVar("T") + + @overload + def cache_get(self, key: int) -> bytes | None: # noqa: F811 + ... + + @overload + def cache_get(self, key: int, default: T) -> bytes | T: # noqa: F811 + ... + + def cache_get( + self, key: int, default: T | None = None + ) -> bytes | T | None: # noqa: F811 + utils.ensure(key < len(self.session_cache.fields)) + if self.session_cache.data[key][0] != 1: + return default + return bytes(self.session_cache.data[key][1:]) + + def cache_is_set(self, key: int) -> bool: + return self.session_cache.is_set(key) + + def cache_set(self, key: int, value: bytes) -> None: + utils.ensure(key < len(self.session_cache.fields)) + utils.ensure(len(value) <= self.session_cache.fields[key]) + self.session_cache.data[key][0] = 1 + self.session_cache.data[key][1:] = value diff --git a/core/src/trezor/wire/thp/thp_session.py b/core/src/trezor/wire/thp/thp_session.py index 451b6bd60..ec73b8e56 100644 --- a/core/src/trezor/wire/thp/thp_session.py +++ b/core/src/trezor/wire/thp/thp_session.py @@ -19,38 +19,23 @@ class ThpError(WireError): class SessionState(IntEnum): UNALLOCATED = 0 - INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1 + INITIALIZED = 1 # do not change, it is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1 PAIRED = 2 UNPAIRED = 3 PAIRING = 4 APP_TRAFFIC = 5 -def create_autenticated_session(unauthenticated_session: SessionThpCache): - # storage_thp_cache.start_session() - TODO something like this but for THP - raise NotImplementedError("Secure channel is not implemented, yet.") - - -def create_new_unauthenticated_session(iface: WireInterface, cid: int): - session_id = _get_id(iface, cid) - new_session = storage_thp_cache.create_new_unauthenticated_session(session_id) - set_session_state(new_session, SessionState.INITIALIZED) - - -def get_active_session() -> SessionThpCache | None: - return storage_thp_cache.get_active_session() - - def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None: session_id = _get_id(iface, cid) return get_session_from_id(session_id) def get_session_from_id(session_id) -> SessionThpCache | None: - session = _get_authenticated_session_or_none(session_id) - if session is None: - session = _get_unauthenticated_session_or_none(session_id) - return session + for session in storage_thp_cache._SESSIONS: + if session.session_id == session_id: + return session + return None def get_state(session: SessionThpCache | None) -> int: @@ -101,12 +86,6 @@ def sync_set_send_bit_to_opposite(cache: SessionThpCache | ChannelCache) -> None _sync_set_send_bit(cache=cache, bit=1 - sync_get_send_bit(cache)) -def is_active_session(session: SessionThpCache): - if session is None: - return False - return session.session_id == storage_thp_cache.get_active_session_id() - - def set_session_state(session: SessionThpCache, new_state: SessionState): session.state = bytearray(new_state.to_bytes(1, "big")) @@ -115,20 +94,6 @@ def _get_id(iface: WireInterface, cid: int) -> bytes: return ustruct.pack(">HH", iface.iface_num(), cid) -def _get_authenticated_session_or_none(session_id) -> SessionThpCache | None: - for authenticated_session in storage_thp_cache._SESSIONS: - if authenticated_session.session_id == session_id: - return authenticated_session - return None - - -def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache | None: - for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS: - if unauthenticated_session.session_id == session_id: - return unauthenticated_session - return None - - def _sync_set_send_bit(cache: SessionThpCache | ChannelCache, bit: int) -> None: if bit not in (0, 1): raise ThpError("Unexpected send sync bit")