1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-08 06:20:56 +00:00

chore(core): update core to reflect cache and context refactor

[no changelog]
This commit is contained in:
M1nd3r 2024-11-15 21:52:30 +01:00
parent 80e4f506ba
commit 6b17af6a55
18 changed files with 165 additions and 106 deletions

View File

@ -1,11 +1,13 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.device as storage_device import storage.device as storage_device
from storage.cache_common import APP_COMMON_BUSY_DEADLINE_MS, APP_COMMON_SEED
from trezor import TR, config, utils, wire, workflow from trezor import TR, config, utils, wire, workflow
from trezor.enums import HomescreenFormat, MessageType from trezor.enums import HomescreenFormat, MessageType
from trezor.messages import Success, UnlockPath from trezor.messages import Success, UnlockPath
from trezor.ui.layouts import confirm_action from trezor.ui.layouts import confirm_action
from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter
from . import workflow_handlers from . import workflow_handlers
@ -34,7 +36,7 @@ def busy_expiry_ms() -> int:
Returns the time left until the busy state expires or 0 if the device is not in the busy state. Returns the time left until the busy state expires or 0 if the device is not in the busy state.
""" """
busy_deadline_ms = storage_cache.get_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) busy_deadline_ms = context.cache_get_int(APP_COMMON_BUSY_DEADLINE_MS)
if busy_deadline_ms is None: if busy_deadline_ms is None:
return 0 return 0
@ -203,12 +205,15 @@ def get_features() -> Features:
async def handle_Initialize(msg: Initialize) -> Features: async def handle_Initialize(msg: Initialize) -> Features:
session_id = storage_cache.start_session(msg.session_id) import storage.cache_codec as cache_codec
session_id = cache_codec.start_session(msg.session_id)
if not utils.BITCOIN_ONLY: if not utils.BITCOIN_ONLY:
derive_cardano = storage_cache.get_bool(storage_cache.APP_COMMON_DERIVE_CARDANO) from storage.cache_common import APP_COMMON_DERIVE_CARDANO
have_seed = storage_cache.is_set(storage_cache.APP_COMMON_SEED)
derive_cardano = context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
have_seed = context.cache_is_set(APP_COMMON_SEED)
if ( if (
have_seed have_seed
and msg.derive_cardano is not None and msg.derive_cardano is not None
@ -216,14 +221,12 @@ async def handle_Initialize(msg: Initialize) -> Features:
): ):
# seed is already derived, and host wants to change derive_cardano setting # seed is already derived, and host wants to change derive_cardano setting
# => create a new session # => create a new session
storage_cache.end_current_session() cache_codec.end_current_session()
session_id = storage_cache.start_session() session_id = cache_codec.start_session()
have_seed = False have_seed = False
if not have_seed: if not have_seed:
storage_cache.set_bool( context.cache_set_bool(APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano))
storage_cache.APP_COMMON_DERIVE_CARDANO, bool(msg.derive_cardano)
)
features = get_features() features = get_features()
features.session_id = session_id features.session_id = session_id
@ -252,16 +255,17 @@ async def handle_SetBusy(msg: SetBusy) -> Success:
import utime import utime
deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms) deadline = utime.ticks_add(utime.ticks_ms(), msg.expiry_ms)
storage_cache.set_int(storage_cache.APP_COMMON_BUSY_DEADLINE_MS, deadline) context.cache_set_int(APP_COMMON_BUSY_DEADLINE_MS, deadline)
else: else:
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) context.cache_delete(APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen() set_homescreen()
workflow.close_others() workflow.close_others()
return Success() return Success()
async def handle_EndSession(msg: EndSession) -> Success: async def handle_EndSession(msg: EndSession) -> Success:
storage_cache.end_current_session() ctx = context.get_context()
ctx.release()
return Success() return Success()
@ -276,7 +280,7 @@ async def handle_Ping(msg: Ping) -> Success:
async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType: async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
from trezor.messages import PreauthorizedRequest from trezor.messages import PreauthorizedRequest
from trezor.wire.context import call_any, get_context from trezor.wire.context import call_any
from apps.common import authorization from apps.common import authorization
@ -289,11 +293,9 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
req = await call_any(PreauthorizedRequest(), *wire_types) req = await call_any(PreauthorizedRequest(), *wire_types)
assert req.MESSAGE_WIRE_TYPE is not None assert req.MESSAGE_WIRE_TYPE is not None
handler = workflow_handlers.find_registered_handler( handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
get_context().iface, req.MESSAGE_WIRE_TYPE
)
if handler is None: if handler is None:
return wire.unexpected_message() return wire.message_handler.unexpected_message()
return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument] return await handler(req, authorization.get()) # type: ignore [Expected 1 positional argument]
@ -301,7 +303,7 @@ async def handle_DoPreauthorized(msg: DoPreauthorized) -> protobuf.MessageType:
async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType: async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
from trezor.crypto import hmac from trezor.crypto import hmac
from trezor.messages import UnlockedPathRequest from trezor.messages import UnlockedPathRequest
from trezor.wire.context import call_any, get_context from trezor.wire.context import call_any
from apps.common.paths import SLIP25_PURPOSE from apps.common.paths import SLIP25_PURPOSE
from apps.common.seed import Slip21Node, get_seed from apps.common.seed import Slip21Node, get_seed
@ -342,9 +344,7 @@ async def handle_UnlockPath(msg: UnlockPath) -> protobuf.MessageType:
req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types) req = await call_any(UnlockedPathRequest(mac=expected_mac), *wire_types)
assert req.MESSAGE_WIRE_TYPE in wire_types assert req.MESSAGE_WIRE_TYPE in wire_types
handler = workflow_handlers.find_registered_handler( handler = workflow_handlers.find_registered_handler(req.MESSAGE_WIRE_TYPE)
get_context().iface, req.MESSAGE_WIRE_TYPE
)
assert handler is not None assert handler is not None
return await handler(req, msg) # type: ignore [Expected 1 positional argument] return await handler(req, msg) # type: ignore [Expected 1 positional argument]
@ -364,7 +364,7 @@ def set_homescreen() -> None:
set_default = workflow.set_default # local_cache_attribute set_default = workflow.set_default # local_cache_attribute
if storage_cache.is_set(storage_cache.APP_COMMON_BUSY_DEADLINE_MS): if context.cache_is_set(APP_COMMON_BUSY_DEADLINE_MS):
from apps.homescreen import busyscreen from apps.homescreen import busyscreen
set_default(busyscreen) set_default(busyscreen)
@ -393,7 +393,7 @@ def set_homescreen() -> None:
def lock_device(interrupt_workflow: bool = True) -> None: def lock_device(interrupt_workflow: bool = True) -> None:
if config.has_pin(): if config.has_pin():
config.lock() config.lock()
wire.filters.append(_pinlock_filter) filters.append(_pinlock_filter)
set_homescreen() set_homescreen()
if interrupt_workflow: if interrupt_workflow:
workflow.close_others() workflow.close_others()
@ -429,7 +429,7 @@ async def unlock_device() -> None:
_SCREENSAVER_IS_ON = False _SCREENSAVER_IS_ON = False
set_homescreen() set_homescreen()
wire.remove_filter(_pinlock_filter) remove_filter(_pinlock_filter)
def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]: def _pinlock_filter(msg_type: int, prev_handler: Handler[Msg]) -> Handler[Msg]:
@ -450,7 +450,9 @@ def reload_settings_from_storage() -> None:
workflow.idle_timer.set( workflow.idle_timer.set(
storage_device.get_autolock_delay_ms(), lock_device_if_unlocked storage_device.get_autolock_delay_ms(), lock_device_if_unlocked
) )
wire.EXPERIMENTAL_ENABLED = storage_device.get_experimental_features() wire.message_handler.EXPERIMENTAL_ENABLED = (
storage_device.get_experimental_features()
)
if ui.display.orientation() != storage_device.get_rotation(): if ui.display.orientation() != storage_device.get_rotation():
ui.backlight_fade(ui.BacklightLevels.DIM) ui.backlight_fade(ui.BacklightLevels.DIM)
ui.display.orientation(storage_device.get_rotation()) ui.display.orientation(storage_device.get_rotation())
@ -482,4 +484,4 @@ def boot() -> None:
backup.activate_repeated_backup() backup.activate_repeated_backup()
if not config.is_unlocked(): if not config.is_unlocked():
# pinlocked handler should always be the last one # pinlocked handler should always be the last one
wire.filters.append(_pinlock_filter) filters.append(_pinlock_filter)

