[coroutines] Handle get_return_object_on_allocation_failure

Summary:
If promise_type has get_return_object_on_allocation_failure defined,
check if an allocation function returns nullptr, and if so,
return the result of get_return_object_on_allocation_failure().

Reviewers: rsmith, EricWF

Reviewed By: EricWF

Subscribers: mehdi_amini, cfe-commits

Differential Revision: https://reviews.llvm.org/D31399

llvm-svn: 298891
diff --git a/clang/lib/Sema/SemaCoroutine.cpp b/clang/lib/Sema/SemaCoroutine.cpp
index 309194f..d8e7b9b 100644
--- a/clang/lib/Sema/SemaCoroutine.cpp
+++ b/clang/lib/Sema/SemaCoroutine.cpp
@@ -708,8 +708,8 @@
     }
     this->IsValid = makePromiseStmt() && makeInitialAndFinalSuspend() &&
                     makeOnException() && makeOnFallthrough() &&
-                    makeNewAndDeleteExpr() && makeReturnObject() &&
-                    makeParamMoves();
+                    makeReturnOnAllocFailure() && makeNewAndDeleteExpr() &&
+                    makeReturnObject() && makeParamMoves();
   }
 
   bool isInvalid() const { return !this->IsValid; }
@@ -720,6 +720,7 @@
   bool makeOnFallthrough();
   bool makeOnException();
   bool makeReturnObject();
+  bool makeReturnOnAllocFailure();
   bool makeParamMoves();
 };
 }
@@ -777,6 +778,66 @@
   return true;
 }
 
+static bool diagReturnOnAllocFailure(Sema &S, Expr *E,
+                                     CXXRecordDecl *PromiseRecordDecl,
+                                     FunctionScopeInfo &Fn) {
+  auto Loc = E->getExprLoc();
+  if (auto *DeclRef = dyn_cast_or_null<DeclRefExpr>(E)) {
+    auto *Decl = DeclRef->getDecl();
+    if (CXXMethodDecl *Method = dyn_cast_or_null<CXXMethodDecl>(Decl)) {
+      if (Method->isStatic())
+        return true;
+      else
+        Loc = Decl->getLocation();
+    }
+  }
+
+  S.Diag(
+      Loc,
+      diag::err_coroutine_promise_get_return_object_on_allocation_failure)
+      << PromiseRecordDecl;
+  S.Diag(Fn.FirstCoroutineStmtLoc, diag::note_declared_coroutine_here)
+      << Fn.getFirstCoroutineStmtKeyword();
+  return false;
+}
+
+bool SubStmtBuilder::makeReturnOnAllocFailure() {
+  if (!PromiseRecordDecl) return true;
+
+  // [dcl.fct.def.coroutine]/8
+  // The unqualified-id get_return_object_on_allocation_failure is looked up in
+  // the scope of class P by class member access lookup (3.4.5). ...
+  // If an allocation function returns nullptr, ... the coroutine return value
+  // is obtained by a call to ... get_return_object_on_allocation_failure().
+
+  DeclarationName DN =
+      S.PP.getIdentifierInfo("get_return_object_on_allocation_failure");
+  LookupResult Found(S, DN, Loc, Sema::LookupMemberName);
+  // Suppress diagnostics when a private member is selected. The same warnings
+  // will be produced again when building the call.
+  Found.suppressDiagnostics();
+  if (!S.LookupQualifiedName(Found, PromiseRecordDecl)) return true;
+
+  CXXScopeSpec SS;
+  ExprResult DeclNameExpr =
+      S.BuildDeclarationNameExpr(SS, Found, /*NeedsADL=*/false);
+  if (DeclNameExpr.isInvalid()) return false;
+
+  if (!diagReturnOnAllocFailure(S, DeclNameExpr.get(), PromiseRecordDecl, Fn))
+    return false;
+
+  ExprResult ReturnObjectOnAllocationFailure =
+      S.ActOnCallExpr(nullptr, DeclNameExpr.get(), Loc, {}, Loc);
+  if (ReturnObjectOnAllocationFailure.isInvalid()) return false;
+
+  StmtResult ReturnStmt = S.ActOnReturnStmt(
+      Loc, ReturnObjectOnAllocationFailure.get(), S.getCurScope());
+  if (ReturnStmt.isInvalid()) return false;
+
+  this->ReturnStmtOnAllocFailure = ReturnStmt.get();
+  return true;
+}
+
 bool SubStmtBuilder::makeNewAndDeleteExpr() {
   // Form and check allocation and deallocation calls.
   QualType PromiseType = Fn.CoroutinePromise->getType();
@@ -786,7 +847,8 @@
   if (S.RequireCompleteType(Loc, PromiseType, diag::err_incomplete_type))
     return false;
 
-  // FIXME: Add support for get_return_object_on_allocation failure.
+  // FIXME: Add nothrow_t placement arg for global alloc
+  //        if ReturnStmtOnAllocFailure != nullptr.
   // FIXME: Add support for stateful allocators.
 
   FunctionDecl *OperatorNew = nullptr;