move "generate union code" to aidl_to_cpp_common

This is a preparation to add NDK backend for union.

Bug: 170784707
Bug: 150948558
Test: aidl_unittests / aidl_integration_test
Change-Id: I347fc7e114b84eb03baf826c86388176d90bc81e
diff --git a/aidl_to_cpp_common.cpp b/aidl_to_cpp_common.cpp
index c1953ac..04e805f 100644
--- a/aidl_to_cpp_common.cpp
+++ b/aidl_to_cpp_common.cpp
@@ -15,8 +15,11 @@
  */
 #include "aidl_to_cpp_common.h"
 
+#include <android-base/format.h>
 #include <android-base/stringprintf.h>
 #include <android-base/strings.h>
+
+#include <set>
 #include <unordered_map>
 
 #include "ast_cpp.h"
@@ -390,6 +393,179 @@
   return decl;
 }
 
+void GenerateParcelableComparisonOperators(CodeWriter& out, const AidlParcelable& parcelable) {
+  std::set<string> operators{"<", ">", "==", ">=", "<=", "!="};
+  bool is_empty = false;
+
+  auto comparable = [&](const string& prefix) {
+    vector<string> fields;
+    if (auto p = parcelable.AsStructuredParcelable(); p != nullptr) {
+      is_empty = p->GetFields().empty();
+      for (const auto& f : p->GetFields()) {
+        fields.push_back(prefix + f->GetName());
+      }
+      return "std::tie(" + Join(fields, ", ") + ")";
+    } else if (auto p = parcelable.AsUnionDeclaration(); p != nullptr) {
+      return prefix + "_value";
+    } else {
+      AIDL_FATAL(parcelable) << "Unknown paracelable type";
+    }
+  };
+
+  string lhs = comparable("");
+  string rhs = comparable("rhs.");
+  for (const auto& op : operators) {
+    out << "inline bool operator" << op << "(const " << parcelable.GetName() << "&"
+        << (is_empty ? "" : " rhs") << ") const {\n"
+        << "  return " << lhs << " " << op << " " << rhs << ";\n"
+        << "}\n";
+  }
+  out << "\n";
+}
+
+const vector<string> UnionWriter::headers{
+    "type_traits",  // std::is_same_v
+    "utility",      // std::mode/forward for value
+    "variant",      // std::variant for value
+};
+
+void UnionWriter::PrivateFields(CodeWriter& out) const {
+  vector<string> field_types;
+  for (const auto& f : decl.GetFields()) {
+    field_types.push_back(name_of(f->GetType(), typenames));
+  }
+  out << "std::variant<" + Join(field_types, ", ") + "> _value;\n";
+}
+
+void UnionWriter::PublicFields(CodeWriter& out) const {
+  AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+                             /* type_params= */ nullptr, /* comments= */ "");
+  tag_type.Resolve(typenames);
+
+  out << "enum Tag : " << name_of(tag_type, typenames) << " {\n";
+  bool is_first = true;
+  for (const auto& f : decl.GetFields()) {
+    out << "  " << f->GetName() << (is_first ? " = 0" : "") << ",  // " << f->Signature() << ";\n";
+    is_first = false;
+  }
+  out << "};\n";
+
+  const auto& name = decl.GetName();
+
+  AIDL_FATAL_IF(decl.GetFields().empty(), decl) << "Union '" << name << "' is empty.";
+  const auto& first_field = decl.GetFields()[0];
+  const auto& default_name = first_field->GetName();
+  const auto& default_value =
+      name_of(first_field->GetType(), typenames) + "(" + first_field->ValueString(decorator) + ")";
+
+  auto tmpl = R"--(
+template<typename _Tp>
+static constexpr bool _not_self = !std::is_same_v<std::remove_cv_t<std::remove_reference_t<_Tp>>, {name}>;
+
+{name}() : _value(std::in_place_index<{default_name}>, {default_value}) {{ }}
+{name}(const {name}&) = default;
+{name}({name}&&) = default;
+{name}& operator=(const {name}&) = default;
+{name}& operator=({name}&&) = default;
+
+template <typename _Tp, std::enable_if_t<_not_self<_Tp>, int> = 0>
+constexpr {name}(_Tp&& _arg)
+    : _value(std::forward<_Tp>(_arg)) {{}}
+
+template <typename... _Tp>
+constexpr explicit {name}(_Tp&&... _args)
+    : _value(std::forward<_Tp>(_args)...) {{}}
+
+template <Tag _tag, typename... _Tp>
+static {name} make(_Tp&&... _args) {{
+  return {name}(std::in_place_index<_tag>, std::forward<_Tp>(_args)...);
+}}
+
+template <Tag _tag, typename _Tp, typename... _Up>
+static {name} make(std::initializer_list<_Tp> _il, _Up&&... _args) {{
+  return {name}(std::in_place_index<_tag>, std::move(_il), std::forward<_Up>(_args)...);
+}}
+
+Tag getTag() const {{
+  return static_cast<Tag>(_value.index());
+}}
+
+template <Tag _tag>
+const auto& get() const {{
+  if (getTag() != _tag) {{ abort(); }}
+  return std::get<_tag>(_value);
+}}
+
+template <Tag _tag>
+auto& get() {{
+  if (getTag() != _tag) {{ abort(); }}
+  return std::get<_tag>(_value);
+}}
+
+template <Tag _tag, typename... _Tp>
+void set(_Tp&&... _args) {{
+  _value.emplace<_tag>(std::forward<_Tp>(_args)...);
+}}
+
+)--";
+  out << fmt::format(tmpl, fmt::arg("name", name), fmt::arg("default_name", default_name),
+                     fmt::arg("default_value", default_value));
+}
+
+void UnionWriter::ReadFromParcel(CodeWriter& out, const ParcelWriterContext& ctx) const {
+  AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+                             /* type_params= */ nullptr, /* comments= */ "");
+  tag_type.Resolve(typenames);
+
+  const string tag = "_aidl_tag";
+  const string value = "_aidl_value";
+  const string status = "_aidl_ret_status";
+
+  auto read_var = [&](const string& var, const AidlTypeSpecifier& type) {
+    out << fmt::format("{} {};\n", name_of(type, typenames), var);
+    out << fmt::format("if (({} = ", status);
+    ctx.read_func(out, var, type);
+    out << fmt::format(") != {}) return {};\n", ctx.status_ok, status);
+  };
+
+  out << fmt::format("{} {};\n", ctx.status_type, status);
+  read_var(tag, tag_type);
+  out << fmt::format("switch ({}) {{\n", tag);
+  for (const auto& variable : decl.GetFields()) {
+    out << fmt::format("case {}: {{\n", variable->GetName());
+    out.Indent();
+    read_var(value, variable->GetType());
+    out << fmt::format("set<{}>(std::move({}));\n", variable->GetName(), value);
+    out << fmt::format("return {}; }}\n", ctx.status_ok);
+    out.Dedent();
+  }
+  out << "}\n";
+  out << fmt::format("return {};\n", ctx.status_bad);
+}
+
+void UnionWriter::WriteToParcel(CodeWriter& out, const ParcelWriterContext& ctx) const {
+  AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+                             /* type_params= */ nullptr, /* comments= */ "");
+  tag_type.Resolve(typenames);
+
+  const string tag = "_aidl_tag";
+  const string value = "_aidl_value";
+  const string status = "_aidl_ret_status";
+
+  out << fmt::format("{} {} = ", ctx.status_type, status);
+  ctx.write_func(out, "getTag()", tag_type);
+  out << ";\n";
+  out << fmt::format("if ({} != {}) return {};\n", status, ctx.status_ok, status);
+  out << "switch (getTag()) {\n";
+  for (const auto& variable : decl.GetFields()) {
+    out << fmt::format("case {}: return ", variable->GetName());
+    ctx.write_func(out, "get<" + variable->GetName() + ">()", variable->GetType());
+    out << ";\n";
+  }
+  out << "}\n";
+  out << "abort();\n";
+}
+
 }  // namespace cpp
 }  // namespace aidl
 }  // namespace android
