AST, Sema, Serialization: add CUDAKernelCallExpr and related semantic actions

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@125217 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/AST/ExprClassification.cpp b/lib/AST/ExprClassification.cpp
index ba3b88b..593f783 100644
--- a/lib/AST/ExprClassification.cpp
+++ b/lib/AST/ExprClassification.cpp
@@ -238,6 +238,7 @@
   case Expr::CallExprClass:
   case Expr::CXXOperatorCallExprClass:
   case Expr::CXXMemberCallExprClass:
+  case Expr::CUDAKernelCallExprClass:
     return ClassifyUnnamed(Ctx, cast<CallExpr>(E)->getCallReturnType());
 
     // __builtin_choose_expr is equivalent to the chosen expression.
diff --git a/lib/AST/ExprConstant.cpp b/lib/AST/ExprConstant.cpp
index 227c60e..0c3f647 100644
--- a/lib/AST/ExprConstant.cpp
+++ b/lib/AST/ExprConstant.cpp
@@ -2599,6 +2599,7 @@
   case Expr::AddrLabelExprClass:
   case Expr::StmtExprClass:
   case Expr::CXXMemberCallExprClass:
+  case Expr::CUDAKernelCallExprClass:
   case Expr::CXXDynamicCastExprClass:
   case Expr::CXXTypeidExprClass:
   case Expr::CXXUuidofExprClass:
diff --git a/lib/AST/ItaniumMangle.cpp b/lib/AST/ItaniumMangle.cpp
index 0e05041..6608571 100644
--- a/lib/AST/ItaniumMangle.cpp
+++ b/lib/AST/ItaniumMangle.cpp
@@ -1732,7 +1732,8 @@
   case Expr::BinaryTypeTraitExprClass:
   case Expr::VAArgExprClass:
   case Expr::CXXUuidofExprClass:
