[OpenMP] Support for the if-clause on the combined directive 'target parallel'.
The if-clause on the combined directive potentially applies to both the
'target' and the 'parallel' regions. Codegen'ing the if-clause on the
combined directive requires additional support because the expression in
the clause must be captured by the 'target' capture statement but not
the 'parallel' capture statement. Note that this situation arises for
other clauses such as num_threads.
The OMPIfClause class inherits OMPClauseWithPreInit to support capturing
of expressions in the clause. A member CaptureRegion is added to
OMPClauseWithPreInit to indicate which captured statement (in this case
'target' but not 'parallel') captures these expressions.
To ensure correct codegen of captured expressions in the presence of
combined 'target' directives, OMPParallelScope was added to 'parallel'
codegen.
Reviewers: ABataev
Differential Revision: https://reviews.llvm.org/D28781
llvm-svn: 292437
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 5b21f580..a4622e1 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -1863,6 +1863,7 @@
OMPOrderedClause *OC = nullptr;
OMPScheduleClause *SC = nullptr;
SmallVector<OMPLinearClause *, 4> LCs;
+ SmallVector<OMPClauseWithPreInit *, 8> PICs;
// This is required for proper codegen.
for (auto *Clause : Clauses) {
if (isOpenMPPrivate(Clause->getClauseKind()) ||
@@ -1879,15 +1880,8 @@
}
DSAStack->setForceVarCapturing(/*V=*/false);
} else if (isParallelOrTaskRegion(DSAStack->getCurrentDirective())) {
- // Mark all variables in private list clauses as used in inner region.
- // Required for proper codegen of combined directives.
- // TODO: add processing for other clauses.
- if (auto *C = OMPClauseWithPreInit::get(Clause)) {
- if (auto *DS = cast_or_null<DeclStmt>(C->getPreInitStmt())) {
- for (auto *D : DS->decls())
- MarkVariableReferenced(D->getLocation(), cast<VarDecl>(D));
- }
- }
+ if (auto *C = OMPClauseWithPreInit::get(Clause))
+ PICs.push_back(C);
if (auto *C = OMPClauseWithPostUpdate::get(Clause)) {
if (auto *E = C->getPostUpdateExpr())
MarkDeclarationsReferencedInExpr(E);
@@ -1933,10 +1927,31 @@
return StmtError();
}
StmtResult SR = S;
- int ThisCaptureLevel =
- getOpenMPCaptureLevels(DSAStack->getCurrentDirective());
- while (--ThisCaptureLevel >= 0)
+ SmallVector<OpenMPDirectiveKind, 4> CaptureRegions;
+ getOpenMPCaptureRegions(CaptureRegions, DSAStack->getCurrentDirective());
+ for (auto ThisCaptureRegion : llvm::reverse(CaptureRegions)) {
+ // Mark all variables in private list clauses as used in inner region.
+ // Required for proper codegen of combined directives.
+ // TODO: add processing for other clauses.
+ if (isParallelOrTaskRegion(DSAStack->getCurrentDirective())) {
+ for (auto *C : PICs) {
+ OpenMPDirectiveKind CaptureRegion = C->getCaptureRegion();
+ // Find the particular capture region for the clause if the
+ // directive is a combined one with multiple capture regions.
+ // If the directive is not a combined one, the capture region
+ // associated with the clause is OMPD_unknown and is generated
+ // only once.
+ if (CaptureRegion == ThisCaptureRegion ||
+ CaptureRegion == OMPD_unknown) {
+ if (auto *DS = cast_or_null<DeclStmt>(C->getPreInitStmt())) {
+ for (auto *D : DS->decls())
+ MarkVariableReferenced(D->getLocation(), cast<VarDecl>(D));
+ }
+ }
+ }
+ }
SR = ActOnCapturedRegionEnd(SR.get());
+ }
return SR;
}
@@ -6611,6 +6626,137 @@
return Res;
}
+// An OpenMP directive such as 'target parallel' has two captured regions:
+// for the 'target' and 'parallel' respectively. This function returns
+// the region in which to capture expressions associated with a clause.
+// A return value of OMPD_unknown signifies that the expression should not
+// be captured.
+static OpenMPDirectiveKind
+getOpenMPCaptureRegionForClause(OpenMPDirectiveKind DKind,
+ OpenMPClauseKind CKind,
+ OpenMPDirectiveKind NameModifier) {
+ OpenMPDirectiveKind CaptureRegion = OMPD_unknown;
+
+ switch (CKind) {
+ case OMPC_if:
+ switch (DKind) {
+ case OMPD_target_parallel:
+ // If this clause applies to the nested 'parallel' region, capture within
+ // the 'target' region, otherwise do not capture.
+ if (NameModifier == OMPD_unknown || NameModifier == OMPD_parallel)
+ CaptureRegion = OMPD_target;
+ break;
+ case OMPD_cancel:
+ case OMPD_parallel:
+ case OMPD_parallel_sections:
+ case OMPD_parallel_for:
+ case OMPD_parallel_for_simd:
+ case OMPD_target:
+ case OMPD_target_simd:
+ case OMPD_target_parallel_for:
+ case OMPD_target_parallel_for_simd:
+ case OMPD_target_teams:
+ case OMPD_target_teams_distribute:
+ case OMPD_target_teams_distribute_simd:
+ case OMPD_target_teams_distribute_parallel_for:
+ case OMPD_target_teams_distribute_parallel_for_simd:
+ case OMPD_teams_distribute_parallel_for:
+ case OMPD_teams_distribute_parallel_for_simd:
+ case OMPD_distribute_parallel_for:
+ case OMPD_distribute_parallel_for_simd:
+ case OMPD_task:
+ case OMPD_taskloop:
+ case OMPD_taskloop_simd:
+ case OMPD_target_data:
+ case OMPD_target_enter_data:
+ case OMPD_target_exit_data:
+ case OMPD_target_update:
+ // Do not capture if-clause expressions.
+ break;
+ case OMPD_threadprivate:
+ case OMPD_taskyield:
+ case OMPD_barrier:
+ case OMPD_taskwait:
+ case OMPD_cancellation_point:
+ case OMPD_flush:
+ case OMPD_declare_reduction:
+ case OMPD_declare_simd:
+ case OMPD_declare_target:
+ case OMPD_end_declare_target:
+ case OMPD_teams:
+ case OMPD_simd:
+ case OMPD_for:
+ case OMPD_for_simd:
+ case OMPD_sections:
+ case OMPD_section:
+ case OMPD_single:
+ case OMPD_master:
+ case OMPD_critical:
+ case OMPD_taskgroup:
+ case OMPD_distribute:
+ case OMPD_ordered:
+ case OMPD_atomic:
+ case OMPD_distribute_simd:
+ case OMPD_teams_distribute:
+ case OMPD_teams_distribute_simd:
+ llvm_unreachable("Unexpected OpenMP directive with if-clause");
+ case OMPD_unknown:
+ llvm_unreachable("Unknown OpenMP directive");
+ }
+ break;
+ case OMPC_schedule:
+ case OMPC_dist_schedule:
+ case OMPC_firstprivate:
+ case OMPC_lastprivate:
+ case OMPC_reduction:
+ case OMPC_linear:
+ case OMPC_default:
+ case OMPC_proc_bind:
+ case OMPC_final:
+ case OMPC_num_threads:
+ case OMPC_safelen:
+ case OMPC_simdlen:
+ case OMPC_collapse:
+ case OMPC_private:
+ case OMPC_shared:
+ case OMPC_aligned:
+ case OMPC_copyin:
+ case OMPC_copyprivate:
+ case OMPC_ordered:
+ case OMPC_nowait:
+ case OMPC_untied:
+ case OMPC_mergeable:
+ case OMPC_threadprivate:
+ case OMPC_flush:
+ case OMPC_read:
+ case OMPC_write:
+ case OMPC_update:
+ case OMPC_capture:
+ case OMPC_seq_cst:
+ case OMPC_depend:
+ case OMPC_device:
+ case OMPC_threads:
+ case OMPC_simd:
+ case OMPC_map:
+ case OMPC_num_teams:
+ case OMPC_thread_limit:
+ case OMPC_priority:
+ case OMPC_grainsize:
+ case OMPC_nogroup:
+ case OMPC_num_tasks:
+ case OMPC_hint:
+ case OMPC_defaultmap:
+ case OMPC_unknown:
+ case OMPC_uniform:
+ case OMPC_to:
+ case OMPC_from:
+ case OMPC_use_device_ptr:
+ case OMPC_is_device_ptr:
+ llvm_unreachable("Unexpected OpenMP clause.");
+ }
+ return CaptureRegion;
+}
+
OMPClause *Sema::ActOnOpenMPIfClause(OpenMPDirectiveKind NameModifier,
Expr *Condition, SourceLocation StartLoc,
SourceLocation LParenLoc,
@@ -6618,6 +6764,8 @@
SourceLocation ColonLoc,
SourceLocation EndLoc) {
Expr *ValExpr = Condition;
+ Stmt *HelperValStmt = nullptr;
+ OpenMPDirectiveKind CaptureRegion = OMPD_unknown;
if (!Condition->isValueDependent() && !Condition->isTypeDependent() &&
!Condition->isInstantiationDependent() &&
!Condition->containsUnexpandedParameterPack()) {
@@ -6626,10 +6774,20 @@
return nullptr;
ValExpr = MakeFullExpr(Val.get()).get();
+
+ OpenMPDirectiveKind DKind = DSAStack->getCurrentDirective();
+ CaptureRegion =
+ getOpenMPCaptureRegionForClause(DKind, OMPC_if, NameModifier);
+ if (CaptureRegion != OMPD_unknown) {
+ llvm::MapVector<Expr *, DeclRefExpr *> Captures;
+ ValExpr = tryBuildCapture(*this, ValExpr, Captures).get();
+ HelperValStmt = buildPreInits(Context, Captures);
+ }
}
- return new (Context) OMPIfClause(NameModifier, ValExpr, StartLoc, LParenLoc,
- NameModifierLoc, ColonLoc, EndLoc);
+ return new (Context)
+ OMPIfClause(NameModifier, ValExpr, HelperValStmt, CaptureRegion, StartLoc,
+ LParenLoc, NameModifierLoc, ColonLoc, EndLoc);
}
OMPClause *Sema::ActOnOpenMPFinalClause(Expr *Condition,