View File

@ -1,7 +1,7 @@
from micropython import const from micropython import const
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor.wire import DataError from trezor.wire import DataError, context
from .. import writers from .. import writers
@ -26,7 +26,7 @@ class PaymentRequestVerifier:
def __init__( def __init__(
self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain self, msg: TxAckPaymentRequest, coin: coininfo.CoinInfo, keychain: Keychain
) -> None: ) -> None:
from storage import cache from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto.hashlib import sha256 from trezor.crypto.hashlib import sha256
from trezor.utils import HashWriter from trezor.utils import HashWriter
@ -42,9 +42,9 @@ class PaymentRequestVerifier:
if msg.nonce: if msg.nonce:
nonce = bytes(msg.nonce) nonce = bytes(msg.nonce)
if cache.get(cache.APP_COMMON_NONCE) != nonce: if context.cache_get(APP_COMMON_NONCE) != nonce:
raise DataError("Invalid nonce in payment request.") raise DataError("Invalid nonce in payment request.")
cache.delete(cache.APP_COMMON_NONCE) context.cache_delete(APP_COMMON_NONCE)
else: else:
nonce = b"" nonce = b""
if msg.memos: if msg.memos:

View File

@ -1,8 +1,14 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from storage import cache, device import storage.device as device
from storage.cache_common import (
APP_CARDANO_ICARUS_SECRET,
APP_CARDANO_ICARUS_TREZOR_SECRET,
APP_COMMON_DERIVE_CARDANO,
)
from trezor import wire from trezor import wire
from trezor.crypto import cardano from trezor.crypto import cardano
from trezor.wire import context
from apps.common import mnemonic from apps.common import mnemonic
from apps.common.seed import get_seed from apps.common.seed import get_seed
@ -112,7 +118,7 @@ def is_minting_path(path: Bip32Path) -> bool:
def derive_and_store_secrets(passphrase: str) -> None: def derive_and_store_secrets(passphrase: str) -> None:
assert device.is_initialized() assert device.is_initialized()
assert cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO) assert context.cache_get_bool(APP_COMMON_DERIVE_CARDANO)
if not mnemonic.is_bip39(): if not mnemonic.is_bip39():
# nothing to do for SLIP-39, where we can derive the root from the main seed # nothing to do for SLIP-39, where we can derive the root from the main seed
@ -132,8 +138,8 @@ def derive_and_store_secrets(passphrase: str) -> None:
else: else:
icarus_trezor_secret = icarus_secret icarus_trezor_secret = icarus_secret
cache.set(cache.APP_CARDANO_ICARUS_SECRET, icarus_secret) context.cache_set(APP_CARDANO_ICARUS_SECRET, icarus_secret)
cache.set(cache.APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret) context.cache_set(APP_CARDANO_ICARUS_TREZOR_SECRET, icarus_trezor_secret)
async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain: async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychain:
@ -148,19 +154,19 @@ async def _get_keychain_bip39(derivation_type: CardanoDerivationType) -> Keychai
seed = await get_seed() seed = await get_seed()
return Keychain(cardano.from_seed_ledger(seed)) return Keychain(cardano.from_seed_ledger(seed))
if not cache.get_bool(cache.APP_COMMON_DERIVE_CARDANO): if not context.cache_get_bool(APP_COMMON_DERIVE_CARDANO):
raise wire.ProcessError("Cardano derivation is not enabled for this session") raise wire.ProcessError("Cardano derivation is not enabled for this session")
if derivation_type == CardanoDerivationType.ICARUS: if derivation_type == CardanoDerivationType.ICARUS:
cache_entry = cache.APP_CARDANO_ICARUS_SECRET cache_entry = APP_CARDANO_ICARUS_SECRET
else: else:
cache_entry = cache.APP_CARDANO_ICARUS_TREZOR_SECRET cache_entry = APP_CARDANO_ICARUS_TREZOR_SECRET
# _get_secret # _get_secret
secret = cache.get(cache_entry) secret = context.cache_get(cache_entry)
if secret is None: if secret is None:
await derive_and_store_roots() await derive_and_store_roots()
secret = cache.get(cache_entry) secret = context.cache_get(cache_entry)
assert secret is not None assert secret is not None
root = cardano.from_secret(secret) root = cardano.from_secret(secret)

