from common import * from storage.cache_thp import BROADCAST_CHANNEL_ID from trezor.wire.thp.writer import REPORT_LENGTH from ubinascii import hexlify import ustruct from trezor import io, utils from trezor.loop import wait from trezor.utils import chunks from trezor.wire import thp_v1 from trezor.wire.protocol_common import MessageWithId import trezor.wire.thp.thp_session as THP from trezor.wire.thp import checksum from trezor.wire.thp.checksum import CHECKSUM_LENGTH class MockHID: def __init__(self, num): self.num = num self.data = [] def iface_num(self): return self.num def write(self, msg): self.data.append(bytearray(msg)) return len(msg) def wait_object(self, mode): return wait(mode | self.num) MESSAGE_TYPE = 0x4242 MESSAGE_TYPE_BYTES = b"\x42\x42" _MESSAGE_TYPE_LEN = 2 PLAINTEXT_0 = 0x01 PLAINTEXT_1 = 0x11 COMMON_CID = 4660 CONT = 0x80 HEADER_INIT_LENGTH = 5 HEADER_CONT_LENGTH = 3 INIT_MESSAGE_DATA_LENGTH = REPORT_LENGTH - HEADER_INIT_LENGTH - _MESSAGE_TYPE_LEN def make_header(ctrl_byte, cid, length): return ustruct.pack(">BHH", ctrl_byte, cid, length) def make_cont_header(): return ustruct.pack(">BH", CONT, COMMON_CID) def makeSimpleMessage(header, message_type, message_data): return header + ustruct.pack(">H", message_type) + message_data def makeCidRequest(header, message_data): return header + message_data def printBytes(a): print(hexlify(a).decode("utf-8")) def getPlaintext() -> bytes: if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1: return PLAINTEXT_1 return PLAINTEXT_0 def getCid() -> int: return THP.get_cid(THP.get_active_session()) # This test suite is an adaptation of test_trezor.wire.codec_v1 class TestWireTrezorHostProtocolV1(unittest.TestCase): def setUp(self): self.interface = MockHID(0xDEADBEEF) if not utils.USE_THP: import storage.cache_thp # noqa: F401 def _simple(self): cid_req_header = make_header( ctrl_byte=0x40, cid=BROADCAST_CHANNEL_ID, length=12 ) cid_request_dummy_data = b"\x00\x11\x22\x33\x44\x55\x66\x77\x96\x64\x3c\x6c" cid_req_message = makeCidRequest(cid_req_header, cid_request_dummy_data) message_header = make_header(ctrl_byte=0x01, cid=COMMON_CID, length=18) cid_request_dummy_data_checksum = b"\x67\x8e\xac\xe0" message = makeSimpleMessage( message_header, MESSAGE_TYPE, cid_request_dummy_data + cid_request_dummy_data_checksum, ) buffer = bytearray(64) gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) gen.send(cid_req_message) gen.send(None) gen.send(message) with self.assertRaises(StopIteration) as e: gen.send(None) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.data, cid_request_dummy_data) buffer_without_zeroes = buffer[: len(message) - 5] message_without_header = message[5:] # message should have been read into the buffer self.assertEqual(buffer_without_zeroes, message_without_header) def _read_one_packet(self): # zero length message - just a header PLAINTEXT = getPlaintext() header = make_header( PLAINTEXT, cid=COMMON_CID, length=_MESSAGE_TYPE_LEN + CHECKSUM_LENGTH ) chksum = checksum.compute(header + MESSAGE_TYPE_BYTES) message = header + MESSAGE_TYPE_BYTES + chksum buffer = bytearray(64) gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) gen.send(message) with self.assertRaises(StopIteration) as e: gen.send(None) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.data, b"") # message should have been read into the buffer self.assertEqual(buffer, MESSAGE_TYPE_BYTES + chksum + b"\x00" * 58) def _read_many_packets(self): message = bytes(range(256)) header = make_header( getPlaintext(), COMMON_CID, len(message) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, ) chksum = checksum.compute(header + MESSAGE_TYPE_BYTES + message) # message = MESSAGE_TYPE_BYTES + message + checksum # first packet is init header + 59 bytes of data # other packets are cont header + 61 bytes of data cont_header = make_cont_header() packets = [header + MESSAGE_TYPE_BYTES + message[:INIT_MESSAGE_DATA_LENGTH]] + [ cont_header + chunk for chunk in chunks( message[INIT_MESSAGE_DATA_LENGTH:] + chksum, 64 - HEADER_CONT_LENGTH, ) ] buffer = bytearray(262) gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) for packet in packets: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) query = gen.send(packet) # last packet will stop with self.assertRaises(StopIteration) as e: gen.send(None) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.data, message) # message should have been read into the buffer ) self.assertEqual(buffer, MESSAGE_TYPE_BYTES + message + chksum) def _read_large_message(self): message = b"hello world" header = make_header( getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + len(message) + CHECKSUM_LENGTH, ) packet = ( header + MESSAGE_TYPE_BYTES + message + checksum.compute(header + MESSAGE_TYPE_BYTES + message) ) # make sure we fit into one packet, to make this easier self.assertTrue(len(packet) <= thp_v1._REPORT_LENGTH) buffer = bytearray(1) self.assertTrue(len(buffer) <= len(packet)) gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) gen.send(packet) with self.assertRaises(StopIteration) as e: gen.send(None) # e.value is StopIteration. e.value.value is the return value of the call result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.data, message) # read should have allocated its own buffer and not touch ours self.assertEqual(buffer, b"\x00") def _roundtrip(self): message_payload = bytes(range(256)) message = MessageWithId( MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) ) gen = thp_v1.deprecated_write_message(self.interface, message) # exhaust the iterator: # (XXX we can only do this because the iterator is only accepting None and returns None) for query in gen: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) buffer = bytearray(1024) gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) for packet in self.interface.data: self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) printBytes(packet) query = gen.send(packet) with self.assertRaises(StopIteration) as e: gen.send(None) result = e.value.value self.assertEqual(result.type, MESSAGE_TYPE) self.assertEqual(result.data, message.data) def _write_one_packet(self): message = MessageWithId( MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID) ) gen = thp_v1.deprecated_write_message(self.interface, message) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) with self.assertRaises(StopIteration): gen.send(None) header = make_header( getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH ) expected_message = ( header + MESSAGE_TYPE_BYTES + checksum.compute(header + MESSAGE_TYPE_BYTES) + b"\x00" * (INIT_MESSAGE_DATA_LENGTH - CHECKSUM_LENGTH) ) self.assertTrue(self.interface.data == [expected_message]) def _write_multiple_packets(self): message_payload = bytes(range(256)) message = MessageWithId( MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) ) gen = thp_v1.deprecated_write_message(self.interface, message) header = make_header( PLAINTEXT_1, COMMON_CID, len(message.data) + _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH, ) cont_header = make_cont_header() chksum = checksum.compute( header + message.type.to_bytes(2, "big") + message.data ) packets = [ header + MESSAGE_TYPE_BYTES + message.data[:INIT_MESSAGE_DATA_LENGTH] ] + [ cont_header + chunk for chunk in chunks( message.data[INIT_MESSAGE_DATA_LENGTH:] + chksum, thp_v1._REPORT_LENGTH - HEADER_CONT_LENGTH, ) ] for _ in packets: # we receive as many queries as there are packets query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) # the first sent None only started the generator. the len(packets)-th None # will finish writing and raise StopIteration with self.assertRaises(StopIteration): gen.send(None) # packets must be identical up to the last one self.assertListEqual(packets[:-1], self.interface.data[:-1]) # last packet must be identical up to message length. remaining bytes in # the 64-byte packets are garbage -- in particular, it's the bytes of the # previous packet last_packet = packets[-1] + packets[-2][len(packets[-1]) :] self.assertEqual(last_packet, self.interface.data[-1]) def _read_huge_packet(self): PACKET_COUNT = 1180 # message that takes up 1 180 USB packets message_size = (PACKET_COUNT - 1) * ( REPORT_LENGTH - HEADER_CONT_LENGTH - CHECKSUM_LENGTH - _MESSAGE_TYPE_LEN ) + INIT_MESSAGE_DATA_LENGTH # ensure that a message this big won't fit into memory # Note: this control is changed, because THP has only 2 byte length field self.assertTrue(message_size > thp_v1.MAX_PAYLOAD_LEN) # self.assertRaises(MemoryError, bytearray, message_size) header = make_header(PLAINTEXT_1, COMMON_CID, message_size) packet = header + MESSAGE_TYPE_BYTES + (b"\x00" * INIT_MESSAGE_DATA_LENGTH) buffer = bytearray(65536) gen = thp_v1.deprecated_read_message(self.interface, buffer) query = gen.send(None) # THP returns "Message too large" error after reading the message size, # it is different from codec_v1 as it does not allow big enough messages # to raise MemoryError in this test self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) with self.assertRaises(thp_v1.ThpError) as e: query = gen.send(packet) self.assertEqual(e.value.args[0], "Message too large") if __name__ == "__main__": unittest.main()