pw_protobuf: Modularize codegen plugin
This change splits the pw_protobuf codegen plugin into three parts: the
ProtoNode tree, pwpb-specific code generation, and the compiler plugin.
This is done so that parts of the code can be reused in other protobuf
compiler plugins.
Change-Id: I36ff2e0970f41783135cfb10fc8d3bf8701ec7d0
diff --git a/pw_protobuf/BUILD.gn b/pw_protobuf/BUILD.gn
index 9e37211..f5f05b9 100644
--- a/pw_protobuf/BUILD.gn
+++ b/pw_protobuf/BUILD.gn
@@ -58,9 +58,9 @@
# Source files for pw_protobuf's protoc plugin.
pw_input_group("codegen_protoc_plugin") {
inputs = [
- "py/pw_protobuf/codegen.py",
- "py/pw_protobuf/methods.py",
- "py/pw_protobuf/proto_structures.py",
+ "py/pw_protobuf/codegen_pwpb.py",
+ "py/pw_protobuf/plugin.py",
+ "py/pw_protobuf/proto_tree.py",
]
}
diff --git a/pw_protobuf/py/pw_protobuf/codegen.py b/pw_protobuf/py/pw_protobuf/codegen.py
deleted file mode 100755
index a9329b5..0000000
--- a/pw_protobuf/py/pw_protobuf/codegen.py
+++ /dev/null
@@ -1,433 +0,0 @@
-#!/usr/bin/env python3
-# Copyright 2020 The Pigweed Authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may not
-# use this file except in compliance with the License. You may obtain a copy of
-# the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations under
-# the License.
-"""pw_protobuf compiler plugin.
-
-This file implements a protobuf compiler plugin which generates C++ headers for
-protobuf messages in the pw_protobuf format.
-"""
-
-import os
-import sys
-
-from typing import Iterable, List
-
-import google.protobuf.compiler.plugin_pb2 as plugin_pb2
-import google.protobuf.descriptor_pb2 as descriptor_pb2
-
-from pw_protobuf.methods import PROTO_FIELD_METHODS
-from pw_protobuf.proto_structures import ProtoEnum, ProtoExternal, ProtoMessage
-from pw_protobuf.proto_structures import ProtoMessageField, ProtoNode
-from pw_protobuf.proto_structures import ProtoPackage
-
-PLUGIN_NAME = 'pw_protobuf'
-PLUGIN_VERSION = '0.1.0'
-
-PROTO_H_EXTENSION = '.pwpb.h'
-PROTO_CC_EXTENSION = '.pwpb.cc'
-
-PROTOBUF_NAMESPACE = 'pw::protobuf'
-BASE_PROTO_CLASS = 'ProtoMessageEncoder'
-
-
-# protoc captures stdout, so we need to printf debug to stderr.
-def debug_print(*args, **kwargs):
- print(*args, file=sys.stderr, **kwargs)
-
-
-class OutputFile:
- """A buffer to which data is written.
-
- Example:
-
- ```
- output = Output("hello.c")
- output.write_line('int main(void) {')
- with output.indent():
- output.write_line('printf("Hello, world");')
- output.write_line('return 0;')
- output.write_line('}')
-
- print(output.content())
- ```
-
- Produces:
- ```
- int main(void) {
- printf("Hello, world");
- return 0;
- }
- ```
- """
-
- INDENT_WIDTH = 2
-
- def __init__(self, filename: str):
- self._filename: str = filename
- self._content: List[str] = []
- self._indentation: int = 0
-
- def write_line(self, line: str = '') -> None:
- if line:
- self._content.append(' ' * self._indentation)
- self._content.append(line)
- self._content.append('\n')
-
- def indent(self) -> 'OutputFile._IndentationContext':
- """Increases the indentation level of the output."""
- return self._IndentationContext(self)
-
- def name(self) -> str:
- return self._filename
-
- def content(self) -> str:
- return ''.join(self._content)
-
- class _IndentationContext:
- """Context that increases the output's indentation when it is active."""
- def __init__(self, output: 'OutputFile'):
- self._output = output
-
- def __enter__(self):
- self._output._indentation += OutputFile.INDENT_WIDTH
-
- def __exit__(self, typ, value, traceback):
- self._output._indentation -= OutputFile.INDENT_WIDTH
-
-
-def generate_code_for_message(message: ProtoNode, root: ProtoNode,
- output: OutputFile) -> None:
- """Creates a C++ class for a protobuf message."""
- assert message.type() == ProtoNode.Type.MESSAGE
-
- # Message classes inherit from the base proto message class in codegen.h
- # and use its constructor.
- base_class = f'{PROTOBUF_NAMESPACE}::{BASE_PROTO_CLASS}'
- output.write_line(
- f'class {message.cpp_namespace(root)}::Encoder : public {base_class} {{'
- )
- output.write_line(' public:')
-
- with output.indent():
- output.write_line(f'using {BASE_PROTO_CLASS}::{BASE_PROTO_CLASS};')
-
- # Generate methods for each of the message's fields.
- for field in message.fields():
- for method_class in PROTO_FIELD_METHODS[field.type()]:
- method = method_class(field, message, root)
- if not method.should_appear():
- continue
-
- output.write_line()
- method_signature = (
- f'{method.return_type()} '
- f'{method.name()}({method.param_string()})')
-
- if not method.in_class_definition():
- # Method will be defined outside of the class at the end of
- # the file.
- output.write_line(f'{method_signature};')
- continue
-
- output.write_line(f'{method_signature} {{')
- with output.indent():
- for line in method.body():
- output.write_line(line)
- output.write_line('}')
-
- output.write_line('};')
-
-
-def define_not_in_class_methods(message: ProtoNode, root: ProtoNode,
- output: OutputFile) -> None:
- """Defines methods for a message class that were previously declared."""
- assert message.type() == ProtoNode.Type.MESSAGE
-
- for field in message.fields():
- for method_class in PROTO_FIELD_METHODS[field.type()]:
- method = method_class(field, message, root)
- if not method.should_appear() or method.in_class_definition():
- continue
-
- output.write_line()
- class_name = f'{message.cpp_namespace(root)}::Encoder'
- method_signature = (
- f'inline {method.return_type(from_root=True)} '
- f'{class_name}::{method.name()}({method.param_string()})')
- output.write_line(f'{method_signature} {{')
- with output.indent():
- for line in method.body():
- output.write_line(line)
- output.write_line('}')
-
-
-def generate_code_for_enum(enum: ProtoNode, root: ProtoNode,
- output: OutputFile) -> None:
- """Creates a C++ enum for a proto enum."""
- assert enum.type() == ProtoNode.Type.ENUM
-
- output.write_line(f'enum class {enum.cpp_namespace(root)} {{')
- with output.indent():
- for name, number in enum.values():
- output.write_line(f'{name} = {number},')
- output.write_line('};')
-
-
-def forward_declare(node: ProtoNode, root: ProtoNode,
- output: OutputFile) -> None:
- """Generates code forward-declaring entities in a message's namespace."""
- if node.type() != ProtoNode.Type.MESSAGE:
- return
-
- namespace = node.cpp_namespace(root)
- output.write_line()
- output.write_line(f'namespace {namespace} {{')
-
- # Define an enum defining each of the message's fields and their numbers.
- output.write_line('enum class Fields {')
- with output.indent():
- for field in node.fields():
- output.write_line(f'{field.enum_name()} = {field.number()},')
- output.write_line('};')
-
- # Declare the message's encoder class and all of its enums.
- output.write_line()
- output.write_line('class Encoder;')
- for child in node.children():
- if child.type() == ProtoNode.Type.ENUM:
- output.write_line()
- generate_code_for_enum(child, node, output)
-
- output.write_line(f'}} // namespace {namespace}')
-
-
-def _proto_filename_to_generated_header(proto_file: str) -> str:
- """Returns the generated C++ header name for a .proto file."""
- return os.path.splitext(proto_file)[0] + PROTO_H_EXTENSION
-
-
-def generate_code_for_package(file_descriptor_proto, package: ProtoNode,
- output: OutputFile) -> None:
- """Generates code for a single .pb.h file corresponding to a .proto file."""
-
- assert package.type() == ProtoNode.Type.PACKAGE
-
- output.write_line(f'// {os.path.basename(output.name())} automatically '
- f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}')
- output.write_line('#pragma once\n')
- output.write_line('#include <cstddef>')
- output.write_line('#include <cstdint>\n')
- output.write_line('#include "pw_protobuf/codegen.h"')
-
- for imported_file in file_descriptor_proto.dependency:
- generated_header = _proto_filename_to_generated_header(imported_file)
- output.write_line(f'#include "{generated_header}"')
-
- if package.cpp_namespace():
- file_namespace = package.cpp_namespace()
- if file_namespace.startswith('::'):
- file_namespace = file_namespace[2:]
-
- output.write_line(f'\nnamespace {file_namespace} {{')
-
- for node in package:
- forward_declare(node, package, output)
-
- # Define all top-level enums.
- for node in package.children():
- if node.type() == ProtoNode.Type.ENUM:
- output.write_line()
- generate_code_for_enum(node, package, output)
-
- # Run through all messages in the file, generating a class for each.
- for node in package:
- if node.type() == ProtoNode.Type.MESSAGE:
- output.write_line()
- generate_code_for_message(node, package, output)
-
- # Run a second pass through the classes, this time defining all of the
- # methods which were previously only declared.
- for node in package:
- if node.type() == ProtoNode.Type.MESSAGE:
- define_not_in_class_methods(node, package, output)
-
- if package.cpp_namespace():
- output.write_line(f'\n}} // namespace {package.cpp_namespace()}')
-
-
-def add_enum_fields(enum: ProtoNode, proto_enum) -> None:
- """Adds fields from a protobuf enum descriptor to an enum node."""
- assert enum.type() == ProtoNode.Type.ENUM
- for value in proto_enum.value:
- enum.add_value(value.name, value.number)
-
-
-def create_external_nodes(root: ProtoNode, path: str) -> ProtoNode:
- """Creates external nodes for a path starting from the given root."""
-
- node = root
- for part in path.split('.'):
- child = node.find(part)
- if not child:
- child = ProtoExternal(part)
- node.add_child(child)
- node = child
-
- return node
-
-
-def add_message_fields(global_root: ProtoNode, package_root: ProtoNode,
- message: ProtoNode, proto_message) -> None:
- """Adds fields from a protobuf message descriptor to a message node."""
- assert message.type() == ProtoNode.Type.MESSAGE
-
- for field in proto_message.field:
- if field.type_name:
- # The "type_name" member contains the global .proto path of the
- # field's type object, for example ".pw.protobuf.test.KeyValuePair".
- # Try to find the node for this object within the current context.
-
- if field.type_name[0] == '.':
- # Fully qualified path.
- root_relative_path = field.type_name[1:]
- search_root = global_root
- else:
- root_relative_path = field.type_name
- search_root = package_root
-
- type_node = search_root.find(root_relative_path)
-
- if type_node is None:
- # Create nodes for field types that don't exist within this
- # compilation context, such as those imported from other .proto
- # files.
- type_node = create_external_nodes(search_root,
- root_relative_path)
- else:
- type_node = None
-
- repeated = \
- field.label == descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
- message.add_field(
- ProtoMessageField(
- field.name,
- field.number,
- field.type,
- type_node,
- repeated,
- ))
-
-
-def populate_fields(proto_file, global_root: ProtoNode,
- package_root: ProtoNode) -> None:
- """Traverses a proto file, adding all message and enum fields to a tree."""
- def populate_message(node, message):
- """Recursively populates nested messages and enums."""
- add_message_fields(global_root, package_root, node, message)
-
- for enum in message.enum_type:
- add_enum_fields(node.find(enum.name), enum)
- for msg in message.nested_type:
- populate_message(node.find(msg.name), msg)
-
- # Iterate through the proto file, populating top-level enums and messages.
- for enum in proto_file.enum_type:
- add_enum_fields(package_root.find(enum.name), enum)
- for message in proto_file.message_type:
- populate_message(package_root.find(message.name), message)
-
-
-def build_hierarchy(proto_file):
- """Creates a ProtoNode hierarchy from a proto file descriptor."""
-
- root = ProtoPackage('')
- package_root = root
-
- for part in proto_file.package.split('.'):
- package = ProtoPackage(part)
- package_root.add_child(package)
- package_root = package
-
- def build_message_subtree(proto_message):
- node = ProtoMessage(proto_message.name)
- for enum in proto_message.enum_type:
- node.add_child(ProtoEnum(enum.name))
- for submessage in proto_message.nested_type:
- node.add_child(build_message_subtree(submessage))
-
- return node
-
- for enum in proto_file.enum_type:
- package_root.add_child(ProtoEnum(enum.name))
-
- for message in proto_file.message_type:
- package_root.add_child(build_message_subtree(message))
-
- return root, package_root
-
-
-def process_proto_file(proto_file) -> Iterable[OutputFile]:
- """Generates code for a single .proto file."""
-
- # Two passes are made through the file. The first builds the tree of all
- # message/enum nodes, then the second creates the fields in each. This is
- # done as non-primitive fields need pointers to their types, which requires
- # the entire tree to have been parsed into memory.
- global_root, package_root = build_hierarchy(proto_file)
- populate_fields(proto_file, global_root, package_root)
-
- output_filename = _proto_filename_to_generated_header(proto_file.name)
- output_file = OutputFile(output_filename)
- generate_code_for_package(proto_file, package_root, output_file)
-
- return [output_file]
-
-
-def process_proto_request(req: plugin_pb2.CodeGeneratorRequest,
- res: plugin_pb2.CodeGeneratorResponse) -> None:
- """Handles a protoc CodeGeneratorRequest message.
-
- Generates code for the files in the request and writes the output to the
- specified CodeGeneratorResponse message.
-
- Args:
- req: A CodeGeneratorRequest for a proto compilation.
- res: A CodeGeneratorResponse to populate with the plugin's output.
- """
- for proto_file in req.proto_file:
- # TODO(frolv): Proto files are currently processed individually. Support
- # for multiple files with cross-dependencies should be added.
- output_files = process_proto_file(proto_file)
- for output_file in output_files:
- fd = res.file.add()
- fd.name = output_file.name()
- fd.content = output_file.content()
-
-
-def main() -> int:
- """Protobuf compiler plugin entrypoint.
-
- Reads a CodeGeneratorRequest proto from stdin and writes a
- CodeGeneratorResponse to stdout.
- """
- data = sys.stdin.buffer.read()
- request = plugin_pb2.CodeGeneratorRequest.FromString(data)
- response = plugin_pb2.CodeGeneratorResponse()
- process_proto_request(request, response)
- sys.stdout.buffer.write(response.SerializeToString())
- return 0
-
-
-if __name__ == '__main__':
- sys.exit(main())
diff --git a/pw_protobuf/py/pw_protobuf/methods.py b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py
similarity index 61%
rename from pw_protobuf/py/pw_protobuf/methods.py
rename to pw_protobuf/py/pw_protobuf/codegen_pwpb.py
index 023437d..5fb6661 100644
--- a/pw_protobuf/py/pw_protobuf/methods.py
+++ b/pw_protobuf/py/pw_protobuf/codegen_pwpb.py
@@ -11,14 +11,32 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
-"""This module defines methods for protobuf message C++ encoder classes."""
+"""This module defines the generated code for pw_protobuf C++ classes."""
import abc
-from typing import List, Tuple
+from datetime import datetime
+import os
+import sys
+from typing import Dict, Iterable, List, Tuple
import google.protobuf.descriptor_pb2 as descriptor_pb2
-from pw_protobuf.proto_structures import ProtoMessageField, ProtoNode
+from pw_protobuf.proto_tree import ProtoMessageField, ProtoNode
+from pw_protobuf.proto_tree import build_node_tree
+
+PLUGIN_NAME = 'pw_protobuf'
+PLUGIN_VERSION = '0.1.0'
+
+PROTO_H_EXTENSION = '.pwpb.h'
+PROTO_CC_EXTENSION = '.pwpb.cc'
+
+PROTOBUF_NAMESPACE = 'pw::protobuf'
+BASE_PROTO_CLASS = 'ProtoMessageEncoder'
+
+
+# protoc captures stdout, so we need to printf debug to stderr.
+def debug_print(*args, **kwargs):
+ print(*args, file=sys.stderr, **kwargs)
class ProtoMethod(abc.ABC):
@@ -447,7 +465,7 @@
# Mapping of protobuf field types to their method definitions.
-PROTO_FIELD_METHODS = {
+PROTO_FIELD_METHODS: Dict[int, List] = {
descriptor_pb2.FieldDescriptorProto.TYPE_DOUBLE:
[DoubleMethod, PackedDoubleMethod],
descriptor_pb2.FieldDescriptorProto.TYPE_FLOAT:
@@ -480,3 +498,240 @@
descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE: [SubMessageMethod],
descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: [EnumMethod],
}
+
+
+class OutputFile:
+ """A buffer to which data is written.
+
+ Example:
+
+ ```
+ output = Output("hello.c")
+ output.write_line('int main(void) {')
+ with output.indent():
+ output.write_line('printf("Hello, world");')
+ output.write_line('return 0;')
+ output.write_line('}')
+
+ print(output.content())
+ ```
+
+ Produces:
+ ```
+ int main(void) {
+ printf("Hello, world");
+ return 0;
+ }
+ ```
+ """
+
+ INDENT_WIDTH = 2
+
+ def __init__(self, filename: str):
+ self._filename: str = filename
+ self._content: List[str] = []
+ self._indentation: int = 0
+
+ def write_line(self, line: str = '') -> None:
+ if line:
+ self._content.append(' ' * self._indentation)
+ self._content.append(line)
+ self._content.append('\n')
+
+ def indent(self) -> 'OutputFile._IndentationContext':
+ """Increases the indentation level of the output."""
+ return self._IndentationContext(self)
+
+ def name(self) -> str:
+ return self._filename
+
+ def content(self) -> str:
+ return ''.join(self._content)
+
+ class _IndentationContext:
+ """Context that increases the output's indentation when it is active."""
+ def __init__(self, output: 'OutputFile'):
+ self._output = output
+
+ def __enter__(self):
+ self._output._indentation += OutputFile.INDENT_WIDTH
+
+ def __exit__(self, typ, value, traceback):
+ self._output._indentation -= OutputFile.INDENT_WIDTH
+
+
+def generate_code_for_message(message: ProtoNode, root: ProtoNode,
+ output: OutputFile) -> None:
+ """Creates a C++ class for a protobuf message."""
+ assert message.type() == ProtoNode.Type.MESSAGE
+
+ # Message classes inherit from the base proto message class in codegen.h
+ # and use its constructor.
+ base_class = f'{PROTOBUF_NAMESPACE}::{BASE_PROTO_CLASS}'
+ output.write_line(
+ f'class {message.cpp_namespace(root)}::Encoder : public {base_class} {{'
+ )
+ output.write_line(' public:')
+
+ with output.indent():
+ output.write_line(f'using {BASE_PROTO_CLASS}::{BASE_PROTO_CLASS};')
+
+ # Generate methods for each of the message's fields.
+ for field in message.fields():
+ for method_class in PROTO_FIELD_METHODS[field.type()]:
+ method = method_class(field, message, root)
+ if not method.should_appear():
+ continue
+
+ output.write_line()
+ method_signature = (
+ f'{method.return_type()} '
+ f'{method.name()}({method.param_string()})')
+
+ if not method.in_class_definition():
+ # Method will be defined outside of the class at the end of
+ # the file.
+ output.write_line(f'{method_signature};')
+ continue
+
+ output.write_line(f'{method_signature} {{')
+ with output.indent():
+ for line in method.body():
+ output.write_line(line)
+ output.write_line('}')
+
+ output.write_line('};')
+
+
+def define_not_in_class_methods(message: ProtoNode, root: ProtoNode,
+ output: OutputFile) -> None:
+ """Defines methods for a message class that were previously declared."""
+ assert message.type() == ProtoNode.Type.MESSAGE
+
+ for field in message.fields():
+ for method_class in PROTO_FIELD_METHODS[field.type()]:
+ method = method_class(field, message, root)
+ if not method.should_appear() or method.in_class_definition():
+ continue
+
+ output.write_line()
+ class_name = f'{message.cpp_namespace(root)}::Encoder'
+ method_signature = (
+ f'inline {method.return_type(from_root=True)} '
+ f'{class_name}::{method.name()}({method.param_string()})')
+ output.write_line(f'{method_signature} {{')
+ with output.indent():
+ for line in method.body():
+ output.write_line(line)
+ output.write_line('}')
+
+
+def generate_code_for_enum(enum: ProtoNode, root: ProtoNode,
+ output: OutputFile) -> None:
+ """Creates a C++ enum for a proto enum."""
+ assert enum.type() == ProtoNode.Type.ENUM
+
+ output.write_line(f'enum class {enum.cpp_namespace(root)} {{')
+ with output.indent():
+ for name, number in enum.values():
+ output.write_line(f'{name} = {number},')
+ output.write_line('};')
+
+
+def forward_declare(node: ProtoNode, root: ProtoNode,
+ output: OutputFile) -> None:
+ """Generates code forward-declaring entities in a message's namespace."""
+ if node.type() != ProtoNode.Type.MESSAGE:
+ return
+
+ namespace = node.cpp_namespace(root)
+ output.write_line()
+ output.write_line(f'namespace {namespace} {{')
+
+ # Define an enum defining each of the message's fields and their numbers.
+ output.write_line('enum class Fields {')
+ with output.indent():
+ for field in node.fields():
+ output.write_line(f'{field.enum_name()} = {field.number()},')
+ output.write_line('};')
+
+ # Declare the message's encoder class and all of its enums.
+ output.write_line()
+ output.write_line('class Encoder;')
+ for child in node.children():
+ if child.type() == ProtoNode.Type.ENUM:
+ output.write_line()
+ generate_code_for_enum(child, node, output)
+
+ output.write_line(f'}} // namespace {namespace}')
+
+
+def _proto_filename_to_generated_header(proto_file: str) -> str:
+ """Returns the generated C++ header name for a .proto file."""
+ return os.path.splitext(proto_file)[0] + PROTO_H_EXTENSION
+
+
+def generate_code_for_package(file_descriptor_proto, package: ProtoNode,
+ output: OutputFile) -> None:
+ """Generates code for a single .pb.h file corresponding to a .proto file."""
+
+ assert package.type() == ProtoNode.Type.PACKAGE
+
+ output.write_line(f'// {os.path.basename(output.name())} automatically '
+ f'generated by {PLUGIN_NAME} {PLUGIN_VERSION}')
+ output.write_line(f'// on {datetime.now()}')
+ output.write_line('#pragma once\n')
+ output.write_line('#include <cstddef>')
+ output.write_line('#include <cstdint>\n')
+ output.write_line('#include "pw_protobuf/codegen.h"')
+
+ for imported_file in file_descriptor_proto.dependency:
+ generated_header = _proto_filename_to_generated_header(imported_file)
+ output.write_line(f'#include "{generated_header}"')
+
+ if package.cpp_namespace():
+ file_namespace = package.cpp_namespace()
+ if file_namespace.startswith('::'):
+ file_namespace = file_namespace[2:]
+
+ output.write_line(f'\nnamespace {file_namespace} {{')
+
+ for node in package:
+ forward_declare(node, package, output)
+
+ # Define all top-level enums.
+ for node in package.children():
+ if node.type() == ProtoNode.Type.ENUM:
+ output.write_line()
+ generate_code_for_enum(node, package, output)
+
+ # Run through all messages in the file, generating a class for each.
+ for node in package:
+ if node.type() == ProtoNode.Type.MESSAGE:
+ output.write_line()
+ generate_code_for_message(node, package, output)
+
+ # Run a second pass through the classes, this time defining all of the
+ # methods which were previously only declared.
+ for node in package:
+ if node.type() == ProtoNode.Type.MESSAGE:
+ define_not_in_class_methods(node, package, output)
+
+ if package.cpp_namespace():
+ output.write_line(f'\n}} // namespace {package.cpp_namespace()}')
+
+
+def process_proto_file(proto_file) -> Iterable[OutputFile]:
+ """Generates code for a single .proto file."""
+
+ # Two passes are made through the file. The first builds the tree of all
+ # message/enum nodes, then the second creates the fields in each. This is
+ # done as non-primitive fields need pointers to their types, which requires
+ # the entire tree to have been parsed into memory.
+ _, package_root = build_node_tree(proto_file)
+
+ output_filename = _proto_filename_to_generated_header(proto_file.name)
+ output_file = OutputFile(output_filename)
+ generate_code_for_package(proto_file, package_root, output_file)
+
+ return [output_file]
diff --git a/pw_protobuf/py/pw_protobuf/plugin.py b/pw_protobuf/py/pw_protobuf/plugin.py
new file mode 100755
index 0000000..3702a6e
--- /dev/null
+++ b/pw_protobuf/py/pw_protobuf/plugin.py
@@ -0,0 +1,62 @@
+#!/usr/bin/env python3
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""pw_protobuf compiler plugin.
+
+This file implements a protobuf compiler plugin which generates C++ headers for
+protobuf messages in the pw_protobuf format.
+"""
+
+import sys
+
+import google.protobuf.compiler.plugin_pb2 as plugin_pb2
+
+import pw_protobuf.codegen_pwpb as codegen_pwpb
+
+
+def process_proto_request(req: plugin_pb2.CodeGeneratorRequest,
+ res: plugin_pb2.CodeGeneratorResponse) -> None:
+ """Handles a protoc CodeGeneratorRequest message.
+
+ Generates code for the files in the request and writes the output to the
+ specified CodeGeneratorResponse message.
+
+ Args:
+ req: A CodeGeneratorRequest for a proto compilation.
+ res: A CodeGeneratorResponse to populate with the plugin's output.
+ """
+ for proto_file in req.proto_file:
+ output_files = codegen_pwpb.process_proto_file(proto_file)
+ for output_file in output_files:
+ fd = res.file.add()
+ fd.name = output_file.name()
+ fd.content = output_file.content()
+
+
+def main() -> int:
+ """Protobuf compiler plugin entrypoint.
+
+ Reads a CodeGeneratorRequest proto from stdin and writes a
+ CodeGeneratorResponse to stdout.
+ """
+ data = sys.stdin.buffer.read()
+ request = plugin_pb2.CodeGeneratorRequest.FromString(data)
+ response = plugin_pb2.CodeGeneratorResponse()
+ process_proto_request(request, response)
+ sys.stdout.buffer.write(response.SerializeToString())
+ return 0
+
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/pw_protobuf/py/pw_protobuf/proto_structures.py b/pw_protobuf/py/pw_protobuf/proto_structures.py
deleted file mode 100644
index 600e494..0000000
--- a/pw_protobuf/py/pw_protobuf/proto_structures.py
+++ /dev/null
@@ -1,287 +0,0 @@
-# Copyright 2020 The Pigweed Authors
-#
-# Licensed under the Apache License, Version 2.0 (the "License"); you may not
-# use this file except in compliance with the License. You may obtain a copy of
-# the License at
-#
-# https://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
-# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
-# License for the specific language governing permissions and limitations under
-# the License.
-"""This module defines data structures for protobuf entities."""
-
-import abc
-import collections
-import enum
-
-from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
-
-T = TypeVar('T') # pylint: disable=invalid-name
-
-
-class ProtoNode(abc.ABC):
- """A ProtoNode represents a C++ scope mapping of an entity in a .proto file.
-
- Nodes form a tree beginning at a top-level (global) scope, descending into a
- hierarchy of .proto packages and the messages and enums defined within them.
- """
- class Type(enum.Enum):
- """The type of a ProtoNode.
-
- PACKAGE maps to a C++ namespace.
- MESSAGE maps to a C++ "Encoder" class within its own namespace.
- ENUM maps to a C++ enum within its parent's namespace.
- EXTERNAL represents a node defined within a different compilation unit.
- """
- PACKAGE = 1
- MESSAGE = 2
- ENUM = 3
- EXTERNAL = 4
-
- def __init__(self, name: str):
- self._name: str = name
- self._children: Dict[str, 'ProtoNode'] = collections.OrderedDict()
- self._parent: Optional['ProtoNode'] = None
-
- @abc.abstractmethod
- def type(self) -> 'ProtoNode.Type':
- """The type of the node."""
-
- def children(self) -> List['ProtoNode']:
- return list(self._children.values())
-
- def name(self) -> str:
- return self._name
-
- def cpp_name(self) -> str:
- """The name of this node in generated C++ code."""
- return self._name.replace('.', '::')
-
- def cpp_namespace(self, root: Optional['ProtoNode'] = None) -> str:
- """C++ namespace of the node, up to the specified root."""
- return '::'.join(
- self._attr_hierarchy(lambda node: node.cpp_name(), root))
-
- def common_ancestor(self, other: 'ProtoNode') -> Optional['ProtoNode']:
- """Finds the earliest common ancestor of this node and other."""
-
- if other is None:
- return None
-
- own_depth = self.depth()
- other_depth = other.depth()
- diff = abs(own_depth - other_depth)
-
- if own_depth < other_depth:
- first: Optional['ProtoNode'] = self
- second: Optional['ProtoNode'] = other
- else:
- first = other
- second = self
-
- while diff > 0:
- assert second is not None
- second = second.parent()
- diff -= 1
-
- while first != second:
- if first is None or second is None:
- return None
-
- first = first.parent()
- second = second.parent()
-
- return first
-
- def depth(self) -> int:
- """Returns the depth of this node from the root."""
- depth = 0
- node = self._parent
- while node:
- depth += 1
- node = node.parent()
- return depth
-
- def add_child(self, child: 'ProtoNode') -> None:
- """Inserts a new node into the tree as a child of this node.
-
- Args:
- child: The node to insert.
-
- Raises:
- ValueError: This node does not allow nesting the given type of child.
- """
- if not self._supports_child(child):
- raise ValueError('Invalid child %s for node of type %s' %
- (child.type(), self.type()))
-
- # pylint: disable=protected-access
- if child._parent is not None:
- del child._parent._children[child.name()]
-
- child._parent = self
- self._children[child.name()] = child
- # pylint: enable=protected-access
-
- def find(self, path: str) -> Optional['ProtoNode']:
- """Finds a node within this node's subtree."""
- node = self
-
- # pylint: disable=protected-access
- for section in path.split('.'):
- child = node._children.get(section)
- if child is None:
- return None
- node = child
- # pylint: enable=protected-access
-
- return node
-
- def parent(self) -> Optional['ProtoNode']:
- return self._parent
-
- def __iter__(self) -> Iterator['ProtoNode']:
- """Iterates depth-first through all nodes in this node's subtree."""
- yield self
- for child_iterator in self._children.values():
- for child in child_iterator:
- yield child
-
- def _attr_hierarchy(self, attr_accessor: Callable[['ProtoNode'], T],
- root: Optional['ProtoNode']) -> Iterator[T]:
- """Fetches node attributes at each level of the tree from the root.
-
- Args:
- attr_accessor: Function which extracts attributes from a ProtoNode.
- root: The node at which to terminate.
-
- Returns:
- An iterator to a list of the selected attributes from the root to the
- current node.
- """
- hierarchy = []
- node: Optional['ProtoNode'] = self
- while node is not None and node != root:
- hierarchy.append(attr_accessor(node))
- node = node.parent()
- return reversed(hierarchy)
-
- @abc.abstractmethod
- def _supports_child(self, child: 'ProtoNode') -> bool:
- """Returns True if child is a valid child type for the current node."""
-
-
-class ProtoPackage(ProtoNode):
- """A protobuf package."""
- def type(self) -> ProtoNode.Type:
- return ProtoNode.Type.PACKAGE
-
- def _supports_child(self, child: ProtoNode) -> bool:
- return True
-
-
-class ProtoEnum(ProtoNode):
- """Representation of an enum in a .proto file."""
- def __init__(self, name: str):
- super().__init__(name)
- self._values: List[Tuple[str, int]] = []
-
- def type(self) -> ProtoNode.Type:
- return ProtoNode.Type.ENUM
-
- def values(self) -> List[Tuple[str, int]]:
- return list(self._values)
-
- def add_value(self, name: str, value: int) -> None:
- self._values.append((ProtoMessageField.upper_snake_case(name), value))
-
- def _supports_child(self, child: ProtoNode) -> bool:
- # Enums cannot have nested children.
- return False
-
-
-class ProtoMessage(ProtoNode):
- """Representation of a message in a .proto file."""
- def __init__(self, name: str):
- super().__init__(name)
- self._fields: List['ProtoMessageField'] = []
-
- def type(self) -> ProtoNode.Type:
- return ProtoNode.Type.MESSAGE
-
- def fields(self) -> List['ProtoMessageField']:
- return list(self._fields)
-
- def add_field(self, field: 'ProtoMessageField') -> None:
- self._fields.append(field)
-
- def _supports_child(self, child: ProtoNode) -> bool:
- return (child.type() == self.Type.ENUM
- or child.type() == self.Type.MESSAGE)
-
-
-class ProtoExternal(ProtoNode):
- """A node from a different compilation unit.
-
- An external node is one that isn't defined within the current compilation
- unit, most likely as it comes from an imported proto file. Its type is not
- known, so it does not have any members or additional data. Its purpose
- within the node graph is to provide namespace resolution between compile
- units.
- """
- def type(self) -> ProtoNode.Type:
- return ProtoNode.Type.EXTERNAL
-
- def _supports_child(self, child: ProtoNode) -> bool:
- return True
-
-
-# This class is not a node and does not appear in the proto tree.
-# Fields belong to proto messages and are processed separately.
-class ProtoMessageField:
- """Representation of a field within a protobuf message."""
- def __init__(self,
- field_name: str,
- field_number: int,
- field_type: int,
- type_node: Optional[ProtoNode] = None,
- repeated: bool = False):
- self._field_name = field_name
- self._number: int = field_number
- self._type: int = field_type
- self._type_node: Optional[ProtoNode] = type_node
- self._repeated: bool = repeated
-
- def name(self) -> str:
- return self.upper_camel_case(self._field_name)
-
- def enum_name(self) -> str:
- return self.upper_snake_case(self._field_name)
-
- def number(self) -> int:
- return self._number
-
- def type(self) -> int:
- return self._type
-
- def type_node(self) -> Optional[ProtoNode]:
- return self._type_node
-
- def is_repeated(self) -> bool:
- return self._repeated
-
- @staticmethod
- def upper_camel_case(field_name: str) -> str:
- """Converts a field name to UpperCamelCase."""
- name_components = field_name.split('_')
- for i, _ in enumerate(name_components):
- name_components[i] = name_components[i].lower().capitalize()
- return ''.join(name_components)
-
- @staticmethod
- def upper_snake_case(field_name: str) -> str:
- """Converts a field name to UPPER_SNAKE_CASE."""
- return field_name.upper()
diff --git a/pw_protobuf/py/pw_protobuf/proto_tree.py b/pw_protobuf/py/pw_protobuf/proto_tree.py
new file mode 100644
index 0000000..7b9e361
--- /dev/null
+++ b/pw_protobuf/py/pw_protobuf/proto_tree.py
@@ -0,0 +1,500 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""This module defines data structures for protobuf entities."""
+
+import abc
+import collections
+import enum
+
+from typing import Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
+from typing import cast
+
+import google.protobuf.descriptor_pb2 as descriptor_pb2
+
+T = TypeVar('T') # pylint: disable=invalid-name
+
+
+class ProtoNode(abc.ABC):
+ """A ProtoNode represents a C++ scope mapping of an entity in a .proto file.
+
+ Nodes form a tree beginning at a top-level (global) scope, descending into a
+ hierarchy of .proto packages and the messages and enums defined within them.
+ """
+ class Type(enum.Enum):
+ """The type of a ProtoNode.
+
+ PACKAGE maps to a C++ namespace.
+ MESSAGE maps to a C++ "Encoder" class within its own namespace.
+ ENUM maps to a C++ enum within its parent's namespace.
+ EXTERNAL represents a node defined within a different compilation unit.
+ SERVICE represents an RPC service definition.
+ """
+ PACKAGE = 1
+ MESSAGE = 2
+ ENUM = 3
+ EXTERNAL = 4
+ SERVICE = 5
+
+ def __init__(self, name: str):
+ self._name: str = name
+ self._children: Dict[str, 'ProtoNode'] = collections.OrderedDict()
+ self._parent: Optional['ProtoNode'] = None
+
+ @abc.abstractmethod
+ def type(self) -> 'ProtoNode.Type':
+ """The type of the node."""
+
+ def children(self) -> List['ProtoNode']:
+ return list(self._children.values())
+
+ def name(self) -> str:
+ return self._name
+
+ def cpp_name(self) -> str:
+ """The name of this node in generated C++ code."""
+ return self._name.replace('.', '::')
+
+ def cpp_namespace(self, root: Optional['ProtoNode'] = None) -> str:
+ """C++ namespace of the node, up to the specified root."""
+ return '::'.join(
+ self._attr_hierarchy(lambda node: node.cpp_name(), root))
+
+ def common_ancestor(self, other: 'ProtoNode') -> Optional['ProtoNode']:
+ """Finds the earliest common ancestor of this node and other."""
+
+ if other is None:
+ return None
+
+ own_depth = self.depth()
+ other_depth = other.depth()
+ diff = abs(own_depth - other_depth)
+
+ if own_depth < other_depth:
+ first: Optional['ProtoNode'] = self
+ second: Optional['ProtoNode'] = other
+ else:
+ first = other
+ second = self
+
+ while diff > 0:
+ assert second is not None
+ second = second.parent()
+ diff -= 1
+
+ while first != second:
+ if first is None or second is None:
+ return None
+
+ first = first.parent()
+ second = second.parent()
+
+ return first
+
+ def depth(self) -> int:
+ """Returns the depth of this node from the root."""
+ depth = 0
+ node = self._parent
+ while node:
+ depth += 1
+ node = node.parent()
+ return depth
+
+ def add_child(self, child: 'ProtoNode') -> None:
+ """Inserts a new node into the tree as a child of this node.
+
+ Args:
+ child: The node to insert.
+
+ Raises:
+ ValueError: This node does not allow nesting the given type of child.
+ """
+ if not self._supports_child(child):
+ raise ValueError('Invalid child %s for node of type %s' %
+ (child.type(), self.type()))
+
+ # pylint: disable=protected-access
+ if child._parent is not None:
+ del child._parent._children[child.name()]
+
+ child._parent = self
+ self._children[child.name()] = child
+ # pylint: enable=protected-access
+
+ def find(self, path: str) -> Optional['ProtoNode']:
+ """Finds a node within this node's subtree."""
+ node = self
+
+ # pylint: disable=protected-access
+ for section in path.split('.'):
+ child = node._children.get(section)
+ if child is None:
+ return None
+ node = child
+ # pylint: enable=protected-access
+
+ return node
+
+ def parent(self) -> Optional['ProtoNode']:
+ return self._parent
+
+ def __iter__(self) -> Iterator['ProtoNode']:
+ """Iterates depth-first through all nodes in this node's subtree."""
+ yield self
+ for child_iterator in self._children.values():
+ for child in child_iterator:
+ yield child
+
+ def _attr_hierarchy(self, attr_accessor: Callable[['ProtoNode'], T],
+ root: Optional['ProtoNode']) -> Iterator[T]:
+ """Fetches node attributes at each level of the tree from the root.
+
+ Args:
+ attr_accessor: Function which extracts attributes from a ProtoNode.
+ root: The node at which to terminate.
+
+ Returns:
+ An iterator to a list of the selected attributes from the root to the
+ current node.
+ """
+ hierarchy = []
+ node: Optional['ProtoNode'] = self
+ while node is not None and node != root:
+ hierarchy.append(attr_accessor(node))
+ node = node.parent()
+ return reversed(hierarchy)
+
+ @abc.abstractmethod
+ def _supports_child(self, child: 'ProtoNode') -> bool:
+ """Returns True if child is a valid child type for the current node."""
+
+
+class ProtoPackage(ProtoNode):
+ """A protobuf package."""
+ def type(self) -> ProtoNode.Type:
+ return ProtoNode.Type.PACKAGE
+
+ def _supports_child(self, child: ProtoNode) -> bool:
+ return True
+
+
+class ProtoEnum(ProtoNode):
+ """Representation of an enum in a .proto file."""
+ def __init__(self, name: str):
+ super().__init__(name)
+ self._values: List[Tuple[str, int]] = []
+
+ def type(self) -> ProtoNode.Type:
+ return ProtoNode.Type.ENUM
+
+ def values(self) -> List[Tuple[str, int]]:
+ return list(self._values)
+
+ def add_value(self, name: str, value: int) -> None:
+ self._values.append((ProtoMessageField.upper_snake_case(name), value))
+
+ def _supports_child(self, child: ProtoNode) -> bool:
+ # Enums cannot have nested children.
+ return False
+
+
+class ProtoMessage(ProtoNode):
+ """Representation of a message in a .proto file."""
+ def __init__(self, name: str):
+ super().__init__(name)
+ self._fields: List['ProtoMessageField'] = []
+
+ def type(self) -> ProtoNode.Type:
+ return ProtoNode.Type.MESSAGE
+
+ def fields(self) -> List['ProtoMessageField']:
+ return list(self._fields)
+
+ def add_field(self, field: 'ProtoMessageField') -> None:
+ self._fields.append(field)
+
+ def _supports_child(self, child: ProtoNode) -> bool:
+ return (child.type() == self.Type.ENUM
+ or child.type() == self.Type.MESSAGE)
+
+
+class ProtoService(ProtoNode):
+ """Representation of a service in a .proto file."""
+ def __init__(self, name: str):
+ super().__init__(name)
+ self._methods: List['ProtoServiceMethod'] = []
+
+ def type(self) -> ProtoNode.Type:
+ return ProtoNode.Type.SERVICE
+
+ def methods(self) -> List['ProtoServiceMethod']:
+ return list(self._methods)
+
+ def add_method(self, method: 'ProtoServiceMethod') -> None:
+ self._methods.append(method)
+
+ def _supports_child(self, child: ProtoNode) -> bool:
+ return False
+
+
+class ProtoExternal(ProtoNode):
+ """A node from a different compilation unit.
+
+ An external node is one that isn't defined within the current compilation
+ unit, most likely as it comes from an imported proto file. Its type is not
+ known, so it does not have any members or additional data. Its purpose
+ within the node graph is to provide namespace resolution between compile
+ units.
+ """
+ def type(self) -> ProtoNode.Type:
+ return ProtoNode.Type.EXTERNAL
+
+ def _supports_child(self, child: ProtoNode) -> bool:
+ return True
+
+
+# This class is not a node and does not appear in the proto tree.
+# Fields belong to proto messages and are processed separately.
+class ProtoMessageField:
+ """Representation of a field within a protobuf message."""
+ def __init__(self,
+ field_name: str,
+ field_number: int,
+ field_type: int,
+ type_node: Optional[ProtoNode] = None,
+ repeated: bool = False):
+ self._field_name = field_name
+ self._number: int = field_number
+ self._type: int = field_type
+ self._type_node: Optional[ProtoNode] = type_node
+ self._repeated: bool = repeated
+
+ def name(self) -> str:
+ return self.upper_camel_case(self._field_name)
+
+ def enum_name(self) -> str:
+ return self.upper_snake_case(self._field_name)
+
+ def number(self) -> int:
+ return self._number
+
+ def type(self) -> int:
+ return self._type
+
+ def type_node(self) -> Optional[ProtoNode]:
+ return self._type_node
+
+ def is_repeated(self) -> bool:
+ return self._repeated
+
+ @staticmethod
+ def upper_camel_case(field_name: str) -> str:
+ """Converts a field name to UpperCamelCase."""
+ name_components = field_name.split('_')
+ for i, _ in enumerate(name_components):
+ name_components[i] = name_components[i].lower().capitalize()
+ return ''.join(name_components)
+
+ @staticmethod
+ def upper_snake_case(field_name: str) -> str:
+ """Converts a field name to UPPER_SNAKE_CASE."""
+ return field_name.upper()
+
+
+class ProtoServiceMethod:
+ """A method defined in a protobuf service."""
+ class Type(enum.Enum):
+ UNARY = 0
+ SERVER_STREAMING = 1
+ CLIENT_STREAMING = 2
+ BIDIRECTIONAL_STREAMING = 3
+
+ def __init__(self, name: str, method_type: Type, request_type: ProtoNode,
+ response_type: ProtoNode):
+ self._name = name
+ self._type = method_type
+ self._request_type = request_type
+ self._response_type = response_type
+
+ def name(self) -> str:
+ return self._name
+
+
+def _add_enum_fields(enum_node: ProtoNode, proto_enum) -> None:
+ """Adds fields from a protobuf enum descriptor to an enum node."""
+ assert enum_node.type() == ProtoNode.Type.ENUM
+ enum_node = cast(ProtoEnum, enum_node)
+
+ for value in proto_enum.value:
+ enum_node.add_value(value.name, value.number)
+
+
+def _create_external_nodes(root: ProtoNode, path: str) -> ProtoNode:
+ """Creates external nodes for a path starting from the given root."""
+
+ node = root
+ for part in path.split('.'):
+ child = node.find(part)
+ if not child:
+ child = ProtoExternal(part)
+ node.add_child(child)
+ node = child
+
+ return node
+
+
+def _find_or_create_node(global_root: ProtoNode, package_root: ProtoNode,
+ path: str) -> ProtoNode:
+ """Searches the proto tree for a node by path, creating it if not found."""
+
+ if path[0] == '.':
+ # Fully qualified path.
+ root_relative_path = path[1:]
+ search_root = global_root
+ else:
+ root_relative_path = path
+ search_root = package_root
+
+ node = search_root.find(root_relative_path)
+ if node is None:
+ # Create nodes for field types that don't exist within this
+ # compilation context, such as those imported from other .proto
+ # files.
+ node = _create_external_nodes(search_root, root_relative_path)
+
+ return node
+
+
+def _add_message_fields(global_root: ProtoNode, package_root: ProtoNode,
+ message: ProtoNode, proto_message) -> None:
+ """Adds fields from a protobuf message descriptor to a message node."""
+ assert message.type() == ProtoNode.Type.MESSAGE
+ message = cast(ProtoMessage, message)
+
+ type_node: Optional[ProtoNode]
+
+ for field in proto_message.field:
+ if field.type_name:
+ # The "type_name" member contains the global .proto path of the
+ # field's type object, for example ".pw.protobuf.test.KeyValuePair".
+ # Try to find the node for this object within the current context.
+ type_node = _find_or_create_node(global_root, package_root,
+ field.type_name)
+ else:
+ type_node = None
+
+ repeated = \
+ field.label == descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED
+ message.add_field(
+ ProtoMessageField(
+ field.name,
+ field.number,
+ field.type,
+ type_node,
+ repeated,
+ ))
+
+
+def _add_service_methods(global_root: ProtoNode, package_root: ProtoNode,
+ service: ProtoNode, proto_service) -> None:
+ assert service.type() == ProtoNode.Type.SERVICE
+ service = cast(ProtoService, service)
+
+ for method in proto_service.method:
+ if method.client_streaming and method.server_streaming:
+ method_type = ProtoServiceMethod.Type.BIDIRECTIONAL_STREAMING
+ elif method.client_streaming:
+ method_type = ProtoServiceMethod.Type.CLIENT_STREAMING
+ elif method.server_streaming:
+ method_type = ProtoServiceMethod.Type.SERVER_STREAMING
+ else:
+ method_type = ProtoServiceMethod.Type.UNARY
+
+ request_node = _find_or_create_node(global_root, package_root,
+ method.input_type)
+ response_node = _find_or_create_node(global_root, package_root,
+ method.output_type)
+
+ service.add_method(
+ ProtoServiceMethod(method.name, method_type, request_node,
+ response_node))
+
+
+def _populate_fields(proto_file, global_root: ProtoNode,
+ package_root: ProtoNode) -> None:
+ """Traverses a proto file, adding all message and enum fields to a tree."""
+ def populate_message(node, message):
+ """Recursively populates nested messages and enums."""
+ _add_message_fields(global_root, package_root, node, message)
+
+ for proto_enum in message.enum_type:
+ _add_enum_fields(node.find(proto_enum.name), proto_enum)
+ for msg in message.nested_type:
+ populate_message(node.find(msg.name), msg)
+
+ # Iterate through the proto file, populating top-level objects.
+ for proto_enum in proto_file.enum_type:
+ enum_node = package_root.find(proto_enum.name)
+ assert enum_node is not None
+ _add_enum_fields(enum_node, proto_enum)
+
+ for message in proto_file.message_type:
+ populate_message(package_root.find(message.name), message)
+
+ for service in proto_file.service:
+ service_node = package_root.find(service.name)
+ assert service_node is not None
+ _add_service_methods(global_root, package_root, service_node, service)
+
+
+def _build_hierarchy(proto_file):
+ """Creates a ProtoNode hierarchy from a proto file descriptor."""
+
+ root = ProtoPackage('')
+ package_root = root
+
+ for part in proto_file.package.split('.'):
+ package = ProtoPackage(part)
+ package_root.add_child(package)
+ package_root = package
+
+ def build_message_subtree(proto_message):
+ node = ProtoMessage(proto_message.name)
+ for proto_enum in proto_message.enum_type:
+ node.add_child(ProtoEnum(proto_enum.name))
+ for submessage in proto_message.nested_type:
+ node.add_child(build_message_subtree(submessage))
+
+ return node
+
+ for proto_enum in proto_file.enum_type:
+ package_root.add_child(ProtoEnum(proto_enum.name))
+
+ for message in proto_file.message_type:
+ package_root.add_child(build_message_subtree(message))
+
+ for service in proto_file.service:
+ package_root.add_child(ProtoService(service.name))
+
+ return root, package_root
+
+
+def build_node_tree(file_descriptor_proto) -> Tuple[ProtoNode, ProtoNode]:
+ """Constructs a tree of proto nodes from a file descriptor.
+
+ Returns the root node of the entire proto package tree and the node
+ representing the file's package.
+ """
+ global_root, package_root = _build_hierarchy(file_descriptor_proto)
+ _populate_fields(file_descriptor_proto, global_root, package_root)
+ return global_root, package_root
diff --git a/pw_protobuf/py/setup.py b/pw_protobuf/py/setup.py
index 06192fd..4d99196 100644
--- a/pw_protobuf/py/setup.py
+++ b/pw_protobuf/py/setup.py
@@ -31,7 +31,7 @@
packages=setuptools.find_packages(),
test_suite='setup.test_suite',
entry_points={
- 'console_scripts': ['pw_protobuf_codegen = pw_protobuf.codegen:main']
+ 'console_scripts': ['pw_protobuf_codegen = pw_protobuf.plugin:main']
},
install_requires=[
'protobuf',