-  case Expr::CXXNoexceptExprClass: {
+  case Expr::CXXNoexceptExprClass:
+  case Expr::CUDAKernelCallExprClass: {
     // As bad as this diagnostic is, it's better than crashing.
     Diagnostic &Diags = Context.getDiags();
     unsigned DiagID = Diags.getCustomDiagID(Diagnostic::Error,
diff --git a/lib/AST/StmtPrinter.cpp b/lib/AST/StmtPrinter.cpp
index 5d31fd6..fa1736f 100644
--- a/lib/AST/StmtPrinter.cpp
+++ b/lib/AST/StmtPrinter.cpp
@@ -961,6 +961,15 @@
   VisitCallExpr(cast<CallExpr>(Node));
 }
 
+void StmtPrinter::VisitCUDAKernelCallExpr(CUDAKernelCallExpr *Node) {
+  PrintExpr(Node->getCallee());
+  OS << "<<<";
+  PrintCallArgs(Node->getConfig());
+  OS << ">>>(";
+  PrintCallArgs(Node);
+  OS << ")";
+}
+
 void StmtPrinter::VisitCXXNamedCastExpr(CXXNamedCastExpr *Node) {
   OS << Node->getCastName() << '<';
   OS << Node->getTypeAsWritten().getAsString(Policy) << ">(";
diff --git a/lib/AST/StmtProfile.cpp b/lib/AST/StmtProfile.cpp
index 842a2d9..707cac4 100644
--- a/lib/AST/StmtProfile.cpp
+++ b/lib/AST/StmtProfile.cpp
@@ -644,6 +644,10 @@
   VisitCallExpr(S);
 }
 
+void StmtProfiler::VisitCUDAKernelCallExpr(CUDAKernelCallExpr *S) {
+  VisitCallExpr(S);
+}
+
 void StmtProfiler::VisitCXXNamedCastExpr(CXXNamedCastExpr *S) {
   VisitExplicitCastExpr(S);
 }
diff --git a/lib/Sema/SemaExpr.cpp b/lib/Sema/SemaExpr.cpp
index 429804d..fcc56f1 100644
--- a/lib/Sema/SemaExpr.cpp
+++ b/lib/Sema/SemaExpr.cpp
@@ -4362,7 +4362,8 @@
 /// locations.
 ExprResult
 Sema::ActOnCallExpr(Scope *S, Expr *Fn, SourceLocation LParenLoc,
-                    MultiExprArg args, SourceLocation RParenLoc) {
+                    MultiExprArg args, SourceLocation RParenLoc,
+                    Expr *ExecConfig) {
   unsigned NumArgs = args.size();
 
   // Since this might be a postfix expression, get rid of ParenListExprs.
@@ -4399,10 +4400,17 @@
     else if (Expr::hasAnyTypeDependentArguments(Args, NumArgs))
       Dependent = true;
 
-    if (Dependent)
-      return Owned(new (Context) CallExpr(Context, Fn, Args, NumArgs,
-                                          Context.DependentTy, VK_RValue,
-                                          RParenLoc));
+    if (Dependent) {
+      if (ExecConfig) {
+        return Owned(new (Context) CUDAKernelCallExpr(
+            Context, Fn, cast<CallExpr>(ExecConfig), Args, NumArgs,
+            Context.DependentTy, VK_RValue, RParenLoc));
+      } else {
+        return Owned(new (Context) CallExpr(Context, Fn, Args, NumArgs,
+                                            Context.DependentTy, VK_RValue,
+                                            RParenLoc));
+      }
+    }
 
     // Determine whether this is a call to an object (C++ [over.call.object]).
     if (Fn->getType()->isRecordType())
@@ -4495,7 +4503,7 @@
   if (isa<UnresolvedLookupExpr>(NakedFn)) {
     UnresolvedLookupExpr *ULE = cast<UnresolvedLookupExpr>(NakedFn);
     return BuildOverloadedCallExpr(S, Fn, ULE, LParenLoc, Args, NumArgs,
-                                   RParenLoc);
+                                   RParenLoc, ExecConfig);
   }
 
   NamedDecl *NDecl = 0;
@@ -4506,7 +4514,23 @@
   if (isa<DeclRefExpr>(NakedFn))
     NDecl = cast<DeclRefExpr>(NakedFn)->getDecl();
 
-  return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, Args, NumArgs, RParenLoc);
+  return BuildResolvedCallExpr(Fn, NDecl, LParenLoc, Args, NumArgs, RParenLoc,
+                               ExecConfig);
+}
+
+ExprResult
+Sema::ActOnCUDAExecConfigExpr(Scope *S, SourceLocation LLLLoc,
+                              MultiExprArg execConfig, SourceLocation GGGLoc) {
+  FunctionDecl *ConfigDecl = Context.getcudaConfigureCallDecl();
+  if (!ConfigDecl)
+    return ExprError(Diag(LLLLoc, diag::err_undeclared_var_use)
+                          << "cudaConfigureCall");
+  QualType ConfigQTy = ConfigDecl->getType();
+
+  DeclRefExpr *ConfigDR = new (Context) DeclRefExpr(
+      ConfigDecl, ConfigQTy, VK_LValue, LLLLoc);
+
+  return ActOnCallExpr(S, ConfigDR, LLLLoc, execConfig, GGGLoc, 0);
 }
 
 /// BuildResolvedCallExpr - Build a call to a resolved expression,
