You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/tests/test_trezor.wire.codec_v1.py

207 lines
7.2 KiB

from common import *
from ubinascii import unhexlify
import ustruct
from trezor import io
from trezor.loop import wait
from trezor.utils import chunks
from trezor.wire import codec_v1
class MockHID:
def __init__(self, num):
self.num = num
self.data = []
def iface_num(self):
return self.num
def write(self, msg):
self.data.append(bytearray(msg))
return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242
HEADER_PAYLOAD_LENGTH = codec_v1._REP_LEN - 3 - ustruct.calcsize(">HL")
def make_header(mtype, length):
return b"?##" + ustruct.pack(">HL", mtype, length)
class TestWireCodecV1(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def test_read_one_packet(self):
# zero length message - just a header
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
buffer = bytearray(64)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(message_packet)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, b"")
# message should have been read into the buffer
self.assertEqual(buffer, b"\x00" * 64)
def test_read_many_packets(self):
message = bytes(range(256))
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
# first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
b"?" + chunk
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
]
buffer = bytearray(256)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in packets[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
# last packet will stop
with self.assertRaises(StopIteration) as e:
gen.send(packets[-1])
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# message should have been read into the buffer
self.assertEqual(buffer, message)
def test_read_large_message(self):
message = b"hello world"
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
packet = header + message
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= codec_v1._REP_LEN)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(packet)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_write_one_packet(self):
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, b"")
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(mtype=MESSAGE_TYPE, length=0)
expected_message = header + b"\x00" * HEADER_PAYLOAD_LENGTH
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
message = bytes(range(256))
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
# first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
packets = [header + message[:HEADER_PAYLOAD_LENGTH]] + [
b"?" + chunk
for chunk in chunks(message[HEADER_PAYLOAD_LENGTH:], codec_v1._REP_LEN - 1)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]) :]
self.assertEqual(last_packet, self.interface.data[-1])
def test_roundtrip(self):
message = bytes(range(256))
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(self.interface.data[-1])
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data, message)
def test_read_huge_packet(self):
PACKET_COUNT = 100_000
# message that takes up 100 000 USB packets
message_size = (PACKET_COUNT - 1) * 63 + HEADER_PAYLOAD_LENGTH
# ensure that a message this big won't fit into memory
self.assertRaises(MemoryError, bytearray, message_size)
header = make_header(mtype=MESSAGE_TYPE, length=message_size)
packet = header + b"\x00" * HEADER_PAYLOAD_LENGTH
buffer = bytearray(65536)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
for _ in range(PACKET_COUNT - 1):
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(codec_v1.CodecError) as e:
gen.send(packet)
self.assertEqual(e.value.args[0], "Message too large")
if __name__ == "__main__":
unittest.main()