diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index 0b40e835b..982b0503a 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -28,7 +28,7 @@ from trezor.wire.thp.thp_session import ThpError async def handle_pairing_request( ctx: PairingContext, message: protobuf.MessageType -) -> None | ThpEndResponse: +) -> ThpEndResponse: assert ThpStartPairingRequest.is_type_of(message) if __debug__: @@ -52,11 +52,14 @@ async def handle_pairing_request( return await _handle_qr_code_tag(ctx, response) if ThpNfcUnidirectionalTag.is_type_of(response): return await _handle_nfc_unidirectional_tag(ctx, response) + raise Exception( + "TODO Change this exception message and type. This exception should result in channel destruction." + ) async def _handle_code_entry_challenge( ctx: PairingContext, message: protobuf.MessageType -) -> None | ThpEndResponse: +) -> ThpEndResponse: assert ThpCodeEntryChallenge.is_type_of(message) _check_state(ctx, ChannelState.TP2) @@ -73,11 +76,14 @@ async def _handle_code_entry_challenge( return await _handle_qr_code_tag(ctx, response) if ThpNfcUnidirectionalTag.is_type_of(response): return await _handle_nfc_unidirectional_tag(ctx, response) + raise Exception( + "TODO Change this exception message and type. This exception should result in channel destruction." + ) async def _handle_code_entry_cpace( ctx: PairingContext, message: protobuf.MessageType -) -> None | ThpEndResponse: +) -> ThpEndResponse: assert ThpCodeEntryCpaceHost.is_type_of(message) _check_state(ctx, ChannelState.TP3) @@ -89,54 +95,43 @@ async def _handle_code_entry_cpace( async def _handle_code_entry_tag( ctx: PairingContext, message: protobuf.MessageType -) -> None | ThpEndResponse: +) -> ThpEndResponse: assert ThpCodeEntryTag.is_type_of(message) - - _check_state(ctx, ChannelState.TP4) - ctx.channel.set_channel_state(ChannelState.TC1) - response = await ctx.call_any( - ThpCodeEntrySecret(), - MessageType.ThpCredentialRequest, - MessageType.ThpEndRequest, + return await _handle_tag_message( + ctx, + expected_state=ChannelState.TP4, + used_method=ThpPairingMethod.PairingMethod_CodeEntry, + msg=ThpCodeEntrySecret(), ) - await _handle_credential_request_or_end_request(ctx, response) async def _handle_qr_code_tag( ctx: PairingContext, message: protobuf.MessageType -) -> None | ThpEndResponse: +) -> ThpEndResponse: assert ThpQrCodeTag.is_type_of(message) - - _check_state(ctx, ChannelState.TP3) - _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode) - ctx.channel.set_channel_state(ChannelState.TC1) - response = await ctx.call_any( - ThpQrCodeSecret(), - MessageType.ThpCredentialRequest, - MessageType.ThpEndRequest, + return await _handle_tag_message( + ctx, + expected_state=ChannelState.TP3, + used_method=ThpPairingMethod.PairingMethod_QrCode, + msg=ThpQrCodeSecret(), ) - await _handle_credential_request_or_end_request(ctx, response) async def _handle_nfc_unidirectional_tag( ctx: PairingContext, message: protobuf.MessageType -) -> None | ThpEndResponse: +) -> ThpEndResponse: assert ThpNfcUnidirectionalTag.is_type_of(message) - - _check_state(ctx, ChannelState.TP3) - _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional) - ctx.channel.set_channel_state(ChannelState.TC1) - response = await ctx.call_any( - ThpNfcUnideirectionalSecret(), - MessageType.ThpCredentialRequest, - MessageType.ThpEndRequest, + return await _handle_tag_message( + ctx, + expected_state=ChannelState.TP3, + used_method=ThpPairingMethod.PairingMethod_NFC_Unidirectional, + msg=ThpNfcUnideirectionalSecret(), ) - await _handle_credential_request_or_end_request(ctx, response) async def _handle_credential_request( ctx: PairingContext, message: protobuf.MessageType -) -> None | ThpEndResponse: +) -> ThpEndResponse: assert ThpCredentialRequest.is_type_of(message) _check_state(ctx, ChannelState.TC1) @@ -145,7 +140,7 @@ async def _handle_credential_request( MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) - await _handle_credential_request_or_end_request(ctx, response) + return await _handle_credential_request_or_end_request(ctx, response) async def _handle_end_request( @@ -158,6 +153,23 @@ async def _handle_end_request( return ThpEndResponse() +async def _handle_tag_message( + ctx: PairingContext, + expected_state: ChannelState, + used_method: ThpPairingMethod, + msg: protobuf.MessageType, +) -> ThpEndResponse: + _check_state(ctx, expected_state) + _check_method_is_allowed(ctx, used_method) + ctx.channel.set_channel_state(ChannelState.TC1) + response = await ctx.call_any( + msg, + MessageType.ThpCredentialRequest, + MessageType.ThpEndRequest, + ) + return await _handle_credential_request_or_end_request(ctx, response) + + def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None: if expected_state is not ctx.channel.get_channel_state(): raise UnexpectedMessage("Unexpected message") @@ -174,7 +186,7 @@ def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: async def _handle_credential_request_or_end_request( ctx: PairingContext, response: protobuf.MessageType | None -) -> None | ThpEndResponse: +) -> ThpEndResponse: if ThpCredentialRequest.is_type_of(response): return await _handle_credential_request(ctx, response) if ThpEndRequest.is_type_of(response):