From 601834d2333992eee0b8d7e248f441f2f7f059e9 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 10 Apr 2024 12:00:33 +0200 Subject: [PATCH] Clean session creation --- core/SConscript.firmware | 2 + core/SConscript.unix | 2 + core/src/all_modules.py | 2 + core/src/apps/thp/create_session.py | 29 ++++++--- core/src/trezor/wire/thp/channel.py | 63 +++----------------- core/src/trezor/wire/thp/handler_provider.py | 16 +++++ core/src/trezor/wire/thp/session_context.py | 10 ++-- core/src/trezor/wire/thp/thp_messages.py | 13 ---- 8 files changed, 56 insertions(+), 81 deletions(-) create mode 100644 core/src/trezor/wire/thp/handler_provider.py diff --git a/core/SConscript.firmware b/core/SConscript.firmware index 6b139850d..a258462e7 100644 --- a/core/SConscript.firmware +++ b/core/SConscript.firmware @@ -690,6 +690,8 @@ if FROZEN: SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/tezos/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/Tezos*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/zcash/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/webauthn/*.py')) diff --git a/core/SConscript.unix b/core/SConscript.unix index 112be304f..22240ac3d 100644 --- a/core/SConscript.unix +++ b/core/SConscript.unix @@ -776,6 +776,8 @@ if FROZEN: SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/tezos/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'trezor/enums/Tezos*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/thp/*.py')) + SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/zcash/*.py')) SOURCE_PY.extend(Glob(SOURCE_PY_DIR + 'apps/webauthn/*.py')) diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 8f69ca6c8..32dd1c58c 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -217,6 +217,8 @@ trezor.wire.thp.checksum import trezor.wire.thp.checksum trezor.wire.thp.crypto import trezor.wire.thp.crypto +trezor.wire.thp.handler_provider +import trezor.wire.thp.handler_provider trezor.wire.thp.pairing_context import trezor.wire.thp.pairing_context trezor.wire.thp.session_context diff --git a/core/src/apps/thp/create_session.py b/core/src/apps/thp/create_session.py index c03500e49..7e4cded0d 100644 --- a/core/src/apps/thp/create_session.py +++ b/core/src/apps/thp/create_session.py @@ -1,13 +1,26 @@ -from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] - -from trezor.wire.thp.channel import Channel - -if TYPE_CHECKING: - from trezor.messages import ThpCreateNewSession, ThpNewSession +from trezor import log, loop +from trezor.messages import ThpCreateNewSession, ThpNewSession +from trezor.wire.thp import SessionState, channel +from trezor.wire.thp.session_context import SessionContext async def create_new_session( - channel: Channel, message: ThpCreateNewSession + channel: channel.Channel, message: ThpCreateNewSession ) -> ThpNewSession: - new_session_id: int = channel.create_new_session(message.passphrase) + + session = SessionContext.create_new_session(channel) + session.set_session_state(SessionState.ALLOCATED) + channel.sessions[session.session_id] = session + loop.schedule(session.handle()) + new_session_id: int = session.session_id + + if __debug__: + log.debug( + __name__, + "create_new_session - new session created. Passphrase: %s, Session id: %d", + message.passphrase, + session.session_id, + ) + print(channel.sessions) + return ThpNewSession(new_session_id=new_session_id) diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 18c5c64f8..f67a88d80 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -7,9 +7,10 @@ from storage import cache_thp from storage.cache_thp import KEY_LENGTH, SESSION_ID_LENGTH, TAG_LENGTH, ChannelCache from trezor import log, loop, protobuf, utils from trezor.enums import FailureType, MessageType # , ThpPairingMethod -from trezor.messages import Failure, ThpCreateNewSession, ThpNewSession +from trezor.messages import Failure from trezor.wire import message_handler from trezor.wire.thp import ack_handler, thp_messages +from trezor.wire.thp.handler_provider import get_handler from ..protocol_common import Context, MessageWithType from . import ChannelState, SessionState, checksum, crypto @@ -397,40 +398,12 @@ class Channel(Context): if __debug__: log.debug(__name__, "handle_channel_message: %s", message) # TODO handle other messages than CreateNewSession - if TYPE_CHECKING: - assert isinstance(message, ThpCreateNewSession) - if __debug__: - log.debug( - __name__, - "handle_channel_message - passphrase: %s", - message.passphrase, - ) - # await thp_messages.handle_CreateNewSession(message) - new_session_id: int = self.create_new_session(message.passphrase) - - # TODO reuse existing buffer and compute size dynamically - bufferrone = bytearray(5) - msg = ThpNewSession(new_session_id=new_session_id) - message_size: int = thp_messages.get_new_session_message( - bufferrone, new_session_id - ) - if __debug__: - log.debug( - __name__, "handle_channel_message - message size: %d", message_size - ) - - _encode_session_into_buffer(memoryview(bufferrone), 0) - if TYPE_CHECKING: - assert msg.MESSAGE_WIRE_TYPE is not None - _encode_message_type_into_buffer( - memoryview(bufferrone), msg.MESSAGE_WIRE_TYPE, SESSION_ID_LENGTH - ) - _encode_message_into_buffer( - memoryview(bufferrone), msg, SESSION_ID_LENGTH + MESSAGE_TYPE_LENGTH - ) - await self.write(ThpNewSession(new_session_id=new_session_id)) - # TODO not finished + handler = get_handler(message) + task = handler(self, message) + response_message = await task + # TODO handle + await self.write(response_message) def _decrypt_single_packet_payload(self, payload: bytes) -> bytearray: payload_buffer = bytearray(payload) @@ -600,28 +573,6 @@ class Channel(Context): ) return protobuf.encoded_length(error_message) - def create_new_session( - self, - passphrase: str | None, - ) -> int: - if __debug__: - log.debug(__name__, " create_new_session") - from trezor.wire.thp.session_context import SessionContext - - session = SessionContext.create_new_session(self) - session.set_session_state(SessionState.ALLOCATED) - self.sessions[session.session_id] = session - loop.schedule(session.handle()) - if __debug__: - log.debug( - __name__, - "create_new_session - new session created. Session id: %d", - session.session_id, - ) - if __debug__: - print(self.sessions) - return session.session_id - def _todo_clear_buffer(self): # TODO Buffer clearing not implemented pass diff --git a/core/src/trezor/wire/thp/handler_provider.py b/core/src/trezor/wire/thp/handler_provider.py new file mode 100644 index 000000000..a0774ded1 --- /dev/null +++ b/core/src/trezor/wire/thp/handler_provider.py @@ -0,0 +1,16 @@ +from typing import TYPE_CHECKING + +from trezor import protobuf + +from apps.thp import create_session + +if TYPE_CHECKING: + from typing import Any, Callable, Coroutine + + pass + + +def get_handler( + msg: protobuf.MessageType, +) -> Callable[[Any, Any], Coroutine[Any, Any, protobuf.MessageType]]: + return create_session.create_new_session diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index 0b3664165..0ffc6a39b 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -8,7 +8,7 @@ from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure from ..protocol_common import Context, MessageWithType from . import SessionState -from .channel import Channel +from . import channel if TYPE_CHECKING: from typing import Container # pyright: ignore[reportShadowedImports] @@ -29,7 +29,9 @@ class UnexpectedMessageWithType(Exception): class SessionContext(Context): - def __init__(self, channel: Channel, session_cache: SessionThpCache) -> None: + def __init__( + self, channel: channel.Channel, session_cache: SessionThpCache + ) -> None: if channel.channel_id != session_cache.channel_id: raise Exception( "The session has different channel id than the provided channel context!" @@ -41,7 +43,7 @@ class SessionContext(Context): self.incoming_message = loop.chan() @classmethod - def create_new_session(cls, channel_context: Channel) -> "SessionContext": + def create_new_session(cls, channel_context: channel.Channel) -> "SessionContext": session_cache = cache_thp.get_new_session(channel_context.channel_cache) return cls(channel_context, session_cache) @@ -145,7 +147,7 @@ class SessionContext(Context): self.session_cache.state = bytearray(state.to_bytes(1, "big")) -def load_cached_sessions(channel: Channel) -> dict[int, SessionContext]: # TODO +def load_cached_sessions(channel: channel.Channel) -> dict[int, SessionContext]: # TODO if __debug__: log.debug(__name__, "load_cached_sessions") sessions: dict[int, SessionContext] = {} diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index cb6f8b649..f1dadb8a4 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -2,7 +2,6 @@ import ustruct # pyright:ignore[reportMissingModuleSource] from storage.cache_thp import BROADCAST_CHANNEL_ID from trezor import protobuf -from trezor.messages import ThpCreateNewSession, ThpNewSession from .. import message_handler from ..protocol_common import Message @@ -98,21 +97,9 @@ def get_handshake_completion_response() -> bytes: ) -def get_new_session_message(buffer: bytearray, new_session_id: int) -> int: - msg = ThpNewSession(new_session_id=new_session_id) - encoded_msg = protobuf.encode(buffer, msg) - return encoded_msg - - def decode_message(buffer: bytes, msg_type: int) -> protobuf.MessageType: print("decode message") expected_type = protobuf.type_for_wire(msg_type) x = message_handler.wrap_protobuf_load(buffer, expected_type) print("result decoded", x) return x - - -async def handle_CreateNewSession(msg: ThpCreateNewSession) -> None: - print(msg.passphrase) - print(msg.on_device) - pass