feat(core): improve SessionContext's handling of expected types

M1nd3r 2 months ago
parent 714a949919
commit d4622b1b15

@ -1,3 +1,5 @@
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
from storage import cache_thp
from storage.cache_thp import SessionThpCache
from trezor import loop, protobuf
@ -7,14 +9,13 @@ from ..protocol_common import Context, MessageWithType
from . import SessionState
from .channel import Channel
if TYPE_CHECKING:
from typing import Container # pyright: ignore[reportShadowedImports]
class UnexpectedMessageWithType(Exception):
"""A message was received that is not part of the current workflow.
pass
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
class UnexpectedMessageWithType(Exception):
def __init__(self, msg: MessageWithType) -> None:
super().__init__()
self.msg = msg
@ -44,12 +45,19 @@ class SessionContext(Context):
print(message)
# TODO continue similarly to handle_session function in wire.__init__
async def read(self, expected_message_type: int) -> protobuf.MessageType:
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
message: MessageWithType = await self.incoming_message.take()
if message.type != expected_message_type:
if message.type not in expected_types:
raise UnexpectedMessageWithType(message)
expected_type = protobuf.type_for_wire(message.type)
if expected_type is None:
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:

Loading…
Cancel
Save