refactor(core): move app registrations to a single handler function

apps.webauthn.boot() does not need an if-condition because it's only
called from session.py when the usb interface is enabled

This means that they do not need to be stored in RAM at all. The obvious
drawback is that we need to hand-edit the if/elif sequence, but we don't
register new handlers all that often so 🤷
pull/1610/head
matejcik 3 years ago committed by matejcik
parent 391976bcda
commit e629a72c3a

@ -12,6 +12,8 @@ from trezor.messages.Success import Success
from apps.common import mnemonic, safety_checks from apps.common import mnemonic, safety_checks
from apps.common.request_pin import verify_user_pin from apps.common.request_pin import verify_user_pin
from . import workflow_handlers
if False: if False:
import protobuf import protobuf
from typing import Iterable, NoReturn, Protocol from typing import Iterable, NoReturn, Protocol
@ -149,7 +151,9 @@ async def handle_DoPreauthorized(
PreauthorizedRequest(), *authorization.expected_wire_types() PreauthorizedRequest(), *authorization.expected_wire_types()
) )
handler = wire.find_registered_workflow_handler(ctx.iface, req.MESSAGE_WIRE_TYPE) handler = workflow_handlers.find_registered_handler(
ctx.iface, req.MESSAGE_WIRE_TYPE
)
if handler is None: if handler is None:
return wire.unexpected_message() return wire.unexpected_message()
@ -230,13 +234,13 @@ async def unlock_device(ctx: wire.GenericContext = wire.DUMMY_CONTEXT) -> None:
await verify_user_pin(ctx) await verify_user_pin(ctx)
set_homescreen() set_homescreen()
wire.find_handler = wire.find_registered_workflow_handler wire.find_handler = workflow_handlers.find_registered_handler
def get_pinlocked_handler( def get_pinlocked_handler(
iface: wire.WireInterface, msg_type: int iface: wire.WireInterface, msg_type: int
) -> wire.Handler[wire.Msg] | None: ) -> wire.Handler[wire.Msg] | None:
orig_handler = wire.find_registered_workflow_handler(iface, msg_type) orig_handler = workflow_handlers.find_registered_handler(iface, msg_type)
if orig_handler is None: if orig_handler is None:
return None return None
@ -268,13 +272,19 @@ def reload_settings_from_storage() -> None:
def boot() -> None: def boot() -> None:
wire.register(MessageType.Initialize, handle_Initialize) workflow_handlers.register(MessageType.Initialize, handle_Initialize)
wire.register(MessageType.GetFeatures, handle_GetFeatures) workflow_handlers.register(MessageType.GetFeatures, handle_GetFeatures)
wire.register(MessageType.Cancel, handle_Cancel) workflow_handlers.register(MessageType.Cancel, handle_Cancel)
wire.register(MessageType.LockDevice, handle_LockDevice) workflow_handlers.register(MessageType.LockDevice, handle_LockDevice)
wire.register(MessageType.EndSession, handle_EndSession) workflow_handlers.register(MessageType.EndSession, handle_EndSession)
wire.register(MessageType.Ping, handle_Ping) workflow_handlers.register(MessageType.Ping, handle_Ping)
wire.register(MessageType.DoPreauthorized, handle_DoPreauthorized) workflow_handlers.register(MessageType.DoPreauthorized, handle_DoPreauthorized)
wire.register(MessageType.CancelAuthorization, handle_CancelAuthorization) workflow_handlers.register(
MessageType.CancelAuthorization, handle_CancelAuthorization
)
reload_settings_from_storage() reload_settings_from_storage()
if config.is_unlocked():
wire.find_handler = workflow_handlers.find_registered_handler
else:
wire.find_handler = get_pinlocked_handler

