feat(core): implement cache handling of passphrase, refactor cache

M1nd3r/thp6
M1nd3r 2 weeks ago
parent ad74d7f598
commit 2f0a7ec740

@ -325,6 +325,8 @@ apps.common.address_type
import apps.common.address_type
apps.common.authorization
import apps.common.authorization
apps.common.cache
import apps.common.cache
apps.common.cbor
import apps.common.cbor
apps.common.coininfo

@ -6,6 +6,7 @@ from trezor import TR, config, utils, wire, workflow
from trezor.enums import HomescreenFormat, MessageType
from trezor.messages import Success, UnlockPath
from trezor.ui.layouts import confirm_action
from trezor.wire import context
from . import workflow_handlers
@ -33,7 +34,7 @@ def busy_expiry_ms() -> int:
Returns the time left until the busy state expires or 0 if the device is not in the busy state.
"""
busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
busy_deadline_ms = context.cache_get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
if busy_deadline_ms is None:
return 0
@ -175,7 +176,7 @@ def get_features() -> Features:
return f
# handle_Initialize should not be used with THP to start a new session
@storage_cache.check_thp_is_not_used
async def handle_Initialize(msg: Initialize) -> Features:
if utils.USE_THP:
raise ValueError("With THP enabled, a session id must be provided in args")
@ -183,8 +184,8 @@ async def handle_Initialize(msg: Initialize) -> Features:
session_id = storage_cache.start_session(msg.session_id)
if not utils.BITCOIN_ONLY:
derive_cardano = storage_cache.get(storage_cache.APP_COMMON_DERIVE_CARDANO)
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
derive_cardano = context.cache_get(storage_cache.APP_COMMON_DERIVE_CARDANO)
have_seed = context.cache_is_set(storage_cache.APP_COMMON_SEED)
if (
have_seed
@ -194,11 +195,11 @@ async def handle_Initialize(msg: Initialize) -> Features:
# seed is already derived, and host wants to change derive_cardano setting
# => create a new session
storage_cache.end_current_session()
session_id = storage_cache.start_session() # This should not be used in THP
session_id = storage_cache.start_session()
have_seed = False
if not have_seed:
storage_cache.set(
context.cache_set(
storage_cache.APP_COMMON_DERIVE_CARDANO,
b"\x01" if msg.derive_cardano else b"",
)
@ -229,7 +230,7 @@ async def handle_SetBusy(msg: SetBusy) -> Success:
import utime
deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms)
storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline)
context.cache_set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline)
else:
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen()
@ -338,7 +339,7 @@ def set_homescreen() -> None:
set_default = workflow.set_default # local_cache_attribute
if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS):
if context.cache_is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS):
from apps.homescreen import busyscreen
set_default(busyscreen)

@ -1,7 +1,7 @@
from micropython import const
from typing import TYPE_CHECKING
from trezor.wire import DataError
from trezor.wire import DataError, context
from .. import writers
@ -42,7 +42,7 @@ class PaymentRequestVerifier:
if msg.nonce:
nonce = bytes(msg.nonce)
if cache.get(cache.APP_COMMON_NONCE) != nonce:
if context.cache_get(cache.APP_COMMON_NONCE) != nonce:
raise DataError("Invalid nonce in payment request.")
cache.delete(cache.APP_COMMON_NONCE)
else:

@ -15,6 +15,7 @@ if TYPE_CHECKING:
from trezor import messages
from trezor.crypto import bip32
from trezor.enums import CardanoDerivationType
from trezor.wire.protocol_common import Context
from apps.common.keychain import Handler, MsgOut
from apps.common.paths import Bip32Path
@ -110,9 +111,9 @@ def is_minting_path(path: Bip32Path) -> bool:
return path[: len(MINTING_ROOT)] == MINTING_ROOT
def derive_and_store_secrets(passphrase: str) -> None:
def derive_and_store_secrets(ctx: Context, passphrase: str) -> None:
assert device.is_initialized()
assert cache.get(cache.APP_COMMON_DERIVE_CARDANO)
assert ctx.cache_get(cache.APP_COMMON_DERIVE_CARDANO)
if not mnemonic.is_bip39():
# nothing to do for SLIP-39, where we can derive the root from the main seed
@ -132,8 +133,8 @@ def derive_and_store_secrets(passphrase: str) -> None:
else:
icarus_trezor_secret = icarus_secret
cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret)
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
ctx.cache_set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret)
ctx.cache_set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:

@ -3,6 +3,7 @@ from typing import Iterable
import storage.cache as storage_cache
from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire import context
WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
@ -17,7 +18,7 @@ APP_COMMON_AUTHORIZATION_TYPE = (
def is_set() -> bool:
return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None
return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None
def set(auth_message: protobuf.MessageType) -> None:
@ -29,16 +30,16 @@ def set(auth_message: protobuf.MessageType) -> None:
# (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer)
context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type:
return None
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"")
buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"")
return protobuf.load_message_buffer(buffer, stored_auth_type)
@ -49,7 +50,7 @@ def is_set_any_session(auth_type: MessageType) -> bool:
def get_wire_types() -> Iterable[int]:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE)
stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if stored_auth_type is None:
return ()

@ -0,0 +1,23 @@
from typing import TYPE_CHECKING
from trezor.wire import context
if TYPE_CHECKING:
from typing import Callable, ParamSpec
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs):
value = context.cache_get(key)
if value is None:
value = func(*args, **kwargs)
context.cache_set(key, value)
return value
return wrapper
return decorator

@ -4,6 +4,7 @@ from typing import Any, NoReturn
import storage.cache as storage_cache
from trezor import TR, config, utils, wire
from trezor.ui.layouts import show_error_and_raise
from trezor.wire import context
async def _request_sd_salt(
@ -77,7 +78,7 @@ async def request_pin_and_sd_salt(
def _set_last_unlock_time() -> None:
now = utime.ticks_ms()
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
context.cache_set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
_DEF_ARG_PIN_ENTER: str = TR.pin__enter
@ -91,7 +92,7 @@ async def verify_user_pin(
) -> None:
# _get_last_unlock_time
last_unlock = int.from_bytes(
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
context.cache_get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
)
if (

@ -3,13 +3,14 @@ import storage.device as storage_device
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
from trezor.enums import SafetyCheckLevel
from trezor.wire import context
def read_setting() -> SafetyCheckLevel:
"""
Returns the effective safety check level.
"""
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
if temporary_safety_check_level:
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
else:
@ -34,7 +35,7 @@ def apply_setting(level: SafetyCheckLevel) -> None:
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily:
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else:
raise ValueError("Unknown SafetyCheckLevel")

@ -2,14 +2,20 @@ from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.device as storage_device
from trezor import utils
from trezor import log, utils
from trezor.crypto import hmac
from trezor.wire import context
from trezor.wire.context import get_context
from apps.common import cache
from . import mnemonic
from .passphrase import get as get_passphrase
if TYPE_CHECKING:
from trezor.crypto import bip32
from trezor.messages import ThpCreateNewSession
from trezor.wire.protocol_common import Context
from .paths import Bip32Path, Slip21Path
@ -45,54 +51,56 @@ class Slip21Node:
return Slip21Node(data=self.data)
async def get_seed() -> bytes:
common_seed = context.cache_get(storage_cache.APP_COMMON_SEED)
assert common_seed is not None
return common_seed
if not utils.BITCOIN_ONLY:
# === Cardano variant ===
# We want to derive both the normal seed and the Cardano seed together, AND
# expose a method for Cardano to do the same
async def derive_and_store_roots() -> None:
async def derive_and_store_roots(
ctx: Context | None = None, msg: ThpCreateNewSession | None = None
) -> None:
if __debug__:
log.debug(__name__, "derive_and_store_roots start")
from trezor import wire
if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized")
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED)
need_cardano_secret = storage_cache.get(
# For old codec_v1 implementation, the context is passed using get_context
# This handling is specific. In the rest of the code, a context.cache_* is used instead
if ctx is None:
ctx = get_context()
need_seed = not ctx.cache_is_set(storage_cache.APP_COMMON_SEED)
need_cardano_secret = ctx.cache_get(
storage_cache.APP_COMMON_DERIVE_CARDANO
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
) and not ctx.cache_is_set(storage_cache.APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret:
return
passphrase = await get_passphrase()
if msg is None or msg.on_device:
passphrase = await get_passphrase()
else:
passphrase = msg.passphrase or ""
if need_seed:
common_seed = mnemonic.get_seed(passphrase)
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed)
ctx.cache_set(storage_cache.APP_COMMON_SEED, common_seed)
if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase)
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed() -> bytes:
await derive_and_store_roots()
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED)
assert common_seed is not None
return common_seed
else:
# === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else.
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED)
async def get_seed() -> bytes:
passphrase = await get_passphrase()
return mnemonic.get_seed(passphrase)
derive_and_store_secrets(ctx, passphrase)
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
@cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes:
if not storage_device.is_initialized():
raise Exception("Device is not initialized")

@ -59,6 +59,7 @@ async def cosi_commit(msg: CosiCommit) -> CosiSignature:
from trezor.crypto import cosi
from trezor.crypto.curve import ed25519
from trezor.ui.layouts import confirm_blob, confirm_text
from trezor.wire import context
from trezor.wire.context import call
from apps.common import paths
@ -71,11 +72,11 @@ async def cosi_commit(msg: CosiCommit) -> CosiSignature:
seckey = node.private_key()
pubkey = ed25519.publickey(seckey)
if not storage_cache.is_set(storage_cache.APP_MISC_COSI_COMMITMENT):
if not context.cache_is_set(storage_cache.APP_MISC_COSI_COMMITMENT):
nonce, commitment = cosi.commit()
storage_cache.set(storage_cache.APP_MISC_COSI_NONCE, nonce)
storage_cache.set(storage_cache.APP_MISC_COSI_COMMITMENT, commitment)
commitment = storage_cache.get(storage_cache.APP_MISC_COSI_COMMITMENT)
context.cache_set(storage_cache.APP_MISC_COSI_NONCE, nonce)
context.cache_set(storage_cache.APP_MISC_COSI_COMMITMENT, commitment)
commitment = context.cache_get(storage_cache.APP_MISC_COSI_COMMITMENT)
if commitment is None:
raise RuntimeError
@ -101,7 +102,7 @@ async def cosi_commit(msg: CosiCommit) -> CosiSignature:
)
# clear nonce from cache
nonce = storage_cache.get(storage_cache.APP_MISC_COSI_NONCE)
nonce = context.cache_get(storage_cache.APP_MISC_COSI_NONCE)
storage_cache.delete(storage_cache.APP_MISC_COSI_COMMITMENT)
storage_cache.delete(storage_cache.APP_MISC_COSI_NONCE)
if nonce is None:

@ -59,14 +59,15 @@ async def _init_step(
) -> MoneroLiveRefreshStartAck:
import storage.cache as storage_cache
from trezor.messages import MoneroLiveRefreshStartAck
from trezor.wire import context
from apps.common import paths
await paths.validate_path(keychain, msg.address_n)
if not storage_cache.get(storage_cache.APP_MONERO_LIVE_REFRESH):
if not context.cache_get(storage_cache.APP_MONERO_LIVE_REFRESH):
await layout.require_confirm_live_refresh()
storage_cache.set(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01")
context.cache_set(storage_cache.APP_MONERO_LIVE_REFRESH, b"\x01")
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

@ -14,7 +14,11 @@ async def create_new_session(
# from apps.common.seed import get_seed TODO
from trezor.wire.thp.session_manager import create_new_session
from apps.common.seed import derive_and_store_roots
session = create_new_session(channel)
await derive_and_store_roots(session, message)
session.set_session_state(SessionState.ALLOCATED)
channel.sessions[session.session_id] = session
loop.schedule(session.handle())

@ -9,9 +9,25 @@ from trezor import utils
SESSIONLESS_FLAG = const(128)
if TYPE_CHECKING:
from typing import TypeVar, overload
from typing import Callable, ParamSpec, TypeVar, overload
T = TypeVar("T")
P = ParamSpec("P")
def check_thp_is_not_used(f: Callable[P, T]) -> Callable[P, T]:
"""A type-safe decorator to raise an exception when the function is called with THP enabled.
This decorator should be removed after the caches for Codec_v1 and THP are properly refactored and separated.
"""
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
if utils.USE_THP:
raise Exception("Cannot call this function with the new THP enabled")
return f(*args, **kwargs)
return inner
# Traditional cache keys
APP_COMMON_SEED = const(0)
@ -74,6 +90,18 @@ _SESSIONLESS_CACHE.clear()
gc.collect()
if TYPE_CHECKING:
@overload
def get(key: int) -> bytes | None: ...
@overload
def get(key: int, default: T) -> bytes | T: # noqa: F811
...
# Common functions
def clear_all() -> None:
global autolock_last_touch
@ -82,42 +110,99 @@ def clear_all() -> None:
_PROTOCOL_CACHE.clear_all()
def get_int_all_sessions(key: int) -> builtins.set[int]:
if key & SESSIONLESS_FLAG:
values = builtins.set()
encoded = _SESSIONLESS_CACHE.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
return _PROTOCOL_CACHE.get_int_all_sessions(key)
# Sessionless functions
def get_sessionless(
key: int, default: T | None = None
) -> bytes | T | None: # noqa: F811
if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default)
raise ValueError("Argument 'key' does not have a sessionless flag")
def get_int_sessionless(
key: int, default: T | None = None
) -> int | T | None: # noqa: F811
encoded = get_sessionless(key)
if encoded is None:
return default
else:
return int.from_bytes(encoded, "big")
def is_set_sessionless(key: int) -> bool:
if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG)
raise ValueError("Argument 'key' does not have a sessionless flag")
def set_sessionless(key: int, value: bytes) -> None:
if key & SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value)
return
raise ValueError("Argument 'key' does not have a sessionless flag")
def set_int_sessionless(key: int, value: int) -> None:
if not key & SESSIONLESS_FLAG:
raise ValueError("Argument 'key' does not have a sessionless flag")
length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG]
encoded = value.to_bytes(length, "big")
# Ensure that the value fits within the length. Micropython's int.to_bytes()
# doesn't raise OverflowError.
assert int.from_bytes(encoded, "big") == value
set_sessionless(key, encoded)
# Codec_v1 specific functions
@check_thp_is_not_used
def start_session(received_session_id: bytes | None = None) -> bytes:
return _PROTOCOL_CACHE.start_session(received_session_id)
return cache_codec.start_session(received_session_id)
@check_thp_is_not_used
def end_current_session() -> None:
_PROTOCOL_CACHE.end_current_session()
cache_codec.end_current_session()
@check_thp_is_not_used
def delete(key: int) -> None:
if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.delete(key ^ SESSIONLESS_FLAG)
active_session = _PROTOCOL_CACHE.get_active_session()
active_session = cache_codec.get_active_session()
if active_session is None:
raise InvalidSessionError
return active_session.delete(key)
if TYPE_CHECKING:
@overload
def get(key: int) -> bytes | None: ...
@overload
def get(key: int, default: T) -> bytes | T: # noqa: F811
...
@check_thp_is_not_used
def get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.get(key ^ SESSIONLESS_FLAG, default)
active_session = _PROTOCOL_CACHE.get_active_session()
return get_sessionless(key, default)
active_session = cache_codec.get_active_session()
if active_session is None:
raise InvalidSessionError
return active_session.get(key, default)
@check_thp_is_not_used
def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
encoded = get(key)
if encoded is None:
@ -126,37 +211,30 @@ def get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
return int.from_bytes(encoded, "big")
def get_int_all_sessions(key: int) -> builtins.set[int]:
if key & SESSIONLESS_FLAG:
values = builtins.set()
encoded = _SESSIONLESS_CACHE.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
return values
return _PROTOCOL_CACHE.get_int_all_sessions(key)
@check_thp_is_not_used
def is_set(key: int) -> bool:
if key & SESSIONLESS_FLAG:
return _SESSIONLESS_CACHE.is_set(key ^ SESSIONLESS_FLAG)
active_session = _PROTOCOL_CACHE.get_active_session()
active_session = cache_codec.get_active_session()
if active_session is None:
raise InvalidSessionError
return active_session.is_set(key)
@check_thp_is_not_used
def set(key: int, value: bytes) -> None:
if key & SESSIONLESS_FLAG:
_SESSIONLESS_CACHE.set(key ^ SESSIONLESS_FLAG, value)
return
active_session = _PROTOCOL_CACHE.get_active_session()
active_session = cache_codec.get_active_session()
if active_session is None:
raise InvalidSessionError
active_session.set(key, value)
@check_thp_is_not_used
def set_int(key: int, value: int) -> None:
active_session = _PROTOCOL_CACHE.get_active_session()
active_session = cache_codec.get_active_session()
if key & SESSIONLESS_FLAG:
length = _SESSIONLESS_CACHE.fields[key ^ SESSIONLESS_FLAG]
@ -172,39 +250,3 @@ def set_int(key: int, value: int) -> None:
assert int.from_bytes(encoded, "big") == value
set(key, encoded)
if TYPE_CHECKING:
from typing import Awaitable, Callable, ParamSpec, TypeVar
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
AsyncByteFunc = Callable[P, Awaitable[bytes]]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs):
value = get(key)
if value is None:
value = func(*args, **kwargs)
set(key, value)
return value
return wrapper
return decorator
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs):
value = get(key)
if value is None:
value = await func(*args, **kwargs)
set(key, value)
return value
return wrapper
return decorator

@ -2,7 +2,7 @@ import builtins
from micropython import const # pyright: ignore[reportMissingModuleSource]
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage.cache_common import DataCache, InvalidSessionError
from storage.cache_common import DataCache
from trezor import utils
if TYPE_CHECKING:
@ -96,30 +96,22 @@ class SessionThpCache(ConnectionCache):
_CHANNELS: list[ChannelCache] = []
_SESSIONS: list[SessionThpCache] = []
_UNAUTHENTICATED_SESSIONS: list[SessionThpCache] = [] # TODO remove/replace
def initialize() -> None:
global _CHANNELS
global _SESSIONS
global _UNAUTHENTICATED_SESSIONS
for _ in range(_MAX_CHANNELS_COUNT):
_CHANNELS.append(ChannelCache())
for _ in range(_MAX_SESSIONS_COUNT):
_SESSIONS.append(SessionThpCache())
for _ in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
_UNAUTHENTICATED_SESSIONS.append(SessionThpCache())
for channel in _CHANNELS:
channel.clear()
for session in _SESSIONS:
session.clear()
for session in _UNAUTHENTICATED_SESSIONS:
session.clear()
initialize()
@ -128,8 +120,6 @@ initialize()
_next_unauthenicated_session_index: int = 0 # TODO remove
# First unauthenticated channel will have index 0
_is_active_session_authenticated: bool
_active_session_idx: int | None = None
_usage_counter = 0
# with this (arbitrary) value=4659, the first allocated channel will have cid=1234 (hex)
@ -256,22 +246,6 @@ def _get_session_state(session: SessionThpCache) -> int:
return int.from_bytes(session.state, "big")
def get_active_session_id() -> bytearray | None:
active_session = get_active_session()
if active_session is None:
return None
return active_session.session_id
def get_active_session() -> SessionThpCache | None:
if _active_session_idx is None:
return None
if _is_active_session_authenticated:
return _SESSIONS[_active_session_idx]
return _UNAUTHENTICATED_SESSIONS[_active_session_idx]
def get_next_channel_id() -> bytes:
global cid_counter
while True:
@ -304,7 +278,7 @@ def _is_session_id_unique(channel: ChannelCache) -> bool:
def _is_cid_unique() -> bool:
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
for session in _SESSIONS:
if cid_counter == _get_cid(session):
return False
return True
@ -314,53 +288,6 @@ def _get_cid(session: SessionThpCache) -> int:
return int.from_bytes(session.session_id[2:], "big")
def create_new_unauthenticated_session(session_id: bytes) -> SessionThpCache:
if len(session_id) != SESSION_ID_LENGTH:
raise ValueError("session_id must be X bytes long, where X=", SESSION_ID_LENGTH)
global _active_session_idx
global _is_active_session_authenticated
global _next_unauthenicated_session_index
i = _next_unauthenicated_session_index
_UNAUTHENTICATED_SESSIONS[i] = SessionThpCache()
_UNAUTHENTICATED_SESSIONS[i].session_id = bytearray(session_id)
_next_unauthenicated_session_index += 1
if _next_unauthenicated_session_index >= _MAX_UNAUTHENTICATED_SESSIONS_COUNT:
_next_unauthenicated_session_index = 0
# Set session as active if and only if there is no active session
if _active_session_idx is None:
_active_session_idx = i
_is_active_session_authenticated = False
return _UNAUTHENTICATED_SESSIONS[i]
def get_unauth_session_index(unauth_session: SessionThpCache) -> int | None:
for i in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if unauth_session == _UNAUTHENTICATED_SESSIONS[i]:
return i
return None
def create_new_auth_session(unauth_session: SessionThpCache) -> SessionThpCache:
unauth_session_idx = get_unauth_session_index(unauth_session)
if unauth_session_idx is None:
raise InvalidSessionError
# replace least recently used authenticated session by the new session
new_auth_session_index = get_least_recently_used_authetnicated_session_index()
_SESSIONS[new_auth_session_index] = _UNAUTHENTICATED_SESSIONS[unauth_session_idx]
_UNAUTHENTICATED_SESSIONS[unauth_session_idx].clear()
_SESSIONS[new_auth_session_index].last_usage = _get_usage_counter_and_increment()
return _SESSIONS[new_auth_session_index]
def get_least_recently_used_authetnicated_session_index() -> int:
return get_least_recently_used_item(_SESSIONS, _MAX_SESSIONS_COUNT)
def get_least_recently_used_item(
list: list[ChannelCache] | list[SessionThpCache], max_count: int
):
@ -373,71 +300,9 @@ def get_least_recently_used_item(
return lru_item_index
# The function start_session should not be used in production code. It is present only to assure compatibility with old tests.
def start_session(session_id: bytes | None) -> bytes: # TODO incomplete
global _active_session_idx
global _is_active_session_authenticated
if session_id is not None:
if get_active_session_id() == session_id:
return session_id
for index in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = True
return session_id
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = False
return session_id
channel = get_new_unauthenticated_channel(b"\x00")
new_session_id = get_next_session_id(channel)
new_session = create_new_unauthenticated_session(new_session_id)
index = get_unauth_session_index(new_session)
_active_session_idx = index
_is_active_session_authenticated = False
return new_session_id
def start_existing_session(session_id: bytes) -> bytes:
global _active_session_idx
global _is_active_session_authenticated
if session_id is None:
raise ValueError("session_id cannot be None")
if get_active_session_id() == session_id:
return session_id
for index in range(_MAX_SESSIONS_COUNT):
if _SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = True
return session_id
for index in range(_MAX_UNAUTHENTICATED_SESSIONS_COUNT):
if _UNAUTHENTICATED_SESSIONS[index].session_id == session_id:
_active_session_idx = index
_is_active_session_authenticated = False
return session_id
raise ValueError("There is no active session with provided session_id")
def end_current_session() -> None:
global _active_session_idx
active_session = get_active_session()
if active_session is None:
return
active_session.clear()
_active_session_idx = None
def get_int_all_sessions(key: int) -> builtins.set[int]:
values = builtins.set()
for session in _SESSIONS: # Should there be _SESSIONS + _UNAUTHENTICATED_SESSIONS ?
for session in _SESSIONS:
encoded = session.get(key)
if encoded is not None:
values.add(int.from_bytes(encoded, "big"))
@ -445,7 +310,5 @@ def get_int_all_sessions(key: int) -> builtins.set[int]:
def clear_all() -> None:
global _active_session_idx
_active_session_idx = None
for session in _SESSIONS + _UNAUTHENTICATED_SESSIONS:
for session in _SESSIONS:
session.clear()

@ -3,6 +3,9 @@ from typing import TYPE_CHECKING
import storage.cache as storage_cache
from storage import common
from trezor.wire import context
from apps.common import cache
if TYPE_CHECKING:
from trezor.enums import BackupType
@ -314,7 +317,7 @@ def set_safety_check_level(level: StorageSafetyCheckLevel) -> None:
common.set_uint8(_NAMESPACE, _SAFETY_CHECK_LEVEL, level)
@storage_cache.stored(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES)
@cache.stored(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES)
def _get_experimental_features() -> bytes:
if common.get_bool(_NAMESPACE, _EXPERIMENTAL_FEATURES):
return b"\x01"
@ -328,7 +331,7 @@ def get_experimental_features() -> bool:
def set_experimental_features(enabled: bool) -> None:
cached_bytes = b"\x01" if enabled else b""
storage_cache.set(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES, cached_bytes)
context.cache_set(storage_cache.STORAGE_DEVICE_EXPERIMENTAL_FEATURES, cached_bytes)
common.set_true_or_delete(_NAMESPACE, _EXPERIMENTAL_FEATURES, enabled)

@ -15,6 +15,8 @@ for ButtonRequests. Of course, `context.wait()` transparently works in such situ
from typing import TYPE_CHECKING
from storage import cache
from storage.cache import SESSIONLESS_FLAG
from trezor import log, loop, protobuf
from trezor.wire import codec_v1
@ -159,6 +161,32 @@ class CodecContext(Context):
memoryview(buffer)[:msg_size],
)
# ACCESS TO CACHE
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(self, key: int) -> bytes | None: ...
@overload
def cache_get(self, key: int, default: T) -> bytes | T: ...
def cache_get(self, key: int, default: T | None = None) -> bytes | T | None:
return cache.get(key, default)
def cache_get_int(self, key: int, default: T | None = None) -> int | T | None:
return cache.get_int(key, default)
def cache_is_set(self, key: int) -> bool:
return cache.is_set(key)
def cache_set(self, key: int, value: bytes) -> None:
cache.set(key, value)
def cache_set_int(self, key: int, value: int) -> None:
cache.set_int(key, value)
CURRENT_CONTEXT: Context | None = None
@ -268,3 +296,49 @@ def with_context(ctx: Context, workflow: loop.Task) -> Generator:
send_exc = e
else:
send_exc = None
# ACCESS TO CACHE
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(key: int, default: T | None = None) -> bytes | T | None: # noqa: F811
if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG:
return cache.get_sessionless(key, default)
return CURRENT_CONTEXT.cache_get(key, default)
def cache_get_int(key: int, default: T | None = None) -> int | T | None: # noqa: F811
if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG:
return cache.get_int_sessionless(key, default)
return CURRENT_CONTEXT.cache_get_int(key, default)
def cache_is_set(key: int) -> bool:
if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG:
return cache.is_set_sessionless(key)
return CURRENT_CONTEXT.cache_is_set(key)
def cache_set(key: int, value: bytes) -> None:
if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG:
cache.set_sessionless(key, value)
return
CURRENT_CONTEXT.cache_set(key, value)
def cache_set_int(key: int, value: int) -> None:
if CURRENT_CONTEXT is None or key & SESSIONLESS_FLAG:
cache.set_int_sessionless(key, value)
return
CURRENT_CONTEXT.cache_set_int(key, value)

@ -7,6 +7,7 @@ if TYPE_CHECKING:
from typing import Container, TypeVar, overload
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
T = TypeVar("T")
class Message:
@ -69,6 +70,24 @@ class Context:
async def write(self, msg: protobuf.MessageType) -> None: ...
if TYPE_CHECKING:
@overload
def cache_get(self, key: int) -> bytes | None: ...
@overload
def cache_get(self, key: int, default: T) -> bytes | T: ...
def cache_get(self, key: int, default: T | None = None) -> bytes | T | None: ...
def cache_get_int(self, key: int, default: T | None = None) -> int | T | None: ...
def cache_is_set(self, key: int) -> bool: ...
def cache_set(self, key: int, value: bytes) -> None: ...
def cache_set_int(self, key: int, value: int) -> None: ...
class WireError(Exception):
pass

@ -3,7 +3,7 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
if TYPE_CHECKING:
from enum import IntEnum
from trezorio import WireInterface
from typing import Protocol
from typing import Protocol, TypeVar, overload
from storage.cache_thp import ChannelCache
from trezor import loop, protobuf, utils
@ -11,6 +11,8 @@ if TYPE_CHECKING:
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.session_context import SessionContext
T = TypeVar("T")
class ChannelContext(Protocol):
buffer: utils.BufferType
iface: WireInterface
@ -40,6 +42,18 @@ if TYPE_CHECKING:
def get_channel_id_int(self) -> int: ...
@overload
def cache_get(self, key: int) -> bytes | None: ...
@overload
def cache_get(self, key: int, default: T) -> bytes | T: ...
def cache_get(self, key: int, default: T | None = None) -> bytes | T | None: ...
def cache_is_set(self, key: int) -> bool: ...
def cache_set(self, key: int, value: bytes) -> None: ...
else:
IntEnum = object

@ -22,7 +22,8 @@ if __debug__:
from . import state_to_str
if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports]
from trezorio import WireInterface
from typing import TypeVar, overload
from . import ChannelContext, PairingContext
from .session_context import SessionContext
@ -173,6 +174,7 @@ class Channel:
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
if __debug__:
log.debug(__name__, "write message: %s", msg.MESSAGE_NAME)
self.buffer = memory_manager.get_write_buffer(self.buffer, msg)
noise_payload_len = memory_manager.encode_into_buffer(
self.buffer, msg, session_id
)
@ -274,3 +276,33 @@ class Channel:
async def _wait_for_ack(self) -> None:
await loop.sleep(1000)
# ACCESS TO CACHE
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(self, key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(
self, key: int, default: T | None = None
) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.channel_cache.fields))
if self.channel_cache.data[key][0] != 1:
return default
return bytes(self.channel_cache.data[key][1:])
def cache_is_set(self, key: int) -> bool:
return self.channel_cache.is_set(key)
def cache_set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.channel_cache.fields))
utils.ensure(len(value) <= self.channel_cache.fields[key])
self.channel_cache.data[key][0] = 1
self.channel_cache.data[key][1:] = value

@ -1,16 +1,34 @@
from typing import TYPE_CHECKING
from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire.errors import UnexpectedMessage
from apps.thp import create_session
if TYPE_CHECKING:
from typing import Any, Callable, Coroutine
from trezor.messages import LoadDevice
from . import ChannelContext
pass
def get_handler_for_channel_message(
msg: protobuf.MessageType,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
return create_session.create_new_session
if msg.MESSAGE_WIRE_TYPE is MessageType.ThpCreateNewSession:
return create_session.create_new_session
if __debug__:
if msg.MESSAGE_WIRE_TYPE is MessageType.LoadDevice:
from apps.debug.load_device import load_device
def wrapper(
channel: ChannelContext, msg: LoadDevice
) -> Coroutine[Any, Any, protobuf.MessageType]:
return load_device(msg)
return wrapper
raise UnexpectedMessage("There is no handler available for this message")

@ -45,6 +45,19 @@ def select_buffer(
raise Exception("Failed to create a buffer for channel") # TODO handle better
def get_write_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType
) -> utils.BufferType:
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer):
# message is too big, we need to allocate a new buffer
return bytearray(required_min_size)
return buffer
def encode_into_buffer(
buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int
) -> int:
@ -54,11 +67,6 @@ def encode_into_buffer(
msg_size = protobuf.encoded_length(msg)
payload_size = SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH + msg_size
required_min_size = payload_size + CHECKSUM_LENGTH + TAG_LENGTH
if required_min_size > len(buffer):
# message is too big, we need to allocate a new buffer
buffer = bytearray(required_min_size)
_encode_session_into_buffer(memoryview(buffer), session_id)
_encode_message_type_into_buffer(

@ -35,6 +35,8 @@ if TYPE_CHECKING:
from . import ChannelContext
if __debug__:
from trezor.messages import LoadDevice
from . import state_to_str
@ -237,7 +239,7 @@ async def _handle_state_TH2(
)
# TODO add credential recognition
paired: bool = True # TODO should be output from credential check
paired: bool = False # TODO should be output from credential check
# send hanshake completion response
await ctx.write_handshake_message(
@ -334,7 +336,7 @@ async def _handle_channel_message(
expected_type = protobuf.type_for_wire(message_type)
message = message_handler.wrap_protobuf_load(buf, expected_type)
if not ThpCreateNewSession.is_type_of(message):
if not _is_channel_message(message):
raise ThpError(
"The received message cannot be handled by channel itself. It must be sent to allocated session."
)
@ -348,3 +350,9 @@ async def _handle_channel_message(
await ctx.write(response_message)
if __debug__:
log.debug(__name__, "_handle_channel_message - end")
def _is_channel_message(message) -> bool:
if __debug__:
return ThpCreateNewSession.is_type_of(message) or LoadDevice.is_type_of(message)
return ThpCreateNewSession.is_type_of(message)

@ -1,7 +1,7 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage.cache_thp import SessionThpCache
from trezor import log, loop, protobuf
from trezor import log, loop, protobuf, utils
from trezor.wire import message_handler, protocol_common
from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure
@ -13,6 +13,8 @@ if TYPE_CHECKING:
Any,
Awaitable,
Container,
TypeVar,
overload,
)
from . import ChannelContext
@ -160,3 +162,33 @@ class SessionContext(Context):
def set_session_state(self, state: SessionState) -> None:
self.session_cache.state = bytearray(state.to_bytes(1, "big"))
# ACCESS TO CACHE
if TYPE_CHECKING:
T = TypeVar("T")
@overload
def cache_get(self, key: int) -> bytes | None: # noqa: F811
...
@overload
def cache_get(self, key: int, default: T) -> bytes | T: # noqa: F811
...
def cache_get(
self, key: int, default: T | None = None
) -> bytes | T | None: # noqa: F811
utils.ensure(key < len(self.session_cache.fields))
if self.session_cache.data[key][0] != 1:
return default
return bytes(self.session_cache.data[key][1:])
def cache_is_set(self, key: int) -> bool:
return self.session_cache.is_set(key)
def cache_set(self, key: int, value: bytes) -> None:
utils.ensure(key < len(self.session_cache.fields))
utils.ensure(len(value) <= self.session_cache.fields[key])
self.session_cache.data[key][0] = 1
self.session_cache.data[key][1:] = value

@ -19,38 +19,23 @@ class ThpError(WireError):
class SessionState(IntEnum):
UNALLOCATED = 0
INITIALIZED = 1 # do not change, is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1
INITIALIZED = 1 # do not change, it is denoted as constant in storage.cache _THP_SESSION_STATE_INITIALIZED = 1
PAIRED = 2
UNPAIRED = 3
PAIRING = 4
APP_TRAFFIC = 5
def create_autenticated_session(unauthenticated_session: SessionThpCache):
# storage_thp_cache.start_session() - TODO something like this but for THP
raise NotImplementedError("Secure channel is not implemented, yet.")
def create_new_unauthenticated_session(iface: WireInterface, cid: int):
session_id = _get_id(iface, cid)
new_session = storage_thp_cache.create_new_unauthenticated_session(session_id)
set_session_state(new_session, SessionState.INITIALIZED)
def get_active_session() -> SessionThpCache | None:
return storage_thp_cache.get_active_session()
def get_session(iface: WireInterface, cid: int) -> SessionThpCache | None:
session_id = _get_id(iface, cid)
return get_session_from_id(session_id)
def get_session_from_id(session_id) -> SessionThpCache | None:
session = _get_authenticated_session_or_none(session_id)
if session is None:
session = _get_unauthenticated_session_or_none(session_id)
return session
for session in storage_thp_cache._SESSIONS:
if session.session_id == session_id:
return session
return None
def get_state(session: SessionThpCache | None) -> int:
@ -101,12 +86,6 @@ def sync_set_send_bit_to_opposite(cache: SessionThpCache | ChannelCache) -> None
_sync_set_send_bit(cache=cache, bit=1 - sync_get_send_bit(cache))
def is_active_session(session: SessionThpCache):
if session is None:
return False
return session.session_id == storage_thp_cache.get_active_session_id()
def set_session_state(session: SessionThpCache, new_state: SessionState):
session.state = bytearray(new_state.to_bytes(1, "big"))
@ -115,20 +94,6 @@ def _get_id(iface: WireInterface, cid: int) -> bytes:
return ustruct.pack(">HH", iface.iface_num(), cid)
def _get_authenticated_session_or_none(session_id) -> SessionThpCache | None:
for authenticated_session in storage_thp_cache._SESSIONS:
if authenticated_session.session_id == session_id:
return authenticated_session
return None
def _get_unauthenticated_session_or_none(session_id) -> SessionThpCache | None:
for unauthenticated_session in storage_thp_cache._UNAUTHENTICATED_SESSIONS:
if unauthenticated_session.session_id == session_id:
return unauthenticated_session
return None
def _sync_set_send_bit(cache: SessionThpCache | ChannelCache, bit: int) -> None:
if bit not in (0, 1):
raise ThpError("Unexpected send sync bit")

Loading…
Cancel
Save