core/tests: add synchronous protobuf tests

pull/1128/head
matejcik 4 years ago committed by Tomas Susanka
parent 31e2170766
commit 0a758b8181

@ -1,9 +1,7 @@
from common import * from common import *
import protobuf import protobuf
from trezor.utils import BufferIO
if False:
from typing import Awaitable, Dict
class Message(protobuf.MessageType): class Message(protobuf.MessageType):
@ -12,47 +10,22 @@ class Message(protobuf.MessageType):
self.enum_field = enum_field self.enum_field = enum_field
@classmethod @classmethod
def get_fields(cls) -> Dict: def get_fields(cls):
return { return {
1: ("sint_field", protobuf.SVarintType, 0), 1: ("sint_field", protobuf.SVarintType, 0),
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0), 2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
} }
class ByteReader:
def __init__(self, data: bytes) -> None:
self.data = data
self.pos = 0
async def areadinto(self, buf: bytearray) -> int:
remaining = len(self.data) - self.pos
limit = len(buf)
if remaining < limit:
raise EOFError
buf[:] = self.data[self.pos : self.pos + limit]
self.pos += limit
return limit
class ByteArrayWriter:
def __init__(self) -> None:
self.buf = bytearray(0)
async def awrite(self, buf: bytes) -> int:
self.buf.extend(buf)
return len(buf)
def load_uvarint(data: bytes) -> int: def load_uvarint(data: bytes) -> int:
reader = ByteReader(data) reader = BufferIO(data)
return await_result(protobuf.load_uvarint(reader)) return protobuf.load_uvarint(reader)
def dump_uvarint(value: int) -> bytearray: def dump_uvarint(value: int) -> bytearray:
writer = ByteArrayWriter() writer = BufferIO(bytearray(16))
await_result(protobuf.dump_uvarint(writer, value)) protobuf.dump_uvarint(writer, value)
return writer.buf return memoryview(writer.buffer)[:writer.offset]
class TestProtobuf(unittest.TestCase): class TestProtobuf(unittest.TestCase):
@ -91,21 +64,23 @@ class TestProtobuf(unittest.TestCase):
def test_validate_enum(self): def test_validate_enum(self):
# ok message: # ok message:
msg = Message(-42, 5) msg = Message(-42, 5)
writer = ByteArrayWriter() length = protobuf.count_message(msg)
await_result(protobuf.dump_message(writer, msg)) buffer_io = BufferIO(bytearray(length))
reader = ByteReader(bytes(writer.buf)) protobuf.dump_message(buffer_io, msg)
nmsg = await_result(protobuf.load_message(reader, Message)) buffer_io.seek(0)
nmsg = protobuf.load_message(buffer_io, Message)
self.assertEqual(msg.sint_field, nmsg.sint_field) self.assertEqual(msg.sint_field, nmsg.sint_field)
self.assertEqual(msg.enum_field, nmsg.enum_field) self.assertEqual(msg.enum_field, nmsg.enum_field)
# bad enum value: # bad enum value:
buffer_io.seek(0)
msg = Message(-42, 42) msg = Message(-42, 42)
writer = ByteArrayWriter() # XXX this assumes the message will have equal size
await_result(protobuf.dump_message(writer, msg)) protobuf.dump_message(buffer_io, msg)
reader = ByteReader(bytes(writer.buf)) buffer_io.seek(0)
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
await_result(protobuf.load_message(reader, Message)) protobuf.load_message(buffer_io, Message)
if __name__ == "__main__": if __name__ == "__main__":

