[coroutines] Fix co_await for range statement

Summary:
Currently we build the co_await expressions on the wrong implicit statements of the implicit ranged for; Specifically we build the co_await expression wrapping the range declaration, but it should wrap the begin expression.

This patch fixes co_await on range for.

Reviewers: rsmith, GorNishanov

Reviewed By: GorNishanov

Subscribers: cfe-commits

Differential Revision: https://reviews.llvm.org/D34021

llvm-svn: 305363
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 06ae660..b05c099 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -470,11 +470,11 @@
   return ScopeInfo;
 }
 
-static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc,
-                                    StringRef Keyword) {
-  if (!checkCoroutineContext(S, KWLoc, Keyword))
+bool Sema::ActOnCoroutineBodyStart(Scope *SC, SourceLocation KWLoc,
+                                   StringRef Keyword) {
+  if (!checkCoroutineContext(*this, KWLoc, Keyword))
     return false;
-  auto *ScopeInfo = S.getCurFunction();
+  auto *ScopeInfo = getCurFunction();
   assert(ScopeInfo->CoroutinePromise);
 
   // If we have existing coroutine statements then we have already built
@@ -484,24 +484,24 @@
 
   ScopeInfo->setNeedsCoroutineSuspends(false);
 
-  auto *Fn = cast<FunctionDecl>(S.CurContext);
+  auto *Fn = cast<FunctionDecl>(CurContext);
   SourceLocation Loc = Fn->getLocation();
   // Build the initial suspend point
   auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
     ExprResult Suspend =
-        buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None);
+        buildPromiseCall(*this, ScopeInfo->CoroutinePromise, Loc, Name, None);
     if (Suspend.isInvalid())
       return StmtError();
-    Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get());
+    Suspend = buildOperatorCoawaitCall(*this, SC, Loc, Suspend.get());
     if (Suspend.isInvalid())
       return StmtError();
-    Suspend = S.BuildResolvedCoawaitExpr(Loc, Suspend.get(),
-                                         /*IsImplicit*/ true);
-    Suspend = S.ActOnFinishFullExpr(Suspend.get());
+    Suspend = BuildResolvedCoawaitExpr(Loc, Suspend.get(),
+                                       /*IsImplicit*/ true);
+    Suspend = ActOnFinishFullExpr(Suspend.get());
     if (Suspend.isInvalid()) {
-      S.Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
+      Diag(Loc, diag::note_coroutine_promise_suspend_implicitly_required)
           << ((Name == "initial_suspend") ? 0 : 1);
-      S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
+      Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
       return StmtError();
     }
     return cast<Stmt>(Suspend.get());
@@ -521,7 +521,7 @@
 }
 
 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
-  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) {
+  if (!ActOnCoroutineBodyStart(S, Loc, "co_await")) {
     CorrectDelayedTyposInExpr(E);
     return ExprError();
   }
@@ -613,7 +613,7 @@
 }
 
 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
-  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) {
+  if (!ActOnCoroutineBodyStart(S, Loc, "co_yield")) {
     CorrectDelayedTyposInExpr(E);
     return ExprError();
   }
@@ -658,14 +658,15 @@
   if (RSS.IsInvalid)
     return ExprError();
 
-  Expr *Res = new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
-                                        RSS.Results[2], RSS.OpaqueValue);
+  Expr *Res =
+      new (Context) CoyieldExpr(Loc, E, RSS.Results[0], RSS.Results[1],
+                                RSS.Results[2], RSS.OpaqueValue);
 
   return Res;
 }
 
 StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
-  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) {
+  if (!ActOnCoroutineBodyStart(S, Loc, "co_return")) {
     CorrectDelayedTyposInExpr(E);
     return StmtError();
   }
diff --git a/clang/lib/Sema/SemaStmt.cpp b/clang/lib/Sema/SemaStmt.cpp
index 0c35f9a..eed10b0 100644
--- a/clang/lib/Sema/SemaStmt.cpp
+++ b/clang/lib/Sema/SemaStmt.cpp
@@ -1989,11 +1989,11 @@
     return StmtError();
   }
 
