mirror of
https://github.com/trezor/trezor-firmware.git
synced 2024-11-14 03:30:02 +00:00
7c348a1d79
[no changelog]
766 lines
24 KiB
Python
Executable File
766 lines
24 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# Converts Google's protobuf python definitions of Trezor wire messages
|
|
# to plain-python objects as used in Trezor Core and python-trezor
|
|
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import re
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
import click
|
|
import construct as c
|
|
import mako
|
|
import mako.template
|
|
from google.protobuf import descriptor_pb2
|
|
|
|
FieldDescriptor = descriptor_pb2.FieldDescriptorProto
|
|
|
|
FIELD_TYPES_PYTHON = {
|
|
FieldDescriptor.TYPE_UINT64: "int",
|
|
FieldDescriptor.TYPE_UINT32: "int",
|
|
FieldDescriptor.TYPE_SINT64: "int",
|
|
FieldDescriptor.TYPE_SINT32: "int",
|
|
FieldDescriptor.TYPE_BOOL: "bool",
|
|
FieldDescriptor.TYPE_BYTES: "bytes",
|
|
FieldDescriptor.TYPE_STRING: "str",
|
|
}
|
|
|
|
TYPE_NAMES = {
|
|
FieldDescriptor.TYPE_UINT64: "uint64",
|
|
FieldDescriptor.TYPE_UINT32: "uint32",
|
|
FieldDescriptor.TYPE_SINT64: "sint64",
|
|
FieldDescriptor.TYPE_SINT32: "sint32",
|
|
FieldDescriptor.TYPE_BOOL: "bool",
|
|
FieldDescriptor.TYPE_BYTES: "bytes",
|
|
FieldDescriptor.TYPE_STRING: "string",
|
|
}
|
|
|
|
FIELD_TYPES_RUST_BLOB = {
|
|
FieldDescriptor.TYPE_UINT64: 0,
|
|
FieldDescriptor.TYPE_UINT32: 0,
|
|
FieldDescriptor.TYPE_SINT64: 1,
|
|
FieldDescriptor.TYPE_SINT32: 1,
|
|
FieldDescriptor.TYPE_BOOL: 2,
|
|
FieldDescriptor.TYPE_BYTES: 3,
|
|
FieldDescriptor.TYPE_STRING: 4,
|
|
FieldDescriptor.TYPE_ENUM: 5,
|
|
FieldDescriptor.TYPE_MESSAGE: 6,
|
|
}
|
|
|
|
INT_TYPES = (
|
|
FieldDescriptor.TYPE_UINT64,
|
|
FieldDescriptor.TYPE_UINT32,
|
|
FieldDescriptor.TYPE_SINT64,
|
|
FieldDescriptor.TYPE_SINT32,
|
|
)
|
|
|
|
MESSAGE_TYPE_ENUM = "MessageType"
|
|
|
|
LengthDelimited = c.Struct(
|
|
"len" / c.VarInt,
|
|
"bytes" / c.Bytes(c.this.len),
|
|
)
|
|
|
|
ListOfSimpleValues = c.GreedyRange(
|
|
c.Struct(
|
|
"key" / c.VarInt,
|
|
"value" / c.Switch(c.this.key & 0b111, {0: c.VarInt, 2: LengthDelimited}),
|
|
)
|
|
)
|
|
|
|
|
|
def parse_protobuf_simple(data):
|
|
"""Micro-parse protobuf-encoded data.
|
|
|
|
Assume every value is of type 0 (varint) or 2 (length-delimited),
|
|
and parse to a dict of fieldnum: value.
|
|
"""
|
|
return {v.key >> 3: v.value for v in ListOfSimpleValues.parse(data)}
|
|
|
|
|
|
PROTOC = shutil.which("protoc")
|
|
if not PROTOC:
|
|
print("protoc command not found")
|
|
sys.exit(1)
|
|
|
|
PROTOC_PREFIX = Path(PROTOC).resolve().parent.parent
|
|
|
|
|
|
ENUM_ENTRY = c.PrefixedArray(c.Byte, c.Int16ul)
|
|
|
|
FIELD_STRUCT = c.Struct(
|
|
"tag" / c.Byte,
|
|
"flags_and_type"
|
|
/ c.BitStruct(
|
|
"is_required" / c.Flag,
|
|
"is_repeated" / c.Flag,
|
|
"is_experimental" / c.Flag,
|
|
c.Padding(1),
|
|
"type" / c.BitsInteger(4),
|
|
),
|
|
"enum_or_msg_offset" / c.Int16ul,
|
|
"name" / c.Int16ul,
|
|
)
|
|
|
|
MSG_ENTRY = c.Struct(
|
|
"fields_count" / c.Rebuild(c.Byte, c.len_(c.this.fields)),
|
|
"defaults_size" / c.Rebuild(c.Byte, c.len_(c.this.defaults)),
|
|
# highest bit = is_experimental
|
|
# the rest = wire_id, 0x7FFF iff unset
|
|
"flags_and_wire_type" / c.Int16ul,
|
|
"fields" / c.Array(c.this.fields_count, FIELD_STRUCT),
|
|
"defaults" / c.Bytes(c.this.defaults_size),
|
|
)
|
|
|
|
DEFAULT_VARINT_ENTRY = c.Sequence(c.Byte, c.VarInt)
|
|
DEFAULT_LENGTH_ENTRY = c.Sequence(c.Byte, c.Prefixed(c.VarInt, c.GreedyRange(c.Byte)))
|
|
|
|
NAME_ENTRY = c.Sequence(
|
|
"msg_name" / c.Int16ul,
|
|
"msg_offset" / c.Int16ul,
|
|
)
|
|
|
|
WIRETYPE_ENTRY = c.Sequence(
|
|
"wire_id" / c.Int16ul,
|
|
"msg_offset" / c.Int16ul,
|
|
)
|
|
|
|
# QDEF(MP_QSTR_copysign, 5171, 8, "copysign")
|
|
QDEF_RE = re.compile(
|
|
r'^QDEF\(MP_QSTR(\S+), ([0-9]+), ([0-9])+, "(.*)"\)$'
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ProtoField:
|
|
name: str
|
|
number: int
|
|
type: object
|
|
extensions: dict
|
|
orig: object
|
|
# name of type without package path
|
|
type_name: str
|
|
|
|
descriptor: "Descriptor"
|
|
|
|
@property
|
|
def repeated(self):
|
|
return self.orig.label == FieldDescriptor.LABEL_REPEATED
|
|
|
|
@property
|
|
def required(self):
|
|
return self.orig.label == FieldDescriptor.LABEL_REQUIRED
|
|
|
|
@property
|
|
def optional(self):
|
|
return not self.required and not self.repeated
|
|
|
|
@property
|
|
def experimental(self):
|
|
return bool(self.extensions.get("experimental_field"))
|
|
|
|
@property
|
|
def is_message(self):
|
|
return self.type == FieldDescriptor.TYPE_MESSAGE
|
|
|
|
@property
|
|
def is_enum(self):
|
|
return self.type == FieldDescriptor.TYPE_ENUM
|
|
|
|
@property
|
|
def python_type(self):
|
|
return FIELD_TYPES_PYTHON.get(self.type, self.type_name)
|
|
|
|
@property
|
|
def default_value(self):
|
|
if not self.orig.HasField("default_value"):
|
|
return None
|
|
return self.orig.default_value
|
|
|
|
@property
|
|
def default_value_repr(self):
|
|
if self.default_value is None:
|
|
return "None"
|
|
|
|
elif self.is_enum:
|
|
selected_enum_value = strip_enum_prefix(self.type_name, self.default_value)
|
|
return f"{self.type_name}.{selected_enum_value}"
|
|
|
|
elif self.type == FieldDescriptor.TYPE_STRING:
|
|
return repr(self.default_value)
|
|
elif self.type == FieldDescriptor.TYPE_BYTES:
|
|
return "b" + repr(self.default_value)
|
|
elif self.type == FieldDescriptor.TYPE_BOOL:
|
|
return "True" if self.default_value == "true" else "False"
|
|
else:
|
|
return str(self.default_value)
|
|
|
|
@property
|
|
def type_object(self):
|
|
if self.is_enum:
|
|
return find_by_name(self.descriptor.enums, self.type_name)
|
|
if self.is_message:
|
|
return find_by_name(self.descriptor.messages, self.type_name)
|
|
return None
|
|
|
|
@classmethod
|
|
def from_field(cls, descriptor, field):
|
|
if not field.type_name:
|
|
type_name = TYPE_NAMES[field.type]
|
|
else:
|
|
type_name = field.type_name.rsplit(".")[-1]
|
|
return cls(
|
|
name=field.name,
|
|
number=field.number,
|
|
type=field.type,
|
|
orig=field,
|
|
extensions=descriptor.get_extensions(field),
|
|
type_name=type_name,
|
|
descriptor=descriptor,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class ProtoMessage:
|
|
name: str
|
|
wire_type: Optional[int]
|
|
orig: object
|
|
extensions: dict
|
|
|
|
fields: List[ProtoField]
|
|
|
|
@classmethod
|
|
def from_message(cls, descriptor: "Descriptor", message):
|
|
message_type = find_by_name(descriptor.message_type_enum.value, message.name)
|
|
# use extensions set on the message_type entry (if any)
|
|
extensions = descriptor.get_extensions(message_type)
|
|
# override with extensions set on the message itself
|
|
extensions.update(descriptor.get_extensions(message))
|
|
|
|
if "wire_type" in extensions:
|
|
wire_type = extensions["wire_type"]
|
|
elif message_type is not None:
|
|
wire_type = message_type.number
|
|
else:
|
|
wire_type = None
|
|
|
|
return cls(
|
|
name=message.name,
|
|
wire_type=wire_type,
|
|
orig=message,
|
|
extensions=extensions,
|
|
fields=[
|
|
ProtoField.from_field(descriptor, f)
|
|
for f in descriptor._filter_items(message.field)
|
|
],
|
|
)
|
|
|
|
|
|
def protoc(files):
|
|
"""Compile code with protoc and return the data."""
|
|
include_dirs = set()
|
|
include_dirs.add(str(PROTOC_PREFIX / "include"))
|
|
if "PROTOC_INCLUDE" in os.environ:
|
|
include_dirs.add(os.environ["PROTOC_INCLUDE"])
|
|
|
|
for file in files:
|
|
include_dirs.add(os.path.dirname(file) or ".")
|
|
protoc_includes = ["-I" + dir for dir in include_dirs if dir]
|
|
|
|
return subprocess.check_output(
|
|
[PROTOC, "--descriptor_set_out=/dev/stdout"] + protoc_includes + list(files)
|
|
)
|
|
|
|
|
|
def strip_enum_prefix(enum_name, value_name):
|
|
"""Generate stripped-down enum value name, given the enum type name.
|
|
|
|
There are three kinds of enums in the codebase:
|
|
|
|
(1) New-style:
|
|
|
|
enum SomeEnum {
|
|
First_Value = 1;
|
|
SecondValue = 2;
|
|
}
|
|
|
|
(2) Old-style without "Type":
|
|
|
|
enum SomeEnum {
|
|
SomeEnum_First_Value = 1;
|
|
SomeEnum_SecondValue = 2;
|
|
}
|
|
|
|
(3) Old-style with "Type":
|
|
|
|
enum SomeEnumType {
|
|
SomeEnum_First_Value = 1;
|
|
SomeEnum_SecondValue = 2;
|
|
}
|
|
|
|
This function accepts the name of the enum ("SomeEnum") and the name of the value,
|
|
and returns the name of the value as it would look in the new-style -- i.e.,
|
|
for any variation of the above, the values returned would be "First_Value" and
|
|
"SecondValue".
|
|
"""
|
|
leader = enum_name + "_"
|
|
if value_name.startswith(leader):
|
|
return value_name[len(leader) :]
|
|
|
|
if enum_name.endswith("Type"):
|
|
leader = enum_name[: -len("Type")] + "_"
|
|
if value_name.startswith(leader):
|
|
return value_name[len(leader) :]
|
|
|
|
return value_name
|
|
|
|
|
|
def find_by_name(haystack, name, default=None):
|
|
return next((item for item in haystack if item.name == name), default)
|
|
|
|
|
|
class Descriptor:
|
|
def __init__(self, data, include_deprecated: bool, bitcoin_only: bool):
|
|
self.descriptor = descriptor_pb2.FileDescriptorSet()
|
|
self.descriptor.ParseFromString(data)
|
|
|
|
self.include_deprecated = include_deprecated
|
|
self.bitcoin_only = bitcoin_only
|
|
self.files = self.descriptor.file
|
|
|
|
logging.debug(f"found {len(self.files)} files")
|
|
|
|
# collect extensions across all files
|
|
# this is required for self._get_extension() to work
|
|
self.extensions = {
|
|
ext.name: ext.number for file in self.files for ext in file.extension
|
|
}
|
|
|
|
if self.bitcoin_only:
|
|
self.files = [
|
|
f
|
|
for f in self.files
|
|
if self.get_extensions(f).get("include_in_bitcoin_only")
|
|
]
|
|
logging.debug(f"found {len(self.files)} bitcoin-only files")
|
|
|
|
# find message_type enum
|
|
top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files)
|
|
self.message_type_enum = find_by_name(top_level_enums, MESSAGE_TYPE_ENUM, ())
|
|
self.convert_enum_value_names(self.message_type_enum)
|
|
|
|
# find messages and enums
|
|
self.messages = []
|
|
self.enums = []
|
|
|
|
for file in self.files:
|
|
messages = [
|
|
ProtoMessage.from_message(self, m)
|
|
for m in self._filter_items(file.message_type)
|
|
]
|
|
self.messages += messages
|
|
self.enums += self._filter_items(file.enum_type)
|
|
|
|
for message in messages:
|
|
# recursively search for nested types in newly added messages
|
|
self._nested_types_from_message(message.orig)
|
|
|
|
if not self.messages and not self.enums:
|
|
raise RuntimeError("No messages and no enums found.")
|
|
|
|
for enum in self.enums:
|
|
self.convert_enum_value_names(enum)
|
|
|
|
def _filter_items(self, iter):
|
|
return [
|
|
item
|
|
for item in iter
|
|
# exclude deprecated items unless specified
|
|
if (self.include_deprecated or not item.options.deprecated)
|
|
]
|
|
|
|
def _get_extension(self, something, extension_name, default=None):
|
|
if something is None:
|
|
return default
|
|
if extension_name not in self.extensions:
|
|
return default
|
|
|
|
# There doesn't seem to be a sane way to access extensions on a descriptor
|
|
# via the google.protobuf API.
|
|
# We do have access to descriptors of the extensions...
|
|
extension_num = self.extensions[extension_name]
|
|
# ...and the "options" descriptor _does_ include the extension data. But while
|
|
# the API provides access to unknown fields, it hides the extensions.
|
|
# What we do is re-encode the options descriptor...
|
|
options_bytes = something.options.SerializeToString()
|
|
# ...and re-parse it as a dict of uvarints/strings...
|
|
simple_values = parse_protobuf_simple(options_bytes)
|
|
# ...and extract the value corresponding to the extension we care about.
|
|
return simple_values.get(extension_num, default)
|
|
|
|
def get_extensions(self, something):
|
|
return {
|
|
extension: self._get_extension(something, extension)
|
|
for extension in self.extensions
|
|
if self._get_extension(something, extension) is not None
|
|
}
|
|
|
|
def _nested_types_from_message(self, message):
|
|
nested_messages = [
|
|
ProtoMessage.from_message(self, m)
|
|
for m in self._filter_items(message.nested_type)
|
|
]
|
|
self.messages += nested_messages
|
|
self.enums += self._filter_items(message.enum_type)
|
|
for nested in nested_messages:
|
|
self._nested_types_from_message(nested.orig)
|
|
|
|
def convert_enum_value_names(self, enum):
|
|
for value in enum.value:
|
|
value.name = strip_enum_prefix(enum.name, value.name)
|
|
|
|
|
|
class PythonRenderer:
|
|
def __init__(self, descriptor: Descriptor, out_dir="", python_extension="py"):
|
|
self.descriptor = descriptor
|
|
self.out_dir = Path(out_dir)
|
|
self.python_extension = python_extension
|
|
|
|
def process_message(self, template, message):
|
|
logging.debug(f"Processing message {message.name}")
|
|
return template.render(message=message)
|
|
|
|
def process_enum(self, template, enum):
|
|
logging.debug(f"Processing enum {enum.name}")
|
|
|
|
all_values = self.descriptor._filter_items(enum.value)
|
|
|
|
has_bitcoin_only_values = self.descriptor._get_extension(
|
|
enum, "has_bitcoin_only_values"
|
|
)
|
|
if has_bitcoin_only_values:
|
|
values_always = [
|
|
v
|
|
for v in all_values
|
|
if self.descriptor._get_extension(v, "bitcoin_only")
|
|
]
|
|
values_altcoin = [v for v in all_values if v not in values_always]
|
|
|
|
else:
|
|
values_always = all_values
|
|
values_altcoin = []
|
|
|
|
return template.render(
|
|
enum=enum,
|
|
values_always=values_always,
|
|
values_altcoin=values_altcoin,
|
|
)
|
|
|
|
def write_to_file(self, item_name, content):
|
|
dest = self.out_dir / (item_name + "." + self.python_extension)
|
|
dest.write_text(content)
|
|
|
|
def generate_messages(self, template_src):
|
|
template = mako.template.Template(filename=str(template_src))
|
|
for message in self.descriptor.messages:
|
|
self.write_to_file(message.name, self.process_message(template, message))
|
|
|
|
def generate_enums(self, template_src):
|
|
template = mako.template.Template(filename=str(template_src))
|
|
for enum in self.descriptor.enums:
|
|
self.write_to_file(enum.name, self.process_enum(template, enum))
|
|
|
|
def render_singlefile(self, template_src):
|
|
template = mako.template.Template(filename=str(template_src))
|
|
return template.render(
|
|
messages=self.descriptor.messages,
|
|
enums=self.descriptor.enums,
|
|
)
|
|
|
|
def generate_python(self):
|
|
enum_template = self.out_dir / "_proto_enum_class.mako"
|
|
message_template = self.out_dir / "_proto_message_class.mako"
|
|
init_template = self.out_dir / "_proto_init.mako"
|
|
if enum_template.exists():
|
|
self.generate_enums(enum_template)
|
|
if message_template.exists():
|
|
self.generate_messages(message_template)
|
|
if init_template.exists():
|
|
init_py = self.render_singlefile(init_template)
|
|
self.write_to_file("__init__", init_py)
|
|
|
|
|
|
class RustBlobRenderer:
|
|
def __init__(self, descriptor: Descriptor, qstr_defs: str = None):
|
|
self.descriptor = descriptor
|
|
|
|
self.qstr_map = {}
|
|
self.enum_map = {}
|
|
self.msg_map = {}
|
|
|
|
if qstr_defs:
|
|
self.build_qstr_map(qstr_defs)
|
|
|
|
def write_qstrs(self, qstr_path):
|
|
logging.debug(f"Writing qstrings to {qstr_path}")
|
|
message_names = {m.name for m in self.descriptor.messages}
|
|
field_names = {
|
|
f.name for message in self.descriptor.messages for f in message.fields
|
|
}
|
|
with open(qstr_path, "w") as f:
|
|
for name in sorted(message_names | field_names):
|
|
f.write(f"Q({name})\n")
|
|
|
|
def write_blobs(self, blob_dir):
|
|
logging.debug(f"Writing blobs to {blob_dir}")
|
|
blob_dir = Path(blob_dir)
|
|
|
|
enum_blob = self.build_enums_with_offsets()
|
|
|
|
# build msg entries and fill out map
|
|
msg_entries = self.build_message_entries()
|
|
# fill message offsets
|
|
self.fill_enum_or_msg_offsets(msg_entries)
|
|
# encode blob
|
|
msg_blob = self.build_message_blob(msg_entries)
|
|
|
|
name_blob = self.build_blob_names()
|
|
wire_blob = self.build_blob_wire()
|
|
|
|
(blob_dir / "proto_enums.data").write_bytes(enum_blob)
|
|
(blob_dir / "proto_msgs.data").write_bytes(msg_blob)
|
|
(blob_dir / "proto_names.data").write_bytes(name_blob)
|
|
(blob_dir / "proto_wire.data").write_bytes(wire_blob)
|
|
|
|
def build_qstr_map(self, qstr_defs):
|
|
# QSTR defs are rolled out into an enum in py/qstr.h, the numeric
|
|
# value is simply an incremented integer.
|
|
qstr_counter = 0
|
|
with open(qstr_defs, "r") as f:
|
|
for line in f:
|
|
match = QDEF_RE.match(line)
|
|
if not match:
|
|
continue
|
|
line = match.group(0)
|
|
string = match.group(4)
|
|
self.qstr_map[string] = qstr_counter
|
|
qstr_counter += 1
|
|
logging.debug(f"Found {qstr_counter} Qstr defs")
|
|
|
|
def build_enums_with_offsets(self):
|
|
enums = []
|
|
cursor = 0
|
|
for enum in sorted(self.descriptor.enums, key=lambda e: e.name):
|
|
if enum.name == "MessageType":
|
|
continue
|
|
self.enum_map[enum.name] = cursor
|
|
enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value))
|
|
enums.append(enum_blob)
|
|
cursor += len(enum_blob)
|
|
|
|
return b"".join(enums)
|
|
|
|
def encode_flags_and_wire_type(self, message):
|
|
wire_type = message.wire_type
|
|
if wire_type is None:
|
|
wire_type = 0x7FFF
|
|
if wire_type > 0x7FFF:
|
|
raise ValueError("Unsupported wire type")
|
|
|
|
flags_and_wire_type = wire_type
|
|
if message.extensions.get("experimental_message"):
|
|
flags_and_wire_type |= 0x8000
|
|
|
|
return flags_and_wire_type
|
|
|
|
def encode_field(self, field):
|
|
return dict(
|
|
tag=field.number,
|
|
flags_and_type=dict(
|
|
is_required=field.required,
|
|
is_repeated=field.repeated,
|
|
is_experimental=field.experimental,
|
|
type=FIELD_TYPES_RUST_BLOB[field.type],
|
|
),
|
|
enum_or_msg_offset=0,
|
|
name=self.qstr_map[field.name],
|
|
orig_field=field,
|
|
)
|
|
|
|
def fill_enum_or_msg_offsets(self, msg_entries):
|
|
for msg_dict in msg_entries:
|
|
for field_dict in msg_dict["fields"]:
|
|
field = field_dict["orig_field"]
|
|
if field.is_enum:
|
|
field_dict["enum_or_msg_offset"] = self.enum_map[field.type_name]
|
|
elif field.is_message:
|
|
field_dict["enum_or_msg_offset"] = self.msg_map[field.type_name]
|
|
|
|
def build_message_entries(self):
|
|
messages = []
|
|
cursor = 0
|
|
|
|
for message in sorted(self.descriptor.messages, key=lambda m: m.name):
|
|
self.msg_map[message.name] = cursor
|
|
fields = sorted(message.fields, key=lambda f: f.number)
|
|
|
|
defaults = b"".join(self.encode_field_default(f) for f in fields)
|
|
flags_and_wire_type = self.encode_flags_and_wire_type(message)
|
|
entry = dict(
|
|
flags_and_wire_type=flags_and_wire_type,
|
|
fields=[self.encode_field(f) for f in fields],
|
|
defaults=defaults,
|
|
)
|
|
|
|
messages.append(entry)
|
|
cursor += len(MSG_ENTRY.build(entry))
|
|
|
|
return messages
|
|
|
|
def build_message_blob(self, msg_entries):
|
|
return b"".join(MSG_ENTRY.build(entry) for entry in msg_entries)
|
|
|
|
def encode_field_default(self, field):
|
|
if field.number > 0xFF:
|
|
raise ValueError("Invalid field number")
|
|
|
|
default = field.default_value
|
|
|
|
if default is None:
|
|
return b""
|
|
|
|
elif field.type in INT_TYPES:
|
|
return DEFAULT_VARINT_ENTRY.build((field.number, int(default)))
|
|
|
|
elif field.type == FieldDescriptor.TYPE_BOOL:
|
|
return DEFAULT_VARINT_ENTRY.build((field.number, int(default == "true")))
|
|
|
|
elif field.type == FieldDescriptor.TYPE_BYTES:
|
|
if default != "":
|
|
raise ValueError(
|
|
"Bytes fields can only have empty bytes for default value"
|
|
)
|
|
return DEFAULT_LENGTH_ENTRY.build((field.number, b""))
|
|
|
|
elif field.type == FieldDescriptor.TYPE_STRING:
|
|
return DEFAULT_LENGTH_ENTRY.build((field.number, default.encode()))
|
|
|
|
elif field.is_enum:
|
|
# find the right value
|
|
value = find_by_name(field.type_object.value, default)
|
|
if value is None:
|
|
raise ValueError(f"Default not found for field {field.name}")
|
|
return DEFAULT_VARINT_ENTRY.build((field.number, value.number))
|
|
|
|
else:
|
|
raise ValueError(f"Cannot encode default value for field {field.name}")
|
|
|
|
def build_blob_names(self):
|
|
# sorting by Qstr value of the message name
|
|
messages = sorted(self.descriptor.messages, key=lambda m: self.qstr_map[m.name])
|
|
return b"".join(
|
|
NAME_ENTRY.build((self.qstr_map[m.name], self.msg_map[m.name]))
|
|
for m in messages
|
|
)
|
|
|
|
def build_blob_wire(self):
|
|
# create wire-type -> message mapping
|
|
wire_messages = [m for m in self.descriptor.messages if m.wire_type is not None]
|
|
# sorting by wire-type
|
|
wire_messages.sort(key=lambda m: m.wire_type)
|
|
return b"".join(
|
|
WIRETYPE_ENTRY.build((m.wire_type, self.msg_map[m.name]))
|
|
for m in wire_messages
|
|
)
|
|
|
|
|
|
ReadableFile = click.Path(exists=True, dir_okay=False, readable=True)
|
|
WritableFile = click.Path(dir_okay=False, writable=True)
|
|
WritableDirectory = click.Path(exists=True, file_okay=False, writable=True)
|
|
|
|
|
|
@click.command()
|
|
# fmt: off
|
|
@click.argument("proto", nargs=-1, type=ReadableFile, required=True)
|
|
@click.option("--python-outdir", type=WritableDirectory, help="Output directory for Python classes (contents will be deleted)")
|
|
@click.option("--python-extension", default="py", help="Use .pyi to generate type stubs")
|
|
@click.option("--outfile", type=WritableFile, help="Output file for single-file generated definitions")
|
|
@click.option("--template", type=ReadableFile, help="Template for single-file entry")
|
|
@click.option("--blob-outdir", type=WritableDirectory, help="Output directory for protobuf blob files")
|
|
@click.option("--qstr-defs", type=ReadableFile, help="Collected Qstr definitions")
|
|
@click.option("--qstr-out", type=WritableFile, help="Output Qstr header")
|
|
@click.option("-v", "--verbose", is_flag=True)
|
|
@click.option("-d", "--include-deprecated", is_flag=True, help="Include deprecated fields, messages and enums")
|
|
@click.option("-b", "--bitcoin-only", type=int, default=0, help="Exclude fields, messages and enums that do not belong to bitcoin_only builds")
|
|
# fmt: on
|
|
def main(
|
|
proto,
|
|
python_outdir,
|
|
python_extension,
|
|
outfile,
|
|
template,
|
|
blob_outdir,
|
|
qstr_defs,
|
|
qstr_out,
|
|
verbose,
|
|
include_deprecated,
|
|
bitcoin_only,
|
|
):
|
|
if verbose:
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
|
|
descriptor_proto = protoc(proto)
|
|
descriptor = Descriptor(
|
|
descriptor_proto,
|
|
include_deprecated=include_deprecated,
|
|
bitcoin_only=bitcoin_only,
|
|
)
|
|
|
|
if python_outdir:
|
|
outdir = Path(python_outdir)
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
tmpdir_path = Path(tmpdir)
|
|
for file in outdir.glob("_proto*.mako"):
|
|
shutil.copy(file, tmpdir)
|
|
|
|
renderer = PythonRenderer(descriptor, tmpdir_path, python_extension)
|
|
renderer.generate_python()
|
|
|
|
for file in outdir.glob("*." + python_extension):
|
|
if file.name == "__init__." + python_extension:
|
|
continue
|
|
file.unlink()
|
|
|
|
for file in tmpdir_path.iterdir():
|
|
shutil.copy(file, outdir)
|
|
|
|
if outfile:
|
|
if not template:
|
|
raise click.ClickException("Please specify --template")
|
|
|
|
renderer = PythonRenderer(descriptor)
|
|
with open(outfile, "w") as f:
|
|
f.write(renderer.render_singlefile(template))
|
|
|
|
if qstr_out:
|
|
renderer = RustBlobRenderer(descriptor)
|
|
renderer.write_qstrs(qstr_out)
|
|
|
|
if blob_outdir:
|
|
if not qstr_defs:
|
|
raise click.ClickException("Qstr defs not provided")
|
|
|
|
renderer = RustBlobRenderer(descriptor, qstr_defs)
|
|
renderer.write_blobs(blob_outdir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|