|
|
|
@ -1,33 +1,60 @@
|
|
|
|
|
from storage import cache_thp
|
|
|
|
|
from storage.cache_thp import SessionThpCache
|
|
|
|
|
from trezor import protobuf
|
|
|
|
|
from trezor import loop, protobuf
|
|
|
|
|
from trezor.wire import message_handler
|
|
|
|
|
|
|
|
|
|
from ..protocol_common import Context
|
|
|
|
|
from ..protocol_common import Context, MessageWithType
|
|
|
|
|
from . import SessionState
|
|
|
|
|
from .channel_context import ChannelContext
|
|
|
|
|
from .channel import Channel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class UnexpectedMessageWithType(Exception):
|
|
|
|
|
"""A message was received that is not part of the current workflow.
|
|
|
|
|
|
|
|
|
|
Utility exception to inform the session handler that the current workflow
|
|
|
|
|
should be aborted and a new one started as if `msg` was the first message.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, msg: MessageWithType) -> None:
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.msg = msg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SessionContext(Context):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, channel_context: ChannelContext, session_cache: SessionThpCache
|
|
|
|
|
) -> None:
|
|
|
|
|
if channel_context.channel_id != session_cache.channel_id:
|
|
|
|
|
def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None:
|
|
|
|
|
if channel.channel_id != session_cache.channel_id:
|
|
|
|
|
raise Exception(
|
|
|
|
|
"The session has different channel id than the provided channel context!"
|
|
|
|
|
)
|
|
|
|
|
super().__init__(channel_context.iface, channel_context.channel_id)
|
|
|
|
|
self.channel_context = channel_context
|
|
|
|
|
super().__init__(channel.iface, channel.channel_id)
|
|
|
|
|
self.channel_context = channel
|
|
|
|
|
self.session_cache = session_cache
|
|
|
|
|
self.session_id = int.from_bytes(session_cache.session_id, "big")
|
|
|
|
|
|
|
|
|
|
async def write(self, msg: protobuf.MessageType) -> None:
|
|
|
|
|
return await self.channel_context.write(msg, self.session_id)
|
|
|
|
|
self.incoming_message = loop.chan()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext":
|
|
|
|
|
def create_new_session(cls, channel_context: Channel) -> "SessionContext":
|
|
|
|
|
session_cache = cache_thp.get_new_session(channel_context.channel_cache)
|
|
|
|
|
return cls(channel_context, session_cache)
|
|
|
|
|
|
|
|
|
|
async def handle(self) -> None:
|
|
|
|
|
take = self.incoming_message.take()
|
|
|
|
|
while True:
|
|
|
|
|
message = await take
|
|
|
|
|
print(message)
|
|
|
|
|
# TODO continue similarly to handle_session function in wire.__init__
|
|
|
|
|
|
|
|
|
|
async def read(self, expected_message_type: int) -> protobuf.MessageType:
|
|
|
|
|
message: MessageWithType = await self.incoming_message.take()
|
|
|
|
|
if message.type != expected_message_type:
|
|
|
|
|
raise UnexpectedMessageWithType(message)
|
|
|
|
|
|
|
|
|
|
expected_type = protobuf.type_for_wire(message.type)
|
|
|
|
|
return message_handler.wrap_protobuf_load(message.data, expected_type)
|
|
|
|
|
|
|
|
|
|
async def write(self, msg: protobuf.MessageType) -> None:
|
|
|
|
|
return await self.channel_context.write(msg, self.session_id)
|
|
|
|
|
|
|
|
|
|
# ACCESS TO SESSION DATA
|
|
|
|
|
|
|
|
|
|
def get_session_state(self) -> SessionState:
|
|
|
|
@ -43,7 +70,7 @@ class SessionContext(Context):
|
|
|
|
|
pass # TODO implement
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_cached_sessions(channel: ChannelContext) -> dict[int, SessionContext]: # TODO
|
|
|
|
|
def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
|
|
|
|
|
sessions: dict[int, SessionContext] = {}
|
|
|
|
|
cached_sessions = cache_thp.get_all_allocated_sessions()
|
|
|
|
|
for session in cached_sessions:
|
|
|
|
|