Remove code duplication in pairing.py

M1nd3r/thp2
M1nd3r 4 weeks ago
parent 8fb45754c6
commit 11bbdc7cc7

@ -28,7 +28,7 @@ from trezor.wire.thp.thp_session import ThpError
async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse:
) -> ThpEndResponse:
assert ThpStartPairingRequest.is_type_of(message)
if __debug__:
@ -52,11 +52,14 @@ async def handle_pairing_request(
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
raise Exception(
"TODO Change this exception message and type. This exception should result in channel destruction."
)
async def _handle_code_entry_challenge(
ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse:
) -> ThpEndResponse:
assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(ctx, ChannelState.TP2)
@ -73,11 +76,14 @@ async def _handle_code_entry_challenge(
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
raise Exception(
"TODO Change this exception message and type. This exception should result in channel destruction."
)
async def _handle_code_entry_cpace(
ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse:
) -> ThpEndResponse:
assert ThpCodeEntryCpaceHost.is_type_of(message)
_check_state(ctx, ChannelState.TP3)
@ -89,54 +95,43 @@ async def _handle_code_entry_cpace(
async def _handle_code_entry_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse:
) -> ThpEndResponse:
assert ThpCodeEntryTag.is_type_of(message)
_check_state(ctx, ChannelState.TP4)
ctx.channel.set_channel_state(ChannelState.TC1)
response = await ctx.call_any(
ThpCodeEntrySecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
return await _handle_tag_message(
ctx,
expected_state=ChannelState.TP4,
used_method=ThpPairingMethod.PairingMethod_CodeEntry,
msg=ThpCodeEntrySecret(),
)
await _handle_credential_request_or_end_request(ctx, response)
async def _handle_qr_code_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse:
) -> ThpEndResponse:
assert ThpQrCodeTag.is_type_of(message)
_check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode)
ctx.channel.set_channel_state(ChannelState.TC1)
response = await ctx.call_any(
ThpQrCodeSecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
return await _handle_tag_message(
ctx,
expected_state=ChannelState.TP3,
used_method=ThpPairingMethod.PairingMethod_QrCode,
msg=ThpQrCodeSecret(),
)
await _handle_credential_request_or_end_request(ctx, response)
async def _handle_nfc_unidirectional_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse:
) -> ThpEndResponse:
assert ThpNfcUnidirectionalTag.is_type_of(message)
_check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
ctx.channel.set_channel_state(ChannelState.TC1)
response = await ctx.call_any(
ThpNfcUnideirectionalSecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
return await _handle_tag_message(
ctx,
expected_state=ChannelState.TP3,
used_method=ThpPairingMethod.PairingMethod_NFC_Unidirectional,
msg=ThpNfcUnideirectionalSecret(),
)
await _handle_credential_request_or_end_request(ctx, response)
async def _handle_credential_request(
ctx: PairingContext, message: protobuf.MessageType
) -> None | ThpEndResponse:
) -> ThpEndResponse:
assert ThpCredentialRequest.is_type_of(message)
_check_state(ctx, ChannelState.TC1)
@ -145,7 +140,7 @@ async def _handle_credential_request(
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
await _handle_credential_request_or_end_request(ctx, response)
return await _handle_credential_request_or_end_request(ctx, response)
async def _handle_end_request(
@ -158,6 +153,23 @@ async def _handle_end_request(
return ThpEndResponse()
async def _handle_tag_message(
ctx: PairingContext,
expected_state: ChannelState,
used_method: ThpPairingMethod,
msg: protobuf.MessageType,
) -> ThpEndResponse:
_check_state(ctx, expected_state)
_check_method_is_allowed(ctx, used_method)
ctx.channel.set_channel_state(ChannelState.TC1)
response = await ctx.call_any(
msg,
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
return await _handle_credential_request_or_end_request(ctx, response)
def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None:
if expected_state is not ctx.channel.get_channel_state():
raise UnexpectedMessage("Unexpected message")
@ -174,7 +186,7 @@ def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
async def _handle_credential_request_or_end_request(
ctx: PairingContext, response: protobuf.MessageType | None
) -> None | ThpEndResponse:
) -> ThpEndResponse:
if ThpCredentialRequest.is_type_of(response):
return await _handle_credential_request(ctx, response)
if ThpEndRequest.is_type_of(response):

Loading…
Cancel
Save