Handling interfaces in HIDL safe_unions (C++)

This change implements the (missing) parcelling/unparcelling code
required to handle interface types in HIDL safe unions.

Bug: 79878527
Bug: 110269925
Test: Exercised the existing hidl_test suite, added a new safe_union
containing an interface type (included in a separate CL).

Change-Id: I6711fba0ef22a64dc04a692a6fa95f149d0eb429
diff --git a/CompoundType.cpp b/CompoundType.cpp
index 2410784..4134f80 100644
--- a/CompoundType.cpp
+++ b/CompoundType.cpp
@@ -198,6 +198,95 @@
     return false;
 }
 
+void CompoundType::emitSafeUnionReaderWriterForInterfaces(
+        Formatter &out,
+        const std::string &name,
+        const std::string &parcelObj,
+        bool parcelObjIsPointer,
+        bool isReader,
+        ErrorMode mode) const {
+
+    CHECK(mStyle == STYLE_SAFE_UNION);
+    if (mFields->empty()) { return; }
+
+    out.block([&] {
+        const auto discriminatorType = getUnionDiscriminatorType();
+        if (isReader) {
+            out << discriminatorType->getCppStackType()
+                << " _hidl_d_primitive;\n";
+        } else {
+            out << "const "
+                << discriminatorType->getCppStackType()
+                << " _hidl_d_primitive = "
+                << discriminatorType->getCppTypeCast(name + ".getDiscriminator()")
+                << ";\n";
+        }
+
+        getUnionDiscriminatorType()->emitReaderWriter(out, "_hidl_d_primitive", parcelObj,
+                                                    parcelObjIsPointer, isReader, mode);
+        out << "switch (("
+            << fullName()
+            << "::hidl_discriminator) _hidl_d_primitive) ";
+
+        out.block([&] {
+            for (const auto& field : *mFields) {
+                out << "case "
+                    << fullName()
+                    << "::hidl_discriminator::"
+                    << field->name()
+                    << ": ";
+
+                const std::string tempFieldName = "_hidl_temp_" + field->name();
+                out.block([&] {
+                    if (isReader) {
+                        out << field->type().getCppResultType()
+                            << " "
+                            << tempFieldName
+                            << ";\n";
+
+                        field->type().emitReaderWriter(out, tempFieldName, parcelObj,
+                                                       parcelObjIsPointer, isReader, mode);
+
+                        const std::string derefOperator = field->type().resultNeedsDeref()
+                                                          ? "*" : "";
+                        out << name
+                            << "."
+                            << field->name()
+                            << "("
+                            << derefOperator
+                            << tempFieldName
+                            << ");\n";
+                    } else {
+                        const std::string fieldValue = name + "." + field->name() + "()";
+                        out << field->type().getCppArgumentType()
+                            << " "
+                            << tempFieldName
+                            << " = "
+                            << fieldValue
+                            << ";\n";
+
+                        field->type().emitReaderWriter(out, tempFieldName, parcelObj,
+                                                       parcelObjIsPointer, isReader, mode);
+                    }
+                    out << "break;\n";
+                }).endl();
+            }
+
+            out << "case " << fullName() << "::hidl_discriminator::"
+                << "hidl_no_init: ";
+
+            out.block([&] {
+                out << "break;\n";
+            }).endl();
+
+            out << "default: ";
+            out.block([&] {
+                out << "details::logAlwaysFatal(\"Unknown union discriminator.\");\n";
+            }).endl();
+        }).endl();
+    }).endl();
+}
+
 void CompoundType::emitReaderWriter(
         Formatter &out,
         const std::string &name,
@@ -210,9 +299,17 @@
         parcelObj + (parcelObjIsPointer ? "->" : ".");
 
     if(containsInterface()){
+        if (mStyle == STYLE_SAFE_UNION) {
+            emitSafeUnionReaderWriterForInterfaces(out, name, parcelObj,
+                                                   parcelObjIsPointer,
+                                                   isReader, mode);
+            return;
+        }
+
         for (const auto& field : *mFields) {
             field->type().emitReaderWriter(out, name + "." + field->name(),
-                                               parcelObj, parcelObjIsPointer, isReader, mode);
+                                            parcelObj, parcelObjIsPointer,
+                                            isReader, mode);
         }
     } else {
         const std::string parentName = "_hidl_" + name + "_parent";
@@ -1008,11 +1105,11 @@
         << fullName() << "::hidl_union::~hidl_union() {}\n\n";
 
     // Utility method
-    out << fullName() << "::hidl_discriminator "
-        << localName() << "::getDiscriminator() const ";
+    out << fullName() << "::hidl_discriminator ("
+        << fullName() << "::getDiscriminator)() const ";
 
     out.block([&] {
-        out << "\nreturn hidl_d;\n";
+        out << "return hidl_d;\n";
     }).endl().endl();
 }