From 90ee5f3d38a89f30edf37e87b8545ce2f2f65ebe Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 14 Sep 2020 12:47:06 +0200 Subject: [PATCH] feat!: implement protobuf required fields and default values BREAKING CHANGE: this makes arguments to protobuf constructors keyword-only, and arguments corresponding to required fields are now mandatory --- common/protob/pb2py | 65 ++++++++---- core/src/protobuf.py | 45 +++++---- core/tests/test_protobuf.py | 71 ++++++++++--- python/src/trezorlib/protobuf.py | 49 +++++---- python/tests/test_protobuf_encoding.py | 134 ++++++++++++++++++------- python/tests/test_protobuf_misc.py | 12 +-- 6 files changed, 258 insertions(+), 118 deletions(-) diff --git a/common/protob/pb2py b/common/protob/pb2py index bfc333b7fd..a21b90bbf2 100755 --- a/common/protob/pb2py +++ b/common/protob/pb2py @@ -51,7 +51,11 @@ class ProtoField: type_name = attr.ib() proto_type = attr.ib() py_type = attr.ib() - py_inner_type = attr.ib() + default_value = attr.ib() + + @property + def optional(self): + return not self.required and not self.repeated @classmethod def from_field(cls, descriptor, field): @@ -75,9 +79,18 @@ class ProtoField: "Unknown field type {} for field {}".format(field.type, field.name) ) from None - py_inner_type = py_type - if repeated: - py_type = "List[{}]".format(py_type) + if not field.HasField("default_value"): + default_value = None + elif field.type == field.TYPE_ENUM: + default_value = str(descriptor.enum_types[type_name][field.default_value]) + elif field.type == field.TYPE_STRING: + default_value = f'"{field.default_value}"' + elif field.type == field.TYPE_BYTES: + default_value = f'b"{field.default_value}"' + elif field.type == field.TYPE_BOOL: + default_value = "True" if field.default_value == "true" else "False" + else: + default_value = field.default_value return cls( name=field.name, @@ -88,7 +101,7 @@ class ProtoField: type_name=type_name, proto_type=proto_type, py_type=py_type, - py_inner_type=py_inner_type, + default_value=default_value, ) @@ -216,20 +229,26 @@ class Descriptor: yield self.create_message_import(name) def create_init_method(self, fields): + required_fields = [f for f in fields if f.required] + repeated_fields = [f for f in fields if f.repeated] + optional_fields = [f for f in fields if f.optional] # please keep the yields aligned # fmt: off - ... # https://github.com/ambv/black/issues/385 - yield " def __init__(" - yield " self," - for field in fields: - yield f" {field.name}: {field.py_type} = None," + yield " def __init__(" + yield " self," + yield " *," + for field in required_fields: + yield f" {field.name}: {field.py_type}," + for field in repeated_fields: + yield f" {field.name}: List[{field.py_type}] = None," + for field in optional_fields: + yield f" {field.name}: {field.py_type} = {field.default_value}," yield " ) -> None:" - for field in fields: - if field.repeated: - yield f" self.{field.name} = {field.name} if {field.name} is not None else []" - else: - yield f" self.{field.name} = {field.name}" + for field in repeated_fields: + yield f" self.{field.name} = {field.name} if {field.name} is not None else []" + for field in required_fields + optional_fields: + yield f" self.{field.name} = {field.name}" # fmt: on def create_fields_method(self, fields): @@ -239,10 +258,8 @@ class Descriptor: yield " return {" for field in fields: comments = [] - if field.required: - comments.append("required") - if field.orig.HasField("default_value"): - comments.append("default={}".format(field.orig.default_value)) + if field.default_value is not None: + comments.append(f"default={field.orig.default_value}") if comments: comment = " # " + " ".join(comments) @@ -251,8 +268,10 @@ class Descriptor: if field.repeated: flags = "p.FLAG_REPEATED" + elif field.required: + flags = "p.FLAG_REQUIRED" else: - flags = "0" + flags = field.default_value yield " {num}: ('{name}', {type}, {flags}),{comment}".format( num=field.number, @@ -288,7 +307,7 @@ class Descriptor: for field in all_enums: allowed_values = self.enum_types[field.type_name].values() valuestr = ", ".join(str(v) for v in sorted(allowed_values)) - yield " {} = Literal[{}]".format(field.py_inner_type, valuestr) + yield " {} = Literal[{}]".format(field.py_type, valuestr) yield " except ImportError:" yield " pass" @@ -335,7 +354,9 @@ class Descriptor: def process_messages(self, messages, include_deprecated=False): for message in sorted(messages, key=lambda m: m.name): - self.write_to_file(message.name, self.process_message(message, include_deprecated)) + self.write_to_file( + message.name, self.process_message(message, include_deprecated) + ) def process_enums(self, enums): for enum in sorted(enums, key=lambda e: e.name): diff --git a/core/src/protobuf.py b/core/src/protobuf.py index 40c93dc6cb..b55176f433 100644 --- a/core/src/protobuf.py +++ b/core/src/protobuf.py @@ -3,8 +3,6 @@ Extremely minimal streaming codec for a subset of protobuf. Supports uint32, bytes, string, embedded message and repeated fields. """ -from micropython import const - if False: from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union from typing_extensions import Protocol @@ -178,7 +176,8 @@ class LimitedReader: return nread -FLAG_REPEATED = const(1) +FLAG_REPEATED = object() +FLAG_REQUIRED = object() if False: MessageTypeDef = Union[ @@ -190,7 +189,7 @@ if False: Type[UnicodeType], Type[MessageType], ] - FieldDef = Tuple[str, MessageTypeDef, int] + FieldDef = Tuple[str, MessageTypeDef, Any] FieldDict = Dict[int, FieldDef] FieldCache = Dict[Type[MessageType], FieldDict] @@ -201,7 +200,6 @@ if False: def load_message( reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None ) -> LoadedMessageType: - if field_cache is None: field_cache = {} fields = field_cache.get(msg_type) @@ -209,7 +207,13 @@ def load_message( fields = msg_type.get_fields() field_cache[msg_type] = fields - msg = msg_type() + # we need to avoid calling __init__, which enforces required arguments + msg = object.__new__(msg_type) # type: LoadedMessageType + # pre-seed the object with defaults + for fname, _, fdefault in fields.values(): + if fdefault is FLAG_REPEATED: + fdefault = [] + setattr(msg, fname, fdefault) if False: SingularValue = Union[int, bool, bytearray, str, MessageType] @@ -237,7 +241,7 @@ def load_message( raise ValueError continue - fname, ftype, fflags = field + fname, ftype, fdefault = field if wtype != ftype.WIRE_TYPE: raise TypeError # parsed wire type differs from the schema @@ -263,17 +267,16 @@ def load_message( else: raise TypeError # field type is unknown - if fflags & FLAG_REPEATED: - pvalue = getattr(msg, fname, []) - pvalue.append(fvalue) - fvalue = pvalue - setattr(msg, fname, fvalue) + if fdefault is FLAG_REPEATED: + getattr(msg, fname).append(fvalue) + else: + setattr(msg, fname, fvalue) - # fill missing fields - for tag in fields: - field = fields[tag] - if not hasattr(msg, field[0]): - setattr(msg, field[0], None) + for fname, _, _ in fields.values(): + if getattr(msg, fname) is FLAG_REQUIRED: + # The message is intended to be user-facing when decoding from wire, + # but not when used internally. + raise ValueError("Required field '{}' was not received".format(fname)) return msg @@ -291,7 +294,7 @@ def dump_message( field_cache[type(msg)] = fields for ftag in fields: - fname, ftype, fflags = fields[ftag] + fname, ftype, fdefault = fields[ftag] fvalue = getattr(msg, fname, None) if fvalue is None: @@ -299,7 +302,7 @@ def dump_message( fkey = (ftag << 3) | ftype.WIRE_TYPE - if not fflags & FLAG_REPEATED: + if fdefault is not FLAG_REPEATED: repvalue[0] = fvalue fvalue = repvalue @@ -356,7 +359,7 @@ def count_message(msg: MessageType, field_cache: FieldCache = None) -> int: field_cache[type(msg)] = fields for ftag in fields: - fname, ftype, fflags = fields[ftag] + fname, ftype, fdefault = fields[ftag] fvalue = getattr(msg, fname, None) if fvalue is None: @@ -364,7 +367,7 @@ def count_message(msg: MessageType, field_cache: FieldCache = None) -> int: fkey = (ftag << 3) | ftype.WIRE_TYPE - if not fflags & FLAG_REPEATED: + if fdefault is not FLAG_REPEATED: repvalue[0] = fvalue fvalue = repvalue diff --git a/core/tests/test_protobuf.py b/core/tests/test_protobuf.py index 3089469cb0..180e6eca19 100644 --- a/core/tests/test_protobuf.py +++ b/core/tests/test_protobuf.py @@ -5,8 +5,8 @@ from trezor.utils import BufferReader, BufferWriter class Message(protobuf.MessageType): - def __init__(self, uint_field: int = 0, enum_field: int = 0) -> None: - self.sint_field = uint_field + def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None: + self.sint_field = sint_field self.enum_field = enum_field @classmethod @@ -17,6 +17,19 @@ class Message(protobuf.MessageType): } +class MessageWithRequiredAndDefault(protobuf.MessageType): + def __init__(self, required_field, default_field) -> None: + self.required_field = required_field + self.default_field = default_field + + @classmethod + def get_fields(cls): + return { + 1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED), + 2: ("default_field", protobuf.SVarintType, -1), + } + + def load_uvarint(data: bytes) -> int: reader = BufferReader(data) return protobuf.load_uvarint(reader) @@ -25,7 +38,18 @@ def load_uvarint(data: bytes) -> int: def dump_uvarint(value: int) -> bytearray: writer = BufferWriter(bytearray(16)) protobuf.dump_uvarint(writer, value) - return memoryview(writer.buffer)[:writer.offset] + return memoryview(writer.buffer)[: writer.offset] + + +def dump_message(msg: protobuf.MessageType) -> bytearray: + length = protobuf.count_message(msg) + buffer = bytearray(length) + protobuf.dump_message(BufferWriter(buffer), msg) + return buffer + + +def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType: + return protobuf.load_message(BufferReader(buffer), msg_type) class TestProtobuf(unittest.TestCase): @@ -58,30 +82,47 @@ class TestProtobuf(unittest.TestCase): protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011 ) self.assertEqual( - protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)), -2 ** 32 + protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32) ) def test_validate_enum(self): # ok message: msg = Message(-42, 5) - length = protobuf.count_message(msg) - buffer_writer = BufferWriter(bytearray(length)) - protobuf.dump_message(buffer_writer, msg) - - buffer_reader = BufferReader(buffer_writer.buffer) - nmsg = protobuf.load_message(buffer_reader, Message) + msg_encoded = dump_message(msg) + nmsg = load_message(Message, msg_encoded) self.assertEqual(msg.sint_field, nmsg.sint_field) self.assertEqual(msg.enum_field, nmsg.enum_field) # bad enum value: - buffer_writer.seek(0) msg = Message(-42, 42) - # XXX this assumes the message will have equal size - protobuf.dump_message(buffer_writer, msg) - buffer_reader.seek(0) + msg_encoded = dump_message(msg) with self.assertRaises(TypeError): - protobuf.load_message(buffer_reader, Message) + load_message(Message, msg_encoded) + + def test_required(self): + msg = MessageWithRequiredAndDefault(required_field=1, default_field=2) + msg_encoded = dump_message(msg) + nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded) + + self.assertEqual(nmsg.required_field, 1) + self.assertEqual(nmsg.default_field, 2) + + # try a message without the required_field + msg = MessageWithRequiredAndDefault(required_field=None, default_field=2) + # encoding always succeeds + msg_encoded = dump_message(msg) + with self.assertRaises(ValueError): + load_message(MessageWithRequiredAndDefault, msg_encoded) + + # try a message without the default field + msg = MessageWithRequiredAndDefault(required_field=1, default_field=None) + msg_encoded = dump_message(msg) + nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded) + + self.assertEqual(nmsg.required_field, 1) + self.assertEqual(nmsg.default_field, -1) + if __name__ == "__main__": diff --git a/python/src/trezorlib/protobuf.py b/python/src/trezorlib/protobuf.py index f9e49bd5e5..c388063315 100644 --- a/python/src/trezorlib/protobuf.py +++ b/python/src/trezorlib/protobuf.py @@ -49,7 +49,7 @@ FieldType = Union[ Type["UnicodeType"], Type["BytesType"], ] -FieldInfo = Tuple[str, FieldType, int] +FieldInfo = Tuple[str, FieldType, Any] MT = TypeVar("MT", bound="MessageType") @@ -239,12 +239,14 @@ class MessageType: def _fill_missing(self) -> None: # fill missing fields - for fname, _, fflags in self.get_fields().values(): + for fname, _, fdefault in self.get_fields().values(): if not hasattr(self, fname): - if fflags & FLAG_REPEATED: + if fdefault is FLAG_REPEATED: setattr(self, fname, []) + elif fdefault is FLAG_REQUIRED: + raise ValueError("value for required field is missing") else: - setattr(self, fname, None) + setattr(self, fname, fdefault) def ByteSize(self) -> int: data = BytesIO() @@ -276,7 +278,8 @@ class CountingWriter: return nwritten -FLAG_REPEATED = 1 +FLAG_REPEATED = object() +FLAG_REQUIRED = object() def decode_packed_array_field(ftype: FieldType, reader: Reader) -> List[Any]: @@ -325,7 +328,14 @@ def decode_length_delimited_field( def load_message(reader: Reader, msg_type: Type[MT]) -> MT: fields = msg_type.get_fields() - msg = msg_type() + + msg_dict = {} + # pre-seed the dict + for fname, _, fdefault in fields.values(): + if fdefault is FLAG_REPEATED: + msg_dict[fname] = [] + elif fdefault is not FLAG_REQUIRED: + msg_dict[fname] = fdefault while True: try: @@ -348,9 +358,9 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT: raise ValueError continue - fname, ftype, fflags = field + fname, ftype, fdefault = field - if wtype == 2 and ftype.WIRE_TYPE == 0 and fflags & FLAG_REPEATED: + if wtype == 2 and ftype.WIRE_TYPE == 0 and fdefault is FLAG_REPEATED: # packed array fvalues = decode_packed_array_field(ftype, reader) @@ -366,18 +376,17 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT: else: raise TypeError # unknown wire type - if fflags & FLAG_REPEATED: - pvalue = getattr(msg, fname) - pvalue.extend(fvalues) - fvalue = pvalue + if fdefault is FLAG_REPEATED: + msg_dict[fname].extend(fvalues) elif len(fvalues) != 1: raise ValueError("Unexpected multiple values in non-repeating field") else: - fvalue = fvalues[0] + msg_dict[fname] = fvalues[0] - setattr(msg, fname, fvalue) - - return msg + for fname, _, fdefault in fields.values(): + if fdefault is FLAG_REQUIRED and fname not in msg_dict: + raise ValueError # required field was not received + return msg_type(**msg_dict) def dump_message(writer: Writer, msg: MessageType) -> None: @@ -386,7 +395,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None: fields = mtype.get_fields() for ftag in fields: - fname, ftype, fflags = fields[ftag] + fname, ftype, fdefault = fields[ftag] fvalue = getattr(msg, fname, None) if fvalue is None: @@ -394,7 +403,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None: fkey = (ftag << 3) | ftype.WIRE_TYPE - if not fflags & FLAG_REPEATED: + if fdefault is not FLAG_REPEATED: repvalue[0] = fvalue fvalue = repvalue @@ -529,8 +538,8 @@ def value_to_proto(ftype: FieldType, value: Any) -> Any: def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT: params = {} - for fname, ftype, fflags in message_type.get_fields().values(): - repeated = fflags & FLAG_REPEATED + for fname, ftype, fdefault in message_type.get_fields().values(): + repeated = fdefault is FLAG_REPEATED value = d.get(fname) if value is None: continue diff --git a/python/tests/test_protobuf_encoding.py b/python/tests/test_protobuf_encoding.py index 7d7709f28e..7a3bb80b86 100644 --- a/python/tests/test_protobuf_encoding.py +++ b/python/tests/test_protobuf_encoding.py @@ -26,25 +26,25 @@ class PrimitiveMessage(protobuf.MessageType): @classmethod def get_fields(cls): return { - 1: ("uvarint", protobuf.UVarintType, 0), - 2: ("svarint", protobuf.SVarintType, 0), - 3: ("bool", protobuf.BoolType, 0), - 4: ("bytes", protobuf.BytesType, 0), - 5: ("unicode", protobuf.UnicodeType, 0), - 6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 0), + 1: ("uvarint", protobuf.UVarintType, None), + 2: ("svarint", protobuf.SVarintType, None), + 3: ("bool", protobuf.BoolType, None), + 4: ("bytes", protobuf.BytesType, None), + 5: ("unicode", protobuf.UnicodeType, None), + 6: ("enum", protobuf.EnumType("t", (0, 5, 25)), None), } class EnumMessageMoreValues(protobuf.MessageType): @classmethod def get_fields(cls): - return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), 0)} + return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), None)} class EnumMessageLessValues(protobuf.MessageType): @classmethod def get_fields(cls): - return {1: ("enum", protobuf.EnumType("t", (0, 5)), 0)} + return {1: ("enum", protobuf.EnumType("t", (0, 5)), None)} class RepeatedFields(protobuf.MessageType): @@ -68,6 +68,17 @@ def dump_uvarint(value): return writer.getvalue() +def load_message(buffer, msg_type): + reader = BytesIO(buffer) + return protobuf.load_message(reader, msg_type) + + +def dump_message(msg): + writer = BytesIO() + protobuf.dump_message(writer, msg) + return writer.getvalue() + + def test_dump_uvarint(): assert dump_uvarint(0) == b"\x00" assert dump_uvarint(1) == b"\x01" @@ -109,7 +120,7 @@ def test_sint_uint(): # roundtrip: assert protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)) == 1234567891011 - assert protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)) == -2 ** 32 + assert protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))) == -(2 ** 32) def test_simple_message(): @@ -122,11 +133,8 @@ def test_simple_message(): enum=5, ) - buf = BytesIO() - - protobuf.dump_message(buf, msg) - buf.seek(0) - retr = protobuf.load_message(buf, PrimitiveMessage) + buf = dump_message(msg) + retr = load_message(buf, PrimitiveMessage) assert msg == retr assert retr.uvarint == 12345678910 @@ -141,18 +149,15 @@ def test_validate_enum(caplog): caplog.set_level(logging.INFO) # round-trip of a valid value msg = EnumMessageMoreValues(enum=0) - buf = BytesIO() - protobuf.dump_message(buf, msg) - buf.seek(0) - retr = protobuf.load_message(buf, EnumMessageLessValues) + buf = dump_message(msg) + retr = load_message(buf, EnumMessageLessValues) assert retr.enum == msg.enum assert not caplog.records # dumping an invalid enum value fails msg.enum = 19 - buf.seek(0) - protobuf.dump_message(buf, msg) + buf = dump_message(msg) assert len(caplog.records) == 1 record = caplog.records.pop(0) @@ -160,10 +165,8 @@ def test_validate_enum(caplog): assert record.getMessage() == "Value 19 unknown for type t" msg.enum = 3 - buf.seek(0) - protobuf.dump_message(buf, msg) - buf.seek(0) - protobuf.load_message(buf, EnumMessageLessValues) + buf = dump_message(msg) + load_message(buf, EnumMessageLessValues) assert len(caplog.records) == 1 record = caplog.records.pop(0) @@ -175,10 +178,8 @@ def test_repeated(): msg = RepeatedFields( uintlist=[1, 2, 3], enumlist=[0, 1, 0, 1], strlist=["hello", "world"] ) - buf = BytesIO() - protobuf.dump_message(buf, msg) - buf.seek(0) - retr = protobuf.load_message(buf, RepeatedFields) + buf = dump_message(msg) + retr = load_message(buf, RepeatedFields) assert retr == msg @@ -187,8 +188,7 @@ def test_enum_in_repeated(caplog): caplog.set_level(logging.INFO) msg = RepeatedFields(enumlist=[0, 1, 2, 3]) - buf = BytesIO() - protobuf.dump_message(buf, msg) + dump_message(msg) assert len(caplog.records) == 2 for record in caplog.records: assert record.levelname == "INFO" @@ -202,8 +202,7 @@ def test_packed(): field_len = len(packed_values) message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values - buf = BytesIO(message_bytes) - msg = protobuf.load_message(buf, RepeatedFields) + msg = load_message(message_bytes, RepeatedFields) assert msg assert msg.uintlist == values assert not msg.enumlist @@ -217,9 +216,76 @@ def test_packed_enum(): field_len = len(packed_values) message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values - buf = BytesIO(message_bytes) - msg = protobuf.load_message(buf, RepeatedFields) + msg = load_message(message_bytes, RepeatedFields) assert msg assert msg.enumlist == values assert not msg.uintlist assert not msg.strlist + + +class RequiredFields(protobuf.MessageType): + @classmethod + def get_fields(cls): + return { + 1: ("uvarint", protobuf.UVarintType, protobuf.FLAG_REQUIRED), + 2: ("nested", PrimitiveMessage, protobuf.FLAG_REQUIRED), + } + + +def test_required(): + msg = RequiredFields(uvarint=3, nested=PrimitiveMessage()) + buf = dump_message(msg) + msg_ok = load_message(buf, RequiredFields) + + assert msg_ok == msg + + with pytest.raises(ValueError): + # cannot construct instance without the required fields + msg = RequiredFields(uvarint=3) + + msg = RequiredFields(uvarint=3, nested=None) + # we can always encode an invalid message + buf = dump_message(msg) + with pytest.raises(ValueError): + # required field `nested` is also not sent + load_message(buf, RequiredFields) + + msg = RequiredFields(uvarint=None, nested=PrimitiveMessage()) + buf = dump_message(msg) + with pytest.raises(ValueError): + # required field `uvarint` is not sent + load_message(buf, RequiredFields) + + +class DefaultFields(protobuf.MessageType): + @classmethod + def get_fields(cls): + return { + 1: ("uvarint", protobuf.UVarintType, 42), + 2: ("svarint", protobuf.SVarintType, -42), + 3: ("bool", protobuf.BoolType, True), + 4: ("bytes", protobuf.BytesType, b"hello"), + 5: ("unicode", protobuf.UnicodeType, "hello"), + 6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 5), + } + + +def test_default(): + # load empty message + retr = load_message(b"", DefaultFields) + assert retr.uvarint == 42 + assert retr.svarint == -42 + assert retr.bool is True + assert retr.bytes == b"hello" + assert retr.unicode == "hello" + assert retr.enum == 5 + + msg = DefaultFields(uvarint=0) + buf = dump_message(msg) + retr = load_message(buf, DefaultFields) + assert retr.uvarint == 0 + + msg = DefaultFields(uvarint=None) + buf = dump_message(msg) + retr = load_message(buf, DefaultFields) + assert retr.uvarint == 42 diff --git a/python/tests/test_protobuf_misc.py b/python/tests/test_protobuf_misc.py index 73d6efd44b..644d5becba 100644 --- a/python/tests/test_protobuf_misc.py +++ b/python/tests/test_protobuf_misc.py @@ -31,12 +31,12 @@ class SimpleMessage(protobuf.MessageType): @classmethod def get_fields(cls): return { - 1: ("uvarint", protobuf.UVarintType, 0), - 2: ("svarint", protobuf.SVarintType, 0), - 3: ("bool", protobuf.BoolType, 0), - 4: ("bytes", protobuf.BytesType, 0), - 5: ("unicode", protobuf.UnicodeType, 0), - 6: ("enum", SimpleEnumType, 0), + 1: ("uvarint", protobuf.UVarintType, None), + 2: ("svarint", protobuf.SVarintType, None), + 3: ("bool", protobuf.BoolType, None), + 4: ("bytes", protobuf.BytesType, None), + 5: ("unicode", protobuf.UnicodeType, None), + 6: ("enum", SimpleEnumType, None), 7: ("rep_int", protobuf.UVarintType, protobuf.FLAG_REPEATED), 8: ("rep_str", protobuf.UnicodeType, protobuf.FLAG_REPEATED), 9: ("rep_enum", SimpleEnumType, protobuf.FLAG_REPEATED),