diff --git a/core/tests/test_trezor.wire.thp_v1.py b/core/tests/test_trezor.wire.thp_v1.py index a3120d43e..a2ffb2aec 100644 --- a/core/tests/test_trezor.wire.thp_v1.py +++ b/core/tests/test_trezor.wire.thp_v1.py @@ -67,7 +67,7 @@ def printBytes(a): def getPlaintext() -> bytes: if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1: return PLAINTEXT_1 - PLAINTEXT_0 + return PLAINTEXT_0 def getCid() -> int: @@ -97,7 +97,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): ) buffer = bytearray(64) - printBytes(cid_req_message) gen = thp_v1.read_message(self.interface, buffer) query = gen.send(None) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) @@ -219,6 +218,32 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): # read should have allocated its own buffer and not touch ours self.assertEqual(buffer, b"\x00") + def test_roundtrip(self): + message_payload = bytes(range(256)) + message = MessageWithId( + MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) + ) + gen = thp_v1.write_message(self.interface, message) + # exhaust the iterator: + # (XXX we can only do this because the iterator is only accepting None and returns None) + for query in gen: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + buffer = bytearray(1024) + gen = thp_v1.read_message(self.interface, buffer) + query = gen.send(None) + for packet in self.interface.data: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + printBytes(packet) + query = gen.send(packet) + + with self.assertRaises(StopIteration) as e: + gen.send(None) + + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data, message.data) + def test_write_one_packet(self): message = MessageWithId( MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID) @@ -231,7 +256,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): gen.send(None) header = make_header( - PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH + getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH ) expected_message = ( header @@ -285,32 +310,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase): last_packet = packets[-1] + packets[-2][len(packets[-1]) :] self.assertEqual(last_packet, self.interface.data[-1]) - def test_roundtrip(self): - message_payload = bytes(range(256)) - message = MessageWithId( - MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID) - ) - gen = thp_v1.write_message(self.interface, message) - - # exhaust the iterator: - # (XXX we can only do this because the iterator is only accepting None and returns None) - for query in gen: - self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) - - buffer = bytearray(1024) - gen = thp_v1.read_message(self.interface, buffer) - query = gen.send(None) - for packet in self.interface.data: - self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) - query = gen.send(packet) - - with self.assertRaises(StopIteration) as e: - gen.send(None) - - result = e.value.value - self.assertEqual(result.type, MESSAGE_TYPE) - self.assertEqual(result.data, message.data) - def test_read_huge_packet(self): PACKET_COUNT = 1180 # message that takes up 1 180 USB packets