[OpenMP] Support for num_teams-clause on the 'target teams' directive.
The num_teams-clause on the combined directive applies to the
'teams' region of this construct. We modify the NumTeamsClause
class to capture the clause expression within the 'target' region.
Reviewers: ABataev
Differential Revision: https://reviews.llvm.org/D29085
llvm-svn: 293048
diff --git a/clang/lib/AST/OpenMPClause.cpp b/clang/lib/AST/OpenMPClause.cpp
index 55f5ca5..0dcf82d 100644
--- a/clang/lib/AST/OpenMPClause.cpp
+++ b/clang/lib/AST/OpenMPClause.cpp
@@ -52,6 +52,8 @@
return static_cast<const OMPIfClause *>(C);
case OMPC_num_threads:
return static_cast<const OMPNumThreadsClause *>(C);
+ case OMPC_num_teams:
+ return static_cast<const OMPNumTeamsClause *>(C);
case OMPC_default:
case OMPC_proc_bind:
case OMPC_final:
@@ -79,7 +81,6 @@
case OMPC_threads:
case OMPC_simd:
case OMPC_map:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index fa086a3..b122580 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -497,6 +497,7 @@
VisitOMPClauseList(C);
}
void OMPClauseProfiler::VisitOMPNumTeamsClause(const OMPNumTeamsClause *C) {
+ VistOMPClauseWithPreInit(C);
if (C->getNumTeams())
Profiler->VisitStmt(C->getNumTeams());
}
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 040bddc..f9474b3 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -6771,6 +6771,69 @@
llvm_unreachable("Unknown OpenMP directive");
}
break;
+ case OMPC_num_teams:
+ switch (DKind) {
+ case OMPD_target_teams:
+ CaptureRegion = OMPD_target;
+ break;
+ case OMPD_cancel:
+ case OMPD_parallel:
+ case OMPD_parallel_sections:
+ case OMPD_parallel_for:
+ case OMPD_parallel_for_simd:
+ case OMPD_target:
+ case OMPD_target_simd:
+ case OMPD_target_parallel:
+ case OMPD_target_parallel_for:
+ case OMPD_target_parallel_for_simd:
+ case OMPD_target_teams_distribute:
+ case OMPD_target_teams_distribute_simd:
+ case OMPD_target_teams_distribute_parallel_for:
+ case OMPD_target_teams_distribute_parallel_for_simd:
+ case OMPD_teams_distribute_parallel_for:
+ case OMPD_teams_distribute_parallel_for_simd:
+ case OMPD_distribute_parallel_for:
+ case OMPD_distribute_parallel_for_simd:
+ case OMPD_task:
+ case OMPD_taskloop:
+ case OMPD_taskloop_simd:
+ case OMPD_target_data:
+ case OMPD_target_enter_data:
+ case OMPD_target_exit_data:
+ case OMPD_target_update:
+ case OMPD_teams:
+ case OMPD_teams_distribute:
+ case OMPD_teams_distribute_simd:
+ // Do not capture num_teams-clause expressions.
+ break;
+ case OMPD_threadprivate:
+ case OMPD_taskyield:
+ case OMPD_barrier:
+ case OMPD_taskwait:
+ case OMPD_cancellation_point:
+ case OMPD_flush:
+ case OMPD_declare_reduction:
+ case OMPD_declare_simd:
+ case OMPD_declare_target:
+ case OMPD_end_declare_target:
+ case OMPD_simd:
+ case OMPD_for:
+ case OMPD_for_simd:
+ case OMPD_sections:
+ case OMPD_section:
+ case OMPD_single:
+ case OMPD_master:
+ case OMPD_critical:
+ case OMPD_taskgroup:
+ case OMPD_distribute:
+ case OMPD_ordered:
+ case OMPD_atomic:
+ case OMPD_distribute_simd:
+ llvm_unreachable("Unexpected OpenMP directive with num_teams-clause");
+ case OMPD_unknown:
+ llvm_unreachable("Unknown OpenMP directive");
+ }
+ break;
case OMPC_schedule:
case OMPC_dist_schedule:
case OMPC_firstprivate:
@@ -6804,7 +6867,6 @@
case OMPC_threads:
case OMPC_simd:
case OMPC_map:
- case OMPC_num_teams:
case OMPC_thread_limit:
case OMPC_priority:
case OMPC_grainsize:
@@ -10860,6 +10922,8 @@
SourceLocation LParenLoc,
SourceLocation EndLoc) {
Expr *ValExpr = NumTeams;
+ Stmt *HelperValStmt = nullptr;
+ OpenMPDirectiveKind CaptureRegion = OMPD_unknown;
// OpenMP [teams Constrcut, Restrictions]
// The num_teams expression must evaluate to a positive integer value.
@@ -10867,7 +10931,16 @@
/*StrictlyPositive=*/true))
return nullptr;
- return new (Context) OMPNumTeamsClause(ValExpr, StartLoc, LParenLoc, EndLoc);
+ OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
+ CaptureRegion = getOpenMPCaptureRegionForClause(DKind, OMPC_num_teams);
+ if (CaptureRegion != OMPD_unknown) {
+ llvm::MapVector<Expr *, DeclRefExpr *> Captures;
+ ValExpr = tryBuildCapture(*this, ValExpr, Captures).get();
+ HelperValStmt = buildPreInits(Context, Captures);
+ }
+
+ return new (Context) OMPNumTeamsClause(ValExpr, HelperValStmt, CaptureRegion,
+ StartLoc, LParenLoc, EndLoc);
}
OMPClause *Sema::ActOnOpenMPThreadLimitClause(Expr *ThreadLimit,
diff --git a/clang/lib/Serialization/ASTReaderStmt.cpp b/clang/lib/Serialization/ASTReaderStmt.cpp
index 55eca50..ce136bf 100644
--- a/clang/lib/Serialization/ASTReaderStmt.cpp
+++ b/clang/lib/Serialization/ASTReaderStmt.cpp
@@ -2300,6 +2300,7 @@
}
void OMPClauseReader::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
+ VisitOMPClauseWithPreInit(C);
C->setNumTeams(Reader->Record.readSubExpr());
C->setLParenLoc(Reader->ReadSourceLocation());
}
diff --git a/clang/lib/Serialization/ASTWriterStmt.cpp b/clang/lib/Serialization/ASTWriterStmt.cpp
index ea6f579..5d9080e 100644
--- a/clang/lib/Serialization/ASTWriterStmt.cpp
+++ b/clang/lib/Serialization/ASTWriterStmt.cpp
@@ -2067,6 +2067,7 @@
}
void OMPClauseWriter::VisitOMPNumTeamsClause(OMPNumTeamsClause *C) {
+ VisitOMPClauseWithPreInit(C);
Record.AddStmt(C->getNumTeams());
Record.AddSourceLocation(C->getLParenLoc());
}