@ -1,14 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_BIP44 from apps.common.paths import PATTERN_BIP44
CURVE = "secp256k1" CURVE = "secp256k1"
SLIP44_ID = 714 SLIP44_ID = 714
PATTERN = PATTERN_BIP44 PATTERN = PATTERN_BIP44
def boot() -> None:
wire.add(MessageType.BinanceGetAddress, __name__, "get_address")
wire.add(MessageType.BinanceGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.BinanceSignTx, __name__, "sign_tx")

@ -1,13 +0,0 @@
from trezor import wire
from trezor.messages import MessageType
def boot() -> None:
wire.add(MessageType.AuthorizeCoinJoin, __name__, "authorize_coinjoin")
wire.add(MessageType.GetPublicKey, __name__, "get_public_key")
wire.add(MessageType.GetAddress, __name__, "get_address")
wire.add(MessageType.GetOwnershipId, __name__, "get_ownership_id")
wire.add(MessageType.GetOwnershipProof, __name__, "get_ownership_proof")
wire.add(MessageType.SignTx, __name__, "sign_tx")
wire.add(MessageType.SignMessage, __name__, "sign_message")
wire.add(MessageType.VerifyMessage, __name__, "verify_message")

@ -1,10 +0,0 @@
from trezor import wire
from trezor.messages import MessageType
CURVE = "ed25519"
def boot() -> None:
wire.add(MessageType.CardanoGetAddress, __name__, "get_address")
wire.add(MessageType.CardanoGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.CardanoSignTx, __name__, "sign_tx")

@ -10,6 +10,8 @@ if __debug__:
from trezor import config, crypto, log, loop, utils from trezor import config, crypto, log, loop, utils
from trezor.messages.Success import Success from trezor.messages.Success import Success
from apps import workflow_handlers
if False: if False:
from trezor.messages.DebugLinkDecision import DebugLinkDecision from trezor.messages.DebugLinkDecision import DebugLinkDecision
from trezor.messages.DebugLinkGetState import DebugLinkGetState from trezor.messages.DebugLinkGetState import DebugLinkGetState
@ -183,10 +185,19 @@ if __debug__:
if not utils.EMULATOR: if not utils.EMULATOR:
config.wipe() config.wipe()
wire.add(MessageType.LoadDevice, __name__, "load_device") workflow_handlers.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore
wire.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore workflow_handlers.register(
wire.register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) MessageType.DebugLinkGetState, dispatch_DebugLinkGetState
wire.register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom) )
wire.register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen) workflow_handlers.register(
wire.register(MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard) MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom
wire.register(MessageType.DebugLinkWatchLayout, dispatch_DebugLinkWatchLayout) )
workflow_handlers.register(
MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen
)
workflow_handlers.register(
MessageType.DebugLinkEraseSdCard, dispatch_DebugLinkEraseSdCard
)
workflow_handlers.register(
MessageType.DebugLinkWatchLayout, dispatch_DebugLinkWatchLayout
)

@ -1,13 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_BIP44 from apps.common.paths import PATTERN_BIP44
CURVE = "secp256k1" CURVE = "secp256k1"
SLIP44_ID = 194 SLIP44_ID = 194
PATTERN = PATTERN_BIP44 PATTERN = PATTERN_BIP44
def boot() -> None:
wire.add(MessageType.EosGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.EosSignTx, __name__, "sign_tx")

@ -1,12 +1 @@
from trezor import wire
from trezor.messages import MessageType
CURVE = "secp256k1" CURVE = "secp256k1"
def boot() -> None:
wire.add(MessageType.EthereumGetAddress, __name__, "get_address")
wire.add(MessageType.EthereumGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.EthereumSignTx, __name__, "sign_tx")
wire.add(MessageType.EthereumSignMessage, __name__, "sign_message")
wire.add(MessageType.EthereumVerifyMessage, __name__, "verify_message")

