1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-14 03:30:02 +00:00

fix(core): fix bugs in thp unittests

This commit is contained in:
M1nd3r 2024-03-19 14:29:30 +01:00
parent 947cd8fa1d
commit 45b0293371

View File

@ -67,7 +67,7 @@ def printBytes(a):
def getPlaintext() -> bytes:
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
return PLAINTEXT_1
PLAINTEXT_0
return PLAINTEXT_0
def getCid() -> int:
@ -97,7 +97,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
)
buffer = bytearray(64)
printBytes(cid_req_message)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
@ -219,6 +218,32 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_roundtrip(self):
message_payload = bytes(range(256))
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
printBytes(packet)
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def test_write_one_packet(self):
message = MessageWithId(
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
@ -231,7 +256,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
gen.send(None)
header = make_header(
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
)
expected_message = (
header
@ -285,32 +310,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def test_roundtrip(self):
message_payload = bytes(range(256))
message = MessageWithId(
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
)
gen = thp_v1.write_message(self.interface, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = thp_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(None)
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message.data)
def test_read_huge_packet(self):
PACKET_COUNT = 1180
# message that takes up 1 180 USB packets