diff --git a/aidl_to_cpp_common.h b/aidl_to_cpp_common.h
index 2892f88..f49edb0 100644
--- a/aidl_to_cpp_common.h
+++ b/aidl_to_cpp_common.h
@@ -16,6 +16,7 @@
 
 #pragma once
 
+#include <functional>
 #include <string>
 #include <type_traits>
 
@@ -74,6 +75,31 @@
                                const std::vector<std::string>& enclosing_namespaces_of_enum_decl);
 std::string TemplateDecl(const AidlStructuredParcelable& defined_type);
 
+void GenerateParcelableComparisonOperators(CodeWriter& out, const AidlParcelable& parcelable);
+
+struct ParcelWriterContext {
+  string status_type;
+  string status_ok;
+  string status_bad;
+  std::function<void(CodeWriter& out, const std::string& var, const AidlTypeSpecifier& type)>
+      read_func;
+  std::function<void(CodeWriter& out, const std::string& value, const AidlTypeSpecifier& type)>
+      write_func;
+};
+
+struct UnionWriter {
+  const AidlUnionDecl& decl;
+  const AidlTypenames& typenames;
+  const std::function<std::string(const AidlTypeSpecifier&, const AidlTypenames&)> name_of;
+  const ::ConstantValueDecorator& decorator;
+  static const std::vector<std::string> headers;
+
+  void PrivateFields(CodeWriter& out) const;
+  void PublicFields(CodeWriter& out) const;
+  void ReadFromParcel(CodeWriter& out, const ParcelWriterContext&) const;
+  void WriteToParcel(CodeWriter& out, const ParcelWriterContext&) const;
+};
+
 }  // namespace cpp
 }  // namespace aidl
 }  // namespace android
