1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-23 14:58:09 +00:00

core/wire: make handler lookup pluggable

This commit is contained in:
matejcik 2020-04-21 13:06:09 +02:00 committed by matejcik
parent 837c4df61f
commit 341c5b7d10
2 changed files with 15 additions and 8 deletions

View File

@ -166,7 +166,7 @@ if __debug__:
wire.add(MessageType.LoadDevice, __name__, "load_device") wire.add(MessageType.LoadDevice, __name__, "load_device")
wire.add(MessageType.DebugLinkShowText, __name__, "show_text") wire.add(MessageType.DebugLinkShowText, __name__, "show_text")
wire.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) wire.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore
wire.register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) wire.register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState)
wire.register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom) wire.register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom)
wire.register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen) wire.register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen)

View File

@ -51,15 +51,19 @@ if False:
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Coroutine,
Dict, Dict,
Iterable, Iterable,
Optional, Optional,
Tuple, Tuple,
Type, Type,
TypeVar,
) )
from trezorio import WireInterface from trezorio import WireInterface
Handler = Callable[..., loop.Task] Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[["Context", Msg], HandlerTask]
# Maps a wire type directly to a handler. # Maps a wire type directly to a handler.
@ -257,7 +261,7 @@ class UnexpectedMessageError(Exception):
async def handle_session(iface: WireInterface, session_id: int) -> None: async def handle_session(iface: WireInterface, session_id: int) -> None:
ctx = Context(iface, session_id) ctx = Context(iface, session_id)
next_reader = None # type: Optional[codec_v1.Reader] next_reader = None # type: Optional[codec_v1.Reader]
res_msg = None res_msg = None # type: Optional[protobuf.MessageType]
req_reader = None req_reader = None
req_type = None req_type = None
req_msg = None req_msg = None
@ -301,7 +305,7 @@ async def handle_session(iface: WireInterface, session_id: int) -> None:
# We need to find a handler for this message type. Should not # We need to find a handler for this message type. Should not
# raise. # raise.
handler = get_workflow_handler(req_reader) handler = find_handler(iface, req_reader.type)
if handler is None: if handler is None:
# If no handler is found, we can skip decoding and directly # If no handler is found, we can skip decoding and directly
@ -314,7 +318,7 @@ async def handle_session(iface: WireInterface, session_id: int) -> None:
# We found a valid handler for this message type. # We found a valid handler for this message type.
# Workflow task, declared for the `workflow.on_close` call later. # Workflow task, declared for the `workflow.on_close` call later.
wf_task = None # type: Optional[loop.Task] wf_task = None # type: Optional[HandlerTask]
# Here we make sure we always respond with a Failure response # Here we make sure we always respond with a Failure response
# in case of any errors. # in case of any errors.
@ -407,9 +411,9 @@ async def handle_session(iface: WireInterface, session_id: int) -> None:
log.exception(__name__, exc) log.exception(__name__, exc)
def get_workflow_handler(reader: codec_v1.Reader) -> Optional[Handler]: def find_registered_workflow_handler(
msg_type = reader.type iface: WireInterface, msg_type: int
) -> Optional[Handler]:
if msg_type in workflow_handlers: if msg_type in workflow_handlers:
# Message has a handler available, return it directly. # Message has a handler available, return it directly.
handler = workflow_handlers[msg_type] handler = workflow_handlers[msg_type]
@ -426,6 +430,9 @@ def get_workflow_handler(reader: codec_v1.Reader) -> Optional[Handler]:
return handler return handler
find_handler = find_registered_workflow_handler
def import_workflow(pkgname: str, modname: str) -> Any: def import_workflow(pkgname: str, modname: str) -> Any:
modpath = "%s.%s" % (pkgname, modname) modpath = "%s.%s" % (pkgname, modname)
module = __import__(modpath, None, None, (modname,), 0) module = __import__(modpath, None, None, (modname,), 0)