From e629a72c3a4ebe7ed048f32857e003dc7005096b Mon Sep 17 00:00:00 2001 From: matejcik Date: Fri, 19 Mar 2021 16:41:05 +0100 Subject: [PATCH] refactor(core): move app registrations to a single handler function apps.webauthn.boot() does not need an if-condition because it's only called from session.py when the usb interface is enabled This means that they do not need to be stored in RAM at all. The obvious drawback is that we need to hand-edit the if/elif sequence, but we don't register new handlers all that often so :shrug: --- core/src/apps/base.py | 32 +++-- core/src/apps/binance/__init__.py | 9 -- core/src/apps/bitcoin/__init__.py | 13 -- core/src/apps/cardano/__init__.py | 10 -- core/src/apps/debug/__init__.py | 25 +++- core/src/apps/eos/__init__.py | 8 -- core/src/apps/ethereum/__init__.py | 11 -- core/src/apps/lisk/__init__.py | 11 -- core/src/apps/management/__init__.py | 16 --- core/src/apps/misc/__init__.py | 9 -- core/src/apps/monero/__init__.py | 15 -- core/src/apps/nem/__init__.py | 8 -- core/src/apps/ripple/__init__.py | 8 -- core/src/apps/stellar/__init__.py | 8 -- core/src/apps/tezos/__init__.py | 9 -- core/src/apps/webauthn/__init__.py | 23 +-- core/src/apps/workflow_handlers.py | 200 +++++++++++++++++++++++++++ core/src/trezor/wire/__init__.py | 52 +------ 18 files changed, 247 insertions(+), 220 deletions(-) create mode 100644 core/src/apps/workflow_handlers.py diff --git a/core/src/apps/base.py b/core/src/apps/base.py index f9fe4f951..42aacc16a 100644 --- a/core/src/apps/base.py +++ b/core/src/apps/base.py @@ -12,6 +12,8 @@ from trezor.messages.Success import Success from apps.common import mnemonic, safety_checks from apps.common.request_pin import verify_user_pin +from . import workflow_handlers + if False: import protobuf from typing import Iterable, NoReturn, Protocol @@ -149,7 +151,9 @@ async def handle_DoPreauthorized( PreauthorizedRequest(), *authorization.expected_wire_types() ) - handler = wire.find_registered_workflow_handler(ctx.iface, req.MESSAGE_WIRE_TYPE) + handler = workflow_handlers.find_registered_handler( + ctx.iface, req.MESSAGE_WIRE_TYPE + ) if handler is None: return wire.unexpected_message() @@ -230,13 +234,13 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None: await verify_user_pin(ctx) set_homescreen() - wire.find_handler = wire.find_registered_workflow_handler + wire.find_handler = workflow_handlers.find_registered_handler def get_pinlocked_handler( iface: wire.WireInterface, msg_type: int ) -> wire.Handler[wire.Msg] | None: - orig_handler = wire.find_registered_workflow_handler(iface, msg_type) + orig_handler = workflow_handlers.find_registered_handler(iface, msg_type) if orig_handler is None: return None @@ -268,13 +272,19 @@ def reload_settings_from_storage() -> None: def boot() -> None: - wire.register(MessageType.Initialize, handle_Initialize) - wire.register(MessageType.GetFeatures, handle_GetFeatures) - wire.register(MessageType.Cancel, handle_Cancel) - wire.register(MessageType.LockDevice, handle_LockDevice) - wire.register(MessageType.EndSession, handle_EndSession) - wire.register(MessageType.Ping, handle_Ping) - wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized) - wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization) + workflow_handlers.register(MessageType.Initialize, handle_Initialize) + workflow_handlers.register(MessageType.GetFeatures, handle_GetFeatures) + workflow_handlers.register(MessageType.Cancel, handle_Cancel) + workflow_handlers.register(MessageType.LockDevice, handle_LockDevice) + workflow_handlers.register(MessageType.EndSession, handle_EndSession) + workflow_handlers.register(MessageType.Ping, handle_Ping) + workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized) + workflow_handlers.register( + MessageType.CancelAuthorization, handle_CancelAuthorization + ) reload_settings_from_storage() + if config.is_unlocked(): + wire.find_handler = workflow_handlers.find_registered_handler + else: + wire.find_handler = get_pinlocked_handler diff --git a/core/src/apps/binance/__init__.py b/core/src/apps/binance/__init__.py index d701b4b72..3947ffdab 100644 --- a/core/src/apps/binance/__init__.py +++ b/core/src/apps/binance/__init__.py @@ -1,14 +1,5 @@ -from trezor import wire -from trezor.messages import MessageType - from apps.common.paths import PATTERN_BIP44 CURVE = "secp256k1" SLIP44_ID = 714 PATTERN = PATTERN_BIP44 - - -def boot() -> None: - wire.add(MessageType.BinanceGetAddress, __name__, "get_address") - wire.add(MessageType.BinanceGetPublicKey, __name__, "get_public_key") - wire.add(MessageType.BinanceSignTx, __name__, "sign_tx") diff --git a/core/src/apps/bitcoin/__init__.py b/core/src/apps/bitcoin/__init__.py index b55370252..e69de29bb 100644 --- a/core/src/apps/bitcoin/__init__.py +++ b/core/src/apps/bitcoin/__init__.py @@ -1,13 +0,0 @@ -from trezor import wire -from trezor.messages import MessageType - - -def boot() -> None: - wire.add(MessageType.AuthorizeCoinJoin, __name__, "authorize_coinjoin") - wire.add(MessageType.GetPublicKey, __name__, "get_public_key") - wire.add(MessageType.GetAddress, __name__, "get_address") - wire.add(MessageType.GetOwnershipId, __name__, "get_ownership_id") - wire.add(MessageType.GetOwnershipProof, __name__, "get_ownership_proof") - wire.add(MessageType.SignTx, __name__, "sign_tx") - wire.add(MessageType.SignMessage, __name__, "sign_message") - wire.add(MessageType.VerifyMessage, __name__, "verify_message") diff --git a/core/src/apps/cardano/__init__.py b/core/src/apps/cardano/__init__.py index 0ecb24c58..e69de29bb 100644 --- a/core/src/apps/cardano/__init__.py +++ b/core/src/apps/cardano/__init__.py @@ -1,10 +0,0 @@ -from trezor import wire -from trezor.messages import MessageType - -CURVE = "ed25519" - - -def boot() -> None: - wire.add(MessageType.CardanoGetAddress, __name__, "get_address") - wire.add(MessageType.CardanoGetPublicKey, __name__, "get_public_key") - wire.add(MessageType.CardanoSignTx, __name__, "sign_tx") diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 8732209a2..b32b305c5 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -10,6 +10,8 @@ if __debug__: from trezor import config, crypto, log, loop, utils from trezor.messages.Success import Success + from apps import workflow_handlers + if False: from trezor.messages.DebugLinkDecision import DebugLinkDecision from trezor.messages.DebugLinkGetState import DebugLinkGetState @@ -183,10 +185,19 @@ if __debug__: if not utils.EMULATOR: config.wipe() - wire.add(MessageType.LoadDevice, __name__, "load_device") - wire.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore - wire.register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) - wire.register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom) - wire.register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen) - wire.register(MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard) - wire.register(MessageType.DebugLinkWatchLayout, dispatch_DebugLinkWatchLayout) + workflow_handlers.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore + workflow_handlers.register( + MessageType.DebugLinkGetState, dispatch_DebugLinkGetState + ) + workflow_handlers.register( + MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom + ) + workflow_handlers.register( + MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen + ) + workflow_handlers.register( + MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard + ) + workflow_handlers.register( + MessageType.DebugLinkWatchLayout, dispatch_DebugLinkWatchLayout + ) diff --git a/core/src/apps/eos/__init__.py b/core/src/apps/eos/__init__.py index 6f9288323..281255758 100644 --- a/core/src/apps/eos/__init__.py +++ b/core/src/apps/eos/__init__.py @@ -1,13 +1,5 @@ -from trezor import wire -from trezor.messages import MessageType - from apps.common.paths import PATTERN_BIP44 CURVE = "secp256k1" SLIP44_ID = 194 PATTERN = PATTERN_BIP44 - - -def boot() -> None: - wire.add(MessageType.EosGetPublicKey, __name__, "get_public_key") - wire.add(MessageType.EosSignTx, __name__, "sign_tx") diff --git a/core/src/apps/ethereum/__init__.py b/core/src/apps/ethereum/__init__.py index 1ce060d24..aef1e750b 100644 --- a/core/src/apps/ethereum/__init__.py +++ b/core/src/apps/ethereum/__init__.py @@ -1,12 +1 @@ -from trezor import wire -from trezor.messages import MessageType - CURVE = "secp256k1" - - -def boot() -> None: - wire.add(MessageType.EthereumGetAddress, __name__, "get_address") - wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key") - wire.add(MessageType.EthereumSignTx, __name__, "sign_tx") - wire.add(MessageType.EthereumSignMessage, __name__, "sign_message") - wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message") diff --git a/core/src/apps/lisk/__init__.py b/core/src/apps/lisk/__init__.py index dd7b2e7e2..31b94dddb 100644 --- a/core/src/apps/lisk/__init__.py +++ b/core/src/apps/lisk/__init__.py @@ -1,16 +1,5 @@ -from trezor import wire -from trezor.messages import MessageType - from apps.common.paths import PATTERN_SEP5 CURVE = "ed25519" SLIP44_ID = 134 PATTERN = PATTERN_SEP5 - - -def boot() -> None: - wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key") - wire.add(MessageType.LiskGetAddress, __name__, "get_address") - wire.add(MessageType.LiskSignTx, __name__, "sign_tx") - wire.add(MessageType.LiskSignMessage, __name__, "sign_message") - wire.add(MessageType.LiskVerifyMessage, __name__, "verify_message") diff --git a/core/src/apps/management/__init__.py b/core/src/apps/management/__init__.py index 255bea63a..e69de29bb 100644 --- a/core/src/apps/management/__init__.py +++ b/core/src/apps/management/__init__.py @@ -1,16 +0,0 @@ -from trezor import wire -from trezor.messages import MessageType - - -def boot() -> None: - wire.add(MessageType.ResetDevice, __name__, "reset_device") - wire.add(MessageType.BackupDevice, __name__, "backup_device") - wire.add(MessageType.WipeDevice, __name__, "wipe_device") - wire.add(MessageType.RecoveryDevice, __name__, "recovery_device") - wire.add(MessageType.ApplySettings, __name__, "apply_settings") - wire.add(MessageType.ApplyFlags, __name__, "apply_flags") - wire.add(MessageType.ChangePin, __name__, "change_pin") - wire.add(MessageType.SetU2FCounter, __name__, "set_u2f_counter") - wire.add(MessageType.GetNextU2FCounter, __name__, "get_next_u2f_counter") - wire.add(MessageType.SdProtect, __name__, "sd_protect") - wire.add(MessageType.ChangeWipeCode, __name__, "change_wipe_code") diff --git a/core/src/apps/misc/__init__.py b/core/src/apps/misc/__init__.py index a086c6f53..e69de29bb 100644 --- a/core/src/apps/misc/__init__.py +++ b/core/src/apps/misc/__init__.py @@ -1,9 +0,0 @@ -from trezor import wire -from trezor.messages import MessageType - - -def boot() -> None: - wire.add(MessageType.GetEntropy, __name__, "get_entropy") - wire.add(MessageType.SignIdentity, __name__, "sign_identity") - wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key") - wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value") diff --git a/core/src/apps/monero/__init__.py b/core/src/apps/monero/__init__.py index 1cfff37bb..153dc3154 100644 --- a/core/src/apps/monero/__init__.py +++ b/core/src/apps/monero/__init__.py @@ -1,20 +1,5 @@ -from trezor import wire -from trezor.messages import MessageType - from apps.common.paths import PATTERN_SEP5 CURVE = "ed25519" SLIP44_ID = 128 PATTERN = PATTERN_SEP5 - - -def boot() -> None: - wire.add(MessageType.MoneroGetAddress, __name__, "get_address") - wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only") - wire.add(MessageType.MoneroTransactionInitRequest, __name__, "sign_tx") - wire.add(MessageType.MoneroKeyImageExportInitRequest, __name__, "key_image_sync") - wire.add(MessageType.MoneroGetTxKeyRequest, __name__, "get_tx_keys") - wire.add(MessageType.MoneroLiveRefreshStartRequest, __name__, "live_refresh") - - if __debug__ and hasattr(MessageType, "DebugMoneroDiagRequest"): - wire.add(MessageType.DebugMoneroDiagRequest, __name__, "diag") diff --git a/core/src/apps/nem/__init__.py b/core/src/apps/nem/__init__.py index 880fd8b59..4d7aa3b54 100644 --- a/core/src/apps/nem/__init__.py +++ b/core/src/apps/nem/__init__.py @@ -1,6 +1,3 @@ -from trezor import wire -from trezor.messages import MessageType - from apps.common.paths import PATTERN_SEP5 CURVE = "ed25519-keccak" @@ -10,8 +7,3 @@ PATTERNS = ( PATTERN_SEP5, "m/44'/coin_type'/account'/0'/0'", # NanoWallet compatibility ) - - -def boot() -> None: - wire.add(MessageType.NEMGetAddress, __name__, "get_address") - wire.add(MessageType.NEMSignTx, __name__, "sign_tx") diff --git a/core/src/apps/ripple/__init__.py b/core/src/apps/ripple/__init__.py index f2816366e..634415c47 100644 --- a/core/src/apps/ripple/__init__.py +++ b/core/src/apps/ripple/__init__.py @@ -1,13 +1,5 @@ -from trezor import wire -from trezor.messages import MessageType - from apps.common.paths import PATTERN_BIP44 CURVE = "secp256k1" SLIP44_ID = 144 PATTERN = PATTERN_BIP44 - - -def boot() -> None: - wire.add(MessageType.RippleGetAddress, __name__, "get_address") - wire.add(MessageType.RippleSignTx, __name__, "sign_tx") diff --git a/core/src/apps/stellar/__init__.py b/core/src/apps/stellar/__init__.py index 3677f8779..f81e40fb5 100644 --- a/core/src/apps/stellar/__init__.py +++ b/core/src/apps/stellar/__init__.py @@ -1,13 +1,5 @@ -from trezor import wire -from trezor.messages import MessageType - from apps.common.paths import PATTERN_SEP5 CURVE = "ed25519" SLIP44_ID = 148 PATTERN = PATTERN_SEP5 - - -def boot() -> None: - wire.add(MessageType.StellarGetAddress, __name__, "get_address") - wire.add(MessageType.StellarSignTx, __name__, "sign_tx") diff --git a/core/src/apps/tezos/__init__.py b/core/src/apps/tezos/__init__.py index 0ca63354f..79bd1683d 100644 --- a/core/src/apps/tezos/__init__.py +++ b/core/src/apps/tezos/__init__.py @@ -1,6 +1,3 @@ -from trezor import wire -from trezor.messages import MessageType - from apps.common.paths import PATTERN_SEP5 CURVE = "ed25519" @@ -9,9 +6,3 @@ PATTERNS = ( PATTERN_SEP5, "m/44'/coin_type'/0'/account'", # Ledger compatibility ) - - -def boot() -> None: - wire.add(MessageType.TezosGetAddress, __name__, "get_address") - wire.add(MessageType.TezosSignTx, __name__, "sign_tx") - wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key") diff --git a/core/src/apps/webauthn/__init__.py b/core/src/apps/webauthn/__init__.py index ee3f7a7dc..401ce0d61 100644 --- a/core/src/apps/webauthn/__init__.py +++ b/core/src/apps/webauthn/__init__.py @@ -1,24 +1,9 @@ -from trezor import loop, wire -from trezor.messages import MessageType +from trezor import loop + +import usb from .fido2 import handle_reports def boot() -> None: - wire.add( - MessageType.WebAuthnListResidentCredentials, - __name__, - "list_resident_credentials", - ) - wire.add( - MessageType.WebAuthnAddResidentCredential, __name__, "add_resident_credential" - ) - wire.add( - MessageType.WebAuthnRemoveResidentCredential, - __name__, - "remove_resident_credential", - ) - import usb - - if usb.ENABLE_IFACE_WEBAUTHN: - loop.schedule(handle_reports(usb.iface_webauthn)) + loop.schedule(handle_reports(usb.iface_webauthn)) diff --git a/core/src/apps/workflow_handlers.py b/core/src/apps/workflow_handlers.py new file mode 100644 index 000000000..f7b9e4b17 --- /dev/null +++ b/core/src/apps/workflow_handlers.py @@ -0,0 +1,200 @@ +from trezor import utils +from trezor.messages import MessageType + +if False: + from trezor.wire import Handler + from trezorio import WireInterface + + +workflow_handlers: dict[int, Handler] = {} + + +def register(wire_type: int, handler: Handler) -> None: + """Register `handler` to get scheduled after `wire_type` message is received.""" + workflow_handlers[wire_type] = handler + + +def find_message_handler_module(msg_type: int) -> str: + """Statically find the appropriate workflow handler. + + For now, new messages must be registered by hand in the if-elif manner below. + The reason for this is memory fragmentation optimization: + - using a dict would mean that the whole thing stays in RAM, whereas an if-elif + sequence is run from flash + - collecting everything as strings instead of importing directly means that we don't + need to load any of the modules into memory until we actually need them + """ + if False: + raise RuntimeError + + # debug + elif __debug__ and msg_type == MessageType.LoadDevice: + return "apps.debug.load_device" + + # management + elif msg_type == MessageType.ResetDevice: + return "apps.management.reset_device" + elif msg_type == MessageType.BackupDevice: + return "apps.management.backup_device" + elif msg_type == MessageType.WipeDevice: + return "apps.management.wipe_device" + elif msg_type == MessageType.RecoveryDevice: + return "apps.management.recovery_device" + elif msg_type == MessageType.ApplySettings: + return "apps.management.apply_settings" + elif msg_type == MessageType.ApplyFlags: + return "apps.management.apply_flags" + elif msg_type == MessageType.ChangePin: + return "apps.management.change_pin" + elif msg_type == MessageType.SetU2FCounter: + return "apps.management.set_u2f_counter" + elif msg_type == MessageType.GetNextU2FCounter: + return "apps.management.get_next_u2f_counter" + elif msg_type == MessageType.SdProtect: + return "apps.management.sd_protect" + elif msg_type == MessageType.ChangeWipeCode: + return "apps.management.change_wipe_code" + + # bitcoin + elif msg_type == MessageType.AuthorizeCoinJoin: + return "apps.bitcoin.authorize_coinjoin" + elif msg_type == MessageType.GetPublicKey: + return "apps.bitcoin.get_public_key" + elif msg_type == MessageType.GetAddress: + return "apps.bitcoin.get_address" + elif msg_type == MessageType.GetOwnershipId: + return "apps.bitcoin.get_ownership_id" + elif msg_type == MessageType.GetOwnershipProof: + return "apps.bitcoin.get_ownership_proof" + elif msg_type == MessageType.SignTx: + return "apps.bitcoin.sign_tx" + elif msg_type == MessageType.SignMessage: + return "apps.bitcoin.sign_message" + elif msg_type == MessageType.VerifyMessage: + return "apps.bitcoin.verify_message" + + # misc + elif msg_type == MessageType.GetEntropy: + return "apps.misc.get_entropy" + elif msg_type == MessageType.SignIdentity: + return "apps.misc.sign_identity" + elif msg_type == MessageType.GetECDHSessionKey: + return "apps.misc.get_ecdh_session_key" + elif msg_type == MessageType.CipherKeyValue: + return "apps.misc.cipher_key_value" + + elif not utils.BITCOIN_ONLY: + if False: + raise RuntimeError + + # webauthn + elif msg_type == MessageType.WebAuthnListResidentCredentials: + return "apps.webauthn.list_resident_credentials" + elif msg_type == MessageType.WebAuthnAddResidentCredential: + return "apps.webauthn.add_resident_credential" + elif msg_type == MessageType.WebAuthnRemoveResidentCredential: + return "apps.webauthn.remove_resident_credential" + + # ethereum + elif msg_type == MessageType.EthereumGetAddress: + return "apps.ethereum.get_address" + elif msg_type == MessageType.EthereumGetPublicKey: + return "apps.ethereum.get_public_key" + elif msg_type == MessageType.EthereumSignTx: + return "apps.ethereum.sign_tx" + elif msg_type == MessageType.EthereumSignMessage: + return "apps.ethereum.sign_message" + elif msg_type == MessageType.EthereumVerifyMessage: + return "apps.ethereum.verify_message" + + # lisk + elif msg_type == MessageType.LiskGetPublicKey: + return "apps.lisk.get_public_key" + elif msg_type == MessageType.LiskGetAddress: + return "apps.lisk.get_address" + elif msg_type == MessageType.LiskSignTx: + return "apps.lisk.sign_tx" + elif msg_type == MessageType.LiskSignMessage: + return "apps.lisk.sign_message" + elif msg_type == MessageType.LiskVerifyMessage: + return "apps.lisk.verify_message" + + # monero + elif msg_type == MessageType.MoneroGetAddress: + return "apps.monero.get_address" + elif msg_type == MessageType.MoneroGetWatchKey: + return "apps.monero.get_watch_only" + elif msg_type == MessageType.MoneroTransactionInitRequest: + return "apps.monero.sign_tx" + elif msg_type == MessageType.MoneroKeyImageExportInitRequest: + return "apps.monero.key_image_sync" + elif msg_type == MessageType.MoneroGetTxKeyRequest: + return "apps.monero.get_tx_keys" + elif msg_type == MessageType.MoneroLiveRefreshStartRequest: + return "apps.monero.live_refresh" + if __debug__ and msg_type == MessageType.DebugMoneroDiagRequest: + return "apps.monero.diag" + + # nem + elif msg_type == MessageType.NEMGetAddress: + return "apps.nem.get_address" + elif msg_type == MessageType.NEMSignTx: + return "apps.nem.sign_tx" + + # stellar + elif msg_type == MessageType.StellarGetAddress: + return "apps.stellar.get_address" + elif msg_type == MessageType.StellarSignTx: + return "apps.stellar.sign_tx" + + # ripple + elif msg_type == MessageType.RippleGetAddress: + return "apps.ripple.get_address" + elif msg_type == MessageType.RippleSignTx: + return "apps.ripple.sign_tx" + + # cardano + elif msg_type == MessageType.CardanoGetAddress: + return "apps.cardano.get_address" + elif msg_type == MessageType.CardanoGetPublicKey: + return "apps.cardano.get_public_key" + elif msg_type == MessageType.CardanoSignTx: + return "apps.cardano.sign_tx" + + # tezos + elif msg_type == MessageType.TezosGetAddress: + return "apps.tezos.get_address" + elif msg_type == MessageType.TezosSignTx: + return "apps.tezos.sign_tx" + elif msg_type == MessageType.TezosGetPublicKey: + return "apps.tezos.get_public_key" + + # eos + elif msg_type == MessageType.EosGetPublicKey: + return "apps.eos.get_public_key" + elif msg_type == MessageType.EosSignTx: + return "apps.eos.sign_tx" + + # binance + elif msg_type == MessageType.BinanceGetAddress: + return "apps.binance.get_address" + elif msg_type == MessageType.BinanceGetPublicKey: + return "apps.binance.get_public_key" + elif msg_type == MessageType.BinanceSignTx: + return "apps.binance.sign_tx" + + raise ValueError + + +def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None: + if msg_type in workflow_handlers: + # Message has a handler available, return it directly. + return workflow_handlers[msg_type] + + try: + modname = find_message_handler_module(msg_type) + handler_name = modname[modname.rfind(".") + 1 :] + module = __import__(modname, None, None, (handler_name,), 0) + return getattr(module, handler_name) # type: ignore + except ValueError: + return None diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 90de32ba9..35db4fa0a 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -63,38 +63,15 @@ if False: Handler = Callable[["Context", Msg], HandlerTask] -# Maps a wire type directly to a handler. -workflow_handlers: dict[int, Handler] = {} - -# Maps a wire type to a tuple of package and module. This allows handlers -# to be dynamically imported when such message arrives. -workflow_packages: dict[int, tuple[str, str]] = {} - # If set to False protobuf messages marked with "unstable" option are rejected. experimental_enabled: bool = False -def add(wire_type: int, pkgname: str, modname: str) -> None: - """Shortcut for registering a dynamically-imported Protobuf workflow.""" - workflow_packages[wire_type] = (pkgname, modname) - - -def register(wire_type: int, handler: Handler) -> None: - """Register `handler` to get scheduled after `wire_type` message is received.""" - workflow_handlers[wire_type] = handler - - def setup(iface: WireInterface, is_debug_session: bool = False) -> None: """Initialize the wire stack on passed USB interface.""" loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session)) -def clear() -> None: - """Remove all registered handlers.""" - workflow_handlers.clear() - workflow_packages.clear() - - if False: from typing import Protocol @@ -459,33 +436,12 @@ async def handle_session( log.exception(__name__, exc) -def find_registered_workflow_handler( - iface: WireInterface, msg_type: int -) -> Handler | None: - if msg_type in workflow_handlers: - # Message has a handler available, return it directly. - handler = workflow_handlers[msg_type] - - elif msg_type in workflow_packages: - # Message needs a dynamically imported handler, import it. - pkgname, modname = workflow_packages[msg_type] - handler = import_workflow(pkgname, modname) - - else: - # Message does not have any registered handler. - return None - - return handler - - -find_handler = find_registered_workflow_handler +def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None: + """Placeholder handler lookup before a proper one is registered.""" + return None -def import_workflow(pkgname: str, modname: str) -> Any: - modpath = "%s.%s" % (pkgname, modname) - module = __import__(modpath, None, None, (modname,), 0) - handler = getattr(module, modname) - return handler +find_handler = _find_handler_placeholder def failure(exc: BaseException) -> Failure: