|
|
|
@ -67,7 +67,7 @@ def printBytes(a):
|
|
|
|
|
def getPlaintext() -> bytes:
|
|
|
|
|
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
|
|
|
|
|
return PLAINTEXT_1
|
|
|
|
|
PLAINTEXT_0
|
|
|
|
|
return PLAINTEXT_0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def getCid() -> int:
|
|
|
|
@ -97,7 +97,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
buffer = bytearray(64)
|
|
|
|
|
printBytes(cid_req_message)
|
|
|
|
|
gen = thp_v1.read_message(self.interface, buffer)
|
|
|
|
|
query = gen.send(None)
|
|
|
|
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
|
|
|
@ -219,6 +218,32 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
# read should have allocated its own buffer and not touch ours
|
|
|
|
|
self.assertEqual(buffer, b"\x00")
|
|
|
|
|
|
|
|
|
|
def test_roundtrip(self):
|
|
|
|
|
message_payload = bytes(range(256))
|
|
|
|
|
message = MessageWithId(
|
|
|
|
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
|
|
|
|
)
|
|
|
|
|
gen = thp_v1.write_message(self.interface, message)
|
|
|
|
|
# exhaust the iterator:
|
|
|
|
|
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
|
|
|
|
for query in gen:
|
|
|
|
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
|
|
|
|
|
|
|
|
|
buffer = bytearray(1024)
|
|
|
|
|
gen = thp_v1.read_message(self.interface, buffer)
|
|
|
|
|
query = gen.send(None)
|
|
|
|
|
for packet in self.interface.data:
|
|
|
|
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
|
|
|
|
printBytes(packet)
|
|
|
|
|
query = gen.send(packet)
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(StopIteration) as e:
|
|
|
|
|
gen.send(None)
|
|
|
|
|
|
|
|
|
|
result = e.value.value
|
|
|
|
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
|
|
|
|
self.assertEqual(result.data, message.data)
|
|
|
|
|
|
|
|
|
|
def test_write_one_packet(self):
|
|
|
|
|
message = MessageWithId(
|
|
|
|
|
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
|
|
|
|
@ -231,7 +256,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
gen.send(None)
|
|
|
|
|
|
|
|
|
|
header = make_header(
|
|
|
|
|
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
|
|
|
|
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
|
|
|
|
)
|
|
|
|
|
expected_message = (
|
|
|
|
|
header
|
|
|
|
@ -285,32 +310,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|
|
|
|
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
|
|
|
|
self.assertEqual(last_packet, self.interface.data[-1])
|
|
|
|
|
|
|
|
|
|
def test_roundtrip(self):
|
|
|
|
|
message_payload = bytes(range(256))
|
|
|
|
|
message = MessageWithId(
|
|
|
|
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
|
|
|
|
)
|
|
|
|
|
gen = thp_v1.write_message(self.interface, message)
|
|
|
|
|
|
|
|
|
|
# exhaust the iterator:
|
|
|
|
|
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
|
|
|
|
for query in gen:
|
|
|
|
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
|
|
|
|
|
|
|
|
|
buffer = bytearray(1024)
|
|
|
|
|
gen = thp_v1.read_message(self.interface, buffer)
|
|
|
|
|
query = gen.send(None)
|
|
|
|
|
for packet in self.interface.data:
|
|
|
|
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
|
|
|
|
query = gen.send(packet)
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(StopIteration) as e:
|
|
|
|
|
gen.send(None)
|
|
|
|
|
|
|
|
|
|
result = e.value.value
|
|
|
|
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
|
|
|
|
self.assertEqual(result.data, message.data)
|
|
|
|
|
|
|
|
|
|
def test_read_huge_packet(self):
|
|
|
|
|
PACKET_COUNT = 1180
|
|
|
|
|
# message that takes up 1 180 USB packets
|
|
|
|
|