View File

@ -1,23 +1,20 @@
from typing import Iterable from typing import Iterable
import storage.cache as storage_cache from storage.cache_common import (
APP_COMMON_AUTHORIZATION_DATA,
APP_COMMON_AUTHORIZATION_TYPE,
)
from trezor import protobuf from trezor import protobuf
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.wire import context
WIRE_TYPES: dict[int, tuple[int, ...]] = { WIRE_TYPES: dict[int, tuple[int, ...]] = {
MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof), MessageType.AuthorizeCoinJoin: (MessageType.SignTx, MessageType.GetOwnershipProof),
} }
APP_COMMON_AUTHORIZATION_DATA = (
storage_cache.APP_COMMON_AUTHORIZATION_DATA
) # global_import_cache
APP_COMMON_AUTHORIZATION_TYPE = (
storage_cache.APP_COMMON_AUTHORIZATION_TYPE
) # global_import_cache
def is_set() -> bool: def is_set() -> bool:
return storage_cache.get(APP_COMMON_AUTHORIZATION_TYPE) is not None return context.cache_get(APP_COMMON_AUTHORIZATION_TYPE) is not None
def set(auth_message: protobuf.MessageType) -> None: def set(auth_message: protobuf.MessageType) -> None:
@ -29,27 +26,27 @@ def set(auth_message: protobuf.MessageType) -> None:
# (because only wire-level messages have wire_type, which we use as identifier) # (because only wire-level messages have wire_type, which we use as identifier)
ensure(auth_message.MESSAGE_WIRE_TYPE is not None) ensure(auth_message.MESSAGE_WIRE_TYPE is not None)
assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too assert auth_message.MESSAGE_WIRE_TYPE is not None # so that typechecker knows too
storage_cache.set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE) context.cache_set_int(APP_COMMON_AUTHORIZATION_TYPE, auth_message.MESSAGE_WIRE_TYPE)
storage_cache.set(APP_COMMON_AUTHORIZATION_DATA, buffer) context.cache_set(APP_COMMON_AUTHORIZATION_DATA, buffer)
def get() -> protobuf.MessageType | None: def get() -> protobuf.MessageType | None:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if not stored_auth_type: if not stored_auth_type:
return None return None
buffer = storage_cache.get(APP_COMMON_AUTHORIZATION_DATA, b"") buffer = context.cache_get(APP_COMMON_AUTHORIZATION_DATA, b"")
return protobuf.load_message_buffer(buffer, stored_auth_type) return protobuf.load_message_buffer(buffer, stored_auth_type)
def is_set_any_session(auth_type: MessageType) -> bool: def is_set_any_session(auth_type: MessageType) -> bool:
return auth_type in storage_cache.get_int_all_sessions( return auth_type in context.cache_get_int_all_sessions(
APP_COMMON_AUTHORIZATION_TYPE APP_COMMON_AUTHORIZATION_TYPE
) )
def get_wire_types() -> Iterable[int]: def get_wire_types() -> Iterable[int]:
stored_auth_type = storage_cache.get_int(APP_COMMON_AUTHORIZATION_TYPE) stored_auth_type = context.cache_get_int(APP_COMMON_AUTHORIZATION_TYPE)
if stored_auth_type is None: if stored_auth_type is None:
return () return ()
@ -57,5 +54,5 @@ def get_wire_types() -> Iterable[int]:
def clear() -> None: def clear() -> None:
storage_cache.delete(APP_COMMON_AUTHORIZATION_TYPE) context.cache_delete(APP_COMMON_AUTHORIZATION_TYPE)
storage_cache.delete(APP_COMMON_AUTHORIZATION_DATA) context.cache_delete(APP_COMMON_AUTHORIZATION_DATA)

