refactor(core): move app registrations to a single handler function

apps.webauthn.boot() does not need an if-condition because it's only
called from session.py when the usb interface is enabled

This means that they do not need to be stored in RAM at all. The obvious
drawback is that we need to hand-edit the if/elif sequence, but we don't
register new handlers all that often so 🤷
pull/1610/head
matejcik 3 years ago committed by matejcik
parent 391976bcda
commit e629a72c3a

@ -12,6 +12,8 @@ from trezor.messages.Success import Success
from apps.common import mnemonic, safety_checks
from apps.common.request_pin import verify_user_pin
from . import workflow_handlers
if False:
import protobuf
from typing import Iterable, NoReturn, Protocol
@ -149,7 +151,9 @@ async def handle_DoPreauthorized(
PreauthorizedRequest(), *authorization.expected_wire_types()
)
handler = wire.find_registered_workflow_handler(ctx.iface, req.MESSAGE_WIRE_TYPE)
handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE
)
if handler is None:
return wire.unexpected_message()
@ -230,13 +234,13 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None:
await verify_user_pin(ctx)
set_homescreen()
wire.find_handler = wire.find_registered_workflow_handler
wire.find_handler = workflow_handlers.find_registered_handler
def get_pinlocked_handler(
iface: wire.WireInterface, msg_type: int
) -> wire.Handler[wire.Msg] | None:
orig_handler = wire.find_registered_workflow_handler(iface, msg_type)
orig_handler = workflow_handlers.find_registered_handler(iface, msg_type)
if orig_handler is None:
return None
@ -268,13 +272,19 @@ def reload_settings_from_storage() -> None:
def boot() -> None:
wire.register(MessageType.Initialize, handle_Initialize)
wire.register(MessageType.GetFeatures, handle_GetFeatures)
wire.register(MessageType.Cancel, handle_Cancel)
wire.register(MessageType.LockDevice, handle_LockDevice)
wire.register(MessageType.EndSession, handle_EndSession)
wire.register(MessageType.Ping, handle_Ping)
wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization)
workflow_handlers.register(MessageType.Initialize, handle_Initialize)
workflow_handlers.register(MessageType.GetFeatures, handle_GetFeatures)
workflow_handlers.register(MessageType.Cancel, handle_Cancel)
workflow_handlers.register(MessageType.LockDevice, handle_LockDevice)
workflow_handlers.register(MessageType.EndSession, handle_EndSession)
workflow_handlers.register(MessageType.Ping, handle_Ping)
workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
workflow_handlers.register(
MessageType.CancelAuthorization, handle_CancelAuthorization
)
reload_settings_from_storage()
if config.is_unlocked():
wire.find_handler = workflow_handlers.find_registered_handler
else:
wire.find_handler = get_pinlocked_handler

@ -1,14 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_BIP44
CURVE = "secp256k1"
SLIP44_ID = 714
PATTERN = PATTERN_BIP44
def boot() -> None:
wire.add(MessageType.BinanceGetAddress, __name__, "get_address")
wire.add(MessageType.BinanceGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.BinanceSignTx, __name__, "sign_tx")

@ -1,13 +0,0 @@
from trezor import wire
from trezor.messages import MessageType
def boot() -> None:
wire.add(MessageType.AuthorizeCoinJoin, __name__, "authorize_coinjoin")
wire.add(MessageType.GetPublicKey, __name__, "get_public_key")
wire.add(MessageType.GetAddress, __name__, "get_address")
wire.add(MessageType.GetOwnershipId, __name__, "get_ownership_id")
wire.add(MessageType.GetOwnershipProof, __name__, "get_ownership_proof")
wire.add(MessageType.SignTx, __name__, "sign_tx")
wire.add(MessageType.SignMessage, __name__, "sign_message")
wire.add(MessageType.VerifyMessage, __name__, "verify_message")

