Implement SessionContext structure

M1nd3r/thp5
M1nd3r 2 months ago
parent 53bdc3979a
commit a1a5962838

@ -213,8 +213,8 @@ trezor.wire.thp
import trezor.wire.thp import trezor.wire.thp
trezor.wire.thp.ack_handler trezor.wire.thp.ack_handler
import trezor.wire.thp.ack_handler import trezor.wire.thp.ack_handler
trezor.wire.thp.channel_context trezor.wire.thp.channel
import trezor.wire.thp.channel_context import trezor.wire.thp.channel
trezor.wire.thp.checksum trezor.wire.thp.checksum
import trezor.wire.thp.checksum import trezor.wire.thp.checksum
trezor.wire.thp.session_context trezor.wire.thp.session_context

@ -42,7 +42,7 @@ if TYPE_CHECKING:
T = TypeVar("T") T = TypeVar("T")
class UnexpectedMessage(Exception): class UnexpectedMessageWithId(Exception):
"""A message was received that is not part of the current workflow. """A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow Utility exception to inform the session handler that the current workflow
@ -118,7 +118,7 @@ class CodecContext(Context):
# If we got a message with unexpected type, raise the message via # If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it. # `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types: if msg.type not in expected_types:
raise UnexpectedMessage(msg) raise UnexpectedMessageWithId(msg)
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError # TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
# (and maybe update ctx.session_id - depends on expected behaviour) # (and maybe update ctx.session_id - depends on expected behaviour)

@ -137,7 +137,7 @@ async def handle_single_message(
# results of the handler. # results of the handler.
res_msg = await task res_msg = await task
except context.UnexpectedMessage as exc: except context.UnexpectedMessageWithId as exc:
# Workflow was trying to read a message from the wire, and # Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for # something unexpected came in. See Context.read() for
# example, which expects some particular message and raises # example, which expects some particular message and raises

@ -4,6 +4,7 @@ from trezor import protobuf
if TYPE_CHECKING: if TYPE_CHECKING:
from trezorio import WireInterface # pyright: ignore[reportMissingImports] from trezorio import WireInterface # pyright: ignore[reportMissingImports]
from typing import Container # pyright: ignore[reportShadowedImports]
class Message: class Message:
@ -46,6 +47,12 @@ class Context:
self.iface: WireInterface = iface self.iface: WireInterface = iface
self.channel_id: bytes = channel_id self.channel_id: bytes = channel_id
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType: ...
async def write(self, msg: protobuf.MessageType) -> None: ... async def write(self, msg: protobuf.MessageType) -> None: ...

@ -10,7 +10,7 @@ from trezor import loop, protobuf, utils
from trezor.messages import ThpCreateNewSession from trezor.messages import ThpCreateNewSession
from trezor.wire import message_handler from trezor.wire import message_handler
from ..protocol_common import Context from ..protocol_common import Context, MessageWithType
from . import ChannelState, SessionState, checksum from . import ChannelState, SessionState, checksum
from . import thp_session as THP from . import thp_session as THP
from .checksum import CHECKSUM_LENGTH from .checksum import CHECKSUM_LENGTH
@ -39,7 +39,7 @@ REPORT_LENGTH = const(64)
MAX_PAYLOAD_LEN = const(60000) MAX_PAYLOAD_LEN = const(60000)
class ChannelContext(Context): class Channel(Context):
def __init__(self, channel_cache: ChannelCache) -> None: def __init__(self, channel_cache: ChannelCache) -> None:
iface = _decode_iface(channel_cache.iface) iface = _decode_iface(channel_cache.iface)
super().__init__(iface, channel_cache.channel_id) super().__init__(iface, channel_cache.channel_id)
@ -56,7 +56,7 @@ class ChannelContext(Context):
@classmethod @classmethod
def create_new_channel( def create_new_channel(
cls, iface: WireInterface, buffer: utils.BufferType cls, iface: WireInterface, buffer: utils.BufferType
) -> "ChannelContext": ) -> "Channel":
channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface)) channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface))
r = cls(channel_cache) r = cls(channel_cache)
r.set_buffer(buffer) r.set_buffer(buffer)
@ -217,9 +217,11 @@ class ChannelContext(Context):
if session_state is SessionState.UNALLOCATED: if session_state is SessionState.UNALLOCATED:
raise Exception("Unalloacted session") raise Exception("Unalloacted session")
await self.sessions[session_id].receive_message( self.sessions[session_id].incoming_message.publish(
message_type, MessageWithType(
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH], message_type,
self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH],
)
) )
if state is ChannelState.TH2: if state is ChannelState.TH2:
@ -275,6 +277,7 @@ class ChannelContext(Context):
session = SessionContext.create_new_session(self) session = SessionContext.create_new_session(self)
print("help") print("help")
self.sessions[session.session_id] = session self.sessions[session.session_id] = session
loop.schedule(session.handle())
print("new session created. Session id:", session.session_id) print("new session created. Session id:", session.session_id)
def _todo_clear_buffer(self): def _todo_clear_buffer(self):
@ -300,11 +303,11 @@ class ChannelContext(Context):
return THP.sync_get_send_bit(self.channel_cache) != sync_bit return THP.sync_get_send_bit(self.channel_cache) != sync_bit
def load_cached_channels(buffer: utils.BufferType) -> dict[int, ChannelContext]: # TODO def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO
channels: dict[int, ChannelContext] = {} channels: dict[int, Channel] = {}
cached_channels = cache_thp.get_all_allocated_channels() cached_channels = cache_thp.get_all_allocated_channels()
for c in cached_channels: for c in cached_channels:
channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c) channels[int.from_bytes(c.channel_id, "big")] = Channel(c)
for c in channels.values(): for c in channels.values():
c.set_buffer(buffer) c.set_buffer(buffer)
return channels return channels

