From 3a8c4c63309f7268ce54c6f68c0588826c2c4c54 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 12 Apr 2024 18:42:35 +0200 Subject: [PATCH] Change pairing process into a workflow --- core/src/apps/thp/pairing.py | 130 +++++++++++-------- core/src/trezor/wire/thp/channel.py | 9 +- core/src/trezor/wire/thp/handler_provider.py | 38 +++++- core/src/trezor/wire/thp/pairing_context.py | 93 ++++--------- 4 files changed, 146 insertions(+), 124 deletions(-) diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index a83c6d3a1..0eae6b921 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -1,5 +1,5 @@ from trezor import log, protobuf -from trezor.enums import ThpPairingMethod +from trezor.enums import MessageType, ThpPairingMethod from trezor.messages import ( ThpCodeEntryChallenge, ThpCodeEntryCommitment, @@ -18,110 +18,138 @@ from trezor.messages import ( ThpQrCodeTag, ThpStartPairingRequest, ) +from trezor.wire import context from trezor.wire.errors import UnexpectedMessage from trezor.wire.thp import ChannelState -from trezor.wire.thp.channel import Channel +from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.thp_session import ThpError # TODO implement the following handlers async def handle_pairing_request( - channel: Channel, message: protobuf.MessageType -) -> ThpCodeEntryCommitment | ThpPairingPreparationsFinished: + ctx: PairingContext, message: protobuf.MessageType +) -> None: assert ThpStartPairingRequest.is_type_of(message) if __debug__: log.debug(__name__, "handle_pairing_request") - _check_state(channel, ChannelState.TP1) - if _is_method_included(channel, ThpPairingMethod.PairingMethod_CodeEntry): - channel.set_channel_state(ChannelState.TP2) - return ThpCodeEntryCommitment() - channel.set_channel_state(ChannelState.TP3) - return ThpPairingPreparationsFinished() + _check_state(ctx, ChannelState.TP1) + + if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry): + ctx.channel.set_channel_state(ChannelState.TP2) + await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) + + ctx.channel.set_channel_state(ChannelState.TP3) + await context.call_any( + ThpPairingPreparationsFinished(), + MessageType.ThpQrCodeTag, + MessageType.ThpNfcUnidirectionalTag, + ) async def handle_code_entry_challenge( - channel: Channel, message: protobuf.MessageType -) -> ThpPairingPreparationsFinished: + ctx: PairingContext, message: protobuf.MessageType +) -> None: assert ThpCodeEntryChallenge.is_type_of(message) - _check_state(channel, ChannelState.TP2) - channel.set_channel_state(ChannelState.TP3) - return ThpPairingPreparationsFinished() + _check_state(ctx, ChannelState.TP2) + ctx.channel.set_channel_state(ChannelState.TP3) + await context.call_any( + ThpPairingPreparationsFinished(), + MessageType.ThpCodeEntryCpaceHost, + MessageType.ThpQrCodeTag, + MessageType.ThpNfcUnidirectionalTag, + ) async def handle_code_entry_cpace( - channel: Channel, message: protobuf.MessageType -) -> ThpCodeEntryCpaceTrezor: + ctx: PairingContext, message: protobuf.MessageType +) -> None: assert ThpCodeEntryCpaceHost.is_type_of(message) - _check_state(channel, ChannelState.TP3) - _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry) - channel.set_channel_state(ChannelState.TP4) - return ThpCodeEntryCpaceTrezor() + _check_state(ctx, ChannelState.TP3) + _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry) + ctx.channel.set_channel_state(ChannelState.TP4) + await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) async def handle_code_entry_tag( - channel: Channel, message: protobuf.MessageType -) -> ThpCodeEntrySecret: + ctx: PairingContext, message: protobuf.MessageType +) -> None: assert ThpCodeEntryTag.is_type_of(message) - _check_state(channel, ChannelState.TP4) - channel.set_channel_state(ChannelState.TC1) - return ThpCodeEntrySecret() + _check_state(ctx, ChannelState.TP4) + ctx.channel.set_channel_state(ChannelState.TC1) + await context.call_any( + ThpCodeEntrySecret(), + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) async def handle_qr_code_tag( - channel: Channel, message: protobuf.MessageType -) -> ThpQrCodeSecret: + ctx: PairingContext, message: protobuf.MessageType +) -> None: assert ThpQrCodeTag.is_type_of(message) - _check_state(channel, ChannelState.TP3) - _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode) - channel.set_channel_state(ChannelState.TC1) - return ThpQrCodeSecret() + _check_state(ctx, ChannelState.TP3) + _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode) + ctx.channel.set_channel_state(ChannelState.TC1) + await context.call_any( + ThpQrCodeSecret(), + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) async def handle_nfc_unidirectional_tag( - channel: Channel, message: protobuf.MessageType -) -> ThpNfcUnideirectionalSecret: + ctx: PairingContext, message: protobuf.MessageType +) -> None: assert ThpNfcUnidirectionalTag.is_type_of(message) - _check_state(channel, ChannelState.TP3) - _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional) - channel.set_channel_state(ChannelState.TC1) - return ThpNfcUnideirectionalSecret() + _check_state(ctx, ChannelState.TP3) + _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional) + ctx.channel.set_channel_state(ChannelState.TC1) + await context.call_any( + ThpNfcUnideirectionalSecret(), + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) async def handle_credential_request( - channel: Channel, message: protobuf.MessageType -) -> ThpCredentialResponse: + ctx: PairingContext, message: protobuf.MessageType +) -> None: assert ThpCredentialRequest.is_type_of(message) - _check_state(channel, ChannelState.TC1) - return ThpCredentialResponse() + _check_state(ctx, ChannelState.TC1) + await context.call_any( + ThpCredentialResponse(), + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) async def handle_end_request( - channel: Channel, message: protobuf.MessageType + ctx: PairingContext, message: protobuf.MessageType ) -> ThpEndResponse: assert ThpEndRequest.is_type_of(message) - _check_state(channel, ChannelState.TC1) - channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + _check_state(ctx, ChannelState.TC1) + ctx.channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) return ThpEndResponse() -def _check_state(channel: Channel, expected_state: ChannelState) -> None: - if expected_state is not channel.get_channel_state(): +def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None: + if expected_state is not ctx.channel.get_channel_state(): raise UnexpectedMessage("Unexpected message") -def _check_method_is_allowed(channel: Channel, method: ThpPairingMethod) -> None: - if not _is_method_included(channel, method): +def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None: + if not _is_method_included(ctx, method): raise ThpError("Unexpected pairing method") -def _is_method_included(channel: Channel, method: ThpPairingMethod) -> bool: - return method in channel.selected_pairing_methods +def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: + return method in ctx.channel.selected_pairing_methods diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 35bc1715d..026604006 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -14,7 +14,6 @@ from trezor.messages import ( ) from trezor.wire import message_handler from trezor.wire.thp import ack_handler, thp_messages -from trezor.wire.thp.handler_provider import get_handler from ..protocol_common import Context, MessageWithType from . import ChannelState, SessionState, checksum, crypto @@ -375,10 +374,15 @@ class Channel(Context): ) async def _handle_pairing(self, message_length: int) -> None: + from trezor.wire.thp.handler_provider import get_handler_for_pairing + + from . import pairing_context from .pairing_context import PairingContext if self.connection_context is None: self.connection_context = PairingContext(self) + pairing_context.get_handler = get_handler_for_pairing # noqa + loop.schedule(self.connection_context.handle()) self._decrypt_buffer(message_length) @@ -426,8 +430,9 @@ class Channel(Context): "This message cannot be handled by channel itself. It must be send to allocated session." ) # TODO handle other messages than CreateNewSession + from trezor.wire.thp.handler_provider import get_handler_for_handshake - handler = get_handler(message) + handler = get_handler_for_handshake(message) task = handler(self, message) response_message = await task # TODO handle diff --git a/core/src/trezor/wire/thp/handler_provider.py b/core/src/trezor/wire/thp/handler_provider.py index a0774ded1..7cb57befd 100644 --- a/core/src/trezor/wire/thp/handler_provider.py +++ b/core/src/trezor/wire/thp/handler_provider.py @@ -1,6 +1,8 @@ from typing import TYPE_CHECKING from trezor import protobuf +from trezor.enums import MessageType +from trezor.wire.thp.thp_session import ThpError from apps.thp import create_session @@ -9,8 +11,42 @@ if TYPE_CHECKING: pass +from apps.thp.pairing import ( + handle_code_entry_challenge, + handle_code_entry_cpace, + handle_code_entry_tag, + handle_credential_request, + handle_end_request, + handle_nfc_unidirectional_tag, + handle_pairing_request, + handle_qr_code_tag, +) -def get_handler( + +def get_handler_for_handshake( msg: protobuf.MessageType, ) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: return create_session.create_new_session + + +def get_handler_for_pairing( + messageType: int, +) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: + if TYPE_CHECKING: + assert isinstance(messageType, MessageType) + handler = handlers.get(messageType) + if handler is None: + raise ThpError("Pairing handler for this message is not available!") + return handler + + +handlers = { + MessageType.ThpStartPairingRequest: handle_pairing_request, + MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge, + MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace, + MessageType.ThpCodeEntryTag: handle_code_entry_tag, + MessageType.ThpQrCodeTag: handle_qr_code_tag, + MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag, + MessageType.ThpCredentialRequest: handle_credential_request, + MessageType.ThpEndRequest: handle_end_request, +} diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index a81a95dd7..0f4722b35 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -1,46 +1,28 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from trezor import log, loop, protobuf, workflow -from trezor.enums import MessageType -from trezor.wire import message_handler, protocol_common +from trezor.wire import context, message_handler, protocol_common from trezor.wire.context import UnexpectedMessageWithId from trezor.wire.errors import ActionCancelled -from trezor.wire.protocol_common import MessageWithType +from trezor.wire.protocol_common import Context, MessageWithType from trezor.wire.thp.session_context import UnexpectedMessageWithType -from trezor.wire.thp.thp_session import ThpError - -from apps.thp.pairing import ( - handle_code_entry_challenge, - handle_code_entry_cpace, - handle_code_entry_tag, - handle_credential_request, - handle_end_request, - handle_nfc_unidirectional_tag, - handle_pairing_request, - handle_qr_code_tag, -) from .channel import Channel if TYPE_CHECKING: - from typing import Container, Generator # pyright:ignore[reportShadowedImports] + from typing import ( # pyright:ignore[reportShadowedImports] + Any, + Callable, + Container, + Coroutine, + ) pass -handlers = { - MessageType.ThpStartPairingRequest: handle_pairing_request, - MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge, - MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace, - MessageType.ThpCodeEntryTag: handle_code_entry_tag, - MessageType.ThpQrCodeTag: handle_qr_code_tag, - MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag, - MessageType.ThpCredentialRequest: handle_credential_request, - MessageType.ThpEndRequest: handle_end_request, -} - -class PairingContext: +class PairingContext(Context): def __init__(self, channel: Channel) -> None: + super().__init__(channel.iface, channel.channel_id) self.channel = channel self.incoming_message = loop.chan() @@ -115,7 +97,9 @@ class PairingContext: str(expected_types), exp_type, ) + message: MessageWithType = await self.incoming_message.take() + if message.type not in expected_types: raise UnexpectedMessageWithType(message) @@ -128,6 +112,15 @@ class PairingContext: return await self.channel.write(msg) +def _find_handler_placeholder( + messageType: int, +) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: + raise Exception() + + +get_handler = _find_handler_placeholder + + async def handle_pairing_message( ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool ) -> protocol_common.MessageWithType | None: @@ -159,7 +152,7 @@ async def handle_pairing_message( req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) # Create the handler task. - task = handler(ctx.channel, req_msg) + task = handler(ctx, req_msg) # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a @@ -168,7 +161,7 @@ async def handle_pairing_message( if use_workflow: # Spawn a workflow around the task. This ensures that concurrent # workflows are shut down. - res_msg = await workflow.spawn(with_context(ctx, task)) + res_msg = await workflow.spawn(context.with_context(ctx, task)) pass # TODO else: # For debug messages, ignore workflow processing and just await @@ -206,43 +199,3 @@ async def handle_pairing_message( # problem bubbles up await ctx.write(res_msg) return None - - -def get_handler(messageType: int): - if TYPE_CHECKING: - assert isinstance(messageType, MessageType) - handler = handlers.get(messageType) - if handler is None: - raise ThpError("Pairing handler for this message is not available!") - return handler - - -def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator: - """Run a workflow in a particular context. - - Stores the context in a closure and installs it into the global variable every time - the closure is resumed, thus making sure that all calls to `wire.context.*` will - work as expected. - """ - global CURRENT_CONTEXT - send_val = None - send_exc = None - - while True: - CURRENT_CONTEXT = ctx - try: - if send_exc is not None: - res = workflow.throw(send_exc) - else: - res = workflow.send(send_val) - except StopIteration as st: - return st.value - finally: - CURRENT_CONTEXT = None - - try: - send_val = yield res - except BaseException as e: - send_exc = e - else: - send_exc = None