1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2024-11-18 21:48:13 +00:00

refactor(python/protobuf): allow field types imported in the same module

This commit is contained in:
matejcik 2024-06-12 15:55:49 +02:00
parent 27fef37cc9
commit cd55d32407
4 changed files with 122 additions and 139 deletions

View File

@ -25,12 +25,13 @@ For serializing (dumping) protobuf types, object with `Writer` interface is requ
from __future__ import annotations
import logging
import sys
import typing as t
import warnings
from dataclasses import dataclass
from enum import IntEnum
from io import BytesIO
from itertools import zip_longest
import typing as t
import typing_extensions as tx
@ -62,10 +63,6 @@ _UVARINT_BUFFER = bytearray(1)
LOG = logging.getLogger(__name__)
def safe_issubclass(value: t.Any, cls: T | tuple[T, ...]) -> tx.TypeGuard[T]:
return isinstance(value, type) and issubclass(value, cls)
def load_uvarint(reader: Reader) -> int:
buffer = _UVARINT_BUFFER
result = 0
@ -135,14 +132,14 @@ def uint_to_sint(uint: int) -> int:
WIRE_TYPE_INT = 0
WIRE_TYPE_LENGTH = 2
WIRE_TYPES = {
"uint32": WIRE_TYPE_INT,
"uint64": WIRE_TYPE_INT,
"sint32": WIRE_TYPE_INT,
"sint64": WIRE_TYPE_INT,
"bool": WIRE_TYPE_INT,
"bytes": WIRE_TYPE_LENGTH,
"string": WIRE_TYPE_LENGTH,
PROTO_TYPES = {
"uint32": int,
"uint64": int,
"sint32": int,
"sint64": int,
"bool": bool,
"bytes": bytes,
"string": str,
}
REQUIRED_FIELD_PLACEHOLDER = object()
@ -151,50 +148,76 @@ REQUIRED_FIELD_PLACEHOLDER = object()
@dataclass
class Field:
name: str
type: str
proto_type: str
repeated: bool = False
required: bool = False
default: object = None
_py_type: type | None = None
_owner: type[MessageType] | None = None
@property
def py_type(self) -> type:
if self._py_type is None:
self._py_type = self._resolve_type()
# pyright issue https://github.com/microsoft/pyright/issues/8136
return self._py_type # type: ignore [Type ["Unknown | None"]]
def _resolve_type(self) -> type:
# look for a type in the builtins
py_type = PROTO_TYPES.get(self.proto_type)
if py_type is not None:
return py_type
# look for a type in the class locals
assert self._owner is not None, "Field is not owned by a MessageType"
py_type = self._owner.__dict__.get(self.proto_type)
if py_type is not None:
return py_type
# look for a type in the class globals
cls_module = sys.modules.get(self._owner.__module__, None)
cls_globals = getattr(cls_module, "__dict__", {})
py_type = cls_globals.get(self.proto_type)
if py_type is not None:
return py_type
raise TypeError(f"Could not resolve field type {self.proto_type}")
@property
def wire_type(self) -> int:
if self.type in WIRE_TYPES:
return WIRE_TYPES[self.type]
field_type_object = get_field_type_object(self)
if safe_issubclass(field_type_object, MessageType):
if issubclass(self.py_type, (MessageType, bytes, str)):
return WIRE_TYPE_LENGTH
if safe_issubclass(field_type_object, IntEnum):
if issubclass(self.py_type, int):
return WIRE_TYPE_INT
raise ValueError(f"Unrecognized type for field {self.name}")
def value_fits(self, value: int) -> bool:
if self.type == "uint32":
if self.proto_type == "uint32":
return 0 <= value < 2**32
if self.type == "uint64":
if self.proto_type == "uint64":
return 0 <= value < 2**64
if self.type == "sint32":
if self.proto_type == "sint32":
return -(2**31) <= value < 2**31
if self.type == "sint64":
if self.proto_type == "sint64":
return -(2**63) <= value < 2**63
raise ValueError(f"Cannot check range bounds for {self.type}")
raise ValueError(f"Cannot check range bounds for {self.proto_type}")
class _MessageTypeMeta(type):
def __init__(cls, name: str, bases: tuple, d: dict) -> None:
super().__init__(name, bases, d)
if name != "MessageType":
cls.__init__ = MessageType.__init__ # type: ignore [Parameter]
class MessageType(metaclass=_MessageTypeMeta):
class MessageType:
MESSAGE_WIRE_TYPE: t.ClassVar[int | None] = None
FIELDS: t.ClassVar[dict[int, Field]] = {}
def __init_subclass__(cls) -> None:
super().__init_subclass__()
# override the generated __init__ methods by the parent method
cls.__init__ = MessageType.__init__
for field in cls.FIELDS.values():
field._owner = cls
@classmethod
def get_field(cls, name: str) -> Field | None:
return next((f for f in cls.FIELDS.values() if f.name == name), None)
@ -278,15 +301,6 @@ class CountingWriter:
return nwritten
def get_field_type_object(field: Field) -> type[MessageType] | type[IntEnum] | None:
from . import messages
field_type_object = getattr(messages, field.type, None)
if not safe_issubclass(field_type_object, (IntEnum, MessageType)):
return None
return field_type_object
def decode_packed_array_field(field: Field, reader: Reader) -> list[t.Any]:
assert field.repeated, "Not decoding packed array into non-repeated field"
length = load_uvarint(reader)
@ -304,33 +318,26 @@ def decode_varint_field(field: Field, reader: Reader) -> int | bool | IntEnum:
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
value = load_uvarint(reader)
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, IntEnum):
if issubclass(field.py_type, IntEnum):
try:
return field_type_object(value)
return field.py_type(value)
except ValueError as e:
# treat enum errors as warnings
LOG.info(f"On field {field.name}: {e}")
return value
if field.type.startswith("uint"):
if not field.value_fits(value):
LOG.info(
f"On field {field.name}: value {value} out of range for {field.type}"
)
return value
if field.type.startswith("sint"):
value = uint_to_sint(value)
if not field.value_fits(value):
LOG.info(
f"On field {field.name}: value {value} out of range for {field.type}"
)
return value
if field.type == "bool":
if issubclass(field.py_type, bool):
return bool(value)
if issubclass(field.py_type, int):
if field.proto_type.startswith("sint"):
value = uint_to_sint(value)
if not field.value_fits(value):
LOG.info(
f"On field {field.name}: value {value} out of range for {field.proto_type}"
)
return value
raise TypeError # not a varint field or unknown type
@ -341,19 +348,18 @@ def decode_length_delimited_field(
if value > MAX_FIELD_SIZE:
raise ValueError(f"Field {field.name} contents too large ({value} bytes)")
if field.type == "bytes":
if issubclass(field.py_type, bytes):
buf = bytearray(value)
reader.readinto(buf)
return bytes(buf)
if field.type == "string":
if issubclass(field.py_type, str):
buf = bytearray(value)
reader.readinto(buf)
return buf.decode()
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
return load_message(LimitedReader(reader, value), field_type_object)
if issubclass(field.py_type, MessageType):
return load_message(LimitedReader(reader, value), field.py_type)
raise TypeError # field type is unknown
@ -446,47 +452,41 @@ def dump_message(writer: Writer, msg: "MessageType") -> None:
for svalue in fvalue:
dump_uvarint(writer, fkey)
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
if not isinstance(svalue, field_type_object):
if issubclass(field.py_type, MessageType):
if not isinstance(svalue, field.py_type):
raise ValueError(
f"Value {svalue} in field {field.name} is not {field_type_object.__name__}"
f"Value {svalue} in field {field.name} is not {field.py_type.__name__}"
)
counter = CountingWriter()
dump_message(counter, svalue)
dump_uvarint(writer, counter.size)
dump_message(writer, svalue)
elif safe_issubclass(field_type_object, IntEnum):
if svalue not in field_type_object.__members__.values():
elif issubclass(field.py_type, IntEnum):
if svalue not in field.py_type.__members__.values():
raise ValueError(
f"Value {svalue} in field {field.name} unknown for {field.type}"
f"Value {svalue} in field {field.name} unknown for {field.proto_type}"
)
dump_uvarint(writer, svalue)
elif field.type.startswith("uint"):
if not field.value_fits(svalue):
raise ValueError(
f"Value {svalue} in field {field.name} does not fit into {field.type}"
)
dump_uvarint(writer, svalue)
elif field.type.startswith("sint"):
if not field.value_fits(svalue):
raise ValueError(
f"Value {svalue} in field {field.name} does not fit into {field.type}"
)
dump_uvarint(writer, sint_to_uint(svalue))
elif field.type == "bool":
elif issubclass(field.py_type, bool):
dump_uvarint(writer, int(svalue))
elif field.type == "bytes":
elif issubclass(field.py_type, int):
if not field.value_fits(svalue):
raise ValueError(
f"Value {svalue} in field {field.name} does not fit into {field.proto_type}"
)
if field.proto_type.startswith("sint"):
svalue = sint_to_uint(svalue)
dump_uvarint(writer, svalue)
elif issubclass(field.py_type, bytes):
assert isinstance(svalue, (bytes, bytearray))
dump_uvarint(writer, len(svalue))
writer.write(svalue)
elif field.type == "string":
elif issubclass(field.py_type, str):
assert isinstance(svalue, str)
svalue_bytes = svalue.encode()
dump_uvarint(writer, len(svalue_bytes))
@ -549,9 +549,9 @@ def format_message(
field = pb.get_field(name)
if field is not None:
if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
if isinstance(value, int) and issubclass(field.py_type, IntEnum):
try:
return f"{field.type(value).name} ({value})"
return f"{field.py_type(value).name} ({value})"
except ValueError:
return str(value)
@ -569,30 +569,20 @@ def format_message(
def value_to_proto(field: Field, value: t.Any) -> t.Any:
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
if issubclass(field.py_type, MessageType):
raise TypeError("value_to_proto only converts simple values")
if safe_issubclass(field_type_object, IntEnum):
if issubclass(field.py_type, IntEnum):
if isinstance(value, str):
return field_type_object.__members__[value]
return field.py_type.__members__[value]
else:
try:
return field_type_object(value)
return field.py_type(value)
except ValueError as e:
LOG.info(f"On field {field.name}: {e}")
return int(value)
if "int" in field.type:
return int(value)
if field.type == "bool":
return bool(value)
if field.type == "string":
return str(value)
if field.type == "bytes":
if issubclass(field.py_type, bytes):
if isinstance(value, str):
return bytes.fromhex(value)
elif isinstance(value, bytes):
@ -600,6 +590,8 @@ def value_to_proto(field: Field, value: t.Any) -> t.Any:
else:
raise TypeError(f"can't convert {type(value)} value to bytes")
return field.py_type(value)
def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
params = {}
@ -611,9 +603,8 @@ def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
if not field.repeated:
value = [value]
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
newvalue = [dict_to_proto(field_type_object, v) for v in value]
if issubclass(field.py_type, MessageType):
newvalue = [dict_to_proto(field.py_type, v) for v in value]
else:
newvalue = [value_to_proto(field, v) for v in value]

View File

@ -20,7 +20,7 @@ from io import BytesIO
import pytest
from trezorlib import messages, protobuf
from trezorlib import protobuf
class SomeEnum(IntEnum):
@ -94,19 +94,6 @@ class RecursiveMessage(protobuf.MessageType):
}
# message types are read from the messages module so we need to "include" these messages there for now
messages.SomeEnum = SomeEnum
messages.WiderEnum = WiderEnum
messages.NarrowerEnum = NarrowerEnum
messages.PrimitiveMessage = PrimitiveMessage
messages.EnumMessageMoreValues = EnumMessageMoreValues
messages.EnumMessageLessValues = EnumMessageLessValues
messages.RepeatedFields = RepeatedFields
messages.RequiredFields = RequiredFields
messages.DefaultFields = DefaultFields
messages.RecursiveMessage = RecursiveMessage
def load_uvarint(buffer):
reader = BytesIO(buffer)
return protobuf.load_uvarint(reader)

View File

@ -18,7 +18,7 @@ from enum import IntEnum
import pytest
from trezorlib import messages, protobuf
from trezorlib import protobuf
class SimpleEnum(IntEnum):
@ -55,18 +55,11 @@ class RequiredFields(protobuf.MessageType):
}
# message types are read from the messages module so we need to "include" these messages there for now
messages.SimpleEnum = SimpleEnum
messages.SimpleMessage = SimpleMessage
messages.NestedMessage = NestedMessage
messages.RequiredFields = RequiredFields
def test_get_field():
# smoke test
field = SimpleMessage.get_field("bool")
assert field.name == "bool"
assert field.type == "bool"
assert field.proto_type == "bool"
assert field.repeated is False
assert field.required is False
assert field.default is None
@ -91,6 +84,19 @@ def test_dict_roundtrip():
assert recovered == msg
def test_dict_to_proto_fresh():
class FreshMessage(protobuf.MessageType):
FIELDS = {
1: protobuf.Field("scalar", "uint64"),
2: protobuf.Field("nested", "SimpleMessage"),
}
dictdata = {"scalar": 5, "nested": {"uvarint": 5}}
recovered = protobuf.dict_to_proto(FreshMessage, dictdata)
assert recovered.scalar == 5
assert recovered.nested.uvarint == 5
def test_to_dict():
msg = SimpleMessage(
uvarint=5,
@ -204,7 +210,6 @@ def test_nested_recover():
assert isinstance(recovered.nested, SimpleMessage)
@pytest.mark.xfail(reason="formatting broken because of size counting")
def test_unknown_enum_to_str():
simple = SimpleMessage(enum=SimpleEnum.QUUX)
string = protobuf.format_message(simple)

View File

@ -124,13 +124,13 @@ def _make_bad_params():
if field.name in DRY_RUN_ALLOWED_FIELDS:
continue
if "int" in field.type:
if field.py_type is int:
yield field.name, 1
elif field.type == "bool":
elif field.py_type is bool:
yield field.name, True
elif field.type == "string":
elif field.py_type is str:
yield field.name, "test"
elif field.type == "RecoveryType":
elif field.py_type is messages.RecoveryType:
yield field.name, 1
else:
# Someone added a field to RecoveryDevice of a type that has no assigned