feat(core): change pairing process into a workflow

M1nd3r/thp6
M1nd3r 1 month ago
parent 1ad87e6b4f
commit e884aaef51

@ -1,5 +1,5 @@
from trezor import log, protobuf
from trezor.enums import ThpPairingMethod
from trezor.enums import MessageType, ThpPairingMethod
from trezor.messages import (
ThpCodeEntryChallenge,
ThpCodeEntryCommitment,
@ -20,108 +20,177 @@ from trezor.messages import (
)
from trezor.wire.errors import UnexpectedMessage
from trezor.wire.thp import ChannelState
from trezor.wire.thp.channel import Channel
from trezor.wire.thp.pairing_context import PairingContext
from trezor.wire.thp.thp_session import ThpError
# TODO implement the following handlers
async def handle_pairing_request(
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntryCommitment | ThpPairingPreparationsFinished:
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpStartPairingRequest.is_type_of(message)
if __debug__:
log.debug(__name__, "handle_pairing_request")
_check_state(channel, ChannelState.TP1)
if _is_method_included(channel, ThpPairingMethod.PairingMethod_CodeEntry):
channel.set_channel_state(ChannelState.TP2)
return ThpCodeEntryCommitment()
channel.set_channel_state(ChannelState.TP3)
return ThpPairingPreparationsFinished()
_check_state(ctx, ChannelState.TP1)
async def handle_code_entry_challenge(
channel: Channel, message: protobuf.MessageType
) -> ThpPairingPreparationsFinished:
assert ThpCodeEntryChallenge.is_type_of(message)
if _is_method_included(ctx, ThpPairingMethod.PairingMethod_CodeEntry):
ctx.channel.set_channel_state(ChannelState.TP2)
response = await ctx.call(ThpCodeEntryCommitment(), ThpCodeEntryChallenge)
return await _handle_code_entry_challenge(ctx, response)
_check_state(channel, ChannelState.TP2)
channel.set_channel_state(ChannelState.TP3)
return ThpPairingPreparationsFinished()
ctx.channel.set_channel_state(ChannelState.TP3)
response = await ctx.call_any(
ThpPairingPreparationsFinished(),
MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag,
)
if ThpQrCodeTag.is_type_of(response):
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
raise Exception(
"TODO Change this exception message and type. This exception should result in channel destruction."
)
async def handle_code_entry_cpace(
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntryCpaceTrezor:
async def _handle_code_entry_challenge(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpCodeEntryChallenge.is_type_of(message)
_check_state(ctx, ChannelState.TP2)
ctx.channel.set_channel_state(ChannelState.TP3)
response = await ctx.call_any(
ThpPairingPreparationsFinished(),
MessageType.ThpCodeEntryCpaceHost,
MessageType.ThpQrCodeTag,
MessageType.ThpNfcUnidirectionalTag,
)
if ThpCodeEntryCpaceHost.is_type_of(response):
return await _handle_code_entry_cpace(ctx, response)
if ThpQrCodeTag.is_type_of(response):
return await _handle_qr_code_tag(ctx, response)
if ThpNfcUnidirectionalTag.is_type_of(response):
return await _handle_nfc_unidirectional_tag(ctx, response)
raise Exception(
"TODO Change this exception message and type. This exception should result in channel destruction."
)
async def _handle_code_entry_cpace(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpCodeEntryCpaceHost.is_type_of(message)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry)
channel.set_channel_state(ChannelState.TP4)
return ThpCodeEntryCpaceTrezor()
_check_state(ctx, ChannelState.TP3)
_check_method_is_allowed(ctx, ThpPairingMethod.PairingMethod_CodeEntry)
ctx.channel.set_channel_state(ChannelState.TP4)
response = await ctx.call(ThpCodeEntryCpaceTrezor(), ThpCodeEntryTag)
return await _handle_code_entry_tag(ctx, response)
async def handle_code_entry_tag(
channel: Channel, message: protobuf.MessageType
) -> ThpCodeEntrySecret:
async def _handle_code_entry_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpCodeEntryTag.is_type_of(message)
_check_state(channel, ChannelState.TP4)
channel.set_channel_state(ChannelState.TC1)
return ThpCodeEntrySecret()
return await _handle_tag_message(
ctx,
expected_state=ChannelState.TP4,
used_method=ThpPairingMethod.PairingMethod_CodeEntry,
msg=ThpCodeEntrySecret(),
)
async def handle_qr_code_tag(
channel: Channel, message: protobuf.MessageType
) -> ThpQrCodeSecret:
async def _handle_qr_code_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpQrCodeTag.is_type_of(message)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode)
channel.set_channel_state(ChannelState.TC1)
return ThpQrCodeSecret()
return await _handle_tag_message(
ctx,
expected_state=ChannelState.TP3,
used_method=ThpPairingMethod.PairingMethod_QrCode,
msg=ThpQrCodeSecret(),
)
async def handle_nfc_unidirectional_tag(
channel: Channel, message: protobuf.MessageType
) -> ThpNfcUnideirectionalSecret:
async def _handle_nfc_unidirectional_tag(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpNfcUnidirectionalTag.is_type_of(message)
return await _handle_tag_message(
ctx,
expected_state=ChannelState.TP3,
used_method=ThpPairingMethod.PairingMethod_NFC_Unidirectional,
msg=ThpNfcUnideirectionalSecret(),
)
_check_state(channel, ChannelState.TP3)
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
channel.set_channel_state(ChannelState.TC1)
return ThpNfcUnideirectionalSecret()
async def handle_credential_request(
channel: Channel, message: protobuf.MessageType
) -> ThpCredentialResponse:
async def _handle_credential_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpCredentialRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1)
return ThpCredentialResponse()
_check_state(ctx, ChannelState.TC1)
response = await ctx.call_any(
ThpCredentialResponse(),
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
return await _handle_credential_request_or_end_request(ctx, response)
async def handle_end_request(
channel: Channel, message: protobuf.MessageType
async def _handle_end_request(
ctx: PairingContext, message: protobuf.MessageType
) -> ThpEndResponse:
assert ThpEndRequest.is_type_of(message)
_check_state(channel, ChannelState.TC1)
channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
_check_state(ctx, ChannelState.TC1)
ctx.channel.set_channel_state(ChannelState.ENCRYPTED_TRANSPORT)
return ThpEndResponse()
def _check_state(channel: Channel, expected_state: ChannelState) -> None:
if expected_state is not channel.get_channel_state():
async def _handle_tag_message(
ctx: PairingContext,
expected_state: ChannelState,
used_method: ThpPairingMethod,
msg: protobuf.MessageType,
) -> ThpEndResponse:
_check_state(ctx, expected_state)
_check_method_is_allowed(ctx, used_method)
ctx.channel.set_channel_state(ChannelState.TC1)
response = await ctx.call_any(
msg,
MessageType.ThpCredentialRequest,
MessageType.ThpEndRequest,
)
return await _handle_credential_request_or_end_request(ctx, response)
def _check_state(ctx: PairingContext, expected_state: ChannelState) -> None:
if expected_state is not ctx.channel.get_channel_state():
raise UnexpectedMessage("Unexpected message")
def _check_method_is_allowed(channel: Channel, method: ThpPairingMethod) -> None:
if not _is_method_included(channel, method):
def _check_method_is_allowed(ctx: PairingContext, method: ThpPairingMethod) -> None:
if not _is_method_included(ctx, method):
raise ThpError("Unexpected pairing method")
def _is_method_included(channel: Channel, method: ThpPairingMethod) -> bool:
return method in channel.selected_pairing_methods
def _is_method_included(ctx: PairingContext, method: ThpPairingMethod) -> bool:
return method in ctx.channel.selected_pairing_methods
async def _handle_credential_request_or_end_request(
ctx: PairingContext, response: protobuf.MessageType | None
) -> ThpEndResponse:
if ThpCredentialRequest.is_type_of(response):
return await _handle_credential_request(ctx, response)
if ThpEndRequest.is_type_of(response):
return await _handle_end_request(ctx, response)
raise UnexpectedMessage(
"Received message is not credential request or end request."
)

@ -14,7 +14,6 @@ from trezor.messages import (
)
from trezor.wire import message_handler
from trezor.wire.thp import ack_handler, thp_messages
from trezor.wire.thp.handler_provider import get_handler
from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum, crypto
@ -375,10 +374,12 @@ class Channel(Context):
)
async def _handle_pairing(self, message_length: int) -> None:
from .pairing_context import PairingContext
if self.connection_context is None:
self.connection_context = PairingContext(self)
loop.schedule(self.connection_context.handle())
self._decrypt_buffer(message_length)
@ -426,8 +427,9 @@ class Channel(Context):
"This message cannot be handled by channel itself. It must be send to allocated session."
)
# TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_channel_message
handler = get_handler(message)
handler = get_handler_for_channel_message(message)
task = handler(self, message)
response_message = await task
# TODO handle

@ -10,7 +10,7 @@ if TYPE_CHECKING:
pass
def get_handler(
def get_handler_for_channel_message(
msg: protobuf.MessageType,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
return create_session.create_new_session

@ -1,46 +1,23 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from trezor import log, loop, protobuf, workflow
from trezor.enums import MessageType
from trezor.wire import message_handler, protocol_common
from trezor.wire import context, message_handler, protocol_common
from trezor.wire.context import UnexpectedMessageWithId
from trezor.wire.errors import ActionCancelled
from trezor.wire.protocol_common import MessageWithType
from trezor.wire.protocol_common import Context, MessageWithType
from trezor.wire.thp.session_context import UnexpectedMessageWithType
from trezor.wire.thp.thp_session import ThpError
from apps.thp.pairing import (
handle_code_entry_challenge,
handle_code_entry_cpace,
handle_code_entry_tag,
handle_credential_request,
handle_end_request,
handle_nfc_unidirectional_tag,
handle_pairing_request,
handle_qr_code_tag,
)
from .channel import Channel
if TYPE_CHECKING:
from typing import Container, Generator # pyright:ignore[reportShadowedImports]
from typing import Container # pyright:ignore[reportShadowedImports]
pass
handlers = {
MessageType.ThpStartPairingRequest: handle_pairing_request,
MessageType.ThpCodeEntryChallenge: handle_code_entry_challenge,
MessageType.ThpCodeEntryCpaceHost: handle_code_entry_cpace,
MessageType.ThpCodeEntryTag: handle_code_entry_tag,
MessageType.ThpQrCodeTag: handle_qr_code_tag,
MessageType.ThpNfcUnidirectionalTag: handle_nfc_unidirectional_tag,
MessageType.ThpCredentialRequest: handle_credential_request,
MessageType.ThpEndRequest: handle_end_request,
}
class PairingContext:
class PairingContext(Context):
def __init__(self, channel: Channel) -> None:
super().__init__(channel.iface, channel.channel_id)
self.channel = channel
self.incoming_message = loop.chan()
@ -73,7 +50,7 @@ class PairingContext:
next_message = None
try:
next_message = await handle_pairing_message(
next_message = await handle_pairing_request_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
@ -115,7 +92,9 @@ class PairingContext:
str(expected_types),
exp_type,
)
message: MessageWithType = await self.incoming_message.take()
if message.type not in expected_types:
raise UnexpectedMessageWithType(message)
@ -127,22 +106,31 @@ class PairingContext:
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel.write(msg)
async def call(
self, msg: protobuf.MessageType, expected_type: type[protobuf.MessageType]
) -> protobuf.MessageType:
assert expected_type.MESSAGE_WIRE_TYPE is not None
await self.write(msg)
del msg
async def handle_pairing_message(
return await self.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_types: int
) -> protobuf.MessageType:
await self.write(msg)
del msg
return await self.read(expected_types)
async def handle_pairing_request_message(
ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool
) -> protocol_common.MessageWithType | None:
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type. Should not raise.
# TODO register handlers to dict
handler = get_handler(msg.type) # pylint: disable=assignment-from-none
if handler is None:
# If no handler is found, we can skip decoding and directly
# respond with failure.
await ctx.write(message_handler.unexpected_message())
return None
from apps.thp.pairing import handle_pairing_request
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
@ -159,7 +147,7 @@ async def handle_pairing_message(
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(ctx.channel, req_msg)
task = handle_pairing_request(ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
@ -168,7 +156,7 @@ async def handle_pairing_message(
if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(with_context(ctx, task))
res_msg = await workflow.spawn(context.with_context(ctx, task))
pass # TODO
else:
# For debug messages, ignore workflow processing and just await
@ -201,48 +189,9 @@ async def handle_pairing_message(
else:
log.exception(__name__, exc)
res_msg = message_handler.failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up
await ctx.write(res_msg)
return None
def get_handler(messageType: int):
if TYPE_CHECKING:
assert isinstance(messageType, MessageType)
handler = handlers.get(messageType)
if handler is None:
raise ThpError("Pairing handler for this message is not available!")
return handler
def with_context(ctx: PairingContext, workflow: loop.Task) -> Generator:
"""Run a workflow in a particular context.
Stores the context in a closure and installs it into the global variable every time
the closure is resumed, thus making sure that all calls to `wire.context.*` will
work as expected.
"""
global CURRENT_CONTEXT
send_val = None
send_exc = None
while True:
CURRENT_CONTEXT = ctx
try:
if send_exc is not None:
res = workflow.throw(send_exc)
else:
res = workflow.send(send_val)
except StopIteration as st:
return st.value
finally:
CURRENT_CONTEXT = None
try:
send_val = yield res
except BaseException as e:
send_exc = e
else:
send_exc = None

Loading…
Cancel
Save