move "generate union code" to aidl_to_cpp_common
This is a preparation to add NDK backend for union.
Bug: 170784707
Bug: 150948558
Test: aidl_unittests / aidl_integration_test
Change-Id: I347fc7e114b84eb03baf826c86388176d90bc81e
diff --git a/aidl_to_cpp_common.cpp b/aidl_to_cpp_common.cpp
index c1953ac..04e805f 100644
--- a/aidl_to_cpp_common.cpp
+++ b/aidl_to_cpp_common.cpp
@@ -15,8 +15,11 @@
*/
#include "aidl_to_cpp_common.h"
+#include <android-base/format.h>
#include <android-base/stringprintf.h>
#include <android-base/strings.h>
+
+#include <set>
#include <unordered_map>
#include "ast_cpp.h"
@@ -390,6 +393,179 @@
return decl;
}
+void GenerateParcelableComparisonOperators(CodeWriter& out, const AidlParcelable& parcelable) {
+ std::set<string> operators{"<", ">", "==", ">=", "<=", "!="};
+ bool is_empty = false;
+
+ auto comparable = [&](const string& prefix) {
+ vector<string> fields;
+ if (auto p = parcelable.AsStructuredParcelable(); p != nullptr) {
+ is_empty = p->GetFields().empty();
+ for (const auto& f : p->GetFields()) {
+ fields.push_back(prefix + f->GetName());
+ }
+ return "std::tie(" + Join(fields, ", ") + ")";
+ } else if (auto p = parcelable.AsUnionDeclaration(); p != nullptr) {
+ return prefix + "_value";
+ } else {
+ AIDL_FATAL(parcelable) << "Unknown paracelable type";
+ }
+ };
+
+ string lhs = comparable("");
+ string rhs = comparable("rhs.");
+ for (const auto& op : operators) {
+ out << "inline bool operator" << op << "(const " << parcelable.GetName() << "&"
+ << (is_empty ? "" : " rhs") << ") const {\n"
+ << " return " << lhs << " " << op << " " << rhs << ";\n"
+ << "}\n";
+ }
+ out << "\n";
+}
+
+const vector<string> UnionWriter::headers{
+ "type_traits", // std::is_same_v
+ "utility", // std::mode/forward for value
+ "variant", // std::variant for value
+};
+
+void UnionWriter::PrivateFields(CodeWriter& out) const {
+ vector<string> field_types;
+ for (const auto& f : decl.GetFields()) {
+ field_types.push_back(name_of(f->GetType(), typenames));
+ }
+ out << "std::variant<" + Join(field_types, ", ") + "> _value;\n";
+}
+
+void UnionWriter::PublicFields(CodeWriter& out) const {
+ AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+ /* type_params= */ nullptr, /* comments= */ "");
+ tag_type.Resolve(typenames);
+
+ out << "enum Tag : " << name_of(tag_type, typenames) << " {\n";
+ bool is_first = true;
+ for (const auto& f : decl.GetFields()) {
+ out << " " << f->GetName() << (is_first ? " = 0" : "") << ", // " << f->Signature() << ";\n";
+ is_first = false;
+ }
+ out << "};\n";
+
+ const auto& name = decl.GetName();
+
+ AIDL_FATAL_IF(decl.GetFields().empty(), decl) << "Union '" << name << "' is empty.";
+ const auto& first_field = decl.GetFields()[0];
+ const auto& default_name = first_field->GetName();
+ const auto& default_value =
+ name_of(first_field->GetType(), typenames) + "(" + first_field->ValueString(decorator) + ")";
+
+ auto tmpl = R"--(
+template<typename _Tp>
+static constexpr bool _not_self = !std::is_same_v<std::remove_cv_t<std::remove_reference_t<_Tp>>, {name}>;
+
+{name}() : _value(std::in_place_index<{default_name}>, {default_value}) {{ }}
+{name}(const {name}&) = default;
+{name}({name}&&) = default;
+{name}& operator=(const {name}&) = default;
+{name}& operator=({name}&&) = default;
+
+template <typename _Tp, std::enable_if_t<_not_self<_Tp>, int> = 0>
+constexpr {name}(_Tp&& _arg)
+ : _value(std::forward<_Tp>(_arg)) {{}}
+
+template <typename... _Tp>
+constexpr explicit {name}(_Tp&&... _args)
+ : _value(std::forward<_Tp>(_args)...) {{}}
+
+template <Tag _tag, typename... _Tp>
+static {name} make(_Tp&&... _args) {{
+ return {name}(std::in_place_index<_tag>, std::forward<_Tp>(_args)...);
+}}
+
+template <Tag _tag, typename _Tp, typename... _Up>
+static {name} make(std::initializer_list<_Tp> _il, _Up&&... _args) {{
+ return {name}(std::in_place_index<_tag>, std::move(_il), std::forward<_Up>(_args)...);
+}}
+
+Tag getTag() const {{
+ return static_cast<Tag>(_value.index());
+}}
+
+template <Tag _tag>
+const auto& get() const {{
+ if (getTag() != _tag) {{ abort(); }}
+ return std::get<_tag>(_value);
+}}
+
+template <Tag _tag>
+auto& get() {{
+ if (getTag() != _tag) {{ abort(); }}
+ return std::get<_tag>(_value);
+}}
+
+template <Tag _tag, typename... _Tp>
+void set(_Tp&&... _args) {{
+ _value.emplace<_tag>(std::forward<_Tp>(_args)...);
+}}
+
+)--";
+ out << fmt::format(tmpl, fmt::arg("name", name), fmt::arg("default_name", default_name),
+ fmt::arg("default_value", default_value));
+}
+
+void UnionWriter::ReadFromParcel(CodeWriter& out, const ParcelWriterContext& ctx) const {
+ AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+ /* type_params= */ nullptr, /* comments= */ "");
+ tag_type.Resolve(typenames);
+
+ const string tag = "_aidl_tag";
+ const string value = "_aidl_value";
+ const string status = "_aidl_ret_status";
+
+ auto read_var = [&](const string& var, const AidlTypeSpecifier& type) {
+ out << fmt::format("{} {};\n", name_of(type, typenames), var);
+ out << fmt::format("if (({} = ", status);
+ ctx.read_func(out, var, type);
+ out << fmt::format(") != {}) return {};\n", ctx.status_ok, status);
+ };
+
+ out << fmt::format("{} {};\n", ctx.status_type, status);
+ read_var(tag, tag_type);
+ out << fmt::format("switch ({}) {{\n", tag);
+ for (const auto& variable : decl.GetFields()) {
+ out << fmt::format("case {}: {{\n", variable->GetName());
+ out.Indent();
+ read_var(value, variable->GetType());
+ out << fmt::format("set<{}>(std::move({}));\n", variable->GetName(), value);
+ out << fmt::format("return {}; }}\n", ctx.status_ok);
+ out.Dedent();
+ }
+ out << "}\n";
+ out << fmt::format("return {};\n", ctx.status_bad);
+}
+
+void UnionWriter::WriteToParcel(CodeWriter& out, const ParcelWriterContext& ctx) const {
+ AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+ /* type_params= */ nullptr, /* comments= */ "");
+ tag_type.Resolve(typenames);
+
+ const string tag = "_aidl_tag";
+ const string value = "_aidl_value";
+ const string status = "_aidl_ret_status";
+
+ out << fmt::format("{} {} = ", ctx.status_type, status);
+ ctx.write_func(out, "getTag()", tag_type);
+ out << ";\n";
+ out << fmt::format("if ({} != {}) return {};\n", status, ctx.status_ok, status);
+ out << "switch (getTag()) {\n";
+ for (const auto& variable : decl.GetFields()) {
+ out << fmt::format("case {}: return ", variable->GetName());
+ ctx.write_func(out, "get<" + variable->GetName() + ">()", variable->GetType());
+ out << ";\n";
+ }
+ out << "}\n";
+ out << "abort();\n";
+}
+
} // namespace cpp
} // namespace aidl
} // namespace android