View File

@ -1,25 +1,27 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import storage.cache as storage_cache from storage.cache_common import APP_RECOVERY_REPEATED_BACKUP_UNLOCKED
from trezor import wire from trezor import wire
from trezor.enums import MessageType from trezor.enums import MessageType
from trezor.wire import context
from trezor.wire.message_handler import filters, remove_filter
if TYPE_CHECKING: if TYPE_CHECKING:
from trezor.wire import Handler, Msg from trezor.wire import Handler, Msg
def repeated_backup_enabled() -> bool: def repeated_backup_enabled() -> bool:
return storage_cache.get_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) return context.cache_get_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
def activate_repeated_backup() -> None: def activate_repeated_backup() -> None:
storage_cache.set_bool(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True) context.cache_set_bool(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED, True)
wire.filters.append(_repeated_backup_filter) filters.append(_repeated_backup_filter)
def deactivate_repeated_backup() -> None: def deactivate_repeated_backup() -> None:
storage_cache.delete(storage_cache.APP_RECOVERY_REPEATED_BACKUP_UNLOCKED) context.cache_delete(APP_RECOVERY_REPEATED_BACKUP_UNLOCKED)
wire.remove_filter(_repeated_backup_filter) remove_filter(_repeated_backup_filter)
_ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = ( _ALLOW_WHILE_REPEATED_BACKUP_UNLOCKED = (

View File

@ -0,0 +1,39 @@
from typing import TYPE_CHECKING
from trezor.wire import context
if TYPE_CHECKING:
from typing import Awaitable, Callable, ParamSpec
P = ParamSpec("P")
ByteFunc = Callable[P, bytes]
AsyncByteFunc = Callable[P, Awaitable[bytes]]
def stored(key: int) -> Callable[[ByteFunc[P]], ByteFunc[P]]:
def decorator(func: ByteFunc[P]) -> ByteFunc[P]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = context.cache_get(key)
if value is None:
value = func(*args, **kwargs)
context.cache_set(key, value)
return value
return wrapper
return decorator
def stored_async(key: int) -> Callable[[AsyncByteFunc[P]], AsyncByteFunc[P]]:
def decorator(func: AsyncByteFunc[P]) -> AsyncByteFunc[P]:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> bytes:
value = context.cache_get(key)
if value is None:
value = await func(*args, **kwargs)
context.cache_set(key, value)
return value
return wrapper
return decorator

View File

@ -1,9 +1,10 @@
import utime import utime
from typing import Any, NoReturn from typing import Any, NoReturn
import storage.cache as storage_cache from storage.cache_common import APP_COMMON_REQUEST_PIN_LAST_UNLOCK
from trezor import TR, config, utils, wire from trezor import TR, config, utils, wire
from trezor.ui.layouts import show_error_and_raise from trezor.ui.layouts import show_error_and_raise
from trezor.wire import context
async def _request_sd_salt( async def _request_sd_salt(
@ -77,7 +78,7 @@ async def request_pin_and_sd_salt(
def _set_last_unlock_time() -> None: def _set_last_unlock_time() -> None:
now = utime.ticks_ms() now = utime.ticks_ms()
storage_cache.set_int(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now) context.cache_set_int(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, now)
_DEF_ARG_PIN_ENTER: str = TR.pin__enter _DEF_ARG_PIN_ENTER: str = TR.pin__enter
@ -91,7 +92,7 @@ async def verify_user_pin(
) -> None: ) -> None:
# _get_last_unlock_time # _get_last_unlock_time
last_unlock = int.from_bytes( last_unlock = int.from_bytes(
storage_cache.get(storage_cache.APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big" context.cache_get(APP_COMMON_REQUEST_PIN_LAST_UNLOCK, b""), "big"
) )
if ( if (

View File

@ -1,15 +1,15 @@
import storage.cache as storage_cache
import storage.device as storage_device import storage.device as storage_device
from storage.cache import APP_COMMON_SAFETY_CHECKS_TEMPORARY from storage.cache_common import APP_COMMON_SAFETY_CHECKS_TEMPORARY
from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT from storage.device import SAFETY_CHECK_LEVEL_PROMPT, SAFETY_CHECK_LEVEL_STRICT
from trezor.enums import SafetyCheckLevel from trezor.enums import SafetyCheckLevel
from trezor.wire import context
def read_setting() -> SafetyCheckLevel: def read_setting() -> SafetyCheckLevel:
""" """
Returns the effective safety check level. Returns the effective safety check level.
""" """
temporary_safety_check_level = storage_cache.get(APP_COMMON_SAFETY_CHECKS_TEMPORARY) temporary_safety_check_level = context.cache_get(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
if temporary_safety_check_level: if temporary_safety_check_level:
return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum] return int.from_bytes(temporary_safety_check_level, "big") # type: ignore [int-into-enum]
else: else:
@ -27,14 +27,14 @@ def apply_setting(level: SafetyCheckLevel) -> None:
Changes the safety level settings. Changes the safety level settings.
""" """
if level == SafetyCheckLevel.Strict: if level == SafetyCheckLevel.Strict:
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
elif level == SafetyCheckLevel.PromptAlways: elif level == SafetyCheckLevel.PromptAlways:
storage_cache.delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY) context.cache_delete(APP_COMMON_SAFETY_CHECKS_TEMPORARY)
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_PROMPT)
elif level == SafetyCheckLevel.PromptTemporarily: elif level == SafetyCheckLevel.PromptTemporarily:
storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT) storage_device.set_safety_check_level(SAFETY_CHECK_LEVEL_STRICT)
storage_cache.set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big")) context.cache_set(APP_COMMON_SAFETY_CHECKS_TEMPORARY, level.to_bytes(1, "big"))
else: else:
raise ValueError("Unknown SafetyCheckLevel") raise ValueError("Unknown SafetyCheckLevel")

View File

@ -1,9 +1,12 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import storage.cache as storage_cache
import storage.device as storage_device import storage.device as storage_device
from storage.cache_common import APP_COMMON_SEED, APP_COMMON_SEED_WITHOUT_PASSPHRASE
from trezor import utils from trezor import utils
from trezor.crypto import hmac from trezor.crypto import hmac
from trezor.wire import context
from apps.common import cache
from . import mnemonic from . import mnemonic
from .passphrase import get as get_passphrase from .passphrase import get as get_passphrase
@ -13,6 +16,12 @@ if TYPE_CHECKING:
from .paths import Bip32Path, Slip21Path from .paths import Bip32Path, Slip21Path
if not utils.BITCOIN_ONLY:
from storage.cache_common import (
APP_CARDANO_ICARUS_SECRET,
APP_COMMON_DERIVE_CARDANO,
)
class Slip21Node: class Slip21Node:
""" """
@ -56,10 +65,10 @@ if not utils.BITCOIN_ONLY:
if not storage_device.is_initialized(): if not storage_device.is_initialized():
raise wire.NotInitialized("Device is not initialized") raise wire.NotInitialized("Device is not initialized")
need_seed = not storage_cache.is_set(storage_cache.APP_COMMON_SEED) need_seed = not context.cache_is_set(APP_COMMON_SEED)
need_cardano_secret = storage_cache.get_bool( need_cardano_secret = context.cache_get_bool(
storage_cache.APP_COMMON_DERIVE_CARDANO APP_COMMON_DERIVE_CARDANO
) and not storage_cache.is_set(storage_cache.APP_CARDANO_ICARUS_SECRET) ) and not context.cache_is_set(APP_CARDANO_ICARUS_SECRET)
if not need_seed and not need_cardano_secret: if not need_seed and not need_cardano_secret:
return return
@ -68,17 +77,17 @@ if not utils.BITCOIN_ONLY:
if need_seed: if need_seed:
common_seed = mnemonic.get_seed(passphrase) common_seed = mnemonic.get_seed(passphrase)
storage_cache.set(storage_cache.APP_COMMON_SEED, common_seed) context.cache_set(APP_COMMON_SEED, common_seed)
if need_cardano_secret: if need_cardano_secret:
from apps.cardano.seed import derive_and_store_secrets from apps.cardano.seed import derive_and_store_secrets
derive_and_store_secrets(passphrase) derive_and_store_secrets(passphrase)
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED) @cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes: async def get_seed() -> bytes:
await derive_and_store_roots() await derive_and_store_roots()
common_seed = storage_cache.get(storage_cache.APP_COMMON_SEED) common_seed = context.cache_get(APP_COMMON_SEED)
assert common_seed is not None assert common_seed is not None
return common_seed return common_seed
@ -86,13 +95,13 @@ else:
# === Bitcoin-only variant === # === Bitcoin-only variant ===
# We use the simple version of `get_seed` that never needs to derive anything else. # We use the simple version of `get_seed` that never needs to derive anything else.
@storage_cache.stored_async(storage_cache.APP_COMMON_SEED) @cache.stored_async(APP_COMMON_SEED)
async def get_seed() -> bytes: async def get_seed() -> bytes:
passphrase = await get_passphrase() passphrase = await get_passphrase()
return mnemonic.get_seed(passphrase) return mnemonic.get_seed(passphrase)
@storage_cache.stored(storage_cache.APP_COMMON_SEED_WITHOUT_PASSPHRASE) @cache.stored(APP_COMMON_SEED_WITHOUT_PASSPHRASE)
def _get_seed_without_passphrase() -> bytes: def _get_seed_without_passphrase() -> bytes:
if not storage_device.is_initialized(): if not storage_device.is_initialized():
raise Exception("Device is not initialized") raise Exception("Device is not initialized")

View File

@ -71,7 +71,7 @@ if __debug__:
) )
async def return_layout_change( async def return_layout_change(
ctx: wire.context.Context, detect_deadlock: bool = False ctx: wire.protocol_common.Context, detect_deadlock: bool = False
) -> None: ) -> None:
# set up the wait # set up the wait
storage.layout_watcher = True storage.layout_watcher = True
@ -356,11 +356,12 @@ if __debug__:
async def handle_session(iface: WireInterface) -> None: async def handle_session(iface: WireInterface) -> None:
from trezor import protobuf, wire from trezor import protobuf, wire
from trezor.wire import codec_v1, context from trezor.wire.codec import codec_v1
from trezor.wire.codec.codec_context import CodecContext
global DEBUG_CONTEXT global DEBUG_CONTEXT
DEBUG_CONTEXT = ctx = context.Context(iface, WIRE_BUFFER_DEBUG) DEBUG_CONTEXT = ctx = CodecContext(iface, WIRE_BUFFER_DEBUG)
if storage.layout_watcher: if storage.layout_watcher:
try: try:
@ -391,7 +392,7 @@ if __debug__:
) )
if msg.type not in WORKFLOW_HANDLERS: if msg.type not in WORKFLOW_HANDLERS:
await ctx.write(wire.unexpected_message()) await ctx.write(wire.message_handler.unexpected_message())
continue continue
elif req_type is None: elif req_type is None:
@ -402,7 +403,7 @@ if __debug__:
await ctx.write(Success()) await ctx.write(Success())
continue continue
req_msg = wire.wrap_protobuf_load(msg.data, req_type) req_msg = wire.message_handler.wrap_protobuf_load(msg.data, req_type)
try: try:
res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg) res_msg = await WORKFLOW_HANDLERS[msg.type](req_msg)
except Exception as exc: except Exception as exc:

View File

@ -5,10 +5,11 @@ if TYPE_CHECKING:
async def get_nonce(msg: GetNonce) -> Nonce: async def get_nonce(msg: GetNonce) -> Nonce:
from storage import cache from storage.cache_common import APP_COMMON_NONCE
from trezor.crypto import random from trezor.crypto import random
from trezor.messages import Nonce from trezor.messages import Nonce
from trezor.wire import context
nonce = random.bytes(32) nonce = random.bytes(32)
cache.set(cache.APP_COMMON_NONCE, nonce) context.cache_set(APP_COMMON_NONCE, nonce)
return Nonce(nonce=nonce) return Nonce(nonce=nonce)

View File

@ -38,7 +38,7 @@ async def recovery_process() -> Success:
recovery_type = storage_recovery.get_type() recovery_type = storage_recovery.get_type()
wire.AVOID_RESTARTING_FOR = ( wire.message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize, MessageType.Initialize,
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.EndSession, MessageType.EndSession,
@ -59,7 +59,7 @@ async def _continue_repeated_backup() -> None:
from apps.common import backup from apps.common import backup
from apps.management.backup_device import perform_backup from apps.management.backup_device import perform_backup
wire.AVOID_RESTARTING_FOR = ( wire.message_handler.AVOID_RESTARTING_FOR = (
MessageType.Initialize, MessageType.Initialize,
MessageType.GetFeatures, MessageType.GetFeatures,
MessageType.EndSession, MessageType.EndSession,

View File

@ -57,16 +57,17 @@ async def _init_step(
msg: MoneroLiveRefreshStartRequest, msg: MoneroLiveRefreshStartRequest,
keychain: Keychain, keychain: Keychain,
) -> MoneroLiveRefreshStartAck: ) -> MoneroLiveRefreshStartAck:
import storage.cache as storage_cache from storage.cache_common import APP_MONERO_LIVE_REFRESH
from trezor.messages import MoneroLiveRefreshStartAck from trezor.messages import MoneroLiveRefreshStartAck
from trezor.wire import context
from apps.common import paths from apps.common import paths
await paths.validate_path(keychain, msg.address_n) await paths.validate_path(keychain, msg.address_n)
if not storage_cache.get_bool(storage_cache.APP_MONERO_LIVE_REFRESH): if not context.cache_get_bool(APP_MONERO_LIVE_REFRESH):
await layout.require_confirm_live_refresh() await layout.require_confirm_live_refresh()
storage_cache.set_bool(storage_cache.APP_MONERO_LIVE_REFRESH, True) context.cache_set_bool(APP_MONERO_LIVE_REFRESH, True)
s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type) s.creds = misc.get_creds(keychain, msg.address_n, msg.network_type)

View File

@ -7,7 +7,7 @@ from trezor.messages import (
ThpCredentialMetadata, ThpCredentialMetadata,
ThpPairingCredential, ThpPairingCredential,
) )
from trezor.wire import wrap_protobuf_load from trezor.wire.message_handler import wrap_protobuf_load
if TYPE_CHECKING: if TYPE_CHECKING:
from apps.common.paths import Slip21Path from apps.common.paths import Slip21Path