diff --git a/aidl_unittest.cpp b/aidl_unittest.cpp
index ddfa0a5..3dd4e62 100644
--- a/aidl_unittest.cpp
+++ b/aidl_unittest.cpp
@@ -2701,6 +2701,7 @@
 #include <binder/Parcel.h>
 #include <binder/Status.h>
 #include <cstdint>
+#include <type_traits>
 #include <utility>
 #include <variant>
 #include <vector>
diff --git a/generate_cpp.cpp b/generate_cpp.cpp
index ad32352..616d472 100644
--- a/generate_cpp.cpp
+++ b/generate_cpp.cpp
@@ -24,6 +24,7 @@
 #include <set>
 #include <string>
 
+#include <android-base/format.h>
 #include <android-base/stringprintf.h>
 
 #include "aidl_language.h"
@@ -1044,7 +1045,6 @@
 template <typename ParcelableType>
 struct ParcelableTraits {
   static void AddIncludes(set<string>& includes);
-  static string GetComparable(const ParcelableType& decl, const string& var_prefix);
   static void AddFields(ClassDecl& clazz, const ParcelableType& decl,
                         const AidlTypenames& typenames);
   static void GenReadFromParcel(const ParcelableType& parcel, const AidlTypenames& typenames,
@@ -1058,13 +1058,6 @@
   static void AddIncludes(set<string>& includes) {
     includes.insert("tuple");  // std::tie in comparison operators
   }
-  static string GetComparable(const AidlStructuredParcelable& decl, const string& var_prefix) {
-    std::vector<std::string> var_names;
-    for (const auto& variable : decl.GetFields()) {
-      var_names.push_back(var_prefix + variable->GetName());
-    }
-    return "std::tie(" + Join(var_names, ", ") + ")";
-  }
   static void AddFields(ClassDecl& clazz, const AidlStructuredParcelable& decl,
                         const AidlTypenames& typenames) {
     for (const auto& variable : decl.GetFields()) {
@@ -1142,183 +1135,61 @@
   }
 };
 
+// Adapter to cpp::UnionWriter
 template <>
 struct ParcelableTraits<AidlUnionDecl> {
   static void AddIncludes(set<string>& includes) {
-    includes.insert("variant");  // std::variant for value
-    includes.insert("utility");  // std::mode/forward for value
-  }
-  static string GetComparable(const AidlUnionDecl&, const string& var_prefix) {
-    return var_prefix + "_value";
+    includes.insert(std::begin(UnionWriter::headers), std::end(UnionWriter::headers));
   }
   static void AddFields(ClassDecl& clazz, const AidlUnionDecl& decl,
                         const AidlTypenames& typenames) {
-    AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
-                               /* type_params= */ nullptr, /* comments= */ "");
-    tag_type.Resolve(typenames);
-
-    std::ostringstream out;
-    out << "enum Tag : " << CppNameOf(tag_type, typenames) << " {\n";
-    bool is_first = true;
-    for (const auto& f : decl.GetFields()) {
-      out << "  " << f->GetName() << (is_first ? " = 0" : "") << ",  // " << f->Signature()
-          << ";\n";
-      is_first = false;
-    }
-    out << "};\n\n";
-    clazz.AddPublic(std::make_unique<LiteralDecl>(out.str()));
-
-    const auto& name = decl.GetName();
-
-    AIDL_FATAL_IF(decl.GetFields().empty(), decl) << "Union '" << name << "' is empty.";
-    const auto& first_field = decl.GetFields()[0];
-    const auto& first_name = first_field->GetName();
-    const auto& first_value = GetInitializer(typenames, *first_field);
-
-    // clang-format off
-    auto helper_methods = vector<string>{
-        // type classification
-        "template<typename _Tp>\n"
-        "static constexpr bool _not_self = !std::is_same_v<std::remove_cv_t<std::remove_reference_t<_Tp>>, " + name + ">;\n\n",
-
-        // default ctor inits with the first member's default value
-        name + "() : _value(std::in_place_index<" + first_name + ">, " + first_value + ") { }\n",
-
-        // other ctors with default implementation
-        name + "(const " + name + "&) = default;\n",
-        name + "(" + name + "&&) = default;\n",
-        name + "& operator=(const " + name + "&) = default;\n",
-        name + "& operator=(" + name + "&&) = default;\n\n",
-
-        // conversion ctor from value
-        "template <typename _Tp, std::enable_if_t<_not_self<_Tp>, int> = 0>\n"
-        "constexpr " + name + "(_Tp&& _arg)\n"
-        "    : _value(std::forward<_Tp>(_arg)) {}\n\n",
-
-        // ctor to support in-place construction using in_place_index/in_place_type
-        "template <typename... _Tp>\n"
-        "constexpr explicit " + name + "(_Tp&&... _args)\n"
-        "    : _value(std::forward<_Tp>(_args)...) {}\n\n",
-
-        // value ctor: make<tag>(args...)
-        "template <Tag _tag, typename... _Tp>\n"
-        "static " + name + " make(_Tp&&... _args) {\n"
-        "  return " + name + "(std::in_place_index<_tag>, std::forward<_Tp>(_args)...);\n"
-        "}\n\n",
-
-        // value ctor: make<tag>({initializer_list})
-        "template <Tag _tag, typename _Tp, typename... _Up>\n"
-        "static " + name + " make(std::initializer_list<_Tp> _il, _Up&&... _args) {\n"
-        "  return " + name + "(std::in_place_index<_tag>, std::move(_il), std::forward<_Up>(_args)...);\n"
-        "}\n\n",
-
-        // getTag
-        "Tag getTag() const {\n"
-        "  return static_cast<Tag>(_value.index());\n"
-        "}\n\n",
-
-        // const-getter
-        "template <Tag _tag>\n"
-        "const auto& get() const {\n"
-        "  if (getTag() != _tag) { abort(); }\n"
-        "  return std::get<_tag>(_value);\n"
-        "}\n\n",
-
-        // getter
-        "template <Tag _tag>\n"
-        "auto& get() {\n"
-        "  if (getTag() != _tag) { abort(); }\n"
-        "  return std::get<_tag>(_value);\n"
-        "}\n\n",
-
-        // setter
-        "template <Tag _tag, typename... _Tp>\n"
-        "void set(_Tp&&... _args) {\n"
-        "  _value.emplace<_tag>(std::forward<_Tp>(_args)...);\n"
-        "}\n\n",
-    };
-    // clang-format on
-    for (const auto& helper_method : helper_methods) {
-      clazz.AddPublic(std::make_unique<LiteralDecl>(helper_method));
-    }
-
-    vector<string> field_types;
-    for (const auto& f : decl.GetFields()) {
-      field_types.push_back(CppNameOf(f->GetType(), typenames));
-    }
-    clazz.AddPrivate(
-        std::make_unique<LiteralDecl>("std::variant<" + Join(field_types, ", ") + "> _value;\n"));
+    UnionWriter uw{decl, typenames, &CppNameOf, &ConstantValueDecorator};
+    const string public_fields = RunWriter([&](auto& out) { uw.PublicFields(out); });
+    const string private_fields = RunWriter([&](auto& out) { uw.PrivateFields(out); });
+    clazz.AddPublic(std::make_unique<LiteralDecl>(public_fields));
+    clazz.AddPrivate(std::make_unique<LiteralDecl>(private_fields));
   }
-  static void GenReadFromParcel(const AidlUnionDecl& parcel, const AidlTypenames& typenames,
+  static void GenReadFromParcel(const AidlUnionDecl& decl, const AidlTypenames& typenames,
                                 StatementBlock* read_block) {
-    const AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
-                                     /* type_params= */ nullptr, /* comments= */ "");
-    const string tag = "_aidl_tag";
-    const string value = "_aidl_value";
-
-    string block;
-    CodeWriterPtr out = CodeWriter::ForString(&block);
-
-    auto read_var = [&](const string& var, const AidlTypeSpecifier& type) {
-      *out << StringPrintf("%s %s;\n", CppNameOf(type, typenames).c_str(), var.c_str());
-      *out << StringPrintf("if ((%s = %s->%s(%s)) != %s) return %s;\n", kAndroidStatusVarName,
-                           kParcelVarName, ParcelReadMethodOf(type, typenames).c_str(),
-                           ParcelReadCastOf(type, typenames, "&" + var).c_str(), kAndroidStatusOk,
-                           kAndroidStatusVarName);
-    };
-
-    // begin
-    *out << StringPrintf("%s %s;\n", kAndroidStatusLiteral, kAndroidStatusVarName);
-    read_var(tag, tag_type);
-    *out << StringPrintf("switch (%s) {\n", tag.c_str());
-    for (const auto& variable : parcel.GetFields()) {
-      *out << StringPrintf("case %s: {\n", variable->GetName().c_str());
-      out->Indent();
-      read_var(value, variable->GetType());
-      *out << StringPrintf("set<%s>(std::move(%s));\n", variable->GetName().c_str(), value.c_str());
-      *out << StringPrintf("return %s; }\n", kAndroidStatusOk);
-      out->Dedent();
-    }
-    *out << "}\n";
-    *out << StringPrintf("return %s;\n", kAndroidStatusBadValue);
-    // end
-
-    out->Close();
-    read_block->AddLiteral(block, /*add_semicolon=*/false);
+    const string body = RunWriter([&](auto& out) {
+      UnionWriter uw{decl, typenames, &CppNameOf, &ConstantValueDecorator};
+      uw.ReadFromParcel(out, GetParcelWriterContext(typenames));
+    });
+    read_block->AddLiteral(body, /*add_semicolon=*/false);
   }
-  static void GenWriteToParcel(const AidlUnionDecl& parcel, const AidlTypenames& typenames,
+  static void GenWriteToParcel(const AidlUnionDecl& decl, const AidlTypenames& typenames,
                                StatementBlock* write_block) {
-    const AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
-                                     /* type_params= */ nullptr, /* comments= */ "");
-    const string tag = "_aidl_tag";
-    const string value = "_aidl_value";
+    const string body = RunWriter([&](auto& out) {
+      UnionWriter uw{decl, typenames, &CppNameOf, &ConstantValueDecorator};
+      uw.WriteToParcel(out, GetParcelWriterContext(typenames));
+    });
+    write_block->AddLiteral(body, /*add_semicolon=*/false);
+  }
 
-    string block;
-    CodeWriterPtr out = CodeWriter::ForString(&block);
-
-    auto write_value = [&](const string& value, const AidlTypeSpecifier& type) {
-      return StringPrintf("%s->%s(%s)", kParcelVarName,
-                          ParcelWriteMethodOf(type, typenames).c_str(),
-                          ParcelWriteCastOf(type, typenames, value).c_str());
-    };
-
-    // begin
-    *out << StringPrintf("%s %s = %s;\n", kAndroidStatusLiteral, kAndroidStatusVarName,
-                         write_value("getTag()", tag_type).c_str());
-    *out << StringPrintf("if (%s != %s) return %s;\n", kAndroidStatusVarName, kAndroidStatusOk,
-                         kAndroidStatusVarName);
-    *out << "switch (getTag()) {\n";
-    for (const auto& variable : parcel.GetFields()) {
-      const string value = "get<" + variable->GetName() + ">()";
-      *out << StringPrintf("case %s: return %s;\n", variable->GetName().c_str(),
-                           write_value(value, variable->GetType()).c_str());
-    }
-    *out << "}\n";
-    *out << "abort();\n";
-    // end
-
+ private:
+  static string RunWriter(std::function<void(CodeWriter&)> writer) {
+    string code;
+    CodeWriterPtr out = CodeWriter::ForString(&code);
+    writer(*out);
     out->Close();
-    write_block->AddLiteral(block, /*add_semicolon=*/false);
+    return code;
+  }
+  static ParcelWriterContext GetParcelWriterContext(const AidlTypenames& typenames) {
+    return ParcelWriterContext{
+        .status_type = kAndroidStatusLiteral,
+        .status_ok = kAndroidStatusOk,
+        .status_bad = kAndroidStatusBadValue,
+        .read_func =
+            [&](CodeWriter& out, const string& var, const AidlTypeSpecifier& type) {
+              out << fmt::format("{}->{}({})", kParcelVarName, ParcelReadMethodOf(type, typenames),
+                                 ParcelReadCastOf(type, typenames, "&" + var));
+            },
+        .write_func =
+            [&](CodeWriter& out, const string& value, const AidlTypeSpecifier& type) {
+              out << fmt::format("{}->{}({})", kParcelVarName, ParcelWriteMethodOf(type, typenames),
+                                 ParcelWriteCastOf(type, typenames, value));
+            },
+    };
   }
 };
 
@@ -1336,19 +1207,9 @@
     AddHeaders(variable->GetType(), typenames, &includes);
   }
 
-  set<string> operators = {"<", ">", "==", ">=", "<=", "!="};
-  string lhs = Traits::GetComparable(parcel, "");
-  string rhs = Traits::GetComparable(parcel, "rhs.");
-  bool is_empty = parcel.GetFields().empty();
-  std::ostringstream operator_code;
-  for (const auto& op : operators) {
-    operator_code << "inline bool operator" << op << "(const " << parcel.GetName() << "&"
-                  << (is_empty ? "" : " rhs") << ") const {\n"
-                  << "  return " << lhs << " " << op << " " << rhs << ";\n"
-                  << "}\n";
-  }
-  operator_code << "\n";
-  parcel_class->AddPublic(std::make_unique<LiteralDecl>(operator_code.str()));
+  string operator_code;
+  GenerateParcelableComparisonOperators(*CodeWriter::ForString(&operator_code), parcel);
+  parcel_class->AddPublic(std::make_unique<LiteralDecl>(operator_code));
 
   Traits::AddFields(*parcel_class, parcel, typenames);