Implementing Discriminated Unions in HIDL
This change augments hidl-gen to generate C++ code for HIDL safe_union
constructs. The HIDL declaration is similar to regular structs/unions:
safe_union SafeUnion {
uint8_t a;
string b;
}
The emitted C++ code is:
struct SafeUnion final {
public:
enum class hidl_discriminator : uint8_t {
a,
b,
hidl_no_init
};
SafeUnion();
~SafeUnion();
SafeUnion(const SafeUnion&);
void a(uint8_t);
uint8_t& a();
uint8_t a() const;
void b(const ::android::hardware::hidl_string&);
::android::hardware::hidl_string& b();
const ::android::hardware::hidl_string& b() const;
// Utility method
hidl_discriminator getDiscriminator() const;
private:
void hidl_destructUnion();
union hidl_union final {
uint8_t a __attribute__ ((aligned(1)));
::android::hardware::hidl_string b __attribute__ ((aligned(8)));
hidl_union();
~hidl_union();
} u;
hidl_discriminator d __attribute__ ((aligned(1)))\
{hidl_discriminator::hidl_no_init};
};
Also synthesizes the appropriate toString, operator== and
{read,write}Embedded{From,To}Parcel methods.
Bug: 79878527
Test: Ran the existing hidl_test suite, tested safe_union get/set
methods on a test HAL.
Change-Id: I894c144204ab3bdbe446022b2805a64a211d28b3
Merged-In: I894c144204ab3bdbe446022b2805a64a211d28b3
diff --git a/CompoundType.cpp b/CompoundType.cpp
index e0759bc..23059f0 100644
--- a/CompoundType.cpp
+++ b/CompoundType.cpp
@@ -17,6 +17,7 @@
#include "CompoundType.h"
#include "ArrayType.h"
+#include "ScalarType.h"
#include "VectorType.h"
#include <android-base/logging.h>
@@ -66,6 +67,9 @@
status_t err = validateUniqueNames();
if (err != OK) return err;
+ err = validateSubTypeNames();
+ if (err != OK) return err;
+
return Scope::validate();
}
@@ -84,6 +88,29 @@
return OK;
}
+void CompoundType::emitInvalidSubTypeNamesError(const std::string& subTypeName,
+ const Location& location) const {
+ std::cerr << "ERROR: Type name '" << subTypeName << "' defined at "
+ << location << " conflicts with a member function of "
+ << "safe_union " << localName() << ". Consider renaming or "
+ << "moving its definition outside the safe_union scope.\n";
+}
+
+status_t CompoundType::validateSubTypeNames() const {
+ if (mStyle != STYLE_SAFE_UNION) { return OK; }
+ const auto& subTypes = Scope::getSubTypes();
+
+ for (const auto& subType : subTypes) {
+ if (subType->localName() == "getDiscriminator") {
+ emitInvalidSubTypeNamesError(subType->localName(),
+ subType->location());
+ return UNKNOWN_ERROR;
+ }
+ }
+
+ return OK;
+}
+
bool CompoundType::isCompoundType() const {
return true;
}
@@ -108,6 +135,9 @@
case STYLE_UNION: {
return "union " + localName();
}
+ case STYLE_SAFE_UNION: {
+ return "safe_union " + localName();
+ }
}
CHECK(!"Should not be here");
}
@@ -127,6 +157,7 @@
case StorageMode_Result:
return base + (containsInterface()?"":"*");
}
+ CHECK(!"Should not be here");
}
std::string CompoundType::getJavaType(bool /* forInitializer */) const {
@@ -143,6 +174,10 @@
{
return "TYPE_UNION";
}
+ case STYLE_SAFE_UNION:
+ {
+ return "TYPE_SAFE_UNION";
+ }
}
CHECK(!"Should not be here");
}
@@ -202,10 +237,11 @@
<< ");\n";
handleError(out, mode);
}
- if (mStyle != STYLE_STRUCT) {
- return;
- }
- if (needsEmbeddedReadWrite()) {
+
+ bool needEmbeddedReadWrite = needsEmbeddedReadWrite();
+ CHECK(mStyle != STYLE_UNION || !needEmbeddedReadWrite);
+
+ if (needEmbeddedReadWrite) {
emitReaderWriterEmbedded(out, 0 /* depth */, name, name, /* sanitizedName */
isReader /* nameIsPointer */, parcelObj, parcelObjIsPointer,
isReader, mode, parentName, "0 /* parentOffset */");
@@ -374,7 +410,141 @@
handleError(out, mode);
}
+void CompoundType::emitLayoutAsserts(Formatter& out, const Layout& layout,
+ const std::string& layoutName) const {
+ out << "static_assert(sizeof("
+ << fullName()
+ << layoutName
+ << ") == "
+ << layout.size
+ << ", \"wrong size\");\n";
+
+ out << "static_assert(__alignof("
+ << fullName()
+ << layoutName
+ << ") == "
+ << layout.align
+ << ", \"wrong alignment\");\n";
+}
+
+void CompoundType::emitSafeUnionTypeDeclarations(Formatter& out) const {
+ out << "struct "
+ << localName()
+ << " final {\n";
+
+ out.indent();
+
+ Scope::emitTypeDeclarations(out);
+
+ CompoundLayout layout = getCompoundAlignmentAndSize();
+ if (mFields->empty()) {
+ out.unindent();
+ out << "};\n\n";
+
+ emitLayoutAsserts(out, layout.overall, "");
+ out << "\n";
+ return;
+ }
+
+ out << "enum class hidl_discriminator : "
+ << getUnionDiscriminatorType()->getCppType(StorageMode_Stack, false)
+ << " ";
+
+ out.block([&] {
+ for (const auto& field : *mFields) {
+ out << field->name() << ",\n";
+ }
+ out << "hidl_no_init\n";
+ });
+ out << ";\n\n";
+
+ out << localName() << "();\n"
+ << "~" << localName() << "();\n"
+ << localName() << "(const " << localName() << "&);\n\n";
+
+ for (const auto& field : *mFields) {
+ out << "void "
+ << field->name()
+ << "("
+ << field->type().getCppArgumentType()
+ << ");\n";
+
+ out << field->type().getCppStackType()
+ << "& "
+ << field->name()
+ << "();\n";
+
+ out << field->type().getCppArgumentType()
+ << " "
+ << field->name()
+ << "() const;\n\n";
+ }
+
+ out << "// Utility method\n";
+ out << "hidl_discriminator getDiscriminator() const;\n\n";
+
+ out.unindent();
+ out << "private:\n";
+ out.indent();
+
+ out << "void hidl_destructUnion();\n\n";
+ out << "union hidl_union final {\n";
+ out.indent();
+
+ bool hasPointer = containsPointer();
+ for (const auto& field : *mFields) {
+
+ size_t fieldAlign, fieldSize;
+ field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
+
+ out << field->type().getCppStackType()
+ << " "
+ << field->name();
+
+ if (!hasPointer) {
+ out << " __attribute__ ((aligned("
+ << fieldAlign
+ << ")))";
+ }
+ out << ";\n";
+ }
+
+ out << "\n"
+ << "hidl_union();\n"
+ << "~hidl_union();\n";
+
+ out.unindent();
+ out << "} hidl_u;\n\n";
+ out << "hidl_discriminator hidl_d";
+
+ if (!hasPointer) {
+ out << " __attribute__ ((aligned("
+ << layout.discriminator.align << "))) ";
+ }
+ out << "{hidl_discriminator::hidl_no_init};\n";
+
+ if (!hasPointer) {
+ out << "\n";
+
+ emitLayoutAsserts(out, layout.innerStruct, "::hidl_union");
+ emitLayoutAsserts(out, layout.discriminator, "::hidl_discriminator");
+ }
+
+ out.unindent();
+ out << "};\n\n";
+
+ if (!hasPointer) {
+ emitLayoutAsserts(out, layout.overall, "");
+ out << "\n";
+ }
+}
+
void CompoundType::emitTypeDeclarations(Formatter& out) const {
+ if (mStyle == STYLE_SAFE_UNION) {
+ emitSafeUnionTypeDeclarations(out);
+ return;
+ }
+
out << ((mStyle == STYLE_STRUCT) ? "struct" : "union")
<< " "
<< localName()
@@ -405,10 +575,7 @@
size_t fieldAlign, fieldSize;
field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
- size_t pad = offset % fieldAlign;
- if (pad > 0) {
- offset += fieldAlign - pad;
- }
+ offset += Layout::getPad(offset, fieldAlign);
if (pass == 0) {
out << field->type().getCppStackType()
@@ -438,24 +605,27 @@
}
}
- size_t structAlign, structSize;
- getAlignmentAndSize(&structAlign, &structSize);
-
- out << "static_assert(sizeof("
- << fullName()
- << ") == "
- << structSize
- << ", \"wrong size\");\n";
-
- out << "static_assert(__alignof("
- << fullName()
- << ") == "
- << structAlign
- << ", \"wrong alignment\");\n\n";
+ CompoundLayout layout = getCompoundAlignmentAndSize();
+ emitLayoutAsserts(out, layout.overall, "");
+ out << "\n";
}
void CompoundType::emitTypeForwardDeclaration(Formatter& out) const {
- out << ((mStyle == STYLE_STRUCT) ? "struct" : "union") << " " << localName() << ";\n";
+ switch (mStyle) {
+ case STYLE_UNION: {
+ out << "union";
+ break;
+ }
+ case STYLE_STRUCT:
+ case STYLE_SAFE_UNION: {
+ out << "struct";
+ break;
+ }
+ default: {
+ CHECK(!"Should not be here");
+ }
+ }
+ out << " " << localName() << ";\n";
}
void CompoundType::emitPackageTypeDeclarations(Formatter& out) const {
@@ -472,15 +642,54 @@
<< "std::string os;\n";
out << "os += \"{\";\n";
- for (const NamedReference<Type>* field : *mFields) {
- out << "os += \"";
- if (field != *(mFields->begin())) {
- out << ", ";
- }
- out << "." << field->name() << " = \";\n";
- field->type().emitDump(out, "os", "o." + field->name());
+ if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+ out << "\nswitch (o.getDiscriminator()) {\n";
+ out.indent();
}
+ for (const NamedReference<Type>* field : *mFields) {
+ if (mStyle == STYLE_SAFE_UNION) {
+ out << "case "
+ << fullName()
+ << "::hidl_discriminator::"
+ << field->name()
+ << ": ";
+
+ out.block([&] {
+ out << "os += \"."
+ << field->name()
+ << " = \";\n"
+ << "os += toString(o."
+ << field->name()
+ << "());\n"
+ << "break;\n";
+ }).endl();
+ } else {
+ out << "os += \"";
+ if (field != *(mFields->begin())) {
+ out << ", ";
+ }
+ out << "." << field->name() << " = \";\n";
+ field->type().emitDump(out, "os", "o." + field->name());
+ }
+ }
+
+ if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+ out << "case " << fullName() << "::hidl_discriminator::"
+ << "hidl_no_init: ";
+
+ out.block([&] {
+ out << "break;\n";
+ }).endl();
+
+ out << "default: ";
+ out.block([&] {
+ out << "details::logAlwaysFatal(\"Unknown union discriminator.\");\n";
+ }).endl();
+
+ out.unindent();
+ out << "}\n";
+ }
out << "os += \"}\"; return os;\n";
}).endl().endl();
@@ -489,10 +698,52 @@
<< getCppArgumentType() << " " << (mFields->empty() ? "/* lhs */" : "lhs") << ", "
<< getCppArgumentType() << " " << (mFields->empty() ? "/* rhs */" : "rhs") << ") ";
out.block([&] {
- for (const auto &field : *mFields) {
- out.sIf("lhs." + field->name() + " != rhs." + field->name(), [&] {
+ if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+ out.sIf("lhs.getDiscriminator() != rhs.getDiscriminator()", [&] {
out << "return false;\n";
}).endl();
+
+ out << "switch (lhs.getDiscriminator()) {\n";
+ out.indent();
+ }
+
+ for (const auto& field : *mFields) {
+ if (mStyle == STYLE_SAFE_UNION) {
+ out << "case "
+ << fullName()
+ << "::hidl_discriminator::"
+ << field->name()
+ << ": ";
+
+ out.block([&] {
+ out << "return (lhs."
+ << field->name()
+ << "() == rhs."
+ << field->name()
+ << "());\n";
+ }).endl();
+ } else {
+ out.sIf("lhs." + field->name() + " != rhs." + field->name(), [&] {
+ out << "return false;\n";
+ }).endl();
+ }
+ }
+
+ if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+ out << "case " << fullName() << "::hidl_discriminator::"
+ << "hidl_no_init: ";
+
+ out.block([&] {
+ out << "return false;\n";
+ }).endl();
+
+ out << "default: ";
+ out.block([&] {
+ out << "details::logAlwaysFatal(\"Unknown union discriminator.\");\n";
+ }).endl();
+
+ out.unindent();
+ out << "}\n";
}
out << "return true;\n";
}).endl().endl();
@@ -548,6 +799,200 @@
}
}
+void emitSafeUnionGetterDefinition(const std::string& fieldName,
+ Formatter& out) {
+ out.block([&] {
+ out << "if (CC_UNLIKELY(hidl_d != hidl_discriminator::"
+ << fieldName
+ << ")) ";
+
+ out.block([&] {
+ out << "details::logAlwaysFatal(\"Bad safe_union access.\");\n";
+ }).endl().endl();
+
+ out << "return hidl_u."
+ << fieldName
+ << ";\n";
+ }).endl().endl();
+}
+
+void CompoundType::emitSafeUnionTypeConstructors(Formatter& out) const {
+
+ // Default constructor
+ out << fullName()
+ << "::"
+ << localName()
+ << "() {}\n\n";
+
+ // Destructor
+ out << fullName()
+ << "::~"
+ << localName()
+ << "() ";
+
+ out.block([&] {
+ out << "hidl_destructUnion();\n";
+ }).endl().endl();
+
+ // Copy constructor
+ out << fullName()
+ << "::"
+ << localName()
+ << "(const "
+ << localName()
+ << "& other) ";
+
+ out.block([&] {
+ out << "switch(other.hidl_d) ";
+ out.block([&] {
+
+ for (const auto& field : *mFields) {
+ out << "case hidl_discriminator::"
+ << field->name()
+ << ": ";
+
+ out.block([&] {
+ out << "new (&hidl_u."
+ << field->name()
+ << ") "
+ << field->type().getCppStackType()
+ << "(other.hidl_u."
+ << field->name()
+ << ");\n"
+ << "break;\n";
+ }).endl();
+ }
+
+ out << "case hidl_discriminator::hidl_no_init: { break; }\n";
+ out << "default: { details::logAlwaysFatal("
+ << "\"Unknown union discriminator.\"); }\n";
+ }).endl().endl();
+
+ out << "hidl_d = other.hidl_d;\n";
+ }).endl().endl();
+}
+
+void CompoundType::emitSafeUnionTypeDefinitions(Formatter& out) const {
+ if (mFields->empty()) { return; }
+ emitSafeUnionTypeConstructors(out);
+
+ out << "void "
+ << fullName()
+ << "::hidl_destructUnion() ";
+
+ out.block([&] {
+ out << "switch(hidl_d) ";
+ out.block([&] {
+
+ for (const auto& field : *mFields) {
+ out << "case hidl_discriminator::"
+ << field->name()
+ << ": ";
+
+ out.block([&] {
+ const std::string fullFieldName = "hidl_u." + field->name();
+ field->type().emitTypeDestructorCall(out, fullFieldName);
+ out << "break;\n";
+ }).endl();
+ }
+
+ out << "case hidl_discriminator::hidl_no_init: { break; }\n";
+ out << "default: { details::logAlwaysFatal("
+ << "\"Unknown union discriminator.\"); }\n";
+ }).endl().endl();
+
+ out << "hidl_d = hidl_discriminator::hidl_no_init;\n";
+ }).endl().endl();
+
+ CompoundLayout layout = getCompoundAlignmentAndSize();
+ for (const NamedReference<Type>* field : *mFields) {
+ // Setter
+ out << "void "
+ << fullName()
+ << "::"
+ << field->name()
+ << "("
+ << field->type().getCppArgumentType()
+ << " o) ";
+
+ out.block([&] {
+ out << "if (hidl_d != hidl_discriminator::"
+ << field->name()
+ << ") ";
+
+ out.block([&] {
+ out << "hidl_destructUnion();\n"
+ << "::std::memset(&hidl_u, 0, sizeof(hidl_u));\n\n";
+
+ out << "new (&hidl_u."
+ << field->name()
+ << ") "
+ << field->type().getCppStackType()
+ << "(o);\n";
+
+ out << "hidl_d = hidl_discriminator::"
+ << field->name()
+ << ";\n";
+ }).endl();
+
+ out << "else if (&(hidl_u."
+ << field->name()
+ << ") != &o) ";
+
+ out.block([&] {
+ out << "hidl_u."
+ << field->name()
+ << " = o;\n";
+ }).endl();
+ }).endl().endl();
+
+ // Getter (mutable)
+ out << field->type().getCppStackType()
+ << "& ("
+ << fullName()
+ << "::"
+ << field->name()
+ << ")() ";
+
+ emitSafeUnionGetterDefinition(field->name(), out);
+
+ // Getter (immutable)
+ out << field->type().getCppArgumentType()
+ << " ("
+ << fullName()
+ << "::"
+ << field->name()
+ << ")() const ";
+
+ emitSafeUnionGetterDefinition(field->name(), out);
+ }
+
+ // Trivial constructor/destructor for internal union
+ out << fullName() << "::hidl_union::hidl_union() {}\n\n"
+ << fullName() << "::hidl_union::~hidl_union() {}\n\n";
+
+ // Utility method
+ out << fullName() << "::hidl_discriminator "
+ << localName() << "::getDiscriminator() const ";
+
+ out.block([&] {
+ out << "static_assert(offsetof("
+ << fullName()
+ << ", hidl_u) == 0"
+ << ", \"wrong offset\");\n";
+
+ if (!containsPointer()) {
+ out << "static_assert(offsetof("
+ << fullName()
+ << ", hidl_d) == "
+ << layout.discriminator.offset
+ << ", \"wrong offset\");\n";
+ }
+
+ out << "\nreturn hidl_d;\n";
+ }).endl().endl();
+}
+
void CompoundType::emitTypeDefinitions(Formatter& out, const std::string& prefix) const {
std::string space = prefix.empty() ? "" : (prefix + "::");
Scope::emitTypeDefinitions(out, space + localName());
@@ -561,8 +1006,13 @@
emitResolveReferenceDef(out, prefix, true /* isReader */);
emitResolveReferenceDef(out, prefix, false /* isReader */);
}
+
+ if (mStyle == STYLE_SAFE_UNION) {
+ emitSafeUnionTypeDefinitions(out);
+ }
}
+// TODO(b/79878527): Implement Java bindings for safe unions
void CompoundType::emitJavaTypeDeclarations(Formatter& out, bool atTopLevel) const {
out << "public final ";
@@ -713,10 +1163,7 @@
size_t fieldAlign, fieldSize;
field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
- size_t pad = offset % fieldAlign;
- if (pad > 0) {
- offset += fieldAlign - pad;
- }
+ offset += Layout::getPad(offset, fieldAlign);
field->type().emitJavaFieldReaderWriter(
out, 0 /* depth */, "parcel", "_hidl_blob", field->name(),
@@ -784,10 +1231,9 @@
for (const auto& field : *mFields) {
size_t fieldAlign, fieldSize;
field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
- size_t pad = offset % fieldAlign;
- if (pad > 0) {
- offset += fieldAlign - pad;
- }
+
+ offset += Layout::getPad(offset, fieldAlign);
+
field->type().emitJavaFieldReaderWriter(
out, 0 /* depth */, "parcel", "_hidl_blob", field->name(),
"_hidl_offset + " + std::to_string(offset), false /* isReader */);
@@ -844,15 +1290,34 @@
out << "::android::status_t _hidl_err = ::android::OK;\n\n";
+ if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+ out << "switch (" << name << ".getDiscriminator()) {\n";
+ out.indent();
+ }
+
for (const auto &field : *mFields) {
if (!field->type().needsEmbeddedReadWrite()) {
continue;
}
+ if (mStyle == STYLE_SAFE_UNION) {
+ out << "case " << fullName() << "::hidl_discriminator::"
+ << field->name() << ": {\n";
+ out.indent();
+ }
+
+ const std::string fieldName = (mStyle == STYLE_SAFE_UNION)
+ ? (name + "." + field->name() + "()" + error)
+ : (name + "." + field->name() + error);
+
+ const std::string fieldOffset = (mStyle == STYLE_SAFE_UNION)
+ ? "0 /* safe_union: union offset into struct */"
+ : ("offsetof(" + fullName() + ", " + field->name() + ")");
+
field->type().emitReaderWriterEmbedded(
out,
0 /* depth */,
- name + "." + field->name() + error,
+ fieldName,
field->name() /* sanitizedName */,
false /* nameIsPointer */,
"parcel",
@@ -860,11 +1325,19 @@
isReader,
ErrorMode_Return,
"parentHandle",
- "parentOffset + offsetof("
- + fullName()
- + ", "
- + field->name()
- + ")");
+ "parentOffset + " + fieldOffset);
+
+ if (mStyle == STYLE_SAFE_UNION) {
+ out << "break;\n";
+ out.unindent();
+ out << "}\n";
+ }
+ }
+
+ if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+ out << "default: { break; }\n";
+ out.unindent();
+ out << "}\n";
}
out << "return _hidl_err;\n";
@@ -918,15 +1391,34 @@
// should not be used at all, then the #error should not be emitted.
std::string error = useParent ? "" : "\n#error\n";
+ if (mStyle == STYLE_SAFE_UNION) {
+ out << "switch (" << nameDeref << "getDiscriminator()) {\n";
+ out.indent();
+ }
+
for (const auto &field : *mFields) {
if (!field->type().needsResolveReferences()) {
continue;
}
+ if (mStyle == STYLE_SAFE_UNION) {
+ out << "case " << fullName() << "::hidl_discriminator::"
+ << field->name() << ": {\n";
+ out.indent();
+ }
+
+ const std::string fieldName = (mStyle == STYLE_SAFE_UNION)
+ ? (nameDeref + field->name() + "()")
+ : (nameDeref + field->name());
+
+ const std::string fieldOffset = (mStyle == STYLE_SAFE_UNION)
+ ? "0 /* safe_union: union offset into struct */"
+ : ("offsetof(" + fullName() + ", " + field->name() + ")");
+
field->type().emitResolveReferencesEmbedded(
out,
0 /* depth */,
- nameDeref + field->name(),
+ fieldName,
field->name() /* sanitizedName */,
false, // nameIsPointer
"parcel", // const std::string &parcelObj,
@@ -935,12 +1427,21 @@
ErrorMode_Return,
parentHandleName + error,
parentOffsetName
- + " + offsetof("
- + fullName()
- + ", "
- + field->name()
- + ")"
+ + " + "
+ + fieldOffset
+ error);
+
+ if (mStyle == STYLE_SAFE_UNION) {
+ out << "break;\n";
+ out.unindent();
+ out << "}\n";
+ }
+ }
+
+ if (mStyle == STYLE_SAFE_UNION) {
+ out << "default: { _hidl_err = ::android::BAD_VALUE; break; }\n";
+ out.unindent();
+ out << "}\n";
}
out << "return _hidl_err;\n";
@@ -950,7 +1451,7 @@
}
bool CompoundType::needsEmbeddedReadWrite() const {
- if (mStyle != STYLE_STRUCT) {
+ if (mStyle == STYLE_UNION) {
return false;
}
@@ -964,7 +1465,7 @@
}
bool CompoundType::deepNeedsResolveReferences(std::unordered_set<const Type*>* visited) const {
- if (mStyle != STYLE_STRUCT) {
+ if (mStyle == STYLE_UNION) {
return false;
}
@@ -998,6 +1499,15 @@
out << "sub_union: {\n";
break;
}
+ case STYLE_SAFE_UNION:
+ {
+ out << "sub_safe_union: {\n";
+ break;
+ }
+ default:
+ {
+ CHECK(!"Should not be here");
+ }
}
out.indent();
type->emitVtsTypeDeclarations(out);
@@ -1018,6 +1528,15 @@
out << "union_value: {\n";
break;
}
+ case STYLE_SAFE_UNION:
+ {
+ out << "safe_union_value: {\n";
+ break;
+ }
+ default:
+ {
+ CHECK(!"Should not be here");
+ }
}
out.indent();
out << "name: \"" << field->name() << "\"\n";
@@ -1037,7 +1556,7 @@
}
bool CompoundType::deepIsJavaCompatible(std::unordered_set<const Type*>* visited) const {
- if (mStyle != STYLE_STRUCT) {
+ if (mStyle != STYLE_STRUCT) { // TODO(natre): Update
return false;
}
@@ -1061,48 +1580,86 @@
}
void CompoundType::getAlignmentAndSize(size_t *align, size_t *size) const {
- *align = 1;
- *size = 0;
+ CompoundLayout layout = getCompoundAlignmentAndSize();
+ *align = layout.overall.align;
+ *size = layout.overall.size;
+}
- size_t offset = 0;
+CompoundType::CompoundLayout CompoundType::getCompoundAlignmentAndSize() const {
+ CompoundLayout compoundLayout;
+
+ // Local aliases for convenience
+ Layout& overall = compoundLayout.overall;
+ Layout& innerStruct = compoundLayout.innerStruct;
+ Layout& discriminator = compoundLayout.discriminator;
+
for (const auto &field : *mFields) {
+
// Each field is aligned according to its alignment requirement.
// The surrounding structure's alignment is the maximum of its
// fields' aligments.
-
size_t fieldAlign, fieldSize;
field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
+ size_t lPad = Layout::getPad(innerStruct.size, fieldAlign);
- size_t pad = offset % fieldAlign;
- if (pad > 0) {
- offset += fieldAlign - pad;
- }
+ innerStruct.size = (mStyle == STYLE_STRUCT)
+ ? (innerStruct.size + lPad + fieldSize)
+ : std::max(innerStruct.size, fieldSize);
- if (mStyle == STYLE_STRUCT) {
- offset += fieldSize;
- } else {
- *size = std::max(*size, fieldSize);
- }
+ innerStruct.align = std::max(innerStruct.align, fieldAlign);
+ }
- if (fieldAlign > (*align)) {
- *align = fieldAlign;
+ // Padding for the inner structure
+ innerStruct.size += Layout::getPad(innerStruct.size,
+ innerStruct.align);
+
+ if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+ getUnionDiscriminatorType()->getAlignmentAndSize(
+ &(discriminator.align), &(discriminator.size));
+
+ discriminator.offset = innerStruct.size;
+ discriminator.offset += Layout::getPad(discriminator.offset,
+ discriminator.align);
+
+ overall.size = discriminator.offset + discriminator.size;
+ } else {
+ overall.size = innerStruct.size;
+ }
+
+ // An empty struct/union still occupies a byte of space in C++.
+ if (overall.size == 0) {
+ overall.size = 1;
+ }
+
+ // Padding for the overall structure
+ overall.align = std::max(innerStruct.align, discriminator.align);
+ overall.size += Layout::getPad(overall.size, overall.align);
+
+ return compoundLayout;
+}
+
+std::unique_ptr<ScalarType> CompoundType::getUnionDiscriminatorType() const {
+ static const std::vector<std::pair<int, ScalarType::Kind> > scalars {
+ {8, ScalarType::Kind::KIND_UINT8},
+ {16, ScalarType::Kind::KIND_UINT16},
+ {32, ScalarType::Kind::KIND_UINT32},
+ };
+
+ size_t numFields = mFields->size() + 1; // +1 for no_init
+ auto kind = ScalarType::Kind::KIND_UINT64;
+
+ for (const auto& scalar : scalars) {
+ if (numFields <= (1ULL << scalar.first)) {
+ kind = scalar.second; break;
}
}
- if (mStyle == STYLE_STRUCT) {
- *size = offset;
- }
+ return std::unique_ptr<ScalarType>(new ScalarType(kind, NULL));
+}
- // Final padding to account for the structure's alignment.
- size_t pad = (*size) % (*align);
- if (pad > 0) {
- (*size) += (*align) - pad;
- }
-
- if (*size == 0) {
- // An empty struct still occupies a byte of space in C++.
- *size = 1;
- }
+size_t CompoundType::Layout::getPad(size_t offset, size_t align) {
+ size_t remainder = offset % align;
+ return (remainder > 0) ? (align - remainder) : 0;
}
} // namespace android