-  // Coroutines: 'for co_await' implicitly co_awaits its range.
-  if (CoawaitLoc.isValid()) {
-    ExprResult Coawait = ActOnCoawaitExpr(S, CoawaitLoc, Range);
-    if (Coawait.isInvalid()) return StmtError();
-    Range = Coawait.get();
+  // Build the coroutine state immediately and not later during template
+  // instantiation
+  if (!CoawaitLoc.isInvalid()) {
+    if (!ActOnCoroutineBodyStart(S, CoawaitLoc, "co_await"))
+      return StmtError();
   }
 
   // Build  auto && __range = range-init
@@ -2031,16 +2031,12 @@
 /// BeginExpr and EndExpr are set and FRS_Success is returned on success;
 /// CandidateSet and BEF are set and some non-success value is returned on
 /// failure.
-static Sema::ForRangeStatus BuildNonArrayForRange(Sema &SemaRef,
-                                            Expr *BeginRange, Expr *EndRange,
-                                            QualType RangeType,
-                                            VarDecl *BeginVar,
-                                            VarDecl *EndVar,
-                                            SourceLocation ColonLoc,
-                                            OverloadCandidateSet *CandidateSet,
-                                            ExprResult *BeginExpr,
-                                            ExprResult *EndExpr,
-                                            BeginEndFunction *BEF) {
+static Sema::ForRangeStatus
+BuildNonArrayForRange(Sema &SemaRef, Expr *BeginRange, Expr *EndRange,
+                      QualType RangeType, VarDecl *BeginVar, VarDecl *EndVar,
+                      SourceLocation ColonLoc, SourceLocation CoawaitLoc,
+                      OverloadCandidateSet *CandidateSet, ExprResult *BeginExpr,
+                      ExprResult *EndExpr, BeginEndFunction *BEF) {
   DeclarationNameInfo BeginNameInfo(
       &SemaRef.PP.getIdentifierTable().get("begin"), ColonLoc);
   DeclarationNameInfo EndNameInfo(&SemaRef.PP.getIdentifierTable().get("end"),
@@ -2087,6 +2083,15 @@
           << ColonLoc << BEF_begin << BeginRange->getType();
     return RangeStatus;
   }
+  if (!CoawaitLoc.isInvalid()) {
+    // FIXME: getCurScope() should not be used during template instantiation.
+    // We should pick up the set of unqualified lookup results for operator
+    // co_await during the initial parse.
+    *BeginExpr = SemaRef.ActOnCoawaitExpr(SemaRef.getCurScope(), ColonLoc,
+                                          BeginExpr->get());
+    if (BeginExpr->isInvalid())
+      return Sema::FRS_DiagnosticIssued;
+  }
   if (FinishForRangeVarDecl(SemaRef, BeginVar, BeginExpr->get(), ColonLoc,
                             diag::err_for_range_iter_deduction_failure)) {
     NoteForRangeBeginEndFunction(SemaRef, BeginExpr->get(), *BEF);
@@ -2253,6 +2258,11 @@
 
       // begin-expr is __range.
       BeginExpr = BeginRangeRef;
+      if (!CoawaitLoc.isInvalid()) {
+        BeginExpr = ActOnCoawaitExpr(S, ColonLoc, BeginExpr.get());
+        if (BeginExpr.isInvalid())
+          return StmtError();
+      }
       if (FinishForRangeVarDecl(*this, BeginVar, BeginRangeRef.get(), ColonLoc,
                                 diag::err_for_range_iter_deduction_failure)) {
         NoteForRangeBeginEndFunction(*this, BeginExpr.get(), BEF_begin);
@@ -2335,11 +2345,10 @@
       OverloadCandidateSet CandidateSet(RangeLoc,
                                         OverloadCandidateSet::CSK_Normal);
       BeginEndFunction BEFFailure;
-      ForRangeStatus RangeStatus =
-          BuildNonArrayForRange(*this, BeginRangeRef.get(),
-                                EndRangeRef.get(), RangeType,
-                                BeginVar, EndVar, ColonLoc, &CandidateSet,
-                                &BeginExpr, &EndExpr, &BEFFailure);
+      ForRangeStatus RangeStatus = BuildNonArrayForRange(
+          *this, BeginRangeRef.get(), EndRangeRef.get(), RangeType, BeginVar,
+          EndVar, ColonLoc, CoawaitLoc, &CandidateSet, &BeginExpr, &EndExpr,
+          &BEFFailure);
 
       if (Kind == BFRK_Build && RangeStatus == FRS_NoViableFunction &&
           BEFFailure == BEF_begin) {
@@ -2436,6 +2445,9 @@
 
     IncrExpr = ActOnUnaryOp(S, ColonLoc, tok::plusplus, BeginRef.get());
     if (!IncrExpr.isInvalid() && CoawaitLoc.isValid())
+      // FIXME: getCurScope() should not be used during template instantiation.
+      // We should pick up the set of unqualified lookup results for operator
+      // co_await during the initial parse.
       IncrExpr = ActOnCoawaitExpr(S, CoawaitLoc, IncrExpr.get());
     if (!IncrExpr.isInvalid())
       IncrExpr = ActOnFinishFullExpr(IncrExpr.get());