Change pairing process into a workflow

M1nd3r/thp5
M1nd3r 1 month ago
parent dca9f05921
commit 3a8c4c6330

@ -1,5 +1,5 @@
from trezor import log, protobuf
from trezor.enums import ThpPairingMethod
from trezor.enums import MessageType, ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
@ -18,110 +18,138 @@ from trezor.messages import (
ThpQrCodeTag,
ThpStartPairingRequest,
)
from trezor.wire import context
from trezor.wire.errors import UnexpectedMessage
from trezor.wire.thp import ChannelState
from trezor.wire.thp.channel import Channel
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.thp_session import ThpError
# TODO implement the following handlers
async def handle_pairing_request(
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntryCommitment | ThpPairingPreparationsFinished:
ctx: PairingContext, message: protobuf.MessageType
) -> None:
assert ThpStartPairingRequest.is_type_of(message)
if __debug__:
log.debug(__name__, "handle_pairing_request")
_check_state(channel, ChannelState.TP1)
if _is_method_included(channel, ThpPairingMethod.PairingMethod_CodeEntry):
channel.set_channel_state(ChannelState.TP2)
return ThpCodeEntryCommitment()
channel.set_channel_state(ChannelState.TP3)
return ThpPairingPreparationsFinished()
_check_state(ctx, ChannelState.TP1)
if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry):
ctx.channel.set_channel_state(ChannelState.TP2)
await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
ctx.channel.set_channel_state(ChannelState.TP3)
await context.call_any(
ThpPairingPreparationsFinished(),
MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag,
)
async def handle_code_entry_challenge(
channel: Channel, message: protobuf.MessageType
) -> ThpPairingPreparationsFinished:
ctx: PairingContext, message: protobuf.MessageType
) -> None:
assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(channel, ChannelState.TP2)
channel.set_channel_state(ChannelState.TP3)
return ThpPairingPreparationsFinished()
_check_state(ctx, ChannelState.TP2)
ctx.channel.set_channel_state(ChannelState.TP3)
await context.call_any(
ThpPairingPreparationsFinished(),
MessageType.ThpCodeEntryCpaceHost,
MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag,
)
async def handle_code_entry_cpace(
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntryCpaceTrezor:
ctx: PairingContext, message: protobuf.MessageType
) -> None:
assert ThpCodeEntryCpaceHost.is_type_of(message)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry)
channel.set_channel_state(ChannelState.TP4)
return ThpCodeEntryCpaceTrezor()
_check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry)
ctx.channel.set_channel_state(ChannelState.TP4)
await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag)
async def handle_code_entry_tag(
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntrySecret:
ctx: PairingContext, message: protobuf.MessageType
) -> None:
assert ThpCodeEntryTag.is_type_of(message)
_check_state(channel, ChannelState.TP4)
channel.set_channel_state(ChannelState.TC1)
return ThpCodeEntrySecret()
_check_state(ctx, ChannelState.TP4)
ctx.channel.set_channel_state(ChannelState.TC1)
await context.call_any(
ThpCodeEntrySecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
async def handle_qr_code_tag(
channel: Channel, message: protobuf.MessageType
) -> ThpQrCodeSecret:
ctx: PairingContext, message: protobuf.MessageType
) -> None:
assert ThpQrCodeTag.is_type_of(message)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode)
channel.set_channel_state(ChannelState.TC1)
return ThpQrCodeSecret()
_check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode)
ctx.channel.set_channel_state(ChannelState.TC1)
await context.call_any(
ThpQrCodeSecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
async def handle_nfc_unidirectional_tag(
channel: Channel, message: protobuf.MessageType
) -> ThpNfcUnideirectionalSecret:
ctx: PairingContext, message: protobuf.MessageType
) -> None:
assert ThpNfcUnidirectionalTag.is_type_of(message)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
channel.set_channel_state(ChannelState.TC1)
return ThpNfcUnideirectionalSecret()
_check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
ctx.channel.set_channel_state(ChannelState.TC1)
await context.call_any(
ThpNfcUnideirectionalSecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
async def handle_credential_request(
channel: Channel, message: protobuf.MessageType
) -> ThpCredentialResponse:
ctx: PairingContext, message: protobuf.MessageType
) -> None:
assert ThpCredentialRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1)
return ThpCredentialResponse()
_check_state(ctx, ChannelState.TC1)
await context.call_any(
ThpCredentialResponse(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
async def handle_end_request(
channel: Channel, message: protobuf.MessageType
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpEndRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1)
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
_check_state(ctx, ChannelState.TC1)
ctx.channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
def _check_state(channel: Channel, expected_state: ChannelState) -> None:
if expected_state is not channel.get_channel_state():
def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None:
if expected_state is not ctx.channel.get_channel_state():
raise UnexpectedMessage("Unexpected message")
def _check_method_is_allowed(channel: Channel, method: ThpPairingMethod) -> None:
if not _is_method_included(channel, method):
def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
if not _is_method_included(ctx, method):
raise ThpError("Unexpected pairing method")
def _is_method_included(channel: Channel, method: ThpPairingMethod) -> bool:
return method in channel.selected_pairing_methods
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel.selected_pairing_methods

@ -14,7 +14,6 @@ from trezor.messages import (
)
from trezor.wire import message_handler
from trezor.wire.thp import ack_handler, thp_messages
from trezor.wire.thp.handler_provider import get_handler
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum, crypto
@ -375,10 +374,15 @@ class Channel(Context):
)
async def _handle_pairing(self, message_length: int) -> None:
from trezor.wire.thp.handler_provider import get_handler_for_pairing
from . import pairing_context
from .pairing_context import PairingContext
if self.connection_context is None:
self.connection_context = PairingContext(self)
pairing_context.get_handler = get_handler_for_pairing # noqa
loop.schedule(self.connection_context.handle())
self._decrypt_buffer(message_length)
@ -426,8 +430,9 @@ class Channel(Context):
"This message cannot be handled by channel itself. It must be send to allocated session."
)
# TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_handshake
handler = get_handler(message)
handler = get_handler_for_handshake(message)
task = handler(self, message)
response_message = await task
# TODO handle

