Fix thp tests

M1nd3r/thp5
M1nd3r 2 months ago
parent 6b9a2c64f1
commit 3e781b24ab

@ -67,7 +67,7 @@ def printBytes(a):
def getPlaintext() -> bytes:
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
return PLAINTEXT_1
PLAINTEXT_0
return PLAINTEXT_0
def getCid() -> int:
@ -97,7 +97,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
)
buffer = bytearray(64)
printBytes(cid_req_message)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
@ -219,6 +218,32 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_roundtrip(self):
message_payload = bytes(range(256))
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
printBytes(packet)
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def test_write_one_packet(self):
message = MessageWithId(
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
@ -231,7 +256,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
gen.send(None)
header = make_header(
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
expected_message = (
header
@ -285,32 +310,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def test_roundtrip(self):
message_payload = bytes(range(256))
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def test_read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets

Loading…
Cancel
Save