[OPENMP] Add support for mapping memory pointed by member pointer.

Added support for map(s, s.ptr[0:1]) kind of mapping.

llvm-svn: 342648
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 9452e96..2188172 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6752,7 +6752,9 @@
       MapBaseValuesArrayTy &BasePointers, MapValuesArrayTy &Pointers,
       MapValuesArrayTy &Sizes, MapFlagsArrayTy &Types,
       StructRangeInfoTy &PartialStruct, bool IsFirstComponentList,
-      bool IsImplicit) const {
+      bool IsImplicit,
+      ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
+          OverlappedElements = llvm::None) const {
     // The following summarizes what has to be generated for each map and the
     // types below. The generated information is expressed in this order:
     // base pointer, section pointer, size, flags
@@ -7023,7 +7025,6 @@
 
         Address LB =
             CGF.EmitOMPSharedLValue(I->getAssociatedExpression()).getAddress();
-        llvm::Value *Size = getExprTypeSize(I->getAssociatedExpression());
 
         // If this component is a pointer inside the base struct then we don't
         // need to create any entry for it - it will be combined with the object
@@ -7032,6 +7033,70 @@
             IsPointer && EncounteredME &&
             (dyn_cast<MemberExpr>(I->getAssociatedExpression()) ==
              EncounteredME);
+        if (!OverlappedElements.empty()) {
+          // Handle base element with the info for overlapped elements.
+          assert(!PartialStruct.Base.isValid() && "The base element is set.");
+          assert(Next == CE &&
+                 "Expected last element for the overlapped elements.");
+          assert(!IsPointer &&
+                 "Unexpected base element with the pointer type.");
+          // Mark the whole struct as the struct that requires allocation on the
+          // device.
+          PartialStruct.LowestElem = {0, LB};
+          CharUnits TypeSize = CGF.getContext().getTypeSizeInChars(
+              I->getAssociatedExpression()->getType());
+          Address HB = CGF.Builder.CreateConstGEP(
+              CGF.Builder.CreatePointerBitCastOrAddrSpaceCast(LB,
+                                                              CGF.VoidPtrTy),
+              TypeSize.getQuantity() - 1, CharUnits::One());
+          PartialStruct.HighestElem = {
+              std::numeric_limits<decltype(
+                  PartialStruct.HighestElem.first)>::max(),
+              HB};
+          PartialStruct.Base = BP;
+          // Emit data for non-overlapped data.
+          OpenMPOffloadMappingFlags Flags =
+              OMP_MAP_MEMBER_OF |
+              getMapTypeBits(MapType, MapTypeModifier, IsImplicit,
+                             /*AddPtrFlag=*/false,
+                             /*AddIsTargetParamFlag=*/false);
+          LB = BP;
+          llvm::Value *Size = nullptr;
+          // Do bitcopy of all non-overlapped structure elements.
+          for (OMPClauseMappableExprCommon::MappableExprComponentListRef
+                   Component : OverlappedElements) {
+            Address ComponentLB = Address::invalid();
+            for (const OMPClauseMappableExprCommon::MappableComponent &MC :
+                 Component) {
+              if (MC.getAssociatedDeclaration()) {
+                ComponentLB =
+                    CGF.EmitOMPSharedLValue(MC.getAssociatedExpression())
+                        .getAddress();
+                Size = CGF.Builder.CreatePtrDiff(
+                    CGF.EmitCastToVoidPtr(ComponentLB.getPointer()),
+                    CGF.EmitCastToVoidPtr(LB.getPointer()));
+                break;
+              }
+            }
+            BasePointers.push_back(BP.getPointer());
+            Pointers.push_back(LB.getPointer());
+            Sizes.push_back(Size);
+            Types.push_back(Flags);
+            LB = CGF.Builder.CreateConstGEP(ComponentLB, 1,
+                                            CGF.getPointerSize());
+          }
+          BasePointers.push_back(BP.getPointer());
+          Pointers.push_back(LB.getPointer());
+          Size = CGF.Builder.CreatePtrDiff(
+              CGF.EmitCastToVoidPtr(
+                  CGF.Builder.CreateConstGEP(HB, 1, CharUnits::One())
+                      .getPointer()),
+              CGF.EmitCastToVoidPtr(LB.getPointer()));
+          Sizes.push_back(Size);
+          Types.push_back(Flags);
+          break;
+        }
+        llvm::Value *Size = getExprTypeSize(I->getAssociatedExpression());
         if (!IsMemberPointer) {
           BasePointers.push_back(BP.getPointer());
           Pointers.push_back(LB.getPointer());
@@ -7136,6 +7201,66 @@
     Flags |= MemberOfFlag;
   }
 
+  void getPlainLayout(const CXXRecordDecl *RD,
+                      llvm::SmallVectorImpl<const FieldDecl *> &Layout,
+                      bool AsBase) const {
+    const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
+
+    llvm::StructType *St =
+        AsBase ? RL.getBaseSubobjectLLVMType() : RL.getLLVMType();
+
+    unsigned NumElements = St->getNumElements();
+    llvm::SmallVector<
+        llvm::PointerUnion<const CXXRecordDecl *, const FieldDecl *>, 4>
+        RecordLayout(NumElements);
+
+    // Fill bases.
+    for (const auto &I : RD->bases()) {
+      if (I.isVirtual())
+        continue;
+      const auto *Base = I.getType()->getAsCXXRecordDecl();
+      // Ignore empty bases.
+      if (Base->isEmpty() || CGF.getContext()
+                                 .getASTRecordLayout(Base)
+                                 .getNonVirtualSize()
+                                 .isZero())
+        continue;
+
+      unsigned FieldIndex = RL.getNonVirtualBaseLLVMFieldNo(Base);
+      RecordLayout[FieldIndex] = Base;
+    }
+    // Fill in virtual bases.
+    for (const auto &I : RD->vbases()) {
+      const auto *Base = I.getType()->getAsCXXRecordDecl();
+      // Ignore empty bases.
+      if (Base->isEmpty())
+        continue;
+      unsigned FieldIndex = RL.getVirtualBaseIndex(Base);
+      if (RecordLayout[FieldIndex])
+        continue;
+      RecordLayout[FieldIndex] = Base;
+    }
+    // Fill in all the fields.
+    assert(!RD->isUnion() && "Unexpected union.");
+    for (const auto *Field : RD->fields()) {
+      // Fill in non-bitfields. (Bitfields always use a zero pattern, which we
+      // will fill in later.)
+      if (!Field->isBitField()) {
+        unsigned FieldIndex = RL.getLLVMFieldNo(Field);
+        RecordLayout[FieldIndex] = Field;
+      }
+    }
+    for (const llvm::PointerUnion<const CXXRecordDecl *, const FieldDecl *>
+             &Data : RecordLayout) {
+      if (Data.isNull())
+        continue;
+      if (const auto *Base = Data.dyn_cast<const CXXRecordDecl *>())
+        getPlainLayout(Base, Layout, /*AsBase=*/true);
+      else
+        Layout.push_back(Data.get<const FieldDecl *>());
+    }
+  }
+
 public:
   MappableExprsHandler(const OMPExecutableDirective &Dir, CodeGenFunction &CGF)
       : CurDir(Dir), CGF(CGF) {
@@ -7376,9 +7501,6 @@
            "Not expecting to generate map info for a variable array type!");
 
     // We need to know when we generating information for the first component
-    // associated with a capture, because the mapping flags depend on it.
-    bool IsFirstComponentList = true;
-
     const ValueDecl *VD = Cap->capturesThis()
                               ? nullptr
                               : Cap->getCapturedVar()->getCanonicalDecl();
@@ -7394,19 +7516,145 @@
       return;
     }
 
