mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-23 06:48:16 +00:00
refactor(core): move app registrations to a single handler function
apps.webauthn.boot() does not need an if-condition because it's only
called from session.py when the usb interface is enabled
This means that they do not need to be stored in RAM at all. The obvious
drawback is that we need to hand-edit the if/elif sequence, but we don't
register new handlers all that often so 🤷
This commit is contained in:
parent
391976bcda
commit
e629a72c3a
@ -12,6 +12,8 @@ from trezor.messages.Success import Success
|
||||
from apps.common import mnemonic, safety_checks
|
||||
from apps.common.request_pin import verify_user_pin
|
||||
|
||||
from . import workflow_handlers
|
||||
|
||||
if False:
|
||||
import protobuf
|
||||
from typing import Iterable, NoReturn, Protocol
|
||||
@ -149,7 +151,9 @@ async def handle_DoPreauthorized(
|
||||
PreauthorizedRequest(), *authorization.expected_wire_types()
|
||||
)
|
||||
|
||||
handler = wire.find_registered_workflow_handler(ctx.iface, req.MESSAGE_WIRE_TYPE)
|
||||
handler = workflow_handlers.find_registered_handler(
|
||||
ctx.iface, req.MESSAGE_WIRE_TYPE
|
||||
)
|
||||
if handler is None:
|
||||
return wire.unexpected_message()
|
||||
|
||||
@ -230,13 +234,13 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None:
|
||||
await verify_user_pin(ctx)
|
||||
|
||||
set_homescreen()
|
||||
wire.find_handler = wire.find_registered_workflow_handler
|
||||
wire.find_handler = workflow_handlers.find_registered_handler
|
||||
|
||||
|
||||
def get_pinlocked_handler(
|
||||
iface: wire.WireInterface, msg_type: int
|
||||
) -> wire.Handler[wire.Msg] | None:
|
||||
orig_handler = wire.find_registered_workflow_handler(iface, msg_type)
|
||||
orig_handler = workflow_handlers.find_registered_handler(iface, msg_type)
|
||||
if orig_handler is None:
|
||||
return None
|
||||
|
||||
@ -268,13 +272,19 @@ def reload_settings_from_storage() -> None:
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.register(MessageType.Initialize, handle_Initialize)
|
||||
wire.register(MessageType.GetFeatures, handle_GetFeatures)
|
||||
wire.register(MessageType.Cancel, handle_Cancel)
|
||||
wire.register(MessageType.LockDevice, handle_LockDevice)
|
||||
wire.register(MessageType.EndSession, handle_EndSession)
|
||||
wire.register(MessageType.Ping, handle_Ping)
|
||||
wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
|
||||
wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization)
|
||||
workflow_handlers.register(MessageType.Initialize, handle_Initialize)
|
||||
workflow_handlers.register(MessageType.GetFeatures, handle_GetFeatures)
|
||||
workflow_handlers.register(MessageType.Cancel, handle_Cancel)
|
||||
workflow_handlers.register(MessageType.LockDevice, handle_LockDevice)
|
||||
workflow_handlers.register(MessageType.EndSession, handle_EndSession)
|
||||
workflow_handlers.register(MessageType.Ping, handle_Ping)
|
||||
workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
|
||||
workflow_handlers.register(
|
||||
MessageType.CancelAuthorization, handle_CancelAuthorization
|
||||
)
|
||||
|
||||
reload_settings_from_storage()
|
||||
if config.is_unlocked():
|
||||
wire.find_handler = workflow_handlers.find_registered_handler
|
||||
else:
|
||||
wire.find_handler = get_pinlocked_handler
|
||||
|
@ -1,14 +1,5 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common.paths import PATTERN_BIP44
|
||||
|
||||
CURVE = "secp256k1"
|
||||
SLIP44_ID = 714
|
||||
PATTERN = PATTERN_BIP44
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.BinanceGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.BinanceGetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.BinanceSignTx, __name__, "sign_tx")
|
||||
|
@ -1,13 +0,0 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.AuthorizeCoinJoin, __name__, "authorize_coinjoin")
|
||||
wire.add(MessageType.GetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.GetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.GetOwnershipId, __name__, "get_ownership_id")
|
||||
wire.add(MessageType.GetOwnershipProof, __name__, "get_ownership_proof")
|
||||
wire.add(MessageType.SignTx, __name__, "sign_tx")
|
||||
wire.add(MessageType.SignMessage, __name__, "sign_message")
|
||||
wire.add(MessageType.VerifyMessage, __name__, "verify_message")
|
@ -1,10 +0,0 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
CURVE = "ed25519"
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.CardanoGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.CardanoGetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.CardanoSignTx, __name__, "sign_tx")
|
@ -10,6 +10,8 @@ if __debug__:
|
||||
from trezor import config, crypto, log, loop, utils
|
||||
from trezor.messages.Success import Success
|
||||
|
||||
from apps import workflow_handlers
|
||||
|
||||
if False:
|
||||
from trezor.messages.DebugLinkDecision import DebugLinkDecision
|
||||
from trezor.messages.DebugLinkGetState import DebugLinkGetState
|
||||
@ -183,10 +185,19 @@ if __debug__:
|
||||
if not utils.EMULATOR:
|
||||
config.wipe()
|
||||
|
||||
wire.add(MessageType.LoadDevice, __name__, "load_device")
|
||||
wire.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore
|
||||
wire.register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState)
|
||||
wire.register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom)
|
||||
wire.register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen)
|
||||
wire.register(MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard)
|
||||
wire.register(MessageType.DebugLinkWatchLayout, dispatch_DebugLinkWatchLayout)
|
||||
workflow_handlers.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore
|
||||
workflow_handlers.register(
|
||||
MessageType.DebugLinkGetState, dispatch_DebugLinkGetState
|
||||
)
|
||||
workflow_handlers.register(
|
||||
MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom
|
||||
)
|
||||
workflow_handlers.register(
|
||||
MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen
|
||||
)
|
||||
workflow_handlers.register(
|
||||
MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard
|
||||
)
|
||||
workflow_handlers.register(
|
||||
MessageType.DebugLinkWatchLayout, dispatch_DebugLinkWatchLayout
|
||||
)
|
||||
|
@ -1,13 +1,5 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common.paths import PATTERN_BIP44
|
||||
|
||||
CURVE = "secp256k1"
|
||||
SLIP44_ID = 194
|
||||
PATTERN = PATTERN_BIP44
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.EosGetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.EosSignTx, __name__, "sign_tx")
|
||||
|
@ -1,12 +1 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
CURVE = "secp256k1"
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.EthereumGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx")
|
||||
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message")
|
||||
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")
|
||||
|
@ -1,16 +1,5 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common.paths import PATTERN_SEP5
|
||||
|
||||
CURVE = "ed25519"
|
||||
SLIP44_ID = 134
|
||||
PATTERN = PATTERN_SEP5
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key")
|
||||
wire.add(MessageType.LiskGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.LiskSignTx, __name__, "sign_tx")
|
||||
wire.add(MessageType.LiskSignMessage, __name__, "sign_message")
|
||||
wire.add(MessageType.LiskVerifyMessage, __name__, "verify_message")
|
||||
|
@ -1,16 +0,0 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.ResetDevice, __name__, "reset_device")
|
||||
wire.add(MessageType.BackupDevice, __name__, "backup_device")
|
||||
wire.add(MessageType.WipeDevice, __name__, "wipe_device")
|
||||
wire.add(MessageType.RecoveryDevice, __name__, "recovery_device")
|
||||
wire.add(MessageType.ApplySettings, __name__, "apply_settings")
|
||||
wire.add(MessageType.ApplyFlags, __name__, "apply_flags")
|
||||
wire.add(MessageType.ChangePin, __name__, "change_pin")
|
||||
wire.add(MessageType.SetU2FCounter, __name__, "set_u2f_counter")
|
||||
wire.add(MessageType.GetNextU2FCounter, __name__, "get_next_u2f_counter")
|
||||
wire.add(MessageType.SdProtect, __name__, "sd_protect")
|
||||
wire.add(MessageType.ChangeWipeCode, __name__, "change_wipe_code")
|
@ -1,9 +0,0 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.GetEntropy, __name__, "get_entropy")
|
||||
wire.add(MessageType.SignIdentity, __name__, "sign_identity")
|
||||
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key")
|
||||
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value")
|
@ -1,20 +1,5 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common.paths import PATTERN_SEP5
|
||||
|
||||
CURVE = "ed25519"
|
||||
SLIP44_ID = 128
|
||||
PATTERN = PATTERN_SEP5
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.MoneroGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only")
|
||||
wire.add(MessageType.MoneroTransactionInitRequest, __name__, "sign_tx")
|
||||
wire.add(MessageType.MoneroKeyImageExportInitRequest, __name__, "key_image_sync")
|
||||
wire.add(MessageType.MoneroGetTxKeyRequest, __name__, "get_tx_keys")
|
||||
wire.add(MessageType.MoneroLiveRefreshStartRequest, __name__, "live_refresh")
|
||||
|
||||
if __debug__ and hasattr(MessageType, "DebugMoneroDiagRequest"):
|
||||
wire.add(MessageType.DebugMoneroDiagRequest, __name__, "diag")
|
||||
|
@ -1,6 +1,3 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common.paths import PATTERN_SEP5
|
||||
|
||||
CURVE = "ed25519-keccak"
|
||||
@ -10,8 +7,3 @@ PATTERNS = (
|
||||
PATTERN_SEP5,
|
||||
"m/44'/coin_type'/account'/0'/0'", # NanoWallet compatibility
|
||||
)
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.NEMGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.NEMSignTx, __name__, "sign_tx")
|
||||
|
@ -1,13 +1,5 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common.paths import PATTERN_BIP44
|
||||
|
||||
CURVE = "secp256k1"
|
||||
SLIP44_ID = 144
|
||||
PATTERN = PATTERN_BIP44
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.RippleGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.RippleSignTx, __name__, "sign_tx")
|
||||
|
@ -1,13 +1,5 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common.paths import PATTERN_SEP5
|
||||
|
||||
CURVE = "ed25519"
|
||||
SLIP44_ID = 148
|
||||
PATTERN = PATTERN_SEP5
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.StellarGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.StellarSignTx, __name__, "sign_tx")
|
||||
|
@ -1,6 +1,3 @@
|
||||
from trezor import wire
|
||||
from trezor.messages import MessageType
|
||||
|
||||
from apps.common.paths import PATTERN_SEP5
|
||||
|
||||
CURVE = "ed25519"
|
||||
@ -9,9 +6,3 @@ PATTERNS = (
|
||||
PATTERN_SEP5,
|
||||
"m/44'/coin_type'/0'/account'", # Ledger compatibility
|
||||
)
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(MessageType.TezosGetAddress, __name__, "get_address")
|
||||
wire.add(MessageType.TezosSignTx, __name__, "sign_tx")
|
||||
wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key")
|
||||
|
@ -1,24 +1,9 @@
|
||||
from trezor import loop, wire
|
||||
from trezor.messages import MessageType
|
||||
from trezor import loop
|
||||
|
||||
import usb
|
||||
|
||||
from .fido2 import handle_reports
|
||||
|
||||
|
||||
def boot() -> None:
|
||||
wire.add(
|
||||
MessageType.WebAuthnListResidentCredentials,
|
||||
__name__,
|
||||
"list_resident_credentials",
|
||||
)
|
||||
wire.add(
|
||||
MessageType.WebAuthnAddResidentCredential, __name__, "add_resident_credential"
|
||||
)
|
||||
wire.add(
|
||||
MessageType.WebAuthnRemoveResidentCredential,
|
||||
__name__,
|
||||
"remove_resident_credential",
|
||||
)
|
||||
import usb
|
||||
|
||||
if usb.ENABLE_IFACE_WEBAUTHN:
|
||||
loop.schedule(handle_reports(usb.iface_webauthn))
|
||||
loop.schedule(handle_reports(usb.iface_webauthn))
|
||||
|
200
core/src/apps/workflow_handlers.py
Normal file
200
core/src/apps/workflow_handlers.py
Normal file
@ -0,0 +1,200 @@
|
||||
from trezor import utils
|
||||
from trezor.messages import MessageType
|
||||
|
||||
if False:
|
||||
from trezor.wire import Handler
|
||||
from trezorio import WireInterface
|
||||
|
||||
|
||||
workflow_handlers: dict[int, Handler] = {}
|
||||
|
||||
|
||||
def register(wire_type: int, handler: Handler) -> None:
|
||||
"""Register `handler` to get scheduled after `wire_type` message is received."""
|
||||
workflow_handlers[wire_type] = handler
|
||||
|
||||
|
||||
def find_message_handler_module(msg_type: int) -> str:
|
||||
"""Statically find the appropriate workflow handler.
|
||||
|
||||
For now, new messages must be registered by hand in the if-elif manner below.
|
||||
The reason for this is memory fragmentation optimization:
|
||||
- using a dict would mean that the whole thing stays in RAM, whereas an if-elif
|
||||
sequence is run from flash
|
||||
- collecting everything as strings instead of importing directly means that we don't
|
||||
need to load any of the modules into memory until we actually need them
|
||||
"""
|
||||
if False:
|
||||
raise RuntimeError
|
||||
|
||||
# debug
|
||||
elif __debug__ and msg_type == MessageType.LoadDevice:
|
||||
return "apps.debug.load_device"
|
||||
|
||||
# management
|
||||
elif msg_type == MessageType.ResetDevice:
|
||||
return "apps.management.reset_device"
|
||||
elif msg_type == MessageType.BackupDevice:
|
||||
return "apps.management.backup_device"
|
||||
elif msg_type == MessageType.WipeDevice:
|
||||
return "apps.management.wipe_device"
|
||||
elif msg_type == MessageType.RecoveryDevice:
|
||||
return "apps.management.recovery_device"
|
||||
elif msg_type == MessageType.ApplySettings:
|
||||
return "apps.management.apply_settings"
|
||||
elif msg_type == MessageType.ApplyFlags:
|
||||
return "apps.management.apply_flags"
|
||||
elif msg_type == MessageType.ChangePin:
|
||||
return "apps.management.change_pin"
|
||||
elif msg_type == MessageType.SetU2FCounter:
|
||||
return "apps.management.set_u2f_counter"
|
||||
elif msg_type == MessageType.GetNextU2FCounter:
|
||||
return "apps.management.get_next_u2f_counter"
|
||||
elif msg_type == MessageType.SdProtect:
|
||||
return "apps.management.sd_protect"
|
||||
elif msg_type == MessageType.ChangeWipeCode:
|
||||
return "apps.management.change_wipe_code"
|
||||
|
||||
# bitcoin
|
||||
elif msg_type == MessageType.AuthorizeCoinJoin:
|
||||
return "apps.bitcoin.authorize_coinjoin"
|
||||
elif msg_type == MessageType.GetPublicKey:
|
||||
return "apps.bitcoin.get_public_key"
|
||||
elif msg_type == MessageType.GetAddress:
|
||||
return "apps.bitcoin.get_address"
|
||||
elif msg_type == MessageType.GetOwnershipId:
|
||||
return "apps.bitcoin.get_ownership_id"
|
||||
elif msg_type == MessageType.GetOwnershipProof:
|
||||
return "apps.bitcoin.get_ownership_proof"
|
||||
elif msg_type == MessageType.SignTx:
|
||||
return "apps.bitcoin.sign_tx"
|
||||
elif msg_type == MessageType.SignMessage:
|
||||
return "apps.bitcoin.sign_message"
|
||||
elif msg_type == MessageType.VerifyMessage:
|
||||
return "apps.bitcoin.verify_message"
|
||||
|
||||
# misc
|
||||
elif msg_type == MessageType.GetEntropy:
|
||||
return "apps.misc.get_entropy"
|
||||
elif msg_type == MessageType.SignIdentity:
|
||||
return "apps.misc.sign_identity"
|
||||
elif msg_type == MessageType.GetECDHSessionKey:
|
||||
return "apps.misc.get_ecdh_session_key"
|
||||
elif msg_type == MessageType.CipherKeyValue:
|
||||
return "apps.misc.cipher_key_value"
|
||||
|
||||
elif not utils.BITCOIN_ONLY:
|
||||
if False:
|
||||
raise RuntimeError
|
||||
|
||||
# webauthn
|
||||
elif msg_type == MessageType.WebAuthnListResidentCredentials:
|
||||
return "apps.webauthn.list_resident_credentials"
|
||||
elif msg_type == MessageType.WebAuthnAddResidentCredential:
|
||||
return "apps.webauthn.add_resident_credential"
|
||||
elif msg_type == MessageType.WebAuthnRemoveResidentCredential:
|
||||
return "apps.webauthn.remove_resident_credential"
|
||||
|
||||
# ethereum
|
||||
elif msg_type == MessageType.EthereumGetAddress:
|
||||
return "apps.ethereum.get_address"
|
||||
elif msg_type == MessageType.EthereumGetPublicKey:
|
||||
return "apps.ethereum.get_public_key"
|
||||
elif msg_type == MessageType.EthereumSignTx:
|
||||
return "apps.ethereum.sign_tx"
|
||||
elif msg_type == MessageType.EthereumSignMessage:
|
||||
return "apps.ethereum.sign_message"
|
||||
elif msg_type == MessageType.EthereumVerifyMessage:
|
||||
return "apps.ethereum.verify_message"
|
||||
|
||||
# lisk
|
||||
elif msg_type == MessageType.LiskGetPublicKey:
|
||||
return "apps.lisk.get_public_key"
|
||||
elif msg_type == MessageType.LiskGetAddress:
|
||||
return "apps.lisk.get_address"
|
||||
elif msg_type == MessageType.LiskSignTx:
|
||||
return "apps.lisk.sign_tx"
|
||||
elif msg_type == MessageType.LiskSignMessage:
|
||||
return "apps.lisk.sign_message"
|
||||
elif msg_type == MessageType.LiskVerifyMessage:
|
||||
return "apps.lisk.verify_message"
|
||||
|
||||
# monero
|
||||
elif msg_type == MessageType.MoneroGetAddress:
|
||||
return "apps.monero.get_address"
|
||||
elif msg_type == MessageType.MoneroGetWatchKey:
|
||||
return "apps.monero.get_watch_only"
|
||||
elif msg_type == MessageType.MoneroTransactionInitRequest:
|
||||
return "apps.monero.sign_tx"
|
||||
elif msg_type == MessageType.MoneroKeyImageExportInitRequest:
|
||||
return "apps.monero.key_image_sync"
|
||||
elif msg_type == MessageType.MoneroGetTxKeyRequest:
|
||||
return "apps.monero.get_tx_keys"
|
||||
elif msg_type == MessageType.MoneroLiveRefreshStartRequest:
|
||||
return "apps.monero.live_refresh"
|
||||
if __debug__ and msg_type == MessageType.DebugMoneroDiagRequest:
|
||||
return "apps.monero.diag"
|
||||
|
||||
# nem
|
||||
elif msg_type == MessageType.NEMGetAddress:
|
||||
return "apps.nem.get_address"
|
||||
elif msg_type == MessageType.NEMSignTx:
|
||||
return "apps.nem.sign_tx"
|
||||
|
||||
# stellar
|
||||
elif msg_type == MessageType.StellarGetAddress:
|
||||
return "apps.stellar.get_address"
|
||||
elif msg_type == MessageType.StellarSignTx:
|
||||
return "apps.stellar.sign_tx"
|
||||
|
||||
# ripple
|
||||
elif msg_type == MessageType.RippleGetAddress:
|
||||
return "apps.ripple.get_address"
|
||||
elif msg_type == MessageType.RippleSignTx:
|
||||
return "apps.ripple.sign_tx"
|
||||
|
||||
# cardano
|
||||
elif msg_type == MessageType.CardanoGetAddress:
|
||||
return "apps.cardano.get_address"
|
||||
elif msg_type == MessageType.CardanoGetPublicKey:
|
||||
return "apps.cardano.get_public_key"
|
||||
elif msg_type == MessageType.CardanoSignTx:
|
||||
return "apps.cardano.sign_tx"
|
||||
|
||||
# tezos
|
||||
elif msg_type == MessageType.TezosGetAddress:
|
||||
return "apps.tezos.get_address"
|
||||
elif msg_type == MessageType.TezosSignTx:
|
||||
return "apps.tezos.sign_tx"
|
||||
elif msg_type == MessageType.TezosGetPublicKey:
|
||||
return "apps.tezos.get_public_key"
|
||||
|
||||
# eos
|
||||
elif msg_type == MessageType.EosGetPublicKey:
|
||||
return "apps.eos.get_public_key"
|
||||
elif msg_type == MessageType.EosSignTx:
|
||||
return "apps.eos.sign_tx"
|
||||
|
||||
# binance
|
||||
elif msg_type == MessageType.BinanceGetAddress:
|
||||
return "apps.binance.get_address"
|
||||
elif msg_type == MessageType.BinanceGetPublicKey:
|
||||
return "apps.binance.get_public_key"
|
||||
elif msg_type == MessageType.BinanceSignTx:
|
||||
return "apps.binance.sign_tx"
|
||||
|
||||
raise ValueError
|
||||
|
||||
|
||||
def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None:
|
||||
if msg_type in workflow_handlers:
|
||||
# Message has a handler available, return it directly.
|
||||
return workflow_handlers[msg_type]
|
||||
|
||||
try:
|
||||
modname = find_message_handler_module(msg_type)
|
||||
handler_name = modname[modname.rfind(".") + 1 :]
|
||||
module = __import__(modname, None, None, (handler_name,), 0)
|
||||
return getattr(module, handler_name) # type: ignore
|
||||
except ValueError:
|
||||
return None
|
@ -63,38 +63,15 @@ if False:
|
||||
Handler = Callable[["Context", Msg], HandlerTask]
|
||||
|
||||
|
||||
# Maps a wire type directly to a handler.
|
||||
workflow_handlers: dict[int, Handler] = {}
|
||||
|
||||
# Maps a wire type to a tuple of package and module. This allows handlers
|
||||
# to be dynamically imported when such message arrives.
|
||||
workflow_packages: dict[int, tuple[str, str]] = {}
|
||||
|
||||
# If set to False protobuf messages marked with "unstable" option are rejected.
|
||||
experimental_enabled: bool = False
|
||||
|
||||
|
||||
def add(wire_type: int, pkgname: str, modname: str) -> None:
|
||||
"""Shortcut for registering a dynamically-imported Protobuf workflow."""
|
||||
workflow_packages[wire_type] = (pkgname, modname)
|
||||
|
||||
|
||||
def register(wire_type: int, handler: Handler) -> None:
|
||||
"""Register `handler` to get scheduled after `wire_type` message is received."""
|
||||
workflow_handlers[wire_type] = handler
|
||||
|
||||
|
||||
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||
"""Initialize the wire stack on passed USB interface."""
|
||||
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
|
||||
|
||||
|
||||
def clear() -> None:
|
||||
"""Remove all registered handlers."""
|
||||
workflow_handlers.clear()
|
||||
workflow_packages.clear()
|
||||
|
||||
|
||||
if False:
|
||||
from typing import Protocol
|
||||
|
||||
@ -459,33 +436,12 @@ async def handle_session(
|
||||
log.exception(__name__, exc)
|
||||
|
||||
|
||||
def find_registered_workflow_handler(
|
||||
iface: WireInterface, msg_type: int
|
||||
) -> Handler | None:
|
||||
if msg_type in workflow_handlers:
|
||||
# Message has a handler available, return it directly.
|
||||
handler = workflow_handlers[msg_type]
|
||||
|
||||
elif msg_type in workflow_packages:
|
||||
# Message needs a dynamically imported handler, import it.
|
||||
pkgname, modname = workflow_packages[msg_type]
|
||||
handler = import_workflow(pkgname, modname)
|
||||
|
||||
else:
|
||||
# Message does not have any registered handler.
|
||||
return None
|
||||
|
||||
return handler
|
||||
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
|
||||
"""Placeholder handler lookup before a proper one is registered."""
|
||||
return None
|
||||
|
||||
|
||||
find_handler = find_registered_workflow_handler
|
||||
|
||||
|
||||
def import_workflow(pkgname: str, modname: str) -> Any:
|
||||
modpath = "%s.%s" % (pkgname, modname)
|
||||
module = __import__(modpath, None, None, (modname,), 0)
|
||||
handler = getattr(module, modname)
|
||||
return handler
|
||||
find_handler = _find_handler_placeholder
|
||||
|
||||
|
||||
def failure(exc: BaseException) -> Failure:
|
||||
|
Loading…
Reference in New Issue
Block a user