diff --git a/core/tests/test_protobuf.py b/core/tests/test_protobuf.py index c6de12331..524b28069 100644 --- a/core/tests/test_protobuf.py +++ b/core/tests/test_protobuf.py @@ -1,9 +1,7 @@ from common import * import protobuf - -if False: - from typing import Awaitable, Dict +from trezor.utils import BufferIO class Message(protobuf.MessageType): @@ -12,47 +10,22 @@ class Message(protobuf.MessageType): self.enum_field = enum_field @classmethod - def get_fields(cls) -> Dict: + def get_fields(cls): return { 1: ("sint_field", protobuf.SVarintType, 0), 2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0), } -class ByteReader: - def __init__(self, data: bytes) -> None: - self.data = data - self.pos = 0 - - async def areadinto(self, buf: bytearray) -> int: - remaining = len(self.data) - self.pos - limit = len(buf) - if remaining < limit: - raise EOFError - - buf[:] = self.data[self.pos : self.pos + limit] - self.pos += limit - return limit - - -class ByteArrayWriter: - def __init__(self) -> None: - self.buf = bytearray(0) - - async def awrite(self, buf: bytes) -> int: - self.buf.extend(buf) - return len(buf) - - def load_uvarint(data: bytes) -> int: - reader = ByteReader(data) - return await_result(protobuf.load_uvarint(reader)) + reader = BufferIO(data) + return protobuf.load_uvarint(reader) def dump_uvarint(value: int) -> bytearray: - writer = ByteArrayWriter() - await_result(protobuf.dump_uvarint(writer, value)) - return writer.buf + writer = BufferIO(bytearray(16)) + protobuf.dump_uvarint(writer, value) + return memoryview(writer.buffer)[:writer.offset] class TestProtobuf(unittest.TestCase): @@ -91,21 +64,23 @@ class TestProtobuf(unittest.TestCase): def test_validate_enum(self): # ok message: msg = Message(-42, 5) - writer = ByteArrayWriter() - await_result(protobuf.dump_message(writer, msg)) - reader = ByteReader(bytes(writer.buf)) - nmsg = await_result(protobuf.load_message(reader, Message)) + length = protobuf.count_message(msg) + buffer_io = BufferIO(bytearray(length)) + protobuf.dump_message(buffer_io, msg) + buffer_io.seek(0) + nmsg = protobuf.load_message(buffer_io, Message) self.assertEqual(msg.sint_field, nmsg.sint_field) self.assertEqual(msg.enum_field, nmsg.enum_field) # bad enum value: + buffer_io.seek(0) msg = Message(-42, 42) - writer = ByteArrayWriter() - await_result(protobuf.dump_message(writer, msg)) - reader = ByteReader(bytes(writer.buf)) + # XXX this assumes the message will have equal size + protobuf.dump_message(buffer_io, msg) + buffer_io.seek(0) with self.assertRaises(TypeError): - await_result(protobuf.load_message(reader, Message)) + protobuf.load_message(buffer_io, Message) if __name__ == "__main__": diff --git a/core/tests/test_trezor.wire.codec_v1.py b/core/tests/test_trezor.wire.codec_v1.py index 931c0d85e..0785c2f92 100644 --- a/core/tests/test_trezor.wire.codec_v1.py +++ b/core/tests/test_trezor.wire.codec_v1.py @@ -1,14 +1,14 @@ from common import * from ubinascii import unhexlify +import ustruct from trezor import io from trezor.loop import wait -from trezor.utils import chunks +from trezor.utils import chunks, BufferIO from trezor.wire import codec_v1 class MockHID: - def __init__(self, num): self.num = num self.data = [] @@ -20,150 +20,162 @@ class MockHID: self.data.append(bytearray(msg)) return len(msg) + def wait_object(self, mode): + return wait(mode | self.num) + + +MESSAGE_TYPE = 0x4242 -class TestWireCodecV1(unittest.TestCase): - def test_reader(self): - rep_len = 64 - interface_num = 0xdeadbeef - message_type = 0x4321 - message_len = 250 - interface = MockHID(interface_num) - reader = codec_v1.Reader(interface) - - message = bytearray(range(message_len)) - report_header = bytearray(unhexlify('3f23234321000000fa')) - - # open, expected one read - first_report = report_header + message[:rep_len - len(report_header)] - self.assertAsync(reader.aopen(), [(None, wait(io.POLL_READ | interface_num)), (first_report, StopIteration()), ]) - self.assertEqual(reader.type, message_type) - self.assertEqual(reader.size, message_len) - - # empty read - empty_buffer = bytearray() - self.assertAsync(reader.areadinto(empty_buffer), [(None, StopIteration()), ]) - self.assertEqual(len(empty_buffer), 0) - self.assertEqual(reader.size, message_len) - - # short read, expected no read - short_buffer = bytearray(32) - self.assertAsync(reader.areadinto(short_buffer), [(None, StopIteration()), ]) - self.assertEqual(len(short_buffer), 32) - self.assertEqual(short_buffer, message[:len(short_buffer)]) - self.assertEqual(reader.size, message_len - len(short_buffer)) - - # aligned read, expected no read - aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer)) - self.assertAsync(reader.areadinto(aligned_buffer), [(None, StopIteration()), ]) - self.assertEqual(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) - self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) - - # one byte read, expected one read - next_report_header = bytearray(unhexlify('3f')) - next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] - onebyte_buffer = bytearray(1) - self.assertAsync(reader.areadinto(onebyte_buffer), [(None, wait(io.POLL_READ | interface_num)), (next_report, StopIteration()), ]) - self.assertEqual(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) - self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) - - # too long read, raises eof - self.assertAsync(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()), ]) - - # long read, expect multiple reads - start_size = reader.size - long_buffer = bytearray(start_size) - report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):] - report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)] - report_payload_rest = report_payload[len(report_payload_head):] - report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header))) - report_payloads = [report_payload_head] + report_payload_rest - next_reports = [next_report_header + r for r in report_payloads] - expected_syscalls = [] - for i, _ in enumerate(next_reports): - prev_report = next_reports[i - 1] if i > 0 else None - expected_syscalls.append((prev_report, wait(io.POLL_READ | interface_num))) - expected_syscalls.append((next_reports[-1], StopIteration())) - self.assertAsync(reader.areadinto(long_buffer), expected_syscalls) - self.assertEqual(long_buffer, message[-start_size:]) - self.assertEqual(reader.size, 0) - - # one byte read, raises eof - self.assertAsync(reader.areadinto(onebyte_buffer), [(None, EOFError()), ]) - - - def test_writer(self): - rep_len = 64 - interface_num = 0xdeadbeef - message_type = 0x87654321 - message_len = 1024 - interface = MockHID(interface_num) - writer = codec_v1.Writer(interface) - writer.setheader(message_type, message_len) - - # init header corresponding to the data above - report_header = bytearray(unhexlify('3f2323432100000400')) - - self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header))) - - # empty write - start_size = writer.size - self.assertAsync(writer.awrite(bytearray()), [(None, StopIteration()), ]) - self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header))) - self.assertEqual(writer.size, start_size) - - # short write, expected no report - start_size = writer.size - short_payload = bytearray(range(4)) - self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ]) - self.assertEqual(writer.size, start_size - len(short_payload)) - self.assertEqual(writer.data, - report_header + - short_payload + - bytearray(rep_len - len(report_header) - len(short_payload))) - - # aligned write, expected one report - start_size = writer.size - aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload))) - self.assertAsync(writer.awrite(aligned_payload), [(None, wait(io.POLL_WRITE | interface_num)), (None, StopIteration()), ]) - self.assertEqual(interface.data, [report_header + - short_payload + - aligned_payload + - bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ]) - self.assertEqual(writer.size, start_size - len(aligned_payload)) - interface.data.clear() - - # short write, expected no report, but data starts with correct seq and cont marker - report_header = bytearray(unhexlify('3f')) - start_size = writer.size - self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ]) - self.assertEqual(writer.size, start_size - len(short_payload)) - self.assertEqual(writer.data[:len(report_header) + len(short_payload)], - report_header + short_payload) - - # long write, expected multiple reports - start_size = writer.size - long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload))) - long_payload_rest = bytearray(range(start_size - len(long_payload_head))) - long_payload = long_payload_head + long_payload_rest - expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header))) - expected_reports = [report_header + r for r in expected_payloads] - expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1]))) - # test write - expected_write_reports = expected_reports[:-1] - self.assertAsync(writer.awrite(long_payload), len(expected_write_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())]) - self.assertEqual(interface.data, expected_write_reports) - self.assertEqual(writer.size, start_size - len(long_payload)) - interface.data.clear() - # test write raises eof - self.assertAsync(writer.awrite(bytearray(1)), [(None, EOFError())]) - self.assertEqual(interface.data, []) - # test close - expected_close_reports = expected_reports[-1:] - self.assertAsync(writer.aclose(), len(expected_close_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())]) - self.assertEqual(interface.data, expected_close_reports) - self.assertEqual(writer.size, 0) - - -if __name__ == '__main__': +def make_header(mtype, length): + return b"?##" + ustruct.pack(">HL", mtype, length) + + +class TestWireCodecV1(unittest.TestCase): + def setUp(self): + self.interface = MockHID(0xDEADBEEF) + + def test_read_one_packet(self): + # zero length message - just a header + message_packet = make_header(mtype=MESSAGE_TYPE, length=0) + buffer = bytearray(64) + + gen = codec_v1.read_message(self.interface, buffer) + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + + with self.assertRaises(StopIteration) as e: + gen.send(message_packet) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data.buffer, b"") + + # message should have been read into the buffer + self.assertEqual(buffer, b"\x00" * 64) + + def test_read_many_packets(self): + message = bytes(range(256)) + + header = make_header(mtype=MESSAGE_TYPE, length=len(message)) + first_len = codec_v1._REP_LEN - len(header) + # first packet is header + (remaining)data + # other packets are "?" + 63 bytes of data + packets = [header + message[:first_len]] + [ + b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1) + ] + + buffer = bytearray(256) + gen = codec_v1.read_message(self.interface, buffer) + query = gen.send(None) + for packet in packets[:-1]: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + query = gen.send(packet) + + # last packet will stop + with self.assertRaises(StopIteration) as e: + gen.send(packets[-1]) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data.buffer, message) + + # message should have been read into the buffer + self.assertEqual(buffer, message) + + def test_read_large_message(self): + message = b"hello world" + header = make_header(mtype=MESSAGE_TYPE, length=len(message)) + + packet = header + message + # make sure we fit into one packet, to make this easier + self.assertTrue(len(packet) <= codec_v1._REP_LEN) + + buffer = bytearray(1) + self.assertTrue(len(buffer) <= len(packet)) + + gen = codec_v1.read_message(self.interface, buffer) + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + with self.assertRaises(StopIteration) as e: + gen.send(packet) + + # e.value is StopIteration. e.value.value is the return value of the call + result = e.value.value + self.assertEqual(result.type, MESSAGE_TYPE) + self.assertEqual(result.data.buffer, message) + + # read should have allocated its own buffer and not touch ours + self.assertEqual(buffer, b"\x00") + + def test_write_one_packet(self): + gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, b"") + + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + with self.assertRaises(StopIteration): + gen.send(None) + + header = make_header(mtype=MESSAGE_TYPE, length=0) + expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header)) + self.assertTrue(self.interface.data == [expected_message]) + + def test_write_multiple_packets(self): + message = bytes(range(256)) + gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message) + + header = make_header(mtype=MESSAGE_TYPE, length=len(message)) + first_len = codec_v1._REP_LEN - len(header) + # first packet is header + (remaining)data + # other packets are "?" + 63 bytes of data + packets = [header + message[:first_len]] + [ + b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1) + ] + + for _ in packets: + # we receive as many queries as there are packets + query = gen.send(None) + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + # the first sent None only started the generator. the len(packets)-th None + # will finish writing and raise StopIteration + with self.assertRaises(StopIteration): + gen.send(None) + + # packets must be identical up to the last one + self.assertListEqual(packets[:-1], self.interface.data[:-1]) + # last packet must be identical up to message length. remaining bytes in + # the 64-byte packets are garbage -- in particular, it's the bytes of the + # previous packet + last_packet = packets[-1] + packets[-2][len(packets[-1]):] + self.assertEqual(last_packet, self.interface.data[-1]) + + def test_roundtrip(self): + message = bytes(range(256)) + gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message) + + # exhaust the iterator: + # (XXX we can only do this because the iterator is only accepting None and returns None) + for query in gen: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE)) + + buffer = bytearray(1024) + gen = codec_v1.read_message(self.interface, buffer) + query = gen.send(None) + for packet in self.interface.data[:-1]: + self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ)) + query = gen.send(packet) + + with self.assertRaises(StopIteration) as e: + gen.send(self.interface.data[-1]) + + result = e.value.value + assert result.type == MESSAGE_TYPE + assert result.data.buffer == message + + +if __name__ == "__main__": unittest.main() diff --git a/core/tests/unittest.py b/core/tests/unittest.py index a78a68d03..dcade810c 100644 --- a/core/tests/unittest.py +++ b/core/tests/unittest.py @@ -9,6 +9,7 @@ class AssertRaisesContext: def __init__(self, exc): self.expected = exc + self.value = None def __enter__(self): return self @@ -17,6 +18,7 @@ class AssertRaisesContext: if exc_type is None: ensure(False, "%r not raised" % self.expected) if issubclass(exc_type, self.expected): + self.value = exc_value return True return False