Remove code duplication in pairing.py

M1nd3r/thp2
M1nd3r 2 months ago
parent 8fb45754c6
commit 11bbdc7cc7

@ -28,7 +28,7 @@ from trezor.wire.thp.thp_session import ThpError
async def handle_pairing_request( async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse: ) -> ThpEndResponse:
assert ThpStartPairingRequest.is_type_of(message) assert ThpStartPairingRequest.is_type_of(message)
if __debug__: if __debug__:
@ -52,11 +52,14 @@ async def handle_pairing_request(
return await _handle_qr_code_tag(ctx, response) return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response): if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response) return await _handle_nfc_unidirectional_tag(ctx, response)
raise Exception(
"TODO Change this exception message and type. This exception should result in channel destruction."
)
async def _handle_code_entry_challenge( async def _handle_code_entry_challenge(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse: ) -> ThpEndResponse:
assert ThpCodeEntryChallenge.is_type_of(message) assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(ctx, ChannelState.TP2) _check_state(ctx, ChannelState.TP2)
@ -73,11 +76,14 @@ async def _handle_code_entry_challenge(
return await _handle_qr_code_tag(ctx, response) return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response): if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response) return await _handle_nfc_unidirectional_tag(ctx, response)
raise Exception(
"TODO Change this exception message and type. This exception should result in channel destruction."
)
async def _handle_code_entry_cpace( async def _handle_code_entry_cpace(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse: ) -> ThpEndResponse:
assert ThpCodeEntryCpaceHost.is_type_of(message) assert ThpCodeEntryCpaceHost.is_type_of(message)
_check_state(ctx, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
@ -89,54 +95,43 @@ async def _handle_code_entry_cpace(
async def _handle_code_entry_tag( async def _handle_code_entry_tag(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse: ) -> ThpEndResponse:
assert ThpCodeEntryTag.is_type_of(message) assert ThpCodeEntryTag.is_type_of(message)
return await _handle_tag_message(
_check_state(ctx, ChannelState.TP4) ctx,
ctx.channel.set_channel_state(ChannelState.TC1) expected_state=ChannelState.TP4,
response = await ctx.call_any( used_method=ThpPairingMethod.PairingMethod_CodeEntry,
ThpCodeEntrySecret(), msg=ThpCodeEntrySecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
) )
await _handle_credential_request_or_end_request(ctx, response)
async def _handle_qr_code_tag( async def _handle_qr_code_tag(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse: ) -> ThpEndResponse:
assert ThpQrCodeTag.is_type_of(message) assert ThpQrCodeTag.is_type_of(message)
return await _handle_tag_message(
_check_state(ctx, ChannelState.TP3) ctx,
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode) expected_state=ChannelState.TP3,
ctx.channel.set_channel_state(ChannelState.TC1) used_method=ThpPairingMethod.PairingMethod_QrCode,
response = await ctx.call_any( msg=ThpQrCodeSecret(),
ThpQrCodeSecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
) )
await _handle_credential_request_or_end_request(ctx, response)
async def _handle_nfc_unidirectional_tag( async def _handle_nfc_unidirectional_tag(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse: ) -> ThpEndResponse:
assert ThpNfcUnidirectionalTag.is_type_of(message) assert ThpNfcUnidirectionalTag.is_type_of(message)
return await _handle_tag_message(
_check_state(ctx, ChannelState.TP3) ctx,
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional) expected_state=ChannelState.TP3,
ctx.channel.set_channel_state(ChannelState.TC1) used_method=ThpPairingMethod.PairingMethod_NFC_Unidirectional,
response = await ctx.call_any( msg=ThpNfcUnideirectionalSecret(),
ThpNfcUnideirectionalSecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
) )
await _handle_credential_request_or_end_request(ctx, response)
async def _handle_credential_request( async def _handle_credential_request(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse: ) -> ThpEndResponse:
assert ThpCredentialRequest.is_type_of(message) assert ThpCredentialRequest.is_type_of(message)
_check_state(ctx, ChannelState.TC1) _check_state(ctx, ChannelState.TC1)
@ -145,7 +140,7 @@ async def _handle_credential_request(
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_credential_request_or_end_request(ctx, response) return await _handle_credential_request_or_end_request(ctx, response)
async def _handle_end_request( async def _handle_end_request(
@ -158,6 +153,23 @@ async def _handle_end_request(
return ThpEndResponse() return ThpEndResponse()
async def _handle_tag_message(
ctx: PairingContext,
expected_state: ChannelState,
used_method: ThpPairingMethod,
msg: protobuf.MessageType,
) -> ThpEndResponse:
_check_state(ctx, expected_state)
_check_method_is_allowed(ctx, used_method)
ctx.channel.set_channel_state(ChannelState.TC1)
response = await ctx.call_any(
msg,
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
return await _handle_credential_request_or_end_request(ctx, response)
def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None: def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None:
if expected_state is not ctx.channel.get_channel_state(): if expected_state is not ctx.channel.get_channel_state():
raise UnexpectedMessage("Unexpected message") raise UnexpectedMessage("Unexpected message")
@ -174,7 +186,7 @@ def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
async def _handle_credential_request_or_end_request( async def _handle_credential_request_or_end_request(
ctx: PairingContext, response: protobuf.MessageType | None ctx: PairingContext, response: protobuf.MessageType | None
) -> None | ThpEndResponse: ) -> ThpEndResponse:
if ThpCredentialRequest.is_type_of(response): if ThpCredentialRequest.is_type_of(response):
return await _handle_credential_request(ctx, response) return await _handle_credential_request(ctx, response)
if ThpEndRequest.is_type_of(response): if ThpEndRequest.is_type_of(response):

Loading…
Cancel
Save