Improve pairing handling

M1nd3r/thp2
M1nd3r 1 month ago
parent 3fc3bbc756
commit 4777750b2f

@ -18,9 +18,8 @@ from trezor.messages import (
ThpQrCodeTag, ThpQrCodeTag,
ThpStartPairingRequest, ThpStartPairingRequest,
) )
from trezor.wire import context
from trezor.wire.errors import UnexpectedMessage from trezor.wire.errors import UnexpectedMessage
from trezor.wire.thp import ChannelState, pairing_context from trezor.wire.thp import ChannelState
from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.thp_session import ThpError from trezor.wire.thp.thp_session import ThpError
@ -29,117 +28,127 @@ from trezor.wire.thp.thp_session import ThpError
async def handle_pairing_request( async def handle_pairing_request(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None: ) -> None | ThpEndResponse:
assert ThpStartPairingRequest.is_type_of(message) assert ThpStartPairingRequest.is_type_of(message)
if __debug__: if __debug__:
log.debug(__name__, "handle_pairing_request") log.debug(__name__, "handle_pairing_request")
_check_state(ctx, ChannelState.TP1) _check_state(ctx, ChannelState.TP1)
if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry): if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry):
ctx.channel.set_channel_state(ChannelState.TP2) ctx.channel.set_channel_state(ChannelState.TP2)
response = await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
else: response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
ctx.channel.set_channel_state(ChannelState.TP3) return await _handle_code_entry_challenge(ctx, response)
response = await context.call_any(
ThpPairingPreparationsFinished(), ctx.channel.set_channel_state(ChannelState.TP3)
MessageType.ThpQrCodeTag, response = await ctx.call_any(
MessageType.ThpNfcUnidirectionalTag, ThpPairingPreparationsFinished(),
) MessageType.ThpQrCodeTag,
await _handle_response(ctx, response) MessageType.ThpNfcUnidirectionalTag,
)
if ThpQrCodeTag.is_type_of(response):
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
async def handle_code_entry_challenge( async def _handle_code_entry_challenge(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None: ) -> None | ThpEndResponse:
assert ThpCodeEntryChallenge.is_type_of(message) assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(ctx, ChannelState.TP2) _check_state(ctx, ChannelState.TP2)
ctx.channel.set_channel_state(ChannelState.TP3) ctx.channel.set_channel_state(ChannelState.TP3)
response = await context.call_any( response = await ctx.call_any(
ThpPairingPreparationsFinished(), ThpPairingPreparationsFinished(),
MessageType.ThpCodeEntryCpaceHost, MessageType.ThpCodeEntryCpaceHost,
MessageType.ThpQrCodeTag, MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag, MessageType.ThpNfcUnidirectionalTag,
) )
await _handle_response(ctx, response) if ThpCodeEntryCpaceHost.is_type_of(response):
return await _handle_code_entry_cpace(ctx, response)
if ThpQrCodeTag.is_type_of(response):
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
async def handle_code_entry_cpace( async def _handle_code_entry_cpace(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None: ) -> None | ThpEndResponse:
assert ThpCodeEntryCpaceHost.is_type_of(message) assert ThpCodeEntryCpaceHost.is_type_of(message)
_check_state(ctx, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry)
ctx.channel.set_channel_state(ChannelState.TP4) ctx.channel.set_channel_state(ChannelState.TP4)
response = await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag) response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag)
await _handle_response(ctx, response) return await _handle_code_entry_tag(ctx, response)
async def handle_code_entry_tag( async def _handle_code_entry_tag(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None: ) -> None | ThpEndResponse:
assert ThpCodeEntryTag.is_type_of(message) assert ThpCodeEntryTag.is_type_of(message)
_check_state(ctx, ChannelState.TP4) _check_state(ctx, ChannelState.TP4)
ctx.channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
response = await context.call_any( response = await ctx.call_any(
ThpCodeEntrySecret(), ThpCodeEntrySecret(),
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_response(ctx, response) await _handle_credential_request_or_end_request(ctx, response)
async def handle_qr_code_tag( async def _handle_qr_code_tag(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None: ) -> None | ThpEndResponse:
assert ThpQrCodeTag.is_type_of(message) assert ThpQrCodeTag.is_type_of(message)
_check_state(ctx, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode)
ctx.channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
response = await context.call_any( response = await ctx.call_any(
ThpQrCodeSecret(), ThpQrCodeSecret(),
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_response(ctx, response) await _handle_credential_request_or_end_request(ctx, response)
async def handle_nfc_unidirectional_tag( async def _handle_nfc_unidirectional_tag(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None: ) -> None | ThpEndResponse:
assert ThpNfcUnidirectionalTag.is_type_of(message) assert ThpNfcUnidirectionalTag.is_type_of(message)
_check_state(ctx, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
ctx.channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
response = await context.call_any( response = await ctx.call_any(
ThpNfcUnideirectionalSecret(), ThpNfcUnideirectionalSecret(),
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_response(ctx, response) await _handle_credential_request_or_end_request(ctx, response)
async def handle_credential_request( async def _handle_credential_request(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> None: ) -> None | ThpEndResponse:
assert ThpCredentialRequest.is_type_of(message) assert ThpCredentialRequest.is_type_of(message)
_check_state(ctx, ChannelState.TC1) _check_state(ctx, ChannelState.TC1)
response = await context.call_any( response = await ctx.call_any(
ThpCredentialResponse(), ThpCredentialResponse(),
MessageType.ThpCredentialRequest, MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest, MessageType.ThpEndRequest,
) )
await _handle_response(ctx, response) await _handle_credential_request_or_end_request(ctx, response)
async def handle_end_request( async def _handle_end_request(
ctx: PairingContext, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse: ) -> ThpEndResponse:
assert ThpEndRequest.is_type_of(message) assert ThpEndRequest.is_type_of(message)
@ -163,12 +172,13 @@ def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel.selected_pairing_methods return method in ctx.channel.selected_pairing_methods
async def _handle_response( async def _handle_credential_request_or_end_request(
ctx: PairingContext, response: protobuf.MessageType | None ctx: PairingContext, response: protobuf.MessageType | None
) -> None: ) -> None | ThpEndResponse:
if response is None: if ThpCredentialRequest.is_type_of(response):
raise Exception("Something is not ok") return await _handle_credential_request(ctx, response)
if response.MESSAGE_WIRE_TYPE is None: if ThpEndRequest.is_type_of(response):
raise Exception("Something is not ok") return await _handle_end_request(ctx, response)
handler = pairing_context.get_handler(response.MESSAGE_WIRE_TYPE) raise UnexpectedMessage(
await handler(ctx, response) "Received message is not credential request or end request."
)

@ -1,8 +1,6 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import protobuf from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire.thp.thp_session import ThpError
from apps.thp import create_session from apps.thp import create_session
@ -11,16 +9,7 @@ if TYPE_CHECKING:
pass pass
from apps.thp.pairing import ( from apps.thp.pairing import handle_pairing_request
handle_code_entry_challenge,
handle_code_entry_cpace,
handle_code_entry_tag,
handle_credential_request,
handle_end_request,
handle_nfc_unidirectional_tag,
handle_pairing_request,
handle_qr_code_tag,
)
def get_handler_for_handshake( def get_handler_for_handshake(
@ -31,22 +20,5 @@ def get_handler_for_handshake(
def get_handler_for_pairing( def get_handler_for_pairing(
messageType: int, messageType: int,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: ) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType | None]]:
if TYPE_CHECKING: return handle_pairing_request
assert isinstance(messageType, MessageType)
handler = handlers.get(messageType)
if handler is None:
raise ThpError("Pairing handler for this message is not available!")
return handler
handlers = {
MessageType.ThpStartPairingRequest: handle_pairing_request,
MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge,
MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace,
MessageType.ThpCodeEntryTag: handle_code_entry_tag,
MessageType.ThpQrCodeTag: handle_qr_code_tag,
MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag,
MessageType.ThpCredentialRequest: handle_credential_request,
MessageType.ThpEndRequest: handle_end_request,
}

@ -111,6 +111,23 @@ class PairingContext(Context):
async def write(self, msg: protobuf.MessageType) -> None: async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg) return await self.channel.write(msg)
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
) -> protobuf.MessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_types: int
) -> protobuf.MessageType:
await self.write(msg)
del msg
return await self.read(expected_types)
def _find_handler_placeholder( def _find_handler_placeholder(
messageType: int, messageType: int,

Loading…
Cancel
Save