@@ -4519,7 +4543,8 @@
 Sema::BuildResolvedCallExpr(Expr *Fn, NamedDecl *NDecl,
                             SourceLocation LParenLoc,
                             Expr **Args, unsigned NumArgs,
-                            SourceLocation RParenLoc) {
+                            SourceLocation RParenLoc,
+                            Expr *Config) {
   FunctionDecl *FDecl = dyn_cast_or_null<FunctionDecl>(NDecl);
 
   // Promote the function operand.
@@ -4527,11 +4552,21 @@
 
   // Make the call expr early, before semantic checks.  This guarantees cleanup
   // of arguments and function on error.
-  CallExpr *TheCall = new (Context) CallExpr(Context, Fn,
-                                             Args, NumArgs,
-                                             Context.BoolTy,
-                                             VK_RValue,
-                                             RParenLoc);
+  CallExpr *TheCall;
+  if (Config) {
+    TheCall = new (Context) CUDAKernelCallExpr(Context, Fn,
+                                               cast<CallExpr>(Config),
+                                               Args, NumArgs,
+                                               Context.BoolTy,
+                                               VK_RValue,
+                                               RParenLoc);
+  } else {
+    TheCall = new (Context) CallExpr(Context, Fn,
+                                     Args, NumArgs,
+                                     Context.BoolTy,
+                                     VK_RValue,
+                                     RParenLoc);
+  }
 
   const FunctionType *FuncT;
   if (!Fn->getType()->isBlockPointerType()) {
diff --git a/lib/Sema/SemaOverload.cpp b/lib/Sema/SemaOverload.cpp
index 916c5a1..42e2411 100644
--- a/lib/Sema/SemaOverload.cpp
+++ b/lib/Sema/SemaOverload.cpp
@@ -7426,7 +7426,8 @@
 Sema::BuildOverloadedCallExpr(Scope *S, Expr *Fn, UnresolvedLookupExpr *ULE,
                               SourceLocation LParenLoc,
                               Expr **Args, unsigned NumArgs,
-                              SourceLocation RParenLoc) {
+                              SourceLocation RParenLoc,
+                              Expr *ExecConfig) {
 #ifndef NDEBUG
   if (ULE->requiresADL()) {
     // To do ADL, we must have found an unqualified name.
@@ -7466,8 +7467,8 @@
     DiagnoseUseOfDecl(FDecl? FDecl : Best->FoundDecl.getDecl(),
                       ULE->getNameLoc());
     Fn = FixOverloadedFunctionReference(Fn, Best->FoundDecl, FDecl);
-    return BuildResolvedCallExpr(Fn, FDecl, LParenLoc, Args, NumArgs,
-                                 RParenLoc);
+    return BuildResolvedCallExpr(Fn, FDecl, LParenLoc, Args, NumArgs, RParenLoc,
+                                 ExecConfig);
   }
 
   case OR_No_Viable_Function:
diff --git a/lib/Sema/TreeTransform.h b/lib/Sema/TreeTransform.h
index bd110bd..ca5d1c1 100644
--- a/lib/Sema/TreeTransform.h
+++ b/lib/Sema/TreeTransform.h
@@ -1348,9 +1348,10 @@
   /// Subclasses may override this routine to provide different behavior.
   ExprResult RebuildCallExpr(Expr *Callee, SourceLocation LParenLoc,
                                    MultiExprArg Args,
-                                   SourceLocation RParenLoc) {
+                                   SourceLocation RParenLoc,
+                                   Expr *ExecConfig = 0) {
     return getSema().ActOnCallExpr(/*Scope=*/0, Callee, LParenLoc,
-                                   move(Args), RParenLoc);
+                                   move(Args), RParenLoc, ExecConfig);
   }
 
   /// \brief Build a new member access expression.
@@ -5921,6 +5922,39 @@
 
 template<typename Derived>
 ExprResult
+TreeTransform<Derived>::TransformCUDAKernelCallExpr(CUDAKernelCallExpr *E) {
+  // Transform the callee.
+  ExprResult Callee = getDerived().TransformExpr(E->getCallee());
+  if (Callee.isInvalid())
+    return ExprError();
+
+  // Transform exec config.
+  ExprResult EC = getDerived().TransformCallExpr(E->getConfig());
+  if (EC.isInvalid())
+    return ExprError();
+
+  // Transform arguments.
+  bool ArgChanged = false;
+  ASTOwningVector<Expr*> Args(SemaRef);
+  if (getDerived().TransformExprs(E->getArgs(), E->getNumArgs(), true, Args, 
+                                  &ArgChanged))
+    return ExprError();
+
+  if (!getDerived().AlwaysRebuild() &&
+      Callee.get() == E->getCallee() &&
+      !ArgChanged)
+    return SemaRef.Owned(E);
+
+  // FIXME: Wrong source location information for the '('.
+  SourceLocation FakeLParenLoc
+    = ((Expr *)Callee.get())->getSourceRange().getBegin();
+  return getDerived().RebuildCallExpr(Callee.get(), FakeLParenLoc,
+                                      move_arg(Args),
+                                      E->getRParenLoc(), EC.get());
+}
+
+template<typename Derived>
+ExprResult
 TreeTransform<Derived>::TransformCXXNamedCastExpr(CXXNamedCastExpr *E) {
   TypeSourceInfo *Type = getDerived().TransformType(E->getTypeInfoAsWritten());
   if (!Type)
diff --git a/lib/Serialization/ASTReaderStmt.cpp b/lib/Serialization/ASTReaderStmt.cpp
index 8bd94b4..864c042 100644
--- a/lib/Serialization/ASTReaderStmt.cpp
+++ b/lib/Serialization/ASTReaderStmt.cpp
@@ -182,6 +182,9 @@
     void VisitSubstNonTypeTemplateParmPackExpr(
                                            SubstNonTypeTemplateParmPackExpr *E);
     void VisitOpaqueValueExpr(OpaqueValueExpr *E);
+    
+    // CUDA Expressions
+    void VisitCUDAKernelCallExpr(CUDAKernelCallExpr *E);
   };
 }
 
@@ -1323,6 +1326,15 @@
   E->Loc = ReadSourceLocation(Record, Idx);
 }
 
+//===----------------------------------------------------------------------===//
+// CUDA Expressions and Statements
+//===----------------------------------------------------------------------===//
+
+void ASTStmtReader::VisitCUDAKernelCallExpr(CUDAKernelCallExpr *E) {
+  VisitCallExpr(E);
+  E->setConfig(cast<CallExpr>(Reader.ReadSubExpr()));
+}
+
 Stmt *ASTReader::ReadStmt(PerFileData &F) {
   switch (ReadingKind) {
   case Read_Decl:
@@ -1872,6 +1884,10 @@
     case EXPR_OPAQUE_VALUE:
       S = new (Context) OpaqueValueExpr(Empty);
       break;
+
+    case EXPR_CUDA_KERNEL_CALL:
+      S = new (Context) CUDAKernelCallExpr(*Context, Empty);
+      break;
     }
     
     // We hit a STMT_STOP, so we're done with this expression.
diff --git a/lib/Serialization/ASTWriter.cpp b/lib/Serialization/ASTWriter.cpp
index a5af03c..624cf10 100644
--- a/lib/Serialization/ASTWriter.cpp
+++ b/lib/Serialization/ASTWriter.cpp
@@ -691,6 +691,7 @@
   RECORD(EXPR_PACK_EXPANSION);
   RECORD(EXPR_SIZEOF_PACK);
   RECORD(EXPR_SUBST_NON_TYPE_TEMPLATE_PARM_PACK);
+  RECORD(EXPR_CUDA_KERNEL_CALL);
 #undef RECORD
 }
 
diff --git a/lib/Serialization/ASTWriterStmt.cpp b/lib/Serialization/ASTWriterStmt.cpp
index d721c18..8a90ef1 100644
--- a/lib/Serialization/ASTWriterStmt.cpp
+++ b/lib/Serialization/ASTWriterStmt.cpp
@@ -156,6 +156,9 @@
     void VisitSubstNonTypeTemplateParmPackExpr(
                                            SubstNonTypeTemplateParmPackExpr *E);
     void VisitOpaqueValueExpr(OpaqueValueExpr *E);
+
+    // CUDA Expressions
+    void VisitCUDAKernelCallExpr(CUDAKernelCallExpr *E);
   };
 }
 
@@ -1325,6 +1328,16 @@
 }
 
 //===----------------------------------------------------------------------===//
+// CUDA Expressions and Statements.
+//===----------------------------------------------------------------------===//
+
+void ASTStmtWriter::VisitCUDAKernelCallExpr(CUDAKernelCallExpr *E) {
+  VisitCallExpr(E);
+  Writer.AddStmt(E->getConfig());
+  Code = serialization::EXPR_CUDA_KERNEL_CALL;
+}
+
+//===----------------------------------------------------------------------===//
 // ASTWriter Implementation
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/StaticAnalyzer/Checkers/ExprEngine.cpp b/lib/StaticAnalyzer/Checkers/ExprEngine.cpp
index e01d171..0b202a1 100644
--- a/lib/StaticAnalyzer/Checkers/ExprEngine.cpp
+++ b/lib/StaticAnalyzer/Checkers/ExprEngine.cpp
@@ -894,6 +894,7 @@
     case Stmt::PredefinedExprClass:
     case Stmt::ShuffleVectorExprClass:
     case Stmt::VAArgExprClass:
+    case Stmt::CUDAKernelCallExprClass:
         // Fall through.
 
     // Cases we intentionally don't evaluate, since they don't need