mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-11 16:00:57 +00:00
feat(common): change all protobuf message type hints to strings
In order to support recursive protobuf messages, which will be needed by Cardano's native scripts. [no changelog]
This commit is contained in:
parent
725d1bd961
commit
4f7c6b3586
@ -379,42 +379,6 @@ class Descriptor:
|
||||
for enum in self.enums:
|
||||
self.convert_enum_value_names(enum)
|
||||
|
||||
self.depsort_messages()
|
||||
|
||||
def depsort_messages(self):
|
||||
# sort messages according to dependencies
|
||||
messages_dict = {m.name: m for m in reversed(self.messages)}
|
||||
messages_sorted = []
|
||||
seen_messages = set()
|
||||
|
||||
# pop first message
|
||||
_, message = messages_dict.popitem()
|
||||
|
||||
while True:
|
||||
dependent_type = next(
|
||||
(
|
||||
f.type_name
|
||||
for f in message.fields
|
||||
if f.is_message and f.type_name not in seen_messages
|
||||
),
|
||||
None,
|
||||
)
|
||||
if dependent_type:
|
||||
# return current message to unprocessed
|
||||
messages_dict[message.name] = message
|
||||
# pop the dependency
|
||||
message = messages_dict.pop(dependent_type)
|
||||
else:
|
||||
# see message
|
||||
seen_messages.add(message.name)
|
||||
messages_sorted.append(message)
|
||||
if not messages_dict:
|
||||
break
|
||||
else:
|
||||
_, message = messages_dict.popitem()
|
||||
|
||||
assert len(messages_sorted) == len(self.messages)
|
||||
self.messages[:] = messages_sorted
|
||||
|
||||
def _filter_items(self, iter):
|
||||
return [
|
||||
|
@ -41,20 +41,20 @@ def member_type(field):
|
||||
class ${message.name}(protobuf.MessageType):
|
||||
% if message.fields:
|
||||
% for field in message.fields:
|
||||
${field.name}: ${member_type(field)}
|
||||
${field.name}: "${member_type(field)}"
|
||||
% endfor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
% for field in required_fields:
|
||||
${field.name}: ${field.python_type},
|
||||
${field.name}: "${field.python_type}",
|
||||
% endfor
|
||||
% for field in repeated_fields:
|
||||
${field.name}: list[${field.python_type}] | None = None,
|
||||
${field.name}: "list[${field.python_type}] | None" = None,
|
||||
% endfor
|
||||
% for field in optional_fields:
|
||||
${field.name}: ${field.python_type} | None = None,
|
||||
${field.name}: "${field.python_type} | None" = None,
|
||||
% endfor
|
||||
) -> None:
|
||||
pass
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -21,22 +21,13 @@ class ${enum.name}(IntEnum):
|
||||
required_fields = [f for f in message.fields if f.required]
|
||||
repeated_fields = [f for f in message.fields if f.repeated]
|
||||
optional_fields = [f for f in message.fields if f.optional]
|
||||
|
||||
|
||||
def type_name(field):
|
||||
if field.type_object is not None:
|
||||
return field.type_name
|
||||
else:
|
||||
return '"' + field.type_name + '"'
|
||||
|
||||
|
||||
%>\
|
||||
class ${message.name}(protobuf.MessageType):
|
||||
MESSAGE_WIRE_TYPE = ${message.wire_type}
|
||||
% if message.fields:
|
||||
FIELDS = {
|
||||
% for field in message.fields:
|
||||
${field.number}: protobuf.Field("${field.name}", ${type_name(field)}, repeated=${field.repeated}, required=${field.required}),
|
||||
${field.number}: protobuf.Field("${field.name}", "${field.type_name}", repeated=${field.repeated}, required=${field.required}),
|
||||
% endfor
|
||||
}
|
||||
|
||||
@ -44,13 +35,13 @@ class ${message.name}(protobuf.MessageType):
|
||||
self,
|
||||
*,
|
||||
% for field in required_fields:
|
||||
${field.name}: ${field.python_type},
|
||||
${field.name}: "${field.python_type}",
|
||||
% endfor
|
||||
% for field in repeated_fields:
|
||||
${field.name}: Optional[List[${field.python_type}]] = None,
|
||||
${field.name}: Optional[List["${field.python_type}"]] = None,
|
||||
% endfor
|
||||
% for field in optional_fields:
|
||||
${field.name}: Optional[${field.python_type}] = ${field.default_value_repr},
|
||||
${field.name}: Optional["${field.python_type}"] = ${field.default_value_repr},
|
||||
% endfor
|
||||
) -> None:
|
||||
% for field in repeated_fields:
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -144,7 +144,7 @@ REQUIRED_FIELD_PLACEHOLDER = object()
|
||||
@attr.s(auto_attribs=True)
|
||||
class Field:
|
||||
name: str
|
||||
type: Union[str, "MessageType", IntEnum]
|
||||
type: str
|
||||
repeated: bool = attr.ib(default=False)
|
||||
required: bool = attr.ib(default=False)
|
||||
default: object = attr.ib(default=None)
|
||||
@ -154,10 +154,11 @@ class Field:
|
||||
if self.type in WIRE_TYPES:
|
||||
return WIRE_TYPES[self.type]
|
||||
|
||||
if safe_issubclass(self.type, MessageType):
|
||||
field_type_object = get_field_type_object(self)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
return WIRE_TYPE_LENGTH
|
||||
|
||||
if safe_issubclass(self.type, IntEnum):
|
||||
if safe_issubclass(field_type_object, IntEnum):
|
||||
return WIRE_TYPE_INT
|
||||
|
||||
raise ValueError(f"Unrecognized type for field {self.name}")
|
||||
@ -269,6 +270,15 @@ class CountingWriter:
|
||||
return nwritten
|
||||
|
||||
|
||||
def get_field_type_object(field: Field) -> Optional[type]:
|
||||
from . import messages
|
||||
|
||||
field_type_object = getattr(messages, field.type, None)
|
||||
if not safe_issubclass(field_type_object, (IntEnum, MessageType)):
|
||||
return None
|
||||
return field_type_object
|
||||
|
||||
|
||||
def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
|
||||
assert field.repeated, "Not decoding packed array into non-repeated field"
|
||||
length = load_uvarint(reader)
|
||||
@ -285,9 +295,11 @@ def decode_packed_array_field(field: Field, reader: Reader) -> List[Any]:
|
||||
def decode_varint_field(field: Field, reader: Reader) -> Union[int, bool, IntEnum]:
|
||||
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
|
||||
value = load_uvarint(reader)
|
||||
if safe_issubclass(field.type, IntEnum):
|
||||
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, IntEnum):
|
||||
try:
|
||||
return field.type(value)
|
||||
return field_type_object(value)
|
||||
except ValueError as e:
|
||||
# treat enum errors as warnings
|
||||
LOG.info(f"On field {field.name}: {e}")
|
||||
@ -328,8 +340,9 @@ def decode_length_delimited_field(
|
||||
reader.readinto(buf)
|
||||
return buf.decode()
|
||||
|
||||
if safe_issubclass(field.type, MessageType):
|
||||
return load_message(LimitedReader(reader, value), field.type)
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
return load_message(LimitedReader(reader, value), field_type_object)
|
||||
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
@ -420,16 +433,17 @@ def dump_message(writer: Writer, msg: MessageType) -> None:
|
||||
for svalue in fvalue:
|
||||
dump_uvarint(writer, fkey)
|
||||
|
||||
if safe_issubclass(field.type, MessageType):
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
counter = CountingWriter()
|
||||
dump_message(counter, svalue)
|
||||
dump_uvarint(writer, counter.size)
|
||||
dump_message(writer, svalue)
|
||||
|
||||
elif safe_issubclass(field.type, IntEnum):
|
||||
if svalue not in field.type.__members__.values():
|
||||
elif safe_issubclass(field_type_object, IntEnum):
|
||||
if svalue not in field_type_object.__members__.values():
|
||||
raise ValueError(
|
||||
f"Value {svalue} in field {field.name} unknown for {field.type.__name__}"
|
||||
f"Value {svalue} in field {field.name} unknown for {field.type}"
|
||||
)
|
||||
dump_uvarint(writer, svalue)
|
||||
|
||||
@ -531,15 +545,16 @@ def format_message(
|
||||
|
||||
|
||||
def value_to_proto(field: Field, value: Any) -> Any:
|
||||
if safe_issubclass(field.type, MessageType):
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
raise TypeError("value_to_proto only converts simple values")
|
||||
|
||||
if safe_issubclass(field.type, IntEnum):
|
||||
if safe_issubclass(field_type_object, IntEnum):
|
||||
if isinstance(value, str):
|
||||
return field.type.__members__[value]
|
||||
return field_type_object.__members__[value]
|
||||
else:
|
||||
try:
|
||||
return field.type(value)
|
||||
return field_type_object(value)
|
||||
except ValueError as e:
|
||||
LOG.info(f"On field {field.name}: {e}")
|
||||
return int(value)
|
||||
@ -572,8 +587,9 @@ def dict_to_proto(message_type: Type[MT], d: Dict[str, Any]) -> MT:
|
||||
if not field.repeated:
|
||||
value = [value]
|
||||
|
||||
if safe_issubclass(field.type, MessageType):
|
||||
newvalue = [dict_to_proto(field.type, v) for v in value]
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
newvalue = [dict_to_proto(field_type_object, v) for v in value]
|
||||
else:
|
||||
newvalue = [value_to_proto(field, v) for v in value]
|
||||
|
||||
|
@ -20,7 +20,7 @@ import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import protobuf
|
||||
from trezorlib import messages, protobuf
|
||||
|
||||
|
||||
class SomeEnum(IntEnum):
|
||||
@ -49,26 +49,64 @@ class PrimitiveMessage(protobuf.MessageType):
|
||||
3: protobuf.Field("bool", "bool"),
|
||||
4: protobuf.Field("bytes", "bytes"),
|
||||
5: protobuf.Field("unicode", "string"),
|
||||
6: protobuf.Field("enum", SomeEnum),
|
||||
6: protobuf.Field("enum", "SomeEnum"),
|
||||
}
|
||||
|
||||
|
||||
class EnumMessageMoreValues(protobuf.MessageType):
|
||||
FIELDS = {1: protobuf.Field("enum", WiderEnum)}
|
||||
FIELDS = {1: protobuf.Field("enum", "WiderEnum")}
|
||||
|
||||
|
||||
class EnumMessageLessValues(protobuf.MessageType):
|
||||
FIELDS = {1: protobuf.Field("enum", NarrowerEnum)}
|
||||
FIELDS = {1: protobuf.Field("enum", "NarrowerEnum")}
|
||||
|
||||
|
||||
class RepeatedFields(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
1: protobuf.Field("uintlist", "uint64", repeated=True),
|
||||
2: protobuf.Field("enumlist", SomeEnum, repeated=True),
|
||||
2: protobuf.Field("enumlist", "SomeEnum", repeated=True),
|
||||
3: protobuf.Field("strlist", "string", repeated=True),
|
||||
}
|
||||
|
||||
|
||||
class RequiredFields(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
1: protobuf.Field("uvarint", "uint64", required=True),
|
||||
2: protobuf.Field("nested", "PrimitiveMessage", required=True),
|
||||
}
|
||||
|
||||
|
||||
class DefaultFields(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
1: protobuf.Field("uvarint", "uint32", default=42),
|
||||
2: protobuf.Field("svarint", "sint32", default=-42),
|
||||
3: protobuf.Field("bool", "bool", default=True),
|
||||
4: protobuf.Field("bytes", "bytes", default=b"hello"),
|
||||
5: protobuf.Field("unicode", "string", default="hello"),
|
||||
6: protobuf.Field("enum", "SomeEnum", default=SomeEnum.Five),
|
||||
}
|
||||
|
||||
|
||||
class RecursiveMessage(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
1: protobuf.Field("uvarint", "uint64"),
|
||||
2: protobuf.Field("recursivefield", "RecursiveMessage", required=False)
|
||||
}
|
||||
|
||||
|
||||
# message types are read from the messages module so we need to "include" these messages there for now
|
||||
messages.SomeEnum = SomeEnum
|
||||
messages.WiderEnum = WiderEnum
|
||||
messages.NarrowerEnum = NarrowerEnum
|
||||
messages.PrimitiveMessage = PrimitiveMessage
|
||||
messages.EnumMessageMoreValues = EnumMessageMoreValues
|
||||
messages.EnumMessageLessValues = EnumMessageLessValues
|
||||
messages.RepeatedFields = RepeatedFields
|
||||
messages.RequiredFields = RequiredFields
|
||||
messages.DefaultFields = DefaultFields
|
||||
messages.RecursiveMessage = RecursiveMessage
|
||||
|
||||
|
||||
def load_uvarint(buffer):
|
||||
reader = BytesIO(buffer)
|
||||
return protobuf.load_uvarint(reader)
|
||||
@ -224,13 +262,6 @@ def test_packed_enum():
|
||||
assert not msg.strlist
|
||||
|
||||
|
||||
class RequiredFields(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
1: protobuf.Field("uvarint", "uint64", required=True),
|
||||
2: protobuf.Field("nested", PrimitiveMessage, required=True),
|
||||
}
|
||||
|
||||
|
||||
def test_required():
|
||||
msg = RequiredFields(uvarint=3, nested=PrimitiveMessage())
|
||||
buf = dump_message(msg)
|
||||
@ -258,17 +289,6 @@ def test_required():
|
||||
load_message(buf, RequiredFields)
|
||||
|
||||
|
||||
class DefaultFields(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
1: protobuf.Field("uvarint", "uint32", default=42),
|
||||
2: protobuf.Field("svarint", "sint32", default=-42),
|
||||
3: protobuf.Field("bool", "bool", default=True),
|
||||
4: protobuf.Field("bytes", "bytes", default=b"hello"),
|
||||
5: protobuf.Field("unicode", "string", default="hello"),
|
||||
6: protobuf.Field("enum", SomeEnum, default=SomeEnum.Five),
|
||||
}
|
||||
|
||||
|
||||
def test_default():
|
||||
# load empty message
|
||||
retr = load_message(b"", DefaultFields)
|
||||
@ -288,3 +308,25 @@ def test_default():
|
||||
buf = dump_message(msg)
|
||||
retr = load_message(buf, DefaultFields)
|
||||
assert retr.uvarint == 42
|
||||
|
||||
|
||||
def test_recursive():
|
||||
msg = RecursiveMessage(
|
||||
uvarint=1,
|
||||
recursivefield=RecursiveMessage(
|
||||
uvarint=2,
|
||||
recursivefield=RecursiveMessage(
|
||||
uvarint=3
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
buf = dump_message(msg)
|
||||
retr = load_message(buf, RecursiveMessage)
|
||||
|
||||
assert msg == retr
|
||||
assert retr.uvarint == 1
|
||||
assert type(retr.recursivefield) == RecursiveMessage
|
||||
assert retr.recursivefield.uvarint == 2
|
||||
assert type(retr.recursivefield.recursivefield) == RecursiveMessage
|
||||
assert retr.recursivefield.recursivefield.uvarint == 3
|
||||
|
@ -18,7 +18,7 @@ from enum import IntEnum
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import protobuf
|
||||
from trezorlib import messages, protobuf
|
||||
|
||||
|
||||
class SimpleEnum(IntEnum):
|
||||
@ -34,18 +34,18 @@ class SimpleMessage(protobuf.MessageType):
|
||||
3: protobuf.Field("bool", "bool"),
|
||||
4: protobuf.Field("bytes", "bytes"),
|
||||
5: protobuf.Field("unicode", "string"),
|
||||
6: protobuf.Field("enum", SimpleEnum),
|
||||
6: protobuf.Field("enum", "SimpleEnum"),
|
||||
7: protobuf.Field("rep_int", "uint64", repeated=True),
|
||||
8: protobuf.Field("rep_str", "string", repeated=True),
|
||||
9: protobuf.Field("rep_enum", SimpleEnum, repeated=True),
|
||||
9: protobuf.Field("rep_enum", "SimpleEnum", repeated=True),
|
||||
}
|
||||
|
||||
|
||||
class NestedMessage(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
1: protobuf.Field("scalar", "uint64"),
|
||||
2: protobuf.Field("nested", SimpleMessage),
|
||||
3: protobuf.Field("repeated", SimpleMessage, repeated=True),
|
||||
2: protobuf.Field("nested", "SimpleMessage"),
|
||||
3: protobuf.Field("repeated", "SimpleMessage", repeated=True),
|
||||
}
|
||||
|
||||
|
||||
@ -55,6 +55,13 @@ class RequiredFields(protobuf.MessageType):
|
||||
}
|
||||
|
||||
|
||||
# message types are read from the messages module so we need to "include" these messages there for now
|
||||
messages.SimpleEnum = SimpleEnum
|
||||
messages.SimpleMessage = SimpleMessage
|
||||
messages.NestedMessage = NestedMessage
|
||||
messages.RequiredFields = RequiredFields
|
||||
|
||||
|
||||
def test_get_field():
|
||||
# smoke test
|
||||
field = SimpleMessage.get_field("bool")
|
||||
|
Loading…
Reference in New Issue
Block a user