View File

@ -1,8 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface
from trezor.wire import Handler, Msg from trezor.wire import Handler, Msg
@ -215,7 +213,7 @@ def _find_message_handler_module(msg_type: int) -> str:
raise ValueError raise ValueError
def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None: def find_registered_handler(msg_type: int) -> Handler | None:
if msg_type in workflow_handlers: if msg_type in workflow_handlers:
# Message has a handler available, return it directly. # Message has a handler available, return it directly.
return workflow_handlers[msg_type] return workflow_handlers[msg_type]

View File

@ -332,7 +332,7 @@ class Layout(Generic[T]):
def _paint(self) -> None: def _paint(self) -> None:
"""Paint the layout and ensure that homescreen cache is properly invalidated.""" """Paint the layout and ensure that homescreen cache is properly invalidated."""
import storage.cache as storage_cache import storage.cache_common as storage_cache
painted = self.layout.paint() painted = self.layout.paint()
if painted: if painted:

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import storage.cache as storage_cache import storage.cache_common as storage_cache
import trezorui2 import trezorui2
from trezor import TR, ui from trezor import TR, ui
@ -125,11 +125,13 @@ class Busyscreen(HomescreenBase):
) )
async def get_result(self) -> Any: async def get_result(self) -> Any:
from trezor.wire import context
from apps.base import set_homescreen from apps.base import set_homescreen
# Handle timeout. # Handle timeout.
result = await super().get_result() result = await super().get_result()
assert result == trezorui2.CANCELLED assert result == trezorui2.CANCELLED
storage_cache.delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS) context.cache_delete(storage_cache.APP_COMMON_BUSY_DEADLINE_MS)
set_homescreen() set_homescreen()
return result return result

