|
|
|
@ -2,74 +2,35 @@ import ustruct
|
|
|
|
|
from micropython import const
|
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
|
|
|
|
|
|
import trezor.wire.thp_session as THP
|
|
|
|
|
from storage.cache_thp import BROADCAST_CHANNEL_ID, SessionThpCache
|
|
|
|
|
from trezor import io, loop, utils
|
|
|
|
|
from trezor.crypto import crc
|
|
|
|
|
from trezor.wire.protocol_common import Message
|
|
|
|
|
from trezor.wire.thp_session import SessionState, ThpError
|
|
|
|
|
|
|
|
|
|
from .protocol_common import MessageWithId
|
|
|
|
|
from .thp import ack_handler, checksum, thp_messages
|
|
|
|
|
from .thp import thp_session as THP
|
|
|
|
|
from .thp.checksum import CHECKSUM_LENGTH
|
|
|
|
|
from .thp.thp_messages import CONTINUATION_PACKET, InitHeader, InterruptingInitPacket
|
|
|
|
|
from .thp.thp_session import SessionState, ThpError
|
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from trezorio import WireInterface
|
|
|
|
|
|
|
|
|
|
_MAX_PAYLOAD_LEN = const(60000)
|
|
|
|
|
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
|
|
|
|
|
_CHECKSUM_LENGTH = const(4)
|
|
|
|
|
_CHANNEL_ALLOCATION_REQ = 0x40
|
|
|
|
|
_CHANNEL_ALLOCATION_RES = 0x40
|
|
|
|
|
_ERROR = 0x41
|
|
|
|
|
_CONTINUATION_PACKET = 0x80
|
|
|
|
|
_ACK_MESSAGE = 0x20
|
|
|
|
|
_HANDSHAKE_INIT = 0x00
|
|
|
|
|
_PLAINTEXT = 0x01
|
|
|
|
|
ENCRYPTED_TRANSPORT = 0x02
|
|
|
|
|
_ENCODED_PROTOBUF_DEVICE_PROPERTIES = (
|
|
|
|
|
b"\x0a\x04\x54\x33\x57\x31\x10\x05\x18\x00\x20\x01\x28\x01\x28\x02"
|
|
|
|
|
)
|
|
|
|
|
_UNALLOCATED_SESSION_ERROR = (
|
|
|
|
|
b"\x55\x4e\x41\x4c\x4c\x4f\x43\x41\x54\x45\x44\x5f\x53\x45\x53\x53\x49\x4f\x4e"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
_REPORT_LENGTH = const(64)
|
|
|
|
|
_REPORT_INIT_DATA_OFFSET = const(5)
|
|
|
|
|
_REPORT_CONT_DATA_OFFSET = const(3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InitHeader:
|
|
|
|
|
format_str = ">BHH"
|
|
|
|
|
|
|
|
|
|
def __init__(self, ctrl_byte, cid, length) -> None:
|
|
|
|
|
self.ctrl_byte = ctrl_byte
|
|
|
|
|
self.cid = cid
|
|
|
|
|
self.length = length
|
|
|
|
|
|
|
|
|
|
def to_bytes(self) -> bytes:
|
|
|
|
|
return ustruct.pack(
|
|
|
|
|
InitHeader.format_str, self.ctrl_byte, self.cid, self.length
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def pack_to_buffer(self, buffer, buffer_offset=0) -> None:
|
|
|
|
|
ustruct.pack_into(
|
|
|
|
|
InitHeader.format_str,
|
|
|
|
|
buffer,
|
|
|
|
|
buffer_offset,
|
|
|
|
|
self.ctrl_byte,
|
|
|
|
|
self.cid,
|
|
|
|
|
self.length,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def pack_to_cont_buffer(self, buffer, buffer_offset=0) -> None:
|
|
|
|
|
ustruct.pack_into(">BH", buffer, buffer_offset, _CONTINUATION_PACKET, self.cid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InterruptingInitPacket:
|
|
|
|
|
def __init__(self, report: bytes) -> None:
|
|
|
|
|
self.initReport = report
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Message:
|
|
|
|
|
async def read_message(iface: WireInterface, buffer: utils.BufferType) -> MessageWithId:
|
|
|
|
|
msg = await read_message_or_init_packet(iface, buffer)
|
|
|
|
|
while type(msg) is not Message:
|
|
|
|
|
while type(msg) is not MessageWithId:
|
|
|
|
|
if isinstance(msg, InterruptingInitPacket):
|
|
|
|
|
msg = await read_message_or_init_packet(iface, buffer, msg.initReport)
|
|
|
|
|
else:
|
|
|
|
@ -79,7 +40,7 @@ async def read_message(iface: WireInterface, buffer: utils.BufferType) -> Messag
|
|
|
|
|
|
|
|
|
|
async def read_message_or_init_packet(
|
|
|
|
|
iface: WireInterface, buffer: utils.BufferType, firstReport: bytes | None = None
|
|
|
|
|
) -> Message | InterruptingInitPacket:
|
|
|
|
|
) -> MessageWithId | InterruptingInitPacket:
|
|
|
|
|
report = firstReport
|
|
|
|
|
while True:
|
|
|
|
|
# Wait for an initial report
|
|
|
|
@ -89,7 +50,7 @@ async def read_message_or_init_packet(
|
|
|
|
|
raise ThpError("Reading failed unexpectedly, report is None.")
|
|
|
|
|
|
|
|
|
|
# Channel multiplexing
|
|
|
|
|
ctrl_byte, cid = ustruct.unpack(">BH", report)
|
|
|
|
|
ctrl_byte, cid, payload_length = ustruct.unpack(">BHH", report)
|
|
|
|
|
|
|
|
|
|
if cid == BROADCAST_CHANNEL_ID:
|
|
|
|
|
await _handle_broadcast(iface, ctrl_byte, report) # TODO await
|
|
|
|
@ -104,7 +65,6 @@ async def read_message_or_init_packet(
|
|
|
|
|
report = None
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
payload_length = ustruct.unpack(">H", report[3:])[0]
|
|
|
|
|
payload = _get_buffer_for_payload(payload_length, buffer)
|
|
|
|
|
header = InitHeader(ctrl_byte, cid, payload_length)
|
|
|
|
|
|
|
|
|
@ -114,7 +74,7 @@ async def read_message_or_init_packet(
|
|
|
|
|
return interruptingPacket
|
|
|
|
|
|
|
|
|
|
# Check CRC
|
|
|
|
|
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
|
|
|
|
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
|
|
|
|
# checksum is not valid -> ignore message
|
|
|
|
|
report = None
|
|
|
|
|
continue
|
|
|
|
@ -141,7 +101,7 @@ async def read_message_or_init_packet(
|
|
|
|
|
|
|
|
|
|
# 1: Handle ACKs
|
|
|
|
|
if _is_ctrl_byte_ack(ctrl_byte):
|
|
|
|
|
_handle_received_ACK(session, sync_bit)
|
|
|
|
|
ack_handler.handle_received_ACK(session, sync_bit)
|
|
|
|
|
report = None
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
@ -215,8 +175,19 @@ async def _buffer_received_data(
|
|
|
|
|
nread += utils.memcpy(payload, nread, report, _REPORT_CONT_DATA_OFFSET)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def write_message_with_sync_control(
|
|
|
|
|
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
|
|
|
|
|
) -> None:
|
|
|
|
|
session = THP.get_session_from_id(message.session_id)
|
|
|
|
|
if session is None:
|
|
|
|
|
raise ThpError("Invalid session")
|
|
|
|
|
if (not THP.sync_can_send_message(session)) and (not is_retransmission):
|
|
|
|
|
raise ThpError("Cannot send another message before ACK is received.")
|
|
|
|
|
await write_message(iface, message, is_retransmission)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def write_message(
|
|
|
|
|
iface: WireInterface, message: Message, is_retransmission: bool = False
|
|
|
|
|
iface: WireInterface, message: MessageWithId, is_retransmission: bool = False
|
|
|
|
|
) -> None:
|
|
|
|
|
session = THP.get_session_from_id(message.session_id)
|
|
|
|
|
if session is None:
|
|
|
|
@ -245,9 +216,9 @@ async def write_message(
|
|
|
|
|
ctrl_byte, 1 - THP.sync_get_send_bit(session)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
header = InitHeader(ctrl_byte, cid, payload_length + _CHECKSUM_LENGTH)
|
|
|
|
|
checksum = _compute_checksum_bytes(header.to_bytes() + payload)
|
|
|
|
|
await write_to_wire(iface, header, payload + checksum)
|
|
|
|
|
header = InitHeader(ctrl_byte, cid, payload_length + CHECKSUM_LENGTH)
|
|
|
|
|
chksum = checksum.compute(header.to_bytes() + payload)
|
|
|
|
|
await write_to_wire(iface, header, payload + chksum)
|
|
|
|
|
# TODO set timeout for retransmission
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -283,7 +254,9 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None:
|
|
|
|
|
async def _handle_broadcast(
|
|
|
|
|
iface: WireInterface, ctrl_byte, report
|
|
|
|
|
) -> MessageWithId | None:
|
|
|
|
|
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
|
|
|
|
|
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
|
|
|
|
|
|
|
|
|
@ -291,67 +264,51 @@ async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message
|
|
|
|
|
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
|
|
|
|
|
|
|
|
|
|
payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
|
|
|
|
|
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
|
|
|
|
if not checksum.is_valid(payload[-4:], header.to_bytes() + payload[:-4]):
|
|
|
|
|
raise ThpError("Checksum is not valid")
|
|
|
|
|
|
|
|
|
|
channel_id = _get_new_channel_id()
|
|
|
|
|
THP.create_new_unauthenticated_session(iface, channel_id)
|
|
|
|
|
|
|
|
|
|
response_data = (
|
|
|
|
|
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
|
|
|
|
|
response_data = thp_messages.get_channel_allocation_response(nonce, channel_id)
|
|
|
|
|
response_header = InitHeader.get_channel_allocation_response_header(
|
|
|
|
|
len(response_data) + CHECKSUM_LENGTH,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
response_header = InitHeader(
|
|
|
|
|
_CHANNEL_ALLOCATION_RES,
|
|
|
|
|
BROADCAST_CHANNEL_ID,
|
|
|
|
|
len(response_data) + _CHECKSUM_LENGTH,
|
|
|
|
|
)
|
|
|
|
|
chksum = checksum.compute(response_header.to_bytes() + response_data)
|
|
|
|
|
await write_to_wire(iface, response_header, response_data + chksum)
|
|
|
|
|
|
|
|
|
|
checksum = _compute_checksum_bytes(response_header.to_bytes() + response_data)
|
|
|
|
|
await write_to_wire(iface, response_header, response_data + checksum)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_allocated(ctrl_byte, session: SessionThpCache, payload) -> Message:
|
|
|
|
|
async def _handle_allocated(
|
|
|
|
|
ctrl_byte, session: SessionThpCache, payload
|
|
|
|
|
) -> MessageWithId:
|
|
|
|
|
# Parameters session and ctrl_byte will be used to determine if the
|
|
|
|
|
# communication should be encrypted or not
|
|
|
|
|
|
|
|
|
|
message_type = ustruct.unpack(">H", payload)[0]
|
|
|
|
|
|
|
|
|
|
# trim message type and checksum from payload
|
|
|
|
|
message_data = payload[2:-_CHECKSUM_LENGTH]
|
|
|
|
|
return Message(message_type, message_data, session.session_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _handle_received_ACK(session: SessionThpCache, sync_bit: int) -> None:
|
|
|
|
|
# No ACKs expected
|
|
|
|
|
if THP.sync_can_send_message(session):
|
|
|
|
|
return
|
|
|
|
|
message_data = payload[2:-CHECKSUM_LENGTH]
|
|
|
|
|
return MessageWithId(message_type, message_data, session.session_id)
|
|
|
|
|
|
|
|
|
|
# ACK has incorrect sync bit
|
|
|
|
|
if THP.sync_get_send_bit(session) != sync_bit:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# ACK is expected and it has correct sync bit
|
|
|
|
|
THP.sync_set_can_send_message(session, True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_unallocated(iface, cid) -> Message | None:
|
|
|
|
|
data = _UNALLOCATED_SESSION_ERROR
|
|
|
|
|
header = InitHeader(_ERROR, cid, len(data) + _CHECKSUM_LENGTH)
|
|
|
|
|
checksum = _compute_checksum_bytes(header.to_bytes() + data)
|
|
|
|
|
await write_to_wire(iface, header, data + checksum)
|
|
|
|
|
async def _handle_unallocated(iface, cid) -> MessageWithId | None:
|
|
|
|
|
data = thp_messages.get_error_unallocated_channel()
|
|
|
|
|
header = InitHeader.get_error_header(cid, len(data) + CHECKSUM_LENGTH)
|
|
|
|
|
chksum = checksum.compute(header.to_bytes() + data)
|
|
|
|
|
await write_to_wire(iface, header, data + chksum)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _sendAck(iface: WireInterface, cid: int, ack_bit: int) -> None:
|
|
|
|
|
ctrl_byte = _add_sync_bit_to_ctrl_byte(_ACK_MESSAGE, ack_bit)
|
|
|
|
|
header = InitHeader(ctrl_byte, cid, _CHECKSUM_LENGTH)
|
|
|
|
|
checksum = _compute_checksum_bytes(header.to_bytes())
|
|
|
|
|
await write_to_wire(iface, header, checksum)
|
|
|
|
|
header = InitHeader(ctrl_byte, cid, CHECKSUM_LENGTH)
|
|
|
|
|
chksum = checksum.compute(header.to_bytes())
|
|
|
|
|
await write_to_wire(iface, header, chksum)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def _handle_unexpected_sync_bit(
|
|
|
|
|
iface: WireInterface, cid: int, sync_bit: int
|
|
|
|
|
) -> Message | None:
|
|
|
|
|
) -> MessageWithId | None:
|
|
|
|
|
await _sendAck(iface, cid, sync_bit)
|
|
|
|
|
|
|
|
|
|
# TODO handle cancelation messages and messages on allocated channels without synchronization
|
|
|
|
@ -362,13 +319,8 @@ def _get_new_channel_id() -> int:
|
|
|
|
|
return THP.get_next_channel_id()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_checksum_valid(checksum: bytes | utils.BufferType, data: bytes) -> bool:
|
|
|
|
|
data_checksum = _compute_checksum_bytes(data)
|
|
|
|
|
return checksum == data_checksum
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_ctrl_byte_continuation(ctrl_byte) -> bool:
|
|
|
|
|
return ctrl_byte & 0x80 == _CONTINUATION_PACKET
|
|
|
|
|
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_ctrl_byte_ack(ctrl_byte) -> bool:
|
|
|
|
@ -381,7 +333,3 @@ def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
|
|
|
|
|
if sync_bit == 1:
|
|
|
|
|
return ctrl_byte | 0x10
|
|
|
|
|
raise ThpError("Unexpected synchronization bit")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _compute_checksum_bytes(data: bytes | utils.BufferType) -> bytes:
|
|
|
|
|
return crc.crc32(data).to_bytes(4, "big")
|
|
|
|
|