[OPENMP] Additional checking for 'collapse' clause.
llvm-svn: 211589
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 561251c..16def4e 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -1391,13 +1391,22 @@
/// \brief Called on a for stmt to check and extract its iteration space
/// for further processing (such as collapsing).
static bool CheckOpenMPIterationSpace(OpenMPDirectiveKind DKind, Stmt *S,
- Sema &SemaRef, DSAStackTy &DSA) {
+ Sema &SemaRef, DSAStackTy &DSA,
+ unsigned CurrentNestedLoopCount,
+ unsigned NestedLoopCount,
+ Expr *NestedLoopCountExpr) {
// OpenMP [2.6, Canonical Loop Form]
// for (init-expr; test-expr; incr-expr) structured-block
auto For = dyn_cast_or_null<ForStmt>(S);
if (!For) {
SemaRef.Diag(S->getLocStart(), diag::err_omp_not_for)
- << getOpenMPDirectiveName(DKind);
+ << (NestedLoopCountExpr != nullptr) << getOpenMPDirectiveName(DKind)
+ << NestedLoopCount << (CurrentNestedLoopCount > 0)
+ << CurrentNestedLoopCount;
+ if (NestedLoopCount > 1)
+ SemaRef.Diag(NestedLoopCountExpr->getExprLoc(),
+ diag::note_omp_collapse_expr)
+ << NestedLoopCountExpr->getSourceRange();
return true;
}
assert(For->getBody());
@@ -1491,14 +1500,21 @@
}
/// \brief Called on a for stmt to check itself and nested loops (if any).
-static bool CheckOpenMPLoop(OpenMPDirectiveKind DKind, unsigned NestedLoopCount,
+static bool CheckOpenMPLoop(OpenMPDirectiveKind DKind, Expr *NestedLoopCountExpr,
Stmt *AStmt, Sema &SemaRef, DSAStackTy &DSA) {
+ unsigned NestedLoopCount = 1;
+ if (NestedLoopCountExpr) {
+ // Found 'collapse' clause - calculate collapse number.
+ llvm::APSInt Result;
+ if (NestedLoopCountExpr->EvaluateAsInt(Result, SemaRef.getASTContext()))
+ NestedLoopCount = Result.getLimitedValue();
+ }
// This is helper routine for loop directives (e.g., 'for', 'simd',
// 'for simd', etc.).
- assert(NestedLoopCount == 1);
Stmt *CurStmt = IgnoreContainerStmts(AStmt, true);
for (unsigned Cnt = 0; Cnt < NestedLoopCount; ++Cnt) {
- if (CheckOpenMPIterationSpace(DKind, CurStmt, SemaRef, DSA))
+ if (CheckOpenMPIterationSpace(DKind, CurStmt, SemaRef, DSA, Cnt,
+ NestedLoopCount, NestedLoopCountExpr))
return true;
// Move on to the next nested for loop, or to the loop body.
CurStmt = IgnoreContainerStmts(cast<ForStmt>(CurStmt)->getBody(), false);
@@ -1509,12 +1525,29 @@
return false;
}
+namespace {
+struct OMPCollapseClauseFilter {
+ OMPCollapseClauseFilter() {}
+ bool operator()(const OMPClause *C) {
+ return C->getClauseKind() == OMPC_collapse;
+ }
+};
+} // namespace
+
+static Expr *GetCollapseNumberExpr(ArrayRef<OMPClause *> Clauses) {
+ OMPExecutableDirective::filtered_clause_iterator<OMPCollapseClauseFilter> I(
+ Clauses);
+ if (I)
+ return cast<OMPCollapseClause>(*I)->getNumForLoops();
+ return nullptr;
+}
+
StmtResult Sema::ActOnOpenMPSimdDirective(ArrayRef<OMPClause *> Clauses,
Stmt *AStmt, SourceLocation StartLoc,
SourceLocation EndLoc) {
// In presence of clause 'collapse', it will define the nested loops number.
- // For now, pass default value of 1.
- if (CheckOpenMPLoop(OMPD_simd, 1, AStmt, *this, *DSAStack))
+ if (CheckOpenMPLoop(OMPD_simd, GetCollapseNumberExpr(Clauses),
+ AStmt, *this, *DSAStack))
return StmtError();
getCurFunction()->setHasBranchProtectedScope();
@@ -1525,8 +1558,8 @@
Stmt *AStmt, SourceLocation StartLoc,
SourceLocation EndLoc) {
// In presence of clause 'collapse', it will define the nested loops number.
- // For now, pass default value of 1.
- if (CheckOpenMPLoop(OMPD_for, 1, AStmt, *this, *DSAStack))
+ if (CheckOpenMPLoop(OMPD_for, GetCollapseNumberExpr(Clauses),
+ AStmt, *this, *DSAStack))
return StmtError();
getCurFunction()->setHasBranchProtectedScope();