[coroutines] Add DependentCoawaitExpr and fix re-building CoroutineBodyStmt.
Summary:
The changes contained in this patch are:
1. Defines a new AST node `CoawaitDependentExpr` for representing co_await expressions while the promise type is still dependent.
2. Correctly detect and transform the 'co_await' operand to `p.await_transform(<expr>)` when possible.
3. Change the initial/final suspend points to build during the initial parse, so they have the correct operator co_await lookup results.
4. Fix transformation of the CoroutineBodyStmt so that it doesn't re-build the final/initial suspends.
@rsmith: This change is a little big, but it's not trivial for me to split it up. Please let me know if you would prefer this submitted as multiple patches.
Reviewers: rsmith, GorNishanov
Reviewed By: rsmith
Subscribers: ABataev, rsmith, mehdi_amini, cfe-commits
Differential Revision: https://reviews.llvm.org/D26057
llvm-svn: 297093
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 31bef09..9fec855 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -21,6 +21,16 @@
using namespace clang;
using namespace sema;
+static bool lookupMember(Sema &S, const char *Name, CXXRecordDecl *RD,
+ SourceLocation Loc) {
+ DeclarationName DN = S.PP.getIdentifierInfo(Name);
+ LookupResult LR(S, DN, Loc, Sema::LookupMemberName);
+ // Suppress diagnostics when a private member is selected. The same warnings
+ // will be produced again when building the call.
+ LR.suppressDiagnostics();
+ return S.LookupQualifiedName(LR, RD);
+}
+
/// Look up the std::coroutine_traits<...>::promise_type for the given
/// function type.
static QualType lookupPromiseType(Sema &S, const FunctionProtoType *FnType,
@@ -167,42 +177,48 @@
return !Diagnosed;
}
-/// Check that this is a context in which a coroutine suspension can appear.
-static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
- StringRef Keyword) {
- if (!isValidCoroutineContext(S, Loc, Keyword))
- return nullptr;
+static ExprResult buildOperatorCoawaitLookupExpr(Sema &SemaRef, Scope *S,
+ SourceLocation Loc) {
+ DeclarationName OpName =
+ SemaRef.Context.DeclarationNames.getCXXOperatorName(OO_Coawait);
+ LookupResult Operators(SemaRef, OpName, SourceLocation(),
+ Sema::LookupOperatorName);
+ SemaRef.LookupName(Operators, S);
- assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
- auto *FD = cast<FunctionDecl>(S.CurContext);
- auto *ScopeInfo = S.getCurFunction();
- assert(ScopeInfo && "missing function scope for function");
+ assert(!Operators.isAmbiguous() && "Operator lookup cannot be ambiguous");
+ const auto &Functions = Operators.asUnresolvedSet();
+ bool IsOverloaded =
+ Functions.size() > 1 ||
+ (Functions.size() == 1 && isa<FunctionTemplateDecl>(*Functions.begin()));
+ Expr *CoawaitOp = UnresolvedLookupExpr::Create(
+ SemaRef.Context, /*NamingClass*/ nullptr, NestedNameSpecifierLoc(),
+ DeclarationNameInfo(OpName, Loc), /*RequiresADL*/ true, IsOverloaded,
+ Functions.begin(), Functions.end());
+ assert(CoawaitOp);
+ return CoawaitOp;
+}
- // If we don't have a promise variable, build one now.
- if (!ScopeInfo->CoroutinePromise) {
- QualType T = FD->getType()->isDependentType()
- ? S.Context.DependentTy
- : lookupPromiseType(
- S, FD->getType()->castAs<FunctionProtoType>(),
- Loc, FD->getLocation());
- if (T.isNull())
- return nullptr;
+/// Build a call to 'operator co_await' if there is a suitable operator for
+/// the given expression.
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, SourceLocation Loc,
+ Expr *E,
+ UnresolvedLookupExpr *Lookup) {
+ UnresolvedSet<16> Functions;
+ Functions.append(Lookup->decls_begin(), Lookup->decls_end());
+ return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
+}
- // Create and default-initialize the promise.
- ScopeInfo->CoroutinePromise =
- VarDecl::Create(S.Context, FD, FD->getLocation(), FD->getLocation(),
- &S.PP.getIdentifierTable().get("__promise"), T,
- S.Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
- S.CheckVariableDeclarationType(ScopeInfo->CoroutinePromise);
- if (!ScopeInfo->CoroutinePromise->isInvalidDecl())
- S.ActOnUninitializedDecl(ScopeInfo->CoroutinePromise);
- }
-
- return ScopeInfo;
+static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
+ SourceLocation Loc, Expr *E) {
+ ExprResult R = buildOperatorCoawaitLookupExpr(SemaRef, S, Loc);
+ if (R.isInvalid())
+ return ExprError();
+ return buildOperatorCoawaitCall(SemaRef, Loc, E,
+ cast<UnresolvedLookupExpr>(R.get()));
}
static Expr *buildBuiltinCall(Sema &S, SourceLocation Loc, Builtin::ID Id,
- MutableArrayRef<Expr *> CallArgs) {
+ MultiExprArg CallArgs) {
StringRef Name = S.Context.BuiltinInfo.getName(Id);
LookupResult R(S, &S.Context.Idents.get(Name), Loc, Sema::LookupOrdinaryName);
S.LookupName(R, S.TUScope, /*AllowBuiltinCreation=*/true);
@@ -221,15 +237,6 @@
return Call.get();
}
-/// Build a call to 'operator co_await' if there is a suitable operator for
-/// the given expression.
-static ExprResult buildOperatorCoawaitCall(Sema &SemaRef, Scope *S,
- SourceLocation Loc, Expr *E) {
- UnresolvedSet<16> Functions;
- SemaRef.LookupOverloadedOperatorName(OO_Coawait, S, E->getType(), QualType(),
- Functions);
- return SemaRef.CreateOverloadedUnaryOp(Loc, UO_Coawait, Functions, E);
-}
struct ReadySuspendResumeResult {
bool IsInvalid;
@@ -237,8 +244,7 @@
};
static ExprResult buildMemberCall(Sema &S, Expr *Base, SourceLocation Loc,
- StringRef Name,
- MutableArrayRef<Expr *> Args) {
+ StringRef Name, MultiExprArg Args) {
DeclarationNameInfo NameInfo(&S.PP.getIdentifierTable().get(Name), Loc);
// FIXME: Fix BuildMemberReferenceExpr to take a const CXXScopeSpec&.
@@ -276,25 +282,174 @@
return Calls;
}
+static ExprResult buildPromiseCall(Sema &S, VarDecl *Promise,
+ SourceLocation Loc, StringRef Name,
+ MultiExprArg Args) {
+
+ // Form a reference to the promise.
+ ExprResult PromiseRef = S.BuildDeclRefExpr(
+ Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
+ if (PromiseRef.isInvalid())
+ return ExprError();
+
+ // Call 'yield_value', passing in E.
+ return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
+}
+
+VarDecl *Sema::buildCoroutinePromise(SourceLocation Loc) {
+ assert(isa<FunctionDecl>(CurContext) && "not in a function scope");
+ auto *FD = cast<FunctionDecl>(CurContext);
+
+ QualType T =
+ FD->getType()->isDependentType()
+ ? Context.DependentTy
+ : lookupPromiseType(*this, FD->getType()->castAs<FunctionProtoType>(),
+ Loc, FD->getLocation());
+ if (T.isNull())
+ return nullptr;
+
+ auto *VD = VarDecl::Create(Context, FD, FD->getLocation(), FD->getLocation(),
+ &PP.getIdentifierTable().get("__promise"), T,
+ Context.getTrivialTypeSourceInfo(T, Loc), SC_None);
+ CheckVariableDeclarationType(VD);
+ if (VD->isInvalidDecl())
+ return nullptr;
+ ActOnUninitializedDecl(VD);
+ assert(!VD->isInvalidDecl());
+ return VD;
+}
+
+/// Check that this is a context in which a coroutine suspension can appear.
+static FunctionScopeInfo *checkCoroutineContext(Sema &S, SourceLocation Loc,
+ StringRef Keyword) {
+ if (!isValidCoroutineContext(S, Loc, Keyword))
+ return nullptr;
+
+ assert(isa<FunctionDecl>(S.CurContext) && "not in a function scope");
+ auto *FD = cast<FunctionDecl>(S.CurContext);
+
+ auto *ScopeInfo = S.getCurFunction();
+ assert(ScopeInfo && "missing function scope for function");
+
+ if (ScopeInfo->CoroutinePromise)
+ return ScopeInfo;
+
+ ScopeInfo->CoroutinePromise = S.buildCoroutinePromise(Loc);
+ if (!ScopeInfo->CoroutinePromise)
+ return nullptr;
+
+ return ScopeInfo;
+}
+
+static bool actOnCoroutineBodyStart(Sema &S, Scope *SC, SourceLocation KWLoc,
+ StringRef Keyword) {
+ if (!checkCoroutineContext(S, KWLoc, Keyword))
+ return false;
+ auto *ScopeInfo = S.getCurFunction();
+ assert(ScopeInfo->CoroutinePromise);
+
+ // If we have existing coroutine statements then we have already built
+ // the initial and final suspend points.
+ if (!ScopeInfo->NeedsCoroutineSuspends)
+ return true;
+
+ ScopeInfo->setNeedsCoroutineSuspends(false);
+
+ auto *Fn = cast<FunctionDecl>(S.CurContext);
+ SourceLocation Loc = Fn->getLocation();
+ // Build the initial suspend point
+ auto buildSuspends = [&](StringRef Name) mutable -> StmtResult {
+ ExprResult Suspend =
+ buildPromiseCall(S, ScopeInfo->CoroutinePromise, Loc, Name, None);
+ if (Suspend.isInvalid())
+ return StmtError();
+ Suspend = buildOperatorCoawaitCall(S, SC, Loc, Suspend.get());
+ if (Suspend.isInvalid())
+ return StmtError();
+ Suspend = S.BuildResolvedCoawaitExpr(Loc, Suspend.get(),
+ /*IsImplicit*/ true);
+ Suspend = S.ActOnFinishFullExpr(Suspend.get());
+ if (Suspend.isInvalid()) {
+ S.Diag(Loc, diag::note_coroutine_promise_call_implicitly_required)
+ << ((Name == "initial_suspend") ? 0 : 1);
+ S.Diag(KWLoc, diag::note_declared_coroutine_here) << Keyword;
+ return StmtError();
+ }
+ return cast<Stmt>(Suspend.get());
+ };
+
+ StmtResult InitSuspend = buildSuspends("initial_suspend");
+ if (InitSuspend.isInvalid())
+ return true;
+
+ StmtResult FinalSuspend = buildSuspends("final_suspend");
+ if (FinalSuspend.isInvalid())
+ return true;
+
+ ScopeInfo->setCoroutineSuspends(InitSuspend.get(), FinalSuspend.get());
+
+ return true;
+}
+
ExprResult Sema::ActOnCoawaitExpr(Scope *S, SourceLocation Loc, Expr *E) {
- auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
- if (!Coroutine) {
+ if (!actOnCoroutineBodyStart(*this, S, Loc, "co_await")) {
CorrectDelayedTyposInExpr(E);
return ExprError();
}
+
if (E->getType()->isPlaceholderType()) {
ExprResult R = CheckPlaceholderExpr(E);
if (R.isInvalid()) return ExprError();
E = R.get();
}
+ ExprResult Lookup = buildOperatorCoawaitLookupExpr(*this, S, Loc);
+ if (Lookup.isInvalid())
+ return ExprError();
+ return BuildUnresolvedCoawaitExpr(Loc, E,
+ cast<UnresolvedLookupExpr>(Lookup.get()));
+}
- ExprResult Awaitable = buildOperatorCoawaitCall(*this, S, Loc, E);
+ExprResult Sema::BuildUnresolvedCoawaitExpr(SourceLocation Loc, Expr *E,
+ UnresolvedLookupExpr *Lookup) {
+ auto *FSI = checkCoroutineContext(*this, Loc, "co_await");
+ if (!FSI)
+ return ExprError();
+
+ if (E->getType()->isPlaceholderType()) {
+ ExprResult R = CheckPlaceholderExpr(E);
+ if (R.isInvalid())
+ return ExprError();
+ E = R.get();
+ }
+
+ auto *Promise = FSI->CoroutinePromise;
+ if (Promise->getType()->isDependentType()) {
+ Expr *Res =
+ new (Context) DependentCoawaitExpr(Loc, Context.DependentTy, E, Lookup);
+ FSI->CoroutineStmts.push_back(Res);
+ return Res;
+ }
+
+ auto *RD = Promise->getType()->getAsCXXRecordDecl();
+ if (lookupMember(*this, "await_transform", RD, Loc)) {
+ ExprResult R = buildPromiseCall(*this, Promise, Loc, "await_transform", E);
+ if (R.isInvalid()) {
+ Diag(Loc,
+ diag::note_coroutine_promise_implicit_await_transform_required_here)
+ << E->getSourceRange();
+ return ExprError();
+ }
+ E = R.get();
+ }
+ ExprResult Awaitable = buildOperatorCoawaitCall(*this, Loc, E, Lookup);
if (Awaitable.isInvalid())
return ExprError();
- return BuildCoawaitExpr(Loc, Awaitable.get());
+ return BuildResolvedCoawaitExpr(Loc, Awaitable.get());
}
-ExprResult Sema::BuildCoawaitExpr(SourceLocation Loc, Expr *E) {
+
+ExprResult Sema::BuildResolvedCoawaitExpr(SourceLocation Loc, Expr *E,
+ bool IsImplicit) {
auto *Coroutine = checkCoroutineContext(*this, Loc, "co_await");
if (!Coroutine)
return ExprError();
@@ -306,8 +461,10 @@
}
if (E->getType()->isDependentType()) {
- Expr *Res = new (Context) CoawaitExpr(Loc, Context.DependentTy, E);
- Coroutine->CoroutineStmts.push_back(Res);
+ Expr *Res = new (Context)
+ CoawaitExpr(Loc, Context.DependentTy, E, IsImplicit);
+ if (!IsImplicit)
+ Coroutine->CoroutineStmts.push_back(Res);
return Res;
}
@@ -322,37 +479,21 @@
return ExprError();
Expr *Res = new (Context) CoawaitExpr(Loc, E, RSS.Results[0], RSS.Results[1],
- RSS.Results[2]);
- Coroutine->CoroutineStmts.push_back(Res);
+ RSS.Results[2], IsImplicit);
+ if (!IsImplicit)
+ Coroutine->CoroutineStmts.push_back(Res);
return Res;
}
-static ExprResult buildPromiseCall(Sema &S, FunctionScopeInfo *Coroutine,
- SourceLocation Loc, StringRef Name,
- MutableArrayRef<Expr *> Args) {
- assert(Coroutine->CoroutinePromise && "no promise for coroutine");
-
- // Form a reference to the promise.
- auto *Promise = Coroutine->CoroutinePromise;
- ExprResult PromiseRef = S.BuildDeclRefExpr(
- Promise, Promise->getType().getNonReferenceType(), VK_LValue, Loc);
- if (PromiseRef.isInvalid())
- return ExprError();
-
- // Call 'yield_value', passing in E.
- return buildMemberCall(S, PromiseRef.get(), Loc, Name, Args);
-}
-
ExprResult Sema::ActOnCoyieldExpr(Scope *S, SourceLocation Loc, Expr *E) {
- auto *Coroutine = checkCoroutineContext(*this, Loc, "co_yield");
- if (!Coroutine) {
+ if (!actOnCoroutineBodyStart(*this, S, Loc, "co_yield")) {
CorrectDelayedTyposInExpr(E);
return ExprError();
}
// Build yield_value call.
- ExprResult Awaitable =
- buildPromiseCall(*this, Coroutine, Loc, "yield_value", E);
+ ExprResult Awaitable = buildPromiseCall(
+ *this, getCurFunction()->CoroutinePromise, Loc, "yield_value", E);
if (Awaitable.isInvalid())
return ExprError();
@@ -396,18 +537,18 @@
return Res;
}
-StmtResult Sema::ActOnCoreturnStmt(SourceLocation Loc, Expr *E) {
- auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
- if (!Coroutine) {
+StmtResult Sema::ActOnCoreturnStmt(Scope *S, SourceLocation Loc, Expr *E) {
+ if (!actOnCoroutineBodyStart(*this, S, Loc, "co_return")) {
CorrectDelayedTyposInExpr(E);
return StmtError();
}
return BuildCoreturnStmt(Loc, E);
}
-StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E) {
- auto *Coroutine = checkCoroutineContext(*this, Loc, "co_return");
- if (!Coroutine)
+StmtResult Sema::BuildCoreturnStmt(SourceLocation Loc, Expr *E,
+ bool IsImplicit) {
+ auto *FSI = checkCoroutineContext(*this, Loc, "co_return");
+ if (!FSI)
return StmtError();
if (E && E->getType()->isPlaceholderType() &&
@@ -420,20 +561,22 @@
// FIXME: If the operand is a reference to a variable that's about to go out
// of scope, we should treat the operand as an xvalue for this overload
// resolution.
+ VarDecl *Promise = FSI->CoroutinePromise;
ExprResult PC;
if (E && (isa<InitListExpr>(E) || !E->getType()->isVoidType())) {
- PC = buildPromiseCall(*this, Coroutine, Loc, "return_value", E);
+ PC = buildPromiseCall(*this, Promise, Loc, "return_value", E);
} else {
E = MakeFullDiscardedValueExpr(E).get();
- PC = buildPromiseCall(*this, Coroutine, Loc, "return_void", None);
+ PC = buildPromiseCall(*this, Promise, Loc, "return_void", None);
}
if (PC.isInvalid())
return StmtError();
Expr *PCE = ActOnFinishFullExpr(PC.get()).get();
- Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE);
- Coroutine->CoroutineStmts.push_back(Res);
+ Stmt *Res = new (Context) CoreturnStmt(Loc, E, PCE, IsImplicit);
+ if (!IsImplicit)
+ FSI->CoroutineStmts.push_back(Res);
return Res;
}
@@ -490,14 +633,91 @@
return OperatorDelete;
}
-// Builds allocation and deallocation for the coroutine. Returns false on
-// failure.
-static bool buildAllocationAndDeallocation(Sema &S, SourceLocation Loc,
- FunctionScopeInfo *Fn,
- Expr *&Allocation,
- Expr *&Deallocation) {
- TypeSourceInfo *TInfo = Fn->CoroutinePromise->getTypeSourceInfo();
- QualType PromiseType = TInfo->getType();
+namespace {
+class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs {
+ Sema &S;
+ FunctionDecl &FD;
+ FunctionScopeInfo &Fn;
+ bool IsValid;
+ SourceLocation Loc;
+ QualType RetType;
+ SmallVector<Stmt *, 4> ParamMovesVector;
+ const bool IsPromiseDependentType;
+ CXXRecordDecl *PromiseRecordDecl = nullptr;
+
+public:
+ SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body)
+ : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
+ IsPromiseDependentType(
+ !Fn.CoroutinePromise ||
+ Fn.CoroutinePromise->getType()->isDependentType()) {
+ this->Body = Body;
+ if (!IsPromiseDependentType) {
+ PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
+ assert(PromiseRecordDecl && "Type should have already been checked");
+ }
+ this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend() &&
+ makeOnException() && makeOnFallthrough() &&
+ makeNewAndDeleteExpr() && makeReturnObject() &&
+ makeParamMoves();
+ }
+
+ bool isInvalid() const { return !this->IsValid; }
+
+ bool makePromiseStmt();
+ bool makeInitialAndFinalSuspend();
+ bool makeNewAndDeleteExpr();
+ bool makeOnFallthrough();
+ bool makeOnException();
+ bool makeReturnObject();
+ bool makeParamMoves();
+};
+}
+
+void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
+ FunctionScopeInfo *Fn = getCurFunction();
+ assert(Fn && Fn->CoroutinePromise && "not a coroutine");
+
+ // Coroutines [stmt.return]p1:
+ // A return statement shall not appear in a coroutine.
+ if (Fn->FirstReturnLoc.isValid()) {
+ Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
+ auto *First = Fn->CoroutineStmts[0];
+ Diag(First->getLocStart(), diag::note_declared_coroutine_here)
+ << (isa<CoawaitExpr>(First) ? "co_await" :
+ isa<CoyieldExpr>(First) ? "co_yield" : "co_return");
+ }
+ SubStmtBuilder Builder(*this, *FD, *Fn, Body);
+ if (Builder.isInvalid())
+ return FD->setInvalidDecl();
+
+ // Build body for the coroutine wrapper statement.
+ Body = CoroutineBodyStmt::Create(Context, Builder);
+}
+
+bool SubStmtBuilder::makePromiseStmt() {
+ // Form a declaration statement for the promise declaration, so that AST
+ // visitors can more easily find it.
+ StmtResult PromiseStmt =
+ S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
+ if (PromiseStmt.isInvalid())
+ return false;
+
+ this->Promise = PromiseStmt.get();
+ return true;
+}
+
+bool SubStmtBuilder::makeInitialAndFinalSuspend() {
+ if (Fn.hasInvalidCoroutineSuspends())
+ return false;
+ this->InitialSuspend = cast<Expr>(Fn.CoroutineSuspends.first);
+ this->FinalSuspend = cast<Expr>(Fn.CoroutineSuspends.second);
+ return true;
+}
+
+bool SubStmtBuilder::makeNewAndDeleteExpr() {
+ // Form and check allocation and deallocation calls.
+ QualType PromiseType = Fn.CoroutinePromise->getType();
if (PromiseType->isDependentType())
return true;
@@ -540,8 +760,6 @@
if (NewExpr.isInvalid())
return false;
- Allocation = NewExpr.get();
-
// Make delete call.
QualType OpDeleteQualType = OperatorDelete->getType();
@@ -567,122 +785,12 @@
if (DeleteExpr.isInvalid())
return false;
- Deallocation = DeleteExpr.get();
+ this->Allocate = NewExpr.get();
+ this->Deallocate = DeleteExpr.get();
return true;
}
-namespace {
-class SubStmtBuilder : public CoroutineBodyStmt::CtorArgs {
- Sema &S;
- FunctionDecl &FD;
- FunctionScopeInfo &Fn;
- bool IsValid;
- SourceLocation Loc;
- QualType RetType;
- SmallVector<Stmt *, 4> ParamMovesVector;
- const bool IsPromiseDependentType;
- CXXRecordDecl *PromiseRecordDecl = nullptr;
-
-public:
- SubStmtBuilder(Sema &S, FunctionDecl &FD, FunctionScopeInfo &Fn, Stmt *Body)
- : S(S), FD(FD), Fn(Fn), Loc(FD.getLocation()),
- IsPromiseDependentType(
- !Fn.CoroutinePromise ||
- Fn.CoroutinePromise->getType()->isDependentType()) {
- this->Body = Body;
- if (!IsPromiseDependentType) {
- PromiseRecordDecl = Fn.CoroutinePromise->getType()->getAsCXXRecordDecl();
- assert(PromiseRecordDecl && "Type should have already been checked");
- }
- this->IsValid = makePromiseStmt() && makeInitialSuspend() &&
- makeFinalSuspend() && makeOnException() &&
- makeOnFallthrough() && makeNewAndDeleteExpr() &&
- makeReturnObject() && makeParamMoves();
- }
-
- bool isInvalid() const { return !this->IsValid; }
-
- bool makePromiseStmt();
- bool makeInitialSuspend();
- bool makeFinalSuspend();
- bool makeNewAndDeleteExpr();
- bool makeOnFallthrough();
- bool makeOnException();
- bool makeReturnObject();
- bool makeParamMoves();
-};
-}
-
-void Sema::CheckCompletedCoroutineBody(FunctionDecl *FD, Stmt *&Body) {
- FunctionScopeInfo *Fn = getCurFunction();
- assert(Fn && !Fn->CoroutineStmts.empty() && "not a coroutine");
-
- // Coroutines [stmt.return]p1:
- // A return statement shall not appear in a coroutine.
- if (Fn->FirstReturnLoc.isValid()) {
- Diag(Fn->FirstReturnLoc, diag::err_return_in_coroutine);
- auto *First = Fn->CoroutineStmts[0];
- Diag(First->getLocStart(), diag::note_declared_coroutine_here)
- << (isa<CoawaitExpr>(First) ? 0 :
- isa<CoyieldExpr>(First) ? 1 : 2);
- }
- SubStmtBuilder Builder(*this, *FD, *Fn, Body);
- if (Builder.isInvalid())
- return FD->setInvalidDecl();
-
- // Build body for the coroutine wrapper statement.
- Body = CoroutineBodyStmt::Create(Context, Builder);
-}
-
-bool SubStmtBuilder::makePromiseStmt() {
- // Form a declaration statement for the promise declaration, so that AST
- // visitors can more easily find it.
- StmtResult PromiseStmt =
- S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(Fn.CoroutinePromise), Loc, Loc);
- if (PromiseStmt.isInvalid())
- return false;
-
- this->Promise = PromiseStmt.get();
- return true;
-}
-
-bool SubStmtBuilder::makeInitialSuspend() {
- // Form and check implicit 'co_await p.initial_suspend();' statement.
- ExprResult InitialSuspend =
- buildPromiseCall(S, &Fn, Loc, "initial_suspend", None);
- // FIXME: Support operator co_await here.
- if (!InitialSuspend.isInvalid())
- InitialSuspend = S.BuildCoawaitExpr(Loc, InitialSuspend.get());
- InitialSuspend = S.ActOnFinishFullExpr(InitialSuspend.get());
- if (InitialSuspend.isInvalid())
- return false;
-
- this->InitialSuspend = InitialSuspend.get();
- return true;
-}
-
-bool SubStmtBuilder::makeFinalSuspend() {
- // Form and check implicit 'co_await p.final_suspend();' statement.
- ExprResult FinalSuspend =
- buildPromiseCall(S, &Fn, Loc, "final_suspend", None);
- // FIXME: Support operator co_await here.
- if (!FinalSuspend.isInvalid())
- FinalSuspend = S.BuildCoawaitExpr(Loc, FinalSuspend.get());
- FinalSuspend = S.ActOnFinishFullExpr(FinalSuspend.get());
- if (FinalSuspend.isInvalid())
- return false;
-
- this->FinalSuspend = FinalSuspend.get();
- return true;
-}
-
-bool SubStmtBuilder::makeNewAndDeleteExpr() {
- // Form and check allocation and deallocation calls.
- return buildAllocationAndDeallocation(S, Loc, &Fn, this->Allocate,
- this->Deallocate);
-}
-
bool SubStmtBuilder::makeOnFallthrough() {
if (!PromiseRecordDecl)
return true;
@@ -690,13 +798,8 @@
// [dcl.fct.def.coroutine]/4
// The unqualified-ids 'return_void' and 'return_value' are looked up in
// the scope of class P. If both are found, the program is ill-formed.
- DeclarationName RVoidDN = S.PP.getIdentifierInfo("return_void");
- LookupResult RVoidResult(S, RVoidDN, Loc, Sema::LookupMemberName);
- const bool HasRVoid = S.LookupQualifiedName(RVoidResult, PromiseRecordDecl);
-
- DeclarationName RValueDN = S.PP.getIdentifierInfo("return_value");
- LookupResult RValueResult(S, RValueDN, Loc, Sema::LookupMemberName);
- const bool HasRValue = S.LookupQualifiedName(RValueResult, PromiseRecordDecl);
+ const bool HasRVoid = lookupMember(S, "return_void", PromiseRecordDecl, Loc);
+ const bool HasRValue = lookupMember(S, "return_value", PromiseRecordDecl, Loc);
StmtResult Fallthrough;
if (HasRVoid && HasRValue) {
@@ -708,7 +811,8 @@
// If the unqualified-id return_void is found, flowing off the end of a
// coroutine is equivalent to a co_return with no operand. Otherwise,
// flowing off the end of a coroutine results in undefined behavior.
- Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr);
+ Fallthrough = S.BuildCoreturnStmt(FD.getLocation(), nullptr,
+ /*IsImplicit*/false);
Fallthrough = S.ActOnFinishFullStmt(Fallthrough.get());
if (Fallthrough.isInvalid())
return false;
@@ -736,15 +840,13 @@
// [dcl.fct.def.coroutine]/3
// The unqualified-id set_exception is found in the scope of P by class
// member access lookup (3.4.5).
- DeclarationName SetExDN = S.PP.getIdentifierInfo("set_exception");
- LookupResult SetExResult(S, SetExDN, Loc, Sema::LookupMemberName);
- if (S.LookupQualifiedName(SetExResult, PromiseRecordDecl)) {
+ if (lookupMember(S, "set_exception", PromiseRecordDecl, Loc)) {
// Form the call 'p.set_exception(std::current_exception())'
SetException = buildStdCurrentExceptionCall(S, Loc);
if (SetException.isInvalid())
return false;
Expr *E = SetException.get();
- SetException = buildPromiseCall(S, &Fn, Loc, "set_exception", E);
+ SetException = buildPromiseCall(S, Fn.CoroutinePromise, Loc, "set_exception", E);
SetException = S.ActOnFinishFullExpr(SetException.get(), Loc);
if (SetException.isInvalid())
return false;
@@ -759,7 +861,7 @@
// Build implicit 'p.get_return_object()' expression and form initialization
// of return type from it.
ExprResult ReturnObject =
- buildPromiseCall(S, &Fn, Loc, "get_return_object", None);
+ buildPromiseCall(S, Fn.CoroutinePromise, Loc, "get_return_object", None);
if (ReturnObject.isInvalid())
return false;
QualType RetType = FD.getReturnType();
@@ -783,3 +885,10 @@
// FIXME: Perform move-initialization of parameters into frame-local copies.
return true;
}
+
+StmtResult Sema::BuildCoroutineBodyStmt(CoroutineBodyStmt::CtorArgs Args) {
+ CoroutineBodyStmt *Res = CoroutineBodyStmt::Create(Context, Args);
+ if (!Res)
+ return StmtError();
+ return Res;
+}