parent
6a93c29fe1
commit
74cb074469
@ -0,0 +1,13 @@
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from trezor.wire.thp.channel import Channel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.messages import ThpCreateNewSession, ThpNewSession
|
||||
|
||||
|
||||
async def create_new_session(
|
||||
channel: Channel, message: ThpCreateNewSession
|
||||
) -> ThpNewSession:
|
||||
new_session_id: int = channel.create_new_session(message.passphrase)
|
||||
return ThpNewSession(new_session_id=new_session_id)
|
@ -0,0 +1,92 @@
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from trezor.wire.errors import UnexpectedMessage
|
||||
from trezor.wire.thp import ChannelState
|
||||
from trezor.wire.thp.channel import Channel
|
||||
from trezor.wire.thp.thp_session import ThpError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezor.enums import ThpPairingMethod
|
||||
from trezor.messages import (
|
||||
ThpCodeEntryChallenge,
|
||||
ThpCodeEntryCommitment,
|
||||
ThpCodeEntryCpaceHost,
|
||||
ThpCodeEntryCpaceTrezor,
|
||||
ThpCodeEntrySecret,
|
||||
ThpCodeEntryTag,
|
||||
ThpNfcUnideirectionalSecret,
|
||||
ThpNfcUnidirectionalTag,
|
||||
ThpQrCodeSecret,
|
||||
ThpQrCodeTag,
|
||||
ThpStartPairingRequest,
|
||||
)
|
||||
|
||||
|
||||
# TODO implement the following handlers
|
||||
|
||||
|
||||
async def handle_pairing_request(
|
||||
channel: Channel, message: ThpStartPairingRequest
|
||||
) -> ThpCodeEntryCommitment | None:
|
||||
_check_state(channel, ChannelState.TP1)
|
||||
if _is_method_included(channel, ThpPairingMethod.PairingMethod_CodeEntry):
|
||||
channel.set_channel_state(ChannelState.TP2)
|
||||
return ThpCodeEntryCommitment()
|
||||
channel.set_channel_state(ChannelState.TP3)
|
||||
return None
|
||||
|
||||
|
||||
async def handle_code_entry_challenge(
|
||||
channel: Channel, message: ThpCodeEntryChallenge
|
||||
) -> None:
|
||||
_check_state(channel, ChannelState.TP2)
|
||||
channel.set_channel_state(ChannelState.TP3)
|
||||
|
||||
|
||||
async def handle_code_entry_cpace(
|
||||
channel: Channel, message: ThpCodeEntryCpaceHost
|
||||
) -> ThpCodeEntryCpaceTrezor:
|
||||
_check_state(channel, ChannelState.TP3)
|
||||
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_CodeEntry)
|
||||
channel.set_channel_state(ChannelState.TP4)
|
||||
return ThpCodeEntryCpaceTrezor()
|
||||
|
||||
|
||||
async def handle_code_entry_tag(
|
||||
channel: Channel, message: ThpCodeEntryTag
|
||||
) -> ThpCodeEntrySecret:
|
||||
_check_state(channel, ChannelState.TP4)
|
||||
channel.set_channel_state(ChannelState.TC1)
|
||||
return ThpCodeEntrySecret()
|
||||
|
||||
|
||||
async def handle_qr_code_tag(
|
||||
channel: Channel, message: ThpQrCodeTag
|
||||
) -> ThpQrCodeSecret:
|
||||
_check_state(channel, ChannelState.TP3)
|
||||
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_QrCode)
|
||||
channel.set_channel_state(ChannelState.TC1)
|
||||
return ThpQrCodeSecret()
|
||||
|
||||
|
||||
async def handle_nfc_unidirectional_tag(
|
||||
channel: Channel, message: ThpNfcUnidirectionalTag
|
||||
) -> ThpNfcUnideirectionalSecret:
|
||||
_check_state(channel, ChannelState.TP3)
|
||||
_check_method_is_allowed(channel, ThpPairingMethod.PairingMethod_NFC_Unidirectional)
|
||||
channel.set_channel_state(ChannelState.TC1)
|
||||
return ThpNfcUnideirectionalSecret()
|
||||
|
||||
|
||||
def _check_state(channel: Channel, expected_state: ChannelState) -> None:
|
||||
if expected_state is not channel.get_channel_state():
|
||||
raise UnexpectedMessage("Unexpected message")
|
||||
|
||||
|
||||
def _check_method_is_allowed(channel: Channel, method: ThpPairingMethod) -> None:
|
||||
if not _is_method_included(channel, method):
|
||||
raise ThpError("Unexpected pairing method")
|
||||
|
||||
|
||||
def _is_method_included(channel: Channel, method: ThpPairingMethod) -> bool:
|
||||
return method in channel.selected_pairing_methods
|
@ -0,0 +1,199 @@
|
||||
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
||||
|
||||
from trezor import log, loop, protobuf, workflow
|
||||
from trezor.messages import ThpStartPairingRequest
|
||||
from trezor.wire import message_handler, protocol_common
|
||||
from trezor.wire.context import UnexpectedMessageWithId
|
||||
from trezor.wire.errors import ActionCancelled
|
||||
from trezor.wire.protocol_common import MessageWithType
|
||||
from trezor.wire.thp.session_context import UnexpectedMessageWithType
|
||||
|
||||
from apps.thp.pairing import handle_pairing_request
|
||||
|
||||
from .channel import Channel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Container # pyright:ignore[reportShadowedImports]
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class PairingContext:
|
||||
def __init__(self, channel: Channel) -> None:
|
||||
self.channel = channel
|
||||
self.incoming_message = loop.chan()
|
||||
|
||||
async def handle(self, is_debug_session: bool = False) -> None:
|
||||
if __debug__:
|
||||
log.debug(__name__, "handle - start")
|
||||
if is_debug_session:
|
||||
import apps.debug
|
||||
|
||||
apps.debug.DEBUG_CONTEXT = self
|
||||
|
||||
take = self.incoming_message.take()
|
||||
next_message: MessageWithType | None = None
|
||||
|
||||
# Take a mark of modules that are imported at this point, so we can
|
||||
# roll back and un-import any others.
|
||||
# TODO modules = utils.unimport_begin()
|
||||
while True:
|
||||
try:
|
||||
if next_message is None:
|
||||
# If the previous run did not keep an unprocessed message for us,
|
||||
# wait for a new one.
|
||||
try:
|
||||
message: MessageWithType = await take
|
||||
except protocol_common.WireError as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
await self.write(message_handler.failure(e))
|
||||
continue
|
||||
else:
|
||||
# Process the message from previous run.
|
||||
message = next_message
|
||||
next_message = None
|
||||
|
||||
try:
|
||||
next_message = await handle_pairing_message(
|
||||
self, message, use_workflow=not is_debug_session
|
||||
)
|
||||
except Exception as exc:
|
||||
# Log and ignore. The session handler can only exit explicitly in the
|
||||
# following finally block.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
finally:
|
||||
if not __debug__ or not is_debug_session:
|
||||
# Unload modules imported by the workflow. Should not raise.
|
||||
# This is not done for the debug session because the snapshot taken
|
||||
# in a debug session would clear modules which are in use by the
|
||||
# workflow running on wire.
|
||||
# TODO utils.unimport_end(modules)
|
||||
|
||||
if next_message is None:
|
||||
|
||||
# Shut down the loop if there is no next message waiting.
|
||||
# Let the session be restarted from `main`.
|
||||
loop.clear()
|
||||
return # pylint: disable=lost-exception
|
||||
|
||||
except Exception as exc:
|
||||
# Log and try again. The session handler can only exit explicitly via
|
||||
# loop.clear() above.
|
||||
if __debug__:
|
||||
log.exception(__name__, exc)
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
if __debug__:
|
||||
exp_type: str = str(expected_type)
|
||||
if expected_type is not None:
|
||||
exp_type = expected_type.MESSAGE_NAME
|
||||
log.debug(
|
||||
__name__,
|
||||
"Read - with expected types %s and expected type %s",
|
||||
str(expected_types),
|
||||
exp_type,
|
||||
)
|
||||
message: MessageWithType = await self.incoming_message.take()
|
||||
if message.type not in expected_types:
|
||||
raise UnexpectedMessageWithType(message)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(message.type)
|
||||
|
||||
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel.write(msg)
|
||||
|
||||
|
||||
async def handle_pairing_message(
|
||||
ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool
|
||||
) -> protocol_common.MessageWithType | None:
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
# We need to find a handler for this message type. Should not raise.
|
||||
# TODO register handlers to dict
|
||||
handler = get_handler(msg.type) # pylint: disable=assignment-from-none
|
||||
|
||||
if handler is None:
|
||||
# If no handler is found, we can skip decoding and directly
|
||||
# respond with failure.
|
||||
await ctx.write(message_handler.unexpected_message())
|
||||
return None
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
if TYPE_CHECKING:
|
||||
assert isinstance(req_msg, ThpStartPairingRequest) # TODO remove
|
||||
task = handler(ctx.channel, req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
if use_workflow:
|
||||
# Spawn a workflow around the task. This ensures that concurrent
|
||||
# workflows are shut down.
|
||||
# res_msg = await workflow.spawn(context.with_context(ctx, task))
|
||||
pass # TODO
|
||||
else:
|
||||
# For debug messages, ignore workflow processing and just await
|
||||
# results of the handler.
|
||||
res_msg = await task
|
||||
|
||||
except UnexpectedMessageWithId as exc:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
# In order not to lose the message, we return it to the caller.
|
||||
# TODO:
|
||||
# We might handle only the few common cases here, like
|
||||
# Initialize and Cancel.
|
||||
return exc.msg
|
||||
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = message_handler.failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await ctx.write(res_msg)
|
||||
return None
|
||||
|
||||
|
||||
def get_handler(messageType: int):
|
||||
return handle_pairing_request
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue