[cxx2a] P0614R1: Support init-statements in range-based for loops.

We don't yet support this for the case where a range-based for loop is
implicitly rewritten to an ObjC for..in statement.

llvm-svn: 343350
diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index ed2fe4e..f0d1947 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -1647,6 +1647,8 @@
     void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
       // Only visit the initialization of a for loop; the body
       // has a different break/continue scope.
+      if (const Stmt *Init = S->getInit())
+        Visit(Init);
       if (const Stmt *Range = S->getRangeStmt())
         Visit(Range);
       if (const Stmt *Begin = S->getBeginStmt())
@@ -2065,15 +2067,20 @@
 /// The body of the loop is not available yet, since it cannot be analysed until
 /// we have determined the type of the for-range-declaration.
 StmtResult Sema::ActOnCXXForRangeStmt(Scope *S, SourceLocation ForLoc,
-                                      SourceLocation CoawaitLoc, Stmt *First,
-                                      SourceLocation ColonLoc, Expr *Range,
-                                      SourceLocation RParenLoc,
+                                      SourceLocation CoawaitLoc, Stmt *InitStmt,
+                                      Stmt *First, SourceLocation ColonLoc,
+                                      Expr *Range, SourceLocation RParenLoc,
                                       BuildForRangeKind Kind) {
   if (!First)
     return StmtError();
 
-  if (Range && ObjCEnumerationCollection(Range))
+  if (Range && ObjCEnumerationCollection(Range)) {
+    // FIXME: Support init-statements in Objective-C++20 ranged for statement.
+    if (InitStmt)
+      return Diag(InitStmt->getBeginLoc(), diag::err_objc_for_range_init_stmt)
+                 << InitStmt->getSourceRange();
     return ActOnObjCForCollectionStmt(ForLoc, First, Range, RParenLoc);
+  }
 
   DeclStmt *DS = dyn_cast<DeclStmt>(First);
   assert(DS && "first part of for range not a decl stmt");
@@ -2119,10 +2126,10 @@
     return StmtError();
   }
 
-  return BuildCXXForRangeStmt(ForLoc, CoawaitLoc, ColonLoc, RangeDecl.get(),
-                              /*BeginStmt=*/nullptr, /*EndStmt=*/nullptr,
-                              /*Cond=*/nullptr, /*Inc=*/nullptr,
-                              DS, RParenLoc, Kind);
+  return BuildCXXForRangeStmt(
+      ForLoc, CoawaitLoc, InitStmt, ColonLoc, RangeDecl.get(),
+      /*BeginStmt=*/nullptr, /*EndStmt=*/nullptr,
+      /*Cond=*/nullptr, /*Inc=*/nullptr, DS, RParenLoc, Kind);
 }
 
 /// Create the initialization, compare, and increment steps for
@@ -2270,6 +2277,7 @@
 static StmtResult RebuildForRangeWithDereference(Sema &SemaRef, Scope *S,
                                                  SourceLocation ForLoc,
                                                  SourceLocation CoawaitLoc,
+                                                 Stmt *InitStmt,
                                                  Stmt *LoopVarDecl,
                                                  SourceLocation ColonLoc,
                                                  Expr *Range,
@@ -2286,8 +2294,8 @@
       return StmtResult();
 
     StmtResult SR = SemaRef.ActOnCXXForRangeStmt(
-        S, ForLoc, CoawaitLoc, LoopVarDecl, ColonLoc, AdjustedRange.get(),
-        RParenLoc, Sema::BFRK_Check);
+        S, ForLoc, CoawaitLoc, InitStmt, LoopVarDecl, ColonLoc,
+        AdjustedRange.get(), RParenLoc, Sema::BFRK_Check);
     if (SR.isInvalid())
       return StmtResult();
   }
@@ -2297,9 +2305,9 @@
   // case there are any other (non-fatal) problems with it.
   SemaRef.Diag(RangeLoc, diag::err_for_range_dereference)
     << Range->getType() << FixItHint::CreateInsertion(RangeLoc, "*");
-  return SemaRef.ActOnCXXForRangeStmt(S, ForLoc, CoawaitLoc, LoopVarDecl,
-                                      ColonLoc, AdjustedRange.get(), RParenLoc,
-                                      Sema::BFRK_Rebuild);
+  return SemaRef.ActOnCXXForRangeStmt(
+      S, ForLoc, CoawaitLoc, InitStmt, LoopVarDecl, ColonLoc,
+      AdjustedRange.get(), RParenLoc, Sema::BFRK_Rebuild);
 }
 
 namespace {
@@ -2319,12 +2327,13 @@
 }
 
 /// BuildCXXForRangeStmt - Build or instantiate a C++11 for-range statement.