@ -1,10 +0,0 @@
from trezor import wire
from trezor.messages import MessageType
CURVE = "ed25519"
def boot() -> None:
wire.add(MessageType.CardanoGetAddress, __name__, "get_address")
wire.add(MessageType.CardanoGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.CardanoSignTx, __name__, "sign_tx")

@ -10,6 +10,8 @@ if __debug__:
from trezor import config, crypto, log, loop, utils
from trezor.messages.Success import Success
from apps import workflow_handlers
if False:
from trezor.messages.DebugLinkDecision import DebugLinkDecision
from trezor.messages.DebugLinkGetState import DebugLinkGetState
@ -183,10 +185,19 @@ if __debug__:
if not utils.EMULATOR:
config.wipe()
wire.add(MessageType.LoadDevice, __name__, "load_device")
wire.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore
wire.register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState)
wire.register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom)
wire.register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen)
wire.register(MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard)
wire.register(MessageType.DebugLinkWatchLayout, dispatch_DebugLinkWatchLayout)
workflow_handlers.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore
workflow_handlers.register(
MessageType.DebugLinkGetState, dispatch_DebugLinkGetState
)
workflow_handlers.register(
MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom
)
workflow_handlers.register(
MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen
)
workflow_handlers.register(
MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard
)
workflow_handlers.register(
MessageType.DebugLinkWatchLayout, dispatch_DebugLinkWatchLayout
)

@ -1,13 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_BIP44
CURVE = "secp256k1"
SLIP44_ID = 194
PATTERN = PATTERN_BIP44
def boot() -> None:
wire.add(MessageType.EosGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.EosSignTx, __name__, "sign_tx")

@ -1,12 +1 @@
from trezor import wire
from trezor.messages import MessageType
CURVE = "secp256k1"
def boot() -> None:
wire.add(MessageType.EthereumGetAddress, __name__, "get_address")
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx")
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message")
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")

@ -1,16 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519"
SLIP44_ID = 134
PATTERN = PATTERN_SEP5
def boot() -> None:
wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.LiskGetAddress, __name__, "get_address")
wire.add(MessageType.LiskSignTx, __name__, "sign_tx")
wire.add(MessageType.LiskSignMessage, __name__, "sign_message")
wire.add(MessageType.LiskVerifyMessage, __name__, "verify_message")

@ -1,16 +0,0 @@
from trezor import wire
from trezor.messages import MessageType
def boot() -> None:
wire.add(MessageType.ResetDevice, __name__, "reset_device")
wire.add(MessageType.BackupDevice, __name__, "backup_device")
wire.add(MessageType.WipeDevice, __name__, "wipe_device")
wire.add(MessageType.RecoveryDevice, __name__, "recovery_device")
wire.add(MessageType.ApplySettings, __name__, "apply_settings")
wire.add(MessageType.ApplyFlags, __name__, "apply_flags")
wire.add(MessageType.ChangePin, __name__, "change_pin")
wire.add(MessageType.SetU2FCounter, __name__, "set_u2f_counter")
wire.add(MessageType.GetNextU2FCounter, __name__, "get_next_u2f_counter")
wire.add(MessageType.SdProtect, __name__, "sd_protect")
wire.add(MessageType.ChangeWipeCode, __name__, "change_wipe_code")

@ -1,9 +0,0 @@
from trezor import wire
from trezor.messages import MessageType
def boot() -> None:
wire.add(MessageType.GetEntropy, __name__, "get_entropy")
wire.add(MessageType.SignIdentity, __name__, "sign_identity")
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key")
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value")

@ -1,20 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519"
SLIP44_ID = 128
PATTERN = PATTERN_SEP5
def boot() -> None:
wire.add(MessageType.MoneroGetAddress, __name__, "get_address")
wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only")
wire.add(MessageType.MoneroTransactionInitRequest, __name__, "sign_tx")
wire.add(MessageType.MoneroKeyImageExportInitRequest, __name__, "key_image_sync")
wire.add(MessageType.MoneroGetTxKeyRequest, __name__, "get_tx_keys")
wire.add(MessageType.MoneroLiveRefreshStartRequest, __name__, "live_refresh")
if __debug__ and hasattr(MessageType, "DebugMoneroDiagRequest"):
wire.add(MessageType.DebugMoneroDiagRequest, __name__, "diag")

