Connect pairing handlers with appropriate messages

M1nd3r/thp5
M1nd3r 2 months ago
parent 79ff7f09ba
commit ed7f304487

@ -7,6 +7,10 @@ from trezor.messages import (
ThpCodeEntryCpaceTrezor, ThpCodeEntryCpaceTrezor,
ThpCodeEntrySecret, ThpCodeEntrySecret,
ThpCodeEntryTag, ThpCodeEntryTag,
ThpCredentialRequest,
ThpCredentialResponse,
ThpEndRequest,
ThpEndResponse,
ThpNfcUnideirectionalSecret, ThpNfcUnideirectionalSecret,
ThpNfcUnidirectionalTag, ThpNfcUnidirectionalTag,
ThpQrCodeSecret, ThpQrCodeSecret,
@ -25,6 +29,7 @@ async def handle_pairing_request(
channel: Channel, message: protobuf.MessageType channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntryCommitment | None: ) -> ThpCodeEntryCommitment | None:
assert ThpStartPairingRequest.is_type_of(message) assert ThpStartPairingRequest.is_type_of(message)
if __debug__: if __debug__:
log.debug(__name__, "handle_pairing_request") log.debug(__name__, "handle_pairing_request")
_check_state(channel, ChannelState.TP1) _check_state(channel, ChannelState.TP1)
@ -36,15 +41,19 @@ async def handle_pairing_request(
async def handle_code_entry_challenge( async def handle_code_entry_challenge(
channel: Channel, message: ThpCodeEntryChallenge channel: Channel, message: protobuf.MessageType
) -> None: ) -> None:
assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(channel, ChannelState.TP2) _check_state(channel, ChannelState.TP2)
channel.set_channel_state(ChannelState.TP3) channel.set_channel_state(ChannelState.TP3)
async def handle_code_entry_cpace( async def handle_code_entry_cpace(
channel: Channel, message: ThpCodeEntryCpaceHost channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntryCpaceTrezor: ) -> ThpCodeEntryCpaceTrezor:
assert ThpCodeEntryCpaceHost.is_type_of(message)
_check_state(channel, ChannelState.TP3) _check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry) _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry)
channel.set_channel_state(ChannelState.TP4) channel.set_channel_state(ChannelState.TP4)
@ -52,16 +61,20 @@ async def handle_code_entry_cpace(
async def handle_code_entry_tag( async def handle_code_entry_tag(
channel: Channel, message: ThpCodeEntryTag channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntrySecret: ) -> ThpCodeEntrySecret:
assert ThpCodeEntryTag.is_type_of(message)
_check_state(channel, ChannelState.TP4) _check_state(channel, ChannelState.TP4)
channel.set_channel_state(ChannelState.TC1) channel.set_channel_state(ChannelState.TC1)
return ThpCodeEntrySecret() return ThpCodeEntrySecret()
async def handle_qr_code_tag( async def handle_qr_code_tag(
channel: Channel, message: ThpQrCodeTag channel: Channel, message: protobuf.MessageType
) -> ThpQrCodeSecret: ) -> ThpQrCodeSecret:
assert ThpQrCodeTag.is_type_of(message)
_check_state(channel, ChannelState.TP3) _check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode) _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode)
channel.set_channel_state(ChannelState.TC1) channel.set_channel_state(ChannelState.TC1)
@ -69,14 +82,35 @@ async def handle_qr_code_tag(
async def handle_nfc_unidirectional_tag( async def handle_nfc_unidirectional_tag(
channel: Channel, message: ThpNfcUnidirectionalTag channel: Channel, message: protobuf.MessageType
) -> ThpNfcUnideirectionalSecret: ) -> ThpNfcUnideirectionalSecret:
assert ThpNfcUnidirectionalTag.is_type_of(message)
_check_state(channel, ChannelState.TP3) _check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional) _check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
channel.set_channel_state(ChannelState.TC1) channel.set_channel_state(ChannelState.TC1)
return ThpNfcUnideirectionalSecret() return ThpNfcUnideirectionalSecret()
async def handle_credential_request(
channel: Channel, message: protobuf.MessageType
) -> ThpCredentialResponse:
assert ThpCredentialRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1)
return ThpCredentialResponse()
async def handle_end_request(
channel: Channel, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpEndRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1)
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
def _check_state(channel: Channel, expected_state: ChannelState) -> None: def _check_state(channel: Channel, expected_state: ChannelState) -> None:
if expected_state is not channel.get_channel_state(): if expected_state is not channel.get_channel_state():
raise UnexpectedMessage("Unexpected message") raise UnexpectedMessage("Unexpected message")

@ -9,7 +9,16 @@ from trezor.wire.protocol_common import MessageWithType
from trezor.wire.thp.session_context import UnexpectedMessageWithType from trezor.wire.thp.session_context import UnexpectedMessageWithType
from trezor.wire.thp.thp_session import ThpError from trezor.wire.thp.thp_session import ThpError
from apps.thp.pairing import handle_pairing_request from apps.thp.pairing import (
handle_code_entry_challenge,
handle_code_entry_cpace,
handle_code_entry_tag,
handle_credential_request,
handle_end_request,
handle_nfc_unidirectional_tag,
handle_pairing_request,
handle_qr_code_tag,
)
from .channel import Channel from .channel import Channel
@ -18,6 +27,17 @@ if TYPE_CHECKING:
pass pass
handlers = {
MessageType.ThpStartPairingRequest: handle_pairing_request,
MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge,
MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace,
MessageType.ThpCodeEntryTag: handle_code_entry_tag,
MessageType.ThpQrCodeTag: handle_qr_code_tag,
MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag,
MessageType.ThpCredentialRequest: handle_credential_request,
MessageType.ThpEndRequest: handle_end_request,
}
class PairingContext: class PairingContext:
def __init__(self, channel: Channel) -> None: def __init__(self, channel: Channel) -> None:
@ -189,10 +209,12 @@ async def handle_pairing_message(
def get_handler(messageType: int): def get_handler(messageType: int):
if messageType == MessageType.ThpStartPairingRequest: if TYPE_CHECKING:
return handle_pairing_request assert isinstance(messageType, MessageType)
else: handler = handlers.get(messageType)
raise ThpError("Handler for this method is not implemented yet") if handler is None:
raise ThpError("Pairing handler for this message is not available!")
return handler
def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator: def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator:

Loading…
Cancel
Save