move "generate union code" to aidl_to_cpp_common
This is a preparation to add NDK backend for union.
Bug: 170784707
Bug: 150948558
Test: aidl_unittests / aidl_integration_test
Change-Id: I347fc7e114b84eb03baf826c86388176d90bc81e
diff --git a/aidl_to_cpp_common.cpp b/aidl_to_cpp_common.cpp
index c1953ac..04e805f 100644
--- a/aidl_to_cpp_common.cpp
+++ b/aidl_to_cpp_common.cpp
@@ -15,8 +15,11 @@
*/
#include "aidl_to_cpp_common.h"
+#include <android-base/format.h>
#include <android-base/stringprintf.h>
#include <android-base/strings.h>
+
+#include <set>
#include <unordered_map>
#include "ast_cpp.h"
@@ -390,6 +393,179 @@
return decl;
}
+void GenerateParcelableComparisonOperators(CodeWriter& out, const AidlParcelable& parcelable) {
+ std::set<string> operators{"<", ">", "==", ">=", "<=", "!="};
+ bool is_empty = false;
+
+ auto comparable = [&](const string& prefix) {
+ vector<string> fields;
+ if (auto p = parcelable.AsStructuredParcelable(); p != nullptr) {
+ is_empty = p->GetFields().empty();
+ for (const auto& f : p->GetFields()) {
+ fields.push_back(prefix + f->GetName());
+ }
+ return "std::tie(" + Join(fields, ", ") + ")";
+ } else if (auto p = parcelable.AsUnionDeclaration(); p != nullptr) {
+ return prefix + "_value";
+ } else {
+ AIDL_FATAL(parcelable) << "Unknown paracelable type";
+ }
+ };
+
+ string lhs = comparable("");
+ string rhs = comparable("rhs.");
+ for (const auto& op : operators) {
+ out << "inline bool operator" << op << "(const " << parcelable.GetName() << "&"
+ << (is_empty ? "" : " rhs") << ") const {\n"
+ << " return " << lhs << " " << op << " " << rhs << ";\n"
+ << "}\n";
+ }
+ out << "\n";
+}
+
+const vector<string> UnionWriter::headers{
+ "type_traits", // std::is_same_v
+ "utility", // std::mode/forward for value
+ "variant", // std::variant for value
+};
+
+void UnionWriter::PrivateFields(CodeWriter& out) const {
+ vector<string> field_types;
+ for (const auto& f : decl.GetFields()) {
+ field_types.push_back(name_of(f->GetType(), typenames));
+ }
+ out << "std::variant<" + Join(field_types, ", ") + "> _value;\n";
+}
+
+void UnionWriter::PublicFields(CodeWriter& out) const {
+ AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+ /* type_params= */ nullptr, /* comments= */ "");
+ tag_type.Resolve(typenames);
+
+ out << "enum Tag : " << name_of(tag_type, typenames) << " {\n";
+ bool is_first = true;
+ for (const auto& f : decl.GetFields()) {
+ out << " " << f->GetName() << (is_first ? " = 0" : "") << ", // " << f->Signature() << ";\n";
+ is_first = false;
+ }
+ out << "};\n";
+
+ const auto& name = decl.GetName();
+
+ AIDL_FATAL_IF(decl.GetFields().empty(), decl) << "Union '" << name << "' is empty.";
+ const auto& first_field = decl.GetFields()[0];
+ const auto& default_name = first_field->GetName();
+ const auto& default_value =
+ name_of(first_field->GetType(), typenames) + "(" + first_field->ValueString(decorator) + ")";
+
+ auto tmpl = R"--(
+template<typename _Tp>
+static constexpr bool _not_self = !std::is_same_v<std::remove_cv_t<std::remove_reference_t<_Tp>>, {name}>;
+
+{name}() : _value(std::in_place_index<{default_name}>, {default_value}) {{ }}
+{name}(const {name}&) = default;
+{name}({name}&&) = default;
+{name}& operator=(const {name}&) = default;
+{name}& operator=({name}&&) = default;
+
+template <typename _Tp, std::enable_if_t<_not_self<_Tp>, int> = 0>
+constexpr {name}(_Tp&& _arg)
+ : _value(std::forward<_Tp>(_arg)) {{}}
+
+template <typename... _Tp>
+constexpr explicit {name}(_Tp&&... _args)
+ : _value(std::forward<_Tp>(_args)...) {{}}
+
+template <Tag _tag, typename... _Tp>
+static {name} make(_Tp&&... _args) {{
+ return {name}(std::in_place_index<_tag>, std::forward<_Tp>(_args)...);
+}}
+
+template <Tag _tag, typename _Tp, typename... _Up>
+static {name} make(std::initializer_list<_Tp> _il, _Up&&... _args) {{
+ return {name}(std::in_place_index<_tag>, std::move(_il), std::forward<_Up>(_args)...);
+}}
+
+Tag getTag() const {{
+ return static_cast<Tag>(_value.index());
+}}
+
+template <Tag _tag>
+const auto& get() const {{
+ if (getTag() != _tag) {{ abort(); }}
+ return std::get<_tag>(_value);
+}}
+
+template <Tag _tag>
+auto& get() {{
+ if (getTag() != _tag) {{ abort(); }}
+ return std::get<_tag>(_value);
+}}
+
+template <Tag _tag, typename... _Tp>
+void set(_Tp&&... _args) {{
+ _value.emplace<_tag>(std::forward<_Tp>(_args)...);
+}}
+
+)--";
+ out << fmt::format(tmpl, fmt::arg("name", name), fmt::arg("default_name", default_name),
+ fmt::arg("default_value", default_value));
+}
+
+void UnionWriter::ReadFromParcel(CodeWriter& out, const ParcelWriterContext& ctx) const {
+ AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+ /* type_params= */ nullptr, /* comments= */ "");
+ tag_type.Resolve(typenames);
+
+ const string tag = "_aidl_tag";
+ const string value = "_aidl_value";
+ const string status = "_aidl_ret_status";
+
+ auto read_var = [&](const string& var, const AidlTypeSpecifier& type) {
+ out << fmt::format("{} {};\n", name_of(type, typenames), var);
+ out << fmt::format("if (({} = ", status);
+ ctx.read_func(out, var, type);
+ out << fmt::format(") != {}) return {};\n", ctx.status_ok, status);
+ };
+
+ out << fmt::format("{} {};\n", ctx.status_type, status);
+ read_var(tag, tag_type);
+ out << fmt::format("switch ({}) {{\n", tag);
+ for (const auto& variable : decl.GetFields()) {
+ out << fmt::format("case {}: {{\n", variable->GetName());
+ out.Indent();
+ read_var(value, variable->GetType());
+ out << fmt::format("set<{}>(std::move({}));\n", variable->GetName(), value);
+ out << fmt::format("return {}; }}\n", ctx.status_ok);
+ out.Dedent();
+ }
+ out << "}\n";
+ out << fmt::format("return {};\n", ctx.status_bad);
+}
+
+void UnionWriter::WriteToParcel(CodeWriter& out, const ParcelWriterContext& ctx) const {
+ AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
+ /* type_params= */ nullptr, /* comments= */ "");
+ tag_type.Resolve(typenames);
+
+ const string tag = "_aidl_tag";
+ const string value = "_aidl_value";
+ const string status = "_aidl_ret_status";
+
+ out << fmt::format("{} {} = ", ctx.status_type, status);
+ ctx.write_func(out, "getTag()", tag_type);
+ out << ";\n";
+ out << fmt::format("if ({} != {}) return {};\n", status, ctx.status_ok, status);
+ out << "switch (getTag()) {\n";
+ for (const auto& variable : decl.GetFields()) {
+ out << fmt::format("case {}: return ", variable->GetName());
+ ctx.write_func(out, "get<" + variable->GetName() + ">()", variable->GetType());
+ out << ";\n";
+ }
+ out << "}\n";
+ out << "abort();\n";
+}
+
} // namespace cpp
} // namespace aidl
} // namespace android
diff --git a/aidl_to_cpp_common.h b/aidl_to_cpp_common.h
index 2892f88..f49edb0 100644
--- a/aidl_to_cpp_common.h
+++ b/aidl_to_cpp_common.h
@@ -16,6 +16,7 @@
#pragma once
+#include <functional>
#include <string>
#include <type_traits>
@@ -74,6 +75,31 @@
const std::vector<std::string>& enclosing_namespaces_of_enum_decl);
std::string TemplateDecl(const AidlStructuredParcelable& defined_type);
+void GenerateParcelableComparisonOperators(CodeWriter& out, const AidlParcelable& parcelable);
+
+struct ParcelWriterContext {
+ string status_type;
+ string status_ok;
+ string status_bad;
+ std::function<void(CodeWriter& out, const std::string& var, const AidlTypeSpecifier& type)>
+ read_func;
+ std::function<void(CodeWriter& out, const std::string& value, const AidlTypeSpecifier& type)>
+ write_func;
+};
+
+struct UnionWriter {
+ const AidlUnionDecl& decl;
+ const AidlTypenames& typenames;
+ const std::function<std::string(const AidlTypeSpecifier&, const AidlTypenames&)> name_of;
+ const ::ConstantValueDecorator& decorator;
+ static const std::vector<std::string> headers;
+
+ void PrivateFields(CodeWriter& out) const;
+ void PublicFields(CodeWriter& out) const;
+ void ReadFromParcel(CodeWriter& out, const ParcelWriterContext&) const;
+ void WriteToParcel(CodeWriter& out, const ParcelWriterContext&) const;
+};
+
} // namespace cpp
} // namespace aidl
} // namespace android
diff --git a/aidl_unittest.cpp b/aidl_unittest.cpp
index ddfa0a5..3dd4e62 100644
--- a/aidl_unittest.cpp
+++ b/aidl_unittest.cpp
@@ -2701,6 +2701,7 @@
#include <binder/Parcel.h>
#include <binder/Status.h>
#include <cstdint>
+#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
diff --git a/generate_cpp.cpp b/generate_cpp.cpp
index ad32352..616d472 100644
--- a/generate_cpp.cpp
+++ b/generate_cpp.cpp
@@ -24,6 +24,7 @@
#include <set>
#include <string>
+#include <android-base/format.h>
#include <android-base/stringprintf.h>
#include "aidl_language.h"
@@ -1044,7 +1045,6 @@
template <typename ParcelableType>
struct ParcelableTraits {
static void AddIncludes(set<string>& includes);
- static string GetComparable(const ParcelableType& decl, const string& var_prefix);
static void AddFields(ClassDecl& clazz, const ParcelableType& decl,
const AidlTypenames& typenames);
static void GenReadFromParcel(const ParcelableType& parcel, const AidlTypenames& typenames,
@@ -1058,13 +1058,6 @@
static void AddIncludes(set<string>& includes) {
includes.insert("tuple"); // std::tie in comparison operators
}
- static string GetComparable(const AidlStructuredParcelable& decl, const string& var_prefix) {
- std::vector<std::string> var_names;
- for (const auto& variable : decl.GetFields()) {
- var_names.push_back(var_prefix + variable->GetName());
- }
- return "std::tie(" + Join(var_names, ", ") + ")";
- }
static void AddFields(ClassDecl& clazz, const AidlStructuredParcelable& decl,
const AidlTypenames& typenames) {
for (const auto& variable : decl.GetFields()) {
@@ -1142,183 +1135,61 @@
}
};
+// Adapter to cpp::UnionWriter
template <>
struct ParcelableTraits<AidlUnionDecl> {
static void AddIncludes(set<string>& includes) {
- includes.insert("variant"); // std::variant for value
- includes.insert("utility"); // std::mode/forward for value
- }
- static string GetComparable(const AidlUnionDecl&, const string& var_prefix) {
- return var_prefix + "_value";
+ includes.insert(std::begin(UnionWriter::headers), std::end(UnionWriter::headers));
}
static void AddFields(ClassDecl& clazz, const AidlUnionDecl& decl,
const AidlTypenames& typenames) {
- AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
- /* type_params= */ nullptr, /* comments= */ "");
- tag_type.Resolve(typenames);
-
- std::ostringstream out;
- out << "enum Tag : " << CppNameOf(tag_type, typenames) << " {\n";
- bool is_first = true;
- for (const auto& f : decl.GetFields()) {
- out << " " << f->GetName() << (is_first ? " = 0" : "") << ", // " << f->Signature()
- << ";\n";
- is_first = false;
- }
- out << "};\n\n";
- clazz.AddPublic(std::make_unique<LiteralDecl>(out.str()));
-
- const auto& name = decl.GetName();
-
- AIDL_FATAL_IF(decl.GetFields().empty(), decl) << "Union '" << name << "' is empty.";
- const auto& first_field = decl.GetFields()[0];
- const auto& first_name = first_field->GetName();
- const auto& first_value = GetInitializer(typenames, *first_field);
-
- // clang-format off
- auto helper_methods = vector<string>{
- // type classification
- "template<typename _Tp>\n"
- "static constexpr bool _not_self = !std::is_same_v<std::remove_cv_t<std::remove_reference_t<_Tp>>, " + name + ">;\n\n",
-
- // default ctor inits with the first member's default value
- name + "() : _value(std::in_place_index<" + first_name + ">, " + first_value + ") { }\n",
-
- // other ctors with default implementation
- name + "(const " + name + "&) = default;\n",
- name + "(" + name + "&&) = default;\n",
- name + "& operator=(const " + name + "&) = default;\n",
- name + "& operator=(" + name + "&&) = default;\n\n",
-
- // conversion ctor from value
- "template <typename _Tp, std::enable_if_t<_not_self<_Tp>, int> = 0>\n"
- "constexpr " + name + "(_Tp&& _arg)\n"
- " : _value(std::forward<_Tp>(_arg)) {}\n\n",
-
- // ctor to support in-place construction using in_place_index/in_place_type
- "template <typename... _Tp>\n"
- "constexpr explicit " + name + "(_Tp&&... _args)\n"
- " : _value(std::forward<_Tp>(_args)...) {}\n\n",
-
- // value ctor: make<tag>(args...)
- "template <Tag _tag, typename... _Tp>\n"
- "static " + name + " make(_Tp&&... _args) {\n"
- " return " + name + "(std::in_place_index<_tag>, std::forward<_Tp>(_args)...);\n"
- "}\n\n",
-
- // value ctor: make<tag>({initializer_list})
- "template <Tag _tag, typename _Tp, typename... _Up>\n"
- "static " + name + " make(std::initializer_list<_Tp> _il, _Up&&... _args) {\n"
- " return " + name + "(std::in_place_index<_tag>, std::move(_il), std::forward<_Up>(_args)...);\n"
- "}\n\n",
-
- // getTag
- "Tag getTag() const {\n"
- " return static_cast<Tag>(_value.index());\n"
- "}\n\n",
-
- // const-getter
- "template <Tag _tag>\n"
- "const auto& get() const {\n"
- " if (getTag() != _tag) { abort(); }\n"
- " return std::get<_tag>(_value);\n"
- "}\n\n",
-
- // getter
- "template <Tag _tag>\n"
- "auto& get() {\n"
- " if (getTag() != _tag) { abort(); }\n"
- " return std::get<_tag>(_value);\n"
- "}\n\n",
-
- // setter
- "template <Tag _tag, typename... _Tp>\n"
- "void set(_Tp&&... _args) {\n"
- " _value.emplace<_tag>(std::forward<_Tp>(_args)...);\n"
- "}\n\n",
- };
- // clang-format on
- for (const auto& helper_method : helper_methods) {
- clazz.AddPublic(std::make_unique<LiteralDecl>(helper_method));
- }
-
- vector<string> field_types;
- for (const auto& f : decl.GetFields()) {
- field_types.push_back(CppNameOf(f->GetType(), typenames));
- }
- clazz.AddPrivate(
- std::make_unique<LiteralDecl>("std::variant<" + Join(field_types, ", ") + "> _value;\n"));
+ UnionWriter uw{decl, typenames, &CppNameOf, &ConstantValueDecorator};
+ const string public_fields = RunWriter([&](auto& out) { uw.PublicFields(out); });
+ const string private_fields = RunWriter([&](auto& out) { uw.PrivateFields(out); });
+ clazz.AddPublic(std::make_unique<LiteralDecl>(public_fields));
+ clazz.AddPrivate(std::make_unique<LiteralDecl>(private_fields));
}
- static void GenReadFromParcel(const AidlUnionDecl& parcel, const AidlTypenames& typenames,
+ static void GenReadFromParcel(const AidlUnionDecl& decl, const AidlTypenames& typenames,
StatementBlock* read_block) {
- const AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
- /* type_params= */ nullptr, /* comments= */ "");
- const string tag = "_aidl_tag";
- const string value = "_aidl_value";
-
- string block;
- CodeWriterPtr out = CodeWriter::ForString(&block);
-
- auto read_var = [&](const string& var, const AidlTypeSpecifier& type) {
- *out << StringPrintf("%s %s;\n", CppNameOf(type, typenames).c_str(), var.c_str());
- *out << StringPrintf("if ((%s = %s->%s(%s)) != %s) return %s;\n", kAndroidStatusVarName,
- kParcelVarName, ParcelReadMethodOf(type, typenames).c_str(),
- ParcelReadCastOf(type, typenames, "&" + var).c_str(), kAndroidStatusOk,
- kAndroidStatusVarName);
- };
-
- // begin
- *out << StringPrintf("%s %s;\n", kAndroidStatusLiteral, kAndroidStatusVarName);
- read_var(tag, tag_type);
- *out << StringPrintf("switch (%s) {\n", tag.c_str());
- for (const auto& variable : parcel.GetFields()) {
- *out << StringPrintf("case %s: {\n", variable->GetName().c_str());
- out->Indent();
- read_var(value, variable->GetType());
- *out << StringPrintf("set<%s>(std::move(%s));\n", variable->GetName().c_str(), value.c_str());
- *out << StringPrintf("return %s; }\n", kAndroidStatusOk);
- out->Dedent();
- }
- *out << "}\n";
- *out << StringPrintf("return %s;\n", kAndroidStatusBadValue);
- // end
-
- out->Close();
- read_block->AddLiteral(block, /*add_semicolon=*/false);
+ const string body = RunWriter([&](auto& out) {
+ UnionWriter uw{decl, typenames, &CppNameOf, &ConstantValueDecorator};
+ uw.ReadFromParcel(out, GetParcelWriterContext(typenames));
+ });
+ read_block->AddLiteral(body, /*add_semicolon=*/false);
}
- static void GenWriteToParcel(const AidlUnionDecl& parcel, const AidlTypenames& typenames,
+ static void GenWriteToParcel(const AidlUnionDecl& decl, const AidlTypenames& typenames,
StatementBlock* write_block) {
- const AidlTypeSpecifier tag_type(AIDL_LOCATION_HERE, "int", /* is_array= */ false,
- /* type_params= */ nullptr, /* comments= */ "");
- const string tag = "_aidl_tag";
- const string value = "_aidl_value";
+ const string body = RunWriter([&](auto& out) {
+ UnionWriter uw{decl, typenames, &CppNameOf, &ConstantValueDecorator};
+ uw.WriteToParcel(out, GetParcelWriterContext(typenames));
+ });
+ write_block->AddLiteral(body, /*add_semicolon=*/false);
+ }
- string block;
- CodeWriterPtr out = CodeWriter::ForString(&block);
-
- auto write_value = [&](const string& value, const AidlTypeSpecifier& type) {
- return StringPrintf("%s->%s(%s)", kParcelVarName,
- ParcelWriteMethodOf(type, typenames).c_str(),
- ParcelWriteCastOf(type, typenames, value).c_str());
- };
-
- // begin
- *out << StringPrintf("%s %s = %s;\n", kAndroidStatusLiteral, kAndroidStatusVarName,
- write_value("getTag()", tag_type).c_str());
- *out << StringPrintf("if (%s != %s) return %s;\n", kAndroidStatusVarName, kAndroidStatusOk,
- kAndroidStatusVarName);
- *out << "switch (getTag()) {\n";
- for (const auto& variable : parcel.GetFields()) {
- const string value = "get<" + variable->GetName() + ">()";
- *out << StringPrintf("case %s: return %s;\n", variable->GetName().c_str(),
- write_value(value, variable->GetType()).c_str());
- }
- *out << "}\n";
- *out << "abort();\n";
- // end
-
+ private:
+ static string RunWriter(std::function<void(CodeWriter&)> writer) {
+ string code;
+ CodeWriterPtr out = CodeWriter::ForString(&code);
+ writer(*out);
out->Close();
- write_block->AddLiteral(block, /*add_semicolon=*/false);
+ return code;
+ }
+ static ParcelWriterContext GetParcelWriterContext(const AidlTypenames& typenames) {
+ return ParcelWriterContext{
+ .status_type = kAndroidStatusLiteral,
+ .status_ok = kAndroidStatusOk,
+ .status_bad = kAndroidStatusBadValue,
+ .read_func =
+ [&](CodeWriter& out, const string& var, const AidlTypeSpecifier& type) {
+ out << fmt::format("{}->{}({})", kParcelVarName, ParcelReadMethodOf(type, typenames),
+ ParcelReadCastOf(type, typenames, "&" + var));
+ },
+ .write_func =
+ [&](CodeWriter& out, const string& value, const AidlTypeSpecifier& type) {
+ out << fmt::format("{}->{}({})", kParcelVarName, ParcelWriteMethodOf(type, typenames),
+ ParcelWriteCastOf(type, typenames, value));
+ },
+ };
}
};
@@ -1336,19 +1207,9 @@
AddHeaders(variable->GetType(), typenames, &includes);
}
- set<string> operators = {"<", ">", "==", ">=", "<=", "!="};
- string lhs = Traits::GetComparable(parcel, "");
- string rhs = Traits::GetComparable(parcel, "rhs.");
- bool is_empty = parcel.GetFields().empty();
- std::ostringstream operator_code;
- for (const auto& op : operators) {
- operator_code << "inline bool operator" << op << "(const " << parcel.GetName() << "&"
- << (is_empty ? "" : " rhs") << ") const {\n"
- << " return " << lhs << " " << op << " " << rhs << ";\n"
- << "}\n";
- }
- operator_code << "\n";
- parcel_class->AddPublic(std::make_unique<LiteralDecl>(operator_code.str()));
+ string operator_code;
+ GenerateParcelableComparisonOperators(*CodeWriter::ForString(&operator_code), parcel);
+ parcel_class->AddPublic(std::make_unique<LiteralDecl>(operator_code));
Traits::AddFields(*parcel_class, parcel, typenames);