Forbid RS objects from being contained in unions.

This change also refactors variable validation in general for RS.
BUG=4283858

Change-Id: I4527986a07c9cf2babdc5b855cdb1f00e3535d5b
diff --git a/slang_rs_export_type.cpp b/slang_rs_export_type.cpp
index 3e84b7e..288f439 100644
--- a/slang_rs_export_type.cpp
+++ b/slang_rs_export_type.cpp
@@ -47,28 +47,28 @@
     const clang::Type *T,
     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
     clang::Diagnostic *Diags,
-    clang::SourceManager *SM,
     const clang::VarDecl *VD,
     const clang::RecordDecl *TopLevelRecord);
 
 static void ReportTypeError(clang::Diagnostic *Diags,
-                            const clang::SourceManager *SM,
                             const clang::VarDecl *VD,
                             const clang::RecordDecl *TopLevelRecord,
                             const char *Message) {
-  if (!Diags || !SM) {
+  if (!Diags) {
     return;
   }
 
+  const clang::SourceManager &SM = Diags->getSourceManager();
+
   // Attempt to use the type declaration first (if we have one).
   // Fall back to the variable definition, if we are looking at something
   // like an array declaration that can't be exported.
   if (TopLevelRecord) {
-    Diags->Report(clang::FullSourceLoc(TopLevelRecord->getLocation(), *SM),
+    Diags->Report(clang::FullSourceLoc(TopLevelRecord->getLocation(), SM),
                   Diags->getCustomDiagID(clang::Diagnostic::Error, Message))
          << TopLevelRecord->getName();
   } else if (VD) {
-    Diags->Report(clang::FullSourceLoc(VD->getLocation(), *SM),
+    Diags->Report(clang::FullSourceLoc(VD->getLocation(), SM),
                   Diags->getCustomDiagID(clang::Diagnostic::Error, Message))
          << VD->getName();
   } else {
@@ -82,13 +82,12 @@
     const clang::ConstantArrayType *CAT,
     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
     clang::Diagnostic *Diags,
-    clang::SourceManager *SM,
     const clang::VarDecl *VD,
     const clang::RecordDecl *TopLevelRecord) {
   // Check element type
   const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
   if (ElementType->isArrayType()) {
-    ReportTypeError(Diags, SM, VD, TopLevelRecord,
+    ReportTypeError(Diags, VD, TopLevelRecord,
         "multidimensional arrays cannot be exported: '%0'");
     return NULL;
   } else if (ElementType->isExtVectorType()) {
@@ -98,20 +97,19 @@
 
     const clang::Type *BaseElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
     if (!RSExportPrimitiveType::IsPrimitiveType(BaseElementType)) {
-      ReportTypeError(Diags, SM, VD, TopLevelRecord,
+      ReportTypeError(Diags, VD, TopLevelRecord,
           "vectors of non-primitive types cannot be exported: '%0'");
       return NULL;
     }
 
     if (numElements == 3 && CAT->getSize() != 1) {
-      ReportTypeError(Diags, SM, VD, TopLevelRecord,
+      ReportTypeError(Diags, VD, TopLevelRecord,
           "arrays of width 3 vector types cannot be exported: '%0'");
       return NULL;
     }
   }
 
-  if (TypeExportableHelper(ElementType, SPS, Diags, SM, VD,
-                           TopLevelRecord) == NULL)
+  if (TypeExportableHelper(ElementType, SPS, Diags, VD, TopLevelRecord) == NULL)
     return NULL;
   else
     return CAT;
@@ -121,7 +119,6 @@
     const clang::Type *T,
     llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
     clang::Diagnostic *Diags,
-    clang::SourceManager *SM,
     const clang::VarDecl *VD,
     const clang::RecordDecl *TopLevelRecord) {
   // Normalize first
@@ -153,7 +150,7 @@
 
       // Check internal struct
       if (T->isUnionType()) {
-        ReportTypeError(Diags, SM, NULL, T->getAsUnionType()->getDecl(),
+        ReportTypeError(Diags, NULL, T->getAsUnionType()->getDecl(),
             "unions cannot be exported: '%0'");
         return NULL;
       } else if (!T->isStructureType()) {
@@ -165,7 +162,7 @@
       if (RD != NULL) {
         RD = RD->getDefinition();
         if (RD == NULL) {
-          ReportTypeError(Diags, SM, NULL, T->getAsStructureType()->getDecl(),
+          ReportTypeError(Diags, NULL, T->getAsStructureType()->getDecl(),
               "struct is not defined in this module");
           return NULL;
         }
@@ -175,7 +172,7 @@
         TopLevelRecord = RD;
       }
       if (RD->getName().empty()) {
-        ReportTypeError(Diags, SM, NULL, RD,
+        ReportTypeError(Diags, NULL, RD,
             "anonymous structures cannot be exported");
         return NULL;
       }
@@ -196,7 +193,7 @@
         const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
         FT = GET_CANONICAL_TYPE(FT);
 
-        if (!TypeExportableHelper(FT, SPS, Diags, SM, VD, TopLevelRecord)) {
+        if (!TypeExportableHelper(FT, SPS, Diags, VD, TopLevelRecord)) {
           return NULL;
         }
 
@@ -204,8 +201,9 @@
         //
         // TODO(zonr/srhines): allow bit fields of size 8, 16, 32
         if (FD->isBitField()) {
-          if (Diags && SM) {
-            Diags->Report(clang::FullSourceLoc(FD->getLocation(), *SM),
+          if (Diags) {
+            Diags->Report(clang::FullSourceLoc(FD->getLocation(),
+                                               Diags->getSourceManager()),
                           Diags->getCustomDiagID(clang::Diagnostic::Error,
                           "bit fields are not able to be exported: '%0.%1'"))
                 << RD->getName()
@@ -219,7 +217,7 @@
     }
     case clang::Type::Pointer: {
       if (TopLevelRecord) {
-        ReportTypeError(Diags, SM, NULL, TopLevelRecord,
+        ReportTypeError(Diags, NULL, TopLevelRecord,
             "structures containing pointers cannot be exported: '%0'");
         return NULL;
       }
@@ -233,7 +231,7 @@
       // We don't support pointer with array-type pointee or unsupported pointee
       // type
       if (PointeeType->isArrayType() ||
-          (TypeExportableHelper(PointeeType, SPS, Diags, SM, VD,
+          (TypeExportableHelper(PointeeType, SPS, Diags, VD,
                                 TopLevelRecord) == NULL))
         return NULL;
       else
@@ -250,7 +248,7 @@
       const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
 
       if ((ElementType->getTypeClass() != clang::Type::Builtin) ||
-          (TypeExportableHelper(ElementType, SPS, Diags, SM, VD,
+          (TypeExportableHelper(ElementType, SPS, Diags, VD,
                                 TopLevelRecord) == NULL))
         return NULL;
       else
@@ -260,7 +258,7 @@
       const clang::ConstantArrayType *CAT =
           UNSAFE_CAST_TYPE(const clang::ConstantArrayType, T);
 
-      return ConstantArrayTypeExportableHelper(CAT, SPS, Diags, SM, VD,
+      return ConstantArrayTypeExportableHelper(CAT, SPS, Diags, VD,
                                                TopLevelRecord);
     }
     default: {
@@ -271,19 +269,122 @@
 
 // Return the type that can be used to create RSExportType, will always return
 // the canonical type
-// If the Type T is not exportable, this function returns NULL. Diags and SM
-// are used to generate proper Clang diagnostic messages when a
+// If the Type T is not exportable, this function returns NULL. Diags is
+// used to generate proper Clang diagnostic messages when a
 // non-exportable type is detected. TopLevelRecord is used to capture the
 // highest struct (in the case of a nested hierarchy) for detecting other
 // types that cannot be exported (mostly pointers within a struct).
 static const clang::Type *TypeExportable(const clang::Type *T,
                                          clang::Diagnostic *Diags,
-                                         clang::SourceManager *SM,
                                          const clang::VarDecl *VD) {
   llvm::SmallPtrSet<const clang::Type*, 8> SPS =
       llvm::SmallPtrSet<const clang::Type*, 8>();
 
-  return TypeExportableHelper(T, SPS, Diags, SM, VD, NULL);
+  return TypeExportableHelper(T, SPS, Diags, VD, NULL);
+}
+
+static bool ValidateVarDeclHelper(
+    clang::VarDecl *VD,
+    const clang::Type *&T,
+    llvm::SmallPtrSet<const clang::Type*, 8>& SPS,
+    clang::RecordDecl *UnionDecl) {
+  if ((T = GET_CANONICAL_TYPE(T)) == NULL)
+    return true;
+
+  if (SPS.count(T))
+    return true;
+
+  switch (T->getTypeClass()) {
+    case clang::Type::Record: {
+      if (RSExportPrimitiveType::GetRSSpecificType(T) !=
+          RSExportPrimitiveType::DataTypeUnknown) {
+        if (!UnionDecl) {
+          return true;
+        } else if (RSExportPrimitiveType::IsRSObjectType(T)) {
+          clang::ASTContext &C = VD->getASTContext();
+          ReportTypeError(&C.getDiagnostics(), VD, UnionDecl,
+              "unions containing RS object types are not allowed");
+          return false;
+        }
+      }
+
+      clang::RecordDecl *RD = NULL;
+
+      // Check internal struct
+      if (T->isUnionType()) {
+        RD = T->getAsUnionType()->getDecl();
+        UnionDecl = RD;
+      } else if (T->isStructureType()) {
+        RD = T->getAsStructureType()->getDecl();
+      } else {
+        slangAssert(false && "Unknown type cannot be exported");
+        return false;
+      }
+
+      if (RD != NULL) {
+        RD = RD->getDefinition();
+        if (RD == NULL) {
+          // FIXME
+          return true;
+        }
+      }
+
+      // Fast check
+      if (RD->hasFlexibleArrayMember() || RD->hasObjectMember())
+        return false;
+
+      // Insert myself into checking set
+      SPS.insert(T);
+
+      // Check all elements
+      for (clang::RecordDecl::field_iterator FI = RD->field_begin(),
+               FE = RD->field_end();
+           FI != FE;
+           FI++) {
+        const clang::FieldDecl *FD = *FI;
+        const clang::Type *FT = RSExportType::GetTypeOfDecl(FD);
+        FT = GET_CANONICAL_TYPE(FT);
+
+        if (!ValidateVarDeclHelper(VD, FT, SPS, UnionDecl)) {
+          return false;
+        }
+      }
+
+      return true;
+    }
+
+    case clang::Type::Builtin: {
+      break;
+    }
+
+    case clang::Type::Pointer: {
+      const clang::PointerType *PT =
+        UNSAFE_CAST_TYPE(const clang::PointerType, T);
+      const clang::Type *PointeeType = GET_POINTEE_TYPE(PT);
+
+      return ValidateVarDeclHelper(VD, PointeeType, SPS, UnionDecl);
+    }
+
+    case clang::Type::ExtVector: {
+      const clang::ExtVectorType *EVT =
+          UNSAFE_CAST_TYPE(const clang::ExtVectorType, T);
+      const clang::Type *ElementType = GET_EXT_VECTOR_ELEMENT_TYPE(EVT);
+      return ValidateVarDeclHelper(VD, ElementType, SPS, UnionDecl);
+    }
+
+    case clang::Type::ConstantArray: {
+      const clang::ConstantArrayType *CAT =
+          UNSAFE_CAST_TYPE(const clang::ConstantArrayType, T);
+      const clang::Type *ElementType = GET_CONSTANT_ARRAY_ELEMENT_TYPE(CAT);
+      return ValidateVarDeclHelper(VD, ElementType, SPS, UnionDecl);
+    }
+
+    default: {
+      break;
+    }
+  }
+
+  return true;
 }
 
 }  // namespace
@@ -292,17 +393,17 @@
 bool RSExportType::NormalizeType(const clang::Type *&T,
                                  llvm::StringRef &TypeName,
                                  clang::Diagnostic *Diags,
-                                 clang::SourceManager *SM,
                                  const clang::VarDecl *VD) {
-  if ((T = TypeExportable(T, Diags, SM, VD)) == NULL) {
+  if ((T = TypeExportable(T, Diags, VD)) == NULL) {
     return false;
   }
   // Get type name
   TypeName = RSExportType::GetTypeName(T);
   if (TypeName.empty()) {
-    if (Diags && SM) {
+    if (Diags) {
       if (VD) {
-        Diags->Report(clang::FullSourceLoc(VD->getLocation(), *SM),
+        Diags->Report(clang::FullSourceLoc(VD->getLocation(),
+                                           Diags->getSourceManager()),
                       Diags->getCustomDiagID(clang::Diagnostic::Error,
                                              "anonymous types cannot "
                                              "be exported"));
@@ -318,6 +419,14 @@
   return true;
 }
 
+bool RSExportType::ValidateVarDecl(clang::VarDecl *VD) {
+  const clang::Type *T = VD->getType().getTypePtr();
+  llvm::SmallPtrSet<const clang::Type*, 8> SPS =
+      llvm::SmallPtrSet<const clang::Type*, 8>();
+
+  return ValidateVarDeclHelper(VD, T, SPS, NULL);
+}
+
 const clang::Type
 *RSExportType::GetTypeOfDecl(const clang::DeclaratorDecl *DD) {
   if (DD) {
@@ -390,7 +499,7 @@
       // "*" plus pointee name
       const clang::Type *PT = GET_POINTEE_TYPE(T);
       llvm::StringRef PointeeName;
-      if (NormalizeType(PT, PointeeName, NULL, NULL, NULL)) {
+      if (NormalizeType(PT, PointeeName, NULL, NULL)) {
         char *Name = new char[ 1 /* * */ + PointeeName.size() + 1 ];
         Name[0] = '*';
         memcpy(Name + 1, PointeeName.data(), PointeeName.size());
@@ -512,7 +621,7 @@
 
 RSExportType *RSExportType::Create(RSContext *Context, const clang::Type *T) {
   llvm::StringRef TypeName;
-  if (NormalizeType(T, TypeName, NULL, NULL, NULL))
+  if (NormalizeType(T, TypeName, Context->getDiagnostics(), NULL))
     return Create(Context, T, TypeName);
   else
     return NULL;
@@ -753,8 +862,8 @@
                                                      const clang::Type *T,
                                                      DataKind DK) {
   llvm::StringRef TypeName;
-  if (RSExportType::NormalizeType(T, TypeName, NULL, NULL, NULL) &&
-      IsPrimitiveType(T)) {
+  if (RSExportType::NormalizeType(T, TypeName, Context->getDiagnostics(), NULL)
+      && IsPrimitiveType(T)) {
     return Create(Context, T, TypeName, DK);
   } else {
     return NULL;
@@ -1163,7 +1272,6 @@
        FI != FE;
        FI++, Index++) {
     clang::Diagnostic *Diags = Context->getDiagnostics();
-    const clang::SourceManager *SM = Context->getSourceManager();
 
     // FIXME: All fields should be primitive type
     slangAssert((*FI)->getKind() == clang::Decl::Field);
@@ -1182,7 +1290,8 @@
           new Field(ET, FD->getName(), ERT,
                     static_cast<size_t>(RL->getFieldOffset(Index) >> 3)));
     } else {
-      Diags->Report(clang::FullSourceLoc(RD->getLocation(), *SM),
+      Diags->Report(clang::FullSourceLoc(RD->getLocation(),
+                                         Diags->getSourceManager()),
                     Diags->getCustomDiagID(clang::Diagnostic::Error,
                     "field type cannot be exported: '%0.%1'"))
           << RD->getName()