1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-24 23:38:09 +00:00

feat!: implement protobuf required fields and default values

BREAKING CHANGE: this makes arguments to protobuf constructors
keyword-only, and arguments corresponding to required fields are now
mandatory
This commit is contained in:
matejcik 2020-09-14 12:47:06 +02:00 committed by matejcik
parent 0799b64b29
commit 90ee5f3d38
6 changed files with 258 additions and 118 deletions

View File

@ -51,7 +51,11 @@ class ProtoField:
type_name = attr.ib()
proto_type = attr.ib()
py_type = attr.ib()
py_inner_type = attr.ib()
default_value = attr.ib()
@property
def optional(self):
return not self.required and not self.repeated
@classmethod
def from_field(cls, descriptor, field):
@ -75,9 +79,18 @@ class ProtoField:
"Unknown field type {} for field {}".format(field.type, field.name)
) from None
py_inner_type = py_type
if repeated:
py_type = "List[{}]".format(py_type)
if not field.HasField("default_value"):
default_value = None
elif field.type == field.TYPE_ENUM:
default_value = str(descriptor.enum_types[type_name][field.default_value])
elif field.type == field.TYPE_STRING:
default_value = f'"{field.default_value}"'
elif field.type == field.TYPE_BYTES:
default_value = f'b"{field.default_value}"'
elif field.type == field.TYPE_BOOL:
default_value = "True" if field.default_value == "true" else "False"
else:
default_value = field.default_value
return cls(
name=field.name,
@ -88,7 +101,7 @@ class ProtoField:
type_name=type_name,
proto_type=proto_type,
py_type=py_type,
py_inner_type=py_inner_type,
default_value=default_value,
)
@ -216,20 +229,26 @@ class Descriptor:
yield self.create_message_import(name)
def create_init_method(self, fields):
required_fields = [f for f in fields if f.required]
repeated_fields = [f for f in fields if f.repeated]
optional_fields = [f for f in fields if f.optional]
# please keep the yields aligned
# fmt: off
... # https://github.com/ambv/black/issues/385
yield " def __init__("
yield " self,"
for field in fields:
yield f" {field.name}: {field.py_type} = None,"
yield " def __init__("
yield " self,"
yield " *,"
for field in required_fields:
yield f" {field.name}: {field.py_type},"
for field in repeated_fields:
yield f" {field.name}: List[{field.py_type}] = None,"
for field in optional_fields:
yield f" {field.name}: {field.py_type} = {field.default_value},"
yield " ) -> None:"
for field in fields:
if field.repeated:
yield f" self.{field.name} = {field.name} if {field.name} is not None else []"
else:
yield f" self.{field.name} = {field.name}"
for field in repeated_fields:
yield f" self.{field.name} = {field.name} if {field.name} is not None else []"
for field in required_fields + optional_fields:
yield f" self.{field.name} = {field.name}"
# fmt: on
def create_fields_method(self, fields):
@ -239,10 +258,8 @@ class Descriptor:
yield " return {"
for field in fields:
comments = []
if field.required:
comments.append("required")
if field.orig.HasField("default_value"):
comments.append("default={}".format(field.orig.default_value))
if field.default_value is not None:
comments.append(f"default={field.orig.default_value}")
if comments:
comment = " # " + " ".join(comments)
@ -251,8 +268,10 @@ class Descriptor:
if field.repeated:
flags = "p.FLAG_REPEATED"
elif field.required:
flags = "p.FLAG_REQUIRED"
else:
flags = "0"
flags = field.default_value
yield " {num}: ('{name}', {type}, {flags}),{comment}".format(
num=field.number,
@ -288,7 +307,7 @@ class Descriptor:
for field in all_enums:
allowed_values = self.enum_types[field.type_name].values()
valuestr = ", ".join(str(v) for v in sorted(allowed_values))
yield " {} = Literal[{}]".format(field.py_inner_type, valuestr)
yield " {} = Literal[{}]".format(field.py_type, valuestr)
yield " except ImportError:"
yield " pass"
@ -335,7 +354,9 @@ class Descriptor:
def process_messages(self, messages, include_deprecated=False):
for message in sorted(messages, key=lambda m: m.name):
self.write_to_file(message.name, self.process_message(message, include_deprecated))
self.write_to_file(
message.name, self.process_message(message, include_deprecated)
)
def process_enums(self, enums):
for enum in sorted(enums, key=lambda e: e.name):

View File

@ -3,8 +3,6 @@ Extremely minimal streaming codec for a subset of protobuf. Supports uint32,
bytes, string, embedded message and repeated fields.
"""
from micropython import const
if False:
from typing import Any, Dict, Iterable, List, Tuple, Type, TypeVar, Union
from typing_extensions import Protocol
@ -178,7 +176,8 @@ class LimitedReader:
return nread
FLAG_REPEATED = const(1)
FLAG_REPEATED = object()
FLAG_REQUIRED = object()
if False:
MessageTypeDef = Union[
@ -190,7 +189,7 @@ if False:
Type[UnicodeType],
Type[MessageType],
]
FieldDef = Tuple[str, MessageTypeDef, int]
FieldDef = Tuple[str, MessageTypeDef, Any]
FieldDict = Dict[int, FieldDef]
FieldCache = Dict[Type[MessageType], FieldDict]
@ -201,7 +200,6 @@ if False:
def load_message(
reader: Reader, msg_type: Type[LoadedMessageType], field_cache: FieldCache = None
) -> LoadedMessageType:
if field_cache is None:
field_cache = {}
fields = field_cache.get(msg_type)
@ -209,7 +207,13 @@ def load_message(
fields = msg_type.get_fields()
field_cache[msg_type] = fields
msg = msg_type()
# we need to avoid calling __init__, which enforces required arguments
msg = object.__new__(msg_type) # type: LoadedMessageType
# pre-seed the object with defaults
for fname, _, fdefault in fields.values():
if fdefault is FLAG_REPEATED:
fdefault = []
setattr(msg, fname, fdefault)
if False:
SingularValue = Union[int, bool, bytearray, str, MessageType]
@ -237,7 +241,7 @@ def load_message(
raise ValueError
continue
fname, ftype, fflags = field
fname, ftype, fdefault = field
if wtype != ftype.WIRE_TYPE:
raise TypeError # parsed wire type differs from the schema
@ -263,17 +267,16 @@ def load_message(
else:
raise TypeError # field type is unknown
if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname, [])
pvalue.append(fvalue)
fvalue = pvalue
setattr(msg, fname, fvalue)
if fdefault is FLAG_REPEATED:
getattr(msg, fname).append(fvalue)
else:
setattr(msg, fname, fvalue)
# fill missing fields
for tag in fields:
field = fields[tag]
if not hasattr(msg, field[0]):
setattr(msg, field[0], None)
for fname, _, _ in fields.values():
if getattr(msg, fname) is FLAG_REQUIRED:
# The message is intended to be user-facing when decoding from wire,
# but not when used internally.
raise ValueError("Required field '{}' was not received".format(fname))
return msg
@ -291,7 +294,7 @@ def dump_message(
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fflags = fields[ftag]
fname, ftype, fdefault = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
@ -299,7 +302,7 @@ def dump_message(
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
if fdefault is not FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
@ -356,7 +359,7 @@ def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
field_cache[type(msg)] = fields
for ftag in fields:
fname, ftype, fflags = fields[ftag]
fname, ftype, fdefault = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
@ -364,7 +367,7 @@ def count_message(msg: MessageType, field_cache: FieldCache = None) -> int:
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
if fdefault is not FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue

View File

@ -5,8 +5,8 @@ from trezor.utils import BufferReader, BufferWriter
class Message(protobuf.MessageType):
def __init__(self, uint_field: int = 0, enum_field: int = 0) -> None:
self.sint_field = uint_field
def __init__(self, sint_field: int = 0, enum_field: int = 0) -> None:
self.sint_field = sint_field
self.enum_field = enum_field
@classmethod
@ -17,6 +17,19 @@ class Message(protobuf.MessageType):
}
class MessageWithRequiredAndDefault(protobuf.MessageType):
def __init__(self, required_field, default_field) -> None:
self.required_field = required_field
self.default_field = default_field
@classmethod
def get_fields(cls):
return {
1: ("required_field", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
2: ("default_field", protobuf.SVarintType, -1),
}
def load_uvarint(data: bytes) -> int:
reader = BufferReader(data)
return protobuf.load_uvarint(reader)
@ -25,7 +38,18 @@ def load_uvarint(data: bytes) -> int:
def dump_uvarint(value: int) -> bytearray:
writer = BufferWriter(bytearray(16))
protobuf.dump_uvarint(writer, value)
return memoryview(writer.buffer)[:writer.offset]
return memoryview(writer.buffer)[: writer.offset]
def dump_message(msg: protobuf.MessageType) -> bytearray:
length = protobuf.count_message(msg)
buffer = bytearray(length)
protobuf.dump_message(BufferWriter(buffer), msg)
return buffer
def load_message(msg_type, buffer: bytearray) -> protobuf.MessageType:
return protobuf.load_message(BufferReader(buffer), msg_type)
class TestProtobuf(unittest.TestCase):
@ -58,30 +82,47 @@ class TestProtobuf(unittest.TestCase):
protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)), 1234567891011
)
self.assertEqual(
protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)), -2 ** 32
protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))), -(2 ** 32)
)
def test_validate_enum(self):
# ok message:
msg = Message(-42, 5)
length = protobuf.count_message(msg)
buffer_writer = BufferWriter(bytearray(length))
protobuf.dump_message(buffer_writer, msg)
buffer_reader = BufferReader(buffer_writer.buffer)
nmsg = protobuf.load_message(buffer_reader, Message)
msg_encoded = dump_message(msg)
nmsg = load_message(Message, msg_encoded)
self.assertEqual(msg.sint_field, nmsg.sint_field)
self.assertEqual(msg.enum_field, nmsg.enum_field)
# bad enum value:
buffer_writer.seek(0)
msg = Message(-42, 42)
# XXX this assumes the message will have equal size
protobuf.dump_message(buffer_writer, msg)
buffer_reader.seek(0)
msg_encoded = dump_message(msg)
with self.assertRaises(TypeError):
protobuf.load_message(buffer_reader, Message)
load_message(Message, msg_encoded)
def test_required(self):
msg = MessageWithRequiredAndDefault(required_field=1, default_field=2)
msg_encoded = dump_message(msg)
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
self.assertEqual(nmsg.required_field, 1)
self.assertEqual(nmsg.default_field, 2)
# try a message without the required_field
msg = MessageWithRequiredAndDefault(required_field=None, default_field=2)
# encoding always succeeds
msg_encoded = dump_message(msg)
with self.assertRaises(ValueError):
load_message(MessageWithRequiredAndDefault, msg_encoded)
# try a message without the default field
msg = MessageWithRequiredAndDefault(required_field=1, default_field=None)
msg_encoded = dump_message(msg)
nmsg = load_message(MessageWithRequiredAndDefault, msg_encoded)
self.assertEqual(nmsg.required_field, 1)
self.assertEqual(nmsg.default_field, -1)
if __name__ == "__main__":

View File

@ -49,7 +49,7 @@ FieldType = Union[
Type["UnicodeType"],
Type["BytesType"],
]
FieldInfo = Tuple[str, FieldType, int]
FieldInfo = Tuple[str, FieldType, Any]
MT = TypeVar("MT", bound="MessageType")
@ -239,12 +239,14 @@ class MessageType:
def _fill_missing(self) -> None:
# fill missing fields
for fname, _, fflags in self.get_fields().values():
for fname, _, fdefault in self.get_fields().values():
if not hasattr(self, fname):
if fflags & FLAG_REPEATED:
if fdefault is FLAG_REPEATED:
setattr(self, fname, [])
elif fdefault is FLAG_REQUIRED:
raise ValueError("value for required field is missing")
else:
setattr(self, fname, None)
setattr(self, fname, fdefault)
def ByteSize(self) -> int:
data = BytesIO()
@ -276,7 +278,8 @@ class CountingWriter:
return nwritten
FLAG_REPEATED = 1
FLAG_REPEATED = object()
FLAG_REQUIRED = object()
def decode_packed_array_field(ftype: FieldType, reader: Reader) -> List[Any]:
@ -325,7 +328,14 @@ def decode_length_delimited_field(
def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
fields = msg_type.get_fields()
msg = msg_type()
msg_dict = {}
# pre-seed the dict
for fname, _, fdefault in fields.values():
if fdefault is FLAG_REPEATED:
msg_dict[fname] = []
elif fdefault is not FLAG_REQUIRED:
msg_dict[fname] = fdefault
while True:
try:
@ -348,9 +358,9 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
raise ValueError
continue
fname, ftype, fflags = field
fname, ftype, fdefault = field
if wtype == 2 and ftype.WIRE_TYPE == 0 and fflags & FLAG_REPEATED:
if wtype == 2 and ftype.WIRE_TYPE == 0 and fdefault is FLAG_REPEATED:
# packed array
fvalues = decode_packed_array_field(ftype, reader)
@ -366,18 +376,17 @@ def load_message(reader: Reader, msg_type: Type[MT]) -> MT:
else:
raise TypeError # unknown wire type
if fflags & FLAG_REPEATED:
pvalue = getattr(msg, fname)
pvalue.extend(fvalues)
fvalue = pvalue
if fdefault is FLAG_REPEATED:
msg_dict[fname].extend(fvalues)
elif len(fvalues) != 1:
raise ValueError("Unexpected multiple values in non-repeating field")
else:
fvalue = fvalues[0]
msg_dict[fname] = fvalues[0]
setattr(msg, fname, fvalue)
return msg
for fname, _, fdefault in fields.values():
if fdefault is FLAG_REQUIRED and fname not in msg_dict:
raise ValueError # required field was not received
return msg_type(**msg_dict)
def dump_message(writer: Writer, msg: MessageType) -> None:
@ -386,7 +395,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
fields = mtype.get_fields()
for ftag in fields:
fname, ftype, fflags = fields[ftag]
fname, ftype, fdefault = fields[ftag]
fvalue = getattr(msg, fname, None)
if fvalue is None:
@ -394,7 +403,7 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
fkey = (ftag << 3) | ftype.WIRE_TYPE
if not fflags & FLAG_REPEATED:
if fdefault is not FLAG_REPEATED:
repvalue[0] = fvalue
fvalue = repvalue
@ -529,8 +538,8 @@ def value_to_proto(ftype: FieldType, value: Any) -> Any:
def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
params = {}
for fname, ftype, fflags in message_type.get_fields().values():
repeated = fflags & FLAG_REPEATED
for fname, ftype, fdefault in message_type.get_fields().values():
repeated = fdefault is FLAG_REPEATED
value = d.get(fname)
if value is None:
continue

View File

@ -26,25 +26,25 @@ class PrimitiveMessage(protobuf.MessageType):
@classmethod
def get_fields(cls):
return {
1: ("uvarint", protobuf.UVarintType, 0),
2: ("svarint", protobuf.SVarintType, 0),
3: ("bool", protobuf.BoolType, 0),
4: ("bytes", protobuf.BytesType, 0),
5: ("unicode", protobuf.UnicodeType, 0),
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 0),
1: ("uvarint", protobuf.UVarintType, None),
2: ("svarint", protobuf.SVarintType, None),
3: ("bool", protobuf.BoolType, None),
4: ("bytes", protobuf.BytesType, None),
5: ("unicode", protobuf.UnicodeType, None),
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), None),
}
class EnumMessageMoreValues(protobuf.MessageType):
@classmethod
def get_fields(cls):
return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), 0)}
return {1: ("enum", protobuf.EnumType("t", (0, 1, 2, 3, 4, 5)), None)}
class EnumMessageLessValues(protobuf.MessageType):
@classmethod
def get_fields(cls):
return {1: ("enum", protobuf.EnumType("t", (0, 5)), 0)}
return {1: ("enum", protobuf.EnumType("t", (0, 5)), None)}
class RepeatedFields(protobuf.MessageType):
@ -68,6 +68,17 @@ def dump_uvarint(value):
return writer.getvalue()
def load_message(buffer, msg_type):
reader = BytesIO(buffer)
return protobuf.load_message(reader, msg_type)
def dump_message(msg):
writer = BytesIO()
protobuf.dump_message(writer, msg)
return writer.getvalue()
def test_dump_uvarint():
assert dump_uvarint(0) == b"\x00"
assert dump_uvarint(1) == b"\x01"
@ -109,7 +120,7 @@ def test_sint_uint():
# roundtrip:
assert protobuf.uint_to_sint(protobuf.sint_to_uint(1234567891011)) == 1234567891011
assert protobuf.uint_to_sint(protobuf.sint_to_uint(-2 ** 32)) == -2 ** 32
assert protobuf.uint_to_sint(protobuf.sint_to_uint(-(2 ** 32))) == -(2 ** 32)
def test_simple_message():
@ -122,11 +133,8 @@ def test_simple_message():
enum=5,
)
buf = BytesIO()
protobuf.dump_message(buf, msg)
buf.seek(0)
retr = protobuf.load_message(buf, PrimitiveMessage)
buf = dump_message(msg)
retr = load_message(buf, PrimitiveMessage)
assert msg == retr
assert retr.uvarint == 12345678910
@ -141,18 +149,15 @@ def test_validate_enum(caplog):
caplog.set_level(logging.INFO)
# round-trip of a valid value
msg = EnumMessageMoreValues(enum=0)
buf = BytesIO()
protobuf.dump_message(buf, msg)
buf.seek(0)
retr = protobuf.load_message(buf, EnumMessageLessValues)
buf = dump_message(msg)
retr = load_message(buf, EnumMessageLessValues)
assert retr.enum == msg.enum
assert not caplog.records
# dumping an invalid enum value fails
msg.enum = 19
buf.seek(0)
protobuf.dump_message(buf, msg)
buf = dump_message(msg)
assert len(caplog.records) == 1
record = caplog.records.pop(0)
@ -160,10 +165,8 @@ def test_validate_enum(caplog):
assert record.getMessage() == "Value 19 unknown for type t"
msg.enum = 3
buf.seek(0)
protobuf.dump_message(buf, msg)
buf.seek(0)
protobuf.load_message(buf, EnumMessageLessValues)
buf = dump_message(msg)
load_message(buf, EnumMessageLessValues)
assert len(caplog.records) == 1
record = caplog.records.pop(0)
@ -175,10 +178,8 @@ def test_repeated():
msg = RepeatedFields(
uintlist=[1, 2, 3], enumlist=[0, 1, 0, 1], strlist=["hello", "world"]
)
buf = BytesIO()
protobuf.dump_message(buf, msg)
buf.seek(0)
retr = protobuf.load_message(buf, RepeatedFields)
buf = dump_message(msg)
retr = load_message(buf, RepeatedFields)
assert retr == msg
@ -187,8 +188,7 @@ def test_enum_in_repeated(caplog):
caplog.set_level(logging.INFO)
msg = RepeatedFields(enumlist=[0, 1, 2, 3])
buf = BytesIO()
protobuf.dump_message(buf, msg)
dump_message(msg)
assert len(caplog.records) == 2
for record in caplog.records:
assert record.levelname == "INFO"
@ -202,8 +202,7 @@ def test_packed():
field_len = len(packed_values)
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
buf = BytesIO(message_bytes)
msg = protobuf.load_message(buf, RepeatedFields)
msg = load_message(message_bytes, RepeatedFields)
assert msg
assert msg.uintlist == values
assert not msg.enumlist
@ -217,9 +216,76 @@ def test_packed_enum():
field_len = len(packed_values)
message_bytes = dump_uvarint(field_id) + dump_uvarint(field_len) + packed_values
buf = BytesIO(message_bytes)
msg = protobuf.load_message(buf, RepeatedFields)
msg = load_message(message_bytes, RepeatedFields)
assert msg
assert msg.enumlist == values
assert not msg.uintlist
assert not msg.strlist
class RequiredFields(protobuf.MessageType):
@classmethod
def get_fields(cls):
return {
1: ("uvarint", protobuf.UVarintType, protobuf.FLAG_REQUIRED),
2: ("nested", PrimitiveMessage, protobuf.FLAG_REQUIRED),
}
def test_required():
msg = RequiredFields(uvarint=3, nested=PrimitiveMessage())
buf = dump_message(msg)
msg_ok = load_message(buf, RequiredFields)
assert msg_ok == msg
with pytest.raises(ValueError):
# cannot construct instance without the required fields
msg = RequiredFields(uvarint=3)
msg = RequiredFields(uvarint=3, nested=None)
# we can always encode an invalid message
buf = dump_message(msg)
with pytest.raises(ValueError):
# required field `nested` is also not sent
load_message(buf, RequiredFields)
msg = RequiredFields(uvarint=None, nested=PrimitiveMessage())
buf = dump_message(msg)
with pytest.raises(ValueError):
# required field `uvarint` is not sent
load_message(buf, RequiredFields)
class DefaultFields(protobuf.MessageType):
@classmethod
def get_fields(cls):
return {
1: ("uvarint", protobuf.UVarintType, 42),
2: ("svarint", protobuf.SVarintType, -42),
3: ("bool", protobuf.BoolType, True),
4: ("bytes", protobuf.BytesType, b"hello"),
5: ("unicode", protobuf.UnicodeType, "hello"),
6: ("enum", protobuf.EnumType("t", (0, 5, 25)), 5),
}
def test_default():
# load empty message
retr = load_message(b"", DefaultFields)
assert retr.uvarint == 42
assert retr.svarint == -42
assert retr.bool is True
assert retr.bytes == b"hello"
assert retr.unicode == "hello"
assert retr.enum == 5
msg = DefaultFields(uvarint=0)
buf = dump_message(msg)
retr = load_message(buf, DefaultFields)
assert retr.uvarint == 0
msg = DefaultFields(uvarint=None)
buf = dump_message(msg)
retr = load_message(buf, DefaultFields)
assert retr.uvarint == 42

View File

@ -31,12 +31,12 @@ class SimpleMessage(protobuf.MessageType):
@classmethod
def get_fields(cls):
return {
1: ("uvarint", protobuf.UVarintType, 0),
2: ("svarint", protobuf.SVarintType, 0),
3: ("bool", protobuf.BoolType, 0),
4: ("bytes", protobuf.BytesType, 0),
5: ("unicode", protobuf.UnicodeType, 0),
6: ("enum", SimpleEnumType, 0),
1: ("uvarint", protobuf.UVarintType, None),
2: ("svarint", protobuf.SVarintType, None),
3: ("bool", protobuf.BoolType, None),
4: ("bytes", protobuf.BytesType, None),
5: ("unicode", protobuf.UnicodeType, None),
6: ("enum", SimpleEnumType, None),
7: ("rep_int", protobuf.UVarintType, protobuf.FLAG_REPEATED),
8: ("rep_str", protobuf.UnicodeType, protobuf.FLAG_REPEATED),
9: ("rep_enum", SimpleEnumType, protobuf.FLAG_REPEATED),