mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-12-24 23:38:09 +00:00
feat!: implement protobuf required fields and default values
BREAKING CHANGE: this makes arguments to protobuf constructors keyword-only, and arguments corresponding to required fields are now mandatory
This commit is contained in:
parent
0799b64b29
commit
90ee5f3d38
@ -51,7 +51,11 @@ class ProtoField:
|
||||
type_name = attr.ib()
|
||||
proto_type = attr.ib()
|
||||
py_type = attr.ib()
|
||||
py_inner_type = attr.ib()
|
||||
default_value = attr.ib()
|
||||
|
||||
@property
|
||||
def optional(self):
|
||||
return not self.required and not self.repeated
|
||||
|
||||
@classmethod
|
||||
def from_field(cls, descriptor, field):
|
||||
@ -75,9 +79,18 @@ class ProtoField:
|
||||
"Unknown field type {} for field {}".format(field.type, field.name)
|
||||
) from None
|
||||
|
||||
py_inner_type = py_type
|
||||
if repeated:
|
||||
py_type = "List[{}]".format(py_type)
|
||||
if not field.HasField("default_value"):
|
||||
default_value = None
|
||||
elif field.type == field.TYPE_ENUM:
|
||||
default_value = str(descriptor.enum_types[type_name][field.default_value])
|
||||
elif field.type == field.TYPE_STRING:
|
||||
default_value = f'"{field.default_value}"'
|
||||
elif field.type == field.TYPE_BYTES:
|
||||
default_value = f'b"{field.default_value}"'
|
||||
elif field.type == field.TYPE_BOOL:
|
||||
default_value = "True" if field.default_value == "true" else "False"
|
||||
else:
|
||||
default_value = field.default_value
|
||||
|
||||
return cls(
|
||||
name=field.name,
|
||||
@ -88,7 +101,7 @@ class ProtoField:
|
||||
type_name=type_name,
|
||||
proto_type=proto_type,
|
||||
py_type=py_type,
|
||||
py_inner_type=py_inner_type,
|
||||
default_value=default_value,
|
||||
)
|
||||
|
||||
|
||||
@ -216,20 +229,26 @@ class Descriptor:
|
||||
yield self.create_message_import(name)
|
||||
|
||||
def create_init_method(self, fields):
|
||||
required_fields = [f for f in fields if f.required]
|
||||
repeated_fields = [f for f in fields if f.repeated]
|
||||
optional_fields = [f for f in fields if f.optional]
|
||||
# please keep the yields aligned
|
||||
# fmt: off
|
||||
... # https://github.com/ambv/black/issues/385
|
||||
yield " def __init__("
|
||||
yield " self,"
|
||||
for field in fields:
|
||||
yield f" {field.name}: {field.py_type} = None,"
|
||||
yield " def __init__("
|
||||
yield " self,"
|
||||
yield " *,"
|
||||
for field in required_fields:
|
||||
yield f" {field.name}: {field.py_type},"
|
||||
for field in repeated_fields:
|
||||
yield f" {field.name}: List[{field.py_type}] = None,"
|
||||
for field in optional_fields:
|
||||
yield f" {field.name}: {field.py_type} = {field.default_value},"
|
||||
yield " ) -> None:"
|
||||
|
||||
for field in fields:
|
||||
if field.repeated:
|
||||
yield f" self.{field.name} = {field.name} if {field.name} is not None else []"
|
||||
else:
|
||||
yield f" self.{field.name} = {field.name}"
|
||||
for field in repeated_fields:
|
||||
yield f" self.{field.name} = {field.name} if {field.name} is not None else []"
|
||||
for field in required_fields + optional_fields:
|
||||
yield f" self.{field.name} = {field.name}"
|
||||
# fmt: on
|
||||
|
||||
def create_fields_method(self, fields):
|
||||
@ -239,10 +258,8 @@ class Descriptor:
|
||||
yield " return {"
|
||||
for field in fields:
|
||||
comments = []
|
||||
if field.required:
|
||||
comments.append("required")
|
||||
if field.orig.HasField("default_value"):
|
||||
comments.append("default={}".format(field.orig.default_value))
|
||||
if field.default_value is not None:
|
||||
comments.append(f"default={field.orig.default_value}")
|
||||
|
||||
if comments:
|
||||
comment = " # " + " ".join(comments)
|
||||
@ -251,8 +268,10 @@ class Descriptor:
|
||||
|
||||
if field.repeated:
|
||||
flags = "p.FLAG_REPEATED"
|
||||
elif field.required:
|
||||
flags = "p.FLAG_REQUIRED"
|
||||
else:
|
||||
flags = "0"
|
||||
flags = field.default_value
|
||||
|
||||
yield " {num}: ('{name}', {type}, {flags}),{comment}".format(
|
||||
num=field.number,
|
||||
@ -288,7 +307,7 @@ class Descriptor:
|
||||
for field in all_enums:
|
||||
allowed_values = self.enum_types[field.type_name].values()
|
||||
valuestr = ", ".join(str(v) for v in sorted(allowed_values))
|
||||
yield " {} = Literal[{}]".format(field.py_inner_type, valuestr)
|
||||
yield " {} = Literal[{}]".format(field.py_type, valuestr)
|
||||
|
||||
yield " except ImportError:"
|
||||
yield " pass"
|
||||
@ -335,7 +354,9 @@ class Descriptor:
|
||||
|
||||
def process_messages(self, messages, include_deprecated=False):
|
||||
for message in sorted(messages, key=lambda m: m.name):
|
||||
self.write_to_file(message.name, self.process_message(message, include_deprecated))
|
||||
self.write_to_file(
|
||||
message.name, self.process_message(message, include_deprecated)
|
||||
)
|
||||
|
||||
def process_enums(self, enums):
|
||||
for enum in sorted(enums, key=lambda e: e.name):
|
||||
|
@ -3,8 +3,6 @@ Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
|
||||
bytes, string, embedded message and repeated fields.
|
||||
"""
|
||||
|
||||
from micropython import const
|
||||
|
||||
if False:
|
||||
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||
from typing_extensions import Protocol
|
||||
@ -178,7 +176,8 @@ class LimitedReader:
|
||||
return nread
|
||||
|
||||
|
||||
FLAG_REPEATED = const(1)
|
||||
FLAG_REPEATED = object()
|
||||
FLAG_REQUIRED = object()
|
||||
|
||||
if False:
|
||||
MessageTypeDef = Union[
|
||||
@ -190,7 +189,7 @@ if False:
|
||||
Type[UnicodeType],
|
||||
Type[MessageType],
|
||||
]
|
||||
FieldDef = Tuple[str, MessageTypeDef, int]
|
||||
FieldDef = Tuple[str, MessageTypeDef, Any]
|
||||
FieldDict = Dict[int, FieldDef]
|
||||
|
||||
FieldCache = Dict[Type[MessageType], FieldDict]
|
||||
@ -201,7 +200,6 @@ if False:
|
||||
def load_message(
|
||||
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None
|
||||
) -> LoadedMessageType:
|
||||
|
||||
if field_cache is None:
|
||||
field_cache = {}
|
||||
fields = field_cache.get(msg_type)
|
||||
@ -209,7 +207,13 @@ def load_message(
|
||||
fields = msg_type.get_fields()
|
||||
field_cache[msg_type] = fields
|
||||
|
||||
msg = msg_type()
|
||||
# we need to avoid calling __init__, which enforces required arguments
|
||||
msg = object.__new__(msg_type) # type: LoadedMessageType
|
||||
# pre-seed the object with defaults
|
||||
for fname, _, fdefault in fields.values():
|
||||
if fdefault is FLAG_REPEATED:
|
||||
fdefault = []
|
||||
setattr(msg, fname, fdefault)
|
||||
|
||||
if False:
|
||||
SingularValue = Union[int, bool, bytearray, str, MessageType]
|
||||
@ -237,7 +241,7 @@ def load_message(
|
||||
raise ValueError
|
||||
continue
|
||||
|
||||
fname, ftype, fflags = field
|
||||
fname, ftype, fdefault = field
|
||||
if wtype != ftype.WIRE_TYPE:
|
||||
raise TypeError # parsed wire type differs from the schema
|
||||
|
||||
@ -263,17 +267,16 @@ def load_message(
|
||||
else:
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
if fflags & FLAG_REPEATED:
|
||||
pvalue = getattr(msg, fname, [])
|
||||
pvalue.append(fvalue)
|
||||
fvalue = pvalue
|
||||
setattr(msg, fname, fvalue)
|
||||
if fdefault is FLAG_REPEATED:
|
||||
getattr(msg, fname).append(fvalue)
|
||||
else:
|
||||
setattr(msg, fname, fvalue)
|
||||
|
||||
# fill missing fields
|
||||
for tag in fields:
|
||||
field = fields[tag]
|
||||
if not hasattr(msg, field[0]):
|
||||
setattr(msg, field[0], None)
|
||||
for fname, _, _ in fields.values():
|
||||
if getattr(msg, fname) is FLAG_REQUIRED:
|
||||
# The message is intended to be user-facing when decoding from wire,
|
||||
# but not when used internally.
|
||||
raise ValueError("Required field '{}' was not received".format(fname))
|
||||
|
||||
return msg
|
||||
|
||||
@ -291,7 +294,7 @@ def dump_message(
|
||||
field_cache[type(msg)] = fields
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
fname, ftype, fdefault = fields[ftag]
|
||||
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
@ -299,7 +302,7 @@ def dump_message(
|
||||
|
||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||
|
||||
if not fflags & FLAG_REPEATED:
|
||||
if fdefault is not FLAG_REPEATED:
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
@ -356,7 +359,7 @@ def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
|
||||
field_cache[type(msg)] = fields
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
fname, ftype, fdefault = fields[ftag]
|
||||
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
@ -364,7 +367,7 @@ def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
|
||||
|
||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||
|
||||
if not fflags & FLAG_REPEATED:
|
||||
if fdefault is not FLAG_REPEATED:
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
|
@ -5,8 +5,8 @@ from trezor.utils import BufferReader, BufferWriter
|
||||
|
||||
|
||||
class Message(protobuf.MessageType):
|
||||
def __init__(self, uint_field: int = 0, enum_field: int = 0) -> None:
|
||||
self.sint_field = uint_field
|
||||
def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None:
|
||||
self.sint_field = sint_field
|
||||
self.enum_field = enum_field
|
||||
|
||||
@classmethod
|
||||
@ -17,6 +17,19 @@ class Message(protobuf.MessageType):
|
||||
}
|
||||
|
||||
|
||||
class MessageWithRequiredAndDefault(protobuf.MessageType):
|
||||
def __init__(self, required_field, default_field) -> None:
|
||||
self.required_field = required_field
|
||||
self.default_field = default_field
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {
|
||||
1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
||||
2: ("default_field", protobuf.SVarintType, -1),
|
||||
}
|
||||
|
||||
|
||||
def load_uvarint(data: bytes) -> int:
|
||||
reader = BufferReader(data)
|
||||
return protobuf.load_uvarint(reader)
|
||||
@ -25,7 +38,18 @@ def load_uvarint(data: bytes) -> int:
|
||||
def dump_uvarint(value: int) -> bytearray:
|
||||
writer = BufferWriter(bytearray(16))
|
||||
protobuf.dump_uvarint(writer, value)
|
||||
return memoryview(writer.buffer)[:writer.offset]
|
||||
return memoryview(writer.buffer)[: writer.offset]
|
||||
|
||||
|
||||
def dump_message(msg: protobuf.MessageType) -> bytearray:
|
||||
length = protobuf.count_message(msg)
|
||||
buffer = bytearray(length)
|
||||
protobuf.dump_message(BufferWriter(buffer), msg)
|
||||
return buffer
|
||||
|
||||
|
||||
def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType:
|
||||
return protobuf.load_message(BufferReader(buffer), msg_type)
|
||||
|
||||
|
||||
class TestProtobuf(unittest.TestCase):
|
||||
@ -58,30 +82,47 @@ class TestProtobuf(unittest.TestCase):
|
||||
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
|
||||
)
|
||||
self.assertEqual(
|
||||
protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)), -2 ** 32
|
||||
protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32)
|
||||
)
|
||||
|
||||
def test_validate_enum(self):
|
||||
# ok message:
|
||||
msg = Message(-42, 5)
|
||||
length = protobuf.count_message(msg)
|
||||
buffer_writer = BufferWriter(bytearray(length))
|
||||
protobuf.dump_message(buffer_writer, msg)
|
||||
|
||||
buffer_reader = BufferReader(buffer_writer.buffer)
|
||||
nmsg = protobuf.load_message(buffer_reader, Message)
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(Message, msg_encoded)
|
||||
|
||||
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
||||
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
||||
|
||||
# bad enum value:
|
||||
buffer_writer.seek(0)
|
||||
msg = Message(-42, 42)
|
||||
# XXX this assumes the message will have equal size
|
||||
protobuf.dump_message(buffer_writer, msg)
|
||||
buffer_reader.seek(0)
|
||||
msg_encoded = dump_message(msg)
|
||||
with self.assertRaises(TypeError):
|
||||
protobuf.load_message(buffer_reader, Message)
|
||||
load_message(Message, msg_encoded)
|
||||
|
||||
def test_required(self):
|
||||
msg = MessageWithRequiredAndDefault(required_field=1, default_field=2)
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||
|
||||
self.assertEqual(nmsg.required_field, 1)
|
||||
self.assertEqual(nmsg.default_field, 2)
|
||||
|
||||
# try a message without the required_field
|
||||
msg = MessageWithRequiredAndDefault(required_field=None, default_field=2)
|
||||
# encoding always succeeds
|
||||
msg_encoded = dump_message(msg)
|
||||
with self.assertRaises(ValueError):
|
||||
load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||
|
||||
# try a message without the default field
|
||||
msg = MessageWithRequiredAndDefault(required_field=1, default_field=None)
|
||||
msg_encoded = dump_message(msg)
|
||||
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||
|
||||
self.assertEqual(nmsg.required_field, 1)
|
||||
self.assertEqual(nmsg.default_field, -1)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -49,7 +49,7 @@ FieldType = Union[
|
||||
Type["UnicodeType"],
|
||||
Type["BytesType"],
|
||||
]
|
||||
FieldInfo = Tuple[str, FieldType, int]
|
||||
FieldInfo = Tuple[str, FieldType, Any]
|
||||
MT = TypeVar("MT", bound="MessageType")
|
||||
|
||||
|
||||
@ -239,12 +239,14 @@ class MessageType:
|
||||
|
||||
def _fill_missing(self) -> None:
|
||||
# fill missing fields
|
||||
for fname, _, fflags in self.get_fields().values():
|
||||
for fname, _, fdefault in self.get_fields().values():
|
||||
if not hasattr(self, fname):
|
||||
if fflags & FLAG_REPEATED:
|
||||
if fdefault is FLAG_REPEATED:
|
||||
setattr(self, fname, [])
|
||||
elif fdefault is FLAG_REQUIRED:
|
||||
raise ValueError("value for required field is missing")
|
||||
else:
|
||||
setattr(self, fname, None)
|
||||
setattr(self, fname, fdefault)
|
||||
|
||||
def ByteSize(self) -> int:
|
||||
data = BytesIO()
|
||||
@ -276,7 +278,8 @@ class CountingWriter:
|
||||
return nwritten
|
||||
|
||||
|
||||
FLAG_REPEATED = 1
|
||||
FLAG_REPEATED = object()
|
||||
FLAG_REQUIRED = object()
|
||||
|
||||
|
||||
def decode_packed_array_field(ftype: FieldType, reader: Reader) -> List[Any]:
|
||||
@ -325,7 +328,14 @@ def decode_length_delimited_field(
|
||||
|
||||
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
fields = msg_type.get_fields()
|
||||
msg = msg_type()
|
||||
|
||||
msg_dict = {}
|
||||
# pre-seed the dict
|
||||
for fname, _, fdefault in fields.values():
|
||||
if fdefault is FLAG_REPEATED:
|
||||
msg_dict[fname] = []
|
||||
elif fdefault is not FLAG_REQUIRED:
|
||||
msg_dict[fname] = fdefault
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -348,9 +358,9 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
raise ValueError
|
||||
continue
|
||||
|
||||
fname, ftype, fflags = field
|
||||
fname, ftype, fdefault = field
|
||||
|
||||
if wtype == 2 and ftype.WIRE_TYPE == 0 and fflags & FLAG_REPEATED:
|
||||
if wtype == 2 and ftype.WIRE_TYPE == 0 and fdefault is FLAG_REPEATED:
|
||||
# packed array
|
||||
fvalues = decode_packed_array_field(ftype, reader)
|
||||
|
||||
@ -366,18 +376,17 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||
else:
|
||||
raise TypeError # unknown wire type
|
||||
|
||||
if fflags & FLAG_REPEATED:
|
||||
pvalue = getattr(msg, fname)
|
||||
pvalue.extend(fvalues)
|
||||
fvalue = pvalue
|
||||
if fdefault is FLAG_REPEATED:
|
||||
msg_dict[fname].extend(fvalues)
|
||||
elif len(fvalues) != 1:
|
||||
raise ValueError("Unexpected multiple values in non-repeating field")
|
||||
else:
|
||||
fvalue = fvalues[0]
|
||||
msg_dict[fname] = fvalues[0]
|
||||
|
||||
setattr(msg, fname, fvalue)
|
||||
|
||||
return msg
|
||||
for fname, _, fdefault in fields.values():
|
||||
if fdefault is FLAG_REQUIRED and fname not in msg_dict:
|
||||
raise ValueError # required field was not received
|
||||
return msg_type(**msg_dict)
|
||||
|
||||
|
||||
def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
@ -386,7 +395,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
fields = mtype.get_fields()
|
||||
|
||||
for ftag in fields:
|
||||
fname, ftype, fflags = fields[ftag]
|
||||
fname, ftype, fdefault = fields[ftag]
|
||||
|
||||
fvalue = getattr(msg, fname, None)
|
||||
if fvalue is None:
|
||||
@ -394,7 +403,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
|
||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||
|
||||
if not fflags & FLAG_REPEATED:
|
||||
if fdefault is not FLAG_REPEATED:
|
||||
repvalue[0] = fvalue
|
||||
fvalue = repvalue
|
||||
|
||||
@ -529,8 +538,8 @@ def value_to_proto(ftype: FieldType, value: Any) -> Any:
|
||||
|
||||
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
||||
params = {}
|
||||
for fname, ftype, fflags in message_type.get_fields().values():
|
||||
repeated = fflags & FLAG_REPEATED
|
||||
for fname, ftype, fdefault in message_type.get_fields().values():
|
||||
repeated = fdefault is FLAG_REPEATED
|
||||
value = d.get(fname)
|
||||
if value is None:
|
||||
continue
|
||||
|
@ -26,25 +26,25 @@ class PrimitiveMessage(protobuf.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {
|
||||
1: ("uvarint", protobuf.UVarintType, 0),
|
||||
2: ("svarint", protobuf.SVarintType, 0),
|
||||
3: ("bool", protobuf.BoolType, 0),
|
||||
4: ("bytes", protobuf.BytesType, 0),
|
||||
5: ("unicode", protobuf.UnicodeType, 0),
|
||||
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 0),
|
||||
1: ("uvarint", protobuf.UVarintType, None),
|
||||
2: ("svarint", protobuf.SVarintType, None),
|
||||
3: ("bool", protobuf.BoolType, None),
|
||||
4: ("bytes", protobuf.BytesType, None),
|
||||
5: ("unicode", protobuf.UnicodeType, None),
|
||||
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), None),
|
||||
}
|
||||
|
||||
|
||||
class EnumMessageMoreValues(protobuf.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), 0)}
|
||||
return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), None)}
|
||||
|
||||
|
||||
class EnumMessageLessValues(protobuf.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {1: ("enum", protobuf.EnumType("t", (0, 5)), 0)}
|
||||
return {1: ("enum", protobuf.EnumType("t", (0, 5)), None)}
|
||||
|
||||
|
||||
class RepeatedFields(protobuf.MessageType):
|
||||
@ -68,6 +68,17 @@ def dump_uvarint(value):
|
||||
return writer.getvalue()
|
||||
|
||||
|
||||
def load_message(buffer, msg_type):
|
||||
reader = BytesIO(buffer)
|
||||
return protobuf.load_message(reader, msg_type)
|
||||
|
||||
|
||||
def dump_message(msg):
|
||||
writer = BytesIO()
|
||||
protobuf.dump_message(writer, msg)
|
||||
return writer.getvalue()
|
||||
|
||||
|
||||
def test_dump_uvarint():
|
||||
assert dump_uvarint(0) == b"\x00"
|
||||
assert dump_uvarint(1) == b"\x01"
|
||||
@ -109,7 +120,7 @@ def test_sint_uint():
|
||||
|
||||
# roundtrip:
|
||||
assert protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)) == 1234567891011
|
||||
assert protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)) == -2 ** 32
|
||||
assert protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))) == -(2 ** 32)
|
||||
|
||||
|
||||
def test_simple_message():
|
||||
@ -122,11 +133,8 @@ def test_simple_message():
|
||||
enum=5,
|
||||
)
|
||||
|
||||
buf = BytesIO()
|
||||
|
||||
protobuf.dump_message(buf, msg)
|
||||
buf.seek(0)
|
||||
retr = protobuf.load_message(buf, PrimitiveMessage)
|
||||
buf = dump_message(msg)
|
||||
retr = load_message(buf, PrimitiveMessage)
|
||||
|
||||
assert msg == retr
|
||||
assert retr.uvarint == 12345678910
|
||||
@ -141,18 +149,15 @@ def test_validate_enum(caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
# round-trip of a valid value
|
||||
msg = EnumMessageMoreValues(enum=0)
|
||||
buf = BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
buf.seek(0)
|
||||
retr = protobuf.load_message(buf, EnumMessageLessValues)
|
||||
buf = dump_message(msg)
|
||||
retr = load_message(buf, EnumMessageLessValues)
|
||||
assert retr.enum == msg.enum
|
||||
|
||||
assert not caplog.records
|
||||
|
||||
# dumping an invalid enum value fails
|
||||
msg.enum = 19
|
||||
buf.seek(0)
|
||||
protobuf.dump_message(buf, msg)
|
||||
buf = dump_message(msg)
|
||||
|
||||
assert len(caplog.records) == 1
|
||||
record = caplog.records.pop(0)
|
||||
@ -160,10 +165,8 @@ def test_validate_enum(caplog):
|
||||
assert record.getMessage() == "Value 19 unknown for type t"
|
||||
|
||||
msg.enum = 3
|
||||
buf.seek(0)
|
||||
protobuf.dump_message(buf, msg)
|
||||
buf.seek(0)
|
||||
protobuf.load_message(buf, EnumMessageLessValues)
|
||||
buf = dump_message(msg)
|
||||
load_message(buf, EnumMessageLessValues)
|
||||
|
||||
assert len(caplog.records) == 1
|
||||
record = caplog.records.pop(0)
|
||||
@ -175,10 +178,8 @@ def test_repeated():
|
||||
msg = RepeatedFields(
|
||||
uintlist=[1, 2, 3], enumlist=[0, 1, 0, 1], strlist=["hello", "world"]
|
||||
)
|
||||
buf = BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
buf.seek(0)
|
||||
retr = protobuf.load_message(buf, RepeatedFields)
|
||||
buf = dump_message(msg)
|
||||
retr = load_message(buf, RepeatedFields)
|
||||
|
||||
assert retr == msg
|
||||
|
||||
@ -187,8 +188,7 @@ def test_enum_in_repeated(caplog):
|
||||
caplog.set_level(logging.INFO)
|
||||
|
||||
msg = RepeatedFields(enumlist=[0, 1, 2, 3])
|
||||
buf = BytesIO()
|
||||
protobuf.dump_message(buf, msg)
|
||||
dump_message(msg)
|
||||
assert len(caplog.records) == 2
|
||||
for record in caplog.records:
|
||||
assert record.levelname == "INFO"
|
||||
@ -202,8 +202,7 @@ def test_packed():
|
||||
field_len = len(packed_values)
|
||||
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
|
||||
|
||||
buf = BytesIO(message_bytes)
|
||||
msg = protobuf.load_message(buf, RepeatedFields)
|
||||
msg = load_message(message_bytes, RepeatedFields)
|
||||
assert msg
|
||||
assert msg.uintlist == values
|
||||
assert not msg.enumlist
|
||||
@ -217,9 +216,76 @@ def test_packed_enum():
|
||||
field_len = len(packed_values)
|
||||
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
|
||||
|
||||
buf = BytesIO(message_bytes)
|
||||
msg = protobuf.load_message(buf, RepeatedFields)
|
||||
msg = load_message(message_bytes, RepeatedFields)
|
||||
assert msg
|
||||
assert msg.enumlist == values
|
||||
assert not msg.uintlist
|
||||
assert not msg.strlist
|
||||
|
||||
|
||||
class RequiredFields(protobuf.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {
|
||||
1: ("uvarint", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
||||
2: ("nested", PrimitiveMessage, protobuf.FLAG_REQUIRED),
|
||||
}
|
||||
|
||||
|
||||
def test_required():
|
||||
msg = RequiredFields(uvarint=3, nested=PrimitiveMessage())
|
||||
buf = dump_message(msg)
|
||||
msg_ok = load_message(buf, RequiredFields)
|
||||
|
||||
assert msg_ok == msg
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# cannot construct instance without the required fields
|
||||
msg = RequiredFields(uvarint=3)
|
||||
|
||||
msg = RequiredFields(uvarint=3, nested=None)
|
||||
# we can always encode an invalid message
|
||||
buf = dump_message(msg)
|
||||
with pytest.raises(ValueError):
|
||||
# required field `nested` is also not sent
|
||||
load_message(buf, RequiredFields)
|
||||
|
||||
msg = RequiredFields(uvarint=None, nested=PrimitiveMessage())
|
||||
buf = dump_message(msg)
|
||||
with pytest.raises(ValueError):
|
||||
# required field `uvarint` is not sent
|
||||
load_message(buf, RequiredFields)
|
||||
|
||||
|
||||
class DefaultFields(protobuf.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {
|
||||
1: ("uvarint", protobuf.UVarintType, 42),
|
||||
2: ("svarint", protobuf.SVarintType, -42),
|
||||
3: ("bool", protobuf.BoolType, True),
|
||||
4: ("bytes", protobuf.BytesType, b"hello"),
|
||||
5: ("unicode", protobuf.UnicodeType, "hello"),
|
||||
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 5),
|
||||
}
|
||||
|
||||
|
||||
def test_default():
|
||||
# load empty message
|
||||
retr = load_message(b"", DefaultFields)
|
||||
assert retr.uvarint == 42
|
||||
assert retr.svarint == -42
|
||||
assert retr.bool is True
|
||||
assert retr.bytes == b"hello"
|
||||
assert retr.unicode == "hello"
|
||||
assert retr.enum == 5
|
||||
|
||||
msg = DefaultFields(uvarint=0)
|
||||
buf = dump_message(msg)
|
||||
retr = load_message(buf, DefaultFields)
|
||||
assert retr.uvarint == 0
|
||||
|
||||
msg = DefaultFields(uvarint=None)
|
||||
buf = dump_message(msg)
|
||||
retr = load_message(buf, DefaultFields)
|
||||
assert retr.uvarint == 42
|
||||
|
@ -31,12 +31,12 @@ class SimpleMessage(protobuf.MessageType):
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
return {
|
||||
1: ("uvarint", protobuf.UVarintType, 0),
|
||||
2: ("svarint", protobuf.SVarintType, 0),
|
||||
3: ("bool", protobuf.BoolType, 0),
|
||||
4: ("bytes", protobuf.BytesType, 0),
|
||||
5: ("unicode", protobuf.UnicodeType, 0),
|
||||
6: ("enum", SimpleEnumType, 0),
|
||||
1: ("uvarint", protobuf.UVarintType, None),
|
||||
2: ("svarint", protobuf.SVarintType, None),
|
||||
3: ("bool", protobuf.BoolType, None),
|
||||
4: ("bytes", protobuf.BytesType, None),
|
||||
5: ("unicode", protobuf.UnicodeType, None),
|
||||
6: ("enum", SimpleEnumType, None),
|
||||
7: ("rep_int", protobuf.UVarintType, protobuf.FLAG_REPEATED),
|
||||
8: ("rep_str", protobuf.UnicodeType, protobuf.FLAG_REPEATED),
|
||||
9: ("rep_enum", SimpleEnumType, protobuf.FLAG_REPEATED),
|
||||
|
Loading…
Reference in New Issue
Block a user