from common import * import protobuf if False: from typing import Awaitable, Dict class Message(protobuf.MessageType): def __init__(self, uint_field: int = 0, enum_field: int = 0) -> None: self.sint_field = uint_field self.enum_field = enum_field @classmethod def get_fields(cls) -> Dict: return { 1: ("sint_field", protobuf.SVarintType, 0), 2: ("enum_field", protobuf.EnumType("t", (0, 5, 25)), 0), } class ByteReader: def __init__(self, data: bytes) -> None: self.data = data self.pos = 0 async def areadinto(self, buf: bytearray) -> int: remaining = len(self.data) - self.pos limit = len(buf) if remaining < limit: raise EOFError buf[:] = self.data[self.pos : self.pos + limit] self.pos += limit return limit class ByteArrayWriter: def __init__(self) -> None: self.buf = bytearray(0) async def awrite(self, buf: bytes) -> int: self.buf.extend(buf) return len(buf) def run_until_complete(task: Awaitable) -> Any: value = None while True: try: result = task.send(value) except StopIteration as e: return e.value if result: value = run_until_complete(result) else: value = None def load_uvarint(data: bytes) -> int: reader = ByteReader(data) return run_until_complete(protobuf.load_uvarint(reader)) def dump_uvarint(value: int) -> bytearray: writer = ByteArrayWriter() run_until_complete(protobuf.dump_uvarint(writer, value)) return writer.buf class TestProtobuf(unittest.TestCase): def test_dump_uvarint(self): self.assertEqual(dump_uvarint(0), b"\x00") self.assertEqual(dump_uvarint(1), b"\x01") self.assertEqual(dump_uvarint(0xFF), b"\xff\x01") self.assertEqual(dump_uvarint(123456), b"\xc0\xc4\x07") with self.assertRaises(ValueError): dump_uvarint(-1) def test_load_uvarint(self): self.assertEqual(load_uvarint(b"\x00"), 0) self.assertEqual(load_uvarint(b"\x01"), 1) self.assertEqual(load_uvarint(b"\xff\x01"), 0xFF) self.assertEqual(load_uvarint(b"\xc0\xc4\x07"), 123456) def test_sint_uint(self): self.assertEqual(protobuf.uint_to_sint(0), 0) self.assertEqual(protobuf.sint_to_uint(0), 0) self.assertEqual(protobuf.sint_to_uint(-1), 1) self.assertEqual(protobuf.sint_to_uint(1), 2) self.assertEqual(protobuf.uint_to_sint(1), -1) self.assertEqual(protobuf.uint_to_sint(2), 1) # roundtrip: self.assertEqual( protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011 ) self.assertEqual( protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)), -2 ** 32 ) def test_validate_enum(self): # ok message: msg = Message(-42, 5) writer = ByteArrayWriter() run_until_complete(protobuf.dump_message(writer, msg)) reader = ByteReader(bytes(writer.buf)) nmsg = run_until_complete(protobuf.load_message(reader, Message)) self.assertEqual(msg.sint_field, nmsg.sint_field) self.assertEqual(msg.enum_field, nmsg.enum_field) # bad enum value: msg = Message(-42, 42) writer = ByteArrayWriter() run_until_complete(protobuf.dump_message(writer, msg)) reader = ByteReader(bytes(writer.buf)) with self.assertRaises(TypeError): run_until_complete(protobuf.load_message(reader, Message)) if __name__ == "__main__": unittest.main()