@ -1,16 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5 from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519" CURVE = "ed25519"
SLIP44_ID = 134 SLIP44_ID = 134
PATTERN = PATTERN_SEP5 PATTERN = PATTERN_SEP5
def boot() -> None:
wire.add(MessageType.LiskGetPublicKey, __name__, "get_public_key")
wire.add(MessageType.LiskGetAddress, __name__, "get_address")
wire.add(MessageType.LiskSignTx, __name__, "sign_tx")
wire.add(MessageType.LiskSignMessage, __name__, "sign_message")
wire.add(MessageType.LiskVerifyMessage, __name__, "verify_message")

@ -1,16 +0,0 @@
from trezor import wire
from trezor.messages import MessageType
def boot() -> None:
wire.add(MessageType.ResetDevice, __name__, "reset_device")
wire.add(MessageType.BackupDevice, __name__, "backup_device")
wire.add(MessageType.WipeDevice, __name__, "wipe_device")
wire.add(MessageType.RecoveryDevice, __name__, "recovery_device")
wire.add(MessageType.ApplySettings, __name__, "apply_settings")
wire.add(MessageType.ApplyFlags, __name__, "apply_flags")
wire.add(MessageType.ChangePin, __name__, "change_pin")
wire.add(MessageType.SetU2FCounter, __name__, "set_u2f_counter")
wire.add(MessageType.GetNextU2FCounter, __name__, "get_next_u2f_counter")
wire.add(MessageType.SdProtect, __name__, "sd_protect")
wire.add(MessageType.ChangeWipeCode, __name__, "change_wipe_code")

@ -1,9 +0,0 @@
from trezor import wire
from trezor.messages import MessageType
def boot() -> None:
wire.add(MessageType.GetEntropy, __name__, "get_entropy")
wire.add(MessageType.SignIdentity, __name__, "sign_identity")
wire.add(MessageType.GetECDHSessionKey, __name__, "get_ecdh_session_key")
wire.add(MessageType.CipherKeyValue, __name__, "cipher_key_value")

@ -1,20 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5 from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519" CURVE = "ed25519"
SLIP44_ID = 128 SLIP44_ID = 128
PATTERN = PATTERN_SEP5 PATTERN = PATTERN_SEP5
def boot() -> None:
wire.add(MessageType.MoneroGetAddress, __name__, "get_address")
wire.add(MessageType.MoneroGetWatchKey, __name__, "get_watch_only")
wire.add(MessageType.MoneroTransactionInitRequest, __name__, "sign_tx")
wire.add(MessageType.MoneroKeyImageExportInitRequest, __name__, "key_image_sync")
wire.add(MessageType.MoneroGetTxKeyRequest, __name__, "get_tx_keys")
wire.add(MessageType.MoneroLiveRefreshStartRequest, __name__, "live_refresh")
if __debug__ and hasattr(MessageType, "DebugMoneroDiagRequest"):
wire.add(MessageType.DebugMoneroDiagRequest, __name__, "diag")

@ -1,6 +1,3 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5 from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519-keccak" CURVE = "ed25519-keccak"
@ -10,8 +7,3 @@ PATTERNS = (
PATTERN_SEP5, PATTERN_SEP5,
"m/44'/coin_type'/account'/0'/0'", # NanoWallet compatibility "m/44'/coin_type'/account'/0'/0'", # NanoWallet compatibility
) )
def boot() -> None:
wire.add(MessageType.NEMGetAddress, __name__, "get_address")
wire.add(MessageType.NEMSignTx, __name__, "sign_tx")

@ -1,13 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_BIP44 from apps.common.paths import PATTERN_BIP44
CURVE = "secp256k1" CURVE = "secp256k1"
SLIP44_ID = 144 SLIP44_ID = 144
PATTERN = PATTERN_BIP44 PATTERN = PATTERN_BIP44
def boot() -> None:
wire.add(MessageType.RippleGetAddress, __name__, "get_address")
wire.add(MessageType.RippleSignTx, __name__, "sign_tx")