@ -1,6 +1,3 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519-keccak"
@ -10,8 +7,3 @@ PATTERNS = (
PATTERN_SEP5,
"m/44'/coin_type'/account'/0'/0'", # NanoWallet compatibility
)
def boot() -> None:
wire.add(MessageType.NEMGetAddress, __name__, "get_address")
wire.add(MessageType.NEMSignTx, __name__, "sign_tx")

@ -1,13 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_BIP44
CURVE = "secp256k1"
SLIP44_ID = 144
PATTERN = PATTERN_BIP44
def boot() -> None:
wire.add(MessageType.RippleGetAddress, __name__, "get_address")
wire.add(MessageType.RippleSignTx, __name__, "sign_tx")

@ -1,13 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519"
SLIP44_ID = 148
PATTERN = PATTERN_SEP5
def boot() -> None:
wire.add(MessageType.StellarGetAddress, __name__, "get_address")
wire.add(MessageType.StellarSignTx, __name__, "sign_tx")

@ -1,6 +1,3 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519"
@ -9,9 +6,3 @@ PATTERNS = (
PATTERN_SEP5,
"m/44'/coin_type'/0'/account'", # Ledger compatibility
)
def boot() -> None:
wire.add(MessageType.TezosGetAddress, __name__, "get_address")
wire.add(MessageType.TezosSignTx, __name__, "sign_tx")
wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key")

@ -1,24 +1,9 @@
from trezor import loop, wire
from trezor.messages import MessageType
from trezor import loop
import usb
from .fido2 import handle_reports
def boot() -> None:
wire.add(
MessageType.WebAuthnListResidentCredentials,
__name__,
"list_resident_credentials",
)
wire.add(
MessageType.WebAuthnAddResidentCredential, __name__, "add_resident_credential"
)
wire.add(
MessageType.WebAuthnRemoveResidentCredential,
__name__,
"remove_resident_credential",
)
import usb
if usb.ENABLE_IFACE_WEBAUTHN:
loop.schedule(handle_reports(usb.iface_webauthn))
loop.schedule(handle_reports(usb.iface_webauthn))

