Fix broken pairing workflow

M1nd3r/thp2
M1nd3r 2 months ago
parent 58df712f29
commit 3fc3bbc756

@ -20,7 +20,7 @@ from trezor.messages import (
) )
from trezor.wire import context from trezor.wire import context
from trezor.wire.errors import UnexpectedMessage from trezor.wire.errors import UnexpectedMessage
from trezor.wire.thp import ChannelState from trezor.wire.thp import ChannelState, pairing_context
from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.thp_session import ThpError from trezor.wire.thp.thp_session import ThpError
@ -38,14 +38,16 @@ async def handle_pairing_request(
if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry): if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry):
ctx.channel.set_channel_state(ChannelState.TP2) ctx.channel.set_channel_state(ChannelState.TP2)
await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge) response = await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
else: else:
ctx.channel.set_channel_state(ChannelState.TP3) ctx.channel.set_channel_state(ChannelState.TP3)
await context.call_any( response = await context.call_any(
ThpPairingPreparationsFinished(), ThpPairingPreparationsFinished(),
MessageType.ThpQrCodeTag, MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag, MessageType.ThpNfcUnidirectionalTag,
) )
await _handle_response(ctx, response)
async def handle_code_entry_challenge( async def handle_code_entry_challenge(
@ -55,12 +57,13 @@ async def handle_code_entry_challenge(
_check_state(ctx, ChannelState.TP2) _check_state(ctx, ChannelState.TP2)
ctx.channel.set_channel_state(ChannelState.TP3) ctx.channel.set_channel_state(ChannelState.TP3)
await context.call_any( response = await context.call_any(
ThpPairingPreparationsFinished(), ThpPairingPreparationsFinished(),
MessageType.ThpCodeEntryCpaceHost, MessageType.ThpCodeEntryCpaceHost,
MessageType.ThpQrCodeTag, MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag, MessageType.ThpNfcUnidirectionalTag,
) )
await _handle_response(ctx, response)
async def handle_code_entry_cpace( async def handle_code_entry_cpace(
@ -71,7 +74,8 @@ async def handle_code_entry_cpace(
_check_state(ctx, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry)
ctx.channel.set_channel_state(ChannelState.TP4) ctx.channel.set_channel_state(ChannelState.TP4)
await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) response = await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag)
await _handle_response(ctx, response)
async def handle_code_entry_tag( async def handle_code_entry_tag(
@ -81,11 +85,12 @@ async def handle_code_entry_tag(
_check_state(ctx, ChannelState.TP4) _check_state(ctx, ChannelState.TP4)
ctx.channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
await context.call_any( response = await context.call_any(
ThpCodeEntrySecret(), ThpCodeEntrySecret(),
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_response(ctx, response)
async def handle_qr_code_tag( async def handle_qr_code_tag(
@ -96,11 +101,12 @@ async def handle_qr_code_tag(
_check_state(ctx, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode)
ctx.channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
await context.call_any( response = await context.call_any(
ThpQrCodeSecret(), ThpQrCodeSecret(),
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_response(ctx, response)
async def handle_nfc_unidirectional_tag( async def handle_nfc_unidirectional_tag(
@ -111,11 +117,12 @@ async def handle_nfc_unidirectional_tag(
_check_state(ctx, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
ctx.channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
await context.call_any( response = await context.call_any(
ThpNfcUnideirectionalSecret(), ThpNfcUnideirectionalSecret(),
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_response(ctx, response)
async def handle_credential_request( async def handle_credential_request(
@ -124,11 +131,12 @@ async def handle_credential_request(
assert ThpCredentialRequest.is_type_of(message) assert ThpCredentialRequest.is_type_of(message)
_check_state(ctx, ChannelState.TC1) _check_state(ctx, ChannelState.TC1)
await context.call_any( response = await context.call_any(
ThpCredentialResponse(), ThpCredentialResponse(),
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_response(ctx, response)
async def handle_end_request( async def handle_end_request(
@ -153,3 +161,14 @@ def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> N
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool: def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel.selected_pairing_methods return method in ctx.channel.selected_pairing_methods
async def _handle_response(
ctx: PairingContext, response: protobuf.MessageType | None
) -> None:
if response is None:
raise Exception("Something is not ok")
if response.MESSAGE_WIRE_TYPE is None:
raise Exception("Something is not ok")
handler = pairing_context.get_handler(response.MESSAGE_WIRE_TYPE)
await handler(ctx, response)

Loading…
Cancel
Save