1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-24 22:31:35 +00:00

core/tests: add synchronous protobuf tests

This commit is contained in:
matejcik 2020-07-13 12:59:10 +02:00 committed by Tomas Susanka
parent 31e2170766
commit 0a758b8181
3 changed files with 165 additions and 176 deletions

View File

@ -1,9 +1,7 @@
from common import * from common import *
import protobuf import protobuf
from trezor.utils import BufferIO
if False:
from typing import Awaitable, Dict
class Message(protobuf.MessageType): class Message(protobuf.MessageType):
@ -12,47 +10,22 @@ class Message(protobuf.MessageType):
self.enum_field = enum_field self.enum_field = enum_field
@classmethod @classmethod
def get_fields(cls) -> Dict: def get_fields(cls):
return { return {
1: ("sint_field", protobuf.SVarintType, 0), 1: ("sint_field", protobuf.SVarintType, 0),
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0), 2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
} }
class ByteReader:
def __init__(self, data: bytes) -> None:
self.data = data
self.pos = 0
async def areadinto(self, buf: bytearray) -> int:
remaining = len(self.data) - self.pos
limit = len(buf)
if remaining < limit:
raise EOFError
buf[:] = self.data[self.pos : self.pos + limit]
self.pos += limit
return limit
class ByteArrayWriter:
def __init__(self) -> None:
self.buf = bytearray(0)
async def awrite(self, buf: bytes) -> int:
self.buf.extend(buf)
return len(buf)
def load_uvarint(data: bytes) -> int: def load_uvarint(data: bytes) -> int:
reader = ByteReader(data) reader = BufferIO(data)
return await_result(protobuf.load_uvarint(reader)) return protobuf.load_uvarint(reader)
def dump_uvarint(value: int) -> bytearray: def dump_uvarint(value: int) -> bytearray:
writer = ByteArrayWriter() writer = BufferIO(bytearray(16))
await_result(protobuf.dump_uvarint(writer, value)) protobuf.dump_uvarint(writer, value)
return writer.buf return memoryview(writer.buffer)[:writer.offset]
class TestProtobuf(unittest.TestCase): class TestProtobuf(unittest.TestCase):
@ -91,21 +64,23 @@ class TestProtobuf(unittest.TestCase):
def test_validate_enum(self): def test_validate_enum(self):
# ok message: # ok message:
msg = Message(-42, 5) msg = Message(-42, 5)
writer = ByteArrayWriter() length = protobuf.count_message(msg)
await_result(protobuf.dump_message(writer, msg)) buffer_io = BufferIO(bytearray(length))
reader = ByteReader(bytes(writer.buf)) protobuf.dump_message(buffer_io, msg)
nmsg = await_result(protobuf.load_message(reader, Message)) buffer_io.seek(0)
nmsg = protobuf.load_message(buffer_io, Message)
self.assertEqual(msg.sint_field, nmsg.sint_field) self.assertEqual(msg.sint_field, nmsg.sint_field)
self.assertEqual(msg.enum_field, nmsg.enum_field) self.assertEqual(msg.enum_field, nmsg.enum_field)
# bad enum value: # bad enum value:
buffer_io.seek(0)
msg = Message(-42, 42) msg = Message(-42, 42)
writer = ByteArrayWriter() # XXX this assumes the message will have equal size
await_result(protobuf.dump_message(writer, msg)) protobuf.dump_message(buffer_io, msg)
reader = ByteReader(bytes(writer.buf)) buffer_io.seek(0)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
await_result(protobuf.load_message(reader, Message)) protobuf.load_message(buffer_io, Message)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,14 +1,14 @@
from common import * from common import *
from ubinascii import unhexlify from ubinascii import unhexlify
import ustruct
from trezor import io from trezor import io
from trezor.loop import wait from trezor.loop import wait
from trezor.utils import chunks from trezor.utils import chunks, BufferIO
from trezor.wire import codec_v1 from trezor.wire import codec_v1
class MockHID: class MockHID:
def __init__(self, num): def __init__(self, num):
self.num = num self.num = num
self.data = [] self.data = []
@ -20,150 +20,162 @@ class MockHID:
self.data.append(bytearray(msg)) self.data.append(bytearray(msg))
return len(msg) return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242
def make_header(mtype, length):
return b"?##" + ustruct.pack(">HL", mtype, length)
class TestWireCodecV1(unittest.TestCase): class TestWireCodecV1(unittest.TestCase):
def setUp(self):
self.interface = MockHID(0xDEADBEEF)
def test_reader(self): def test_read_one_packet(self):
rep_len = 64 # zero length message - just a header
interface_num = 0xdeadbeef message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
message_type = 0x4321 buffer = bytearray(64)
message_len = 250
interface = MockHID(interface_num)
reader = codec_v1.Reader(interface)
message = bytearray(range(message_len)) gen = codec_v1.read_message(self.interface, buffer)
report_header = bytearray(unhexlify('3f23234321000000fa'))
# open, expected one read query = gen.send(None)
first_report = report_header + message[:rep_len - len(report_header)] self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
self.assertAsync(reader.aopen(), [(None, wait(io.POLL_READ | interface_num)), (first_report, StopIteration()), ])
self.assertEqual(reader.type, message_type)
self.assertEqual(reader.size, message_len)
# empty read with self.assertRaises(StopIteration) as e:
empty_buffer = bytearray() gen.send(message_packet)
self.assertAsync(reader.areadinto(empty_buffer), [(None, StopIteration()), ])
self.assertEqual(len(empty_buffer), 0)
self.assertEqual(reader.size, message_len)
# short read, expected no read # e.value is StopIteration. e.value.value is the return value of the call
short_buffer = bytearray(32) result = e.value.value
self.assertAsync(reader.areadinto(short_buffer), [(None, StopIteration()), ]) self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(len(short_buffer), 32) self.assertEqual(result.data.buffer, b"")
self.assertEqual(short_buffer, message[:len(short_buffer)])
self.assertEqual(reader.size, message_len - len(short_buffer))
# aligned read, expected no read # message should have been read into the buffer
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) self.assertEqual(buffer, b"\x00" * 64)
self.assertAsync(reader.areadinto(aligned_buffer), [(None, StopIteration()), ])
self.assertEqual(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
# one byte read, expected one read def test_read_many_packets(self):
next_report_header = bytearray(unhexlify('3f')) message = bytes(range(256))
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
onebyte_buffer = bytearray(1)
self.assertAsync(reader.areadinto(onebyte_buffer), [(None, wait(io.POLL_READ | interface_num)), (next_report, StopIteration()), ])
self.assertEqual(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
# too long read, raises eof header = make_header(mtype=MESSAGE_TYPE, length=len(message))
self.assertAsync(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()), ]) first_len = codec_v1._REP_LEN - len(header)
# first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
packets = [header + message[:first_len]] + [
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
]
# long read, expect multiple reads buffer = bytearray(256)
start_size = reader.size gen = codec_v1.read_message(self.interface, buffer)
long_buffer = bytearray(start_size) query = gen.send(None)
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):] for packet in packets[:-1]:
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)] self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
report_payload_rest = report_payload[len(report_payload_head):] query = gen.send(packet)
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
report_payloads = [report_payload_head] + report_payload_rest
next_reports = [next_report_header + r for r in report_payloads]
expected_syscalls = []
for i, _ in enumerate(next_reports):
prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, wait(io.POLL_READ | interface_num)))
expected_syscalls.append((next_reports[-1], StopIteration()))
self.assertAsync(reader.areadinto(long_buffer), expected_syscalls)
self.assertEqual(long_buffer, message[-start_size:])
self.assertEqual(reader.size, 0)
# one byte read, raises eof # last packet will stop
self.assertAsync(reader.areadinto(onebyte_buffer), [(None, EOFError()), ]) with self.assertRaises(StopIteration) as e:
gen.send(packets[-1])
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message)
# message should have been read into the buffer
self.assertEqual(buffer, message)
def test_read_large_message(self):
message = b"hello world"
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
packet = header + message
# make sure we fit into one packet, to make this easier
self.assertTrue(len(packet) <= codec_v1._REP_LEN)
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
with self.assertRaises(StopIteration) as e:
gen.send(packet)
# e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
self.assertEqual(result.type, MESSAGE_TYPE)
self.assertEqual(result.data.buffer, message)
# read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
def test_write_one_packet(self):
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, b"")
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
with self.assertRaises(StopIteration):
gen.send(None)
header = make_header(mtype=MESSAGE_TYPE, length=0)
expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header))
self.assertTrue(self.interface.data == [expected_message])
def test_write_multiple_packets(self):
message = bytes(range(256))
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
first_len = codec_v1._REP_LEN - len(header)
# first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
packets = [header + message[:first_len]] + [
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
]
for _ in packets:
# we receive as many queries as there are packets
query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# the first sent None only started the generator. the len(packets)-th None
# will finish writing and raise StopIteration
with self.assertRaises(StopIteration):
gen.send(None)
# packets must be identical up to the last one
self.assertListEqual(packets[:-1], self.interface.data[:-1])
# last packet must be identical up to message length. remaining bytes in
# the 64-byte packets are garbage -- in particular, it's the bytes of the
# previous packet
last_packet = packets[-1] + packets[-2][len(packets[-1]):]
self.assertEqual(last_packet, self.interface.data[-1])
def test_roundtrip(self):
message = bytes(range(256))
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
# exhaust the iterator:
# (XXX we can only do this because the iterator is only accepting None and returns None)
for query in gen:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
buffer = bytearray(1024)
gen = codec_v1.read_message(self.interface, buffer)
query = gen.send(None)
for packet in self.interface.data[:-1]:
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
with self.assertRaises(StopIteration) as e:
gen.send(self.interface.data[-1])
result = e.value.value
assert result.type == MESSAGE_TYPE
assert result.data.buffer == message
def test_writer(self): if __name__ == "__main__":
rep_len = 64
interface_num = 0xdeadbeef
message_type = 0x87654321
message_len = 1024
interface = MockHID(interface_num)
writer = codec_v1.Writer(interface)
writer.setheader(message_type, message_len)
# init header corresponding to the data above
report_header = bytearray(unhexlify('3f2323432100000400'))
self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header)))
# empty write
start_size = writer.size
self.assertAsync(writer.awrite(bytearray()), [(None, StopIteration()), ])
self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header)))
self.assertEqual(writer.size, start_size)
# short write, expected no report
start_size = writer.size
short_payload = bytearray(range(4))
self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ])
self.assertEqual(writer.size, start_size - len(short_payload))
self.assertEqual(writer.data,
report_header +
short_payload +
bytearray(rep_len - len(report_header) - len(short_payload)))
# aligned write, expected one report
start_size = writer.size
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
self.assertAsync(writer.awrite(aligned_payload), [(None, wait(io.POLL_WRITE | interface_num)), (None, StopIteration()), ])
self.assertEqual(interface.data, [report_header +
short_payload +
aligned_payload +
bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ])
self.assertEqual(writer.size, start_size - len(aligned_payload))
interface.data.clear()
# short write, expected no report, but data starts with correct seq and cont marker
report_header = bytearray(unhexlify('3f'))
start_size = writer.size
self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ])
self.assertEqual(writer.size, start_size - len(short_payload))
self.assertEqual(writer.data[:len(report_header) + len(short_payload)],
report_header + short_payload)
# long write, expected multiple reports
start_size = writer.size
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
long_payload = long_payload_head + long_payload_rest
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
expected_reports = [report_header + r for r in expected_payloads]
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
# test write
expected_write_reports = expected_reports[:-1]
self.assertAsync(writer.awrite(long_payload), len(expected_write_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())])
self.assertEqual(interface.data, expected_write_reports)
self.assertEqual(writer.size, start_size - len(long_payload))
interface.data.clear()
# test write raises eof
self.assertAsync(writer.awrite(bytearray(1)), [(None, EOFError())])
self.assertEqual(interface.data, [])
# test close
expected_close_reports = expected_reports[-1:]
self.assertAsync(writer.aclose(), len(expected_close_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())])
self.assertEqual(interface.data, expected_close_reports)
self.assertEqual(writer.size, 0)
if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -9,6 +9,7 @@ class AssertRaisesContext:
def __init__(self, exc): def __init__(self, exc):
self.expected = exc self.expected = exc
self.value = None
def __enter__(self): def __enter__(self):
return self return self
@ -17,6 +18,7 @@ class AssertRaisesContext:
if exc_type is None: if exc_type is None:
ensure(False, "%r not raised" % self.expected) ensure(False, "%r not raised" % self.expected)
if issubclass(exc_type, self.expected): if issubclass(exc_type, self.expected):
self.value = exc_value
return True return True
return False return False