Connect pairing handlers with appropriate messages

M1nd3r/thp5
M1nd3r 1 month ago
parent 79ff7f09ba
commit ed7f304487

@ -7,6 +7,10 @@ from trezor.messages import (
ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret,
ThpCodeEntryTag,
ThpCredentialRequest,
ThpCredentialResponse,
ThpEndRequest,
ThpEndResponse,
ThpNfcUnideirectionalSecret,
ThpNfcUnidirectionalTag,
ThpQrCodeSecret,
@ -25,6 +29,7 @@ async def handle_pairing_request(
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntryCommitment | None:
assert ThpStartPairingRequest.is_type_of(message)
if __debug__:
log.debug(__name__, "handle_pairing_request")
_check_state(channel, ChannelState.TP1)
@ -36,15 +41,19 @@ async def handle_pairing_request(
async def handle_code_entry_challenge(
channel: Channel, message: ThpCodeEntryChallenge
channel: Channel, message: protobuf.MessageType
) -> None:
assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(channel, ChannelState.TP2)
channel.set_channel_state(ChannelState.TP3)
async def handle_code_entry_cpace(
channel: Channel, message: ThpCodeEntryCpaceHost
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntryCpaceTrezor:
assert ThpCodeEntryCpaceHost.is_type_of(message)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry)
channel.set_channel_state(ChannelState.TP4)
@ -52,16 +61,20 @@ async def handle_code_entry_cpace(
async def handle_code_entry_tag(
channel: Channel, message: ThpCodeEntryTag
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntrySecret:
assert ThpCodeEntryTag.is_type_of(message)
_check_state(channel, ChannelState.TP4)
channel.set_channel_state(ChannelState.TC1)
return ThpCodeEntrySecret()
async def handle_qr_code_tag(
channel: Channel, message: ThpQrCodeTag
channel: Channel, message: protobuf.MessageType
) -> ThpQrCodeSecret:
assert ThpQrCodeTag.is_type_of(message)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode)
channel.set_channel_state(ChannelState.TC1)
@ -69,14 +82,35 @@ async def handle_qr_code_tag(
async def handle_nfc_unidirectional_tag(
channel: Channel, message: ThpNfcUnidirectionalTag
channel: Channel, message: protobuf.MessageType
) -> ThpNfcUnideirectionalSecret:
assert ThpNfcUnidirectionalTag.is_type_of(message)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
channel.set_channel_state(ChannelState.TC1)
return ThpNfcUnideirectionalSecret()
async def handle_credential_request(
channel: Channel, message: protobuf.MessageType
) -> ThpCredentialResponse:
assert ThpCredentialRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1)
return ThpCredentialResponse()
async def handle_end_request(
channel: Channel, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpEndRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1)
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
def _check_state(channel: Channel, expected_state: ChannelState) -> None:
if expected_state is not channel.get_channel_state():
raise UnexpectedMessage("Unexpected message")

@ -9,7 +9,16 @@ from trezor.wire.protocol_common import MessageWithType
from trezor.wire.thp.session_context import UnexpectedMessageWithType
from trezor.wire.thp.thp_session import ThpError
from apps.thp.pairing import handle_pairing_request
from apps.thp.pairing import (
handle_code_entry_challenge,
handle_code_entry_cpace,
handle_code_entry_tag,
handle_credential_request,
handle_end_request,
handle_nfc_unidirectional_tag,
handle_pairing_request,
handle_qr_code_tag,
)
from .channel import Channel
@ -18,6 +27,17 @@ if TYPE_CHECKING:
pass
handlers = {
MessageType.ThpStartPairingRequest: handle_pairing_request,
MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge,
MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace,
MessageType.ThpCodeEntryTag: handle_code_entry_tag,
MessageType.ThpQrCodeTag: handle_qr_code_tag,
MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag,
MessageType.ThpCredentialRequest: handle_credential_request,
MessageType.ThpEndRequest: handle_end_request,
}
class PairingContext:
def __init__(self, channel: Channel) -> None:
@ -189,10 +209,12 @@ async def handle_pairing_message(
def get_handler(messageType: int):
if messageType == MessageType.ThpStartPairingRequest:
return handle_pairing_request
else:
raise ThpError("Handler for this method is not implemented yet")
if TYPE_CHECKING:
assert isinstance(messageType, MessageType)
handler = handlers.get(messageType)
if handler is None:
raise ThpError("Pairing handler for this message is not available!")
return handler
def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator:

Loading…
Cancel
Save