+    using MapData =
+        std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
+                   OpenMPMapClauseKind, OpenMPMapClauseKind, bool>;
+    SmallVector<MapData, 4> DeclComponentLists;
     // FIXME: MSVC 2013 seems to require this-> to find member CurDir.
-    for (const auto *C : this->CurDir.getClausesOfKind<OMPMapClause>())
+    for (const auto *C : this->CurDir.getClausesOfKind<OMPMapClause>()) {
       for (const auto &L : C->decl_component_lists(VD)) {
         assert(L.first == VD &&
                "We got information for the wrong declaration??");
         assert(!L.second.empty() &&
                "Not expecting declaration with no component lists.");
-        generateInfoForComponentList(C->getMapType(), C->getMapTypeModifier(),
-                                     L.second, BasePointers, Pointers, Sizes,
-                                     Types, PartialStruct, IsFirstComponentList,
-                                     C->isImplicit());
-        IsFirstComponentList = false;
+        DeclComponentLists.emplace_back(L.second, C->getMapType(),
+                                        C->getMapTypeModifier(),
+                                        C->isImplicit());
       }
+    }
+
+    // Find overlapping elements (including the offset from the base element).
+    llvm::SmallDenseMap<
+        const MapData *,
+        llvm::SmallVector<
+            OMPClauseMappableExprCommon::MappableExprComponentListRef, 4>,
+        4>
+        OverlappedData;
+    size_t Count = 0;
+    for (const MapData &L : DeclComponentLists) {
+      OMPClauseMappableExprCommon::MappableExprComponentListRef Components;
+      OpenMPMapClauseKind MapType;
+      OpenMPMapClauseKind MapTypeModifier;
+      bool IsImplicit;
+      std::tie(Components, MapType, MapTypeModifier, IsImplicit) = L;
+      ++Count;
+      for (const MapData &L1 : makeArrayRef(DeclComponentLists).slice(Count)) {
+        OMPClauseMappableExprCommon::MappableExprComponentListRef Components1;
+        std::tie(Components1, MapType, MapTypeModifier, IsImplicit) = L1;
+        auto CI = Components.rbegin();
+        auto CE = Components.rend();
+        auto SI = Components1.rbegin();
+        auto SE = Components1.rend();
+        for (; CI != CE && SI != SE; ++CI, ++SI) {
+          if (CI->getAssociatedExpression()->getStmtClass() !=
+              SI->getAssociatedExpression()->getStmtClass())
+            break;
+          // Are we dealing with different variables/fields?
+          if (CI->getAssociatedDeclaration() != SI->getAssociatedDeclaration())
+            break;
+        }
+        // Found overlapping if, at least for one component, reached the head of
+        // the components list.
+        if (CI == CE || SI == SE) {
+          assert((CI != CE || SI != SE) &&
+                 "Unexpected full match of the mapping components.");
+          const MapData &BaseData = CI == CE ? L : L1;
+          OMPClauseMappableExprCommon::MappableExprComponentListRef SubData =
+              SI == SE ? Components : Components1;
+          auto It = CI == CE ? SI : CI;
+          auto &OverlappedElements = OverlappedData.FindAndConstruct(&BaseData);
+          OverlappedElements.getSecond().push_back(SubData);
+        }
+      }
+    }
+    // Sort the overlapped elements for each item.
+    llvm::SmallVector<const FieldDecl *, 4> Layout;
+    if (!OverlappedData.empty()) {
+      if (const auto *CRD =
+              VD->getType().getCanonicalType()->getAsCXXRecordDecl())
+        getPlainLayout(CRD, Layout, /*AsBase=*/false);
+      else {
+        const auto *RD = VD->getType().getCanonicalType()->getAsRecordDecl();
+        Layout.append(RD->field_begin(), RD->field_end());
+      }
+    }
+    for (auto &Pair : OverlappedData) {
+      llvm::sort(
+          Pair.getSecond(),
+          [&Layout](
+              OMPClauseMappableExprCommon::MappableExprComponentListRef First,
+              OMPClauseMappableExprCommon::MappableExprComponentListRef
+                  Second) {
+            auto CI = First.rbegin();
+            auto CE = First.rend();
+            auto SI = Second.rbegin();
+            auto SE = Second.rend();
+            for (; CI != CE && SI != SE; ++CI, ++SI) {
+              if (CI->getAssociatedExpression()->getStmtClass() !=
+                  SI->getAssociatedExpression()->getStmtClass())
+                break;
+              // Are we dealing with different variables/fields?
+              if (CI->getAssociatedDeclaration() !=
+                  SI->getAssociatedDeclaration())
+                break;
+            }
+            assert(CI != CE && SI != SE &&
+                   "Unexpected end of the map components.");
+            const auto *FD1 = cast<FieldDecl>(CI->getAssociatedDeclaration());
+            const auto *FD2 = cast<FieldDecl>(SI->getAssociatedDeclaration());
+            if (FD1->getParent() == FD2->getParent())
+              return FD1->getFieldIndex() < FD2->getFieldIndex();
+            const auto It =
+                llvm::find_if(Layout, [FD1, FD2](const FieldDecl *FD) {
+                  return FD == FD1 || FD == FD2;
+                });
+            return *It == FD1;
+          });
+    }
+
+    // Associated with a capture, because the mapping flags depend on it.
+    // Go through all of the elements with the overlapped elements.
+    for (const auto &Pair : OverlappedData) {
+      const MapData &L = *Pair.getFirst();
+      OMPClauseMappableExprCommon::MappableExprComponentListRef Components;
+      OpenMPMapClauseKind MapType;
+      OpenMPMapClauseKind MapTypeModifier;
+      bool IsImplicit;
+      std::tie(Components, MapType, MapTypeModifier, IsImplicit) = L;
+      ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
+          OverlappedComponents = Pair.getSecond();
+      bool IsFirstComponentList = true;
+      generateInfoForComponentList(MapType, MapTypeModifier, Components,
+                                   BasePointers, Pointers, Sizes, Types,
+                                   PartialStruct, IsFirstComponentList,
+                                   IsImplicit, OverlappedComponents);
+    }
+    // Go through other elements without overlapped elements.
+    bool IsFirstComponentList = OverlappedData.empty();
+    for (const MapData &L : DeclComponentLists) {
+      OMPClauseMappableExprCommon::MappableExprComponentListRef Components;
+      OpenMPMapClauseKind MapType;
+      OpenMPMapClauseKind MapTypeModifier;
+      bool IsImplicit;
+      std::tie(Components, MapType, MapTypeModifier, IsImplicit) = L;
+      auto It = OverlappedData.find(&L);
+      if (It == OverlappedData.end())
+        generateInfoForComponentList(MapType, MapTypeModifier, Components,
+                                     BasePointers, Pointers, Sizes, Types,
+                                     PartialStruct, IsFirstComponentList,
+                                     IsImplicit);
+      IsFirstComponentList = false;
+    }
   }
 
   /// Generate the base pointers, section pointers, sizes and map types