1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-12-24 07:18:09 +00:00

feat(common): change all protobuf message type hints to strings

In order to support recursive protobuf messages, which will be needed by Cardano's native scripts.
[no changelog]
This commit is contained in:
gabrielkerekes 2021-07-21 13:28:52 +02:00 committed by matejcik
parent 725d1bd961
commit 4f7c6b3586
8 changed files with 5596 additions and 5576 deletions

View File

@ -379,42 +379,6 @@ class Descriptor:
for enum in self.enums:
self.convert_enum_value_names(enum)
self.depsort_messages()
def depsort_messages(self):
# sort messages according to dependencies
messages_dict = {m.name: m for m in reversed(self.messages)}
messages_sorted = []
seen_messages = set()
# pop first message
_, message = messages_dict.popitem()
while True:
dependent_type = next(
(
f.type_name
for f in message.fields
if f.is_message and f.type_name not in seen_messages
),
None,
)
if dependent_type:
# return current message to unprocessed
messages_dict[message.name] = message
# pop the dependency
message = messages_dict.pop(dependent_type)
else:
# see message
seen_messages.add(message.name)
messages_sorted.append(message)
if not messages_dict:
break
else:
_, message = messages_dict.popitem()
assert len(messages_sorted) == len(self.messages)
self.messages[:] = messages_sorted
def _filter_items(self, iter):
return [

View File

@ -41,20 +41,20 @@ def member_type(field):
class ${message.name}(protobuf.MessageType):
% if message.fields:
% for field in message.fields:
${field.name}: ${member_type(field)}
${field.name}: "${member_type(field)}"
% endfor
def __init__(
self,
*,
% for field in required_fields:
${field.name}: ${field.python_type},
${field.name}: "${field.python_type}",
% endfor
% for field in repeated_fields:
${field.name}: list[${field.python_type}] | None = None,
${field.name}: "list[${field.python_type}] | None" = None,
% endfor
% for field in optional_fields:
${field.name}: ${field.python_type} | None = None,
${field.name}: "${field.python_type} | None" = None,
% endfor
) -> None:
pass

File diff suppressed because it is too large Load Diff

View File

@ -21,22 +21,13 @@ class ${enum.name}(IntEnum):
required_fields = [f for f in message.fields if f.required]
repeated_fields = [f for f in message.fields if f.repeated]
optional_fields = [f for f in message.fields if f.optional]
def type_name(field):
if field.type_object is not None:
return field.type_name
else:
return '"' + field.type_name + '"'
%>\
class ${message.name}(protobuf.MessageType):
MESSAGE_WIRE_TYPE = ${message.wire_type}
% if message.fields:
FIELDS = {
% for field in message.fields:
${field.number}: protobuf.Field("${field.name}", ${type_name(field)}, repeated=${field.repeated}, required=${field.required}),
${field.number}: protobuf.Field("${field.name}", "${field.type_name}", repeated=${field.repeated}, required=${field.required}),
% endfor
}
@ -44,13 +35,13 @@ class ${message.name}(protobuf.MessageType):
self,
*,
% for field in required_fields:
${field.name}: ${field.python_type},
${field.name}: "${field.python_type}",
% endfor
% for field in repeated_fields:
${field.name}: Optional[List[${field.python_type}]] = None,
${field.name}: Optional[List["${field.python_type}"]] = None,
% endfor
% for field in optional_fields:
${field.name}: Optional[${field.python_type}] = ${field.default_value_repr},
${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr},
% endfor
) -> None:
% for field in repeated_fields:

File diff suppressed because it is too large Load Diff

View File

@ -144,7 +144,7 @@ REQUIRED_FIELD_PLACEHOLDER = object()
@attr.s(auto_attribs=True)
class Field:
name: str
type: Union[str, "MessageType", IntEnum]
type: str
repeated: bool = attr.ib(default=False)
required: bool = attr.ib(default=False)
default: object = attr.ib(default=None)
@ -154,10 +154,11 @@ class Field:
if self.type in WIRE_TYPES:
return WIRE_TYPES[self.type]
if safe_issubclass(self.type, MessageType):
field_type_object = get_field_type_object(self)
if safe_issubclass(field_type_object, MessageType):
return WIRE_TYPE_LENGTH
if safe_issubclass(self.type, IntEnum):
if safe_issubclass(field_type_object, IntEnum):
return WIRE_TYPE_INT
raise ValueError(f"Unrecognized type for field {self.name}")
@ -269,6 +270,15 @@ class CountingWriter:
return nwritten
def get_field_type_object(field: Field) -> Optional[type]:
from . import messages
field_type_object = getattr(messages, field.type, None)
if not safe_issubclass(field_type_object, (IntEnum, MessageType)):
return None
return field_type_object
def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
assert field.repeated, "Not decoding packed array into non-repeated field"
length = load_uvarint(reader)
@ -285,9 +295,11 @@ def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnum]:
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
value = load_uvarint(reader)
if safe_issubclass(field.type, IntEnum):
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, IntEnum):
try:
return field.type(value)
return field_type_object(value)
except ValueError as e:
# treat enum errors as warnings
LOG.info(f"On field {field.name}: {e}")
@ -328,8 +340,9 @@ def decode_length_delimited_field(
reader.readinto(buf)
return buf.decode()
if safe_issubclass(field.type, MessageType):
return load_message(LimitedReader(reader, value), field.type)
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
return load_message(LimitedReader(reader, value), field_type_object)
raise TypeError # field type is unknown
@ -420,16 +433,17 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
for svalue in fvalue:
dump_uvarint(writer, fkey)
if safe_issubclass(field.type, MessageType):
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
counter = CountingWriter()
dump_message(counter, svalue)
dump_uvarint(writer, counter.size)
dump_message(writer, svalue)
elif safe_issubclass(field.type, IntEnum):
if svalue not in field.type.__members__.values():
elif safe_issubclass(field_type_object, IntEnum):
if svalue not in field_type_object.__members__.values():
raise ValueError(
f"Value {svalue} in field {field.name} unknown for {field.type.__name__}"
f"Value {svalue} in field {field.name} unknown for {field.type}"
)
dump_uvarint(writer, svalue)
@ -531,15 +545,16 @@ def format_message(
def value_to_proto(field: Field, value: Any) -> Any:
if safe_issubclass(field.type, MessageType):
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
raise TypeError("value_to_proto only converts simple values")
if safe_issubclass(field.type, IntEnum):
if safe_issubclass(field_type_object, IntEnum):
if isinstance(value, str):
return field.type.__members__[value]
return field_type_object.__members__[value]
else:
try:
return field.type(value)
return field_type_object(value)
except ValueError as e:
LOG.info(f"On field {field.name}: {e}")
return int(value)
@ -572,8 +587,9 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
if not field.repeated:
value = [value]
if safe_issubclass(field.type, MessageType):
newvalue = [dict_to_proto(field.type, v) for v in value]
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
newvalue = [dict_to_proto(field_type_object, v) for v in value]
else:
newvalue = [value_to_proto(field, v) for v in value]

View File

@ -20,7 +20,7 @@ import logging
import pytest
from trezorlib import protobuf
from trezorlib import messages, protobuf
class SomeEnum(IntEnum):
@ -49,26 +49,64 @@ class PrimitiveMessage(protobuf.MessageType):
3: protobuf.Field("bool", "bool"),
4: protobuf.Field("bytes", "bytes"),
5: protobuf.Field("unicode", "string"),
6: protobuf.Field("enum", SomeEnum),
6: protobuf.Field("enum", "SomeEnum"),
}
class EnumMessageMoreValues(protobuf.MessageType):
FIELDS = {1: protobuf.Field("enum", WiderEnum)}
FIELDS = {1: protobuf.Field("enum", "WiderEnum")}
class EnumMessageLessValues(protobuf.MessageType):
FIELDS = {1: protobuf.Field("enum", NarrowerEnum)}
FIELDS = {1: protobuf.Field("enum", "NarrowerEnum")}
class RepeatedFields(protobuf.MessageType):
FIELDS = {
1: protobuf.Field("uintlist", "uint64", repeated=True),
2: protobuf.Field("enumlist", SomeEnum, repeated=True),
2: protobuf.Field("enumlist", "SomeEnum", repeated=True),
3: protobuf.Field("strlist", "string", repeated=True),
}
class RequiredFields(protobuf.MessageType):
FIELDS = {
1: protobuf.Field("uvarint", "uint64", required=True),
2: protobuf.Field("nested", "PrimitiveMessage", required=True),
}
class DefaultFields(protobuf.MessageType):
FIELDS = {
1: protobuf.Field("uvarint", "uint32", default=42),
2: protobuf.Field("svarint", "sint32", default=-42),
3: protobuf.Field("bool", "bool", default=True),
4: protobuf.Field("bytes", "bytes", default=b"hello"),
5: protobuf.Field("unicode", "string", default="hello"),
6: protobuf.Field("enum", "SomeEnum", default=SomeEnum.Five),
}
class RecursiveMessage(protobuf.MessageType):
FIELDS = {
1: protobuf.Field("uvarint", "uint64"),
2: protobuf.Field("recursivefield", "RecursiveMessage", required=False)
}
# message types are read from the messages module so we need to "include" these messages there for now
messages.SomeEnum = SomeEnum
messages.WiderEnum = WiderEnum
messages.NarrowerEnum = NarrowerEnum
messages.PrimitiveMessage = PrimitiveMessage
messages.EnumMessageMoreValues = EnumMessageMoreValues
messages.EnumMessageLessValues = EnumMessageLessValues
messages.RepeatedFields = RepeatedFields
messages.RequiredFields = RequiredFields
messages.DefaultFields = DefaultFields
messages.RecursiveMessage = RecursiveMessage
def load_uvarint(buffer):
reader = BytesIO(buffer)
return protobuf.load_uvarint(reader)
@ -224,13 +262,6 @@ def test_packed_enum():
assert not msg.strlist
class RequiredFields(protobuf.MessageType):
FIELDS = {
1: protobuf.Field("uvarint", "uint64", required=True),
2: protobuf.Field("nested", PrimitiveMessage, required=True),
}
def test_required():
msg = RequiredFields(uvarint=3, nested=PrimitiveMessage())
buf = dump_message(msg)
@ -258,17 +289,6 @@ def test_required():
load_message(buf, RequiredFields)
class DefaultFields(protobuf.MessageType):
FIELDS = {
1: protobuf.Field("uvarint", "uint32", default=42),
2: protobuf.Field("svarint", "sint32", default=-42),
3: protobuf.Field("bool", "bool", default=True),
4: protobuf.Field("bytes", "bytes", default=b"hello"),
5: protobuf.Field("unicode", "string", default="hello"),
6: protobuf.Field("enum", SomeEnum, default=SomeEnum.Five),
}
def test_default():
# load empty message
retr = load_message(b"", DefaultFields)
@ -288,3 +308,25 @@ def test_default():
buf = dump_message(msg)
retr = load_message(buf, DefaultFields)
assert retr.uvarint == 42
def test_recursive():
msg = RecursiveMessage(
uvarint=1,
recursivefield=RecursiveMessage(
uvarint=2,
recursivefield=RecursiveMessage(
uvarint=3
)
)
)
buf = dump_message(msg)
retr = load_message(buf, RecursiveMessage)
assert msg == retr
assert retr.uvarint == 1
assert type(retr.recursivefield) == RecursiveMessage
assert retr.recursivefield.uvarint == 2
assert type(retr.recursivefield.recursivefield) == RecursiveMessage
assert retr.recursivefield.recursivefield.uvarint == 3

View File

@ -18,7 +18,7 @@ from enum import IntEnum
import pytest
from trezorlib import protobuf
from trezorlib import messages, protobuf
class SimpleEnum(IntEnum):
@ -34,18 +34,18 @@ class SimpleMessage(protobuf.MessageType):
3: protobuf.Field("bool", "bool"),
4: protobuf.Field("bytes", "bytes"),
5: protobuf.Field("unicode", "string"),
6: protobuf.Field("enum", SimpleEnum),
6: protobuf.Field("enum", "SimpleEnum"),
7: protobuf.Field("rep_int", "uint64", repeated=True),
8: protobuf.Field("rep_str", "string", repeated=True),
9: protobuf.Field("rep_enum", SimpleEnum, repeated=True),
9: protobuf.Field("rep_enum", "SimpleEnum", repeated=True),
}
class NestedMessage(protobuf.MessageType):
FIELDS = {
1: protobuf.Field("scalar", "uint64"),
2: protobuf.Field("nested", SimpleMessage),
3: protobuf.Field("repeated", SimpleMessage, repeated=True),
2: protobuf.Field("nested", "SimpleMessage"),
3: protobuf.Field("repeated", "SimpleMessage", repeated=True),
}
@ -55,6 +55,13 @@ class RequiredFields(protobuf.MessageType):
}
# message types are read from the messages module so we need to "include" these messages there for now
messages.SimpleEnum = SimpleEnum
messages.SimpleMessage = SimpleMessage
messages.NestedMessage = NestedMessage
messages.RequiredFields = RequiredFields
def test_get_field():
# smoke test
field = SimpleMessage.get_field("bool")