Remove unnecessary abstractions with handlers

M1nd3r/thp2
M1nd3r 4 weeks ago
parent 4777750b2f
commit 8fb45754c6

@ -374,14 +374,11 @@ class Channel(Context):
)
async def _handle_pairing(self, message_length: int) -> None:
from trezor.wire.thp.handler_provider import get_handler_for_pairing
from . import pairing_context
from .pairing_context import PairingContext
if self.connection_context is None:
self.connection_context = PairingContext(self)
pairing_context.get_handler = get_handler_for_pairing # noqa
loop.schedule(self.connection_context.handle())
self._decrypt_buffer(message_length)
@ -430,9 +427,9 @@ class Channel(Context):
"This message cannot be handled by channel itself. It must be send to allocated session."
)
# TODO handle other messages than CreateNewSession
from trezor.wire.thp.handler_provider import get_handler_for_handshake
from trezor.wire.thp.handler_provider import get_handler_for_channel_message
handler = get_handler_for_handshake(message)
handler = get_handler_for_channel_message(message)
task = handler(self, message)
response_message = await task
# TODO handle

@ -9,16 +9,8 @@ if TYPE_CHECKING:
pass
from apps.thp.pairing import handle_pairing_request
def get_handler_for_handshake(
def get_handler_for_channel_message(
msg: protobuf.MessageType,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
return create_session.create_new_session
def get_handler_for_pairing(
messageType: int,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType | None]]:
return handle_pairing_request

@ -10,12 +10,7 @@ from trezor.wire.thp.session_context import UnexpectedMessageWithType
from .channel import Channel
if TYPE_CHECKING:
from typing import ( # pyright:ignore[reportShadowedImports]
Any,
Callable,
Container,
Coroutine,
)
from typing import Container # pyright:ignore[reportShadowedImports]
pass
@ -55,7 +50,7 @@ class PairingContext(Context):
next_message = None
try:
next_message = await handle_pairing_message(
next_message = await handle_pairing_request_message(
self, message, use_workflow=not is_debug_session
)
except Exception as exc:
@ -129,30 +124,13 @@ class PairingContext(Context):
return await self.read(expected_types)
def _find_handler_placeholder(
messageType: int,
) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]:
raise Exception()
get_handler = _find_handler_placeholder
async def handle_pairing_message(
async def handle_pairing_request_message(
ctx: PairingContext, msg: protocol_common.MessageWithType, use_workflow: bool
) -> protocol_common.MessageWithType | None:
res_msg: protobuf.MessageType | None = None
# We need to find a handler for this message type. Should not raise.
# TODO register handlers to dict
handler = get_handler(msg.type) # pylint: disable=assignment-from-none
if handler is None:
# If no handler is found, we can skip decoding and directly
# respond with failure.
await ctx.write(message_handler.unexpected_message())
return None
from apps.thp.pairing import handle_pairing_request
if msg.type in workflow.ALLOW_WHILE_LOCKED:
workflow.autolock_interrupts_workflow = False
@ -169,7 +147,7 @@ async def handle_pairing_message(
req_msg = message_handler.wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(ctx, req_msg)
task = handle_pairing_request(ctx, req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
@ -211,6 +189,7 @@ async def handle_pairing_message(
else:
log.exception(__name__, exc)
res_msg = message_handler.failure(exc)
if res_msg is not None:
# perform the write outside the big try-except block, so that usb write
# problem bubbles up

Loading…
Cancel
Save