diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index 8f6027d94..0b40e835b 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -18,9 +18,8 @@ from trezor.messages import ( ThpQrCodeTag, ThpStartPairingRequest, ) -from trezor.wire import context from trezor.wire.errors import UnexpectedMessage -from trezor.wire.thp import ChannelState, pairing_context +from trezor.wire.thp import ChannelState from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.thp_session import ThpError @@ -29,117 +28,127 @@ from trezor.wire.thp.thp_session import ThpError async def handle_pairing_request( ctx: PairingContext, message: protobuf.MessageType -) -> None: +) -> None | ThpEndResponse: assert ThpStartPairingRequest.is_type_of(message) if __debug__: log.debug(__name__, "handle_pairing_request") + _check_state(ctx, ChannelState.TP1) if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry): ctx.channel.set_channel_state(ChannelState.TP2) - response = await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) - else: - ctx.channel.set_channel_state(ChannelState.TP3) - response = await context.call_any( - ThpPairingPreparationsFinished(), - MessageType.ThpQrCodeTag, - MessageType.ThpNfcUnidirectionalTag, - ) - await _handle_response(ctx, response) + response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) + return await _handle_code_entry_challenge(ctx, response) + + ctx.channel.set_channel_state(ChannelState.TP3) + response = await ctx.call_any( + ThpPairingPreparationsFinished(), + MessageType.ThpQrCodeTag, + MessageType.ThpNfcUnidirectionalTag, + ) + if ThpQrCodeTag.is_type_of(response): + return await _handle_qr_code_tag(ctx, response) + if ThpNfcUnidirectionalTag.is_type_of(response): + return await _handle_nfc_unidirectional_tag(ctx, response) -async def handle_code_entry_challenge( +async def _handle_code_entry_challenge( ctx: PairingContext, message: protobuf.MessageType -) -> None: +) -> None | ThpEndResponse: assert ThpCodeEntryChallenge.is_type_of(message) _check_state(ctx, ChannelState.TP2) ctx.channel.set_channel_state(ChannelState.TP3) - response = await context.call_any( + response = await ctx.call_any( ThpPairingPreparationsFinished(), MessageType.ThpCodeEntryCpaceHost, MessageType.ThpQrCodeTag, MessageType.ThpNfcUnidirectionalTag, ) - await _handle_response(ctx, response) + if ThpCodeEntryCpaceHost.is_type_of(response): + return await _handle_code_entry_cpace(ctx, response) + if ThpQrCodeTag.is_type_of(response): + return await _handle_qr_code_tag(ctx, response) + if ThpNfcUnidirectionalTag.is_type_of(response): + return await _handle_nfc_unidirectional_tag(ctx, response) -async def handle_code_entry_cpace( +async def _handle_code_entry_cpace( ctx: PairingContext, message: protobuf.MessageType -) -> None: +) -> None | ThpEndResponse: assert ThpCodeEntryCpaceHost.is_type_of(message) _check_state(ctx, ChannelState.TP3) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry) ctx.channel.set_channel_state(ChannelState.TP4) - response = await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) - await _handle_response(ctx, response) + response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) + return await _handle_code_entry_tag(ctx, response) -async def handle_code_entry_tag( +async def _handle_code_entry_tag( ctx: PairingContext, message: protobuf.MessageType -) -> None: +) -> None | ThpEndResponse: assert ThpCodeEntryTag.is_type_of(message) _check_state(ctx, ChannelState.TP4) ctx.channel.set_channel_state(ChannelState.TC1) - response = await context.call_any( + response = await ctx.call_any( ThpCodeEntrySecret(), MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) - await _handle_response(ctx, response) + await _handle_credential_request_or_end_request(ctx, response) -async def handle_qr_code_tag( +async def _handle_qr_code_tag( ctx: PairingContext, message: protobuf.MessageType -) -> None: +) -> None | ThpEndResponse: assert ThpQrCodeTag.is_type_of(message) _check_state(ctx, ChannelState.TP3) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode) ctx.channel.set_channel_state(ChannelState.TC1) - response = await context.call_any( + response = await ctx.call_any( ThpQrCodeSecret(), MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) - await _handle_response(ctx, response) + await _handle_credential_request_or_end_request(ctx, response) -async def handle_nfc_unidirectional_tag( +async def _handle_nfc_unidirectional_tag( ctx: PairingContext, message: protobuf.MessageType -) -> None: +) -> None | ThpEndResponse: assert ThpNfcUnidirectionalTag.is_type_of(message) _check_state(ctx, ChannelState.TP3) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional) ctx.channel.set_channel_state(ChannelState.TC1) - response = await context.call_any( + response = await ctx.call_any( ThpNfcUnideirectionalSecret(), MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) - await _handle_response(ctx, response) + await _handle_credential_request_or_end_request(ctx, response) -async def handle_credential_request( +async def _handle_credential_request( ctx: PairingContext, message: protobuf.MessageType -) -> None: +) -> None | ThpEndResponse: assert ThpCredentialRequest.is_type_of(message) _check_state(ctx, ChannelState.TC1) - response = await context.call_any( + response = await ctx.call_any( ThpCredentialResponse(), MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) - await _handle_response(ctx, response) + await _handle_credential_request_or_end_request(ctx, response) -async def handle_end_request( +async def _handle_end_request( ctx: PairingContext, message: protobuf.MessageType ) -> ThpEndResponse: assert ThpEndRequest.is_type_of(message) @@ -163,12 +172,13 @@ def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: return method in ctx.channel.selected_pairing_methods -async def _handle_response( +async def _handle_credential_request_or_end_request( ctx: PairingContext, response: protobuf.MessageType | None -) -> None: - if response is None: - raise Exception("Something is not ok") - if response.MESSAGE_WIRE_TYPE is None: - raise Exception("Something is not ok") - handler = pairing_context.get_handler(response.MESSAGE_WIRE_TYPE) - await handler(ctx, response) +) -> None | ThpEndResponse: + if ThpCredentialRequest.is_type_of(response): + return await _handle_credential_request(ctx, response) + if ThpEndRequest.is_type_of(response): + return await _handle_end_request(ctx, response) + raise UnexpectedMessage( + "Received message is not credential request or end request." + ) diff --git a/core/src/trezor/wire/thp/handler_provider.py b/core/src/trezor/wire/thp/handler_provider.py index 7cb57befd..1f9f6d99c 100644 --- a/core/src/trezor/wire/thp/handler_provider.py +++ b/core/src/trezor/wire/thp/handler_provider.py @@ -1,8 +1,6 @@ from typing import TYPE_CHECKING from trezor import protobuf -from trezor.enums import MessageType -from trezor.wire.thp.thp_session import ThpError from apps.thp import create_session @@ -11,16 +9,7 @@ if TYPE_CHECKING: pass -from apps.thp.pairing import ( - handle_code_entry_challenge, - handle_code_entry_cpace, - handle_code_entry_tag, - handle_credential_request, - handle_end_request, - handle_nfc_unidirectional_tag, - handle_pairing_request, - handle_qr_code_tag, -) +from apps.thp.pairing import handle_pairing_request def get_handler_for_handshake( @@ -31,22 +20,5 @@ def get_handler_for_handshake( def get_handler_for_pairing( messageType: int, -) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: - if TYPE_CHECKING: - assert isinstance(messageType, MessageType) - handler = handlers.get(messageType) - if handler is None: - raise ThpError("Pairing handler for this message is not available!") - return handler - - -handlers = { - MessageType.ThpStartPairingRequest: handle_pairing_request, - MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge, - MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace, - MessageType.ThpCodeEntryTag: handle_code_entry_tag, - MessageType.ThpQrCodeTag: handle_qr_code_tag, - MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag, - MessageType.ThpCredentialRequest: handle_credential_request, - MessageType.ThpEndRequest: handle_end_request, -} +) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType | None]]: + return handle_pairing_request diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 0f4722b35..633cc32f3 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -111,6 +111,23 @@ class PairingContext(Context): async def write(self, msg: protobuf.MessageType) -> None: return await self.channel.write(msg) + async def call( + self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType] + ) -> protobuf.MessageType: + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await self.write(msg) + del msg + + return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + + async def call_any( + self, msg: protobuf.MessageType, *expected_types: int + ) -> protobuf.MessageType: + await self.write(msg) + del msg + return await self.read(expected_types) + def _find_handler_placeholder( messageType: int,