diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 74b22600ec..e3cee6bf3c 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -166,7 +166,7 @@ if __debug__: wire.add(MessageType.LoadDevice, __name__, "load_device") wire.add(MessageType.DebugLinkShowText, __name__, "show_text") - wire.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) + wire.register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) # type: ignore wire.register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) wire.register(MessageType.DebugLinkReseedRandom, dispatch_DebugLinkReseedRandom) wire.register(MessageType.DebugLinkRecordScreen, dispatch_DebugLinkRecordScreen) diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index e09ac7ead5..51d69328f4 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -51,15 +51,19 @@ if False: Any, Awaitable, Callable, + Coroutine, Dict, Iterable, Optional, Tuple, Type, + TypeVar, ) from trezorio import WireInterface - Handler = Callable[..., loop.Task] + Msg = TypeVar("Msg", bound=protobuf.MessageType) + HandlerTask = Coroutine[Any, Any, protobuf.MessageType] + Handler = Callable[["Context", Msg], HandlerTask] # Maps a wire type directly to a handler. @@ -257,7 +261,7 @@ class UnexpectedMessageError(Exception): async def handle_session(iface: WireInterface, session_id: int) -> None: ctx = Context(iface, session_id) next_reader = None # type: Optional[codec_v1.Reader] - res_msg = None + res_msg = None # type: Optional[protobuf.MessageType] req_reader = None req_type = None req_msg = None @@ -301,7 +305,7 @@ async def handle_session(iface: WireInterface, session_id: int) -> None: # We need to find a handler for this message type. Should not # raise. - handler = get_workflow_handler(req_reader) + handler = find_handler(iface, req_reader.type) if handler is None: # If no handler is found, we can skip decoding and directly @@ -314,7 +318,7 @@ async def handle_session(iface: WireInterface, session_id: int) -> None: # We found a valid handler for this message type. # Workflow task, declared for the `workflow.on_close` call later. - wf_task = None # type: Optional[loop.Task] + wf_task = None # type: Optional[HandlerTask] # Here we make sure we always respond with a Failure response # in case of any errors. @@ -407,9 +411,9 @@ async def handle_session(iface: WireInterface, session_id: int) -> None: log.exception(__name__, exc) -def get_workflow_handler(reader: codec_v1.Reader) -> Optional[Handler]: - msg_type = reader.type - +def find_registered_workflow_handler( + iface: WireInterface, msg_type: int +) -> Optional[Handler]: if msg_type in workflow_handlers: # Message has a handler available, return it directly. handler = workflow_handlers[msg_type] @@ -426,6 +430,9 @@ def get_workflow_handler(reader: codec_v1.Reader) -> Optional[Handler]: return handler +find_handler = find_registered_workflow_handler + + def import_workflow(pkgname: str, modname: str) -> Any: modpath = "%s.%s" % (pkgname, modname) module = __import__(modpath, None, None, (modname,), 0)