diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index 433d2d74e..8f6027d94 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -20,7 +20,7 @@ from trezor.messages import ( ) from trezor.wire import context from trezor.wire.errors import UnexpectedMessage -from trezor.wire.thp import ChannelState +from trezor.wire.thp import ChannelState, pairing_context from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.thp_session import ThpError @@ -38,14 +38,16 @@ async def handle_pairing_request( if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry): ctx.channel.set_channel_state(ChannelState.TP2) - await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) + response = await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) + else: ctx.channel.set_channel_state(ChannelState.TP3) - await context.call_any( + response = await context.call_any( ThpPairingPreparationsFinished(), MessageType.ThpQrCodeTag, MessageType.ThpNfcUnidirectionalTag, ) + await _handle_response(ctx, response) async def handle_code_entry_challenge( @@ -55,12 +57,13 @@ async def handle_code_entry_challenge( _check_state(ctx, ChannelState.TP2) ctx.channel.set_channel_state(ChannelState.TP3) - await context.call_any( + response = await context.call_any( ThpPairingPreparationsFinished(), MessageType.ThpCodeEntryCpaceHost, MessageType.ThpQrCodeTag, MessageType.ThpNfcUnidirectionalTag, ) + await _handle_response(ctx, response) async def handle_code_entry_cpace( @@ -71,7 +74,8 @@ async def handle_code_entry_cpace( _check_state(ctx, ChannelState.TP3) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry) ctx.channel.set_channel_state(ChannelState.TP4) - await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) + response = await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) + await _handle_response(ctx, response) async def handle_code_entry_tag( @@ -81,11 +85,12 @@ async def handle_code_entry_tag( _check_state(ctx, ChannelState.TP4) ctx.channel.set_channel_state(ChannelState.TC1) - await context.call_any( + response = await context.call_any( ThpCodeEntrySecret(), MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) + await _handle_response(ctx, response) async def handle_qr_code_tag( @@ -96,11 +101,12 @@ async def handle_qr_code_tag( _check_state(ctx, ChannelState.TP3) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode) ctx.channel.set_channel_state(ChannelState.TC1) - await context.call_any( + response = await context.call_any( ThpQrCodeSecret(), MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) + await _handle_response(ctx, response) async def handle_nfc_unidirectional_tag( @@ -111,11 +117,12 @@ async def handle_nfc_unidirectional_tag( _check_state(ctx, ChannelState.TP3) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional) ctx.channel.set_channel_state(ChannelState.TC1) - await context.call_any( + response = await context.call_any( ThpNfcUnideirectionalSecret(), MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) + await _handle_response(ctx, response) async def handle_credential_request( @@ -124,11 +131,12 @@ async def handle_credential_request( assert ThpCredentialRequest.is_type_of(message) _check_state(ctx, ChannelState.TC1) - await context.call_any( + response = await context.call_any( ThpCredentialResponse(), MessageType.ThpCredentialRequest, MessageType.ThpEndRequest, ) + await _handle_response(ctx, response) async def handle_end_request( @@ -153,3 +161,14 @@ def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> N def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: return method in ctx.channel.selected_pairing_methods + + +async def _handle_response( + ctx: PairingContext, response: protobuf.MessageType | None +) -> None: + if response is None: + raise Exception("Something is not ok") + if response.MESSAGE_WIRE_TYPE is None: + raise Exception("Something is not ok") + handler = pairing_context.get_handler(response.MESSAGE_WIRE_TYPE) + await handler(ctx, response)