diff --git a/trezorlib/protobuf.py b/trezorlib/protobuf.py index 3b9d8dd28d..2edf2252c2 100644 --- a/trezorlib/protobuf.py +++ b/trezorlib/protobuf.py @@ -127,7 +127,10 @@ class UnicodeType: class MessageType: WIRE_TYPE = 2 - FIELDS = {} + + @classmethod + def get_fields(cls): + return {} def __init__(self, **kwargs): for kw in kwargs: @@ -171,7 +174,7 @@ class MessageType: def _additem(self, attr): # Add new item for repeated field type - for v in self.FIELDS.values(): + for v in self.get_fields().values(): if v[0] != attr: continue if not (v[2] & FLAG_REPEATED): @@ -191,7 +194,7 @@ class MessageType: def _fill_missing(self): # fill missing fields - for fname, ftype, fflags in self.FIELDS.values(): + for fname, ftype, fflags in self.get_fields().values(): if not hasattr(self, fname): if fflags & FLAG_REPEATED: setattr(self, fname, []) @@ -235,7 +238,7 @@ FLAG_REPEATED = 1 def load_message(reader, msg_type): - fields = msg_type.FIELDS + fields = msg_type.get_fields() msg = msg_type() while True: @@ -296,7 +299,7 @@ def load_message(reader, msg_type): def dump_message(writer, msg): repvalue = [0] mtype = msg.__class__ - fields = mtype.FIELDS + fields = mtype.get_fields() for ftag in fields: fname, ftype, fflags = fields[ftag] @@ -424,7 +427,7 @@ def value_to_proto(ftype, value): def dict_to_proto(message_type, d): params = {} - for fname, ftype, fflags in message_type.FIELDS.values(): + for fname, ftype, fflags in message_type.get_fields().values(): repeated = fflags & FLAG_REPEATED value = d.get(fname) if value is None: diff --git a/trezorlib/tests/unit_tests/test_protobuf.py b/trezorlib/tests/unit_tests/test_protobuf.py index d4299e6e2d..0ea04ea094 100644 --- a/trezorlib/tests/unit_tests/test_protobuf.py +++ b/trezorlib/tests/unit_tests/test_protobuf.py @@ -22,13 +22,15 @@ from trezorlib import protobuf class PrimitiveMessage(protobuf.MessageType): - FIELDS = { - 0: ("uvarint", protobuf.UVarintType, 0), - 1: ("svarint", protobuf.SVarintType, 0), - 2: ("bool", protobuf.BoolType, 0), - 3: ("bytes", protobuf.BytesType, 0), - 4: ("unicode", protobuf.UnicodeType, 0), - } + @classmethod + def get_fields(cls): + return { + 0: ("uvarint", protobuf.UVarintType, 0), + 1: ("svarint", protobuf.SVarintType, 0), + 2: ("bool", protobuf.BoolType, 0), + 3: ("bytes", protobuf.BytesType, 0), + 4: ("unicode", protobuf.UnicodeType, 0), + } def load_uvarint(buffer): diff --git a/vendor/trezor-common b/vendor/trezor-common index 4e7df217d3..f60b722638 160000 --- a/vendor/trezor-common +++ b/vendor/trezor-common @@ -1 +1 @@ -Subproject commit 4e7df217d337c8800fc91bac61038918bf3773e4 +Subproject commit f60b722638116a878d88b9f9393f311f8b45834e