|
|
|
@ -26,25 +26,25 @@ class PrimitiveMessage(protobuf.MessageType):
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_fields(cls):
|
|
|
|
|
return {
|
|
|
|
|
1: ("uvarint", protobuf.UVarintType, 0),
|
|
|
|
|
2: ("svarint", protobuf.SVarintType, 0),
|
|
|
|
|
3: ("bool", protobuf.BoolType, 0),
|
|
|
|
|
4: ("bytes", protobuf.BytesType, 0),
|
|
|
|
|
5: ("unicode", protobuf.UnicodeType, 0),
|
|
|
|
|
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 0),
|
|
|
|
|
1: ("uvarint", protobuf.UVarintType, None),
|
|
|
|
|
2: ("svarint", protobuf.SVarintType, None),
|
|
|
|
|
3: ("bool", protobuf.BoolType, None),
|
|
|
|
|
4: ("bytes", protobuf.BytesType, None),
|
|
|
|
|
5: ("unicode", protobuf.UnicodeType, None),
|
|
|
|
|
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), None),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EnumMessageMoreValues(protobuf.MessageType):
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_fields(cls):
|
|
|
|
|
return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), 0)}
|
|
|
|
|
return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), None)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EnumMessageLessValues(protobuf.MessageType):
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_fields(cls):
|
|
|
|
|
return {1: ("enum", protobuf.EnumType("t", (0, 5)), 0)}
|
|
|
|
|
return {1: ("enum", protobuf.EnumType("t", (0, 5)), None)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RepeatedFields(protobuf.MessageType):
|
|
|
|
@ -68,6 +68,17 @@ def dump_uvarint(value):
|
|
|
|
|
return writer.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_message(buffer, msg_type):
|
|
|
|
|
reader = BytesIO(buffer)
|
|
|
|
|
return protobuf.load_message(reader, msg_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dump_message(msg):
|
|
|
|
|
writer = BytesIO()
|
|
|
|
|
protobuf.dump_message(writer, msg)
|
|
|
|
|
return writer.getvalue()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_dump_uvarint():
|
|
|
|
|
assert dump_uvarint(0) == b"\x00"
|
|
|
|
|
assert dump_uvarint(1) == b"\x01"
|
|
|
|
@ -109,7 +120,7 @@ def test_sint_uint():
|
|
|
|
|
|
|
|
|
|
# roundtrip:
|
|
|
|
|
assert protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)) == 1234567891011
|
|
|
|
|
assert protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)) == -2 ** 32
|
|
|
|
|
assert protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))) == -(2 ** 32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_simple_message():
|
|
|
|
@ -122,11 +133,8 @@ def test_simple_message():
|
|
|
|
|
enum=5,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
buf = BytesIO()
|
|
|
|
|
|
|
|
|
|
protobuf.dump_message(buf, msg)
|
|
|
|
|
buf.seek(0)
|
|
|
|
|
retr = protobuf.load_message(buf, PrimitiveMessage)
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
retr = load_message(buf, PrimitiveMessage)
|
|
|
|
|
|
|
|
|
|
assert msg == retr
|
|
|
|
|
assert retr.uvarint == 12345678910
|
|
|
|
@ -141,18 +149,15 @@ def test_validate_enum(caplog):
|
|
|
|
|
caplog.set_level(logging.INFO)
|
|
|
|
|
# round-trip of a valid value
|
|
|
|
|
msg = EnumMessageMoreValues(enum=0)
|
|
|
|
|
buf = BytesIO()
|
|
|
|
|
protobuf.dump_message(buf, msg)
|
|
|
|
|
buf.seek(0)
|
|
|
|
|
retr = protobuf.load_message(buf, EnumMessageLessValues)
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
retr = load_message(buf, EnumMessageLessValues)
|
|
|
|
|
assert retr.enum == msg.enum
|
|
|
|
|
|
|
|
|
|
assert not caplog.records
|
|
|
|
|
|
|
|
|
|
# dumping an invalid enum value fails
|
|
|
|
|
msg.enum = 19
|
|
|
|
|
buf.seek(0)
|
|
|
|
|
protobuf.dump_message(buf, msg)
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
|
|
|
|
|
assert len(caplog.records) == 1
|
|
|
|
|
record = caplog.records.pop(0)
|
|
|
|
@ -160,10 +165,8 @@ def test_validate_enum(caplog):
|
|
|
|
|
assert record.getMessage() == "Value 19 unknown for type t"
|
|
|
|
|
|
|
|
|
|
msg.enum = 3
|
|
|
|
|
buf.seek(0)
|
|
|
|
|
protobuf.dump_message(buf, msg)
|
|
|
|
|
buf.seek(0)
|
|
|
|
|
protobuf.load_message(buf, EnumMessageLessValues)
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
load_message(buf, EnumMessageLessValues)
|
|
|
|
|
|
|
|
|
|
assert len(caplog.records) == 1
|
|
|
|
|
record = caplog.records.pop(0)
|
|
|
|
@ -175,10 +178,8 @@ def test_repeated():
|
|
|
|
|
msg = RepeatedFields(
|
|
|
|
|
uintlist=[1, 2, 3], enumlist=[0, 1, 0, 1], strlist=["hello", "world"]
|
|
|
|
|
)
|
|
|
|
|
buf = BytesIO()
|
|
|
|
|
protobuf.dump_message(buf, msg)
|
|
|
|
|
buf.seek(0)
|
|
|
|
|
retr = protobuf.load_message(buf, RepeatedFields)
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
retr = load_message(buf, RepeatedFields)
|
|
|
|
|
|
|
|
|
|
assert retr == msg
|
|
|
|
|
|
|
|
|
@ -187,8 +188,7 @@ def test_enum_in_repeated(caplog):
|
|
|
|
|
caplog.set_level(logging.INFO)
|
|
|
|
|
|
|
|
|
|
msg = RepeatedFields(enumlist=[0, 1, 2, 3])
|
|
|
|
|
buf = BytesIO()
|
|
|
|
|
protobuf.dump_message(buf, msg)
|
|
|
|
|
dump_message(msg)
|
|
|
|
|
assert len(caplog.records) == 2
|
|
|
|
|
for record in caplog.records:
|
|
|
|
|
assert record.levelname == "INFO"
|
|
|
|
@ -202,8 +202,7 @@ def test_packed():
|
|
|
|
|
field_len = len(packed_values)
|
|
|
|
|
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
|
|
|
|
|
|
|
|
|
|
buf = BytesIO(message_bytes)
|
|
|
|
|
msg = protobuf.load_message(buf, RepeatedFields)
|
|
|
|
|
msg = load_message(message_bytes, RepeatedFields)
|
|
|
|
|
assert msg
|
|
|
|
|
assert msg.uintlist == values
|
|
|
|
|
assert not msg.enumlist
|
|
|
|
@ -217,9 +216,76 @@ def test_packed_enum():
|
|
|
|
|
field_len = len(packed_values)
|
|
|
|
|
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
|
|
|
|
|
|
|
|
|
|
buf = BytesIO(message_bytes)
|
|
|
|
|
msg = protobuf.load_message(buf, RepeatedFields)
|
|
|
|
|
msg = load_message(message_bytes, RepeatedFields)
|
|
|
|
|
assert msg
|
|
|
|
|
assert msg.enumlist == values
|
|
|
|
|
assert not msg.uintlist
|
|
|
|
|
assert not msg.strlist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RequiredFields(protobuf.MessageType):
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_fields(cls):
|
|
|
|
|
return {
|
|
|
|
|
1: ("uvarint", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
|
|
|
|
2: ("nested", PrimitiveMessage, protobuf.FLAG_REQUIRED),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_required():
|
|
|
|
|
msg = RequiredFields(uvarint=3, nested=PrimitiveMessage())
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
msg_ok = load_message(buf, RequiredFields)
|
|
|
|
|
|
|
|
|
|
assert msg_ok == msg
|
|
|
|
|
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
# cannot construct instance without the required fields
|
|
|
|
|
msg = RequiredFields(uvarint=3)
|
|
|
|
|
|
|
|
|
|
msg = RequiredFields(uvarint=3, nested=None)
|
|
|
|
|
# we can always encode an invalid message
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
# required field `nested` is also not sent
|
|
|
|
|
load_message(buf, RequiredFields)
|
|
|
|
|
|
|
|
|
|
msg = RequiredFields(uvarint=None, nested=PrimitiveMessage())
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
with pytest.raises(ValueError):
|
|
|
|
|
# required field `uvarint` is not sent
|
|
|
|
|
load_message(buf, RequiredFields)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DefaultFields(protobuf.MessageType):
|
|
|
|
|
@classmethod
|
|
|
|
|
def get_fields(cls):
|
|
|
|
|
return {
|
|
|
|
|
1: ("uvarint", protobuf.UVarintType, 42),
|
|
|
|
|
2: ("svarint", protobuf.SVarintType, -42),
|
|
|
|
|
3: ("bool", protobuf.BoolType, True),
|
|
|
|
|
4: ("bytes", protobuf.BytesType, b"hello"),
|
|
|
|
|
5: ("unicode", protobuf.UnicodeType, "hello"),
|
|
|
|
|
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 5),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_default():
|
|
|
|
|
# load empty message
|
|
|
|
|
retr = load_message(b"", DefaultFields)
|
|
|
|
|
assert retr.uvarint == 42
|
|
|
|
|
assert retr.svarint == -42
|
|
|
|
|
assert retr.bool is True
|
|
|
|
|
assert retr.bytes == b"hello"
|
|
|
|
|
assert retr.unicode == "hello"
|
|
|
|
|
assert retr.enum == 5
|
|
|
|
|
|
|
|
|
|
msg = DefaultFields(uvarint=0)
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
retr = load_message(buf, DefaultFields)
|
|
|
|
|
assert retr.uvarint == 0
|
|
|
|
|
|
|
|
|
|
msg = DefaultFields(uvarint=None)
|
|
|
|
|
buf = dump_message(msg)
|
|
|
|
|
retr = load_message(buf, DefaultFields)
|
|
|
|
|
assert retr.uvarint == 42
|
|
|
|
|