@ -0,0 +1,200 @@
from trezor import utils
from trezor.messages import MessageType
if False:
from trezor.wire import Handler
from trezorio import WireInterface
workflow_handlers: dict[int, Handler] = {}
def register(wire_type: int, handler: Handler) -> None:
"""Register `handler` to get scheduled after `wire_type` message is received."""
workflow_handlers[wire_type] = handler
def find_message_handler_module(msg_type: int) -> str:
"""Statically find the appropriate workflow handler.
For now, new messages must be registered by hand in the if-elif manner below.
The reason for this is memory fragmentation optimization:
- using a dict would mean that the whole thing stays in RAM, whereas an if-elif
sequence is run from flash
- collecting everything as strings instead of importing directly means that we don't
need to load any of the modules into memory until we actually need them
"""
if False:
raise RuntimeError
# debug
elif __debug__ and msg_type == MessageType.LoadDevice:
return "apps.debug.load_device"
# management
elif msg_type == MessageType.ResetDevice:
return "apps.management.reset_device"
elif msg_type == MessageType.BackupDevice:
return "apps.management.backup_device"
elif msg_type == MessageType.WipeDevice:
return "apps.management.wipe_device"
elif msg_type == MessageType.RecoveryDevice:
return "apps.management.recovery_device"
elif msg_type == MessageType.ApplySettings:
return "apps.management.apply_settings"
elif msg_type == MessageType.ApplyFlags:
return "apps.management.apply_flags"
elif msg_type == MessageType.ChangePin:
return "apps.management.change_pin"
elif msg_type == MessageType.SetU2FCounter:
return "apps.management.set_u2f_counter"
elif msg_type == MessageType.GetNextU2FCounter:
return "apps.management.get_next_u2f_counter"
elif msg_type == MessageType.SdProtect:
return "apps.management.sd_protect"
elif msg_type == MessageType.ChangeWipeCode:
return "apps.management.change_wipe_code"
# bitcoin
elif msg_type == MessageType.AuthorizeCoinJoin:
return "apps.bitcoin.authorize_coinjoin"
elif msg_type == MessageType.GetPublicKey:
return "apps.bitcoin.get_public_key"
elif msg_type == MessageType.GetAddress:
return "apps.bitcoin.get_address"
elif msg_type == MessageType.GetOwnershipId:
return "apps.bitcoin.get_ownership_id"
elif msg_type == MessageType.GetOwnershipProof:
return "apps.bitcoin.get_ownership_proof"
elif msg_type == MessageType.SignTx:
return "apps.bitcoin.sign_tx"
elif msg_type == MessageType.SignMessage:
return "apps.bitcoin.sign_message"
elif msg_type == MessageType.VerifyMessage:
return "apps.bitcoin.verify_message"
# misc
elif msg_type == MessageType.GetEntropy:
return "apps.misc.get_entropy"
elif msg_type == MessageType.SignIdentity:
return "apps.misc.sign_identity"
elif msg_type == MessageType.GetECDHSessionKey:
return "apps.misc.get_ecdh_session_key"
elif msg_type == MessageType.CipherKeyValue:
return "apps.misc.cipher_key_value"
elif not utils.BITCOIN_ONLY:
if False:
raise RuntimeError
# webauthn
elif msg_type == MessageType.WebAuthnListResidentCredentials:
return "apps.webauthn.list_resident_credentials"
elif msg_type == MessageType.WebAuthnAddResidentCredential:
return "apps.webauthn.add_resident_credential"
elif msg_type == MessageType.WebAuthnRemoveResidentCredential:
return "apps.webauthn.remove_resident_credential"
# ethereum
elif msg_type == MessageType.EthereumGetAddress:
return "apps.ethereum.get_address"
elif msg_type == MessageType.EthereumGetPublicKey:
return "apps.ethereum.get_public_key"
elif msg_type == MessageType.EthereumSignTx:
return "apps.ethereum.sign_tx"
elif msg_type == MessageType.EthereumSignMessage:
return "apps.ethereum.sign_message"
elif msg_type == MessageType.EthereumVerifyMessage:
return "apps.ethereum.verify_message"
# lisk
elif msg_type == MessageType.LiskGetPublicKey:
return "apps.lisk.get_public_key"
elif msg_type == MessageType.LiskGetAddress:
return "apps.lisk.get_address"
elif msg_type == MessageType.LiskSignTx:
return "apps.lisk.sign_tx"
elif msg_type == MessageType.LiskSignMessage:
return "apps.lisk.sign_message"
elif msg_type == MessageType.LiskVerifyMessage:
return "apps.lisk.verify_message"
# monero
elif msg_type == MessageType.MoneroGetAddress:
return "apps.monero.get_address"
elif msg_type == MessageType.MoneroGetWatchKey:
return "apps.monero.get_watch_only"
elif msg_type == MessageType.MoneroTransactionInitRequest:
return "apps.monero.sign_tx"
elif msg_type == MessageType.MoneroKeyImageExportInitRequest:
return "apps.monero.key_image_sync"
elif msg_type == MessageType.MoneroGetTxKeyRequest:
return "apps.monero.get_tx_keys"
elif msg_type == MessageType.MoneroLiveRefreshStartRequest:
return "apps.monero.live_refresh"
if __debug__ and msg_type == MessageType.DebugMoneroDiagRequest:
return "apps.monero.diag"
# nem
elif msg_type == MessageType.NEMGetAddress:
return "apps.nem.get_address"
elif msg_type == MessageType.NEMSignTx:
return "apps.nem.sign_tx"
# stellar
elif msg_type == MessageType.StellarGetAddress:
return "apps.stellar.get_address"
elif msg_type == MessageType.StellarSignTx:
return "apps.stellar.sign_tx"
# ripple
elif msg_type == MessageType.RippleGetAddress:
return "apps.ripple.get_address"
elif msg_type == MessageType.RippleSignTx:
return "apps.ripple.sign_tx"
# cardano
elif msg_type == MessageType.CardanoGetAddress:
return "apps.cardano.get_address"
elif msg_type == MessageType.CardanoGetPublicKey:
return "apps.cardano.get_public_key"
elif msg_type == MessageType.CardanoSignTx:
return "apps.cardano.sign_tx"
# tezos
elif msg_type == MessageType.TezosGetAddress:
return "apps.tezos.get_address"
elif msg_type == MessageType.TezosSignTx:
return "apps.tezos.sign_tx"
elif msg_type == MessageType.TezosGetPublicKey:
return "apps.tezos.get_public_key"
# eos
elif msg_type == MessageType.EosGetPublicKey:
return "apps.eos.get_public_key"
elif msg_type == MessageType.EosSignTx:
return "apps.eos.sign_tx"
# binance
elif msg_type == MessageType.BinanceGetAddress:
return "apps.binance.get_address"
elif msg_type == MessageType.BinanceGetPublicKey:
return "apps.binance.get_public_key"
elif msg_type == MessageType.BinanceSignTx:
return "apps.binance.sign_tx"
raise ValueError
def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None:
if msg_type in workflow_handlers:
# Message has a handler available, return it directly.
return workflow_handlers[msg_type]
try:
modname = find_message_handler_module(msg_type)
handler_name = modname[modname.rfind(".") + 1 :]
module = __import__(modname, None, None, (handler_name,), 0)
return getattr(module, handler_name) # type: ignore
except ValueError:
return None