@ -1,13 +1,5 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5 from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519" CURVE = "ed25519"
SLIP44_ID = 148 SLIP44_ID = 148
PATTERN = PATTERN_SEP5 PATTERN = PATTERN_SEP5
def boot() -> None:
wire.add(MessageType.StellarGetAddress, __name__, "get_address")
wire.add(MessageType.StellarSignTx, __name__, "sign_tx")

@ -1,6 +1,3 @@
from trezor import wire
from trezor.messages import MessageType
from apps.common.paths import PATTERN_SEP5 from apps.common.paths import PATTERN_SEP5
CURVE = "ed25519" CURVE = "ed25519"
@ -9,9 +6,3 @@ PATTERNS = (
PATTERN_SEP5, PATTERN_SEP5,
"m/44'/coin_type'/0'/account'", # Ledger compatibility "m/44'/coin_type'/0'/account'", # Ledger compatibility
) )
def boot() -> None:
wire.add(MessageType.TezosGetAddress, __name__, "get_address")
wire.add(MessageType.TezosSignTx, __name__, "sign_tx")
wire.add(MessageType.TezosGetPublicKey, __name__, "get_public_key")

@ -1,24 +1,9 @@
from trezor import loop, wire from trezor import loop
from trezor.messages import MessageType
import usb
from .fido2 import handle_reports from .fido2 import handle_reports
def boot() -> None: def boot() -> None:
wire.add( loop.schedule(handle_reports(usb.iface_webauthn))
MessageType.WebAuthnListResidentCredentials,
__name__,
"list_resident_credentials",
)
wire.add(
MessageType.WebAuthnAddResidentCredential, __name__, "add_resident_credential"
)
wire.add(
MessageType.WebAuthnRemoveResidentCredential,
__name__,
"remove_resident_credential",
)
import usb
if usb.ENABLE_IFACE_WEBAUTHN:
loop.schedule(handle_reports(usb.iface_webauthn))

