diff --git a/core/src/trezor/wire/thp/thp_main.py b/core/src/trezor/wire/thp/thp_main.py index 1680a67abb..f4be23c1a1 100644 --- a/core/src/trezor/wire/thp/thp_main.py +++ b/core/src/trezor/wire/thp/thp_main.py @@ -14,6 +14,7 @@ from . import ( ThpErrorType, channel_manager, checksum, + control_byte, get_channel_allocation_response, writer, ) @@ -47,20 +48,20 @@ async def thp_main_loop(iface: WireInterface) -> None: assert packet_len == len(packet) iface.read(packet, 0) - ctrl_byte, cid = ustruct.unpack(">BH", packet) - - if ctrl_byte == CODEC_V1: + if _get_ctrl_byte(packet) == CODEC_V1: await _handle_codec_v1(iface, packet) continue + cid = ustruct.unpack(">BH", packet)[1] + if cid == BROADCAST_CHANNEL_ID: - await _handle_broadcast(iface, ctrl_byte, packet) + await _handle_broadcast(iface, packet) continue if cid in _CHANNELS: await _handle_allocated(iface, cid, packet) else: - await _handle_unallocated(iface, cid) + await _handle_unallocated(iface, cid, packet) except ThpError as e: if __debug__: @@ -77,10 +78,8 @@ async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None: await writer.write_packet_to_wire(iface, error_message) -async def _handle_broadcast( - iface: WireInterface, ctrl_byte: int, packet: utils.BufferType -) -> None: - if ctrl_byte != CHANNEL_ALLOCATION_REQ: +async def _handle_broadcast(iface: WireInterface, packet: utils.BufferType) -> None: + if _get_ctrl_byte(packet) != CHANNEL_ALLOCATION_REQ: raise ThpError("Unexpected ctrl_byte in a broadcast channel packet") if __debug__: log.debug(__name__, "Received valid message on the broadcast channel") @@ -114,7 +113,7 @@ async def _handle_allocated( ) -> None: channel = _CHANNELS[cid] if channel is None: - await _handle_unallocated(iface, cid) + await _handle_unallocated(iface, cid, packet) raise ThpError("Invalid state of a channel") if channel.iface is not iface: # TODO send error message to wire @@ -126,7 +125,9 @@ async def _handle_allocated( await x -async def _handle_unallocated(iface: WireInterface, cid: int) -> None: +async def _handle_unallocated(iface: WireInterface, cid: int, packet: bytes) -> None: + if control_byte.is_continuation(_get_ctrl_byte(packet)): + return data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big") header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) await write_payload_to_wire_and_add_checksum(iface, header, data) @@ -159,6 +160,10 @@ def _reuse_existing_buffer( return memoryview(existing_buffer)[:payload_length] +def _get_ctrl_byte(packet: bytes) -> int: + return packet[0] + + def _get_codec_v1_error_message() -> bytes: # Codec_v1 magic constant "?##" + Failure message type + msg_size # + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B