diff --git a/tools/pb2py b/tools/pb2py index 3cda885686..3e2318dbb4 100755 --- a/tools/pb2py +++ b/tools/pb2py @@ -30,16 +30,34 @@ def remove_from_start(s, prefix): return s +def process_message_imports(descriptor): + imports = set() + + for field in descriptor.fields: + if field.type == field.TYPE_MESSAGE: + imports.add(field.message_type.name) + + for name in sorted(imports): + yield create_message_import(name) + + def process_message(descriptor, protobuf_module, msg_id, indexfile, is_upy): print(" * type %s" % descriptor.name) - imports = [] - out = ["", "", "class %s(p.MessageType):" % descriptor.name, ] + if is_upy: + yield "import protobuf as p" + else: + yield "from .. import protobuf as p" + yield from process_message_imports(descriptor) + + yield "" + yield "" + yield "class %s(p.MessageType):" % descriptor.name if descriptor.fields_by_number: - out.append(" FIELDS = {") + yield " FIELDS = {" elif msg_id is None: - out.append(" pass") + yield " pass" for number, field in descriptor.fields_by_number.items(): field_name = field.name @@ -60,7 +78,6 @@ def process_message(descriptor, protobuf_module, msg_id, indexfile, is_upy): if field.type == field.TYPE_MESSAGE: field_type = field.message_type.name - imports.append(create_message_import(field_type)) else: try: field_type = types[field.type] @@ -83,35 +100,23 @@ def process_message(descriptor, protobuf_module, msg_id, indexfile, is_upy): else: flags = '0' - out.append(" %d: ('%s', %s, %s),%s" % - (number, field_name, field_type, flags, comment)) + yield " %d: ('%s', %s, %s),%s" % (number, field_name, field_type, flags, comment) if descriptor.fields_by_name: - out.append(" }") + yield " }" if msg_id is not None: - out.append(" MESSAGE_WIRE_TYPE = %d" % msg_id) + yield " MESSAGE_WIRE_TYPE = %d" % msg_id if indexfile is not None: indexfile.write(create_const(t, msg_id, is_upy)) - # Remove duplicate imports - imports = sorted(list(set(imports))) - - if is_upy: - imports = ['import protobuf as p'] + imports - else: - imports = ['from .. import protobuf as p'] + imports - - return imports + out - def process_enum(descriptor, is_upy): - out = [] + print(" * enum %s" % descriptor.name) if is_upy: - out += ("from micropython import const", "") - - print(" * enum %s" % descriptor.name) + yield "from micropython import const" + yield "" for name, value in descriptor.values_by_name.items(): # Remove type name from the beginning of the constant @@ -125,9 +130,7 @@ def process_enum(descriptor, is_upy): enum_prefix, _ = enum_prefix.rsplit("Type", 1) name = remove_from_start(name, "%s_" % enum_prefix) - out.append(create_const(name, value.number, is_upy)) - - return out + yield create_const(name, value.number, is_upy) def process_file(descriptor, protobuf_module, genpath, indexfile, modlist, is_upy): @@ -144,7 +147,6 @@ def process_file(descriptor, protobuf_module, genpath, indexfile, modlist, is_up msg_id = None out = process_message(message_descriptor, protobuf_module, msg_id, indexfile, is_upy) - write_to_file(genpath, name, out) if modlist: modlist.write(create_message_import(name) + "\n") @@ -158,13 +160,10 @@ def process_file(descriptor, protobuf_module, genpath, indexfile, modlist, is_up def write_to_file(genpath, t, out): # Write generated sourcecode to given file - f = open(os.path.join(genpath, "%s.py" % t), 'w') - out = ["# Automatically generated by pb2py"] + out - - data = "\n".join(out) + "\n" - - f.write(data) - f.close() + with open(os.path.join(genpath, "%s.py" % t), 'w') as f: + f.write("# Automatically generated by pb2py\n") + for line in out: + f.write(line + "\n") if __name__ == '__main__':