diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index b3ac3f263..6842b3ea2 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -7,6 +7,10 @@ from trezor.messages import ( ThpCodeEntryCpaceTrezor, ThpCodeEntrySecret, ThpCodeEntryTag, + ThpCredentialRequest, + ThpCredentialResponse, + ThpEndRequest, + ThpEndResponse, ThpNfcUnideirectionalSecret, ThpNfcUnidirectionalTag, ThpQrCodeSecret, @@ -25,6 +29,7 @@ async def handle_pairing_request( channel: Channel, message: protobuf.MessageType ) -> ThpCodeEntryCommitment | None: assert ThpStartPairingRequest.is_type_of(message) + if __debug__: log.debug(__name__, "handle_pairing_request") _check_state(channel, ChannelState.TP1) @@ -36,15 +41,19 @@ async def handle_pairing_request( async def handle_code_entry_challenge( - channel: Channel, message: ThpCodeEntryChallenge + channel: Channel, message: protobuf.MessageType ) -> None: + assert ThpCodeEntryChallenge.is_type_of(message) + _check_state(channel, ChannelState.TP2) channel.set_channel_state(ChannelState.TP3) async def handle_code_entry_cpace( - channel: Channel, message: ThpCodeEntryCpaceHost + channel: Channel, message: protobuf.MessageType ) -> ThpCodeEntryCpaceTrezor: + assert ThpCodeEntryCpaceHost.is_type_of(message) + _check_state(channel, ChannelState.TP3) _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry) channel.set_channel_state(ChannelState.TP4) @@ -52,16 +61,20 @@ async def handle_code_entry_cpace( async def handle_code_entry_tag( - channel: Channel, message: ThpCodeEntryTag + channel: Channel, message: protobuf.MessageType ) -> ThpCodeEntrySecret: + assert ThpCodeEntryTag.is_type_of(message) + _check_state(channel, ChannelState.TP4) channel.set_channel_state(ChannelState.TC1) return ThpCodeEntrySecret() async def handle_qr_code_tag( - channel: Channel, message: ThpQrCodeTag + channel: Channel, message: protobuf.MessageType ) -> ThpQrCodeSecret: + assert ThpQrCodeTag.is_type_of(message) + _check_state(channel, ChannelState.TP3) _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode) channel.set_channel_state(ChannelState.TC1) @@ -69,14 +82,35 @@ async def handle_qr_code_tag( async def handle_nfc_unidirectional_tag( - channel: Channel, message: ThpNfcUnidirectionalTag + channel: Channel, message: protobuf.MessageType ) -> ThpNfcUnideirectionalSecret: + assert ThpNfcUnidirectionalTag.is_type_of(message) + _check_state(channel, ChannelState.TP3) _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional) channel.set_channel_state(ChannelState.TC1) return ThpNfcUnideirectionalSecret() +async def handle_credential_request( + channel: Channel, message: protobuf.MessageType +) -> ThpCredentialResponse: + assert ThpCredentialRequest.is_type_of(message) + + _check_state(channel, ChannelState.TC1) + return ThpCredentialResponse() + + +async def handle_end_request( + channel: Channel, message: protobuf.MessageType +) -> ThpEndResponse: + assert ThpEndRequest.is_type_of(message) + + _check_state(channel, ChannelState.TC1) + channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) + return ThpEndResponse() + + def _check_state(channel: Channel, expected_state: ChannelState) -> None: if expected_state is not channel.get_channel_state(): raise UnexpectedMessage("Unexpected message") diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index 5f4a5fa47..a81a95dd7 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -9,7 +9,16 @@ from trezor.wire.protocol_common import MessageWithType from trezor.wire.thp.session_context import UnexpectedMessageWithType from trezor.wire.thp.thp_session import ThpError -from apps.thp.pairing import handle_pairing_request +from apps.thp.pairing import ( + handle_code_entry_challenge, + handle_code_entry_cpace, + handle_code_entry_tag, + handle_credential_request, + handle_end_request, + handle_nfc_unidirectional_tag, + handle_pairing_request, + handle_qr_code_tag, +) from .channel import Channel @@ -18,6 +27,17 @@ if TYPE_CHECKING: pass +handlers = { + MessageType.ThpStartPairingRequest: handle_pairing_request, + MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge, + MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace, + MessageType.ThpCodeEntryTag: handle_code_entry_tag, + MessageType.ThpQrCodeTag: handle_qr_code_tag, + MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag, + MessageType.ThpCredentialRequest: handle_credential_request, + MessageType.ThpEndRequest: handle_end_request, +} + class PairingContext: def __init__(self, channel: Channel) -> None: @@ -189,10 +209,12 @@ async def handle_pairing_message( def get_handler(messageType: int): - if messageType == MessageType.ThpStartPairingRequest: - return handle_pairing_request - else: - raise ThpError("Handler for this method is not implemented yet") + if TYPE_CHECKING: + assert isinstance(messageType, MessageType) + handler = handlers.get(messageType) + if handler is None: + raise ThpError("Pairing handler for this message is not available!") + return handler def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator: