diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 9171d6bc3..f6dbdaa88 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -201,14 +201,20 @@ trezor.wire.context import trezor.wire.context trezor.wire.errors import trezor.wire.errors +trezor.wire.message_handler +import trezor.wire.message_handler trezor.wire.protocol import trezor.wire.protocol trezor.wire.protocol_common import trezor.wire.protocol_common trezor.wire.thp.ack_handler import trezor.wire.thp.ack_handler +trezor.wire.thp.channel_context +import trezor.wire.thp.channel_context trezor.wire.thp.checksum import trezor.wire.thp.checksum +trezor.wire.thp.session_context +import trezor.wire.thp.session_context trezor.wire.thp.thp_messages import trezor.wire.thp.thp_messages trezor.wire.thp.thp_session diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 1e2917183..1ac38b999 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -43,7 +43,7 @@ if __debug__: layout_change_chan = loop.chan() - DEBUG_CONTEXT: context.Context | None = None + DEBUG_CONTEXT: context.CodecContext | None = None LAYOUT_WATCHER_NONE = 0 LAYOUT_WATCHER_STATE = 1 diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 77a3dc6d0..35f97d629 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -27,11 +27,11 @@ from micropython import const from typing import TYPE_CHECKING from storage.cache_common import InvalidSessionError -from trezor import log, loop, protobuf, utils, workflow +from trezor import log, loop, protobuf, utils from trezor.enums import FailureType from trezor.messages import Failure -from trezor.wire import codec_v1, context, protocol_common -from trezor.wire.errors import ActionCancelled, DataError, Error +from trezor.wire import codec_v1, context, message_handler, protocol_common, thp_v1 +from trezor.wire.errors import DataError, Error # Import all errors into namespace, so that `wire.Error` is available from # other packages. @@ -88,113 +88,43 @@ if __debug__: WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) -async def _handle_single_message( - ctx: context.Context, msg: protocol_common.MessageWithId, use_workflow: bool -) -> protocol_common.MessageWithId | None: - """Handle a message that was loaded from USB by the caller. +async def handle_thp_session(iface: WireInterface, is_debug_session: bool = False): + if __debug__ and is_debug_session: + ctx_buffer = WIRE_BUFFER_DEBUG + else: + ctx_buffer = WIRE_BUFFER - Find the appropriate handler, run it and write its result on the wire. In case - a problem is encountered at any point, write the appropriate error on the wire. + thp_v1.set_buffer(ctx_buffer) + + if __debug__ and is_debug_session: + import apps.debug - If the workflow finished normally or with an error, the return value is None. + print(apps.debug.DEBUG_CONTEXT) # TODO remove - If an unexpected message had arrived on the wire while the workflow was processing, - the workflow is shut down with an `UnexpectedMessage` exception. This is not - considered an "error condition" to return over the wire -- instead the message - is processed as if starting a new workflow. - In such case, the `UnexpectedMessage` is caught and the message is returned - to the caller. It will then be processed in the next iteration of the message loop. - """ - if __debug__: + # TODO add debug context or smth to apps.debug + + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. + modules = utils.unimport_begin() + + while True: try: - msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME - except Exception: - msg_type = f"{msg.type} - unknown message type" - if ctx.session_id is not None: - sid = int.from_bytes(ctx.session_id, "big") - else: - sid = -1 - log.debug( - __name__, - "%s:%x receive: <%s>", - ctx.iface.iface_num(), - sid, - msg_type, - ) - - res_msg: protobuf.MessageType | None = None - - # We need to find a handler for this message type. Should not raise. - handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none - - if handler is None: - # If no handler is found, we can skip decoding and directly - # respond with failure. - await ctx.write(unexpected_message()) - return None - - if msg.type in workflow.ALLOW_WHILE_LOCKED: - workflow.autolock_interrupts_workflow = False - - # Here we make sure we always respond with a Failure response - # in case of any errors. - try: - # Find a protobuf.MessageType subclass that describes this - # message. Raises if the type is not found. - req_type = protobuf.type_for_wire(msg.type) - - # Try to decode the message according to schema from - # `req_type`. Raises if the message is malformed. - req_msg = wrap_protobuf_load(msg.data, req_type) - - # Create the handler task. - task = handler(req_msg) - - # Run the workflow task. Workflow can do more on-the-wire - # communication inside, but it should eventually return a - # response message, or raise an exception (a rather common - # thing to do). Exceptions are handled in the code below. - if use_workflow: - # Spawn a workflow around the task. This ensures that concurrent - # workflows are shut down. - res_msg = await workflow.spawn(context.with_context(ctx, task)) - else: - # For debug messages, ignore workflow processing and just await - # results of the handler. - res_msg = await task - - except context.UnexpectedMessage as exc: - # Workflow was trying to read a message from the wire, and - # something unexpected came in. See Context.read() for - # example, which expects some particular message and raises - # UnexpectedMessage if another one comes in. - # In order not to lose the message, we return it to the caller. - # TODO: - # We might handle only the few common cases here, like - # Initialize and Cancel. - return exc.msg - - except BaseException as exc: - # Either: - # - the message had a type that has a registered handler, but does not have - # a protobuf class - # - the message was not valid protobuf - # - workflow raised some kind of an exception while running - # - something canceled the workflow from the outside - if __debug__: - if isinstance(exc, ActionCancelled): - log.debug(__name__, "cancelled: %s", exc.message) - elif isinstance(exc, loop.TaskClosed): - log.debug(__name__, "cancelled: loop task was closed") - else: - log.exception(__name__, exc) - res_msg = failure(exc) + await thp_v1.thp_main_loop(iface, is_debug_session) - if res_msg is not None: - # perform the write outside the big try-except block, so that usb write - # problem bubbles up - await ctx.write(res_msg) - return None + if not __debug__ or not is_debug_session: + # Unload modules imported by the workflow. Should not raise. + # This is not done for the debug session because the snapshot taken + # in a debug session would clear modules which are in use by the + # workflow running on wire. + utils.unimport_end(modules) + loop.clear() + return + + except Exception as exc: + # Log and try again. The session handler can only exit explicitly via + # loop.clear() above. + if __debug__: + log.exception(__name__, exc) async def handle_session( @@ -205,7 +135,7 @@ async def handle_session( else: ctx_buffer = WIRE_BUFFER - ctx = context.Context(iface, ctx_buffer, session_id) + ctx = context.CodecContext(iface, ctx_buffer, session_id) next_msg: protocol_common.MessageWithId | None = None if __debug__ and is_debug_session: @@ -235,10 +165,10 @@ async def handle_session( next_msg = None # Set ctx.session_id to the value msg.session_id - ctx.session_id = msg.session_id + ctx.channel_id = msg.session_id try: - next_msg = await _handle_single_message( + next_msg = await message_handler.handle_single_message( ctx, msg, use_workflow=not is_debug_session ) except Exception as exc: diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py index f29b557c0..7dc410682 100644 --- a/core/src/trezor/wire/context.py +++ b/core/src/trezor/wire/context.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING import trezor.wire.protocol as protocol from trezor import log, loop, protobuf -from .protocol_common import MessageWithId +from .protocol_common import Context, MessageWithId if TYPE_CHECKING: from trezorio import WireInterface @@ -54,10 +54,10 @@ class UnexpectedMessage(Exception): self.msg = msg -class Context: +class CodecContext(Context): """Wire context. - Represents USB communication inside a particular session on a particular interface + Represents USB communication inside a particular session (channel) on a particular interface (i.e., wire, debug, single BT connection, etc.) """ @@ -65,11 +65,12 @@ class Context: self, iface: WireInterface, buffer: bytearray, - session_id: bytes | None = None, + channel_id: bytes | None = None, ) -> None: self.iface = iface self.buffer = buffer - self.session_id = session_id + self.channel_id = channel_id + super().__init__(iface, channel_id) def read_from_wire(self) -> Awaitable[MessageWithId]: """Read a whole message from the wire without parsing it.""" @@ -99,8 +100,8 @@ class Context: to save on having to decode the type code into a protobuf class. """ if __debug__: - if self.session_id is not None: - sid = int.from_bytes(self.session_id, "big") + if self.channel_id is not None: + sid = int.from_bytes(self.channel_id, "big") else: sid = -1 log.debug( @@ -126,8 +127,8 @@ class Context: expected_type = protobuf.type_for_wire(msg.type) if __debug__: - if self.session_id is not None: - sid = int.from_bytes(self.session_id, "big") + if self.channel_id is not None: + sid = int.from_bytes(self.channel_id, "big") else: sid = -1 log.debug( @@ -146,8 +147,8 @@ class Context: async def write(self, msg: protobuf.MessageType) -> None: """Write a message to the wire.""" if __debug__: - if self.session_id is not None: - sid = int.from_bytes(self.session_id, "big") + if self.channel_id is not None: + sid = int.from_bytes(self.channel_id, "big") else: sid = -1 log.debug( @@ -173,8 +174,8 @@ class Context: msg_size = protobuf.encode(buffer, msg) msg_session_id = None - if self.session_id is not None: - msg_session_id = bytearray(self.session_id) + if self.channel_id is not None: + msg_session_id = bytearray(self.channel_id) await protocol.write_message( self.iface, MessageWithId( @@ -185,7 +186,7 @@ class Context: ) -CURRENT_CONTEXT: Context | None = None +CURRENT_CONTEXT: CodecContext | None = None def wait(task: Awaitable[T]) -> Awaitable[T]: @@ -250,7 +251,7 @@ async def maybe_call( await call(msg, expected_type) -def get_context() -> Context: +def get_context() -> CodecContext: """Get the current session context. Can be needed in case the caller needs raw read and raw write capabilities, which @@ -264,7 +265,7 @@ def get_context() -> Context: return CURRENT_CONTEXT -def with_context(ctx: Context, workflow: loop.Task) -> Generator: +def with_context(ctx: CodecContext, workflow: loop.Task) -> Generator: """Run a workflow in a particular context. Stores the context in a closure and installs it into the global variable every time diff --git a/core/src/trezor/wire/message_handler.py b/core/src/trezor/wire/message_handler.py new file mode 100644 index 000000000..8be41ccd3 --- /dev/null +++ b/core/src/trezor/wire/message_handler.py @@ -0,0 +1,195 @@ +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_common import InvalidSessionError +from trezor import log, loop, protobuf, utils, workflow +from trezor.enums import FailureType +from trezor.messages import Failure +from trezor.wire import context, protocol_common +from trezor.wire.errors import ActionCancelled, DataError, Error + +# Import all errors into namespace, so that `wire.Error` is available from +# other packages. +from trezor.wire.errors import * # isort:skip # noqa: F401,F403 + + +if TYPE_CHECKING: + from trezorio import WireInterface + from typing import Any, Callable, Container, Coroutine, TypeVar + + Msg = TypeVar("Msg", bound=protobuf.MessageType) + HandlerTask = Coroutine[Any, Any, protobuf.MessageType] + Handler = Callable[[Msg], HandlerTask] + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + + +# If set to False protobuf messages marked with "experimental_message" option are rejected. +EXPERIMENTAL_ENABLED = False + + +def wrap_protobuf_load( + buffer: bytes, + expected_type: type[LoadedMessageType], +) -> LoadedMessageType: + try: + msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED) + if __debug__ and utils.EMULATOR: + log.debug( + __name__, "received message contents:\n%s", utils.dump_protobuf(msg) + ) + return msg + except Exception as e: + if __debug__: + log.exception(__name__, e) + if e.args: + raise DataError("Failed to decode message: " + " ".join(e.args)) + else: + raise DataError("Failed to decode message") + + +_PROTOBUF_BUFFER_SIZE = const(8192) + +WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) + +if __debug__: + PROTOBUF_BUFFER_SIZE_DEBUG = 1024 + WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) + + +async def handle_single_message( + ctx: context.CodecContext, msg: protocol_common.MessageWithId, use_workflow: bool +) -> protocol_common.MessageWithId | None: + """Handle a message that was loaded from USB by the caller. + + Find the appropriate handler, run it and write its result on the wire. In case + a problem is encountered at any point, write the appropriate error on the wire. + + If the workflow finished normally or with an error, the return value is None. + + If an unexpected message had arrived on the wire while the workflow was processing, + the workflow is shut down with an `UnexpectedMessage` exception. This is not + considered an "error condition" to return over the wire -- instead the message + is processed as if starting a new workflow. + In such case, the `UnexpectedMessage` is caught and the message is returned + to the caller. It will then be processed in the next iteration of the message loop. + """ + if __debug__: + try: + msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME + except Exception: + msg_type = f"{msg.type} - unknown message type" + if ctx.channel_id is not None: + sid = int.from_bytes(ctx.channel_id, "big") + else: + sid = -1 + log.debug( + __name__, + "%s:%x receive: <%s>", + ctx.iface.iface_num(), + sid, + msg_type, + ) + + res_msg: protobuf.MessageType | None = None + + # We need to find a handler for this message type. Should not raise. + handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none + + if handler is None: + # If no handler is found, we can skip decoding and directly + # respond with failure. + await ctx.write(unexpected_message()) + return None + + if msg.type in workflow.ALLOW_WHILE_LOCKED: + workflow.autolock_interrupts_workflow = False + + # Here we make sure we always respond with a Failure response + # in case of any errors. + try: + # Find a protobuf.MessageType subclass that describes this + # message. Raises if the type is not found. + req_type = protobuf.type_for_wire(msg.type) + + # Try to decode the message according to schema from + # `req_type`. Raises if the message is malformed. + req_msg = wrap_protobuf_load(msg.data, req_type) + + # Create the handler task. + task = handler(req_msg) + + # Run the workflow task. Workflow can do more on-the-wire + # communication inside, but it should eventually return a + # response message, or raise an exception (a rather common + # thing to do). Exceptions are handled in the code below. + if use_workflow: + # Spawn a workflow around the task. This ensures that concurrent + # workflows are shut down. + res_msg = await workflow.spawn(context.with_context(ctx, task)) + else: + # For debug messages, ignore workflow processing and just await + # results of the handler. + res_msg = await task + + except context.UnexpectedMessage as exc: + # Workflow was trying to read a message from the wire, and + # something unexpected came in. See Context.read() for + # example, which expects some particular message and raises + # UnexpectedMessage if another one comes in. + # In order not to lose the message, we return it to the caller. + # TODO: + # We might handle only the few common cases here, like + # Initialize and Cancel. + return exc.msg + + except BaseException as exc: + # Either: + # - the message had a type that has a registered handler, but does not have + # a protobuf class + # - the message was not valid protobuf + # - workflow raised some kind of an exception while running + # - something canceled the workflow from the outside + if __debug__: + if isinstance(exc, ActionCancelled): + log.debug(__name__, "cancelled: %s", exc.message) + elif isinstance(exc, loop.TaskClosed): + log.debug(__name__, "cancelled: loop task was closed") + else: + log.exception(__name__, exc) + res_msg = failure(exc) + + if res_msg is not None: + # perform the write outside the big try-except block, so that usb write + # problem bubbles up + await ctx.write(res_msg) + return None + + +def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None: + """Placeholder handler lookup before a proper one is registered.""" + return None + + +find_handler = _find_handler_placeholder +AVOID_RESTARTING_FOR: Container[int] = () + + +def failure(exc: BaseException) -> Failure: + if isinstance(exc, Error): + return Failure(code=exc.code, message=exc.message) + elif isinstance(exc, loop.TaskClosed): + return Failure(code=FailureType.ActionCancelled, message="Cancelled") + elif isinstance(exc, InvalidSessionError): + return Failure(code=FailureType.InvalidSession, message="Invalid session") + else: + # NOTE: when receiving generic `FirmwareError` on non-debug build, + # change the `if __debug__` to `if True` to get the full error message. + if __debug__: + message = str(exc) + else: + message = "Firmware error" + return Failure(code=FailureType.FirmwareError, message=message) + + +def unexpected_message() -> Failure: + return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") diff --git a/core/src/trezor/wire/protocol_common.py b/core/src/trezor/wire/protocol_common.py index 7c7ab80f2..93b77e617 100644 --- a/core/src/trezor/wire/protocol_common.py +++ b/core/src/trezor/wire/protocol_common.py @@ -1,3 +1,6 @@ +from trezor import protobuf + + class Message: def __init__( self, @@ -24,3 +27,11 @@ class MessageWithId(Message): class WireError(Exception): pass + + +class Context: + def __init__(self, iface, channel_id) -> None: + self.iface = iface + self.channel_id = channel_id + + async def write(self, msg: protobuf.MessageType) -> None: ... diff --git a/core/src/trezor/wire/thp/channel_context.py b/core/src/trezor/wire/thp/channel_context.py new file mode 100644 index 000000000..003798222 --- /dev/null +++ b/core/src/trezor/wire/thp/channel_context.py @@ -0,0 +1,106 @@ +import ustruct +from micropython import const +from typing import TYPE_CHECKING + +from storage.cache_thp import SessionThpCache +from trezor import loop, protobuf, utils + +from ..protocol_common import Context +from . import thp_session +from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT + +# from .thp_session import SessionState, ThpError + +if TYPE_CHECKING: + from trezorio import WireInterface + +_INIT_DATA_OFFSET = const(5) +_CONT_DATA_OFFSET = const(3) + + +class ChannelContext(Context): + def __init__( + self, iface: WireInterface, channel_id: int, session_data: SessionThpCache + ) -> None: + super().__init__(iface, channel_id) + self.session_data = session_data + self.buffer: utils.BufferType + self.waiting_for_ack_timeout: loop.Task | None + self.is_cont_packet_expected: bool = False + self.expected_payload_length: int = 0 + self.bytes_read = 0 + + # ACCESS TO SESSION_DATA + + def get_management_session_state(self): + return thp_session.get_state(self.session_data) + + # CALLED BY THP_MAIN_LOOP + + async def receive_packet(self, packet: utils.BufferType): + ctrl_byte = packet[0] + if _is_ctrl_byte_continuation(ctrl_byte): + await self._handle_cont_packet(packet) + else: + await self._handle_init_packet(packet) + + async def _handle_init_packet(self, packet): + ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet) + packet_payload = packet[5:] + + if _is_ctrl_byte_encrypted_transport(ctrl_byte): + packet_payload = self._decode(packet_payload) + + # session_id = packet_payload[0] # TODO handle handshake differently + self.expected_payload_length = payload_length + self.bytes_read = 0 + + await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET) + # TODO Set/Provide different buffer for management session + + if self.expected_payload_length == self.bytes_read: + self._finish_message() + else: + self.is_cont_packet_expected = True + + async def _handle_cont_packet(self, packet): + if not self.is_cont_packet_expected: + return # Continuation packet is not expected, ignoring + await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET) + + def _decode(self, payload) -> bytes: + return payload # TODO add decryption process + + async def _buffer_packet_data( + self, payload_buffer, packet: utils.BufferType, offset + ): + self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset) + + def _finish_message(self): + # TODO Provide loaded message to SessionContext or handle it with this ChannelContext + self.bytes_read = 0 + self.expected_payload_length = 0 + self.is_cont_packet_expected = False + + # CALLED BY WORKFLOW / SESSION CONTEXT + + async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None: + pass + # TODO protocol.write(self.iface, self.channel_id, session_id, msg) + + def create_new_session( + self, + passphrase="", + ) -> None: # TODO change it to output session data + pass + # TODO check, wheter a session with this passphrase already exists + # if not, create a new session with this passphrase + # if yes, what TODO TODO ??? + + +def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool: + return ctrl_byte & 0x80 == CONTINUATION_PACKET + + +def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool: + return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT diff --git a/core/src/trezor/wire/thp/session_context.py b/core/src/trezor/wire/thp/session_context.py new file mode 100644 index 000000000..9d56d37bb --- /dev/null +++ b/core/src/trezor/wire/thp/session_context.py @@ -0,0 +1,14 @@ +from trezor import protobuf + +from ..context import Context +from .channel_context import ChannelContext + + +class SessionContext(Context): + def __init__(self, channel_context: ChannelContext, session_id: int) -> None: + super().__init__(channel_context.iface, channel_context.channel_id) + self.channel_context = channel_context + self.session_id = session_id + + async def write(self, msg: protobuf.MessageType) -> None: + return await self.channel_context.write(msg, self.session_id) diff --git a/core/src/trezor/wire/thp/thp_messages.py b/core/src/trezor/wire/thp/thp_messages.py index 2837a0eda..12b4649dd 100644 --- a/core/src/trezor/wire/thp/thp_messages.py +++ b/core/src/trezor/wire/thp/thp_messages.py @@ -5,6 +5,7 @@ from storage.cache_thp import BROADCAST_CHANNEL_ID from ..protocol_common import Message CONTINUATION_PACKET = 0x80 +ENCRYPTED_TRANSPORT = 0x02 _ERROR = 0x41 _CHANNEL_ALLOCATION_RES = 0x40 diff --git a/core/src/trezor/wire/thp_v1.py b/core/src/trezor/wire/thp_v1.py index f0ceee329..56d694131 100644 --- a/core/src/trezor/wire/thp_v1.py +++ b/core/src/trezor/wire/thp_v1.py @@ -8,8 +8,14 @@ from trezor import io, loop, utils from .protocol_common import MessageWithId from .thp import ack_handler, checksum, thp_messages from .thp import thp_session as THP +from .thp.channel_context import ChannelContext from .thp.checksum import CHECKSUM_LENGTH -from .thp.thp_messages import CONTINUATION_PACKET, InitHeader, InterruptingInitPacket +from .thp.thp_messages import ( + CONTINUATION_PACKET, + ENCRYPTED_TRANSPORT, + InitHeader, + InterruptingInitPacket, +) from .thp.thp_session import SessionState, ThpError if TYPE_CHECKING: @@ -21,12 +27,14 @@ _CHANNEL_ALLOCATION_REQ = 0x40 _ACK_MESSAGE = 0x20 _HANDSHAKE_INIT = 0x00 _PLAINTEXT = 0x01 -ENCRYPTED_TRANSPORT = 0x02 _REPORT_LENGTH = const(64) _REPORT_INIT_DATA_OFFSET = const(5) _REPORT_CONT_DATA_OFFSET = const(3) +_BUFFER: bytearray +_BUFFER_LOCK = None + async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId: msg = await read_message_or_init_packet(iface, buffer) @@ -38,6 +46,37 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag return msg +def set_buffer(buffer): + _BUFFER = buffer + print(_BUFFER) # TODO remove + + +async def thp_main_loop(iface: WireInterface, is_debug_session=False): + + CHANNELS: dict[int, ChannelContext] = {} + # TODO load cached channels/sessions + + read = loop.wait(iface.iface_num() | io.POLL_READ) + + while True: + packet = await read + ctrl_byte, cid = ustruct.unpack(">BH", packet) + + if cid == BROADCAST_CHANNEL_ID: + await _handle_broadcast(iface, ctrl_byte, packet) + continue + + if cid in CHANNELS: + channel = CHANNELS[cid] + if channel is None: + raise ThpError("Invalid state of a channel") + if channel.get_management_session_state != SessionState.UNALLOCATED: + await channel.receive_packet(packet) + continue + + await _handle_unallocated(iface, cid) + + async def read_message_or_init_packet( iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None ) -> MessageWithId | InterruptingInitPacket: @@ -50,10 +89,10 @@ async def read_message_or_init_packet( raise ThpError("Reading failed unexpectedly, report is None.") # Channel multiplexing - ctrl_byte, cid, payload_length = ustruct.unpack(">BHH", report) + ctrl_byte, cid = ustruct.unpack(">BH", report) if cid == BROADCAST_CHANNEL_ID: - await _handle_broadcast(iface, ctrl_byte, report) # TODO await + await _handle_broadcast(iface, ctrl_byte, report) report = None continue @@ -64,7 +103,7 @@ async def read_message_or_init_packet( # continuation packet is not expected - ignore report = None continue - + payload_length = ustruct.unpack(">H", report[3:])[0] payload = _get_buffer_for_payload(payload_length, buffer) header = InitHeader(ctrl_byte, cid, payload_length) @@ -255,15 +294,15 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None: async def _handle_broadcast( - iface: WireInterface, ctrl_byte, report + iface: WireInterface, ctrl_byte, packet ) -> MessageWithId | None: if ctrl_byte != _CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in broadcast channel packet") - length, nonce = ustruct.unpack(">H8s", report[3:]) + length, nonce = ustruct.unpack(">H8s", packet[3:]) header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length) - payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH) + payload = _get_buffer_for_payload(length, packet[5:], _MAX_CID_REQ_PAYLOAD_LENGTH) if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]): raise ThpError("Checksum is not valid") @@ -274,8 +313,8 @@ async def _handle_broadcast( response_header = InitHeader.get_channel_allocation_response_header( len(response_data) + CHECKSUM_LENGTH, ) - chksum = checksum.compute(response_header.to_bytes() + response_data) + await write_to_wire(iface, response_header, response_data + chksum)