move "generate union code" to aidl_to_cpp_common

This is a preparation to add NDK backend for union.

Bug: 170784707
Bug: 150948558
Test: aidl_unittests / aidl_integration_test
Change-Id: I347fc7e114b84eb03baf826c86388176d90bc81e
diff --git a/aidl_to_cpp_common.cpp b/aidl_to_cpp_common.cpp
index c1953ac..04e805f 100644
--- a/aidl_to_cpp_common.cpp
+++ b/aidl_to_cpp_common.cpp
@@ -15,8 +15,11 @@
  */
 #include "aidl_to_cpp_common.h"
 
+#include <android-base/format.h>
 #include <android-base/stringprintf.h>
 #include <android-base/strings.h>
+
+#include <set>
 #include <unordered_map>
 
 #include "ast_cpp.h"
@@ -390,6 +393,179 @@
   return decl;
 }
 
+void GenerateParcelableComparisonOperators(CodeWriter& out, const AidlParcelable& parcelable) {
+  std::set<string> operators{"<", ">", "==", ">=", "<=", "!="};
+  bool is_empty = false;
+
+  auto comparable = [&](const string& prefix) {
+    vector<string> fields;
+    if (auto p = parcelable.AsStructuredParcelable(); p != nullptr) {
+      is_empty = p->GetFields().empty();
+      for (const auto& f : p->GetFields()) {
+        fields.push_back(prefix + f->GetName());
+      }
+      return "std::tie(" + Join(fields, ", ") + ")";
+    } else if (auto p = parcelable.AsUnionDeclaration(); p != nullptr) {
+      return prefix + "_value";
+    } else {
+      AIDL_FATAL(parcelable) << "Unknown paracelable type";
+    }
+  };
+
+  string lhs = comparable("");
+  string rhs = comparable("rhs.");
+  for (const auto& op : operators) {
+    out << "inline bool operator" << op << "(const " << parcelable.GetName() << "&"
+        << (is_empty ? "" : " rhs") << ") const {\n"
+        << "  return " << lhs << " " << op << " " << rhs << ";\n"
+        << "}\n";
+  }
+  out << "\n";
+}
+
+const vector<string> UnionWriter::headers{
+    "type_traits",  // std::is_same_v
+    "utility",      // std::mode/forward for value
+    "variant",      // std::variant for value
+};
+
+void UnionWriter::PrivateFields(CodeWriter& out) const {
+  vector<string> field_types;
+  for (const auto& f : decl.GetFields()) {
+    field_types.push_back(name_of(f->GetType(), typenames));
+  }
+  out << "std::variant<" + Join(field_types, ", ") + "> _value;\n";
+}
+
+void UnionWriter::PublicFields(CodeWriter& out) const {
+  AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+                             /* type_params= */ nullptr, /* comments= */ "");
+  tag_type.Resolve(typenames);
+
+  out << "enum Tag : " << name_of(tag_type, typenames) << " {\n";
+  bool is_first = true;
+  for (const auto& f : decl.GetFields()) {
+    out << "  " << f->GetName() << (is_first ? " = 0" : "") << ",  // " << f->Signature() << ";\n";
+    is_first = false;
+  }
+  out << "};\n";
+
+  const auto& name = decl.GetName();
+
+  AIDL_FATAL_IF(decl.GetFields().empty(), decl) << "Union '" << name << "' is empty.";
+  const auto& first_field = decl.GetFields()[0];
+  const auto& default_name = first_field->GetName();
+  const auto& default_value =
+      name_of(first_field->GetType(), typenames) + "(" + first_field->ValueString(decorator) + ")";
+
+  auto tmpl = R"--(
+template<typename _Tp>
+static constexpr bool _not_self = !std::is_same_v<std::remove_cv_t<std::remove_reference_t<_Tp>>, {name}>;
+
+{name}() : _value(std::in_place_index<{default_name}>, {default_value}) {{ }}
+{name}(const {name}&) = default;
+{name}({name}&&) = default;
+{name}& operator=(const {name}&) = default;
+{name}& operator=({name}&&) = default;
+
+template <typename _Tp, std::enable_if_t<_not_self<_Tp>, int> = 0>
+constexpr {name}(_Tp&& _arg)
+    : _value(std::forward<_Tp>(_arg)) {{}}
+
+template <typename... _Tp>
+constexpr explicit {name}(_Tp&&... _args)
+    : _value(std::forward<_Tp>(_args)...) {{}}
+
+template <Tag _tag, typename... _Tp>
+static {name} make(_Tp&&... _args) {{
+  return {name}(std::in_place_index<_tag>, std::forward<_Tp>(_args)...);
+}}
+
+template <Tag _tag, typename _Tp, typename... _Up>
+static {name} make(std::initializer_list<_Tp> _il, _Up&&... _args) {{
+  return {name}(std::in_place_index<_tag>, std::move(_il), std::forward<_Up>(_args)...);
+}}
+
+Tag getTag() const {{
+  return static_cast<Tag>(_value.index());
+}}
+
+template <Tag _tag>
+const auto& get() const {{
+  if (getTag() != _tag) {{ abort(); }}
+  return std::get<_tag>(_value);
+}}
+
+template <Tag _tag>
+auto& get() {{
+  if (getTag() != _tag) {{ abort(); }}
+  return std::get<_tag>(_value);
+}}
+
+template <Tag _tag, typename... _Tp>
+void set(_Tp&&... _args) {{
+  _value.emplace<_tag>(std::forward<_Tp>(_args)...);
+}}
+
+)--";
+  out << fmt::format(tmpl, fmt::arg("name", name), fmt::arg("default_name", default_name),
+                     fmt::arg("default_value", default_value));
+}
+
+void UnionWriter::ReadFromParcel(CodeWriter& out, const ParcelWriterContext& ctx) const {
+  AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+                             /* type_params= */ nullptr, /* comments= */ "");
+  tag_type.Resolve(typenames);
+
+  const string tag = "_aidl_tag";
+  const string value = "_aidl_value";
+  const string status = "_aidl_ret_status";
+
+  auto read_var = [&](const string& var, const AidlTypeSpecifier& type) {
+    out << fmt::format("{} {};\n", name_of(type, typenames), var);
+    out << fmt::format("if (({} = ", status);
+    ctx.read_func(out, var, type);
+    out << fmt::format(") != {}) return {};\n", ctx.status_ok, status);
+  };
+
+  out << fmt::format("{} {};\n", ctx.status_type, status);
+  read_var(tag, tag_type);
+  out << fmt::format("switch ({}) {{\n", tag);
+  for (const auto& variable : decl.GetFields()) {
+    out << fmt::format("case {}: {{\n", variable->GetName());
+    out.Indent();
+    read_var(value, variable->GetType());
+    out << fmt::format("set<{}>(std::move({}));\n", variable->GetName(), value);
+    out << fmt::format("return {}; }}\n", ctx.status_ok);
+    out.Dedent();
+  }
+  out << "}\n";
+  out << fmt::format("return {};\n", ctx.status_bad);
+}
+
+void UnionWriter::WriteToParcel(CodeWriter& out, const ParcelWriterContext& ctx) const {
+  AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+                             /* type_params= */ nullptr, /* comments= */ "");
+  tag_type.Resolve(typenames);
+
+  const string tag = "_aidl_tag";
+  const string value = "_aidl_value";
+  const string status = "_aidl_ret_status";
+
+  out << fmt::format("{} {} = ", ctx.status_type, status);
+  ctx.write_func(out, "getTag()", tag_type);
+  out << ";\n";
+  out << fmt::format("if ({} != {}) return {};\n", status, ctx.status_ok, status);
+  out << "switch (getTag()) {\n";
+  for (const auto& variable : decl.GetFields()) {
+    out << fmt::format("case {}: return ", variable->GetName());
+    ctx.write_func(out, "get<" + variable->GetName() + ">()", variable->GetType());
+    out << ";\n";
+  }
+  out << "}\n";
+  out << "abort();\n";
+}
+
 }  // namespace cpp
 }  // namespace aidl
 }  // namespace android