-StmtResult
-Sema::BuildCXXForRangeStmt(SourceLocation ForLoc, SourceLocation CoawaitLoc,
-                           SourceLocation ColonLoc, Stmt *RangeDecl,
-                           Stmt *Begin, Stmt *End, Expr *Cond,
-                           Expr *Inc, Stmt *LoopVarDecl,
-                           SourceLocation RParenLoc, BuildForRangeKind Kind) {
+StmtResult Sema::BuildCXXForRangeStmt(SourceLocation ForLoc,
+                                      SourceLocation CoawaitLoc, Stmt *InitStmt,
+                                      SourceLocation ColonLoc, Stmt *RangeDecl,
+                                      Stmt *Begin, Stmt *End, Expr *Cond,
+                                      Expr *Inc, Stmt *LoopVarDecl,
+                                      SourceLocation RParenLoc,
+                                      BuildForRangeKind Kind) {
   // FIXME: This should not be used during template instantiation. We should
   // pick up the set of unqualified lookup results for the != and + operators
   // in the initial parse.
@@ -2519,7 +2528,7 @@
         // If building the range failed, try dereferencing the range expression
         // unless a diagnostic was issued or the end function is problematic.
         StmtResult SR = RebuildForRangeWithDereference(*this, S, ForLoc,
-                                                       CoawaitLoc,
+                                                       CoawaitLoc, InitStmt,
                                                        LoopVarDecl, ColonLoc,
                                                        Range, RangeLoc,
                                                        RParenLoc);
@@ -2636,7 +2645,7 @@
     return StmtResult();
 
   return new (Context) CXXForRangeStmt(
-      RangeDS, cast_or_null<DeclStmt>(BeginDeclStmt.get()),
+      InitStmt, RangeDS, cast_or_null<DeclStmt>(BeginDeclStmt.get()),
       cast_or_null<DeclStmt>(EndDeclStmt.get()), NotEqExpr.get(),
       IncrExpr.get(), LoopVarDS, /*Body=*/nullptr, ForLoc, CoawaitLoc,
       ColonLoc, RParenLoc);
diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h
index 990e673..a75d307 100644
--- a/clang/lib/Sema/TreeTransform.h
+++ b/clang/lib/Sema/TreeTransform.h
@@ -2021,11 +2021,10 @@
   /// By default, performs semantic analysis to build the new statement.
   /// Subclasses may override this routine to provide different behavior.
   StmtResult RebuildCXXForRangeStmt(SourceLocation ForLoc,
-                                    SourceLocation CoawaitLoc,
-                                    SourceLocation ColonLoc,
-                                    Stmt *Range, Stmt *Begin, Stmt *End,
-                                    Expr *Cond, Expr *Inc,
-                                    Stmt *LoopVar,
+                                    SourceLocation CoawaitLoc, Stmt *Init,
+                                    SourceLocation ColonLoc, Stmt *Range,
+                                    Stmt *Begin, Stmt *End, Expr *Cond,
+                                    Expr *Inc, Stmt *LoopVar,
                                     SourceLocation RParenLoc) {
     // If we've just learned that the range is actually an Objective-C
     // collection, treat this as an Objective-C fast enumeration loop.
@@ -2037,17 +2036,24 @@
 
           Expr *RangeExpr = RangeVar->getInit();
           if (!RangeExpr->isTypeDependent() &&
-              RangeExpr->getType()->isObjCObjectPointerType())
-            return getSema().ActOnObjCForCollectionStmt(ForLoc, LoopVar, RangeExpr,
-                                                        RParenLoc);
+              RangeExpr->getType()->isObjCObjectPointerType()) {
+            // FIXME: Support init-statements in Objective-C++20 ranged for
+            // statement.
+            if (Init) {
+              return SemaRef.Diag(Init->getBeginLoc(),
+                                  diag::err_objc_for_range_init_stmt)
+                         << Init->getSourceRange();
+            }
+            return getSema().ActOnObjCForCollectionStmt(ForLoc, LoopVar,
+                                                        RangeExpr, RParenLoc);
+          }
         }
       }
     }
 
-    return getSema().BuildCXXForRangeStmt(ForLoc, CoawaitLoc, ColonLoc,
-                                          Range, Begin, End,
-                                          Cond, Inc, LoopVar, RParenLoc,
-                                          Sema::BFRK_Rebuild);
+    return getSema().BuildCXXForRangeStmt(ForLoc, CoawaitLoc, Init, ColonLoc,
+                                          Range, Begin, End, Cond, Inc, LoopVar,
+                                          RParenLoc, Sema::BFRK_Rebuild);
   }
 
   /// Build a new C++0x range-based for statement.
@@ -7412,6 +7418,11 @@
 template<typename Derived>
 StmtResult
 TreeTransform<Derived>::TransformCXXForRangeStmt(CXXForRangeStmt *S) {
+  StmtResult Init =
+      S->getInit() ? getDerived().TransformStmt(S->getInit()) : StmtResult();
+  if (Init.isInvalid())
+    return StmtError();
+
   StmtResult Range = getDerived().TransformStmt(S->getRangeStmt());
   if (Range.isInvalid())
     return StmtError();
@@ -7445,6 +7456,7 @@
 
   StmtResult NewStmt = S;
   if (getDerived().AlwaysRebuild() ||
+      Init.get() != S->getInit() ||
       Range.get() != S->getRangeStmt() ||
       Begin.get() != S->getBeginStmt() ||
       End.get() != S->getEndStmt() ||
@@ -7452,7 +7464,7 @@
       Inc.get() != S->getInc() ||
       LoopVar.get() != S->getLoopVarStmt()) {
     NewStmt = getDerived().RebuildCXXForRangeStmt(S->getForLoc(),
-                                                  S->getCoawaitLoc(),
+                                                  S->getCoawaitLoc(), Init.get(),
                                                   S->getColonLoc(), Range.get(),
                                                   Begin.get(), End.get(),
                                                   Cond.get(),
@@ -7470,7 +7482,7 @@
   // it now so we have a new statement to attach the body to.
   if (Body.get() != S->getBody() && NewStmt.get() == S) {
     NewStmt = getDerived().RebuildCXXForRangeStmt(S->getForLoc(),
-                                                  S->getCoawaitLoc(),
+                                                  S->getCoawaitLoc(), Init.get(),
                                                   S->getColonLoc(), Range.get(),
                                                   Begin.get(), End.get(),
                                                   Cond.get(),