From a2f7a0cc782d500a7a91c447744f733d55e1107d Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Fri, 26 Apr 2024 16:15:50 +0200 Subject: [PATCH] Remake ChannelContext, change buffer types --- core/src/apps/thp/create_session.py | 7 ++- core/src/apps/thp/pairing.py | 5 +- core/src/trezor/wire/thp/__init__.py | 54 +++++++++++-------- core/src/trezor/wire/thp/channel.py | 32 ++++++----- core/src/trezor/wire/thp/memory_manager.py | 4 +- core/src/trezor/wire/thp/pairing_context.py | 12 +++-- .../wire/thp/received_message_handler.py | 3 +- core/src/trezor/wire/thp/session_context.py | 4 +- core/src/trezor/wire/thp/session_manager.py | 14 +++-- 9 files changed, 86 insertions(+), 49 deletions(-) diff --git a/core/src/apps/thp/create_session.py b/core/src/apps/thp/create_session.py index d77d88c97..c76ca52dc 100644 --- a/core/src/apps/thp/create_session.py +++ b/core/src/apps/thp/create_session.py @@ -1,6 +1,11 @@ +from typing import TYPE_CHECKING + from trezor import log, loop from trezor.messages import ThpCreateNewSession, ThpNewSession -from trezor.wire.thp import ChannelContext, SessionState +from trezor.wire.thp import SessionState + +if TYPE_CHECKING: + from trezor.wire.thp import ChannelContext async def create_new_session( diff --git a/core/src/apps/thp/pairing.py b/core/src/apps/thp/pairing.py index 955b71167..1c78cb5e4 100644 --- a/core/src/apps/thp/pairing.py +++ b/core/src/apps/thp/pairing.py @@ -1,4 +1,4 @@ -from trezor import log, protobuf +from trezor import protobuf from trezor.enums import MessageType, ThpPairingMethod from trezor.messages import ( ThpCodeEntryChallenge, @@ -25,6 +25,9 @@ from trezor.wire.thp.thp_session import ThpError # TODO implement the following handlers +if __debug__: + from trezor import log + async def handle_pairing_request( ctx: PairingContext, message: protobuf.MessageType diff --git a/core/src/trezor/wire/thp/__init__.py b/core/src/trezor/wire/thp/__init__.py index 59a145c46..b06fa576d 100644 --- a/core/src/trezor/wire/thp/__init__.py +++ b/core/src/trezor/wire/thp/__init__.py @@ -3,12 +3,43 @@ from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] if TYPE_CHECKING: from enum import IntEnum from trezorio import WireInterface + from typing import Protocol from storage.cache_thp import ChannelCache from trezor import loop, protobuf, utils from trezor.enums import FailureType from trezor.wire.thp.pairing_context import PairingContext from trezor.wire.thp.session_context import SessionContext + + class ChannelContext(Protocol): + buffer: utils.BufferType + iface: WireInterface + channel_id: bytes + channel_cache: ChannelCache + selected_pairing_methods = [] # TODO add type + sessions: dict[int, SessionContext] + waiting_for_ack_timeout: loop.spawn | None + write_task_spawn: loop.spawn | None + connection_context: PairingContext | None + + def get_channel_state(self) -> int: ... + + def set_channel_state(self, state: "ChannelState") -> None: ... + + async def write( + self, msg: protobuf.MessageType, session_id: int = 0 + ) -> None: ... + + async def write_error(self, err_type: FailureType, message: str) -> None: ... + + async def write_handshake_message( + self, ctrl_byte: int, payload: bytes + ) -> None: ... + + def decrypt_buffer(self, message_length: int) -> None: ... + + def get_channel_id_int(self) -> int: ... + else: IntEnum = object @@ -36,29 +67,6 @@ class WireInterfaceType(IntEnum): BLE = 2 -class ChannelContext: - def __init__(self, iface: WireInterface, channel_cache: ChannelCache): - self.buffer: utils.BufferType - self.iface: WireInterface = iface - self.channel_id: bytes = channel_cache.channel_id - self.channel_cache: ChannelCache = channel_cache - self.selected_pairing_methods = [] - self.sessions: dict[int, SessionContext] = {} - self.waiting_for_ack_timeout: loop.spawn | None = None - self.write_task_spawn: loop.spawn | None = None - self.connection_context: PairingContext | None = None - - def get_channel_state(self) -> int: ... - def set_channel_state(self, state: ChannelState) -> None: ... - async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: ... - async def write_error(self, err_type: FailureType, message: str) -> None: ... - async def write_handshake_message(self, ctrl_byte: int, payload: bytes) -> None: ... - def decrypt_buffer(self, message_length: int) -> None: ... - - def get_channel_id_int(self) -> int: - return int.from_bytes(self.channel_id, "big") - - def is_channel_state_pairing(state: int) -> bool: if state in ( ChannelState.TP1, diff --git a/core/src/trezor/wire/thp/channel.py b/core/src/trezor/wire/thp/channel.py index 88697ccd6..9f33b9047 100644 --- a/core/src/trezor/wire/thp/channel.py +++ b/core/src/trezor/wire/thp/channel.py @@ -6,14 +6,7 @@ from trezor import log, loop, protobuf, utils, workflow from trezor.enums import FailureType from trezor.wire.thp import interface_manager, received_message_handler -from . import ( - ChannelContext, - ChannelState, - checksum, - control_byte, - crypto, - memory_manager, -) +from . import ChannelState, checksum, control_byte, crypto, memory_manager from . import thp_session as THP from .checksum import CHECKSUM_LENGTH from .thp_messages import ENCRYPTED_TRANSPORT, ERROR, InitHeader @@ -31,19 +24,32 @@ if __debug__: if TYPE_CHECKING: from trezorio import WireInterface # pyright: ignore[reportMissingImports] + from . import ChannelContext, PairingContext + from .session_context import SessionContext +else: + ChannelContext = object + -class Channel(ChannelContext): +class Channel: def __init__(self, channel_cache: ChannelCache) -> None: if __debug__: log.debug(__name__, "channel initialization") - iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) - super().__init__(iface, channel_cache) - self.channel_cache = channel_cache + self.iface: WireInterface = interface_manager.decode_iface(channel_cache.iface) + self.channel_cache: ChannelCache = channel_cache self.is_cont_packet_expected: bool = False self.expected_payload_length: int = 0 self.bytes_read: int = 0 + self.buffer: utils.BufferType + self.channel_id: bytes = channel_cache.channel_id + self.selected_pairing_methods = [] + self.sessions: dict[int, SessionContext] = {} + self.waiting_for_ack_timeout: loop.spawn | None = None + self.write_task_spawn: loop.spawn | None = None + self.connection_context: PairingContext | None = None # ACCESS TO CHANNEL_DATA + def get_channel_id_int(self) -> int: + return int.from_bytes(self.channel_id, "big") def get_channel_state(self) -> int: state = int.from_bytes(self.channel_cache.state, "big") @@ -168,7 +174,7 @@ class Channel(ChannelContext): if __debug__: log.debug(__name__, "write message: %s", msg.MESSAGE_NAME) noise_payload_len = memory_manager.encode_into_buffer( - memoryview(self.buffer), msg, session_id + self.buffer, msg, session_id ) await self.write_and_encrypt(self.buffer[:noise_payload_len]) diff --git a/core/src/trezor/wire/thp/memory_manager.py b/core/src/trezor/wire/thp/memory_manager.py index 7b5687e71..a230e0ed6 100644 --- a/core/src/trezor/wire/thp/memory_manager.py +++ b/core/src/trezor/wire/thp/memory_manager.py @@ -46,7 +46,7 @@ def select_buffer( def encode_into_buffer( - buffer: memoryview, msg: protobuf.MessageType, session_id: int + buffer: utils.BufferType, msg: protobuf.MessageType, session_id: int ) -> int: # cannot write message without wire type @@ -58,7 +58,7 @@ def encode_into_buffer( if required_min_size > len(buffer): # message is too big, we need to allocate a new buffer - buffer = memoryview(bytearray(required_min_size)) + buffer = bytearray(required_min_size) _encode_session_into_buffer(memoryview(buffer), session_id) _encode_message_type_into_buffer( diff --git a/core/src/trezor/wire/thp/pairing_context.py b/core/src/trezor/wire/thp/pairing_context.py index a111ac7a7..f1a7ea496 100644 --- a/core/src/trezor/wire/thp/pairing_context.py +++ b/core/src/trezor/wire/thp/pairing_context.py @@ -1,19 +1,23 @@ -from typing import TYPE_CHECKING # pyright: ignore[reportShadowedImports] +from typing import TYPE_CHECKING -from trezor import log, loop, protobuf, workflow +from trezor import loop, protobuf, workflow from trezor.wire import context, message_handler, protocol_common from trezor.wire.context import UnexpectedMessageWithId from trezor.wire.errors import ActionCancelled from trezor.wire.protocol_common import Context, MessageWithType -from . import ChannelContext from .session_context import UnexpectedMessageWithType if TYPE_CHECKING: - from typing import Container # pyright:ignore[reportShadowedImports] + from typing import Container + + from . import ChannelContext pass +if __debug__: + from trezor import log + class PairingContext(Context): def __init__(self, channel_ctx: ChannelContext) -> None: diff --git a/core/src/trezor/wire/thp/received_message_handler.py b/core/src/trezor/wire/thp/received_message_handler.py index 1735cd3aa..f82f22b6a 100644 --- a/core/src/trezor/wire/thp/received_message_handler.py +++ b/core/src/trezor/wire/thp/received_message_handler.py @@ -19,7 +19,6 @@ from trezor.wire.thp.thp_messages import ( ) from . import ( - ChannelContext, ChannelState, SessionState, checksum, @@ -33,6 +32,8 @@ from .writer import INIT_DATA_OFFSET, MESSAGE_TYPE_LENGTH, write_payload_to_wire if TYPE_CHECKING: from trezor.messages import ThpHandshakeCompletionReqNoisePayload + from . import ChannelContext + if __debug__: from . import state_to_str diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py index a8773ead7..5a7c54546 100644 --- a/core/src/trezor/wire/thp/session_context.py +++ b/core/src/trezor/wire/thp/session_context.py @@ -6,7 +6,7 @@ from trezor.wire import message_handler, protocol_common from trezor.wire.message_handler import AVOID_RESTARTING_FOR, failure from ..protocol_common import Context, MessageWithType -from . import ChannelContext, SessionState +from . import SessionState if TYPE_CHECKING: from typing import ( # pyright: ignore[reportShadowedImports] @@ -15,6 +15,8 @@ if TYPE_CHECKING: Container, ) + from . import ChannelContext + pass _EXIT_LOOP = True diff --git a/core/src/trezor/wire/thp/session_manager.py b/core/src/trezor/wire/thp/session_manager.py index 78a9d4903..947b7faf8 100644 --- a/core/src/trezor/wire/thp/session_manager.py +++ b/core/src/trezor/wire/thp/session_manager.py @@ -1,7 +1,15 @@ +from typing import TYPE_CHECKING + from storage import cache_thp -from trezor import log, loop -from trezor.wire.thp import ChannelContext -from trezor.wire.thp.session_context import SessionContext +from trezor import loop + +from .session_context import SessionContext + +if __debug__: + from trezor import log + +if TYPE_CHECKING: + from . import ChannelContext def create_new_session(channel_ctx: ChannelContext) -> SessionContext: