[OPENMP50]Initial support for use_device_addr clause.
Summary:
Added parsing/sema analysis/serialization support for use_device_addr
clauses.
Reviewers: jdoerfert
Subscribers: yaxunl, guansong, arphaman, sstefan1, llvm-commits, cfe-commits, caomhin
Tags: #clang, #llvm
Differential Revision: https://reviews.llvm.org/D80404
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 14c4c78..fa1c80f 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -136,6 +136,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
@@ -227,6 +228,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
@@ -1198,6 +1200,53 @@
return new (Mem) OMPUseDevicePtrClause(Sizes);
}
+OMPUseDeviceAddrClause *
+OMPUseDeviceAddrClause::Create(const ASTContext &C, const OMPVarListLocTy &Locs,
+ ArrayRef<Expr *> Vars,
+ ArrayRef<ValueDecl *> Declarations,
+ MappableExprComponentListsRef ComponentLists) {
+ OMPMappableExprListSizeTy Sizes;
+ Sizes.NumVars = Vars.size();
+ Sizes.NumUniqueDeclarations = getUniqueDeclarationsTotalNumber(Declarations);
+ Sizes.NumComponentLists = ComponentLists.size();
+ Sizes.NumComponents = getComponentsTotalNumber(ComponentLists);
+
+ // We need to allocate:
+ // 3 x NumVars x Expr* - we have an original list expression for each clause
+ // list entry and an equal number of private copies and inits.
+ // NumUniqueDeclarations x ValueDecl* - unique base declarations associated
+ // with each component list.
+ // (NumUniqueDeclarations + NumComponentLists) x unsigned - we specify the
+ // number of lists for each unique declaration and the size of each component
+ // list.
+ // NumComponents x MappableComponent - the total of all the components in all
+ // the lists.
+ void *Mem = C.Allocate(
+ totalSizeToAlloc<Expr *, ValueDecl *, unsigned,
+ OMPClauseMappableExprCommon::MappableComponent>(
+ Sizes.NumVars, Sizes.NumUniqueDeclarations,
+ Sizes.NumUniqueDeclarations + Sizes.NumComponentLists,
+ Sizes.NumComponents));
+
+ auto *Clause = new (Mem) OMPUseDeviceAddrClause(Locs, Sizes);
+
+ Clause->setVarRefs(Vars);
+ Clause->setClauseInfo(Declarations, ComponentLists);
+ return Clause;
+}
+
+OMPUseDeviceAddrClause *
+OMPUseDeviceAddrClause::CreateEmpty(const ASTContext &C,
+ const OMPMappableExprListSizeTy &Sizes) {
+ void *Mem = C.Allocate(
+ totalSizeToAlloc<Expr *, ValueDecl *, unsigned,
+ OMPClauseMappableExprCommon::MappableComponent>(
+ Sizes.NumVars, Sizes.NumUniqueDeclarations,
+ Sizes.NumUniqueDeclarations + Sizes.NumComponentLists,
+ Sizes.NumComponents));
+ return new (Mem) OMPUseDeviceAddrClause(Sizes);
+}
+
OMPIsDevicePtrClause *
OMPIsDevicePtrClause::Create(const ASTContext &C, const OMPVarListLocTy &Locs,
ArrayRef<Expr *> Vars,
@@ -1934,6 +1983,15 @@
}
}
+void OMPClausePrinter::VisitOMPUseDeviceAddrClause(
+ OMPUseDeviceAddrClause *Node) {
+ if (!Node->varlist_empty()) {
+ OS << "use_device_addr";
+ VisitOMPClauseList(Node, '(');
+ OS << ")";
+ }
+}
+
void OMPClausePrinter::VisitOMPIsDevicePtrClause(OMPIsDevicePtrClause *Node) {
if (!Node->varlist_empty()) {
OS << "is_device_ptr";
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index bd2eeb6..e573c04 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -784,6 +784,10 @@
const OMPUseDevicePtrClause *C) {
VisitOMPClauseList(C);
}
+void OMPClauseProfiler::VisitOMPUseDeviceAddrClause(
+ const OMPUseDeviceAddrClause *C) {
+ VisitOMPClauseList(C);
+}
void OMPClauseProfiler::VisitOMPIsDevicePtrClause(
const OMPIsDevicePtrClause *C) {
VisitOMPClauseList(C);
diff --git a/clang/lib/Basic/OpenMPKinds.cpp b/clang/lib/Basic/OpenMPKinds.cpp
index 8dddb66..a000e4d 100644
--- a/clang/lib/Basic/OpenMPKinds.cpp
+++ b/clang/lib/Basic/OpenMPKinds.cpp
@@ -163,6 +163,7 @@
case OMPC_hint:
case OMPC_uniform:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
@@ -411,6 +412,7 @@
case OMPC_hint:
case OMPC_uniform:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index d12aa65..ae4e340 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -4730,6 +4730,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
diff --git a/clang/lib/Parse/ParseOpenMP.cpp b/clang/lib/Parse/ParseOpenMP.cpp
index bd40e6b..5161c7d 100644
--- a/clang/lib/Parse/ParseOpenMP.cpp
+++ b/clang/lib/Parse/ParseOpenMP.cpp
@@ -2497,7 +2497,7 @@
/// in_reduction-clause | allocator-clause | allocate-clause |
/// acq_rel-clause | acquire-clause | release-clause | relaxed-clause |
/// depobj-clause | destroy-clause | detach-clause | inclusive-clause |
-/// exclusive-clause | uses_allocators-clause
+/// exclusive-clause | uses_allocators-clause | use_device_addr-clause
///
OMPClause *Parser::ParseOpenMPClause(OpenMPDirectiveKind DKind,
OpenMPClauseKind CKind, bool FirstClause) {
@@ -2663,6 +2663,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_allocate:
case OMPC_nontemporal:
@@ -3581,6 +3582,8 @@
/// 'from' '(' [ mapper '(' mapper-identifier ')' ':' ] list ')'
/// use_device_ptr-clause:
/// 'use_device_ptr' '(' list ')'
+/// use_device_addr-clause:
+/// 'use_device_addr' '(' list ')'
/// is_device_ptr-clause:
/// 'is_device_ptr' '(' list ')'
/// allocate-clause:
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index e556969..a60a047 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -5408,6 +5408,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_nontemporal:
case OMPC_order:
@@ -10165,12 +10166,18 @@
assert(isa<CapturedStmt>(AStmt) && "Captured statement expected");
- // OpenMP [2.10.1, Restrictions, p. 97]
- // At least one map clause must appear on the directive.
- if (!hasClauses(Clauses, OMPC_map, OMPC_use_device_ptr)) {
+ // OpenMP [2.12.2, target data Construct, Restrictions]
+ // At least one map, use_device_addr or use_device_ptr clause must appear on
+ // the directive.
+ if (!hasClauses(Clauses, OMPC_map, OMPC_use_device_ptr) &&
+ (LangOpts.OpenMP < 50 || !hasClauses(Clauses, OMPC_use_device_addr))) {
+ StringRef Expected;
+ if (LangOpts.OpenMP < 50)
+ Expected = "'map' or 'use_device_ptr'";
+ else
+ Expected = "'map', 'use_device_ptr', or 'use_device_addr'";
Diag(StartLoc, diag::err_omp_no_clause_for_directive)
- << "'map' or 'use_device_ptr'"
- << getOpenMPDirectiveName(OMPD_target_data);
+ << Expected << getOpenMPDirectiveName(OMPD_target_data);
return StmtError();
}
@@ -11535,6 +11542,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
@@ -12289,6 +12297,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
@@ -12731,6 +12740,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
@@ -12956,6 +12966,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_unified_address:
case OMPC_unified_shared_memory:
@@ -13195,6 +13206,7 @@
case OMPC_to:
case OMPC_from:
case OMPC_use_device_ptr:
+ case OMPC_use_device_addr:
case OMPC_is_device_ptr:
case OMPC_atomic_default_mem_order:
case OMPC_device_type:
@@ -13406,6 +13418,9 @@
case OMPC_use_device_ptr:
Res = ActOnOpenMPUseDevicePtrClause(VarList, Locs);
break;
+ case OMPC_use_device_addr:
+ Res = ActOnOpenMPUseDeviceAddrClause(VarList, Locs);
+ break;
case OMPC_is_device_ptr:
Res = ActOnOpenMPIsDevicePtrClause(VarList, Locs);
break;
@@ -18389,6 +18404,54 @@
MVLI.VarBaseDeclarations, MVLI.VarComponents);
}
+OMPClause *Sema::ActOnOpenMPUseDeviceAddrClause(ArrayRef<Expr *> VarList,
+ const OMPVarListLocTy &Locs) {
+ MappableVarListInfo MVLI(VarList);
+
+ for (Expr *RefExpr : VarList) {
+ assert(RefExpr && "NULL expr in OpenMP use_device_addr clause.");
+ SourceLocation ELoc;
+ SourceRange ERange;
+ Expr *SimpleRefExpr = RefExpr;
+ auto Res = getPrivateItem(*this, SimpleRefExpr, ELoc, ERange,
+ /*AllowArraySection=*/true);
+ if (Res.second) {
+ // It will be analyzed later.
+ MVLI.ProcessedVarList.push_back(RefExpr);
+ }
+ ValueDecl *D = Res.first;
+ if (!D)
+ continue;
+ auto *VD = dyn_cast<VarDecl>(D);
+
+ // If required, build a capture to implement the privatization initialized
+ // with the current list item value.
+ DeclRefExpr *Ref = nullptr;
+ if (!VD)
+ Ref = buildCapture(*this, D, SimpleRefExpr, /*WithInit=*/true);
+ MVLI.ProcessedVarList.push_back(VD ? RefExpr->IgnoreParens() : Ref);
+
+ // We need to add a data sharing attribute for this variable to make sure it
+ // is correctly captured. A variable that shows up in a use_device_addr has
+ // similar properties of a first private variable.
+ DSAStack->addDSA(D, RefExpr->IgnoreParens(), OMPC_firstprivate, Ref);
+
+ // Create a mappable component for the list item. List items in this clause
+ // only need a component.
+ MVLI.VarBaseDeclarations.push_back(D);
+ MVLI.VarComponents.emplace_back();
+ MVLI.VarComponents.back().push_back(
+ OMPClauseMappableExprCommon::MappableComponent(SimpleRefExpr, D));
+ }
+
+ if (MVLI.ProcessedVarList.empty())
+ return nullptr;
+
+ return OMPUseDeviceAddrClause::Create(Context, Locs, MVLI.ProcessedVarList,
+ MVLI.VarBaseDeclarations,
+ MVLI.VarComponents);
+}
+
OMPClause *Sema::ActOnOpenMPIsDevicePtrClause(ArrayRef<Expr *> VarList,
const OMPVarListLocTy &Locs) {
MappableVarListInfo MVLI(VarList);
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 923792f..e4c7155 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -2036,6 +2036,15 @@
return getSema().ActOnOpenMPUseDevicePtrClause(VarList, Locs);
}
+ /// Build a new OpenMP 'use_device_addr' clause.
+ ///
+ /// By default, performs semantic analysis to build the new OpenMP clause.
+ /// Subclasses may override this routine to provide different behavior.
+ OMPClause *RebuildOMPUseDeviceAddrClause(ArrayRef<Expr *> VarList,
+ const OMPVarListLocTy &Locs) {
+ return getSema().ActOnOpenMPUseDeviceAddrClause(VarList, Locs);
+ }
+
/// Build a new OpenMP 'is_device_ptr' clause.
///
/// By default, performs semantic analysis to build the new OpenMP clause.
@@ -9741,6 +9750,21 @@
}
template <typename Derived>
+OMPClause *TreeTransform<Derived>::TransformOMPUseDeviceAddrClause(
+ OMPUseDeviceAddrClause *C) {
+ llvm::SmallVector<Expr *, 16> Vars;
+ Vars.reserve(C->varlist_size());
+ for (auto *VE : C->varlists()) {
+ ExprResult EVar = getDerived().TransformExpr(cast<Expr>(VE));
+ if (EVar.isInvalid())
+ return nullptr;
+ Vars.push_back(EVar.get());
+ }
+ OMPVarListLocTy Locs(C->getBeginLoc(), C->getLParenLoc(), C->getEndLoc());
+ return getDerived().RebuildOMPUseDeviceAddrClause(Vars, Locs);
+}
+
+template <typename Derived>
OMPClause *
TreeTransform<Derived>::TransformOMPIsDevicePtrClause(OMPIsDevicePtrClause *C) {
llvm::SmallVector<Expr *, 16> Vars;
diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp
index 16bcb18..a5a1276 100644
--- a/clang/lib/Serialization/ASTReader.cpp
+++ b/clang/lib/Serialization/ASTReader.cpp
@@ -11918,6 +11918,15 @@
C = OMPUseDevicePtrClause::CreateEmpty(Context, Sizes);
break;
}
+ case llvm::omp::OMPC_use_device_addr: {
+ OMPMappableExprListSizeTy Sizes;
+ Sizes.NumVars = Record.readInt();
+ Sizes.NumUniqueDeclarations = Record.readInt();
+ Sizes.NumComponentLists = Record.readInt();
+ Sizes.NumComponents = Record.readInt();
+ C = OMPUseDeviceAddrClause::CreateEmpty(Context, Sizes);
+ break;
+ }
case llvm::omp::OMPC_is_device_ptr: {
OMPMappableExprListSizeTy Sizes;
Sizes.NumVars = Record.readInt();
@@ -12704,6 +12713,48 @@
C->setComponents(Components, ListSizes);
}
+void OMPClauseReader::VisitOMPUseDeviceAddrClause(OMPUseDeviceAddrClause *C) {
+ C->setLParenLoc(Record.readSourceLocation());
+ auto NumVars = C->varlist_size();
+ auto UniqueDecls = C->getUniqueDeclarationsNum();
+ auto TotalLists = C->getTotalComponentListNum();
+ auto TotalComponents = C->getTotalComponentsNum();
+
+ SmallVector<Expr *, 16> Vars;
+ Vars.reserve(NumVars);
+ for (unsigned i = 0; i != NumVars; ++i)
+ Vars.push_back(Record.readSubExpr());
+ C->setVarRefs(Vars);
+
+ SmallVector<ValueDecl *, 16> Decls;
+ Decls.reserve(UniqueDecls);
+ for (unsigned i = 0; i < UniqueDecls; ++i)
+ Decls.push_back(Record.readDeclAs<ValueDecl>());
+ C->setUniqueDecls(Decls);
+
+ SmallVector<unsigned, 16> ListsPerDecl;
+ ListsPerDecl.reserve(UniqueDecls);
+ for (unsigned i = 0; i < UniqueDecls; ++i)
+ ListsPerDecl.push_back(Record.readInt());
+ C->setDeclNumLists(ListsPerDecl);
+
+ SmallVector<unsigned, 32> ListSizes;
+ ListSizes.reserve(TotalLists);
+ for (unsigned i = 0; i < TotalLists; ++i)
+ ListSizes.push_back(Record.readInt());
+ C->setComponentListSizes(ListSizes);
+
+ SmallVector<OMPClauseMappableExprCommon::MappableComponent, 32> Components;
+ Components.reserve(TotalComponents);
+ for (unsigned i = 0; i < TotalComponents; ++i) {
+ Expr *AssociatedExpr = Record.readSubExpr();
+ auto *AssociatedDecl = Record.readDeclAs<ValueDecl>();
+ Components.push_back(OMPClauseMappableExprCommon::MappableComponent(
+ AssociatedExpr, AssociatedDecl));
+ }
+ C->setComponents(Components, ListSizes);
+}
+
void OMPClauseReader::VisitOMPIsDevicePtrClause(OMPIsDevicePtrClause *C) {
C->setLParenLoc(Record.readSourceLocation());
auto NumVars = C->varlist_size();
diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp
index 1e3adb5..9d81e13 100644
--- a/clang/lib/Serialization/ASTWriter.cpp
+++ b/clang/lib/Serialization/ASTWriter.cpp
@@ -6625,6 +6625,26 @@
}
}
+void OMPClauseWriter::VisitOMPUseDeviceAddrClause(OMPUseDeviceAddrClause *C) {
+ Record.push_back(C->varlist_size());
+ Record.push_back(C->getUniqueDeclarationsNum());
+ Record.push_back(C->getTotalComponentListNum());
+ Record.push_back(C->getTotalComponentsNum());
+ Record.AddSourceLocation(C->getLParenLoc());
+ for (auto *E : C->varlists())
+ Record.AddStmt(E);
+ for (auto *D : C->all_decls())
+ Record.AddDeclRef(D);
+ for (auto N : C->all_num_lists())
+ Record.push_back(N);
+ for (auto N : C->all_lists_sizes())
+ Record.push_back(N);
+ for (auto &M : C->all_components()) {
+ Record.AddStmt(M.getAssociatedExpression());
+ Record.AddDeclRef(M.getAssociatedDeclaration());
+ }
+}
+
void OMPClauseWriter::VisitOMPIsDevicePtrClause(OMPIsDevicePtrClause *C) {
Record.push_back(C->varlist_size());
Record.push_back(C->getUniqueDeclarationsNum());