diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 6dceb1a3c..59e47a0d2 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -213,8 +213,8 @@ trezor.wire.thp import trezor.wire.thp trezor.wire.thp.ack_handler import trezor.wire.thp.ack_handler -trezor.wire.thp.channel_context -import trezor.wire.thp.channel_context +trezor.wire.thp.channel +import trezor.wire.thp.channel trezor.wire.thp.checksum import trezor.wire.thp.checksum trezor.wire.thp.session_context diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index 3eb49e54a..0ff213641 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -42,7 +42,7 @@ if TYPE_CHECKING: T = TypeVar("T") -class UnexpectedMessage(Exception): +class UnexpectedMessageWithId(Exception): """A message was received that is not part of the current workflow. Utility exception to inform the session handler that the current workflow @@ -118,7 +118,7 @@ class CodecContext(Context): # If we got a message with unexpected type, raise the message via # `UnexpectedMessageError` and let the session handler deal with it. if msg.type not in expected_types: - raise UnexpectedMessage(msg) + raise UnexpectedMessageWithId(msg) # TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError # (and maybe update ctx.session_id - depends on expected behaviour) diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py index dc055281b..2ccd43c69 100644 --- a/core/src/trezor/wire/message_handler.py +++ b/core/src/trezor/wire/message_handler.py @@ -137,7 +137,7 @@ async def handle_single_message( # results of the handler. res_msg = await task - except context.UnexpectedMessage as exc: + except context.UnexpectedMessageWithId as exc: # Workflow was trying to read a message from the wire, and # something unexpected came in. See Context.read() for # example, which expects some particular message and raises diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index 89e795b02..733d361c9 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -4,6 +4,7 @@ from trezor import protobuf if TYPE_CHECKING: from trezorio import WireInterface # pyright: ignore[reportMissingImports] + from typing import Container # pyright: ignore[reportShadowedImports] class Message: @@ -46,6 +47,12 @@ class Context: self.iface: WireInterface = iface self.channel_id: bytes = channel_id + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: ... + async def write(self, msg: protobuf.MessageType) -> None: ... diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel.py similarity index 96% rename from core/src/trezor/wire/thp/channel_context.py rename to core/src/trezor/wire/thp/channel.py index 1e7cdcf0e..ca5cdebc5 100644 --- a/core/src/trezor/wire/thp/channel_context.py +++ b/core/src/trezor/wire/thp/channel.py @@ -10,7 +10,7 @@ from trezor import loop, protobuf, utils from trezor.messages import ThpCreateNewSession from trezor.wire import message_handler -from ..protocol_common import Context +from ..protocol_common import Context, MessageWithType from . import ChannelState, SessionState, checksum from . import thp_session as THP from .checksum import CHECKSUM_LENGTH @@ -39,7 +39,7 @@ REPORT_LENGTH = const(64) MAX_PAYLOAD_LEN = const(60000) -class ChannelContext(Context): +class Channel(Context): def __init__(self, channel_cache: ChannelCache) -> None: iface = _decode_iface(channel_cache.iface) super().__init__(iface, channel_cache.channel_id) @@ -56,7 +56,7 @@ class ChannelContext(Context): @classmethod def create_new_channel( cls, iface: WireInterface, buffer: utils.BufferType - ) -> "ChannelContext": + ) -> "Channel": channel_cache = cache_thp.get_new_unauthenticated_channel(_encode_iface(iface)) r = cls(channel_cache) r.set_buffer(buffer) @@ -217,9 +217,11 @@ class ChannelContext(Context): if session_state is SessionState.UNALLOCATED: raise Exception("Unalloacted session") - await self.sessions[session_id].receive_message( - message_type, - self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH], + self.sessions[session_id].incoming_message.publish( + MessageWithType( + message_type, + self.buffer[INIT_DATA_OFFSET + 3 : msg_len - CHECKSUM_LENGTH], + ) ) if state is ChannelState.TH2: @@ -275,6 +277,7 @@ class ChannelContext(Context): session = SessionContext.create_new_session(self) print("help") self.sessions[session.session_id] = session + loop.schedule(session.handle()) print("new session created. Session id:", session.session_id) def _todo_clear_buffer(self): @@ -300,11 +303,11 @@ class ChannelContext(Context): return THP.sync_get_send_bit(self.channel_cache) != sync_bit -def load_cached_channels(buffer: utils.BufferType) -> dict[int, ChannelContext]: # TODO - channels: dict[int, ChannelContext] = {} +def load_cached_channels(buffer: utils.BufferType) -> dict[int, Channel]: # TODO + channels: dict[int, Channel] = {} cached_channels = cache_thp.get_all_allocated_channels() for c in cached_channels: - channels[int.from_bytes(c.channel_id, "big")] = ChannelContext(c) + channels[int.from_bytes(c.channel_id, "big")] = Channel(c) for c in channels.values(): c.set_buffer(buffer) return channels diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index a766e15cf..bac7dd6b6 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -1,33 +1,60 @@ from storage import cache_thp from storage.cache_thp import SessionThpCache -from trezor import protobuf +from trezor import loop, protobuf +from trezor.wire import message_handler -from ..protocol_common import Context +from ..protocol_common import Context, MessageWithType from . import SessionState -from .channel_context import ChannelContext +from .channel import Channel + + +class UnexpectedMessageWithType(Exception): + """A message was received that is not part of the current workflow. + + Utility exception to inform the session handler that the current workflow + should be aborted and a new one started as if `msg` was the first message. + """ + + def __init__(self, msg: MessageWithType) -> None: + super().__init__() + self.msg = msg class SessionContext(Context): - def __init__( - self, channel_context: ChannelContext, session_cache: SessionThpCache - ) -> None: - if channel_context.channel_id != session_cache.channel_id: + def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None: + if channel.channel_id != session_cache.channel_id: raise Exception( "The session has different channel id than the provided channel context!" ) - super().__init__(channel_context.iface, channel_context.channel_id) - self.channel_context = channel_context + super().__init__(channel.iface, channel.channel_id) + self.channel_context = channel self.session_cache = session_cache self.session_id = int.from_bytes(session_cache.session_id, "big") - - async def write(self, msg: protobuf.MessageType) -> None: - return await self.channel_context.write(msg, self.session_id) + self.incoming_message = loop.chan() @classmethod - def create_new_session(cls, channel_context: ChannelContext) -> "SessionContext": + def create_new_session(cls, channel_context: Channel) -> "SessionContext": session_cache = cache_thp.get_new_session(channel_context.channel_cache) return cls(channel_context, session_cache) + async def handle(self) -> None: + take = self.incoming_message.take() + while True: + message = await take + print(message) + # TODO continue similarly to handle_session function in wire.__init__ + + async def read(self, expected_message_type: int) -> protobuf.MessageType: + message: MessageWithType = await self.incoming_message.take() + if message.type != expected_message_type: + raise UnexpectedMessageWithType(message) + + expected_type = protobuf.type_for_wire(message.type) + return message_handler.wrap_protobuf_load(message.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + return await self.channel_context.write(msg, self.session_id) + # ACCESS TO SESSION DATA def get_session_state(self) -> SessionState: @@ -43,7 +70,7 @@ class SessionContext(Context): pass # TODO implement -def load_cached_sessions(channel: ChannelContext) -> dict[int, SessionContext]: # TODO +def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO sessions: dict[int, SessionContext] = {} cached_sessions = cache_thp.get_all_allocated_sessions() for session in cached_sessions: diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index 714207499..6bac26f7c 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -8,12 +8,12 @@ from trezor import io, log, loop, utils from .protocol_common import MessageWithId from .thp import ChannelState, ack_handler, checksum, thp_messages from .thp import thp_session as THP -from .thp.channel_context import ( +from .thp.channel import ( CONT_DATA_OFFSET, INIT_DATA_OFFSET, MAX_PAYLOAD_LEN, REPORT_LENGTH, - ChannelContext, + Channel, load_cached_channels, ) from .thp.checksum import CHECKSUM_LENGTH @@ -38,7 +38,7 @@ _PLAINTEXT = 0x01 _BUFFER: bytearray _BUFFER_LOCK = None -_CHANNEL_CONTEXTS: dict[int, ChannelContext] = {} +_CHANNEL_CONTEXTS: dict[int, Channel] = {} async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: @@ -346,7 +346,7 @@ async def _handle_broadcast( if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): raise ThpError("Checksum is not valid") - new_context: ChannelContext = ChannelContext.create_new_channel(iface, _BUFFER) + new_context: Channel = Channel.create_new_channel(iface, _BUFFER) cid = int.from_bytes(new_context.channel_id, "big") _CHANNEL_CONTEXTS[cid] = new_context