Improve handling of channel allocation requests

M1nd3r/thp5
M1nd3r 2 months ago committed by M1nd3r
parent 22418a2fe0
commit 123365b2bb

@ -13,6 +13,7 @@ if TYPE_CHECKING:
from trezorio import WireInterface
_MAX_PAYLOAD_LEN = const(60000)
_MAX_CID_REQ_PAYLOAD_LENGTH = const(12) # TODO set to reasonable value
_CHECKSUM_LENGTH = const(4)
_CHANNEL_ALLOCATION_REQ = 0x40
_CHANNEL_ALLOCATION_RES = 0x40
@ -166,9 +167,9 @@ def _get_loop_wait_read(iface: WireInterface):
def _get_buffer_for_payload(
payload_length: int, existing_buffer: utils.BufferType
payload_length: int, existing_buffer: utils.BufferType, max_length=_MAX_PAYLOAD_LEN
) -> utils.BufferType:
if payload_length > _MAX_PAYLOAD_LEN:
if payload_length > max_length:
raise ThpError("Message too large")
if payload_length > len(existing_buffer):
# allocate a new buffer to fit the message
@ -285,14 +286,18 @@ async def _write_report(write, iface: WireInterface, report: bytearray) -> None:
async def _handle_broadcast(iface: WireInterface, ctrl_byte, report) -> Message | None:
if ctrl_byte != _CHANNEL_ALLOCATION_REQ:
raise ThpError("Unexpected ctrl_byte in broadcast channel packet")
nonce, checksum = ustruct.unpack(">8s4s", report[5:])
# Note that the length field of the channel allocation request is ignored.
if not _is_checksum_valid(checksum, data=report[:-4]):
length, nonce = ustruct.unpack(">H8s", report[3:])
header = InitHeader(ctrl_byte, BROADCAST_CHANNEL_ID, length)
payload = _get_buffer_for_payload(length, report[5:], _MAX_CID_REQ_PAYLOAD_LENGTH)
if not _is_checksum_valid(payload[-4:], header.to_bytes() + payload[:-4]):
raise ThpError("Checksum is not valid")
channel_id = _get_new_channel_id()
THP.create_new_unauthenticated_session(iface, channel_id)
response_data = (
ustruct.pack(">8sH", nonce, channel_id) + _ENCODED_PROTOBUF_DEVICE_PROPERTIES
)
@ -379,5 +384,5 @@ def _add_sync_bit_to_ctrl_byte(ctrl_byte, sync_bit):
raise ThpError("Unexpected synchronization bit")
def _compute_checksum_bytes(data: bytes | utils.BufferType):
def _compute_checksum_bytes(data: bytes | utils.BufferType) -> bytes:
return crc.crc32(data).to_bytes(4, "big")

@ -399,10 +399,10 @@ class DebugLink:
self.screen_text_file = file_path
def open(self) -> None:
self.transport.begin_session()
self.transport.begin_connection()
def close(self) -> None:
self.transport.end_session()
self.transport.end_connection()
def _call(self, msg: protobuf.MessageType, nowait: bool = False) -> Any:
LOG.debug(
@ -1275,8 +1275,8 @@ def prodtest_t1(client: "TrezorClient") -> protobuf.MessageType:
raise RuntimeError("Device must be in bootloader mode")
return client.call(
messages.ProdTestT1(
payload=b"\x00\xFF\x55\xAA\x66\x99\x33\xCCABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xFF\x55\xAA\x66\x99\x33\xCC"
messages.SelfTest(
payload=b"\x00\xff\x55\xaa\x66\x99\x33\xccABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!\x00\xff\x55\xaa\x66\x99\x33\xcc"
)
)

Loading…
Cancel
Save