@ -63,38 +63,15 @@ if False:
Handler = Callable[["Context", Msg], HandlerTask]
# Maps a wire type directly to a handler.
workflow_handlers: dict[int, Handler] = {}
# Maps a wire type to a tuple of package and module. This allows handlers
# to be dynamically imported when such message arrives.
workflow_packages: dict[int, tuple[str, str]] = {}
# If set to False protobuf messages marked with "unstable" option are rejected.
experimental_enabled: bool = False
def add(wire_type: int, pkgname: str, modname: str) -> None:
"""Shortcut for registering a dynamically-imported Protobuf workflow."""
workflow_packages[wire_type] = (pkgname, modname)
def register(wire_type: int, handler: Handler) -> None:
"""Register `handler` to get scheduled after `wire_type` message is received."""
workflow_handlers[wire_type] = handler
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
"""Initialize the wire stack on passed USB interface."""
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
def clear() -> None:
"""Remove all registered handlers."""
workflow_handlers.clear()
workflow_packages.clear()
if False:
from typing import Protocol
@ -459,33 +436,12 @@ async def handle_session(
log.exception(__name__, exc)
def find_registered_workflow_handler(
iface: WireInterface, msg_type: int
) -> Handler | None:
if msg_type in workflow_handlers:
# Message has a handler available, return it directly.
handler = workflow_handlers[msg_type]
elif msg_type in workflow_packages:
# Message needs a dynamically imported handler, import it.
pkgname, modname = workflow_packages[msg_type]
handler = import_workflow(pkgname, modname)
else:
# Message does not have any registered handler.
return None
return handler
find_handler = find_registered_workflow_handler
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
"""Placeholder handler lookup before a proper one is registered."""
return None
def import_workflow(pkgname: str, modname: str) -> Any:
modpath = "%s.%s" % (pkgname, modname)
module = __import__(modpath, None, None, (modname,), 0)
handler = getattr(module, modname)
return handler
find_handler = _find_handler_placeholder
def failure(exc: BaseException) -> Failure:

Loading…
Cancel
Save