@ -0,0 +1,200 @@
from trezor import utils
from trezor.messages import MessageType
if False:
from trezor.wire import Handler
from trezorio import WireInterface
workflow_handlers: dict[int, Handler] = {}
def register(wire_type: int, handler: Handler) -> None:
"""Register `handler` to get scheduled after `wire_type` message is received."""
workflow_handlers[wire_type] = handler
def find_message_handler_module(msg_type: int) -> str:
"""Statically find the appropriate workflow handler.
For now, new messages must be registered by hand in the if-elif manner below.
The reason for this is memory fragmentation optimization:
- using a dict would mean that the whole thing stays in RAM, whereas an if-elif
sequence is run from flash
- collecting everything as strings instead of importing directly means that we don't
need to load any of the modules into memory until we actually need them
"""
if False:
raise RuntimeError
# debug
elif __debug__ and msg_type == MessageType.LoadDevice:
return "apps.debug.load_device"
# management
elif msg_type == MessageType.ResetDevice:
return "apps.management.reset_device"
elif msg_type == MessageType.BackupDevice:
return "apps.management.backup_device"
elif msg_type == MessageType.WipeDevice:
return "apps.management.wipe_device"
elif msg_type == MessageType.RecoveryDevice:
return "apps.management.recovery_device"
elif msg_type == MessageType.ApplySettings:
return "apps.management.apply_settings"
elif msg_type == MessageType.ApplyFlags:
return "apps.management.apply_flags"
elif msg_type == MessageType.ChangePin:
return "apps.management.change_pin"
elif msg_type == MessageType.SetU2FCounter:
return "apps.management.set_u2f_counter"
elif msg_type == MessageType.GetNextU2FCounter:
return "apps.management.get_next_u2f_counter"
elif msg_type == MessageType.SdProtect:
return "apps.management.sd_protect"
elif msg_type == MessageType.ChangeWipeCode:
return "apps.management.change_wipe_code"
# bitcoin
elif msg_type == MessageType.AuthorizeCoinJoin:
return "apps.bitcoin.authorize_coinjoin"
elif msg_type == MessageType.GetPublicKey:
return "apps.bitcoin.get_public_key"
elif msg_type == MessageType.GetAddress:
return "apps.bitcoin.get_address"
elif msg_type == MessageType.GetOwnershipId:
return "apps.bitcoin.get_ownership_id"
elif msg_type == MessageType.GetOwnershipProof:
return "apps.bitcoin.get_ownership_proof"
elif msg_type == MessageType.SignTx:
return "apps.bitcoin.sign_tx"
elif msg_type == MessageType.SignMessage:
return "apps.bitcoin.sign_message"
elif msg_type == MessageType.VerifyMessage:
return "apps.bitcoin.verify_message"
# misc
elif msg_type == MessageType.GetEntropy:
return "apps.misc.get_entropy"
elif msg_type == MessageType.SignIdentity:
return "apps.misc.sign_identity"
elif msg_type == MessageType.GetECDHSessionKey:
return "apps.misc.get_ecdh_session_key"
elif msg_type == MessageType.CipherKeyValue:
return "apps.misc.cipher_key_value"
elif not utils.BITCOIN_ONLY:
if False:
raise RuntimeError
# webauthn
elif msg_type == MessageType.WebAuthnListResidentCredentials:
return "apps.webauthn.list_resident_credentials"
elif msg_type == MessageType.WebAuthnAddResidentCredential:
return "apps.webauthn.add_resident_credential"
elif msg_type == MessageType.WebAuthnRemoveResidentCredential:
return "apps.webauthn.remove_resident_credential"
# ethereum
elif msg_type == MessageType.EthereumGetAddress:
return "apps.ethereum.get_address"
elif msg_type == MessageType.EthereumGetPublicKey:
return "apps.ethereum.get_public_key"
elif msg_type == MessageType.EthereumSignTx:
return "apps.ethereum.sign_tx"
elif msg_type == MessageType.EthereumSignMessage:
return "apps.ethereum.sign_message"
elif msg_type == MessageType.EthereumVerifyMessage:
return "apps.ethereum.verify_message"
# lisk
elif msg_type == MessageType.LiskGetPublicKey:
return "apps.lisk.get_public_key"
elif msg_type == MessageType.LiskGetAddress:
return "apps.lisk.get_address"
elif msg_type == MessageType.LiskSignTx:
return "apps.lisk.sign_tx"
elif msg_type == MessageType.LiskSignMessage:
return "apps.lisk.sign_message"
elif msg_type == MessageType.LiskVerifyMessage:
return "apps.lisk.verify_message"
# monero
elif msg_type == MessageType.MoneroGetAddress:
return "apps.monero.get_address"
elif msg_type == MessageType.MoneroGetWatchKey:
return "apps.monero.get_watch_only"
elif msg_type == MessageType.MoneroTransactionInitRequest:
return "apps.monero.sign_tx"
elif msg_type == MessageType.MoneroKeyImageExportInitRequest:
return "apps.monero.key_image_sync"
elif msg_type == MessageType.MoneroGetTxKeyRequest:
return "apps.monero.get_tx_keys"
elif msg_type == MessageType.MoneroLiveRefreshStartRequest:
return "apps.monero.live_refresh"
if __debug__ and msg_type == MessageType.DebugMoneroDiagRequest:
return "apps.monero.diag"
# nem
elif msg_type == MessageType.NEMGetAddress:
return "apps.nem.get_address"
elif msg_type == MessageType.NEMSignTx:
return "apps.nem.sign_tx"
# stellar
elif msg_type == MessageType.StellarGetAddress:
return "apps.stellar.get_address"
elif msg_type == MessageType.StellarSignTx:
return "apps.stellar.sign_tx"
# ripple
elif msg_type == MessageType.RippleGetAddress:
return "apps.ripple.get_address"
elif msg_type == MessageType.RippleSignTx:
return "apps.ripple.sign_tx"
# cardano
elif msg_type == MessageType.CardanoGetAddress:
return "apps.cardano.get_address"
elif msg_type == MessageType.CardanoGetPublicKey:
return "apps.cardano.get_public_key"
elif msg_type == MessageType.CardanoSignTx:
return "apps.cardano.sign_tx"
# tezos
elif msg_type == MessageType.TezosGetAddress:
return "apps.tezos.get_address"
elif msg_type == MessageType.TezosSignTx:
return "apps.tezos.sign_tx"
elif msg_type == MessageType.TezosGetPublicKey:
return "apps.tezos.get_public_key"
# eos
elif msg_type == MessageType.EosGetPublicKey:
return "apps.eos.get_public_key"
elif msg_type == MessageType.EosSignTx:
return "apps.eos.sign_tx"
# binance
elif msg_type == MessageType.BinanceGetAddress:
return "apps.binance.get_address"
elif msg_type == MessageType.BinanceGetPublicKey:
return "apps.binance.get_public_key"
elif msg_type == MessageType.BinanceSignTx:
return "apps.binance.sign_tx"
raise ValueError
def find_registered_handler(iface: WireInterface, msg_type: int) -> Handler | None:
if msg_type in workflow_handlers:
# Message has a handler available, return it directly.
return workflow_handlers[msg_type]
try:
modname = find_message_handler_module(msg_type)
handler_name = modname[modname.rfind(".") + 1 :]
module = __import__(modname, None, None, (handler_name,), 0)
return getattr(module, handler_name) # type: ignore
except ValueError:
return None

