pw_protobuf_compiler: Python proto compilation

The python_protos Python module dynamically compiles and loads protobufs
from .proto files.

Change-Id: I7676ac9fe4842f370fda3b24b11a97cf334c5953
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/12664
Reviewed-by: Alexei Frolov <frolv@google.com>
Commit-Queue: Wyatt Hepler <hepler@google.com>
diff --git a/pw_protobuf/pw_protobuf_protos/test_protos/full_test.proto b/pw_protobuf/pw_protobuf_protos/test_protos/full_test.proto
index 07d8995..7262d32 100644
--- a/pw_protobuf/pw_protobuf_protos/test_protos/full_test.proto
+++ b/pw_protobuf/pw_protobuf_protos/test_protos/full_test.proto
@@ -31,7 +31,7 @@
     enum Binary {
       ZERO = 0;
       ONE = 1;
-    };
+    }
 
     Bool status = 1;
   }
@@ -40,7 +40,7 @@
     enum Binary {
       ONE = 0;
       ZERO = 1;
-    };
+    }
 
     // We must go deeper.
     message Compiler {
@@ -81,7 +81,7 @@
   enum Binary {
     OFF = 0;
     ON = 1;
-  };
+  }
 
   message ID {
     uint32 id = 1;
diff --git a/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py b/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
new file mode 100644
index 0000000..4c55bd3
--- /dev/null
+++ b/pw_protobuf_compiler/py/pw_protobuf_compiler/python_protos.py
@@ -0,0 +1,185 @@
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tools for compiling and importing Python protos on the fly."""
+
+import logging
+import os
+from pathlib import Path
+import subprocess
+import shlex
+import tempfile
+from types import ModuleType
+from typing import Dict, Iterable, List, Set, Tuple, Union
+import importlib.util
+
+_LOG = logging.getLogger(__name__)
+
+PathOrStr = Union[Path, str]
+
+
+def compile_protos(
+    output_dir: PathOrStr,
+    proto_files: Iterable[PathOrStr],
+    includes: Iterable[PathOrStr] = ()) -> None:
+    """Compiles proto files for Python by invoking the protobuf compiler.
+
+    Proto files not covered by one of the provided include paths will have their
+    directory added as an include path.
+    """
+    proto_paths: List[Path] = [Path(f).resolve() for f in proto_files]
+    include_paths: Set[Path] = set(Path(d).resolve() for d in includes)
+
+    for path in proto_paths:
+        if not any(include in path.parents for include in include_paths):
+            include_paths.add(path.parent)
+
+    cmd: Tuple[PathOrStr, ...] = (
+        'protoc',
+        '--python_out',
+        os.path.abspath(output_dir),
+        *(f'-I{d}' for d in include_paths),
+        *proto_paths,
+    )
+
+    _LOG.debug('%s', shlex.join(str(c) for c in cmd))
+    process = subprocess.run(cmd, capture_output=True)
+
+    if process.returncode:
+        _LOG.error('protoc invocation failed!\n%s\n%s',
+                   shlex.join(str(c) for c in cmd), process.stderr.decode())
+        process.check_returncode()
+
+
+def _import_module(name: str, path: str) -> ModuleType:
+    spec = importlib.util.spec_from_file_location(name, path)
+    module = importlib.util.module_from_spec(spec)
+    spec.loader.exec_module(module)  # type: ignore[union-attr]
+    return module
+
+
+def import_modules(directory: PathOrStr) -> Iterable[ModuleType]:
+    """Imports modules in a directory and yields them."""
+    parent = os.path.dirname(directory)
+
+    for dirpath, _, files in os.walk(directory):
+        path_parts = os.path.relpath(dirpath, parent).split(os.sep)
+
+        for file in files:
+            name, ext = os.path.splitext(file)
+
+            if ext == '.py':
+                yield _import_module(f'{".".join(path_parts)}.{name}',
+                                     os.path.join(dirpath, file))
+
+
+def compile_and_import(proto_files: Iterable[PathOrStr],
+                       includes: Iterable[PathOrStr] = (),
+                       output_dir: PathOrStr = None) -> Iterable[ModuleType]:
+    """Compiles protos and imports their modules; yields the proto modules.
+
+    Args:
+      proto_files: paths to .proto files to compile
+      includes: include paths to use for .proto compilation
+      output_dir: where to place the generated modules; a temporary directory is
+          used if omitted
+
+    Yields:
+      the generated protobuf Python modules
+    """
+
+    if output_dir:
+        compile_protos(output_dir, proto_files, includes)
+        yield from import_modules(output_dir)
+    else:
+        with tempfile.TemporaryDirectory(prefix='protos_') as tempdir:
+            compile_protos(tempdir, proto_files, includes)
+            yield from import_modules(tempdir)
+
+
+def compile_and_import_file(proto_file: PathOrStr,
+                            includes: Iterable[PathOrStr] = (),
+                            output_dir: PathOrStr = None) -> ModuleType:
+    """Compiles and imports the module for a single .proto file."""
+    return next(iter(compile_and_import([proto_file], includes, output_dir)))
+
+
+class _ProtoPackage:
+    """Used by the Library class for accessing protocol buffer modules."""
+    def __init__(self, package: str):
+        self._packages: Dict[str, _ProtoPackage] = {}
+        self._modules: List[ModuleType] = []
+        self._package = package
+
+    def __getattr__(self, attr: str):
+        """Descends into subpackages or access proto entities in a package."""
+        if attr in self._packages:
+            return self._packages[attr]
+
+        for module in self._modules:
+            if hasattr(module, attr):
+                return getattr(module, attr)
+
+        raise AttributeError(
+            f'Proto package "{self._package}" does not contain "{attr}"')
+
+
+class Library:
+    """A collection of protocol buffer modules sorted by package.
+
+    In Python, each .proto file is compiled into a Python module. The Library
+    class makes it simple to navigate a collection of Python modules
+    corresponding to .proto files, without relying on the location of these
+    compiled modules.
+
+    Proto messages and other types can be directly accessed by their protocol
+    buffer package name. For example, the foo.bar.Baz message can be accessed
+    in a Library called `protos` as:
+
+      protos.packages.foo.bar.Baz
+
+    A Library also provides the modules_by_package dictionary, for looking up
+    the list of modules in a particular package, and the modules() generator
+    for iterating over all modules.
+    """
+    def __init__(self, modules: Iterable[ModuleType]):
+        """Constructs a Library from an iterable of modules.
+
+        A Library can be constructed with modules dynamically compiled by
+        compile_and_import. For example:
+
+            protos = Library(compile_and_import(list_of_proto_files))
+        """
+        self.modules_by_package: Dict[str, List[ModuleType]] = {}
+        self.packages = _ProtoPackage('')
+
+        for module in modules:
+            package = module.DESCRIPTOR.package  # type: ignore[attr-defined]
+            self.modules_by_package.setdefault(package, []).append(module)
+
+            entry = self.packages
+            subpackages = package.split('.')
+
+            for i, subpackage in enumerate(subpackages, 1):
+                if subpackage not in entry._packages:
+                    entry._packages[subpackage] = _ProtoPackage('.'.join(
+                        subpackages[:i]))
+
+                entry = entry._packages[subpackage]
+
+            entry._modules.append(module)
+
+    def modules(self) -> Iterable[ModuleType]:
+        """Allows iterating over all protobuf modules in this library."""
+        for module_list in self.modules_by_package.values():
+            yield from module_list
diff --git a/pw_protobuf_compiler/py/python_protos_test.py b/pw_protobuf_compiler/py/python_protos_test.py
new file mode 100755
index 0000000..2408fb3
--- /dev/null
+++ b/pw_protobuf_compiler/py/python_protos_test.py
@@ -0,0 +1,163 @@
+#!/usr/bin/env python3
+# Copyright 2020 The Pigweed Authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+"""Tests compiling and importing Python protos on the fly."""
+
+from pathlib import Path
+import tempfile
+import unittest
+
+from pw_protobuf_compiler import python_protos
+
+PROTO_1 = b"""\
+syntax = "proto3";
+
+package pw.protobuf_compiler.test1;
+
+message SomeMessage {
+  uint32 magic_number = 1;
+}
+
+message AnotherMessage {
+  enum Result {
+    FAILED = 0;
+    FAILED_MISERABLY = 1;
+    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
+  }
+
+  Result result = 1;
+  string payload = 2;
+}
+
+service PublicService {
+  rpc Unary(SomeMessage) returns (AnotherMessage) {}
+  rpc ServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
+  rpc ClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
+  rpc BidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
+}
+"""
+
+PROTO_2 = b"""\
+syntax = "proto2";
+
+package pw.protobuf_compiler.test2;
+
+message Request {
+  optional float magic_number = 1;
+}
+
+message Response {
+}
+
+service Alpha {
+  rpc Unary(Request) returns (Response) {}
+}
+
+service Bravo {
+  rpc BidiStreaming(stream Request) returns (stream Response) {}
+}
+"""
+
+PROTO_3 = b"""\
+syntax = "proto3";
+
+package pw.protobuf_compiler.test2;
+
+enum Greeting {
+  YO = 0;
+  HI = 1;
+}
+
+message Hello {
+  repeated int64 value = 1;
+  Greeting hi = 2;
+}
+"""
+
+
+class TestCompileAndImport(unittest.TestCase):
+    def setUp(self):
+        self._proto_dir = tempfile.TemporaryDirectory(prefix='proto_test')
+        self._protos = []
+
+        for i, contents in enumerate([PROTO_1, PROTO_2, PROTO_3], 1):
+            self._protos.append(Path(self._proto_dir.name, f'test_{i}.proto'))
+            self._protos[-1].write_bytes(contents)
+
+    def tearDown(self):
+        self._proto_dir.cleanup()
+
+    def test_compile_to_temp_dir_and_import(self):
+        modules = {
+            m.DESCRIPTOR.name: m
+            for m in python_protos.compile_and_import(self._protos)
+        }
+        self.assertEqual(3, len(modules))
+
+        # Make sure the protobuf modules contain what we expect them to.
+        mod = modules['test_1.proto']
+        self.assertEqual(
+            4, len(mod.DESCRIPTOR.services_by_name['PublicService'].methods))
+
+        mod = modules['test_2.proto']
+        self.assertEqual(mod.Request(magic_number=1.5).magic_number, 1.5)
+        self.assertEqual(2, len(mod.DESCRIPTOR.services_by_name))
+
+        mod = modules['test_3.proto']
+        self.assertEqual(mod.Hello(value=[123, 456]).value, [123, 456])
+
+
+class TestProtoLibrary(TestCompileAndImport):
+    """Tests the Library class."""
+    def setUp(self):
+        super().setUp()
+        self._library = python_protos.Library(
+            python_protos.compile_and_import(self._protos))
+
+    def test_packages_can_access_messages(self):
+        msg = self._library.packages.pw.protobuf_compiler.test1.SomeMessage
+        self.assertEqual(msg(magic_number=123).magic_number, 123)
+
+    def test_packages_finds_across_modules(self):
+        msg = self._library.packages.pw.protobuf_compiler.test2.Request
+        self.assertEqual(msg(magic_number=50).magic_number, 50)
+
+        val = self._library.packages.pw.protobuf_compiler.test2.YO
+        self.assertEqual(val, 0)
+
+    def test_packages_invalid_name(self):
+        with self.assertRaises(AttributeError):
+            _ = self._library.packages.nothing
+
+        with self.assertRaises(AttributeError):
+            _ = self._library.packages.pw.NOT_HERE
+
+        with self.assertRaises(AttributeError):
+            _ = self._library.packages.pw.protobuf_compiler.test1.NotARealMsg
+
+    def test_access_modules_by_package(self):
+        test1 = self._library.modules_by_package['pw.protobuf_compiler.test1']
+        self.assertEqual(len(test1), 1)
+        self.assertEqual(test1[0].AnotherMessage.Result.Value('FAILED'), 0)
+
+        test2 = self._library.modules_by_package['pw.protobuf_compiler.test2']
+        self.assertEqual(len(test2), 2)
+
+    def test_access_modules_by_package_unkonwn(self):
+        with self.assertRaises(KeyError):
+            _ = self._library.modules_by_package['pw.not_real']
+
+
+if __name__ == '__main__':
+    unittest.main()