Implementing Discriminated Unions in HIDL

This change augments hidl-gen to generate C++ code for HIDL safe_union
constructs. The HIDL declaration is similar to regular structs/unions:

safe_union SafeUnion {
	uint8_t a;
	string b;
}

The emitted C++ code is:

struct SafeUnion final {
public:
    enum class hidl_discriminator : uint8_t {
        a,
        b,
        hidl_no_init
    };

    SafeUnion();
    ~SafeUnion();
    SafeUnion(const SafeUnion&);

    void a(uint8_t);
    uint8_t& a();
    uint8_t a() const;

    void b(const ::android::hardware::hidl_string&);
    ::android::hardware::hidl_string& b();
    const ::android::hardware::hidl_string& b() const;

    // Utility method
    hidl_discriminator getDiscriminator() const;

private:
    void hidl_destructUnion();

    union hidl_union final {
        uint8_t a __attribute__ ((aligned(1)));
        ::android::hardware::hidl_string b __attribute__ ((aligned(8)));

	hidl_union();
	~hidl_union();
    } u;

    hidl_discriminator d __attribute__ ((aligned(1)))\
    {hidl_discriminator::hidl_no_init};
};

Also synthesizes the appropriate toString, operator== and
{read,write}Embedded{From,To}Parcel methods.

Bug: 79878527
Test: Ran the existing hidl_test suite, tested safe_union get/set
methods on a test HAL.

Change-Id: I894c144204ab3bdbe446022b2805a64a211d28b3
Merged-In: I894c144204ab3bdbe446022b2805a64a211d28b3
diff --git a/CompoundType.cpp b/CompoundType.cpp
index e0759bc..23059f0 100644
--- a/CompoundType.cpp
+++ b/CompoundType.cpp
@@ -17,6 +17,7 @@
 #include "CompoundType.h"
 
 #include "ArrayType.h"
+#include "ScalarType.h"
 #include "VectorType.h"
 
 #include <android-base/logging.h>
@@ -66,6 +67,9 @@
     status_t err = validateUniqueNames();
     if (err != OK) return err;
 
+    err = validateSubTypeNames();
+    if (err != OK) return err;
+
     return Scope::validate();
 }
 
@@ -84,6 +88,29 @@
     return OK;
 }
 
+void CompoundType::emitInvalidSubTypeNamesError(const std::string& subTypeName,
+                                                const Location& location) const {
+    std::cerr << "ERROR: Type name '" << subTypeName << "' defined at "
+              << location << " conflicts with a member function of "
+              << "safe_union " << localName() << ". Consider renaming or "
+              << "moving its definition outside the safe_union scope.\n";
+}
+
+status_t CompoundType::validateSubTypeNames() const {
+    if (mStyle != STYLE_SAFE_UNION) { return OK; }
+    const auto& subTypes = Scope::getSubTypes();
+
+    for (const auto& subType : subTypes) {
+        if (subType->localName() == "getDiscriminator") {
+            emitInvalidSubTypeNamesError(subType->localName(),
+                                         subType->location());
+            return UNKNOWN_ERROR;
+        }
+    }
+
+    return OK;
+}
+
 bool CompoundType::isCompoundType() const {
     return true;
 }
@@ -108,6 +135,9 @@
         case STYLE_UNION: {
             return "union " + localName();
         }
+        case STYLE_SAFE_UNION: {
+            return "safe_union " + localName();
+        }
     }
     CHECK(!"Should not be here");
 }
@@ -127,6 +157,7 @@
         case StorageMode_Result:
             return base + (containsInterface()?"":"*");
     }
+    CHECK(!"Should not be here");
 }
 
 std::string CompoundType::getJavaType(bool /* forInitializer */) const {
@@ -143,6 +174,10 @@
         {
             return "TYPE_UNION";
         }
+        case STYLE_SAFE_UNION:
+        {
+            return "TYPE_SAFE_UNION";
+        }
     }
     CHECK(!"Should not be here");
 }
@@ -202,10 +237,11 @@
                 << ");\n";
             handleError(out, mode);
         }
-        if (mStyle != STYLE_STRUCT) {
-            return;
-        }
-        if (needsEmbeddedReadWrite()) {
+
+        bool needEmbeddedReadWrite = needsEmbeddedReadWrite();
+        CHECK(mStyle != STYLE_UNION || !needEmbeddedReadWrite);
+
+        if (needEmbeddedReadWrite) {
             emitReaderWriterEmbedded(out, 0 /* depth */, name, name, /* sanitizedName */
                                      isReader /* nameIsPointer */, parcelObj, parcelObjIsPointer,
                                      isReader, mode, parentName, "0 /* parentOffset */");
@@ -374,7 +410,141 @@
     handleError(out, mode);
 }
 
+void CompoundType::emitLayoutAsserts(Formatter& out, const Layout& layout,
+                                     const std::string& layoutName) const {
+    out << "static_assert(sizeof("
+        << fullName()
+        << layoutName
+        << ") == "
+        << layout.size
+        << ", \"wrong size\");\n";
+
+    out << "static_assert(__alignof("
+        << fullName()
+        << layoutName
+        << ") == "
+        << layout.align
+        << ", \"wrong alignment\");\n";
+}
+
+void CompoundType::emitSafeUnionTypeDeclarations(Formatter& out) const {
+    out << "struct "
+        << localName()
+        << " final {\n";
+
+    out.indent();
+
+    Scope::emitTypeDeclarations(out);
+
+    CompoundLayout layout = getCompoundAlignmentAndSize();
+    if (mFields->empty()) {
+        out.unindent();
+        out << "};\n\n";
+
+        emitLayoutAsserts(out, layout.overall, "");
+        out << "\n";
+        return;
+    }
+
+    out << "enum class hidl_discriminator : "
+        << getUnionDiscriminatorType()->getCppType(StorageMode_Stack, false)
+        << " ";
+
+    out.block([&] {
+        for (const auto& field : *mFields) {
+            out << field->name() << ",\n";
+        }
+        out << "hidl_no_init\n";
+    });
+    out << ";\n\n";
+
+    out << localName() << "();\n"
+        << "~" << localName() << "();\n"
+        << localName() << "(const " << localName() << "&);\n\n";
+
+    for (const auto& field : *mFields) {
+        out << "void "
+            << field->name()
+            << "("
+            << field->type().getCppArgumentType()
+            << ");\n";
+
+        out << field->type().getCppStackType()
+            << "& "
+            << field->name()
+            << "();\n";
+
+        out << field->type().getCppArgumentType()
+            << " "
+            << field->name()
+            << "() const;\n\n";
+    }
+
+    out << "// Utility method\n";
+    out << "hidl_discriminator getDiscriminator() const;\n\n";
+
+    out.unindent();
+    out << "private:\n";
+    out.indent();
+
+    out << "void hidl_destructUnion();\n\n";
+    out << "union hidl_union final {\n";
+    out.indent();
+
+    bool hasPointer = containsPointer();
+    for (const auto& field : *mFields) {
+
+        size_t fieldAlign, fieldSize;
+        field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
+
+        out << field->type().getCppStackType()
+            << " "
+            << field->name();
+
+        if (!hasPointer) {
+            out << " __attribute__ ((aligned("
+                << fieldAlign
+                << ")))";
+        }
+        out << ";\n";
+    }
+
+    out << "\n"
+        << "hidl_union();\n"
+        << "~hidl_union();\n";
+
+    out.unindent();
+    out << "} hidl_u;\n\n";
+    out << "hidl_discriminator hidl_d";
+
+    if (!hasPointer) {
+        out << " __attribute__ ((aligned("
+            << layout.discriminator.align << "))) ";
+    }
+    out << "{hidl_discriminator::hidl_no_init};\n";
+
+    if (!hasPointer) {
+        out << "\n";
+
+        emitLayoutAsserts(out, layout.innerStruct, "::hidl_union");
+        emitLayoutAsserts(out, layout.discriminator, "::hidl_discriminator");
+    }
+
+    out.unindent();
+    out << "};\n\n";
+
+    if (!hasPointer) {
+        emitLayoutAsserts(out, layout.overall, "");
+        out << "\n";
+    }
+}
+
 void CompoundType::emitTypeDeclarations(Formatter& out) const {
+    if (mStyle == STYLE_SAFE_UNION) {
+        emitSafeUnionTypeDeclarations(out);
+        return;
+    }
+
     out << ((mStyle == STYLE_STRUCT) ? "struct" : "union")
         << " "
         << localName()
@@ -405,10 +575,7 @@
             size_t fieldAlign, fieldSize;
             field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
 
-            size_t pad = offset % fieldAlign;
-            if (pad > 0) {
-                offset += fieldAlign - pad;
-            }
+            offset += Layout::getPad(offset, fieldAlign);
 
             if (pass == 0) {
                 out << field->type().getCppStackType()
@@ -438,24 +605,27 @@
         }
     }
 
-    size_t structAlign, structSize;
-    getAlignmentAndSize(&structAlign, &structSize);
-
-    out << "static_assert(sizeof("
-        << fullName()
-        << ") == "
-        << structSize
-        << ", \"wrong size\");\n";
-
-    out << "static_assert(__alignof("
-        << fullName()
-        << ") == "
-        << structAlign
-        << ", \"wrong alignment\");\n\n";
+    CompoundLayout layout = getCompoundAlignmentAndSize();
+    emitLayoutAsserts(out, layout.overall, "");
+    out << "\n";
 }
 
 void CompoundType::emitTypeForwardDeclaration(Formatter& out) const {
-    out << ((mStyle == STYLE_STRUCT) ? "struct" : "union") << " " << localName() << ";\n";
+    switch (mStyle) {
+        case STYLE_UNION: {
+            out << "union";
+            break;
+        }
+        case STYLE_STRUCT:
+        case STYLE_SAFE_UNION: {
+            out << "struct";
+            break;
+        }
+        default: {
+            CHECK(!"Should not be here");
+        }
+    }
+    out << " " << localName() << ";\n";
 }
 
 void CompoundType::emitPackageTypeDeclarations(Formatter& out) const {
@@ -472,15 +642,54 @@
             << "std::string os;\n";
         out << "os += \"{\";\n";
 
-        for (const NamedReference<Type>* field : *mFields) {
-            out << "os += \"";
-            if (field != *(mFields->begin())) {
-                out << ", ";
-            }
-            out << "." << field->name() << " = \";\n";
-            field->type().emitDump(out, "os", "o." + field->name());
+        if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+            out << "\nswitch (o.getDiscriminator()) {\n";
+            out.indent();
         }
 
+        for (const NamedReference<Type>* field : *mFields) {
+            if (mStyle == STYLE_SAFE_UNION) {
+                out << "case "
+                    << fullName()
+                    << "::hidl_discriminator::"
+                    << field->name()
+                    << ": ";
+
+                out.block([&] {
+                    out << "os += \"."
+                    << field->name()
+                    << " = \";\n"
+                    << "os += toString(o."
+                    << field->name()
+                    << "());\n"
+                    << "break;\n";
+                }).endl();
+            } else {
+                out << "os += \"";
+                if (field != *(mFields->begin())) {
+                    out << ", ";
+                }
+                out << "." << field->name() << " = \";\n";
+                field->type().emitDump(out, "os", "o." + field->name());
+            }
+        }
+
+        if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+            out << "case " << fullName() << "::hidl_discriminator::"
+                << "hidl_no_init: ";
+
+            out.block([&] {
+                out << "break;\n";
+            }).endl();
+
+            out << "default: ";
+            out.block([&] {
+                out << "details::logAlwaysFatal(\"Unknown union discriminator.\");\n";
+            }).endl();
+
+            out.unindent();
+            out << "}\n";
+        }
         out << "os += \"}\"; return os;\n";
     }).endl().endl();
 