@ -1,33 +1,60 @@
from storage import cache_thp from storage import cache_thp
from storage.cache_thp import SessionThpCache from storage.cache_thp import SessionThpCache
from trezor import protobuf from trezor import loop, protobuf
from trezor.wire import message_handler
from ..protocol_common import Context from ..protocol_common import Context, MessageWithType
from . import SessionState from . import SessionState
from .channel_context import ChannelContext from .channel import Channel
class UnexpectedMessageWithType(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: MessageWithType) -> None:
super().__init__()
self.msg = msg
class SessionContext(Context): class SessionContext(Context):
def __init__( def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None:
self, channel_context: ChannelContext, session_cache: SessionThpCache if channel.channel_id != session_cache.channel_id:
) -> None:
if channel_context.channel_id != session_cache.channel_id:
raise Exception( raise Exception(
"The session has different channel id than the provided channel context!" "The session has different channel id than the provided channel context!"
) )
super().__init__(channel_context.iface, channel_context.channel_id) super().__init__(channel.iface, channel.channel_id)
self.channel_context = channel_context self.channel_context = channel
self.session_cache = session_cache self.session_cache = session_cache
self.session_id = int.from_bytes(session_cache.session_id, "big") self.session_id = int.from_bytes(session_cache.session_id, "big")
self.incoming_message = loop.chan()
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_context.write(msg, self.session_id)
@classmethod @classmethod
def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext": def create_new_session(cls, channel_context: Channel) -> "SessionContext":
session_cache = cache_thp.get_new_session(channel_context.channel_cache) session_cache = cache_thp.get_new_session(channel_context.channel_cache)
return cls(channel_context, session_cache) return cls(channel_context, session_cache)
async def handle(self) -> None:
take = self.incoming_message.take()
while True:
message = await take
print(message)
# TODO continue similarly to handle_session function in wire.__init__
async def read(self, expected_message_type: int) -> protobuf.MessageType:
message: MessageWithType = await self.incoming_message.take()
if message.type != expected_message_type:
raise UnexpectedMessageWithType(message)
expected_type = protobuf.type_for_wire(message.type)
return message_handler.wrap_protobuf_load(message.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
return await self.channel_context.write(msg, self.session_id)
# ACCESS TO SESSION DATA # ACCESS TO SESSION DATA
def get_session_state(self) -> SessionState: def get_session_state(self) -> SessionState:
@ -43,7 +70,7 @@ class SessionContext(Context):
pass # TODO implement pass # TODO implement
def load_cached_sessions(channel: ChannelContext) -> dict[int, SessionContext]: # TODO def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO
sessions: dict[int, SessionContext] = {} sessions: dict[int, SessionContext] = {}
cached_sessions = cache_thp.get_all_allocated_sessions() cached_sessions = cache_thp.get_all_allocated_sessions()
for session in cached_sessions: for session in cached_sessions:

@ -8,12 +8,12 @@ from trezor import io, log, loop, utils
from .protocol_common import MessageWithId from .protocol_common import MessageWithId
from .thp import ChannelState, ack_handler, checksum, thp_messages from .thp import ChannelState, ack_handler, checksum, thp_messages
from .thp import thp_session as THP from .thp import thp_session as THP
from .thp.channel_context import ( from .thp.channel import (
CONT_DATA_OFFSET, CONT_DATA_OFFSET,
INIT_DATA_OFFSET, INIT_DATA_OFFSET,
MAX_PAYLOAD_LEN, MAX_PAYLOAD_LEN,
REPORT_LENGTH, REPORT_LENGTH,
ChannelContext, Channel,
load_cached_channels, load_cached_channels,
) )
from .thp.checksum import CHECKSUM_LENGTH from .thp.checksum import CHECKSUM_LENGTH
@ -38,7 +38,7 @@ _PLAINTEXT = 0x01
_BUFFER: bytearray _BUFFER: bytearray
_BUFFER_LOCK = None _BUFFER_LOCK = None
_CHANNEL_CONTEXTS: dict[int, ChannelContext] = {} _CHANNEL_CONTEXTS: dict[int, Channel] = {}
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
@ -346,7 +346,7 @@ async def _handle_broadcast(
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid") raise ThpError("Checksum is not valid")
new_context: ChannelContext = ChannelContext.create_new_channel(iface, _BUFFER) new_context: Channel = Channel.create_new_channel(iface, _BUFFER)
cid = int.from_bytes(new_context.channel_id, "big") cid = int.from_bytes(new_context.channel_id, "big")
_CHANNEL_CONTEXTS[cid] = new_context _CHANNEL_CONTEXTS[cid] = new_context

Loading…
Cancel
Save