diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index a83c6d3a1..982b0503a 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -1,5 +1,5 @@ from trezor import log, protobuf -from trezor.enums import ThpPairingMethod +from trezor.enums import MessageType, ThpPairingMethod from trezor.messages import ( ThpCodeEntryChallenge, ThpCodeEntryCommitment, @@ -20,108 +20,177 @@ from trezor.messages import ( ) from trezor.wire.errors import UnexpectedMessage from trezor.wire.thp import ChannelState -from trezor.wire.thp.channel import Channel +from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.thp_session import ThpError # TODO implement the following handlers async def handle_pairing_request( - channel: Channel, message: protobuf.MessageType -) -> ThpCodeEntryCommitment | ThpPairingPreparationsFinished: + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: assert ThpStartPairingRequest.is_type_of(message) if __debug__: log.debug(__name__, "handle_pairing_request") - _check_state(channel, ChannelState.TP1) - if _is_method_included(channel, ThpPairingMethod.PairingMethod_CodeEntry): - channel.set_channel_state(ChannelState.TP2) - return ThpCodeEntryCommitment() - channel.set_channel_state(ChannelState.TP3) - return ThpPairingPreparationsFinished() + _check_state(ctx, ChannelState.TP1) -async def handle_code_entry_challenge( - channel: Channel, message: protobuf.MessageType -) -> ThpPairingPreparationsFinished: - assert ThpCodeEntryChallenge.is_type_of(message) + if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry): + ctx.channel.set_channel_state(ChannelState.TP2) + + response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) + return await _handle_code_entry_challenge(ctx, response) - _check_state(channel, ChannelState.TP2) - channel.set_channel_state(ChannelState.TP3) - return ThpPairingPreparationsFinished() + ctx.channel.set_channel_state(ChannelState.TP3) + response = await ctx.call_any( + ThpPairingPreparationsFinished(), + MessageType.ThpQrCodeTag, + MessageType.ThpNfcUnidirectionalTag, + ) + if ThpQrCodeTag.is_type_of(response): + return await _handle_qr_code_tag(ctx, response) + if ThpNfcUnidirectionalTag.is_type_of(response): + return await _handle_nfc_unidirectional_tag(ctx, response) + raise Exception( + "TODO Change this exception message and type. This exception should result in channel destruction." + ) -async def handle_code_entry_cpace( - channel: Channel, message: protobuf.MessageType -) -> ThpCodeEntryCpaceTrezor: +async def _handle_code_entry_challenge( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: + assert ThpCodeEntryChallenge.is_type_of(message) + + _check_state(ctx, ChannelState.TP2) + ctx.channel.set_channel_state(ChannelState.TP3) + response = await ctx.call_any( + ThpPairingPreparationsFinished(), + MessageType.ThpCodeEntryCpaceHost, + MessageType.ThpQrCodeTag, + MessageType.ThpNfcUnidirectionalTag, + ) + if ThpCodeEntryCpaceHost.is_type_of(response): + return await _handle_code_entry_cpace(ctx, response) + if ThpQrCodeTag.is_type_of(response): + return await _handle_qr_code_tag(ctx, response) + if ThpNfcUnidirectionalTag.is_type_of(response): + return await _handle_nfc_unidirectional_tag(ctx, response) + raise Exception( + "TODO Change this exception message and type. This exception should result in channel destruction." + ) + + +async def _handle_code_entry_cpace( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: assert ThpCodeEntryCpaceHost.is_type_of(message) - _check_state(channel, ChannelState.TP3) - _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry) - channel.set_channel_state(ChannelState.TP4) - return ThpCodeEntryCpaceTrezor() + _check_state(ctx, ChannelState.TP3) + _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry) + ctx.channel.set_channel_state(ChannelState.TP4) + response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) + return await _handle_code_entry_tag(ctx, response) -async def handle_code_entry_tag( - channel: Channel, message: protobuf.MessageType -) -> ThpCodeEntrySecret: +async def _handle_code_entry_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: assert ThpCodeEntryTag.is_type_of(message) - - _check_state(channel, ChannelState.TP4) - channel.set_channel_state(ChannelState.TC1) - return ThpCodeEntrySecret() + return await _handle_tag_message( + ctx, + expected_state=ChannelState.TP4, + used_method=ThpPairingMethod.PairingMethod_CodeEntry, + msg=ThpCodeEntrySecret(), + ) -async def handle_qr_code_tag( - channel: Channel, message: protobuf.MessageType -) -> ThpQrCodeSecret: +async def _handle_qr_code_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: assert ThpQrCodeTag.is_type_of(message) - - _check_state(channel, ChannelState.TP3) - _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode) - channel.set_channel_state(ChannelState.TC1) - return ThpQrCodeSecret() + return await _handle_tag_message( + ctx, + expected_state=ChannelState.TP3, + used_method=ThpPairingMethod.PairingMethod_QrCode, + msg=ThpQrCodeSecret(), + ) -async def handle_nfc_unidirectional_tag( - channel: Channel, message: protobuf.MessageType -) -> ThpNfcUnideirectionalSecret: +async def _handle_nfc_unidirectional_tag( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: assert ThpNfcUnidirectionalTag.is_type_of(message) + return await _handle_tag_message( + ctx, + expected_state=ChannelState.TP3, + used_method=ThpPairingMethod.PairingMethod_NFC_Unidirectional, + msg=ThpNfcUnideirectionalSecret(), + ) - _check_state(channel, ChannelState.TP3) - _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional) - channel.set_channel_state(ChannelState.TC1) - return ThpNfcUnideirectionalSecret() - -async def handle_credential_request( - channel: Channel, message: protobuf.MessageType -) -> ThpCredentialResponse: +async def _handle_credential_request( + ctx: PairingContext, message: protobuf.MessageType +) -> ThpEndResponse: assert ThpCredentialRequest.is_type_of(message) - _check_state(channel, ChannelState.TC1) - return ThpCredentialResponse() + _check_state(ctx, ChannelState.TC1) + response = await ctx.call_any( + ThpCredentialResponse(), + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) + return await _handle_credential_request_or_end_request(ctx, response) -async def handle_end_request( - channel: Channel, message: protobuf.MessageType +async def _handle_end_request( + ctx: PairingContext, message: protobuf.MessageType ) -> ThpEndResponse: assert ThpEndRequest.is_type_of(message) - _check_state(channel, ChannelState.TC1) - channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + _check_state(ctx, ChannelState.TC1) + ctx.channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) return ThpEndResponse() -def _check_state(channel: Channel, expected_state: ChannelState) -> None: - if expected_state is not channel.get_channel_state(): +async def _handle_tag_message( + ctx: PairingContext, + expected_state: ChannelState, + used_method: ThpPairingMethod, + msg: protobuf.MessageType, +) -> ThpEndResponse: + _check_state(ctx, expected_state) + _check_method_is_allowed(ctx, used_method) + ctx.channel.set_channel_state(ChannelState.TC1) + response = await ctx.call_any( + msg, + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) + return await _handle_credential_request_or_end_request(ctx, response) + + +def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None: + if expected_state is not ctx.channel.get_channel_state(): raise UnexpectedMessage("Unexpected message") -def _check_method_is_allowed(channel: Channel, method: ThpPairingMethod) -> None: - if not _is_method_included(channel, method): +def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None: + if not _is_method_included(ctx, method): raise ThpError("Unexpected pairing method") -def _is_method_included(channel: Channel, method: ThpPairingMethod) -> bool: - return method in channel.selected_pairing_methods +def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: + return method in ctx.channel.selected_pairing_methods + + +async def _handle_credential_request_or_end_request( + ctx: PairingContext, response: protobuf.MessageType | None +) -> ThpEndResponse: + if ThpCredentialRequest.is_type_of(response): + return await _handle_credential_request(ctx, response) + if ThpEndRequest.is_type_of(response): + return await _handle_end_request(ctx, response) + raise UnexpectedMessage( + "Received message is not credential request or end request." + ) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 35bc1715d..c719c279b 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -14,7 +14,6 @@ from trezor.messages import ( ) from trezor.wire import message_handler from trezor.wire.thp import ack_handler, thp_messages -from trezor.wire.thp.handler_provider import get_handler from ..protocol_common import Context, MessageWithType from . import ChannelState, SessionState, checksum, crypto @@ -375,10 +374,12 @@ class Channel(Context): ) async def _handle_pairing(self, message_length: int) -> None: + from .pairing_context import PairingContext if self.connection_context is None: self.connection_context = PairingContext(self) + loop.schedule(self.connection_context.handle()) self._decrypt_buffer(message_length) @@ -426,8 +427,9 @@ class Channel(Context): "This message cannot be handled by channel itself. It must be send to allocated session." ) # TODO handle other messages than CreateNewSession + from trezor.wire.thp.handler_provider import get_handler_for_channel_message - handler = get_handler(message) + handler = get_handler_for_channel_message(message) task = handler(self, message) response_message = await task # TODO handle diff --git a/core/src/trezor/wire/thp/handler_provider.py b/core/src/trezor/wire/thp/handler_provider.py index a0774ded1..68d170442 100644 --- a/core/src/trezor/wire/thp/handler_provider.py +++ b/core/src/trezor/wire/thp/handler_provider.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: pass -def get_handler( +def get_handler_for_channel_message( msg: protobuf.MessageType, ) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: return create_session.create_new_session diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index a81a95dd7..0dd8aeea0 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -1,46 +1,23 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from trezor import log, loop, protobuf, workflow -from trezor.enums import MessageType -from trezor.wire import message_handler, protocol_common +from trezor.wire import context, message_handler, protocol_common from trezor.wire.context import UnexpectedMessageWithId from trezor.wire.errors import ActionCancelled -from trezor.wire.protocol_common import MessageWithType +from trezor.wire.protocol_common import Context, MessageWithType from trezor.wire.thp.session_context import UnexpectedMessageWithType -from trezor.wire.thp.thp_session import ThpError - -from apps.thp.pairing import ( - handle_code_entry_challenge, - handle_code_entry_cpace, - handle_code_entry_tag, - handle_credential_request, - handle_end_request, - handle_nfc_unidirectional_tag, - handle_pairing_request, - handle_qr_code_tag, -) from .channel import Channel if TYPE_CHECKING: - from typing import Container, Generator # pyright:ignore[reportShadowedImports] + from typing import Container # pyright:ignore[reportShadowedImports] pass -handlers = { - MessageType.ThpStartPairingRequest: handle_pairing_request, - MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge, - MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace, - MessageType.ThpCodeEntryTag: handle_code_entry_tag, - MessageType.ThpQrCodeTag: handle_qr_code_tag, - MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag, - MessageType.ThpCredentialRequest: handle_credential_request, - MessageType.ThpEndRequest: handle_end_request, -} - -class PairingContext: +class PairingContext(Context): def __init__(self, channel: Channel) -> None: + super().__init__(channel.iface, channel.channel_id) self.channel = channel self.incoming_message = loop.chan() @@ -73,7 +50,7 @@ class PairingContext: next_message = None try: - next_message = await handle_pairing_message( + next_message = await handle_pairing_request_message( self, message, use_workflow=not is_debug_session ) except Exception as exc: @@ -115,7 +92,9 @@ class PairingContext: str(expected_types), exp_type, ) + message: MessageWithType = await self.incoming_message.take() + if message.type not in expected_types: raise UnexpectedMessageWithType(message) @@ -127,22 +106,31 @@ class PairingContext: async def write(self, msg: protobuf.MessageType) -> None: return await self.channel.write(msg) + async def call( + self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] + ) -> protobuf.MessageType: + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await self.write(msg) + del msg -async def handle_pairing_message( + return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + + async def call_any( + self, msg: protobuf.MessageType, *expected_types: int + ) -> protobuf.MessageType: + await self.write(msg) + del msg + return await self.read(expected_types) + + +async def handle_pairing_request_message( ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool ) -> protocol_common.MessageWithType | None: res_msg: protobuf.MessageType | None = None - # We need to find a handler for this message type. Should not raise. - # TODO register handlers to dict - handler = get_handler(msg.type) # pylint: disable=assignment-from-none - - if handler is None: - # If no handler is found, we can skip decoding and directly - # respond with failure. - await ctx.write(message_handler.unexpected_message()) - return None + from apps.thp.pairing import handle_pairing_request if msg.type in workflow.ALLOW_WHILE_LOCKED: workflow.autolock_interrupts_workflow = False @@ -159,7 +147,7 @@ async def handle_pairing_message( req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) # Create the handler task. - task = handler(ctx.channel, req_msg) + task = handle_pairing_request(ctx, req_msg) # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a @@ -168,7 +156,7 @@ async def handle_pairing_message( if use_workflow: # Spawn a workflow around the task. This ensures that concurrent # workflows are shut down. - res_msg = await workflow.spawn(with_context(ctx, task)) + res_msg = await workflow.spawn(context.with_context(ctx, task)) pass # TODO else: # For debug messages, ignore workflow processing and just await @@ -201,48 +189,9 @@ async def handle_pairing_message( else: log.exception(__name__, exc) res_msg = message_handler.failure(exc) + if res_msg is not None: # perform the write outside the big try-except block, so that usb write # problem bubbles up await ctx.write(res_msg) return None - - -def get_handler(messageType: int): - if TYPE_CHECKING: - assert isinstance(messageType, MessageType) - handler = handlers.get(messageType) - if handler is None: - raise ThpError("Pairing handler for this message is not available!") - return handler - - -def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator: - """Run a workflow in a particular context. - - Stores the context in a closure and installs it into the global variable every time - the closure is resumed, thus making sure that all calls to `wire.context.*` will - work as expected. - """ - global CURRENT_CONTEXT - send_val = None - send_exc = None - - while True: - CURRENT_CONTEXT = ctx - try: - if send_exc is not None: - res = workflow.throw(send_exc) - else: - res = workflow.send(send_val) - except StopIteration as st: - return st.value - finally: - CURRENT_CONTEXT = None - - try: - send_val = yield res - except BaseException as e: - send_exc = e - else: - send_exc = None