Change pairing process into a workflow

M1nd3r/thp2
M1nd3r 2 months ago
parent 72e482ee97
commit 870f1a3b98

@ -1,5 +1,5 @@
from trezor import log, protobuf from trezor import log, protobuf
from trezor.enums import ThpPairingMethod from trezor.enums import MessageType, ThpPairingMethod
from trezor.messages import ( from trezor.messages import (
ThpCodeEntryChallenge, ThpCodeEntryChallenge,
ThpCodeEntryCommitment, ThpCodeEntryCommitment,
@ -18,110 +18,138 @@ from trezor.messages import (
ThpQrCodeTag, ThpQrCodeTag,
ThpStartPairingRequest, ThpStartPairingRequest,
) )
from trezor.wire import context
from trezor.wire.errors import UnexpectedMessage from trezor.wire.errors import UnexpectedMessage
from trezor.wire.thp import ChannelState from trezor.wire.thp import ChannelState
from trezor.wire.thp.channel import Channel from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.thp_session import ThpError from trezor.wire.thp.thp_session import ThpError
# TODO implement the following handlers # TODO implement the following handlers
async def handle_pairing_request( async def handle_pairing_request(
channel: Channel, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpCodeEntryCommitment | ThpPairingPreparationsFinished: ) -> None:
assert ThpStartPairingRequest.is_type_of(message) assert ThpStartPairingRequest.is_type_of(message)
if __debug__: if __debug__:
log.debug(__name__, "handle_pairing_request") log.debug(__name__, "handle_pairing_request")
_check_state(channel, ChannelState.TP1) _check_state(ctx, ChannelState.TP1)
if _is_method_included(channel, ThpPairingMethod.PairingMethod_CodeEntry):
channel.set_channel_state(ChannelState.TP2) if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry):
return ThpCodeEntryCommitment() ctx.channel.set_channel_state(ChannelState.TP2)
channel.set_channel_state(ChannelState.TP3) await context.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
return ThpPairingPreparationsFinished()
ctx.channel.set_channel_state(ChannelState.TP3)
await context.call_any(
ThpPairingPreparationsFinished(),
MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag,
)
async def handle_code_entry_challenge( async def handle_code_entry_challenge(
channel: Channel, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpPairingPreparationsFinished: ) -> None:
assert ThpCodeEntryChallenge.is_type_of(message) assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(channel, ChannelState.TP2) _check_state(ctx, ChannelState.TP2)
channel.set_channel_state(ChannelState.TP3) ctx.channel.set_channel_state(ChannelState.TP3)
return ThpPairingPreparationsFinished() await context.call_any(
ThpPairingPreparationsFinished(),
MessageType.ThpCodeEntryCpaceHost,
MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag,
)
async def handle_code_entry_cpace( async def handle_code_entry_cpace(
channel: Channel, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpCodeEntryCpaceTrezor: ) -> None:
assert ThpCodeEntryCpaceHost.is_type_of(message) assert ThpCodeEntryCpaceHost.is_type_of(message)
_check_state(channel, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry)
channel.set_channel_state(ChannelState.TP4) ctx.channel.set_channel_state(ChannelState.TP4)
return ThpCodeEntryCpaceTrezor() await context.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag)
async def handle_code_entry_tag( async def handle_code_entry_tag(
channel: Channel, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpCodeEntrySecret: ) -> None:
assert ThpCodeEntryTag.is_type_of(message) assert ThpCodeEntryTag.is_type_of(message)
_check_state(channel, ChannelState.TP4) _check_state(ctx, ChannelState.TP4)
channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
return ThpCodeEntrySecret() await context.call_any(
ThpCodeEntrySecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
async def handle_qr_code_tag( async def handle_qr_code_tag(
channel: Channel, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpQrCodeSecret: ) -> None:
assert ThpQrCodeTag.is_type_of(message) assert ThpQrCodeTag.is_type_of(message)
_check_state(channel, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_QrCode)
channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
return ThpQrCodeSecret() await context.call_any(
ThpQrCodeSecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
async def handle_nfc_unidirectional_tag( async def handle_nfc_unidirectional_tag(
channel: Channel, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpNfcUnideirectionalSecret: ) -> None:
assert ThpNfcUnidirectionalTag.is_type_of(message) assert ThpNfcUnidirectionalTag.is_type_of(message)
_check_state(channel, ChannelState.TP3) _check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional) _check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
channel.set_channel_state(ChannelState.TC1) ctx.channel.set_channel_state(ChannelState.TC1)
return ThpNfcUnideirectionalSecret() await context.call_any(
ThpNfcUnideirectionalSecret(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
async def handle_credential_request( async def handle_credential_request(
channel: Channel, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpCredentialResponse: ) -> None:
assert ThpCredentialRequest.is_type_of(message) assert ThpCredentialRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1) _check_state(ctx, ChannelState.TC1)
return ThpCredentialResponse() await context.call_any(
ThpCredentialResponse(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
async def handle_end_request( async def handle_end_request(
channel: Channel, message: protobuf.MessageType ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse: ) -> ThpEndResponse:
assert ThpEndRequest.is_type_of(message) assert ThpEndRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1) _check_state(ctx, ChannelState.TC1)
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT) ctx.channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse() return ThpEndResponse()
def _check_state(channel: Channel, expected_state: ChannelState) -> None: def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None:
if expected_state is not channel.get_channel_state(): if expected_state is not ctx.channel.get_channel_state():
raise UnexpectedMessage("Unexpected message") raise UnexpectedMessage("Unexpected message")
def _check_method_is_allowed(channel: Channel, method: ThpPairingMethod) -> None: def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
if not _is_method_included(channel, method): if not _is_method_included(ctx, method):
raise ThpError("Unexpected pairing method") raise ThpError("Unexpected pairing method")
def _is_method_included(channel: Channel, method: ThpPairingMethod) -> bool: def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in channel.selected_pairing_methods return method in ctx.channel.selected_pairing_methods

@ -14,7 +14,6 @@ from trezor.messages import (
) )
from trezor.wire import message_handler from trezor.wire import message_handler
from trezor.wire.thp import ack_handler, thp_messages from trezor.wire.thp import ack_handler, thp_messages
from trezor.wire.thp.handler_provider import get_handler
from ..protocol_common import Context, MessageWithType from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum, crypto from . import ChannelState, SessionState, checksum, crypto
@ -375,10 +374,15 @@ class Channel(Context):
) )
async def _handle_pairing(self, message_length: int) -> None: async def _handle_pairing(self, message_length: int) -> None:
from trezor.wire.thp.handler_provider import get_handler_for_pairing
from . import pairing_context
from .pairing_context import PairingContext from .pairing_context import PairingContext
if self.connection_context is None: if self.connection_context is None:
self.connection_context = PairingContext(self) self.connection_context = PairingContext(self)
pairing_context.get_handler = get_handler_for_pairing # noqa
loop.schedule(self.connection_context.handle()) loop.schedule(self.connection_context.handle())
self._decrypt_buffer(message_length) self._decrypt_buffer(message_length)
@ -426,8 +430,9 @@ class Channel(Context):
"This message cannot be handled by channel itself. It must be send to allocated session." "This message cannot be handled by channel itself. It must be send to allocated session."
) )
# TODO handle other messages than CreateNewSession # TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_handshake
handler = get_handler(message) handler = get_handler_for_handshake(message)
task = handler(self, message) task = handler(self, message)
response_message = await task response_message = await task
# TODO handle # TODO handle

@ -1,6 +1,8 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from trezor import protobuf from trezor import protobuf
from trezor.enums import MessageType
from trezor.wire.thp.thp_session import ThpError
from apps.thp import create_session from apps.thp import create_session
@ -9,8 +11,42 @@ if TYPE_CHECKING:
pass pass
from apps.thp.pairing import (
handle_code_entry_challenge,
handle_code_entry_cpace,
handle_code_entry_tag,
handle_credential_request,
handle_end_request,
handle_nfc_unidirectional_tag,
handle_pairing_request,
handle_qr_code_tag,
)
def get_handler(
def get_handler_for_handshake(
msg: protobuf.MessageType, msg: protobuf.MessageType,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: ) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
return create_session.create_new_session return create_session.create_new_session
def get_handler_for_pairing(
messageType: int,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
if TYPE_CHECKING:
assert isinstance(messageType, MessageType)
handler = handlers.get(messageType)
if handler is None:
raise ThpError("Pairing handler for this message is not available!")
return handler
handlers = {
MessageType.ThpStartPairingRequest: handle_pairing_request,
MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge,
MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace,
MessageType.ThpCodeEntryTag: handle_code_entry_tag,
MessageType.ThpQrCodeTag: handle_qr_code_tag,
MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag,
MessageType.ThpCredentialRequest: handle_credential_request,
MessageType.ThpEndRequest: handle_end_request,
}

@ -1,46 +1,28 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from trezor import log, loop, protobuf, workflow from trezor import log, loop, protobuf, workflow
from trezor.enums import MessageType from trezor.wire import context, message_handler, protocol_common
from trezor.wire import message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageWithId from trezor.wire.context import UnexpectedMessageWithId
from trezor.wire.errors import ActionCancelled from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import MessageWithType from trezor.wire.protocol_common import Context, MessageWithType
from trezor.wire.thp.session_context import UnexpectedMessageWithType from trezor.wire.thp.session_context import UnexpectedMessageWithType
from trezor.wire.thp.thp_session import ThpError
from apps.thp.pairing import (
handle_code_entry_challenge,
handle_code_entry_cpace,
handle_code_entry_tag,
handle_credential_request,
handle_end_request,
handle_nfc_unidirectional_tag,
handle_pairing_request,
handle_qr_code_tag,
)
from .channel import Channel from .channel import Channel
if TYPE_CHECKING: if TYPE_CHECKING:
from typing import Container, Generator # pyright:ignore[reportShadowedImports] from typing import ( # pyright:ignore[reportShadowedImports]
Any,
Callable,
Container,
Coroutine,
)
pass pass
handlers = {
MessageType.ThpStartPairingRequest: handle_pairing_request,
MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge,
MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace,
MessageType.ThpCodeEntryTag: handle_code_entry_tag,
MessageType.ThpQrCodeTag: handle_qr_code_tag,
MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag,
MessageType.ThpCredentialRequest: handle_credential_request,
MessageType.ThpEndRequest: handle_end_request,
}
class PairingContext(Context):
class PairingContext:
def __init__(self, channel: Channel) -> None: def __init__(self, channel: Channel) -> None:
super().__init__(channel.iface, channel.channel_id)
self.channel = channel self.channel = channel
self.incoming_message = loop.chan() self.incoming_message = loop.chan()
@ -115,7 +97,9 @@ class PairingContext:
str(expected_types), str(expected_types),
exp_type, exp_type,
) )
message: MessageWithType = await self.incoming_message.take() message: MessageWithType = await self.incoming_message.take()
if message.type not in expected_types: if message.type not in expected_types:
raise UnexpectedMessageWithType(message) raise UnexpectedMessageWithType(message)
@ -128,6 +112,15 @@ class PairingContext:
return await self.channel.write(msg) return await self.channel.write(msg)
def _find_handler_placeholder(
messageType: int,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
raise Exception()
get_handler = _find_handler_placeholder
async def handle_pairing_message( async def handle_pairing_message(
ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool
) -> protocol_common.MessageWithType | None: ) -> protocol_common.MessageWithType | None:
@ -159,7 +152,7 @@ async def handle_pairing_message(
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type) req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task. # Create the handler task.
task = handler(ctx.channel, req_msg) task = handler(ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire # Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a # communication inside, but it should eventually return a
@ -168,7 +161,7 @@ async def handle_pairing_message(
if use_workflow: if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent # Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down. # workflows are shut down.
res_msg = await workflow.spawn(with_context(ctx, task)) res_msg = await workflow.spawn(context.with_context(ctx, task))
pass # TODO pass # TODO
else: else:
# For debug messages, ignore workflow processing and just await # For debug messages, ignore workflow processing and just await
@ -206,43 +199,3 @@ async def handle_pairing_message(
# problem bubbles up # problem bubbles up
await ctx.write(res_msg) await ctx.write(res_msg)
return None return None
def get_handler(messageType: int):
if TYPE_CHECKING:
assert isinstance(messageType, MessageType)
handler = handlers.get(messageType)
if handler is None:
raise ThpError("Pairing handler for this message is not available!")
return handler
def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator:
"""Run a workflow in a particular context.
Stores the context in a closure and installs it into the global variable every time
the closure is resumed, thus making sure that all calls to `wire.context.*` will
work as expected.
"""
global CURRENT_CONTEXT
send_val = None
send_exc = None
while True:
CURRENT_CONTEXT = ctx
try:
if send_exc is not None:
res = workflow.throw(send_exc)
else:
res = workflow.send(send_val)
except StopIteration as st:
return st.value
finally:
CURRENT_CONTEXT = None
try:
send_val = yield res
except BaseException as e:
send_exc = e
else:
send_exc = None

Loading…
Cancel
Save