diff --git a/core/src/trezor/messages/__init__.py b/core/src/trezor/messages/__init__.py index 94dc675df..a85980add 100644 --- a/core/src/trezor/messages/__init__.py +++ b/core/src/trezor/messages/__init__.py @@ -1,43 +1,19 @@ -from trezor.messages import MessageType - -if __debug__: - from trezor import log - if False: - from protobuf import MessageType as MessageType_ # noqa: F401 - - MessageClass = type[MessageType_] - -type_to_name: dict[int, str] = {} # reverse table of wire_type mapping -registered: dict[int, MessageClass] = {} # dynamically registered types + import protobuf + from typing import Type -def register(msg_type: MessageClass) -> None: - """Register custom message type in runtime.""" - if __debug__: - log.debug(__name__, "register %s", msg_type) - registered[msg_type.MESSAGE_WIRE_TYPE] = msg_type - - -def get_type(wire_type: int) -> MessageClass: +def get_type(wire_type: int) -> Type[protobuf.MessageType]: """Get message class for handling given wire_type.""" - if wire_type in registered: - # message class is explicitly registered - msg_type = registered[wire_type] - else: - # import message class from trezor.messages dynamically - name = type_to_name[wire_type] - module = __import__("trezor.messages.%s" % name, None, None, (name,), 0) - msg_type = getattr(module, name) - return msg_type - - -# build reverse table of wire types -for msg_name in dir(MessageType): - # Modules contain internal variables that may cause exception here. - # No Message begins with underscore so it's safe to skip those. - if msg_name[0] == "_": - continue - if msg_name == "utils": # skip imported trezor.utils - continue - type_to_name[getattr(MessageType, msg_name)] = msg_name + from trezor.messages import MessageType + + for msg_name in dir(MessageType): + # walk the list of symbols in MessageType + if getattr(MessageType, msg_name) == wire_type: + # import submodule/class of the same name + module = __import__( + "trezor.messages.%s" % msg_name, None, None, (msg_name,), 0 + ) + return getattr(module, msg_name) # type: ignore + + raise KeyError