mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-18 21:48:13 +00:00
refactor(python/protobuf): allow field types imported in the same module
This commit is contained in:
parent
27fef37cc9
commit
cd55d32407
@ -25,12 +25,13 @@ For serializing (dumping) protobuf types, object with `Writer` interface is requ
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import typing as t
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from io import BytesIO
|
||||
from itertools import zip_longest
|
||||
import typing as t
|
||||
|
||||
import typing_extensions as tx
|
||||
|
||||
@ -62,10 +63,6 @@ _UVARINT_BUFFER = bytearray(1)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def safe_issubclass(value: t.Any, cls: T | tuple[T, ...]) -> tx.TypeGuard[T]:
|
||||
return isinstance(value, type) and issubclass(value, cls)
|
||||
|
||||
|
||||
def load_uvarint(reader: Reader) -> int:
|
||||
buffer = _UVARINT_BUFFER
|
||||
result = 0
|
||||
@ -135,14 +132,14 @@ def uint_to_sint(uint: int) -> int:
|
||||
WIRE_TYPE_INT = 0
|
||||
WIRE_TYPE_LENGTH = 2
|
||||
|
||||
WIRE_TYPES = {
|
||||
"uint32": WIRE_TYPE_INT,
|
||||
"uint64": WIRE_TYPE_INT,
|
||||
"sint32": WIRE_TYPE_INT,
|
||||
"sint64": WIRE_TYPE_INT,
|
||||
"bool": WIRE_TYPE_INT,
|
||||
"bytes": WIRE_TYPE_LENGTH,
|
||||
"string": WIRE_TYPE_LENGTH,
|
||||
PROTO_TYPES = {
|
||||
"uint32": int,
|
||||
"uint64": int,
|
||||
"sint32": int,
|
||||
"sint64": int,
|
||||
"bool": bool,
|
||||
"bytes": bytes,
|
||||
"string": str,
|
||||
}
|
||||
|
||||
REQUIRED_FIELD_PLACEHOLDER = object()
|
||||
@ -151,50 +148,76 @@ REQUIRED_FIELD_PLACEHOLDER = object()
|
||||
@dataclass
|
||||
class Field:
|
||||
name: str
|
||||
type: str
|
||||
proto_type: str
|
||||
repeated: bool = False
|
||||
required: bool = False
|
||||
default: object = None
|
||||
|
||||
_py_type: type | None = None
|
||||
_owner: type[MessageType] | None = None
|
||||
|
||||
@property
|
||||
def py_type(self) -> type:
|
||||
if self._py_type is None:
|
||||
self._py_type = self._resolve_type()
|
||||
# pyright issue https://github.com/microsoft/pyright/issues/8136
|
||||
return self._py_type # type: ignore [Type ["Unknown | None"]]
|
||||
|
||||
def _resolve_type(self) -> type:
|
||||
# look for a type in the builtins
|
||||
py_type = PROTO_TYPES.get(self.proto_type)
|
||||
if py_type is not None:
|
||||
return py_type
|
||||
|
||||
# look for a type in the class locals
|
||||
assert self._owner is not None, "Field is not owned by a MessageType"
|
||||
py_type = self._owner.__dict__.get(self.proto_type)
|
||||
if py_type is not None:
|
||||
return py_type
|
||||
|
||||
# look for a type in the class globals
|
||||
cls_module = sys.modules.get(self._owner.__module__, None)
|
||||
cls_globals = getattr(cls_module, "__dict__", {})
|
||||
py_type = cls_globals.get(self.proto_type)
|
||||
if py_type is not None:
|
||||
return py_type
|
||||
|
||||
raise TypeError(f"Could not resolve field type {self.proto_type}")
|
||||
|
||||
@property
|
||||
def wire_type(self) -> int:
|
||||
if self.type in WIRE_TYPES:
|
||||
return WIRE_TYPES[self.type]
|
||||
|
||||
field_type_object = get_field_type_object(self)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
if issubclass(self.py_type, (MessageType, bytes, str)):
|
||||
return WIRE_TYPE_LENGTH
|
||||
|
||||
if safe_issubclass(field_type_object, IntEnum):
|
||||
if issubclass(self.py_type, int):
|
||||
return WIRE_TYPE_INT
|
||||
|
||||
raise ValueError(f"Unrecognized type for field {self.name}")
|
||||
|
||||
def value_fits(self, value: int) -> bool:
|
||||
if self.type == "uint32":
|
||||
if self.proto_type == "uint32":
|
||||
return 0 <= value < 2**32
|
||||
if self.type == "uint64":
|
||||
if self.proto_type == "uint64":
|
||||
return 0 <= value < 2**64
|
||||
if self.type == "sint32":
|
||||
if self.proto_type == "sint32":
|
||||
return -(2**31) <= value < 2**31
|
||||
if self.type == "sint64":
|
||||
if self.proto_type == "sint64":
|
||||
return -(2**63) <= value < 2**63
|
||||
|
||||
raise ValueError(f"Cannot check range bounds for {self.type}")
|
||||
raise ValueError(f"Cannot check range bounds for {self.proto_type}")
|
||||
|
||||
|
||||
class _MessageTypeMeta(type):
|
||||
def __init__(cls, name: str, bases: tuple, d: dict) -> None:
|
||||
super().__init__(name, bases, d)
|
||||
if name != "MessageType":
|
||||
cls.__init__ = MessageType.__init__ # type: ignore [Parameter]
|
||||
|
||||
|
||||
class MessageType(metaclass=_MessageTypeMeta):
|
||||
class MessageType:
|
||||
MESSAGE_WIRE_TYPE: t.ClassVar[int | None] = None
|
||||
|
||||
FIELDS: t.ClassVar[dict[int, Field]] = {}
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
super().__init_subclass__()
|
||||
# override the generated __init__ methods by the parent method
|
||||
cls.__init__ = MessageType.__init__
|
||||
for field in cls.FIELDS.values():
|
||||
field._owner = cls
|
||||
|
||||
@classmethod
|
||||
def get_field(cls, name: str) -> Field | None:
|
||||
return next((f for f in cls.FIELDS.values() if f.name == name), None)
|
||||
@ -278,15 +301,6 @@ class CountingWriter:
|
||||
return nwritten
|
||||
|
||||
|
||||
def get_field_type_object(field: Field) -> type[MessageType] | type[IntEnum] | None:
|
||||
from . import messages
|
||||
|
||||
field_type_object = getattr(messages, field.type, None)
|
||||
if not safe_issubclass(field_type_object, (IntEnum, MessageType)):
|
||||
return None
|
||||
return field_type_object
|
||||
|
||||
|
||||
def decode_packed_array_field(field: Field, reader: Reader) -> list[t.Any]:
|
||||
assert field.repeated, "Not decoding packed array into non-repeated field"
|
||||
length = load_uvarint(reader)
|
||||
@ -304,33 +318,26 @@ def decode_varint_field(field: Field, reader: Reader) -> int | bool | IntEnum:
|
||||
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
|
||||
value = load_uvarint(reader)
|
||||
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, IntEnum):
|
||||
if issubclass(field.py_type, IntEnum):
|
||||
try:
|
||||
return field_type_object(value)
|
||||
return field.py_type(value)
|
||||
except ValueError as e:
|
||||
# treat enum errors as warnings
|
||||
LOG.info(f"On field {field.name}: {e}")
|
||||
return value
|
||||
|
||||
if field.type.startswith("uint"):
|
||||
if not field.value_fits(value):
|
||||
LOG.info(
|
||||
f"On field {field.name}: value {value} out of range for {field.type}"
|
||||
)
|
||||
return value
|
||||
|
||||
if field.type.startswith("sint"):
|
||||
value = uint_to_sint(value)
|
||||
if not field.value_fits(value):
|
||||
LOG.info(
|
||||
f"On field {field.name}: value {value} out of range for {field.type}"
|
||||
)
|
||||
return value
|
||||
|
||||
if field.type == "bool":
|
||||
if issubclass(field.py_type, bool):
|
||||
return bool(value)
|
||||
|
||||
if issubclass(field.py_type, int):
|
||||
if field.proto_type.startswith("sint"):
|
||||
value = uint_to_sint(value)
|
||||
if not field.value_fits(value):
|
||||
LOG.info(
|
||||
f"On field {field.name}: value {value} out of range for {field.proto_type}"
|
||||
)
|
||||
return value
|
||||
|
||||
raise TypeError # not a varint field or unknown type
|
||||
|
||||
|
||||
@ -341,19 +348,18 @@ def decode_length_delimited_field(
|
||||
if value > MAX_FIELD_SIZE:
|
||||
raise ValueError(f"Field {field.name} contents too large ({value} bytes)")
|
||||
|
||||
if field.type == "bytes":
|
||||
if issubclass(field.py_type, bytes):
|
||||
buf = bytearray(value)
|
||||
reader.readinto(buf)
|
||||
return bytes(buf)
|
||||
|
||||
if field.type == "string":
|
||||
if issubclass(field.py_type, str):
|
||||
buf = bytearray(value)
|
||||
reader.readinto(buf)
|
||||
return buf.decode()
|
||||
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
return load_message(LimitedReader(reader, value), field_type_object)
|
||||
if issubclass(field.py_type, MessageType):
|
||||
return load_message(LimitedReader(reader, value), field.py_type)
|
||||
|
||||
raise TypeError # field type is unknown
|
||||
|
||||
@ -446,47 +452,41 @@ def dump_message(writer: Writer, msg: "MessageType") -> None:
|
||||
for svalue in fvalue:
|
||||
dump_uvarint(writer, fkey)
|
||||
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
if not isinstance(svalue, field_type_object):
|
||||
if issubclass(field.py_type, MessageType):
|
||||
if not isinstance(svalue, field.py_type):
|
||||
raise ValueError(
|
||||
f"Value {svalue} in field {field.name} is not {field_type_object.__name__}"
|
||||
f"Value {svalue} in field {field.name} is not {field.py_type.__name__}"
|
||||
)
|
||||
counter = CountingWriter()
|
||||
dump_message(counter, svalue)
|
||||
dump_uvarint(writer, counter.size)
|
||||
dump_message(writer, svalue)
|
||||
|
||||
elif safe_issubclass(field_type_object, IntEnum):
|
||||
if svalue not in field_type_object.__members__.values():
|
||||
elif issubclass(field.py_type, IntEnum):
|
||||
if svalue not in field.py_type.__members__.values():
|
||||
raise ValueError(
|
||||
f"Value {svalue} in field {field.name} unknown for {field.type}"
|
||||
f"Value {svalue} in field {field.name} unknown for {field.proto_type}"
|
||||
)
|
||||
dump_uvarint(writer, svalue)
|
||||
|
||||
elif field.type.startswith("uint"):
|
||||
if not field.value_fits(svalue):
|
||||
raise ValueError(
|
||||
f"Value {svalue} in field {field.name} does not fit into {field.type}"
|
||||
)
|
||||
dump_uvarint(writer, svalue)
|
||||
|
||||
elif field.type.startswith("sint"):
|
||||
if not field.value_fits(svalue):
|
||||
raise ValueError(
|
||||
f"Value {svalue} in field {field.name} does not fit into {field.type}"
|
||||
)
|
||||
dump_uvarint(writer, sint_to_uint(svalue))
|
||||
|
||||
elif field.type == "bool":
|
||||
elif issubclass(field.py_type, bool):
|
||||
dump_uvarint(writer, int(svalue))
|
||||
|
||||
elif field.type == "bytes":
|
||||
elif issubclass(field.py_type, int):
|
||||
if not field.value_fits(svalue):
|
||||
raise ValueError(
|
||||
f"Value {svalue} in field {field.name} does not fit into {field.proto_type}"
|
||||
)
|
||||
if field.proto_type.startswith("sint"):
|
||||
svalue = sint_to_uint(svalue)
|
||||
dump_uvarint(writer, svalue)
|
||||
|
||||
elif issubclass(field.py_type, bytes):
|
||||
assert isinstance(svalue, (bytes, bytearray))
|
||||
dump_uvarint(writer, len(svalue))
|
||||
writer.write(svalue)
|
||||
|
||||
elif field.type == "string":
|
||||
elif issubclass(field.py_type, str):
|
||||
assert isinstance(svalue, str)
|
||||
svalue_bytes = svalue.encode()
|
||||
dump_uvarint(writer, len(svalue_bytes))
|
||||
@ -549,9 +549,9 @@ def format_message(
|
||||
|
||||
field = pb.get_field(name)
|
||||
if field is not None:
|
||||
if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
|
||||
if isinstance(value, int) and issubclass(field.py_type, IntEnum):
|
||||
try:
|
||||
return f"{field.type(value).name} ({value})"
|
||||
return f"{field.py_type(value).name} ({value})"
|
||||
except ValueError:
|
||||
return str(value)
|
||||
|
||||
@ -569,30 +569,20 @@ def format_message(
|
||||
|
||||
|
||||
def value_to_proto(field: Field, value: t.Any) -> t.Any:
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
if issubclass(field.py_type, MessageType):
|
||||
raise TypeError("value_to_proto only converts simple values")
|
||||
|
||||
if safe_issubclass(field_type_object, IntEnum):
|
||||
if issubclass(field.py_type, IntEnum):
|
||||
if isinstance(value, str):
|
||||
return field_type_object.__members__[value]
|
||||
return field.py_type.__members__[value]
|
||||
else:
|
||||
try:
|
||||
return field_type_object(value)
|
||||
return field.py_type(value)
|
||||
except ValueError as e:
|
||||
LOG.info(f"On field {field.name}: {e}")
|
||||
return int(value)
|
||||
|
||||
if "int" in field.type:
|
||||
return int(value)
|
||||
|
||||
if field.type == "bool":
|
||||
return bool(value)
|
||||
|
||||
if field.type == "string":
|
||||
return str(value)
|
||||
|
||||
if field.type == "bytes":
|
||||
if issubclass(field.py_type, bytes):
|
||||
if isinstance(value, str):
|
||||
return bytes.fromhex(value)
|
||||
elif isinstance(value, bytes):
|
||||
@ -600,6 +590,8 @@ def value_to_proto(field: Field, value: t.Any) -> t.Any:
|
||||
else:
|
||||
raise TypeError(f"can't convert {type(value)} value to bytes")
|
||||
|
||||
return field.py_type(value)
|
||||
|
||||
|
||||
def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
|
||||
params = {}
|
||||
@ -611,9 +603,8 @@ def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
|
||||
if not field.repeated:
|
||||
value = [value]
|
||||
|
||||
field_type_object = get_field_type_object(field)
|
||||
if safe_issubclass(field_type_object, MessageType):
|
||||
newvalue = [dict_to_proto(field_type_object, v) for v in value]
|
||||
if issubclass(field.py_type, MessageType):
|
||||
newvalue = [dict_to_proto(field.py_type, v) for v in value]
|
||||
else:
|
||||
newvalue = [value_to_proto(field, v) for v in value]
|
||||
|
||||
|
@ -20,7 +20,7 @@ from io import BytesIO
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import messages, protobuf
|
||||
from trezorlib import protobuf
|
||||
|
||||
|
||||
class SomeEnum(IntEnum):
|
||||
@ -94,19 +94,6 @@ class RecursiveMessage(protobuf.MessageType):
|
||||
}
|
||||
|
||||
|
||||
# message types are read from the messages module so we need to "include" these messages there for now
|
||||
messages.SomeEnum = SomeEnum
|
||||
messages.WiderEnum = WiderEnum
|
||||
messages.NarrowerEnum = NarrowerEnum
|
||||
messages.PrimitiveMessage = PrimitiveMessage
|
||||
messages.EnumMessageMoreValues = EnumMessageMoreValues
|
||||
messages.EnumMessageLessValues = EnumMessageLessValues
|
||||
messages.RepeatedFields = RepeatedFields
|
||||
messages.RequiredFields = RequiredFields
|
||||
messages.DefaultFields = DefaultFields
|
||||
messages.RecursiveMessage = RecursiveMessage
|
||||
|
||||
|
||||
def load_uvarint(buffer):
|
||||
reader = BytesIO(buffer)
|
||||
return protobuf.load_uvarint(reader)
|
||||
|
@ -18,7 +18,7 @@ from enum import IntEnum
|
||||
|
||||
import pytest
|
||||
|
||||
from trezorlib import messages, protobuf
|
||||
from trezorlib import protobuf
|
||||
|
||||
|
||||
class SimpleEnum(IntEnum):
|
||||
@ -55,18 +55,11 @@ class RequiredFields(protobuf.MessageType):
|
||||
}
|
||||
|
||||
|
||||
# message types are read from the messages module so we need to "include" these messages there for now
|
||||
messages.SimpleEnum = SimpleEnum
|
||||
messages.SimpleMessage = SimpleMessage
|
||||
messages.NestedMessage = NestedMessage
|
||||
messages.RequiredFields = RequiredFields
|
||||
|
||||
|
||||
def test_get_field():
|
||||
# smoke test
|
||||
field = SimpleMessage.get_field("bool")
|
||||
assert field.name == "bool"
|
||||
assert field.type == "bool"
|
||||
assert field.proto_type == "bool"
|
||||
assert field.repeated is False
|
||||
assert field.required is False
|
||||
assert field.default is None
|
||||
@ -91,6 +84,19 @@ def test_dict_roundtrip():
|
||||
assert recovered == msg
|
||||
|
||||
|
||||
def test_dict_to_proto_fresh():
|
||||
class FreshMessage(protobuf.MessageType):
|
||||
FIELDS = {
|
||||
1: protobuf.Field("scalar", "uint64"),
|
||||
2: protobuf.Field("nested", "SimpleMessage"),
|
||||
}
|
||||
|
||||
dictdata = {"scalar": 5, "nested": {"uvarint": 5}}
|
||||
recovered = protobuf.dict_to_proto(FreshMessage, dictdata)
|
||||
assert recovered.scalar == 5
|
||||
assert recovered.nested.uvarint == 5
|
||||
|
||||
|
||||
def test_to_dict():
|
||||
msg = SimpleMessage(
|
||||
uvarint=5,
|
||||
@ -204,7 +210,6 @@ def test_nested_recover():
|
||||
assert isinstance(recovered.nested, SimpleMessage)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="formatting broken because of size counting")
|
||||
def test_unknown_enum_to_str():
|
||||
simple = SimpleMessage(enum=SimpleEnum.QUUX)
|
||||
string = protobuf.format_message(simple)
|
||||
|
@ -124,13 +124,13 @@ def _make_bad_params():
|
||||
if field.name in DRY_RUN_ALLOWED_FIELDS:
|
||||
continue
|
||||
|
||||
if "int" in field.type:
|
||||
if field.py_type is int:
|
||||
yield field.name, 1
|
||||
elif field.type == "bool":
|
||||
elif field.py_type is bool:
|
||||
yield field.name, True
|
||||
elif field.type == "string":
|
||||
elif field.py_type is str:
|
||||
yield field.name, "test"
|
||||
elif field.type == "RecoveryType":
|
||||
elif field.py_type is messages.RecoveryType:
|
||||
yield field.name, 1
|
||||
else:
|
||||
# Someone added a field to RecoveryDevice of a type that has no assigned
|
||||
|
Loading…
Reference in New Issue
Block a user