From 8fb45754c6c54aab1c685394a6147db01c9e87ce Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 17 Apr 2024 13:07:58 +0200 Subject: [PATCH] Remove unnecessary abstractions with handlers --- core/src/trezor/wire/thp/channel.py | 7 ++--- core/src/trezor/wire/thp/handler_provider.py | 10 +----- core/src/trezor/wire/thp/pairing_context.py | 33 ++++---------------- 3 files changed, 9 insertions(+), 41 deletions(-) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 026604006..c719c279b 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -374,14 +374,11 @@ class Channel(Context): ) async def _handle_pairing(self, message_length: int) -> None: - from trezor.wire.thp.handler_provider import get_handler_for_pairing - from . import pairing_context from .pairing_context import PairingContext if self.connection_context is None: self.connection_context = PairingContext(self) - pairing_context.get_handler = get_handler_for_pairing # noqa loop.schedule(self.connection_context.handle()) self._decrypt_buffer(message_length) @@ -430,9 +427,9 @@ class Channel(Context): "This message cannot be handled by channel itself. It must be send to allocated session." ) # TODO handle other messages than CreateNewSession - from trezor.wire.thp.handler_provider import get_handler_for_handshake + from trezor.wire.thp.handler_provider import get_handler_for_channel_message - handler = get_handler_for_handshake(message) + handler = get_handler_for_channel_message(message) task = handler(self, message) response_message = await task # TODO handle diff --git a/core/src/trezor/wire/thp/handler_provider.py b/core/src/trezor/wire/thp/handler_provider.py index 1f9f6d99c..68d170442 100644 --- a/core/src/trezor/wire/thp/handler_provider.py +++ b/core/src/trezor/wire/thp/handler_provider.py @@ -9,16 +9,8 @@ if TYPE_CHECKING: pass -from apps.thp.pairing import handle_pairing_request - -def get_handler_for_handshake( +def get_handler_for_channel_message( msg: protobuf.MessageType, ) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: return create_session.create_new_session - - -def get_handler_for_pairing( - messageType: int, -) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType | None]]: - return handle_pairing_request diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 633cc32f3..0dd8aeea0 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -10,12 +10,7 @@ from trezor.wire.thp.session_context import UnexpectedMessageWithType from .channel import Channel if TYPE_CHECKING: - from typing import ( # pyright:ignore[reportShadowedImports] - Any, - Callable, - Container, - Coroutine, - ) + from typing import Container # pyright:ignore[reportShadowedImports] pass @@ -55,7 +50,7 @@ class PairingContext(Context): next_message = None try: - next_message = await handle_pairing_message( + next_message = await handle_pairing_request_message( self, message, use_workflow=not is_debug_session ) except Exception as exc: @@ -129,30 +124,13 @@ class PairingContext(Context): return await self.read(expected_types) -def _find_handler_placeholder( - messageType: int, -) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: - raise Exception() - - -get_handler = _find_handler_placeholder - - -async def handle_pairing_message( +async def handle_pairing_request_message( ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool ) -> protocol_common.MessageWithType | None: res_msg: protobuf.MessageType | None = None - # We need to find a handler for this message type. Should not raise. - # TODO register handlers to dict - handler = get_handler(msg.type) # pylint: disable=assignment-from-none - - if handler is None: - # If no handler is found, we can skip decoding and directly - # respond with failure. - await ctx.write(message_handler.unexpected_message()) - return None + from apps.thp.pairing import handle_pairing_request if msg.type in workflow.ALLOW_WHILE_LOCKED: workflow.autolock_interrupts_workflow = False @@ -169,7 +147,7 @@ async def handle_pairing_message( req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) # Create the handler task. - task = handler(ctx, req_msg) + task = handle_pairing_request(ctx, req_msg) # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a @@ -211,6 +189,7 @@ async def handle_pairing_message( else: log.exception(__name__, exc) res_msg = message_handler.failure(exc) + if res_msg is not None: # perform the write outside the big try-except block, so that usb write # problem bubbles up