mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-24 22:31:35 +00:00
core/tests: add synchronous protobuf tests
This commit is contained in:
parent
31e2170766
commit
0a758b8181
@ -1,9 +1,7 @@
|
|||||||
from common import *
|
from common import *
|
||||||
|
|
||||||
import protobuf
|
import protobuf
|
||||||
|
from trezor.utils import BufferIO
|
||||||
if False:
|
|
||||||
from typing import Awaitable, Dict
|
|
||||||
|
|
||||||
|
|
||||||
class Message(protobuf.MessageType):
|
class Message(protobuf.MessageType):
|
||||||
@ -12,47 +10,22 @@ class Message(protobuf.MessageType):
|
|||||||
self.enum_field = enum_field
|
self.enum_field = enum_field
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls) -> Dict:
|
def get_fields(cls):
|
||||||
return {
|
return {
|
||||||
1: ("sint_field", protobuf.SVarintType, 0),
|
1: ("sint_field", protobuf.SVarintType, 0),
|
||||||
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
|
2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ByteReader:
|
|
||||||
def __init__(self, data: bytes) -> None:
|
|
||||||
self.data = data
|
|
||||||
self.pos = 0
|
|
||||||
|
|
||||||
async def areadinto(self, buf: bytearray) -> int:
|
|
||||||
remaining = len(self.data) - self.pos
|
|
||||||
limit = len(buf)
|
|
||||||
if remaining < limit:
|
|
||||||
raise EOFError
|
|
||||||
|
|
||||||
buf[:] = self.data[self.pos : self.pos + limit]
|
|
||||||
self.pos += limit
|
|
||||||
return limit
|
|
||||||
|
|
||||||
|
|
||||||
class ByteArrayWriter:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.buf = bytearray(0)
|
|
||||||
|
|
||||||
async def awrite(self, buf: bytes) -> int:
|
|
||||||
self.buf.extend(buf)
|
|
||||||
return len(buf)
|
|
||||||
|
|
||||||
|
|
||||||
def load_uvarint(data: bytes) -> int:
|
def load_uvarint(data: bytes) -> int:
|
||||||
reader = ByteReader(data)
|
reader = BufferIO(data)
|
||||||
return await_result(protobuf.load_uvarint(reader))
|
return protobuf.load_uvarint(reader)
|
||||||
|
|
||||||
|
|
||||||
def dump_uvarint(value: int) -> bytearray:
|
def dump_uvarint(value: int) -> bytearray:
|
||||||
writer = ByteArrayWriter()
|
writer = BufferIO(bytearray(16))
|
||||||
await_result(protobuf.dump_uvarint(writer, value))
|
protobuf.dump_uvarint(writer, value)
|
||||||
return writer.buf
|
return memoryview(writer.buffer)[:writer.offset]
|
||||||
|
|
||||||
|
|
||||||
class TestProtobuf(unittest.TestCase):
|
class TestProtobuf(unittest.TestCase):
|
||||||
@ -91,21 +64,23 @@ class TestProtobuf(unittest.TestCase):
|
|||||||
def test_validate_enum(self):
|
def test_validate_enum(self):
|
||||||
# ok message:
|
# ok message:
|
||||||
msg = Message(-42, 5)
|
msg = Message(-42, 5)
|
||||||
writer = ByteArrayWriter()
|
length = protobuf.count_message(msg)
|
||||||
await_result(protobuf.dump_message(writer, msg))
|
buffer_io = BufferIO(bytearray(length))
|
||||||
reader = ByteReader(bytes(writer.buf))
|
protobuf.dump_message(buffer_io, msg)
|
||||||
nmsg = await_result(protobuf.load_message(reader, Message))
|
buffer_io.seek(0)
|
||||||
|
nmsg = protobuf.load_message(buffer_io, Message)
|
||||||
|
|
||||||
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
||||||
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
||||||
|
|
||||||
# bad enum value:
|
# bad enum value:
|
||||||
|
buffer_io.seek(0)
|
||||||
msg = Message(-42, 42)
|
msg = Message(-42, 42)
|
||||||
writer = ByteArrayWriter()
|
# XXX this assumes the message will have equal size
|
||||||
await_result(protobuf.dump_message(writer, msg))
|
protobuf.dump_message(buffer_io, msg)
|
||||||
reader = ByteReader(bytes(writer.buf))
|
buffer_io.seek(0)
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
await_result(protobuf.load_message(reader, Message))
|
protobuf.load_message(buffer_io, Message)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
from common import *
|
from common import *
|
||||||
from ubinascii import unhexlify
|
from ubinascii import unhexlify
|
||||||
|
import ustruct
|
||||||
|
|
||||||
from trezor import io
|
from trezor import io
|
||||||
from trezor.loop import wait
|
from trezor.loop import wait
|
||||||
from trezor.utils import chunks
|
from trezor.utils import chunks, BufferIO
|
||||||
from trezor.wire import codec_v1
|
from trezor.wire import codec_v1
|
||||||
|
|
||||||
|
|
||||||
class MockHID:
|
class MockHID:
|
||||||
|
|
||||||
def __init__(self, num):
|
def __init__(self, num):
|
||||||
self.num = num
|
self.num = num
|
||||||
self.data = []
|
self.data = []
|
||||||
@ -20,150 +20,162 @@ class MockHID:
|
|||||||
self.data.append(bytearray(msg))
|
self.data.append(bytearray(msg))
|
||||||
return len(msg)
|
return len(msg)
|
||||||
|
|
||||||
|
def wait_object(self, mode):
|
||||||
|
return wait(mode | self.num)
|
||||||
|
|
||||||
|
|
||||||
|
MESSAGE_TYPE = 0x4242
|
||||||
|
|
||||||
|
|
||||||
|
def make_header(mtype, length):
|
||||||
|
return b"?##" + ustruct.pack(">HL", mtype, length)
|
||||||
|
|
||||||
|
|
||||||
class TestWireCodecV1(unittest.TestCase):
|
class TestWireCodecV1(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.interface = MockHID(0xDEADBEEF)
|
||||||
|
|
||||||
def test_reader(self):
|
def test_read_one_packet(self):
|
||||||
rep_len = 64
|
# zero length message - just a header
|
||||||
interface_num = 0xdeadbeef
|
message_packet = make_header(mtype=MESSAGE_TYPE, length=0)
|
||||||
message_type = 0x4321
|
buffer = bytearray(64)
|
||||||
message_len = 250
|
|
||||||
interface = MockHID(interface_num)
|
|
||||||
reader = codec_v1.Reader(interface)
|
|
||||||
|
|
||||||
message = bytearray(range(message_len))
|
gen = codec_v1.read_message(self.interface, buffer)
|
||||||
report_header = bytearray(unhexlify('3f23234321000000fa'))
|
|
||||||
|
|
||||||
# open, expected one read
|
query = gen.send(None)
|
||||||
first_report = report_header + message[:rep_len - len(report_header)]
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
self.assertAsync(reader.aopen(), [(None, wait(io.POLL_READ | interface_num)), (first_report, StopIteration()), ])
|
|
||||||
self.assertEqual(reader.type, message_type)
|
|
||||||
self.assertEqual(reader.size, message_len)
|
|
||||||
|
|
||||||
# empty read
|
with self.assertRaises(StopIteration) as e:
|
||||||
empty_buffer = bytearray()
|
gen.send(message_packet)
|
||||||
self.assertAsync(reader.areadinto(empty_buffer), [(None, StopIteration()), ])
|
|
||||||
self.assertEqual(len(empty_buffer), 0)
|
|
||||||
self.assertEqual(reader.size, message_len)
|
|
||||||
|
|
||||||
# short read, expected no read
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
short_buffer = bytearray(32)
|
result = e.value.value
|
||||||
self.assertAsync(reader.areadinto(short_buffer), [(None, StopIteration()), ])
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
self.assertEqual(len(short_buffer), 32)
|
self.assertEqual(result.data.buffer, b"")
|
||||||
self.assertEqual(short_buffer, message[:len(short_buffer)])
|
|
||||||
self.assertEqual(reader.size, message_len - len(short_buffer))
|
|
||||||
|
|
||||||
# aligned read, expected no read
|
# message should have been read into the buffer
|
||||||
aligned_buffer = bytearray(rep_len - len(report_header) - len(short_buffer))
|
self.assertEqual(buffer, b"\x00" * 64)
|
||||||
self.assertAsync(reader.areadinto(aligned_buffer), [(None, StopIteration()), ])
|
|
||||||
self.assertEqual(aligned_buffer, message[len(short_buffer):][:len(aligned_buffer)])
|
|
||||||
self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer))
|
|
||||||
|
|
||||||
# one byte read, expected one read
|
def test_read_many_packets(self):
|
||||||
next_report_header = bytearray(unhexlify('3f'))
|
message = bytes(range(256))
|
||||||
next_report = next_report_header + message[rep_len - len(report_header):][:rep_len - len(next_report_header)]
|
|
||||||
onebyte_buffer = bytearray(1)
|
|
||||||
self.assertAsync(reader.areadinto(onebyte_buffer), [(None, wait(io.POLL_READ | interface_num)), (next_report, StopIteration()), ])
|
|
||||||
self.assertEqual(onebyte_buffer, message[len(short_buffer):][len(aligned_buffer):][:len(onebyte_buffer)])
|
|
||||||
self.assertEqual(reader.size, message_len - len(short_buffer) - len(aligned_buffer) - len(onebyte_buffer))
|
|
||||||
|
|
||||||
# too long read, raises eof
|
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
||||||
self.assertAsync(reader.areadinto(bytearray(reader.size + 1)), [(None, EOFError()), ])
|
first_len = codec_v1._REP_LEN - len(header)
|
||||||
|
# first packet is header + (remaining)data
|
||||||
|
# other packets are "?" + 63 bytes of data
|
||||||
|
packets = [header + message[:first_len]] + [
|
||||||
|
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
|
||||||
|
]
|
||||||
|
|
||||||
# long read, expect multiple reads
|
buffer = bytearray(256)
|
||||||
start_size = reader.size
|
gen = codec_v1.read_message(self.interface, buffer)
|
||||||
long_buffer = bytearray(start_size)
|
query = gen.send(None)
|
||||||
report_payload = message[rep_len - len(report_header) + rep_len - len(next_report_header):]
|
for packet in packets[:-1]:
|
||||||
report_payload_head = report_payload[:rep_len - len(next_report_header) - len(onebyte_buffer)]
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
report_payload_rest = report_payload[len(report_payload_head):]
|
query = gen.send(packet)
|
||||||
report_payload_rest = list(chunks(report_payload_rest, rep_len - len(next_report_header)))
|
|
||||||
report_payloads = [report_payload_head] + report_payload_rest
|
|
||||||
next_reports = [next_report_header + r for r in report_payloads]
|
|
||||||
expected_syscalls = []
|
|
||||||
for i, _ in enumerate(next_reports):
|
|
||||||
prev_report = next_reports[i - 1] if i > 0 else None
|
|
||||||
expected_syscalls.append((prev_report, wait(io.POLL_READ | interface_num)))
|
|
||||||
expected_syscalls.append((next_reports[-1], StopIteration()))
|
|
||||||
self.assertAsync(reader.areadinto(long_buffer), expected_syscalls)
|
|
||||||
self.assertEqual(long_buffer, message[-start_size:])
|
|
||||||
self.assertEqual(reader.size, 0)
|
|
||||||
|
|
||||||
# one byte read, raises eof
|
# last packet will stop
|
||||||
self.assertAsync(reader.areadinto(onebyte_buffer), [(None, EOFError()), ])
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(packets[-1])
|
||||||
|
|
||||||
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
|
result = e.value.value
|
||||||
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
|
self.assertEqual(result.data.buffer, message)
|
||||||
|
|
||||||
|
# message should have been read into the buffer
|
||||||
|
self.assertEqual(buffer, message)
|
||||||
|
|
||||||
|
def test_read_large_message(self):
|
||||||
|
message = b"hello world"
|
||||||
|
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
||||||
|
|
||||||
|
packet = header + message
|
||||||
|
# make sure we fit into one packet, to make this easier
|
||||||
|
self.assertTrue(len(packet) <= codec_v1._REP_LEN)
|
||||||
|
|
||||||
|
buffer = bytearray(1)
|
||||||
|
self.assertTrue(len(buffer) <= len(packet))
|
||||||
|
|
||||||
|
gen = codec_v1.read_message(self.interface, buffer)
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(packet)
|
||||||
|
|
||||||
|
# e.value is StopIteration. e.value.value is the return value of the call
|
||||||
|
result = e.value.value
|
||||||
|
self.assertEqual(result.type, MESSAGE_TYPE)
|
||||||
|
self.assertEqual(result.data.buffer, message)
|
||||||
|
|
||||||
|
# read should have allocated its own buffer and not touch ours
|
||||||
|
self.assertEqual(buffer, b"\x00")
|
||||||
|
|
||||||
|
def test_write_one_packet(self):
|
||||||
|
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, b"")
|
||||||
|
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||||
|
with self.assertRaises(StopIteration):
|
||||||
|
gen.send(None)
|
||||||
|
|
||||||
|
header = make_header(mtype=MESSAGE_TYPE, length=0)
|
||||||
|
expected_message = header + b"\x00" * (codec_v1._REP_LEN - len(header))
|
||||||
|
self.assertTrue(self.interface.data == [expected_message])
|
||||||
|
|
||||||
|
def test_write_multiple_packets(self):
|
||||||
|
message = bytes(range(256))
|
||||||
|
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
|
||||||
|
|
||||||
|
header = make_header(mtype=MESSAGE_TYPE, length=len(message))
|
||||||
|
first_len = codec_v1._REP_LEN - len(header)
|
||||||
|
# first packet is header + (remaining)data
|
||||||
|
# other packets are "?" + 63 bytes of data
|
||||||
|
packets = [header + message[:first_len]] + [
|
||||||
|
b"?" + chunk for chunk in chunks(message[first_len:], codec_v1._REP_LEN - 1)
|
||||||
|
]
|
||||||
|
|
||||||
|
for _ in packets:
|
||||||
|
# we receive as many queries as there are packets
|
||||||
|
query = gen.send(None)
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||||
|
|
||||||
|
# the first sent None only started the generator. the len(packets)-th None
|
||||||
|
# will finish writing and raise StopIteration
|
||||||
|
with self.assertRaises(StopIteration):
|
||||||
|
gen.send(None)
|
||||||
|
|
||||||
|
# packets must be identical up to the last one
|
||||||
|
self.assertListEqual(packets[:-1], self.interface.data[:-1])
|
||||||
|
# last packet must be identical up to message length. remaining bytes in
|
||||||
|
# the 64-byte packets are garbage -- in particular, it's the bytes of the
|
||||||
|
# previous packet
|
||||||
|
last_packet = packets[-1] + packets[-2][len(packets[-1]):]
|
||||||
|
self.assertEqual(last_packet, self.interface.data[-1])
|
||||||
|
|
||||||
|
def test_roundtrip(self):
|
||||||
|
message = bytes(range(256))
|
||||||
|
gen = codec_v1.write_message(self.interface, MESSAGE_TYPE, message)
|
||||||
|
|
||||||
|
# exhaust the iterator:
|
||||||
|
# (XXX we can only do this because the iterator is only accepting None and returns None)
|
||||||
|
for query in gen:
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_WRITE))
|
||||||
|
|
||||||
|
buffer = bytearray(1024)
|
||||||
|
gen = codec_v1.read_message(self.interface, buffer)
|
||||||
|
query = gen.send(None)
|
||||||
|
for packet in self.interface.data[:-1]:
|
||||||
|
self.assertObjectEqual(query, self.interface.wait_object(io.POLL_READ))
|
||||||
|
query = gen.send(packet)
|
||||||
|
|
||||||
|
with self.assertRaises(StopIteration) as e:
|
||||||
|
gen.send(self.interface.data[-1])
|
||||||
|
|
||||||
|
result = e.value.value
|
||||||
|
assert result.type == MESSAGE_TYPE
|
||||||
|
assert result.data.buffer == message
|
||||||
|
|
||||||
|
|
||||||
def test_writer(self):
|
if __name__ == "__main__":
|
||||||
rep_len = 64
|
|
||||||
interface_num = 0xdeadbeef
|
|
||||||
message_type = 0x87654321
|
|
||||||
message_len = 1024
|
|
||||||
interface = MockHID(interface_num)
|
|
||||||
writer = codec_v1.Writer(interface)
|
|
||||||
writer.setheader(message_type, message_len)
|
|
||||||
|
|
||||||
# init header corresponding to the data above
|
|
||||||
report_header = bytearray(unhexlify('3f2323432100000400'))
|
|
||||||
|
|
||||||
self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
|
||||||
|
|
||||||
# empty write
|
|
||||||
start_size = writer.size
|
|
||||||
self.assertAsync(writer.awrite(bytearray()), [(None, StopIteration()), ])
|
|
||||||
self.assertEqual(writer.data, report_header + bytearray(rep_len - len(report_header)))
|
|
||||||
self.assertEqual(writer.size, start_size)
|
|
||||||
|
|
||||||
# short write, expected no report
|
|
||||||
start_size = writer.size
|
|
||||||
short_payload = bytearray(range(4))
|
|
||||||
self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ])
|
|
||||||
self.assertEqual(writer.size, start_size - len(short_payload))
|
|
||||||
self.assertEqual(writer.data,
|
|
||||||
report_header +
|
|
||||||
short_payload +
|
|
||||||
bytearray(rep_len - len(report_header) - len(short_payload)))
|
|
||||||
|
|
||||||
# aligned write, expected one report
|
|
||||||
start_size = writer.size
|
|
||||||
aligned_payload = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
|
||||||
self.assertAsync(writer.awrite(aligned_payload), [(None, wait(io.POLL_WRITE | interface_num)), (None, StopIteration()), ])
|
|
||||||
self.assertEqual(interface.data, [report_header +
|
|
||||||
short_payload +
|
|
||||||
aligned_payload +
|
|
||||||
bytearray(rep_len - len(report_header) - len(short_payload) - len(aligned_payload)), ])
|
|
||||||
self.assertEqual(writer.size, start_size - len(aligned_payload))
|
|
||||||
interface.data.clear()
|
|
||||||
|
|
||||||
# short write, expected no report, but data starts with correct seq and cont marker
|
|
||||||
report_header = bytearray(unhexlify('3f'))
|
|
||||||
start_size = writer.size
|
|
||||||
self.assertAsync(writer.awrite(short_payload), [(None, StopIteration()), ])
|
|
||||||
self.assertEqual(writer.size, start_size - len(short_payload))
|
|
||||||
self.assertEqual(writer.data[:len(report_header) + len(short_payload)],
|
|
||||||
report_header + short_payload)
|
|
||||||
|
|
||||||
# long write, expected multiple reports
|
|
||||||
start_size = writer.size
|
|
||||||
long_payload_head = bytearray(range(rep_len - len(report_header) - len(short_payload)))
|
|
||||||
long_payload_rest = bytearray(range(start_size - len(long_payload_head)))
|
|
||||||
long_payload = long_payload_head + long_payload_rest
|
|
||||||
expected_payloads = [short_payload + long_payload_head] + list(chunks(long_payload_rest, rep_len - len(report_header)))
|
|
||||||
expected_reports = [report_header + r for r in expected_payloads]
|
|
||||||
expected_reports[-1] += bytearray(bytes(1) * (rep_len - len(expected_reports[-1])))
|
|
||||||
# test write
|
|
||||||
expected_write_reports = expected_reports[:-1]
|
|
||||||
self.assertAsync(writer.awrite(long_payload), len(expected_write_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())])
|
|
||||||
self.assertEqual(interface.data, expected_write_reports)
|
|
||||||
self.assertEqual(writer.size, start_size - len(long_payload))
|
|
||||||
interface.data.clear()
|
|
||||||
# test write raises eof
|
|
||||||
self.assertAsync(writer.awrite(bytearray(1)), [(None, EOFError())])
|
|
||||||
self.assertEqual(interface.data, [])
|
|
||||||
# test close
|
|
||||||
expected_close_reports = expected_reports[-1:]
|
|
||||||
self.assertAsync(writer.aclose(), len(expected_close_reports) * [(None, wait(io.POLL_WRITE | interface_num))] + [(None, StopIteration())])
|
|
||||||
self.assertEqual(interface.data, expected_close_reports)
|
|
||||||
self.assertEqual(writer.size, 0)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -9,6 +9,7 @@ class AssertRaisesContext:
|
|||||||
|
|
||||||
def __init__(self, exc):
|
def __init__(self, exc):
|
||||||
self.expected = exc
|
self.expected = exc
|
||||||
|
self.value = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
@ -17,6 +18,7 @@ class AssertRaisesContext:
|
|||||||
if exc_type is None:
|
if exc_type is None:
|
||||||
ensure(False, "%r not raised" % self.expected)
|
ensure(False, "%r not raised" % self.expected)
|
||||||
if issubclass(exc_type, self.expected):
|
if issubclass(exc_type, self.expected):
|
||||||
|
self.value = exc_value
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user