From 38fd91c79b24adf180a63fc215ab75b3f409116e Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Thu, 28 Mar 2024 16:42:03 +0100 Subject: [PATCH] Improve SessionContext's handling of expected types --- core/src/trezor/wire/thp/session_context.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index bac7dd6b6..c5572ca0e 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] + from storage import cache_thp from storage.cache_thp import SessionThpCache from trezor import loop, protobuf @@ -7,6 +9,9 @@ from ..protocol_common import Context, MessageWithType from . import SessionState from .channel import Channel +if TYPE_CHECKING: + from typing import Container # pyright: ignore[reportShadowedImports] + class UnexpectedMessageWithType(Exception): """A message was received that is not part of the current workflow. @@ -44,12 +49,19 @@ class SessionContext(Context): print(message) # TODO continue similarly to handle_session function in wire.__init__ - async def read(self, expected_message_type: int) -> protobuf.MessageType: + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + message: MessageWithType = await self.incoming_message.take() - if message.type != expected_message_type: + if message.type not in expected_types: raise UnexpectedMessageWithType(message) - expected_type = protobuf.type_for_wire(message.type) + if expected_type is None: + expected_type = protobuf.type_for_wire(message.type) + return message_handler.wrap_protobuf_load(message.data, expected_type) async def write(self, msg: protobuf.MessageType) -> None: