[coroutines] Add DependentCoawaitExpr and fix re-building CoroutineBodyStmt.

Summary:
The changes contained in this patch are:

1. Defines a new AST node `CoawaitDependentExpr` for representing co_await expressions while the promise type is still dependent.
2. Correctly detect and transform the 'co_await' operand to  `p.await_transform(<expr>)`  when possible.
3. Change the initial/final suspend points to build during the initial parse, so they have the correct operator co_await lookup results.
4.  Fix transformation of the CoroutineBodyStmt so that it doesn't re-build the final/initial suspends.


@rsmith: This change is a little big, but it's not trivial for me to split it up. Please let me know if you would prefer this submitted as multiple patches.



Reviewers: rsmith, GorNishanov

Reviewed By: rsmith

Subscribers: ABataev, rsmith, mehdi_amini, cfe-commits

Differential Revision: https://reviews.llvm.org/D26057

llvm-svn: 297093
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 31bef09..9fec855 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -21,6 +21,16 @@
 using namespace clang;
 using namespace sema;
 
+static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
+                         SourceLocation Loc) {
+  DeclarationName DN = S.PP.getIdentifierInfo(Name);
+  LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
+  // Suppress diagnostics when a private member is selected. The same warnings
+  // will be produced again when building the call.
+  LR.suppressDiagnostics();
+  return S.LookupQualifiedName(LR, RD);
+}
+
 /// Look up the std::coroutine_traits<...>::promise_type for the given
 /// function type.
 static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
@@ -167,42 +177,48 @@
   return !Diagnosed;
 }
 
-/// Check that this is a context in which a coroutine suspension can appear.
-static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
-                                                StringRef Keyword) {
-  if (!isValidCoroutineContext(S, Loc, Keyword))
-    return nullptr;
+static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
+                                                 SourceLocation Loc) {
+  DeclarationName OpName =
+      SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
+  LookupResult Operators(SemaRef, OpName, SourceLocation(),
+                         Sema::LookupOperatorName);
+  SemaRef.LookupName(Operators, S);
 
-  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
-  auto *FD = cast<FunctionDecl>(S.CurContext);
-  auto *ScopeInfo = S.getCurFunction();
-  assert(ScopeInfo && "missing function scope for function");
+  assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
+  const auto &Functions = Operators.asUnresolvedSet();
+  bool IsOverloaded =
+      Functions.size() > 1 ||
+      (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
+  Expr *CoawaitOp = UnresolvedLookupExpr::Create(
+      SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
+      DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
+      Functions.begin(), Functions.end());
+  assert(CoawaitOp);
+  return CoawaitOp;
+}
 
-  // If we don't have a promise variable, build one now.
-  if (!ScopeInfo->CoroutinePromise) {
-    QualType T = FD->getType()->isDependentType()
-                     ? S.Context.DependentTy
-                     : lookupPromiseType(
-                           S, FD->getType()->castAs<FunctionProtoType>(),
-                           Loc, FD->getLocation());
-    if (T.isNull())
-      return nullptr;
+/// Build a call to 'operator co_await' if there is a suitable operator for
+/// the given expression.
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
+                                           Expr *E,
+                                           UnresolvedLookupExpr *Lookup) {
+  UnresolvedSet<16> Functions;
+  Functions.append(Lookup->decls_begin(), Lookup->decls_end());
+  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
+}
 
-    // Create and default-initialize the promise.
-    ScopeInfo->CoroutinePromise =
-        VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
-                        &S.PP.getIdentifierTable().get("__promise"), T,
-                        S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
-    S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
-    if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
-      S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise);
-  }
-
-  return ScopeInfo;
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
+                                           SourceLocation Loc, Expr *E) {
+  ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
+  if (R.isInvalid())
+    return ExprError();
+  return buildOperatorCoawaitCall(SemaRef, Loc, E,
+                                  cast<UnresolvedLookupExpr>(R.get()));
 }
 
 static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
-                              MutableArrayRef<Expr *> CallArgs) {
+                              MultiExprArg CallArgs) {
   StringRef Name = S.Context.BuiltinInfo.getName(Id);
   LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
   S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
@@ -221,15 +237,6 @@
   return Call.get();
 }
 
-/// Build a call to 'operator co_await' if there is a suitable operator for
-/// the given expression.
-static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
-                                           SourceLocation Loc, Expr *E) {
-  UnresolvedSet<16> Functions;
-  SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
-                                       Functions);
-  return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
-}
 
 struct ReadySuspendResumeResult {
   bool IsInvalid;
@@ -237,8 +244,7 @@
 };
 
 static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
-                                  StringRef Name,
-                                  MutableArrayRef<Expr *> Args) {
+                                  StringRef Name, MultiExprArg Args) {
   DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
 
   // FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
@@ -276,25 +282,174 @@
   return Calls;
 }
 
+static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
+                                   SourceLocation Loc, StringRef Name,
+                                   MultiExprArg Args) {
+
+  // Form a reference to the promise.
+  ExprResult PromiseRef = S.BuildDeclRefExpr(
+      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
+  if (PromiseRef.isInvalid())
+    return ExprError();
+
+  // Call 'yield_value', passing in E.
+  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
+}
+
+VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
+  assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
+  auto *FD = cast<FunctionDecl>(CurContext);
+
+  QualType T =
+      FD->getType()->isDependentType()
+          ? Context.DependentTy
+          : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(),
+                              Loc, FD->getLocation());
+  if (T.isNull())
+    return nullptr;
+
+  auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
+                             &PP.getIdentifierTable().get("__promise"), T,
+                             Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
+  CheckVariableDeclarationType(VD);
+  if (VD->isInvalidDecl())
+    return nullptr;
+  ActOnUninitializedDecl(VD);
+  assert(!VD->isInvalidDecl());
+  return VD;
+}
+
+/// Check that this is a context in which a coroutine suspension can appear.
+static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
+                                                StringRef Keyword) {
+  if (!isValidCoroutineContext(S, Loc, Keyword))
+    return nullptr;
+
+  assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
+  auto *FD = cast<FunctionDecl>(S.CurContext);
+
+  auto *ScopeInfo = S.getCurFunction();
+  assert(ScopeInfo && "missing function scope for function");
+
+  if (ScopeInfo->CoroutinePromise)
+    return ScopeInfo;
+
+  ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
+  if (!ScopeInfo->CoroutinePromise)
+    return nullptr;
+
+  return ScopeInfo;
+}
+
+static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc,
+                                    StringRef Keyword) {
+  if (!checkCoroutineContext(S, KWLoc, Keyword))
+    return false;
+  auto *ScopeInfo = S.getCurFunction();
+  assert(ScopeInfo->CoroutinePromise);
+
+  // If we have existing coroutine statements then we have already built
+  // the initial and final suspend points.
+  if (!ScopeInfo->NeedsCoroutineSuspends)
+    return true;
+
+  ScopeInfo->setNeedsCoroutineSuspends(false);
+
+  auto *Fn = cast<FunctionDecl>(S.CurContext);
+  SourceLocation Loc = Fn->getLocation();
+  // Build the initial suspend point
+  auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
+    ExprResult Suspend =
+        buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None);
+    if (Suspend.isInvalid())
+      return StmtError();
+    Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get());
+    if (Suspend.isInvalid())
+      return StmtError();
+    Suspend = S.BuildResolvedCoawaitExpr(Loc, Suspend.get(),
+                                         /*IsImplicit*/ true);
+    Suspend = S.ActOnFinishFullExpr(Suspend.get());
+    if (Suspend.isInvalid()) {
+      S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+          << ((Name == "initial_suspend") ? 0 : 1);
+      S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
+      return StmtError();
+    }
+    return cast<Stmt>(Suspend.get());
+  };
+
+  StmtResult InitSuspend = buildSuspends("initial_suspend");
+  if (InitSuspend.isInvalid())
+    return true;
+
+  StmtResult FinalSuspend = buildSuspends("final_suspend");
+  if (FinalSuspend.isInvalid())
+    return true;
+
+  ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
+
+  return true;
+}
+
 ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
-  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
-  if (!Coroutine) {
+  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) {
     CorrectDelayedTyposInExpr(E);
     return ExprError();
   }
+
   if (E->getType()->isPlaceholderType()) {
     ExprResult R = CheckPlaceholderExpr(E);
     if (R.isInvalid()) return ExprError();
     E = R.get();
   }
+  ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
+  if (Lookup.isInvalid())
+    return ExprError();
+  return BuildUnresolvedCoawaitExpr(Loc, E,
+                                   cast<UnresolvedLookupExpr>(Lookup.get()));
+}
 
-  ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
+ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
+                                           UnresolvedLookupExpr *Lookup) {
+  auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
+  if (!FSI)
+    return ExprError();
+
+  if (E->getType()->isPlaceholderType()) {
+    ExprResult R = CheckPlaceholderExpr(E);
+    if (R.isInvalid())
+      return ExprError();
+    E = R.get();
+  }
+
+  auto *Promise = FSI->CoroutinePromise;
+  if (Promise->getType()->isDependentType()) {
+    Expr *Res =
+        new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
+    FSI->CoroutineStmts.push_back(Res);
+    return Res;
+  }
+
+  auto *RD = Promise->getType()->getAsCXXRecordDecl();
+  if (lookupMember(*this, "await_transform", RD, Loc)) {
+    ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
+    if (R.isInvalid()) {
+      Diag(Loc,
+           diag::note_coroutine_promise_implicit_await_transform_required_here)
+          << E->getSourceRange();
+      return ExprError();
+    }
+    E = R.get();
+  }
+  ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
   if (Awaitable.isInvalid())
     return ExprError();
 
-  return BuildCoawaitExpr(Loc, Awaitable.get());
+  return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
 }
-ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
+
+ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
+                                  bool IsImplicit) {
   auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
   if (!Coroutine)
     return ExprError();
@@ -306,8 +461,10 @@
   }
 
   if (E->getType()->isDependentType()) {
-    Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
-    Coroutine->CoroutineStmts.push_back(Res);
+    Expr *Res = new (Context)
+        CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
+    if (!IsImplicit)
+      Coroutine->CoroutineStmts.push_back(Res);
     return Res;
   }
 
@@ -322,37 +479,21 @@
     return ExprError();
 
   Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
-                                        RSS.Results[2]);
-  Coroutine->CoroutineStmts.push_back(Res);
+                                        RSS.Results[2], IsImplicit);
+  if (!IsImplicit)
+    Coroutine->CoroutineStmts.push_back(Res);
   return Res;
 }
 
-static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
-                                   SourceLocation Loc, StringRef Name,
-                                   MutableArrayRef<Expr *> Args) {
-  assert(Coroutine->CoroutinePromise && "no promise for coroutine");
-
-  // Form a reference to the promise.
-  auto *Promise = Coroutine->CoroutinePromise;
-  ExprResult PromiseRef = S.BuildDeclRefExpr(
-      Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
-  if (PromiseRef.isInvalid())
-    return ExprError();
-
-  // Call 'yield_value', passing in E.
-  return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
-}
-
 ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
-  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
-  if (!Coroutine) {
+  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) {
     CorrectDelayedTyposInExpr(E);
     return ExprError();
   }
 
   // Build yield_value call.
-  ExprResult Awaitable =
-      buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
+  ExprResult Awaitable = buildPromiseCall(
+      *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
   if (Awaitable.isInvalid())
     return ExprError();
 
@@ -396,18 +537,18 @@
   return Res;
 }
 
-StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
-  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
-  if (!Coroutine) {
+StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
+  if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) {
     CorrectDelayedTyposInExpr(E);
     return StmtError();
   }
   return BuildCoreturnStmt(Loc, E);
 }
 
-StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
-  auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
-  if (!Coroutine)
+StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
+                                   bool IsImplicit) {
+  auto *FSI = checkCoroutineContext(*this, Loc, "co_return");
+  if (!FSI)
     return StmtError();
 
   if (E && E->getType()->isPlaceholderType() &&
@@ -420,20 +561,22 @@
   // FIXME: If the operand is a reference to a variable that's about to go out
   // of scope, we should treat the operand as an xvalue for this overload
   // resolution.
+  VarDecl *Promise = FSI->CoroutinePromise;
   ExprResult PC;
   if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) {
-    PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
+    PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
   } else {
     E = MakeFullDiscardedValueExpr(E).get();
-    PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
+    PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
   }
   if (PC.isInvalid())
     return StmtError();
 
   Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
 
-  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
-  Coroutine->CoroutineStmts.push_back(Res);
+  Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
+  if (!IsImplicit)
+    FSI->CoroutineStmts.push_back(Res);
   return Res;
 }
 
@@ -490,14 +633,91 @@
   return OperatorDelete;
 }
 
-// Builds allocation and deallocation for the coroutine. Returns false on
-// failure.
-static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc,
-                                           FunctionScopeInfo *Fn,
-                                           Expr *&Allocation,
-                                           Expr *&Deallocation) {
-  TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo();
-  QualType PromiseType = TInfo->getType();
+namespace {
+class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs {
+  Sema &S;
+  FunctionDecl &FD;
+  FunctionScopeInfo &Fn;
+  bool IsValid;
+  SourceLocation Loc;
+  QualType RetType;
+  SmallVector<Stmt *, 4> ParamMovesVector;
+  const bool IsPromiseDependentType;
+  CXXRecordDecl *PromiseRecordDecl = nullptr;
+
+public:
+  SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body)
+      : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
+        IsPromiseDependentType(
+            !Fn.CoroutinePromise ||
+            Fn.CoroutinePromise->getType()->isDependentType()) {
+    this->Body = Body;
+    if (!IsPromiseDependentType) {
+      PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
+      assert(PromiseRecordDecl && "Type should have already been checked");
+    }
+    this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend() &&
+                    makeOnException() && makeOnFallthrough() &&
+                    makeNewAndDeleteExpr() && makeReturnObject() &&
+                    makeParamMoves();
+  }
+
+  bool isInvalid() const { return !this->IsValid; }
+
+  bool makePromiseStmt();
+  bool makeInitialAndFinalSuspend();
+  bool makeNewAndDeleteExpr();
+  bool makeOnFallthrough();
+  bool makeOnException();
+  bool makeReturnObject();
+  bool makeParamMoves();
+};
+}
+
+void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
+  FunctionScopeInfo *Fn = getCurFunction();
+  assert(Fn && Fn->CoroutinePromise && "not a coroutine");
+
+  // Coroutines [stmt.return]p1:
+  //   A return statement shall not appear in a coroutine.
+  if (Fn->FirstReturnLoc.isValid()) {
+    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
+    auto *First = Fn->CoroutineStmts[0];
+    Diag(First->getLocStart(), diag::note_declared_coroutine_here)
+        << (isa<CoawaitExpr>(First) ? "co_await" :
+            isa<CoyieldExpr>(First) ? "co_yield" : "co_return");
+  }
+  SubStmtBuilder Builder(*this, *FD, *Fn, Body);
+  if (Builder.isInvalid())
+    return FD->setInvalidDecl();
+
+  // Build body for the coroutine wrapper statement.
+  Body = CoroutineBodyStmt::Create(Context, Builder);
+}
+
+bool SubStmtBuilder::makePromiseStmt() {
+  // Form a declaration statement for the promise declaration, so that AST
+  // visitors can more easily find it.
+  StmtResult PromiseStmt =
+      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
+  if (PromiseStmt.isInvalid())
+    return false;
+
+  this->Promise = PromiseStmt.get();
+  return true;
+}
+
+bool SubStmtBuilder::makeInitialAndFinalSuspend() {
+  if (Fn.hasInvalidCoroutineSuspends())
+    return false;
+  this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
+  this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
+  return true;
+}
+
+bool SubStmtBuilder::makeNewAndDeleteExpr() {
+  // Form and check allocation and deallocation calls.
+  QualType PromiseType = Fn.CoroutinePromise->getType();
   if (PromiseType->isDependentType())
     return true;
 
@@ -540,8 +760,6 @@
   if (NewExpr.isInvalid())
     return false;
 
-  Allocation = NewExpr.get();
-
   // Make delete call.
 
   QualType OpDeleteQualType = OperatorDelete->getType();
@@ -567,122 +785,12 @@
   if (DeleteExpr.isInvalid())
     return false;
 
-  Deallocation = DeleteExpr.get();
+  this->Allocate = NewExpr.get();
+  this->Deallocate = DeleteExpr.get();
 
   return true;
 }
 
-namespace {
-class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs {
-  Sema &S;
-  FunctionDecl &FD;
-  FunctionScopeInfo &Fn;
-  bool IsValid;
-  SourceLocation Loc;
-  QualType RetType;
-  SmallVector<Stmt *, 4> ParamMovesVector;
-  const bool IsPromiseDependentType;
-  CXXRecordDecl *PromiseRecordDecl = nullptr;
-
-public:
-  SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body)
-      : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
-        IsPromiseDependentType(
-            !Fn.CoroutinePromise ||
-            Fn.CoroutinePromise->getType()->isDependentType()) {
-    this->Body = Body;
-    if (!IsPromiseDependentType) {
-      PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
-      assert(PromiseRecordDecl && "Type should have already been checked");
-    }
-    this->IsValid = makePromiseStmt() && makeInitialSuspend() &&
-                    makeFinalSuspend() && makeOnException() &&
-                    makeOnFallthrough() && makeNewAndDeleteExpr() &&
-                    makeReturnObject() && makeParamMoves();
-  }
-
-  bool isInvalid() const { return !this->IsValid; }
-
-  bool makePromiseStmt();
-  bool makeInitialSuspend();
-  bool makeFinalSuspend();
-  bool makeNewAndDeleteExpr();
-  bool makeOnFallthrough();
-  bool makeOnException();
-  bool makeReturnObject();
-  bool makeParamMoves();
-};
-}
-
-void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
-  FunctionScopeInfo *Fn = getCurFunction();
-  assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
-
-  // Coroutines [stmt.return]p1:
-  //   A return statement shall not appear in a coroutine.
-  if (Fn->FirstReturnLoc.isValid()) {
-    Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
-    auto *First = Fn->CoroutineStmts[0];
-    Diag(First->getLocStart(), diag::note_declared_coroutine_here)
-        << (isa<CoawaitExpr>(First) ? 0 :
-            isa<CoyieldExpr>(First) ? 1 : 2);
-  }
-  SubStmtBuilder Builder(*this, *FD, *Fn, Body);
-  if (Builder.isInvalid())
-    return FD->setInvalidDecl();
-
-  // Build body for the coroutine wrapper statement.
-  Body = CoroutineBodyStmt::Create(Context, Builder);
-}
-
-bool SubStmtBuilder::makePromiseStmt() {
-  // Form a declaration statement for the promise declaration, so that AST
-  // visitors can more easily find it.
-  StmtResult PromiseStmt =
-      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
-  if (PromiseStmt.isInvalid())
-    return false;
-
-  this->Promise = PromiseStmt.get();
-  return true;
-}
-
-bool SubStmtBuilder::makeInitialSuspend() {
-  // Form and check implicit 'co_await p.initial_suspend();' statement.
-  ExprResult InitialSuspend =
-      buildPromiseCall(S, &Fn, Loc, "initial_suspend", None);
-  // FIXME: Support operator co_await here.
-  if (!InitialSuspend.isInvalid())
-    InitialSuspend = S.BuildCoawaitExpr(Loc, InitialSuspend.get());
-  InitialSuspend = S.ActOnFinishFullExpr(InitialSuspend.get());
-  if (InitialSuspend.isInvalid())
-    return false;
-
-  this->InitialSuspend = InitialSuspend.get();
-  return true;
-}
-
-bool SubStmtBuilder::makeFinalSuspend() {
-  // Form and check implicit 'co_await p.final_suspend();' statement.
-  ExprResult FinalSuspend =
-      buildPromiseCall(S, &Fn, Loc, "final_suspend", None);
-  // FIXME: Support operator co_await here.
-  if (!FinalSuspend.isInvalid())
-    FinalSuspend = S.BuildCoawaitExpr(Loc, FinalSuspend.get());
-  FinalSuspend = S.ActOnFinishFullExpr(FinalSuspend.get());
-  if (FinalSuspend.isInvalid())
-    return false;
-
-  this->FinalSuspend = FinalSuspend.get();
-  return true;
-}
-
-bool SubStmtBuilder::makeNewAndDeleteExpr() {
-  // Form and check allocation and deallocation calls.
-  return buildAllocationAndDeallocation(S, Loc, &Fn, this->Allocate,
-                                        this->Deallocate);
-}
-
 bool SubStmtBuilder::makeOnFallthrough() {
   if (!PromiseRecordDecl)
     return true;
@@ -690,13 +798,8 @@
   // [dcl.fct.def.coroutine]/4
   // The unqualified-ids 'return_void' and 'return_value' are looked up in
   // the scope of class P. If both are found, the program is ill-formed.
-  DeclarationName RVoidDN = S.PP.getIdentifierInfo("return_void");
-  LookupResult RVoidResult(S, RVoidDN, Loc, Sema::LookupMemberName);
-  const bool HasRVoid = S.LookupQualifiedName(RVoidResult, PromiseRecordDecl);
-
-  DeclarationName RValueDN = S.PP.getIdentifierInfo("return_value");
-  LookupResult RValueResult(S, RValueDN, Loc, Sema::LookupMemberName);
-  const bool HasRValue = S.LookupQualifiedName(RValueResult, PromiseRecordDecl);
+  const bool HasRVoid = lookupMember(S, "return_void", PromiseRecordDecl, Loc);
+  const bool HasRValue = lookupMember(S, "return_value", PromiseRecordDecl, Loc);
 
   StmtResult Fallthrough;
   if (HasRVoid && HasRValue) {
@@ -708,7 +811,8 @@
     // If the unqualified-id return_void is found, flowing off the end of a
     // coroutine is equivalent to a co_return with no operand. Otherwise,
     // flowing off the end of a coroutine results in undefined behavior.
-    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr);
+    Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
+                                      /*IsImplicit*/false);
     Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
     if (Fallthrough.isInvalid())
       return false;