@ -63,38 +63,15 @@ if False:
Handler = Callable[["Context", Msg], HandlerTask] Handler = Callable[["Context", Msg], HandlerTask]
# Maps a wire type directly to a handler.
workflow_handlers: dict[int, Handler] = {}
# Maps a wire type to a tuple of package and module. This allows handlers
# to be dynamically imported when such message arrives.
workflow_packages: dict[int, tuple[str, str]] = {}
# If set to False protobuf messages marked with "unstable" option are rejected. # If set to False protobuf messages marked with "unstable" option are rejected.
experimental_enabled: bool = False experimental_enabled: bool = False
def add(wire_type: int, pkgname: str, modname: str) -> None:
"""Shortcut for registering a dynamically-imported Protobuf workflow."""
workflow_packages[wire_type] = (pkgname, modname)
def register(wire_type: int, handler: Handler) -> None:
"""Register `handler` to get scheduled after `wire_type` message is received."""
workflow_handlers[wire_type] = handler
def setup(iface: WireInterface, is_debug_session: bool = False) -> None: def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
"""Initialize the wire stack on passed USB interface.""" """Initialize the wire stack on passed USB interface."""
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session)) loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
def clear() -> None:
"""Remove all registered handlers."""
workflow_handlers.clear()
workflow_packages.clear()
if False: if False:
from typing import Protocol from typing import Protocol
@ -459,33 +436,12 @@ async def handle_session(
log.exception(__name__, exc) log.exception(__name__, exc)
def find_registered_workflow_handler( def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
iface: WireInterface, msg_type: int """Placeholder handler lookup before a proper one is registered."""
) -> Handler | None: return None
if msg_type in workflow_handlers:
# Message has a handler available, return it directly.
handler = workflow_handlers[msg_type]
elif msg_type in workflow_packages:
# Message needs a dynamically imported handler, import it.
pkgname, modname = workflow_packages[msg_type]
handler = import_workflow(pkgname, modname)
else:
# Message does not have any registered handler.
return None
return handler
find_handler = find_registered_workflow_handler
def import_workflow(pkgname: str, modname: str) -> Any: find_handler = _find_handler_placeholder
modpath = "%s.%s" % (pkgname, modname)
module = __import__(modpath, None, None, (modname,), 0)
handler = getattr(module, modname)
return handler
def failure(exc: BaseException) -> Failure: def failure(exc: BaseException) -> Failure:

Loading…
Cancel
Save