|
|
|
@ -1,3 +1,5 @@
|
|
|
|
|
from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports]
|
|
|
|
|
|
|
|
|
|
from storage import cache_thp
|
|
|
|
|
from storage.cache_thp import SessionThpCache
|
|
|
|
|
from trezor import loop, protobuf
|
|
|
|
@ -7,6 +9,9 @@ from ..protocol_common import Context, MessageWithType
|
|
|
|
|
from . import SessionState
|
|
|
|
|
from .channel import Channel
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from typing import Container # pyright: ignore[reportShadowedImports]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnexpectedMessageWithType(Exception):
|
|
|
|
|
"""A message was received that is not part of the current workflow.
|
|
|
|
@ -44,12 +49,19 @@ class SessionContext(Context):
|
|
|
|
|
print(message)
|
|
|
|
|
# TODO continue similarly to handle_session function in wire.__init__
|
|
|
|
|
|
|
|
|
|
async def read(self, expected_message_type: int) -> protobuf.MessageType:
|
|
|
|
|
async def read(
|
|
|
|
|
self,
|
|
|
|
|
expected_types: Container[int],
|
|
|
|
|
expected_type: type[protobuf.MessageType] | None = None,
|
|
|
|
|
) -> protobuf.MessageType:
|
|
|
|
|
|
|
|
|
|
message: MessageWithType = await self.incoming_message.take()
|
|
|
|
|
if message.type != expected_message_type:
|
|
|
|
|
if message.type not in expected_types:
|
|
|
|
|
raise UnexpectedMessageWithType(message)
|
|
|
|
|
|
|
|
|
|
expected_type = protobuf.type_for_wire(message.type)
|
|
|
|
|
if expected_type is None:
|
|
|
|
|
expected_type = protobuf.type_for_wire(message.type)
|
|
|
|
|
|
|
|
|
|
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
|
|
|
|
|
|
|
|
|
async def write(self, msg: protobuf.MessageType) -> None:
|
|
|
|
|