@@ -736,15 +840,13 @@
   // [dcl.fct.def.coroutine]/3
   // The unqualified-id set_exception is found in the scope of P by class
   // member access lookup (3.4.5).
-  DeclarationName SetExDN = S.PP.getIdentifierInfo("set_exception");
-  LookupResult SetExResult(S, SetExDN, Loc, Sema::LookupMemberName);
-  if (S.LookupQualifiedName(SetExResult, PromiseRecordDecl)) {
+  if (lookupMember(S, "set_exception", PromiseRecordDecl, Loc)) {
     // Form the call 'p.set_exception(std::current_exception())'
     SetException = buildStdCurrentExceptionCall(S, Loc);
     if (SetException.isInvalid())
       return false;
     Expr *E = SetException.get();
-    SetException = buildPromiseCall(S, &Fn, Loc, "set_exception", E);
+    SetException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, "set_exception", E);
     SetException = S.ActOnFinishFullExpr(SetException.get(), Loc);
     if (SetException.isInvalid())
       return false;
@@ -759,7 +861,7 @@
   // Build implicit 'p.get_return_object()' expression and form initialization
   // of return type from it.
   ExprResult ReturnObject =
-      buildPromiseCall(S, &Fn, Loc, "get_return_object", None);
+      buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
   if (ReturnObject.isInvalid())
     return false;
   QualType RetType = FD.getReturnType();
@@ -783,3 +885,10 @@
   // FIXME: Perform move-initialization of parameters into frame-local copies.
   return true;
 }
+
+StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
+  CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
+  if (!Res)
+    return StmtError();
+  return Res;
+}