@@ -489,10 +698,52 @@
             << getCppArgumentType() << " " << (mFields->empty() ? "/* lhs */" : "lhs") << ", "
             << getCppArgumentType() << " " << (mFields->empty() ? "/* rhs */" : "rhs") << ") ";
         out.block([&] {
-            for (const auto &field : *mFields) {
-                out.sIf("lhs." + field->name() + " != rhs." + field->name(), [&] {
+            if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+                out.sIf("lhs.getDiscriminator() != rhs.getDiscriminator()", [&] {
                     out << "return false;\n";
                 }).endl();
+
+                out << "switch (lhs.getDiscriminator()) {\n";
+                out.indent();
+            }
+
+            for (const auto& field : *mFields) {
+                if (mStyle == STYLE_SAFE_UNION) {
+                    out << "case "
+                        << fullName()
+                        << "::hidl_discriminator::"
+                        << field->name()
+                        << ": ";
+
+                    out.block([&] {
+                        out << "return (lhs."
+                        << field->name()
+                        << "() == rhs."
+                        << field->name()
+                        << "());\n";
+                    }).endl();
+                } else {
+                    out.sIf("lhs." + field->name() + " != rhs." + field->name(), [&] {
+                        out << "return false;\n";
+                    }).endl();
+                }
+            }
+
+            if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+                out << "case " << fullName() << "::hidl_discriminator::"
+                    << "hidl_no_init: ";
+
+                out.block([&] {
+                    out << "return false;\n";
+                }).endl();
+
+                out << "default: ";
+                out.block([&] {
+                    out << "details::logAlwaysFatal(\"Unknown union discriminator.\");\n";
+                }).endl();
+
+                out.unindent();
+                out << "}\n";
             }
             out << "return true;\n";
         }).endl().endl();
@@ -548,6 +799,200 @@
     }
 }
 
+void emitSafeUnionGetterDefinition(const std::string& fieldName,
+                                   Formatter& out) {
+    out.block([&] {
+        out << "if (CC_UNLIKELY(hidl_d != hidl_discriminator::"
+            << fieldName
+            << ")) ";
+
+        out.block([&] {
+            out << "details::logAlwaysFatal(\"Bad safe_union access.\");\n";
+        }).endl().endl();
+
+        out << "return hidl_u."
+            << fieldName
+            << ";\n";
+    }).endl().endl();
+}
+
+void CompoundType::emitSafeUnionTypeConstructors(Formatter& out) const {
+
+    // Default constructor
+    out << fullName()
+        << "::"
+        << localName()
+        << "() {}\n\n";
+
+    // Destructor
+    out << fullName()
+        << "::~"
+        << localName()
+        << "() ";
+
+    out.block([&] {
+        out << "hidl_destructUnion();\n";
+    }).endl().endl();
+
+    // Copy constructor
+    out << fullName()
+        << "::"
+        << localName()
+        << "(const "
+        << localName()
+        << "& other) ";
+
+    out.block([&] {
+        out << "switch(other.hidl_d) ";
+        out.block([&] {
+
+            for (const auto& field : *mFields) {
+                out << "case hidl_discriminator::"
+                    << field->name()
+                    << ": ";
+
+                out.block([&] {
+                    out << "new (&hidl_u."
+                        << field->name()
+                        << ") "
+                        << field->type().getCppStackType()
+                        << "(other.hidl_u."
+                        << field->name()
+                        << ");\n"
+                        << "break;\n";
+                }).endl();
+            }
+
+            out << "case hidl_discriminator::hidl_no_init: { break; }\n";
+            out << "default: { details::logAlwaysFatal("
+                << "\"Unknown union discriminator.\"); }\n";
+        }).endl().endl();
+
+        out << "hidl_d = other.hidl_d;\n";
+    }).endl().endl();
+}
+
+void CompoundType::emitSafeUnionTypeDefinitions(Formatter& out) const {
+    if (mFields->empty()) { return; }
+    emitSafeUnionTypeConstructors(out);
+
+    out << "void "
+        << fullName()
+        << "::hidl_destructUnion() ";
+
+    out.block([&] {
+        out << "switch(hidl_d) ";
+        out.block([&] {
+
+            for (const auto& field : *mFields) {
+                out << "case hidl_discriminator::"
+                    << field->name()
+                    << ": ";
+
+                out.block([&] {
+                    const std::string fullFieldName = "hidl_u." + field->name();
+                    field->type().emitTypeDestructorCall(out, fullFieldName);
+                    out << "break;\n";
+                }).endl();
+            }
+
+            out << "case hidl_discriminator::hidl_no_init: { break; }\n";
+            out << "default: { details::logAlwaysFatal("
+                << "\"Unknown union discriminator.\"); }\n";
+        }).endl().endl();
+
+        out << "hidl_d = hidl_discriminator::hidl_no_init;\n";
+    }).endl().endl();
+
+    CompoundLayout layout = getCompoundAlignmentAndSize();
+    for (const NamedReference<Type>* field : *mFields) {
+        // Setter
+        out << "void "
+            << fullName()
+            << "::"
+            << field->name()
+            << "("
+            << field->type().getCppArgumentType()
+            << " o) ";
+
+        out.block([&] {
+            out << "if (hidl_d != hidl_discriminator::"
+                << field->name()
+                << ") ";
+
+            out.block([&] {
+                out << "hidl_destructUnion();\n"
+                    << "::std::memset(&hidl_u, 0, sizeof(hidl_u));\n\n";
+
+                out << "new (&hidl_u."
+                    << field->name()
+                    << ") "
+                    << field->type().getCppStackType()
+                    << "(o);\n";
+
+                out << "hidl_d = hidl_discriminator::"
+                    << field->name()
+                    << ";\n";
+            }).endl();
+
+            out << "else if (&(hidl_u."
+                << field->name()
+                << ") != &o) ";
+
+            out.block([&] {
+                out << "hidl_u."
+                    << field->name()
+                    << " = o;\n";
+            }).endl();
+        }).endl().endl();
+
+        // Getter (mutable)
+        out << field->type().getCppStackType()
+            << "& ("
+            << fullName()
+            << "::"
+            << field->name()
+            << ")() ";
+
+        emitSafeUnionGetterDefinition(field->name(), out);
+
+        // Getter (immutable)
+        out << field->type().getCppArgumentType()
+            << " ("
+            << fullName()
+            << "::"
+            << field->name()
+            << ")() const ";
+
+        emitSafeUnionGetterDefinition(field->name(), out);
+    }
+
+    // Trivial constructor/destructor for internal union
+    out << fullName() << "::hidl_union::hidl_union() {}\n\n"
+        << fullName() << "::hidl_union::~hidl_union() {}\n\n";
+
+    // Utility method
+    out << fullName() << "::hidl_discriminator "
+        << localName() << "::getDiscriminator() const ";
+
+    out.block([&] {
+        out << "static_assert(offsetof("
+            << fullName()
+            << ", hidl_u) == 0"
+            << ", \"wrong offset\");\n";
+
+        if (!containsPointer()) {
+            out << "static_assert(offsetof("
+                << fullName()
+                << ", hidl_d) == "
+                << layout.discriminator.offset
+                << ", \"wrong offset\");\n";
+        }
+
+        out << "\nreturn hidl_d;\n";
+    }).endl().endl();
+}
+
 void CompoundType::emitTypeDefinitions(Formatter& out, const std::string& prefix) const {
     std::string space = prefix.empty() ? "" : (prefix + "::");
     Scope::emitTypeDefinitions(out, space + localName());
@@ -561,8 +1006,13 @@
         emitResolveReferenceDef(out, prefix, true /* isReader */);
         emitResolveReferenceDef(out, prefix, false /* isReader */);
     }
+
+    if (mStyle == STYLE_SAFE_UNION) {
+        emitSafeUnionTypeDefinitions(out);
+    }
 }
 
+// TODO(b/79878527): Implement Java bindings for safe unions
 void CompoundType::emitJavaTypeDeclarations(Formatter& out, bool atTopLevel) const {
     out << "public final ";
 
@@ -713,10 +1163,7 @@
             size_t fieldAlign, fieldSize;
             field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
 
-            size_t pad = offset % fieldAlign;
-            if (pad > 0) {
-                offset += fieldAlign - pad;
-            }
+            offset += Layout::getPad(offset, fieldAlign);
 
             field->type().emitJavaFieldReaderWriter(
                 out, 0 /* depth */, "parcel", "_hidl_blob", field->name(),
@@ -784,10 +1231,9 @@
         for (const auto& field : *mFields) {
             size_t fieldAlign, fieldSize;
             field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
-            size_t pad = offset % fieldAlign;
-            if (pad > 0) {
-                offset += fieldAlign - pad;
-            }
+
+            offset += Layout::getPad(offset, fieldAlign);
+
             field->type().emitJavaFieldReaderWriter(
                 out, 0 /* depth */, "parcel", "_hidl_blob", field->name(),
                 "_hidl_offset + " + std::to_string(offset), false /* isReader */);
@@ -844,15 +1290,34 @@
 
     out << "::android::status_t _hidl_err = ::android::OK;\n\n";
 
+    if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+        out << "switch (" << name << ".getDiscriminator()) {\n";
+        out.indent();
+    }
+
     for (const auto &field : *mFields) {
         if (!field->type().needsEmbeddedReadWrite()) {
             continue;
         }
 
+        if (mStyle == STYLE_SAFE_UNION) {
+            out << "case " << fullName() << "::hidl_discriminator::"
+                << field->name() << ": {\n";
+            out.indent();
+        }
+
+        const std::string fieldName = (mStyle == STYLE_SAFE_UNION)
+                                        ? (name + "." + field->name() + "()" + error)
+                                        : (name + "." + field->name() + error);
+
+        const std::string fieldOffset = (mStyle == STYLE_SAFE_UNION)
+                                        ? "0 /* safe_union: union offset into struct */"
+                                        : ("offsetof(" + fullName() + ", " + field->name() + ")");
+
         field->type().emitReaderWriterEmbedded(
                 out,
                 0 /* depth */,
-                name + "." + field->name() + error,
+                fieldName,
                 field->name() /* sanitizedName */,
                 false /* nameIsPointer */,
                 "parcel",
@@ -860,11 +1325,19 @@
                 isReader,
                 ErrorMode_Return,
                 "parentHandle",
-                "parentOffset + offsetof("
-                    + fullName()
-                    + ", "
-                    + field->name()
-                    + ")");
+                "parentOffset + " + fieldOffset);
+
+        if (mStyle == STYLE_SAFE_UNION) {
+            out << "break;\n";
+            out.unindent();
+            out << "}\n";
+        }
+    }
+
+    if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+        out << "default: { break; }\n";
+        out.unindent();
+        out << "}\n";
     }
 
     out << "return _hidl_err;\n";
@@ -918,15 +1391,34 @@
     // should not be used at all, then the #error should not be emitted.
     std::string error = useParent ? "" : "\n#error\n";
 
+    if (mStyle == STYLE_SAFE_UNION) {
+        out << "switch (" << nameDeref << "getDiscriminator()) {\n";
+        out.indent();
+    }
+
     for (const auto &field : *mFields) {
         if (!field->type().needsResolveReferences()) {
             continue;
         }
 
+        if (mStyle == STYLE_SAFE_UNION) {
+            out << "case " << fullName() << "::hidl_discriminator::"
+                << field->name() << ": {\n";
+            out.indent();
+        }
+
+        const std::string fieldName = (mStyle == STYLE_SAFE_UNION)
+                                        ? (nameDeref + field->name() + "()")
+                                        : (nameDeref + field->name());
+
+        const std::string fieldOffset = (mStyle == STYLE_SAFE_UNION)
+                                        ? "0 /* safe_union: union offset into struct */"
+                                        : ("offsetof(" + fullName() + ", " + field->name() + ")");
+
         field->type().emitResolveReferencesEmbedded(
             out,
             0 /* depth */,
-            nameDeref + field->name(),
+            fieldName,
             field->name() /* sanitizedName */,
             false,    // nameIsPointer
             "parcel", // const std::string &parcelObj,
@@ -935,12 +1427,21 @@
             ErrorMode_Return,
             parentHandleName + error,
             parentOffsetName
-                + " + offsetof("
-                + fullName()
-                + ", "
-                + field->name()
-                + ")"
+                + " + "
+                + fieldOffset
                 + error);
+
+        if (mStyle == STYLE_SAFE_UNION) {
+            out << "break;\n";
+            out.unindent();
+            out << "}\n";
+        }
+    }
+
+    if (mStyle == STYLE_SAFE_UNION) {
+        out << "default: { _hidl_err = ::android::BAD_VALUE; break; }\n";
+        out.unindent();
+        out << "}\n";
     }
 
     out << "return _hidl_err;\n";
@@ -950,7 +1451,7 @@
 }
 
 bool CompoundType::needsEmbeddedReadWrite() const {
-    if (mStyle != STYLE_STRUCT) {
+    if (mStyle == STYLE_UNION) {
         return false;
     }
 
@@ -964,7 +1465,7 @@
 }
 
 bool CompoundType::deepNeedsResolveReferences(std::unordered_set<const Type*>* visited) const {
-    if (mStyle != STYLE_STRUCT) {
+    if (mStyle == STYLE_UNION) {
         return false;
     }
 
@@ -998,6 +1499,15 @@
                 out << "sub_union: {\n";
                 break;
             }
+            case STYLE_SAFE_UNION:
+            {
+                out << "sub_safe_union: {\n";
+                break;
+            }
+            default:
+            {
+                CHECK(!"Should not be here");
+            }
         }
         out.indent();
         type->emitVtsTypeDeclarations(out);
@@ -1018,6 +1528,15 @@
                 out << "union_value: {\n";
                 break;
             }
+            case STYLE_SAFE_UNION:
+            {
+                out << "safe_union_value: {\n";
+                break;
+            }
+            default:
+            {
+                CHECK(!"Should not be here");
+            }
         }
         out.indent();
         out << "name: \"" << field->name() << "\"\n";
@@ -1037,7 +1556,7 @@
 }
 
 bool CompoundType::deepIsJavaCompatible(std::unordered_set<const Type*>* visited) const {
-    if (mStyle != STYLE_STRUCT) {
+    if (mStyle != STYLE_STRUCT) {  // TODO(natre): Update
         return false;
     }
 
@@ -1061,48 +1580,86 @@
 }
 
 void CompoundType::getAlignmentAndSize(size_t *align, size_t *size) const {
-    *align = 1;
-    *size = 0;
+    CompoundLayout layout = getCompoundAlignmentAndSize();
+    *align = layout.overall.align;
+    *size = layout.overall.size;
+}
 
-    size_t offset = 0;
+CompoundType::CompoundLayout CompoundType::getCompoundAlignmentAndSize() const {
+    CompoundLayout compoundLayout;
+
+    // Local aliases for convenience
+    Layout& overall = compoundLayout.overall;
+    Layout& innerStruct = compoundLayout.innerStruct;
+    Layout& discriminator = compoundLayout.discriminator;
+
     for (const auto &field : *mFields) {
+
         // Each field is aligned according to its alignment requirement.
         // The surrounding structure's alignment is the maximum of its
         // fields' aligments.
-
         size_t fieldAlign, fieldSize;
         field->type().getAlignmentAndSize(&fieldAlign, &fieldSize);
+        size_t lPad = Layout::getPad(innerStruct.size, fieldAlign);
 
-        size_t pad = offset % fieldAlign;
-        if (pad > 0) {
-            offset += fieldAlign - pad;
-        }
+        innerStruct.size = (mStyle == STYLE_STRUCT)
+                            ? (innerStruct.size + lPad + fieldSize)
+                            : std::max(innerStruct.size, fieldSize);
 
-        if (mStyle == STYLE_STRUCT) {
-            offset += fieldSize;
-        } else {
-            *size = std::max(*size, fieldSize);
-        }
+        innerStruct.align = std::max(innerStruct.align, fieldAlign);
+    }
 
-        if (fieldAlign > (*align)) {
-            *align = fieldAlign;
+    // Padding for the inner structure
+    innerStruct.size += Layout::getPad(innerStruct.size,
+                                       innerStruct.align);
+
+    if (mStyle == STYLE_SAFE_UNION && !mFields->empty()) {
+        getUnionDiscriminatorType()->getAlignmentAndSize(
+            &(discriminator.align), &(discriminator.size));
+
+        discriminator.offset = innerStruct.size;
+        discriminator.offset += Layout::getPad(discriminator.offset,
+                                               discriminator.align);
+
+        overall.size = discriminator.offset + discriminator.size;
+    } else {
+        overall.size = innerStruct.size;
+    }
+
+    // An empty struct/union still occupies a byte of space in C++.
+    if (overall.size == 0) {
+        overall.size = 1;
+    }
+
+    // Padding for the overall structure
+    overall.align = std::max(innerStruct.align, discriminator.align);
+    overall.size += Layout::getPad(overall.size, overall.align);
+
+    return compoundLayout;
+}
+
+std::unique_ptr<ScalarType> CompoundType::getUnionDiscriminatorType() const {
+    static const std::vector<std::pair<int, ScalarType::Kind> > scalars {
+        {8, ScalarType::Kind::KIND_UINT8},
+        {16, ScalarType::Kind::KIND_UINT16},
+        {32, ScalarType::Kind::KIND_UINT32},
+    };
+
+    size_t numFields = mFields->size() + 1;  // +1 for no_init
+    auto kind = ScalarType::Kind::KIND_UINT64;
+
+    for (const auto& scalar : scalars) {
+        if (numFields <= (1ULL << scalar.first)) {
+            kind = scalar.second; break;
         }
     }
 
-    if (mStyle == STYLE_STRUCT) {
-        *size = offset;
-    }
+    return std::unique_ptr<ScalarType>(new ScalarType(kind, NULL));
+}
 
-    // Final padding to account for the structure's alignment.
-    size_t pad = (*size) % (*align);
-    if (pad > 0) {
-        (*size) += (*align) - pad;
-    }
-
-    if (*size == 0) {
-        // An empty struct still occupies a byte of space in C++.
-        *size = 1;
-    }
+size_t CompoundType::Layout::getPad(size_t offset, size_t align) {
+    size_t remainder = offset % align;
+    return (remainder > 0) ? (align - remainder) : 0;
 }
 
 }  // namespace android