View File

@ -1,7 +1,7 @@
import utime import utime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import storage.cache import storage.cache_common as cache_common
from trezor import log, loop from trezor import log, loop
from trezor.enums import MessageType from trezor.enums import MessageType
@ -153,7 +153,7 @@ def close_others() -> None:
if not task.is_running(): if not task.is_running():
task.close() task.close()
storage.cache.homescreen_shown = None cache_common.homescreen_shown = None
# if tasks were running, closing the last of them will run start_default # if tasks were running, closing the last of them will run start_default
@ -211,11 +211,11 @@ class IdleTimer:
time and saves it to storage.cache. This is done to avoid losing an time and saves it to storage.cache. This is done to avoid losing an
active timer when workflow restart happens and tasks are lost. active timer when workflow restart happens and tasks are lost.
""" """
if _restore_from_cache and storage.cache.autolock_last_touch is not None: if _restore_from_cache and cache_common.autolock_last_touch is not None:
now = storage.cache.autolock_last_touch now = cache_common.autolock_last_touch
else: else:
now = utime.ticks_ms() now = utime.ticks_ms()
storage.cache.autolock_last_touch = now cache_common.autolock_last_touch = now
for callback, task in self.tasks.items(): for callback, task in self.tasks.items():
timeout_us = self.timeouts[callback] timeout_us = self.timeouts[callback]