[OPENMP] Add support for mapping memory pointed by member pointer.
Added support for map(s, s.ptr[0:1]) kind of mapping.
llvm-svn: 342648
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 9452e96..2188172 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -6752,7 +6752,9 @@
MapBaseValuesArrayTy &BasePointers, MapValuesArrayTy &Pointers,
MapValuesArrayTy &Sizes, MapFlagsArrayTy &Types,
StructRangeInfoTy &PartialStruct, bool IsFirstComponentList,
- bool IsImplicit) const {
+ bool IsImplicit,
+ ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
+ OverlappedElements = llvm::None) const {
// The following summarizes what has to be generated for each map and the
// types below. The generated information is expressed in this order:
// base pointer, section pointer, size, flags
@@ -7023,7 +7025,6 @@
Address LB =
CGF.EmitOMPSharedLValue(I->getAssociatedExpression()).getAddress();
- llvm::Value *Size = getExprTypeSize(I->getAssociatedExpression());
// If this component is a pointer inside the base struct then we don't
// need to create any entry for it - it will be combined with the object
@@ -7032,6 +7033,70 @@
IsPointer && EncounteredME &&
(dyn_cast<MemberExpr>(I->getAssociatedExpression()) ==
EncounteredME);
+ if (!OverlappedElements.empty()) {
+ // Handle base element with the info for overlapped elements.
+ assert(!PartialStruct.Base.isValid() && "The base element is set.");
+ assert(Next == CE &&
+ "Expected last element for the overlapped elements.");
+ assert(!IsPointer &&
+ "Unexpected base element with the pointer type.");
+ // Mark the whole struct as the struct that requires allocation on the
+ // device.
+ PartialStruct.LowestElem = {0, LB};
+ CharUnits TypeSize = CGF.getContext().getTypeSizeInChars(
+ I->getAssociatedExpression()->getType());
+ Address HB = CGF.Builder.CreateConstGEP(
+ CGF.Builder.CreatePointerBitCastOrAddrSpaceCast(LB,
+ CGF.VoidPtrTy),
+ TypeSize.getQuantity() - 1, CharUnits::One());
+ PartialStruct.HighestElem = {
+ std::numeric_limits<decltype(
+ PartialStruct.HighestElem.first)>::max(),
+ HB};
+ PartialStruct.Base = BP;
+ // Emit data for non-overlapped data.
+ OpenMPOffloadMappingFlags Flags =
+ OMP_MAP_MEMBER_OF |
+ getMapTypeBits(MapType, MapTypeModifier, IsImplicit,
+ /*AddPtrFlag=*/false,
+ /*AddIsTargetParamFlag=*/false);
+ LB = BP;
+ llvm::Value *Size = nullptr;
+ // Do bitcopy of all non-overlapped structure elements.
+ for (OMPClauseMappableExprCommon::MappableExprComponentListRef
+ Component : OverlappedElements) {
+ Address ComponentLB = Address::invalid();
+ for (const OMPClauseMappableExprCommon::MappableComponent &MC :
+ Component) {
+ if (MC.getAssociatedDeclaration()) {
+ ComponentLB =
+ CGF.EmitOMPSharedLValue(MC.getAssociatedExpression())
+ .getAddress();
+ Size = CGF.Builder.CreatePtrDiff(
+ CGF.EmitCastToVoidPtr(ComponentLB.getPointer()),
+ CGF.EmitCastToVoidPtr(LB.getPointer()));
+ break;
+ }
+ }
+ BasePointers.push_back(BP.getPointer());
+ Pointers.push_back(LB.getPointer());
+ Sizes.push_back(Size);
+ Types.push_back(Flags);
+ LB = CGF.Builder.CreateConstGEP(ComponentLB, 1,
+ CGF.getPointerSize());
+ }
+ BasePointers.push_back(BP.getPointer());
+ Pointers.push_back(LB.getPointer());
+ Size = CGF.Builder.CreatePtrDiff(
+ CGF.EmitCastToVoidPtr(
+ CGF.Builder.CreateConstGEP(HB, 1, CharUnits::One())
+ .getPointer()),
+ CGF.EmitCastToVoidPtr(LB.getPointer()));
+ Sizes.push_back(Size);
+ Types.push_back(Flags);
+ break;
+ }
+ llvm::Value *Size = getExprTypeSize(I->getAssociatedExpression());
if (!IsMemberPointer) {
BasePointers.push_back(BP.getPointer());
Pointers.push_back(LB.getPointer());
@@ -7136,6 +7201,66 @@
Flags |= MemberOfFlag;
}
+ void getPlainLayout(const CXXRecordDecl *RD,
+ llvm::SmallVectorImpl<const FieldDecl *> &Layout,
+ bool AsBase) const {
+ const CGRecordLayout &RL = CGF.getTypes().getCGRecordLayout(RD);
+
+ llvm::StructType *St =
+ AsBase ? RL.getBaseSubobjectLLVMType() : RL.getLLVMType();
+
+ unsigned NumElements = St->getNumElements();
+ llvm::SmallVector<
+ llvm::PointerUnion<const CXXRecordDecl *, const FieldDecl *>, 4>
+ RecordLayout(NumElements);
+
+ // Fill bases.
+ for (const auto &I : RD->bases()) {
+ if (I.isVirtual())
+ continue;
+ const auto *Base = I.getType()->getAsCXXRecordDecl();
+ // Ignore empty bases.
+ if (Base->isEmpty() || CGF.getContext()
+ .getASTRecordLayout(Base)
+ .getNonVirtualSize()
+ .isZero())
+ continue;
+
+ unsigned FieldIndex = RL.getNonVirtualBaseLLVMFieldNo(Base);
+ RecordLayout[FieldIndex] = Base;
+ }
+ // Fill in virtual bases.
+ for (const auto &I : RD->vbases()) {
+ const auto *Base = I.getType()->getAsCXXRecordDecl();
+ // Ignore empty bases.
+ if (Base->isEmpty())
+ continue;
+ unsigned FieldIndex = RL.getVirtualBaseIndex(Base);
+ if (RecordLayout[FieldIndex])
+ continue;
+ RecordLayout[FieldIndex] = Base;
+ }
+ // Fill in all the fields.
+ assert(!RD->isUnion() && "Unexpected union.");
+ for (const auto *Field : RD->fields()) {
+ // Fill in non-bitfields. (Bitfields always use a zero pattern, which we
+ // will fill in later.)
+ if (!Field->isBitField()) {
+ unsigned FieldIndex = RL.getLLVMFieldNo(Field);
+ RecordLayout[FieldIndex] = Field;
+ }
+ }
+ for (const llvm::PointerUnion<const CXXRecordDecl *, const FieldDecl *>
+ &Data : RecordLayout) {
+ if (Data.isNull())
+ continue;
+ if (const auto *Base = Data.dyn_cast<const CXXRecordDecl *>())
+ getPlainLayout(Base, Layout, /*AsBase=*/true);
+ else
+ Layout.push_back(Data.get<const FieldDecl *>());
+ }
+ }
+
public:
MappableExprsHandler(const OMPExecutableDirective &Dir, CodeGenFunction &CGF)
: CurDir(Dir), CGF(CGF) {
@@ -7376,9 +7501,6 @@
"Not expecting to generate map info for a variable array type!");
// We need to know when we generating information for the first component
- // associated with a capture, because the mapping flags depend on it.
- bool IsFirstComponentList = true;
-
const ValueDecl *VD = Cap->capturesThis()
? nullptr
: Cap->getCapturedVar()->getCanonicalDecl();
@@ -7394,19 +7516,145 @@
return;
}
+ using MapData =
+ std::tuple<OMPClauseMappableExprCommon::MappableExprComponentListRef,
+ OpenMPMapClauseKind, OpenMPMapClauseKind, bool>;
+ SmallVector<MapData, 4> DeclComponentLists;
// FIXME: MSVC 2013 seems to require this-> to find member CurDir.
- for (const auto *C : this->CurDir.getClausesOfKind<OMPMapClause>())
+ for (const auto *C : this->CurDir.getClausesOfKind<OMPMapClause>()) {
for (const auto &L : C->decl_component_lists(VD)) {
assert(L.first == VD &&
"We got information for the wrong declaration??");
assert(!L.second.empty() &&
"Not expecting declaration with no component lists.");
- generateInfoForComponentList(C->getMapType(), C->getMapTypeModifier(),
- L.second, BasePointers, Pointers, Sizes,
- Types, PartialStruct, IsFirstComponentList,
- C->isImplicit());
- IsFirstComponentList = false;
+ DeclComponentLists.emplace_back(L.second, C->getMapType(),
+ C->getMapTypeModifier(),
+ C->isImplicit());
}
+ }
+
+ // Find overlapping elements (including the offset from the base element).
+ llvm::SmallDenseMap<
+ const MapData *,
+ llvm::SmallVector<
+ OMPClauseMappableExprCommon::MappableExprComponentListRef, 4>,
+ 4>
+ OverlappedData;
+ size_t Count = 0;
+ for (const MapData &L : DeclComponentLists) {
+ OMPClauseMappableExprCommon::MappableExprComponentListRef Components;
+ OpenMPMapClauseKind MapType;
+ OpenMPMapClauseKind MapTypeModifier;
+ bool IsImplicit;
+ std::tie(Components, MapType, MapTypeModifier, IsImplicit) = L;
+ ++Count;
+ for (const MapData &L1 : makeArrayRef(DeclComponentLists).slice(Count)) {
+ OMPClauseMappableExprCommon::MappableExprComponentListRef Components1;
+ std::tie(Components1, MapType, MapTypeModifier, IsImplicit) = L1;
+ auto CI = Components.rbegin();
+ auto CE = Components.rend();
+ auto SI = Components1.rbegin();
+ auto SE = Components1.rend();
+ for (; CI != CE && SI != SE; ++CI, ++SI) {
+ if (CI->getAssociatedExpression()->getStmtClass() !=
+ SI->getAssociatedExpression()->getStmtClass())
+ break;
+ // Are we dealing with different variables/fields?
+ if (CI->getAssociatedDeclaration() != SI->getAssociatedDeclaration())
+ break;
+ }
+ // Found overlapping if, at least for one component, reached the head of
+ // the components list.
+ if (CI == CE || SI == SE) {
+ assert((CI != CE || SI != SE) &&
+ "Unexpected full match of the mapping components.");
+ const MapData &BaseData = CI == CE ? L : L1;
+ OMPClauseMappableExprCommon::MappableExprComponentListRef SubData =
+ SI == SE ? Components : Components1;
+ auto It = CI == CE ? SI : CI;
+ auto &OverlappedElements = OverlappedData.FindAndConstruct(&BaseData);
+ OverlappedElements.getSecond().push_back(SubData);
+ }
+ }
+ }
+ // Sort the overlapped elements for each item.
+ llvm::SmallVector<const FieldDecl *, 4> Layout;
+ if (!OverlappedData.empty()) {
+ if (const auto *CRD =
+ VD->getType().getCanonicalType()->getAsCXXRecordDecl())
+ getPlainLayout(CRD, Layout, /*AsBase=*/false);
+ else {
+ const auto *RD = VD->getType().getCanonicalType()->getAsRecordDecl();
+ Layout.append(RD->field_begin(), RD->field_end());
+ }
+ }
+ for (auto &Pair : OverlappedData) {
+ llvm::sort(
+ Pair.getSecond(),
+ [&Layout](
+ OMPClauseMappableExprCommon::MappableExprComponentListRef First,
+ OMPClauseMappableExprCommon::MappableExprComponentListRef
+ Second) {
+ auto CI = First.rbegin();
+ auto CE = First.rend();
+ auto SI = Second.rbegin();
+ auto SE = Second.rend();
+ for (; CI != CE && SI != SE; ++CI, ++SI) {
+ if (CI->getAssociatedExpression()->getStmtClass() !=
+ SI->getAssociatedExpression()->getStmtClass())
+ break;
+ // Are we dealing with different variables/fields?
+ if (CI->getAssociatedDeclaration() !=
+ SI->getAssociatedDeclaration())
+ break;
+ }
+ assert(CI != CE && SI != SE &&
+ "Unexpected end of the map components.");
+ const auto *FD1 = cast<FieldDecl>(CI->getAssociatedDeclaration());
+ const auto *FD2 = cast<FieldDecl>(SI->getAssociatedDeclaration());
+ if (FD1->getParent() == FD2->getParent())
+ return FD1->getFieldIndex() < FD2->getFieldIndex();
+ const auto It =
+ llvm::find_if(Layout, [FD1, FD2](const FieldDecl *FD) {
+ return FD == FD1 || FD == FD2;
+ });
+ return *It == FD1;
+ });
+ }
+
+ // Associated with a capture, because the mapping flags depend on it.
+ // Go through all of the elements with the overlapped elements.
+ for (const auto &Pair : OverlappedData) {
+ const MapData &L = *Pair.getFirst();
+ OMPClauseMappableExprCommon::MappableExprComponentListRef Components;
+ OpenMPMapClauseKind MapType;
+ OpenMPMapClauseKind MapTypeModifier;
+ bool IsImplicit;
+ std::tie(Components, MapType, MapTypeModifier, IsImplicit) = L;
+ ArrayRef<OMPClauseMappableExprCommon::MappableExprComponentListRef>
+ OverlappedComponents = Pair.getSecond();
+ bool IsFirstComponentList = true;
+ generateInfoForComponentList(MapType, MapTypeModifier, Components,
+ BasePointers, Pointers, Sizes, Types,
+ PartialStruct, IsFirstComponentList,
+ IsImplicit, OverlappedComponents);
+ }
+ // Go through other elements without overlapped elements.
+ bool IsFirstComponentList = OverlappedData.empty();
+ for (const MapData &L : DeclComponentLists) {
+ OMPClauseMappableExprCommon::MappableExprComponentListRef Components;
+ OpenMPMapClauseKind MapType;
+ OpenMPMapClauseKind MapTypeModifier;
+ bool IsImplicit;
+ std::tie(Components, MapType, MapTypeModifier, IsImplicit) = L;
+ auto It = OverlappedData.find(&L);
+ if (It == OverlappedData.end())
+ generateInfoForComponentList(MapType, MapTypeModifier, Components,
+ BasePointers, Pointers, Sizes, Types,
+ PartialStruct, IsFirstComponentList,
+ IsImplicit);
+ IsFirstComponentList = false;
+ }
}
/// Generate the base pointers, section pointers, sizes and map types