Improve SessionContext's handling of expected types

M1nd3r 1 month ago
parent da6688f7b1
commit 8c1aca4bbd

@ -1,3 +1,5 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage import cache_thp
from storage.cache_thp import SessionThpCache
from trezor import loop, protobuf
@ -7,6 +9,9 @@ from ..protocol_common import Context, MessageWithType
from . import SessionState
from .channel import Channel
if TYPE_CHECKING:
from typing import Container # pyright: ignore[reportShadowedImports]
class UnexpectedMessageWithType(Exception):
"""A message was received that is not part of the current workflow.
@ -44,12 +49,19 @@ class SessionContext(Context):
print(message)
# TODO continue similarly to handle_session function in wire.__init__
async def read(self, expected_message_type: int) -> protobuf.MessageType:
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
message: MessageWithType = await self.incoming_message.take()
if message.type != expected_message_type:
if message.type not in expected_types:
raise UnexpectedMessageWithType(message)
expected_type = protobuf.type_for_wire(message.type)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:

Loading…
Cancel
Save