From e928fdbe2261e987d04d32256a1d4b8a31f72526 Mon Sep 17 00:00:00 2001 From: M1nd3r Date: Wed, 31 Jul 2024 11:51:41 +0200 Subject: [PATCH] test(core): improve thp tests [no changelog] --- core/tests/test_trezor.wire.thp.crypto.py | 6 +++ core/tests/test_trezor.wire.thp.writer.py | 60 ++++++++++------------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/core/tests/test_trezor.wire.thp.crypto.py b/core/tests/test_trezor.wire.thp.crypto.py index d6a6281a0..faff51424 100644 --- a/core/tests/test_trezor.wire.thp.crypto.py +++ b/core/tests/test_trezor.wire.thp.crypto.py @@ -3,6 +3,12 @@ from trezorcrypto import aesgcm, curve25519 import storage +if __debug__: + # Disable log.debug for the test + from trezor import log + + log.debug = lambda name, msg, *args: None + if utils.USE_THP: from trezor.wire.thp import crypto from trezor.wire.thp.crypto import IV_1, IV_2, Handshake diff --git a/core/tests/test_trezor.wire.thp.writer.py b/core/tests/test_trezor.wire.thp.writer.py index ddbefbd4d..87299c514 100644 --- a/core/tests/test_trezor.wire.thp.writer.py +++ b/core/tests/test_trezor.wire.thp.writer.py @@ -1,5 +1,7 @@ from common import * # isort:skip +from typing import Any, Awaitable + if utils.USE_THP: from trezor.wire.thp import writer from trezor.wire.thp.thp_messages import ENCRYPTED_TRANSPORT, PacketHeader @@ -26,6 +28,7 @@ if __debug__: log.debug = lambda name, msg, *args: None + @unittest.skipUnless(utils.USE_THP, "only needed for THP") class TestTrezorHostProtocolWriter(unittest.TestCase): short_payload_expected = b"04123400050700000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000" @@ -80,46 +83,41 @@ class TestTrezorHostProtocolWriter(unittest.TestCase): b"801234b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1", b"801234f2f3f4f5f6f7f8f9fafbfcfdfefff40c65ee00000000000000000000000000000000000000000000000000000000000000000000000000000000000000", ] + def await_until_result(self, task: Awaitable) -> Any: + with self.assertRaises(StopIteration): + while True: + task.send(None) def setUp(self): self.interface = MockHID(0xDEADBEEF) def test_write_empty_packet(self): - gen = writer.write_packet_to_wire(self.interface, b"") - with self.assertRaises(StopIteration): - gen.send(None) - gen.send(None) + self.await_until_result(writer.write_packet_to_wire(self.interface, b"")) + print(self.interface.data[0]) self.assertEqual(len(self.interface.data), 1) self.assertEqual(self.interface.data[0], b"") def test_write_empty_payload(self): header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 4) - gen = writer.write_payloads_to_wire(self.interface, header, (b"",)) - - with self.assertRaises(StopIteration): - gen.send(None) + await_result(writer.write_payloads_to_wire(self.interface, header, (b"",))) self.assertEqual(len(self.interface.data), 0) def test_write_short_payload(self): header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 5) data = b"\x07" - gen = writer.write_payloads_to_wire(self.interface, header, (data,)) - - gen.send(None) - with self.assertRaises(StopIteration): - gen.send(None) + self.await_until_result( + writer.write_payloads_to_wire(self.interface, header, (data,)) + ) self.assertEqual(hexlify(self.interface.data[0]), self.short_payload_expected) def test_write_longer_payload(self): data = bytearray(range(256)) header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 256) - gen = writer.write_payloads_to_wire(self.interface, header, (data,)) + self.await_until_result( + writer.write_payloads_to_wire(self.interface, header, (data,)) + ) - for i in range(5): - gen.send(None) - with self.assertRaises(StopIteration): - gen.send(None) for i in range(len(self.longer_payload_expected)): self.assertEqual( hexlify(self.interface.data[i]), self.longer_payload_expected[i] @@ -128,14 +126,11 @@ class TestTrezorHostProtocolWriter(unittest.TestCase): def test_write_eight_longer_payloads(self): data = bytearray(range(256)) header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 2048) - gen = writer.write_payloads_to_wire( - self.interface, header, (data, data, data, data, data, data, data, data) + self.await_until_result( + writer.write_payloads_to_wire( + self.interface, header, (data, data, data, data, data, data, data, data) + ) ) - - for i in range(34): - gen.send(None) - with self.assertRaises(StopIteration): - gen.send(None) for i in range(len(self.eight_longer_payloads_expected)): self.assertEqual( hexlify(self.interface.data[i]), self.eight_longer_payloads_expected[i] @@ -143,11 +138,10 @@ class TestTrezorHostProtocolWriter(unittest.TestCase): def test_write_empty_payload_with_checksum(self): header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 4) - gen = writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"") + self.await_until_result( + writer.write_payload_to_wire_and_add_checksum(self.interface, header, b"") + ) - gen.send(None) - with self.assertRaises(StopIteration): - gen.send(None) self.assertEqual( hexlify(self.interface.data[0]), self.empty_payload_with_checksum_expected ) @@ -155,14 +149,10 @@ class TestTrezorHostProtocolWriter(unittest.TestCase): def test_write_longer_payload_with_checksum(self): data = bytearray(range(256)) header = PacketHeader(ENCRYPTED_TRANSPORT, 4660, 256) - gen = writer.write_payload_to_wire_and_add_checksum( - self.interface, header, data + self.await_until_result( + writer.write_payload_to_wire_and_add_checksum(self.interface, header, data) ) - for i in range(5): - gen.send(None) - with self.assertRaises(StopIteration): - gen.send(None) for i in range(len(self.longer_payload_with_checksum_expected)): self.assertEqual( hexlify(self.interface.data[i]),