1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-31 01:41:18 +00:00

feat(core): send UNALLOCATED_CHANNEL error as a response only on init packet

[no changelog]
This commit is contained in:
M1nd3r 2025-01-13 13:03:01 +01:00
parent 818c45cc86
commit 34aaf6e645

View File

@ -14,6 +14,7 @@ from . import (
ThpErrorType, ThpErrorType,
channel_manager, channel_manager,
checksum, checksum,
control_byte,
get_channel_allocation_response, get_channel_allocation_response,
writer, writer,
) )
@ -47,20 +48,20 @@ async def thp_main_loop(iface: WireInterface) -> None:
assert packet_len == len(packet) assert packet_len == len(packet)
iface.read(packet, 0) iface.read(packet, 0)
ctrl_byte, cid = ustruct.unpack(">BH", packet) if _get_ctrl_byte(packet) == CODEC_V1:
if ctrl_byte == CODEC_V1:
await _handle_codec_v1(iface, packet) await _handle_codec_v1(iface, packet)
continue continue
cid = ustruct.unpack(">BH", packet)[1]
if cid == BROADCAST_CHANNEL_ID: if cid == BROADCAST_CHANNEL_ID:
await _handle_broadcast(iface, ctrl_byte, packet) await _handle_broadcast(iface, packet)
continue continue
if cid in _CHANNELS: if cid in _CHANNELS:
await _handle_allocated(iface, cid, packet) await _handle_allocated(iface, cid, packet)
else: else:
await _handle_unallocated(iface, cid) await _handle_unallocated(iface, cid, packet)
except ThpError as e: except ThpError as e:
if __debug__: if __debug__:
@ -77,10 +78,8 @@ async def _handle_codec_v1(iface: WireInterface, packet: bytes) -> None:
await writer.write_packet_to_wire(iface, error_message) await writer.write_packet_to_wire(iface, error_message)
async def _handle_broadcast( async def _handle_broadcast(iface: WireInterface, packet: utils.BufferType) -> None:
iface: WireInterface, ctrl_byte: int, packet: utils.BufferType if _get_ctrl_byte(packet) != CHANNEL_ALLOCATION_REQ:
) -> None:
if ctrl_byte != CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in a broadcast channel packet") raise ThpError("Unexpected ctrl_byte in a broadcast channel packet")
if __debug__: if __debug__:
log.debug(__name__, "Received valid message on the broadcast channel") log.debug(__name__, "Received valid message on the broadcast channel")
@ -114,7 +113,7 @@ async def _handle_allocated(
) -> None: ) -> None:
channel = _CHANNELS[cid] channel = _CHANNELS[cid]
if channel is None: if channel is None:
await _handle_unallocated(iface, cid) await _handle_unallocated(iface, cid, packet)
raise ThpError("Invalid state of a channel") raise ThpError("Invalid state of a channel")
if channel.iface is not iface: if channel.iface is not iface:
# TODO send error message to wire # TODO send error message to wire
@ -126,7 +125,9 @@ async def _handle_allocated(
await x await x
async def _handle_unallocated(iface: WireInterface, cid: int) -> None: async def _handle_unallocated(iface: WireInterface, cid: int, packet: bytes) -> None:
if control_byte.is_continuation(_get_ctrl_byte(packet)):
return
data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big") data = (ThpErrorType.UNALLOCATED_CHANNEL).to_bytes(1, "big")
header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH) header = PacketHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
await write_payload_to_wire_and_add_checksum(iface, header, data) await write_payload_to_wire_and_add_checksum(iface, header, data)
@ -159,6 +160,10 @@ def _reuse_existing_buffer(
return memoryview(existing_buffer)[:payload_length] return memoryview(existing_buffer)[:payload_length]
def _get_ctrl_byte(packet: bytes) -> int:
return packet[0]
def _get_codec_v1_error_message() -> bytes: def _get_codec_v1_error_message() -> bytes:
# Codec_v1 magic constant "?##" + Failure message type + msg_size # Codec_v1 magic constant "?##" + Failure message type + msg_size
# + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B # + msg_data (code = "Failure_InvalidProtocol") + padding to 64 B