[OPENMP] Additional checking for 'collapse' clause.

llvm-svn: 211589
diff --git a/clang/lib/Sema/SemaOpenMP.cpp b/clang/lib/Sema/SemaOpenMP.cpp
index 561251c..16def4e 100644
--- a/clang/lib/Sema/SemaOpenMP.cpp
+++ b/clang/lib/Sema/SemaOpenMP.cpp
@@ -1391,13 +1391,22 @@
 /// \brief Called on a for stmt to check and extract its iteration space
 /// for further processing (such as collapsing).
 static bool CheckOpenMPIterationSpace(OpenMPDirectiveKind DKind, Stmt *S,
-                                      Sema &SemaRef, DSAStackTy &DSA) {
+                                      Sema &SemaRef, DSAStackTy &DSA,
+                                      unsigned CurrentNestedLoopCount,
+                                      unsigned NestedLoopCount,
+                                      Expr *NestedLoopCountExpr) {
   // OpenMP [2.6, Canonical Loop Form]
   //   for (init-expr; test-expr; incr-expr) structured-block
   auto For = dyn_cast_or_null<ForStmt>(S);
   if (!For) {
     SemaRef.Diag(S->getLocStart(), diag::err_omp_not_for)
-        << getOpenMPDirectiveName(DKind);
+        << (NestedLoopCountExpr != nullptr) << getOpenMPDirectiveName(DKind)
+        << NestedLoopCount << (CurrentNestedLoopCount > 0)
+        << CurrentNestedLoopCount;
+    if (NestedLoopCount > 1)
+      SemaRef.Diag(NestedLoopCountExpr->getExprLoc(),
+                   diag::note_omp_collapse_expr)
+          << NestedLoopCountExpr->getSourceRange();
     return true;
   }
   assert(For->getBody());
@@ -1491,14 +1500,21 @@
 }
 
 /// \brief Called on a for stmt to check itself and nested loops (if any).
-static bool CheckOpenMPLoop(OpenMPDirectiveKind DKind, unsigned NestedLoopCount,
+static bool CheckOpenMPLoop(OpenMPDirectiveKind DKind, Expr *NestedLoopCountExpr,
                             Stmt *AStmt, Sema &SemaRef, DSAStackTy &DSA) {
+  unsigned NestedLoopCount = 1;
+  if (NestedLoopCountExpr) {
+    // Found 'collapse' clause - calculate collapse number.
+    llvm::APSInt Result;
+    if (NestedLoopCountExpr->EvaluateAsInt(Result, SemaRef.getASTContext()))
+      NestedLoopCount = Result.getLimitedValue();
+  }
   // This is helper routine for loop directives (e.g., 'for', 'simd',
   // 'for simd', etc.).
-  assert(NestedLoopCount == 1);
   Stmt *CurStmt = IgnoreContainerStmts(AStmt, true);
   for (unsigned Cnt = 0; Cnt < NestedLoopCount; ++Cnt) {
-    if (CheckOpenMPIterationSpace(DKind, CurStmt, SemaRef, DSA))
+    if (CheckOpenMPIterationSpace(DKind, CurStmt, SemaRef, DSA, Cnt,
+                                  NestedLoopCount, NestedLoopCountExpr))
       return true;
     // Move on to the next nested for loop, or to the loop body.
     CurStmt = IgnoreContainerStmts(cast<ForStmt>(CurStmt)->getBody(), false);
@@ -1509,12 +1525,29 @@
   return false;
 }
 
+namespace {
+struct OMPCollapseClauseFilter {
+  OMPCollapseClauseFilter() {}
+  bool operator()(const OMPClause *C) {
+    return C->getClauseKind() == OMPC_collapse;
+  }
+};
+} // namespace
+
+static Expr *GetCollapseNumberExpr(ArrayRef<OMPClause *> Clauses) {
+  OMPExecutableDirective::filtered_clause_iterator<OMPCollapseClauseFilter> I(
+      Clauses);
+  if (I)
+    return cast<OMPCollapseClause>(*I)->getNumForLoops();
+  return nullptr;
+}
+
 StmtResult Sema::ActOnOpenMPSimdDirective(ArrayRef<OMPClause *> Clauses,
                                           Stmt *AStmt, SourceLocation StartLoc,
                                           SourceLocation EndLoc) {
   // In presence of clause 'collapse', it will define the nested loops number.
-  // For now, pass default value of 1.
-  if (CheckOpenMPLoop(OMPD_simd, 1, AStmt, *this, *DSAStack))
+  if (CheckOpenMPLoop(OMPD_simd, GetCollapseNumberExpr(Clauses),
+                      AStmt, *this, *DSAStack))
     return StmtError();
 
   getCurFunction()->setHasBranchProtectedScope();
@@ -1525,8 +1558,8 @@
                                          Stmt *AStmt, SourceLocation StartLoc,
                                          SourceLocation EndLoc) {
   // In presence of clause 'collapse', it will define the nested loops number.
-  // For now, pass default value of 1.
-  if (CheckOpenMPLoop(OMPD_for, 1, AStmt, *this, *DSAStack))
+  if (CheckOpenMPLoop(OMPD_for, GetCollapseNumberExpr(Clauses),
+                      AStmt, *this, *DSAStack))
     return StmtError();
 
   getCurFunction()->setHasBranchProtectedScope();