@ -1,6 +1,8 @@
from typing import TYPE_CHECKING
from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire.thp.thp_session import ThpError
from apps.thp import create_session
@ -9,8 +11,42 @@ if TYPE_CHECKING:
pass
from apps.thp.pairing import (
handle_code_entry_challenge,
handle_code_entry_cpace,
handle_code_entry_tag,
handle_credential_request,
handle_end_request,
handle_nfc_unidirectional_tag,
handle_pairing_request,
handle_qr_code_tag,
)
def get_handler(
def get_handler_for_handshake(
msg: protobuf.MessageType,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
return create_session.create_new_session
def get_handler_for_pairing(
messageType: int,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
if TYPE_CHECKING:
assert isinstance(messageType, MessageType)
handler = handlers.get(messageType)
if handler is None:
raise ThpError("Pairing handler for this message is not available!")
return handler
handlers = {
MessageType.ThpStartPairingRequest: handle_pairing_request,
MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge,
MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace,
MessageType.ThpCodeEntryTag: handle_code_entry_tag,
MessageType.ThpQrCodeTag: handle_qr_code_tag,
MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag,
MessageType.ThpCredentialRequest: handle_credential_request,
MessageType.ThpEndRequest: handle_end_request,
}

@ -1,46 +1,28 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from trezor import log, loop, protobuf, workflow
from trezor.enums import MessageType
from trezor.wire import message_handler, protocol_common
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageWithId
from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import MessageWithType
from trezor.wire.protocol_common import Context, MessageWithType
from trezor.wire.thp.session_context import UnexpectedMessageWithType
from trezor.wire.thp.thp_session import ThpError
from apps.thp.pairing import (
handle_code_entry_challenge,
handle_code_entry_cpace,
handle_code_entry_tag,
handle_credential_request,
handle_end_request,
handle_nfc_unidirectional_tag,
handle_pairing_request,
handle_qr_code_tag,
)
from .channel import Channel
if TYPE_CHECKING:
from typing import Container, Generator # pyright:ignore[reportShadowedImports]
from typing import ( # pyright:ignore[reportShadowedImports]
Any,
Callable,
Container,
Coroutine,
)
pass
handlers = {
MessageType.ThpStartPairingRequest: handle_pairing_request,
MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge,
MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace,
MessageType.ThpCodeEntryTag: handle_code_entry_tag,
MessageType.ThpQrCodeTag: handle_qr_code_tag,
MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag,
MessageType.ThpCredentialRequest: handle_credential_request,
MessageType.ThpEndRequest: handle_end_request,
}
class PairingContext:
class PairingContext(Context):
def __init__(self, channel: Channel) -> None:
super().__init__(channel.iface, channel.channel_id)
self.channel = channel
self.incoming_message = loop.chan()
@ -115,7 +97,9 @@ class PairingContext:
str(expected_types),
exp_type,
)
message: MessageWithType = await self.incoming_message.take()
if message.type not in expected_types:
raise UnexpectedMessageWithType(message)
@ -128,6 +112,15 @@ class PairingContext:
return await self.channel.write(msg)
def _find_handler_placeholder(
messageType: int,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
raise Exception()
get_handler = _find_handler_placeholder
async def handle_pairing_message(
ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool
) -> protocol_common.MessageWithType | None:
@ -159,7 +152,7 @@ async def handle_pairing_message(
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(ctx.channel, req_msg)
task = handler(ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
@ -168,7 +161,7 @@ async def handle_pairing_message(
if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(with_context(ctx, task))
res_msg = await workflow.spawn(context.with_context(ctx, task))
pass # TODO
else:
# For debug messages, ignore workflow processing and just await
@ -206,43 +199,3 @@ async def handle_pairing_message(
# problem bubbles up
await ctx.write(res_msg)
return None
def get_handler(messageType: int):
if TYPE_CHECKING:
assert isinstance(messageType, MessageType)
handler = handlers.get(messageType)
if handler is None:
raise ThpError("Pairing handler for this message is not available!")
return handler
def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator:
"""Run a workflow in a particular context.
Stores the context in a closure and installs it into the global variable every time
the closure is resumed, thus making sure that all calls to `wire.context.*` will
work as expected.
"""
global CURRENT_CONTEXT
send_val = None
send_exc = None
while True:
CURRENT_CONTEXT = ctx
try:
if send_exc is not None:
res = workflow.throw(send_exc)
else:
res = workflow.send(send_val)
except StopIteration as st:
return st.value
finally:
CURRENT_CONTEXT = None
try:
send_val = yield res
except BaseException as e:
send_exc = e
else:
send_exc = None

Loading…
Cancel
Save