mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-02-24 05:12:02 +00:00
Fix thp tests
This commit is contained in:
parent
e4a4f8f125
commit
d8079bfd24
@ -67,7 +67,7 @@ def printBytes(a):
|
|||||||
def getPlaintext() -> bytes:
|
def getPlaintext() -> bytes:
|
||||||
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
|
if THP.sync_get_receive_expected_bit(THP.get_active_session()) == 1:
|
||||||
return PLAINTEXT_1
|
return PLAINTEXT_1
|
||||||
PLAINTEXT_0
|
return PLAINTEXT_0
|
||||||
|
|
||||||
|
|
||||||
def getCid() -> int:
|
def getCid() -> int:
|
||||||
@ -97,7 +97,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
buffer = bytearray(64)
|
buffer = bytearray(64)
|
||||||
printBytes(cid_req_message)
|
|
||||||
gen = thp_v1.read_message(self.interface, buffer)
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
query = gen.send(None)
|
query = gen.send(None)
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
@ -219,6 +218,32 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
# read should have allocated its own buffer and not touch ours
|
# read should have allocated its own buffer and not touch ours
|
||||||
self.assertEqual(buffer, b"\x00")
|
self.assertEqual(buffer, b"\x00")
|
||||||
|
|
||||||
|
def test_roundtrip(self):
|
||||||
|
message_payload = bytes(range(256))
|
||||||
|
message = MessageWithId(
|
||||||
|
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
||||||
|
)
|
||||||
|
gen = thp_v1.write_message(self.interface, message)
|
||||||
|
# exhaust the iterator:
|
||||||
|
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
||||||
|
for query in gen:
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||||
|
|
||||||
|
buffer = bytearray(1024)
|
||||||
|
gen = thp_v1.read_message(self.interface, buffer)
|
||||||
|
query = gen.send(None)
|
||||||
|
for packet in self.interface.data:
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
printBytes(packet)
|
||||||
|
query = gen.send(packet)
|
||||||
|
|
||||||
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(None)
|
||||||
|
|
||||||
|
result = e.value.value
|
||||||
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
|
self.assertEqual(result.data, message.data)
|
||||||
|
|
||||||
def test_write_one_packet(self):
|
def test_write_one_packet(self):
|
||||||
message = MessageWithId(
|
message = MessageWithId(
|
||||||
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
|
MESSAGE_TYPE, b"", THP._get_id(self.interface, COMMON_CID)
|
||||||
@ -231,7 +256,7 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
gen.send(None)
|
gen.send(None)
|
||||||
|
|
||||||
header = make_header(
|
header = make_header(
|
||||||
PLAINTEXT_0, COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
getPlaintext(), COMMON_CID, _MESSAGE_TYPE_LEN + CHECKSUM_LENGTH
|
||||||
)
|
)
|
||||||
expected_message = (
|
expected_message = (
|
||||||
header
|
header
|
||||||
@ -285,32 +310,6 @@ class TestWireTrezorHostProtocolV1(unittest.TestCase):
|
|||||||
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
|
||||||
self.assertEqual(last_packet, self.interface.data[-1])
|
self.assertEqual(last_packet, self.interface.data[-1])
|
||||||
|
|
||||||
def test_roundtrip(self):
|
|
||||||
message_payload = bytes(range(256))
|
|
||||||
message = MessageWithId(
|
|
||||||
MESSAGE_TYPE, message_payload, THP._get_id(self.interface, COMMON_CID)
|
|
||||||
)
|
|
||||||
gen = thp_v1.write_message(self.interface, message)
|
|
||||||
|
|
||||||
# exhaust the iterator:
|
|
||||||
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
|
||||||
for query in gen:
|
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
|
||||||
|
|
||||||
buffer = bytearray(1024)
|
|
||||||
gen = thp_v1.read_message(self.interface, buffer)
|
|
||||||
query = gen.send(None)
|
|
||||||
for packet in self.interface.data:
|
|
||||||
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
|
||||||
query = gen.send(packet)
|
|
||||||
|
|
||||||
with self.assertRaises(StopIteration) as e:
|
|
||||||
gen.send(None)
|
|
||||||
|
|
||||||
result = e.value.value
|
|
||||||
self.assertEqual(result.type, MESSAGE_TYPE)
|
|
||||||
self.assertEqual(result.data, message.data)
|
|
||||||
|
|
||||||
def test_read_huge_packet(self):
|
def test_read_huge_packet(self):
|
||||||
PACKET_COUNT = 1180
|
PACKET_COUNT = 1180
|
||||||
# message that takes up 1 180 USB packets
|
# message that takes up 1 180 USB packets
|
||||||
|
Loading…
Reference in New Issue
Block a user