@ -1,14 +1,14 @@
from common import * from common import *
from ubinascii import unhexlify from ubinascii import unhexlify
import ustruct
from trezor import io from trezor import io
from trezor.loop import wait from trezor.loop import wait
from trezor.utils import chunks from trezor.utils import chunks, BufferIO
from trezor.wire import codec_v1 from trezor.wire import codec_v1
class MockHID: class MockHID:
def __init__(self, num): def __init__(self, num):
self.num = num self.num = num
self.data = [] self.data = []
@ -20,150 +20,162 @@ class MockHID:
self.data.append(bytearray(msg)) self.data.append(bytearray(msg))
return len(msg) return len(msg)
def wait_object(self, mode):
return wait(mode | self.num)
MESSAGE_TYPE = 0x4242
class TestWireCodecV1(unittest.TestCase):
def test_reader(self): def make_header(mtype, length):
rep_len = 64 return b"?##" + ustruct.pack(">HL", mtype, length)
interface_num = 0xdeadbeef
message_type = 0x4321
message_len = 250 class TestWireCodecV1(unittest.TestCase):
interface = MockHID(interface_num) def setUp(self):
reader = codec_v1.Reader(interface) self.interface = MockHID(0xDEADBEEF)
message = bytearray(range(message_len)) def test_read_one_packet(self):
report_header = bytearray(unhexlify('3f23234321000000fa')) # zero length message - just a header
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
# open, expected one read buffer = bytearray(64)
first_report = report_header + message[:rep_len - len(report_header)]
self.assertAsync(reader.aopen(), [(None, wait(io.POLL_READ | interface_num)), (first_report, StopIteration()), ]) gen = codec_v1.read_message(self.interface, buffer)
self.assertEqual(reader.type, message_type)
self.assertEqual(reader.size, message_len) query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
# empty read
empty_buffer = bytearray() with self.assertRaises(StopIteration) as e:
self.assertAsync(reader.areadinto(empty_buffer), [(None, StopIteration()), ]) gen.send(message_packet)
self.assertEqual(len(empty_buffer), 0)
self.assertEqual(reader.size, message_len) # e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
# short read, expected no read self.assertEqual(result.type, MESSAGE_TYPE)
short_buffer = bytearray(32) self.assertEqual(result.data.buffer, b"")
self.assertAsync(reader.areadinto(short_buffer), [(None, StopIteration()), ])
self.assertEqual(len(short_buffer), 32) # message should have been read into the buffer
self.assertEqual(short_buffer, message[:len(short_buffer)]) self.assertEqual(buffer, b"\x00" * 64)
self.assertEqual(reader.size, message_len - len(short_buffer))
def test_read_many_packets(self):
# aligned read, expected no read message = bytes(range(256))
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
self.assertAsync(reader.areadinto(aligned_buffer), [(None, StopIteration()), ]) header = make_header(mtype=MESSAGE_TYPE, length=len(message))
self.assertEqual(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)]) first_len = codec_v1._REP_LEN - len(header)
self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer)) # first packet is header + (remaining)data
# other packets are "?" + 63 bytes of data
# one byte read, expected one read packets = [header + message[:first_len]] + [
next_report_header = bytearray(unhexlify('3f')) b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)] ]
onebyte_buffer = bytearray(1)
self.assertAsync(reader.areadinto(onebyte_buffer), [(None, wait(io.POLL_READ | interface_num)), (next_report, StopIteration()), ]) buffer = bytearray(256)
self.assertEqual(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)]) gen = codec_v1.read_message(self.interface, buffer)
self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer)) query = gen.send(None)
for packet in packets[:-1]:
# too long read, raises eof self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
self.assertAsync(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()), ]) query = gen.send(packet)
# long read, expect multiple reads # last packet will stop
start_size = reader.size with self.assertRaises(StopIteration) as e:
long_buffer = bytearray(start_size) gen.send(packets[-1])
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)] # e.value is StopIteration. e.value.value is the return value of the call
report_payload_rest = report_payload[len(report_payload_head):] result = e.value.value
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header))) self.assertEqual(result.type, MESSAGE_TYPE)
report_payloads = [report_payload_head] + report_payload_rest self.assertEqual(result.data.buffer, message)
next_reports = [next_report_header + r for r in report_payloads]
expected_syscalls = [] # message should have been read into the buffer
for i, _ in enumerate(next_reports): self.assertEqual(buffer, message)
prev_report = next_reports[i - 1] if i > 0 else None
expected_syscalls.append((prev_report, wait(io.POLL_READ | interface_num))) def test_read_large_message(self):
expected_syscalls.append((next_reports[-1], StopIteration())) message = b"hello world"
self.assertAsync(reader.areadinto(long_buffer), expected_syscalls) header = make_header(mtype=MESSAGE_TYPE, length=len(message))
self.assertEqual(long_buffer, message[-start_size:])
self.assertEqual(reader.size, 0) packet = header + message
# make sure we fit into one packet, to make this easier
# one byte read, raises eof self.assertTrue(len(packet) <= codec_v1._REP_LEN)
self.assertAsync(reader.areadinto(onebyte_buffer), [(None, EOFError()), ])
buffer = bytearray(1)
self.assertTrue(len(buffer) <= len(packet))
def test_writer(self):
rep_len = 64 gen = codec_v1.read_message(self.interface, buffer)
interface_num = 0xdeadbeef query = gen.send(None)
message_type = 0x87654321 self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
message_len = 1024 with self.assertRaises(StopIteration) as e:
interface = MockHID(interface_num) gen.send(packet)
writer = codec_v1.Writer(interface)
writer.setheader(message_type, message_len) # e.value is StopIteration. e.value.value is the return value of the call
result = e.value.value
# init header corresponding to the data above self.assertEqual(result.type, MESSAGE_TYPE)
report_header = bytearray(unhexlify('3f2323432100000400')) self.assertEqual(result.data.buffer, message)
self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header))) # read should have allocated its own buffer and not touch ours
self.assertEqual(buffer, b"\x00")
# empty write
start_size = writer.size def test_write_one_packet(self):
self.assertAsync(writer.awrite(bytearray()), [(None, StopIteration()), ]) gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, b"")
self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header)))
self.assertEqual(writer.size, start_size) query = gen.send(None)
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
# short write, expected no report with self.assertRaises(StopIteration):
start_size = writer.size gen.send(None)
short_payload = bytearray(range(4))
self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ]) header = make_header(mtype=MESSAGE_TYPE, length=0)
self.assertEqual(writer.size, start_size - len(short_payload)) expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header))
self.assertEqual(writer.data, self.assertTrue(self.interface.data == [expected_message])
report_header +
short_payload + def test_write_multiple_packets(self):
bytearray(rep_len - len(report_header) - len(short_payload))) message = bytes(range(256))
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
# aligned write, expected one report
start_size = writer.size header = make_header(mtype=MESSAGE_TYPE, length=len(message))
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload))) first_len = codec_v1._REP_LEN - len(header)
self.assertAsync(writer.awrite(aligned_payload), [(None, wait(io.POLL_WRITE | interface_num)), (None, StopIteration()), ]) # first packet is header + (remaining)data
self.assertEqual(interface.data, [report_header + # other packets are "?" + 63 bytes of data
short_payload + packets = [header + message[:first_len]] + [
aligned_payload + b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ]) ]
self.assertEqual(writer.size, start_size - len(aligned_payload))
interface.data.clear() for _ in packets:
# we receive as many queries as there are packets
# short write, expected no report, but data starts with correct seq and cont marker query = gen.send(None)
report_header = bytearray(unhexlify('3f')) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
start_size = writer.size
self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ]) # the first sent None only started the generator. the len(packets)-th None
self.assertEqual(writer.size, start_size - len(short_payload)) # will finish writing and raise StopIteration
self.assertEqual(writer.data[:len(report_header) + len(short_payload)], with self.assertRaises(StopIteration):
report_header + short_payload) gen.send(None)
# long write, expected multiple reports # packets must be identical up to the last one
start_size = writer.size self.assertListEqual(packets[:-1], self.interface.data[:-1])
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload))) # last packet must be identical up to message length. remaining bytes in
long_payload_rest = bytearray(range(start_size - len(long_payload_head))) # the 64-byte packets are garbage -- in particular, it's the bytes of the
long_payload = long_payload_head + long_payload_rest # previous packet
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header))) last_packet = packets[-1] + packets[-2][len(packets[-1]):]
expected_reports = [report_header + r for r in expected_payloads] self.assertEqual(last_packet, self.interface.data[-1])
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
# test write def test_roundtrip(self):
expected_write_reports = expected_reports[:-1] message = bytes(range(256))
self.assertAsync(writer.awrite(long_payload), len(expected_write_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())]) gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
self.assertEqual(interface.data, expected_write_reports)
self.assertEqual(writer.size, start_size - len(long_payload)) # exhaust the iterator:
interface.data.clear() # (XXX we can only do this because the iterator is only accepting None and returns None)
# test write raises eof for query in gen:
self.assertAsync(writer.awrite(bytearray(1)), [(None, EOFError())]) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
self.assertEqual(interface.data, [])
# test close buffer = bytearray(1024)
expected_close_reports = expected_reports[-1:] gen = codec_v1.read_message(self.interface, buffer)
self.assertAsync(writer.aclose(), len(expected_close_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())]) query = gen.send(None)
self.assertEqual(interface.data, expected_close_reports) for packet in self.interface.data[:-1]:
self.assertEqual(writer.size, 0) self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
query = gen.send(packet)
if __name__ == '__main__': with self.assertRaises(StopIteration) as e:
gen.send(self.interface.data[-1])
result = e.value.value
assert result.type == MESSAGE_TYPE
assert result.data.buffer == message
if __name__ == "__main__":
unittest.main() unittest.main()

@ -9,6 +9,7 @@ class AssertRaisesContext:
def __init__(self, exc): def __init__(self, exc):
self.expected = exc self.expected = exc
self.value = None
def __enter__(self): def __enter__(self):
return self return self
@ -17,6 +18,7 @@ class AssertRaisesContext:
if exc_type is None: if exc_type is None:
ensure(False, "%r not raised" % self.expected) ensure(False, "%r not raised" % self.expected)
if issubclass(exc_type, self.expected): if issubclass(exc_type, self.expected):
self.value = exc_value
return True return True
return False return False

Loading…
Cancel
Save