mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-22 07:28:10 +00:00
feat!: implement protobuf required fields and default values
BREAKING CHANGE: this makes arguments to protobuf constructors keyword-only, and arguments corresponding to required fields are now mandatory
This commit is contained in:
parent
0799b64b29
commit
90ee5f3d38
@ -51,7 +51,11 @@ class ProtoField:
|
|||||||
type_name = attr.ib()
|
type_name = attr.ib()
|
||||||
proto_type = attr.ib()
|
proto_type = attr.ib()
|
||||||
py_type = attr.ib()
|
py_type = attr.ib()
|
||||||
py_inner_type = attr.ib()
|
default_value = attr.ib()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def optional(self):
|
||||||
|
return not self.required and not self.repeated
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_field(cls, descriptor, field):
|
def from_field(cls, descriptor, field):
|
||||||
@ -75,9 +79,18 @@ class ProtoField:
|
|||||||
"Unknown field type {} for field {}".format(field.type, field.name)
|
"Unknown field type {} for field {}".format(field.type, field.name)
|
||||||
) from None
|
) from None
|
||||||
|
|
||||||
py_inner_type = py_type
|
if not field.HasField("default_value"):
|
||||||
if repeated:
|
default_value = None
|
||||||
py_type = "List[{}]".format(py_type)
|
elif field.type == field.TYPE_ENUM:
|
||||||
|
default_value = str(descriptor.enum_types[type_name][field.default_value])
|
||||||
|
elif field.type == field.TYPE_STRING:
|
||||||
|
default_value = f'"{field.default_value}"'
|
||||||
|
elif field.type == field.TYPE_BYTES:
|
||||||
|
default_value = f'b"{field.default_value}"'
|
||||||
|
elif field.type == field.TYPE_BOOL:
|
||||||
|
default_value = "True" if field.default_value == "true" else "False"
|
||||||
|
else:
|
||||||
|
default_value = field.default_value
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
name=field.name,
|
name=field.name,
|
||||||
@ -88,7 +101,7 @@ class ProtoField:
|
|||||||
type_name=type_name,
|
type_name=type_name,
|
||||||
proto_type=proto_type,
|
proto_type=proto_type,
|
||||||
py_type=py_type,
|
py_type=py_type,
|
||||||
py_inner_type=py_inner_type,
|
default_value=default_value,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -216,20 +229,26 @@ class Descriptor:
|
|||||||
yield self.create_message_import(name)
|
yield self.create_message_import(name)
|
||||||
|
|
||||||
def create_init_method(self, fields):
|
def create_init_method(self, fields):
|
||||||
|
required_fields = [f for f in fields if f.required]
|
||||||
|
repeated_fields = [f for f in fields if f.repeated]
|
||||||
|
optional_fields = [f for f in fields if f.optional]
|
||||||
# please keep the yields aligned
|
# please keep the yields aligned
|
||||||
# fmt: off
|
# fmt: off
|
||||||
... # https://github.com/ambv/black/issues/385
|
yield " def __init__("
|
||||||
yield " def __init__("
|
yield " self,"
|
||||||
yield " self,"
|
yield " *,"
|
||||||
for field in fields:
|
for field in required_fields:
|
||||||
yield f" {field.name}: {field.py_type} = None,"
|
yield f" {field.name}: {field.py_type},"
|
||||||
|
for field in repeated_fields:
|
||||||
|
yield f" {field.name}: List[{field.py_type}] = None,"
|
||||||
|
for field in optional_fields:
|
||||||
|
yield f" {field.name}: {field.py_type} = {field.default_value},"
|
||||||
yield " ) -> None:"
|
yield " ) -> None:"
|
||||||
|
|
||||||
for field in fields:
|
for field in repeated_fields:
|
||||||
if field.repeated:
|
yield f" self.{field.name} = {field.name} if {field.name} is not None else []"
|
||||||
yield f" self.{field.name} = {field.name} if {field.name} is not None else []"
|
for field in required_fields + optional_fields:
|
||||||
else:
|
yield f" self.{field.name} = {field.name}"
|
||||||
yield f" self.{field.name} = {field.name}"
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def create_fields_method(self, fields):
|
def create_fields_method(self, fields):
|
||||||
@ -239,10 +258,8 @@ class Descriptor:
|
|||||||
yield " return {"
|
yield " return {"
|
||||||
for field in fields:
|
for field in fields:
|
||||||
comments = []
|
comments = []
|
||||||
if field.required:
|
if field.default_value is not None:
|
||||||
comments.append("required")
|
comments.append(f"default={field.orig.default_value}")
|
||||||
if field.orig.HasField("default_value"):
|
|
||||||
comments.append("default={}".format(field.orig.default_value))
|
|
||||||
|
|
||||||
if comments:
|
if comments:
|
||||||
comment = " # " + " ".join(comments)
|
comment = " # " + " ".join(comments)
|
||||||
@ -251,8 +268,10 @@ class Descriptor:
|
|||||||
|
|
||||||
if field.repeated:
|
if field.repeated:
|
||||||
flags = "p.FLAG_REPEATED"
|
flags = "p.FLAG_REPEATED"
|
||||||
|
elif field.required:
|
||||||
|
flags = "p.FLAG_REQUIRED"
|
||||||
else:
|
else:
|
||||||
flags = "0"
|
flags = field.default_value
|
||||||
|
|
||||||
yield " {num}: ('{name}', {type}, {flags}),{comment}".format(
|
yield " {num}: ('{name}', {type}, {flags}),{comment}".format(
|
||||||
num=field.number,
|
num=field.number,
|
||||||
@ -288,7 +307,7 @@ class Descriptor:
|
|||||||
for field in all_enums:
|
for field in all_enums:
|
||||||
allowed_values = self.enum_types[field.type_name].values()
|
allowed_values = self.enum_types[field.type_name].values()
|
||||||
valuestr = ", ".join(str(v) for v in sorted(allowed_values))
|
valuestr = ", ".join(str(v) for v in sorted(allowed_values))
|
||||||
yield " {} = Literal[{}]".format(field.py_inner_type, valuestr)
|
yield " {} = Literal[{}]".format(field.py_type, valuestr)
|
||||||
|
|
||||||
yield " except ImportError:"
|
yield " except ImportError:"
|
||||||
yield " pass"
|
yield " pass"
|
||||||
@ -335,7 +354,9 @@ class Descriptor:
|
|||||||
|
|
||||||
def process_messages(self, messages, include_deprecated=False):
|
def process_messages(self, messages, include_deprecated=False):
|
||||||
for message in sorted(messages, key=lambda m: m.name):
|
for message in sorted(messages, key=lambda m: m.name):
|
||||||
self.write_to_file(message.name, self.process_message(message, include_deprecated))
|
self.write_to_file(
|
||||||
|
message.name, self.process_message(message, include_deprecated)
|
||||||
|
)
|
||||||
|
|
||||||
def process_enums(self, enums):
|
def process_enums(self, enums):
|
||||||
for enum in sorted(enums, key=lambda e: e.name):
|
for enum in sorted(enums, key=lambda e: e.name):
|
||||||
|
@ -3,8 +3,6 @@ Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
|
|||||||
bytes, string, embedded message and repeated fields.
|
bytes, string, embedded message and repeated fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from micropython import const
|
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
@ -178,7 +176,8 @@ class LimitedReader:
|
|||||||
return nread
|
return nread
|
||||||
|
|
||||||
|
|
||||||
FLAG_REPEATED = const(1)
|
FLAG_REPEATED = object()
|
||||||
|
FLAG_REQUIRED = object()
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
MessageTypeDef = Union[
|
MessageTypeDef = Union[
|
||||||
@ -190,7 +189,7 @@ if False:
|
|||||||
Type[UnicodeType],
|
Type[UnicodeType],
|
||||||
Type[MessageType],
|
Type[MessageType],
|
||||||
]
|
]
|
||||||
FieldDef = Tuple[str, MessageTypeDef, int]
|
FieldDef = Tuple[str, MessageTypeDef, Any]
|
||||||
FieldDict = Dict[int, FieldDef]
|
FieldDict = Dict[int, FieldDef]
|
||||||
|
|
||||||
FieldCache = Dict[Type[MessageType], FieldDict]
|
FieldCache = Dict[Type[MessageType], FieldDict]
|
||||||
@ -201,7 +200,6 @@ if False:
|
|||||||
def load_message(
|
def load_message(
|
||||||
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None
|
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None
|
||||||
) -> LoadedMessageType:
|
) -> LoadedMessageType:
|
||||||
|
|
||||||
if field_cache is None:
|
if field_cache is None:
|
||||||
field_cache = {}
|
field_cache = {}
|
||||||
fields = field_cache.get(msg_type)
|
fields = field_cache.get(msg_type)
|
||||||
@ -209,7 +207,13 @@ def load_message(
|
|||||||
fields = msg_type.get_fields()
|
fields = msg_type.get_fields()
|
||||||
field_cache[msg_type] = fields
|
field_cache[msg_type] = fields
|
||||||
|
|
||||||
msg = msg_type()
|
# we need to avoid calling __init__, which enforces required arguments
|
||||||
|
msg = object.__new__(msg_type) # type: LoadedMessageType
|
||||||
|
# pre-seed the object with defaults
|
||||||
|
for fname, _, fdefault in fields.values():
|
||||||
|
if fdefault is FLAG_REPEATED:
|
||||||
|
fdefault = []
|
||||||
|
setattr(msg, fname, fdefault)
|
||||||
|
|
||||||
if False:
|
if False:
|
||||||
SingularValue = Union[int, bool, bytearray, str, MessageType]
|
SingularValue = Union[int, bool, bytearray, str, MessageType]
|
||||||
@ -237,7 +241,7 @@ def load_message(
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fname, ftype, fflags = field
|
fname, ftype, fdefault = field
|
||||||
if wtype != ftype.WIRE_TYPE:
|
if wtype != ftype.WIRE_TYPE:
|
||||||
raise TypeError # parsed wire type differs from the schema
|
raise TypeError # parsed wire type differs from the schema
|
||||||
|
|
||||||
@ -263,17 +267,16 @@ def load_message(
|
|||||||
else:
|
else:
|
||||||
raise TypeError # field type is unknown
|
raise TypeError # field type is unknown
|
||||||
|
|
||||||
if fflags & FLAG_REPEATED:
|
if fdefault is FLAG_REPEATED:
|
||||||
pvalue = getattr(msg, fname, [])
|
getattr(msg, fname).append(fvalue)
|
||||||
pvalue.append(fvalue)
|
else:
|
||||||
fvalue = pvalue
|
setattr(msg, fname, fvalue)
|
||||||
setattr(msg, fname, fvalue)
|
|
||||||
|
|
||||||
# fill missing fields
|
for fname, _, _ in fields.values():
|
||||||
for tag in fields:
|
if getattr(msg, fname) is FLAG_REQUIRED:
|
||||||
field = fields[tag]
|
# The message is intended to be user-facing when decoding from wire,
|
||||||
if not hasattr(msg, field[0]):
|
# but not when used internally.
|
||||||
setattr(msg, field[0], None)
|
raise ValueError("Required field '{}' was not received".format(fname))
|
||||||
|
|
||||||
return msg
|
return msg
|
||||||
|
|
||||||
@ -291,7 +294,7 @@ def dump_message(
|
|||||||
field_cache[type(msg)] = fields
|
field_cache[type(msg)] = fields
|
||||||
|
|
||||||
for ftag in fields:
|
for ftag in fields:
|
||||||
fname, ftype, fflags = fields[ftag]
|
fname, ftype, fdefault = fields[ftag]
|
||||||
|
|
||||||
fvalue = getattr(msg, fname, None)
|
fvalue = getattr(msg, fname, None)
|
||||||
if fvalue is None:
|
if fvalue is None:
|
||||||
@ -299,7 +302,7 @@ def dump_message(
|
|||||||
|
|
||||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||||
|
|
||||||
if not fflags & FLAG_REPEATED:
|
if fdefault is not FLAG_REPEATED:
|
||||||
repvalue[0] = fvalue
|
repvalue[0] = fvalue
|
||||||
fvalue = repvalue
|
fvalue = repvalue
|
||||||
|
|
||||||
@ -356,7 +359,7 @@ def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
|
|||||||
field_cache[type(msg)] = fields
|
field_cache[type(msg)] = fields
|
||||||
|
|
||||||
for ftag in fields:
|
for ftag in fields:
|
||||||
fname, ftype, fflags = fields[ftag]
|
fname, ftype, fdefault = fields[ftag]
|
||||||
|
|
||||||
fvalue = getattr(msg, fname, None)
|
fvalue = getattr(msg, fname, None)
|
||||||
if fvalue is None:
|
if fvalue is None:
|
||||||
@ -364,7 +367,7 @@ def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
|
|||||||
|
|
||||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||||
|
|
||||||
if not fflags & FLAG_REPEATED:
|
if fdefault is not FLAG_REPEATED:
|
||||||
repvalue[0] = fvalue
|
repvalue[0] = fvalue
|
||||||
fvalue = repvalue
|
fvalue = repvalue
|
||||||
|
|
||||||
|
@ -5,8 +5,8 @@ from trezor.utils import BufferReader, BufferWriter
|
|||||||
|
|
||||||
|
|
||||||
class Message(protobuf.MessageType):
|
class Message(protobuf.MessageType):
|
||||||
def __init__(self, uint_field: int = 0, enum_field: int = 0) -> None:
|
def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None:
|
||||||
self.sint_field = uint_field
|
self.sint_field = sint_field
|
||||||
self.enum_field = enum_field
|
self.enum_field = enum_field
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -17,6 +17,19 @@ class Message(protobuf.MessageType):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class MessageWithRequiredAndDefault(protobuf.MessageType):
|
||||||
|
def __init__(self, required_field, default_field) -> None:
|
||||||
|
self.required_field = required_field
|
||||||
|
self.default_field = default_field
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls):
|
||||||
|
return {
|
||||||
|
1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
||||||
|
2: ("default_field", protobuf.SVarintType, -1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def load_uvarint(data: bytes) -> int:
|
def load_uvarint(data: bytes) -> int:
|
||||||
reader = BufferReader(data)
|
reader = BufferReader(data)
|
||||||
return protobuf.load_uvarint(reader)
|
return protobuf.load_uvarint(reader)
|
||||||
@ -25,7 +38,18 @@ def load_uvarint(data: bytes) -> int:
|
|||||||
def dump_uvarint(value: int) -> bytearray:
|
def dump_uvarint(value: int) -> bytearray:
|
||||||
writer = BufferWriter(bytearray(16))
|
writer = BufferWriter(bytearray(16))
|
||||||
protobuf.dump_uvarint(writer, value)
|
protobuf.dump_uvarint(writer, value)
|
||||||
return memoryview(writer.buffer)[:writer.offset]
|
return memoryview(writer.buffer)[: writer.offset]
|
||||||
|
|
||||||
|
|
||||||
|
def dump_message(msg: protobuf.MessageType) -> bytearray:
|
||||||
|
length = protobuf.count_message(msg)
|
||||||
|
buffer = bytearray(length)
|
||||||
|
protobuf.dump_message(BufferWriter(buffer), msg)
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
|
||||||
|
def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType:
|
||||||
|
return protobuf.load_message(BufferReader(buffer), msg_type)
|
||||||
|
|
||||||
|
|
||||||
class TestProtobuf(unittest.TestCase):
|
class TestProtobuf(unittest.TestCase):
|
||||||
@ -58,30 +82,47 @@ class TestProtobuf(unittest.TestCase):
|
|||||||
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
|
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)), -2 ** 32
|
protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_validate_enum(self):
|
def test_validate_enum(self):
|
||||||
# ok message:
|
# ok message:
|
||||||
msg = Message(-42, 5)
|
msg = Message(-42, 5)
|
||||||
length = protobuf.count_message(msg)
|
msg_encoded = dump_message(msg)
|
||||||
buffer_writer = BufferWriter(bytearray(length))
|
nmsg = load_message(Message, msg_encoded)
|
||||||
protobuf.dump_message(buffer_writer, msg)
|
|
||||||
|
|
||||||
buffer_reader = BufferReader(buffer_writer.buffer)
|
|
||||||
nmsg = protobuf.load_message(buffer_reader, Message)
|
|
||||||
|
|
||||||
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
self.assertEqual(msg.sint_field, nmsg.sint_field)
|
||||||
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
self.assertEqual(msg.enum_field, nmsg.enum_field)
|
||||||
|
|
||||||
# bad enum value:
|
# bad enum value:
|
||||||
buffer_writer.seek(0)
|
|
||||||
msg = Message(-42, 42)
|
msg = Message(-42, 42)
|
||||||
# XXX this assumes the message will have equal size
|
msg_encoded = dump_message(msg)
|
||||||
protobuf.dump_message(buffer_writer, msg)
|
|
||||||
buffer_reader.seek(0)
|
|
||||||
with self.assertRaises(TypeError):
|
with self.assertRaises(TypeError):
|
||||||
protobuf.load_message(buffer_reader, Message)
|
load_message(Message, msg_encoded)
|
||||||
|
|
||||||
|
def test_required(self):
|
||||||
|
msg = MessageWithRequiredAndDefault(required_field=1, default_field=2)
|
||||||
|
msg_encoded = dump_message(msg)
|
||||||
|
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||||
|
|
||||||
|
self.assertEqual(nmsg.required_field, 1)
|
||||||
|
self.assertEqual(nmsg.default_field, 2)
|
||||||
|
|
||||||
|
# try a message without the required_field
|
||||||
|
msg = MessageWithRequiredAndDefault(required_field=None, default_field=2)
|
||||||
|
# encoding always succeeds
|
||||||
|
msg_encoded = dump_message(msg)
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||||
|
|
||||||
|
# try a message without the default field
|
||||||
|
msg = MessageWithRequiredAndDefault(required_field=1, default_field=None)
|
||||||
|
msg_encoded = dump_message(msg)
|
||||||
|
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
|
||||||
|
|
||||||
|
self.assertEqual(nmsg.required_field, 1)
|
||||||
|
self.assertEqual(nmsg.default_field, -1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -49,7 +49,7 @@ FieldType = Union[
|
|||||||
Type["UnicodeType"],
|
Type["UnicodeType"],
|
||||||
Type["BytesType"],
|
Type["BytesType"],
|
||||||
]
|
]
|
||||||
FieldInfo = Tuple[str, FieldType, int]
|
FieldInfo = Tuple[str, FieldType, Any]
|
||||||
MT = TypeVar("MT", bound="MessageType")
|
MT = TypeVar("MT", bound="MessageType")
|
||||||
|
|
||||||
|
|
||||||
@ -239,12 +239,14 @@ class MessageType:
|
|||||||
|
|
||||||
def _fill_missing(self) -> None:
|
def _fill_missing(self) -> None:
|
||||||
# fill missing fields
|
# fill missing fields
|
||||||
for fname, _, fflags in self.get_fields().values():
|
for fname, _, fdefault in self.get_fields().values():
|
||||||
if not hasattr(self, fname):
|
if not hasattr(self, fname):
|
||||||
if fflags & FLAG_REPEATED:
|
if fdefault is FLAG_REPEATED:
|
||||||
setattr(self, fname, [])
|
setattr(self, fname, [])
|
||||||
|
elif fdefault is FLAG_REQUIRED:
|
||||||
|
raise ValueError("value for required field is missing")
|
||||||
else:
|
else:
|
||||||
setattr(self, fname, None)
|
setattr(self, fname, fdefault)
|
||||||
|
|
||||||
def ByteSize(self) -> int:
|
def ByteSize(self) -> int:
|
||||||
data = BytesIO()
|
data = BytesIO()
|
||||||
@ -276,7 +278,8 @@ class CountingWriter:
|
|||||||
return nwritten
|
return nwritten
|
||||||
|
|
||||||
|
|
||||||
FLAG_REPEATED = 1
|
FLAG_REPEATED = object()
|
||||||
|
FLAG_REQUIRED = object()
|
||||||
|
|
||||||
|
|
||||||
def decode_packed_array_field(ftype: FieldType, reader: Reader) -> List[Any]:
|
def decode_packed_array_field(ftype: FieldType, reader: Reader) -> List[Any]:
|
||||||
@ -325,7 +328,14 @@ def decode_length_delimited_field(
|
|||||||
|
|
||||||
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
||||||
fields = msg_type.get_fields()
|
fields = msg_type.get_fields()
|
||||||
msg = msg_type()
|
|
||||||
|
msg_dict = {}
|
||||||
|
# pre-seed the dict
|
||||||
|
for fname, _, fdefault in fields.values():
|
||||||
|
if fdefault is FLAG_REPEATED:
|
||||||
|
msg_dict[fname] = []
|
||||||
|
elif fdefault is not FLAG_REQUIRED:
|
||||||
|
msg_dict[fname] = fdefault
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -348,9 +358,9 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
|||||||
raise ValueError
|
raise ValueError
|
||||||
continue
|
continue
|
||||||
|
|
||||||
fname, ftype, fflags = field
|
fname, ftype, fdefault = field
|
||||||
|
|
||||||
if wtype == 2 and ftype.WIRE_TYPE == 0 and fflags & FLAG_REPEATED:
|
if wtype == 2 and ftype.WIRE_TYPE == 0 and fdefault is FLAG_REPEATED:
|
||||||
# packed array
|
# packed array
|
||||||
fvalues = decode_packed_array_field(ftype, reader)
|
fvalues = decode_packed_array_field(ftype, reader)
|
||||||
|
|
||||||
@ -366,18 +376,17 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
|
|||||||
else:
|
else:
|
||||||
raise TypeError # unknown wire type
|
raise TypeError # unknown wire type
|
||||||
|
|
||||||
if fflags & FLAG_REPEATED:
|
if fdefault is FLAG_REPEATED:
|
||||||
pvalue = getattr(msg, fname)
|
msg_dict[fname].extend(fvalues)
|
||||||
pvalue.extend(fvalues)
|
|
||||||
fvalue = pvalue
|
|
||||||
elif len(fvalues) != 1:
|
elif len(fvalues) != 1:
|
||||||
raise ValueError("Unexpected multiple values in non-repeating field")
|
raise ValueError("Unexpected multiple values in non-repeating field")
|
||||||
else:
|
else:
|
||||||
fvalue = fvalues[0]
|
msg_dict[fname] = fvalues[0]
|
||||||
|
|
||||||
setattr(msg, fname, fvalue)
|
for fname, _, fdefault in fields.values():
|
||||||
|
if fdefault is FLAG_REQUIRED and fname not in msg_dict:
|
||||||
return msg
|
raise ValueError # required field was not received
|
||||||
|
return msg_type(**msg_dict)
|
||||||
|
|
||||||
|
|
||||||
def dump_message(writer: Writer, msg: MessageType) -> None:
|
def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||||
@ -386,7 +395,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
|||||||
fields = mtype.get_fields()
|
fields = mtype.get_fields()
|
||||||
|
|
||||||
for ftag in fields:
|
for ftag in fields:
|
||||||
fname, ftype, fflags = fields[ftag]
|
fname, ftype, fdefault = fields[ftag]
|
||||||
|
|
||||||
fvalue = getattr(msg, fname, None)
|
fvalue = getattr(msg, fname, None)
|
||||||
if fvalue is None:
|
if fvalue is None:
|
||||||
@ -394,7 +403,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
|||||||
|
|
||||||
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
fkey = (ftag << 3) | ftype.WIRE_TYPE
|
||||||
|
|
||||||
if not fflags & FLAG_REPEATED:
|
if fdefault is not FLAG_REPEATED:
|
||||||
repvalue[0] = fvalue
|
repvalue[0] = fvalue
|
||||||
fvalue = repvalue
|
fvalue = repvalue
|
||||||
|
|
||||||
@ -529,8 +538,8 @@ def value_to_proto(ftype: FieldType, value: Any) -> Any:
|
|||||||
|
|
||||||
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
||||||
params = {}
|
params = {}
|
||||||
for fname, ftype, fflags in message_type.get_fields().values():
|
for fname, ftype, fdefault in message_type.get_fields().values():
|
||||||
repeated = fflags & FLAG_REPEATED
|
repeated = fdefault is FLAG_REPEATED
|
||||||
value = d.get(fname)
|
value = d.get(fname)
|
||||||
if value is None:
|
if value is None:
|
||||||
continue
|
continue
|
||||||
|
@ -26,25 +26,25 @@ class PrimitiveMessage(protobuf.MessageType):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls):
|
def get_fields(cls):
|
||||||
return {
|
return {
|
||||||
1: ("uvarint", protobuf.UVarintType, 0),
|
1: ("uvarint", protobuf.UVarintType, None),
|
||||||
2: ("svarint", protobuf.SVarintType, 0),
|
2: ("svarint", protobuf.SVarintType, None),
|
||||||
3: ("bool", protobuf.BoolType, 0),
|
3: ("bool", protobuf.BoolType, None),
|
||||||
4: ("bytes", protobuf.BytesType, 0),
|
4: ("bytes", protobuf.BytesType, None),
|
||||||
5: ("unicode", protobuf.UnicodeType, 0),
|
5: ("unicode", protobuf.UnicodeType, None),
|
||||||
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 0),
|
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class EnumMessageMoreValues(protobuf.MessageType):
|
class EnumMessageMoreValues(protobuf.MessageType):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls):
|
def get_fields(cls):
|
||||||
return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), 0)}
|
return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), None)}
|
||||||
|
|
||||||
|
|
||||||
class EnumMessageLessValues(protobuf.MessageType):
|
class EnumMessageLessValues(protobuf.MessageType):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls):
|
def get_fields(cls):
|
||||||
return {1: ("enum", protobuf.EnumType("t", (0, 5)), 0)}
|
return {1: ("enum", protobuf.EnumType("t", (0, 5)), None)}
|
||||||
|
|
||||||
|
|
||||||
class RepeatedFields(protobuf.MessageType):
|
class RepeatedFields(protobuf.MessageType):
|
||||||
@ -68,6 +68,17 @@ def dump_uvarint(value):
|
|||||||
return writer.getvalue()
|
return writer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
|
def load_message(buffer, msg_type):
|
||||||
|
reader = BytesIO(buffer)
|
||||||
|
return protobuf.load_message(reader, msg_type)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_message(msg):
|
||||||
|
writer = BytesIO()
|
||||||
|
protobuf.dump_message(writer, msg)
|
||||||
|
return writer.getvalue()
|
||||||
|
|
||||||
|
|
||||||
def test_dump_uvarint():
|
def test_dump_uvarint():
|
||||||
assert dump_uvarint(0) == b"\x00"
|
assert dump_uvarint(0) == b"\x00"
|
||||||
assert dump_uvarint(1) == b"\x01"
|
assert dump_uvarint(1) == b"\x01"
|
||||||
@ -109,7 +120,7 @@ def test_sint_uint():
|
|||||||
|
|
||||||
# roundtrip:
|
# roundtrip:
|
||||||
assert protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)) == 1234567891011
|
assert protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)) == 1234567891011
|
||||||
assert protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)) == -2 ** 32
|
assert protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))) == -(2 ** 32)
|
||||||
|
|
||||||
|
|
||||||
def test_simple_message():
|
def test_simple_message():
|
||||||
@ -122,11 +133,8 @@ def test_simple_message():
|
|||||||
enum=5,
|
enum=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
buf = BytesIO()
|
buf = dump_message(msg)
|
||||||
|
retr = load_message(buf, PrimitiveMessage)
|
||||||
protobuf.dump_message(buf, msg)
|
|
||||||
buf.seek(0)
|
|
||||||
retr = protobuf.load_message(buf, PrimitiveMessage)
|
|
||||||
|
|
||||||
assert msg == retr
|
assert msg == retr
|
||||||
assert retr.uvarint == 12345678910
|
assert retr.uvarint == 12345678910
|
||||||
@ -141,18 +149,15 @@ def test_validate_enum(caplog):
|
|||||||
caplog.set_level(logging.INFO)
|
caplog.set_level(logging.INFO)
|
||||||
# round-trip of a valid value
|
# round-trip of a valid value
|
||||||
msg = EnumMessageMoreValues(enum=0)
|
msg = EnumMessageMoreValues(enum=0)
|
||||||
buf = BytesIO()
|
buf = dump_message(msg)
|
||||||
protobuf.dump_message(buf, msg)
|
retr = load_message(buf, EnumMessageLessValues)
|
||||||
buf.seek(0)
|
|
||||||
retr = protobuf.load_message(buf, EnumMessageLessValues)
|
|
||||||
assert retr.enum == msg.enum
|
assert retr.enum == msg.enum
|
||||||
|
|
||||||
assert not caplog.records
|
assert not caplog.records
|
||||||
|
|
||||||
# dumping an invalid enum value fails
|
# dumping an invalid enum value fails
|
||||||
msg.enum = 19
|
msg.enum = 19
|
||||||
buf.seek(0)
|
buf = dump_message(msg)
|
||||||
protobuf.dump_message(buf, msg)
|
|
||||||
|
|
||||||
assert len(caplog.records) == 1
|
assert len(caplog.records) == 1
|
||||||
record = caplog.records.pop(0)
|
record = caplog.records.pop(0)
|
||||||
@ -160,10 +165,8 @@ def test_validate_enum(caplog):
|
|||||||
assert record.getMessage() == "Value 19 unknown for type t"
|
assert record.getMessage() == "Value 19 unknown for type t"
|
||||||
|
|
||||||
msg.enum = 3
|
msg.enum = 3
|
||||||
buf.seek(0)
|
buf = dump_message(msg)
|
||||||
protobuf.dump_message(buf, msg)
|
load_message(buf, EnumMessageLessValues)
|
||||||
buf.seek(0)
|
|
||||||
protobuf.load_message(buf, EnumMessageLessValues)
|
|
||||||
|
|
||||||
assert len(caplog.records) == 1
|
assert len(caplog.records) == 1
|
||||||
record = caplog.records.pop(0)
|
record = caplog.records.pop(0)
|
||||||
@ -175,10 +178,8 @@ def test_repeated():
|
|||||||
msg = RepeatedFields(
|
msg = RepeatedFields(
|
||||||
uintlist=[1, 2, 3], enumlist=[0, 1, 0, 1], strlist=["hello", "world"]
|
uintlist=[1, 2, 3], enumlist=[0, 1, 0, 1], strlist=["hello", "world"]
|
||||||
)
|
)
|
||||||
buf = BytesIO()
|
buf = dump_message(msg)
|
||||||
protobuf.dump_message(buf, msg)
|
retr = load_message(buf, RepeatedFields)
|
||||||
buf.seek(0)
|
|
||||||
retr = protobuf.load_message(buf, RepeatedFields)
|
|
||||||
|
|
||||||
assert retr == msg
|
assert retr == msg
|
||||||
|
|
||||||
@ -187,8 +188,7 @@ def test_enum_in_repeated(caplog):
|
|||||||
caplog.set_level(logging.INFO)
|
caplog.set_level(logging.INFO)
|
||||||
|
|
||||||
msg = RepeatedFields(enumlist=[0, 1, 2, 3])
|
msg = RepeatedFields(enumlist=[0, 1, 2, 3])
|
||||||
buf = BytesIO()
|
dump_message(msg)
|
||||||
protobuf.dump_message(buf, msg)
|
|
||||||
assert len(caplog.records) == 2
|
assert len(caplog.records) == 2
|
||||||
for record in caplog.records:
|
for record in caplog.records:
|
||||||
assert record.levelname == "INFO"
|
assert record.levelname == "INFO"
|
||||||
@ -202,8 +202,7 @@ def test_packed():
|
|||||||
field_len = len(packed_values)
|
field_len = len(packed_values)
|
||||||
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
|
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
|
||||||
|
|
||||||
buf = BytesIO(message_bytes)
|
msg = load_message(message_bytes, RepeatedFields)
|
||||||
msg = protobuf.load_message(buf, RepeatedFields)
|
|
||||||
assert msg
|
assert msg
|
||||||
assert msg.uintlist == values
|
assert msg.uintlist == values
|
||||||
assert not msg.enumlist
|
assert not msg.enumlist
|
||||||
@ -217,9 +216,76 @@ def test_packed_enum():
|
|||||||
field_len = len(packed_values)
|
field_len = len(packed_values)
|
||||||
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
|
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
|
||||||
|
|
||||||
buf = BytesIO(message_bytes)
|
msg = load_message(message_bytes, RepeatedFields)
|
||||||
msg = protobuf.load_message(buf, RepeatedFields)
|
|
||||||
assert msg
|
assert msg
|
||||||
assert msg.enumlist == values
|
assert msg.enumlist == values
|
||||||
assert not msg.uintlist
|
assert not msg.uintlist
|
||||||
assert not msg.strlist
|
assert not msg.strlist
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredFields(protobuf.MessageType):
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls):
|
||||||
|
return {
|
||||||
|
1: ("uvarint", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
|
||||||
|
2: ("nested", PrimitiveMessage, protobuf.FLAG_REQUIRED),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_required():
|
||||||
|
msg = RequiredFields(uvarint=3, nested=PrimitiveMessage())
|
||||||
|
buf = dump_message(msg)
|
||||||
|
msg_ok = load_message(buf, RequiredFields)
|
||||||
|
|
||||||
|
assert msg_ok == msg
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# cannot construct instance without the required fields
|
||||||
|
msg = RequiredFields(uvarint=3)
|
||||||
|
|
||||||
|
msg = RequiredFields(uvarint=3, nested=None)
|
||||||
|
# we can always encode an invalid message
|
||||||
|
buf = dump_message(msg)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# required field `nested` is also not sent
|
||||||
|
load_message(buf, RequiredFields)
|
||||||
|
|
||||||
|
msg = RequiredFields(uvarint=None, nested=PrimitiveMessage())
|
||||||
|
buf = dump_message(msg)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# required field `uvarint` is not sent
|
||||||
|
load_message(buf, RequiredFields)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultFields(protobuf.MessageType):
|
||||||
|
@classmethod
|
||||||
|
def get_fields(cls):
|
||||||
|
return {
|
||||||
|
1: ("uvarint", protobuf.UVarintType, 42),
|
||||||
|
2: ("svarint", protobuf.SVarintType, -42),
|
||||||
|
3: ("bool", protobuf.BoolType, True),
|
||||||
|
4: ("bytes", protobuf.BytesType, b"hello"),
|
||||||
|
5: ("unicode", protobuf.UnicodeType, "hello"),
|
||||||
|
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 5),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_default():
|
||||||
|
# load empty message
|
||||||
|
retr = load_message(b"", DefaultFields)
|
||||||
|
assert retr.uvarint == 42
|
||||||
|
assert retr.svarint == -42
|
||||||
|
assert retr.bool is True
|
||||||
|
assert retr.bytes == b"hello"
|
||||||
|
assert retr.unicode == "hello"
|
||||||
|
assert retr.enum == 5
|
||||||
|
|
||||||
|
msg = DefaultFields(uvarint=0)
|
||||||
|
buf = dump_message(msg)
|
||||||
|
retr = load_message(buf, DefaultFields)
|
||||||
|
assert retr.uvarint == 0
|
||||||
|
|
||||||
|
msg = DefaultFields(uvarint=None)
|
||||||
|
buf = dump_message(msg)
|
||||||
|
retr = load_message(buf, DefaultFields)
|
||||||
|
assert retr.uvarint == 42
|
||||||
|
@ -31,12 +31,12 @@ class SimpleMessage(protobuf.MessageType):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_fields(cls):
|
def get_fields(cls):
|
||||||
return {
|
return {
|
||||||
1: ("uvarint", protobuf.UVarintType, 0),
|
1: ("uvarint", protobuf.UVarintType, None),
|
||||||
2: ("svarint", protobuf.SVarintType, 0),
|
2: ("svarint", protobuf.SVarintType, None),
|
||||||
3: ("bool", protobuf.BoolType, 0),
|
3: ("bool", protobuf.BoolType, None),
|
||||||
4: ("bytes", protobuf.BytesType, 0),
|
4: ("bytes", protobuf.BytesType, None),
|
||||||
5: ("unicode", protobuf.UnicodeType, 0),
|
5: ("unicode", protobuf.UnicodeType, None),
|
||||||
6: ("enum", SimpleEnumType, 0),
|
6: ("enum", SimpleEnumType, None),
|
||||||
7: ("rep_int", protobuf.UVarintType, protobuf.FLAG_REPEATED),
|
7: ("rep_int", protobuf.UVarintType, protobuf.FLAG_REPEATED),
|
||||||
8: ("rep_str", protobuf.UnicodeType, protobuf.FLAG_REPEATED),
|
8: ("rep_str", protobuf.UnicodeType, protobuf.FLAG_REPEATED),
|
||||||
9: ("rep_enum", SimpleEnumType, protobuf.FLAG_REPEATED),
|
9: ("rep_enum", SimpleEnumType, protobuf.FLAG_REPEATED),
|
||||||
|
Loading…
Reference in New Issue
Block a user