Implement CapturedStmt AST

CapturedStmt can be used to implement generic function outlining as described in
http://lists.cs.uiuc.edu/pipermail/cfe-dev/2013-January/027540.html.

CapturedStmt is not exposed to the C api.

Serialization and template support are pending.

Author: Wei Pan <wei.pan@intel.com>

Differential Revision: http://llvm-reviews.chandlerc.com/D370


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@179615 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp
index 2ae5a12..e120c6a 100644
--- a/lib/AST/Stmt.cpp
+++ b/lib/AST/Stmt.cpp
@@ -1023,3 +1023,105 @@
                                        Stmt *Block) {
   return new(C)SEHFinallyStmt(Loc,Block);
 }
+
+CapturedStmt::Capture *CapturedStmt::getStoredCaptures() const {
+  unsigned Size = sizeof(CapturedStmt) + sizeof(Stmt *) * (NumCaptures + 1);
+
+  // Offset of the first Capture object.
+  unsigned FirstCaptureOffset =
+    llvm::RoundUpToAlignment(Size, llvm::alignOf<Capture>());
+
+  return reinterpret_cast<Capture *>(
+      reinterpret_cast<char *>(const_cast<CapturedStmt *>(this))
+      + FirstCaptureOffset);
+}
+
+CapturedStmt::CapturedStmt(Stmt *S, ArrayRef<Capture> Captures,
+                           ArrayRef<Expr *> CaptureInits,
+                           FunctionDecl *FD,
+                           RecordDecl *RD)
+  : Stmt(CapturedStmtClass), NumCaptures(Captures.size()),
+    TheFuncDecl(FD), TheRecordDecl(RD) {
+  assert( S && "null captured statement");
+  assert(FD && "null function declaration for captured statement");
+  assert(RD && "null record declaration for captured statement");
+
+  // Copy initialization expressions.
+  Stmt **Stored = getStoredStmts();
+  for (unsigned I = 0, N = NumCaptures; I != N; ++I)
+    *Stored++ = CaptureInits[I];
+
+  // Copy the statement being captured.
+  *Stored = S;
+
+  // Copy all Capture objects.
+  Capture *Buffer = getStoredCaptures();
+  std::copy(Captures.begin(), Captures.end(), Buffer);
+}
+
+CapturedStmt::CapturedStmt(EmptyShell Empty, unsigned NumCaptures)
+  : Stmt(CapturedStmtClass, Empty), NumCaptures(NumCaptures),
+    TheFuncDecl(0), TheRecordDecl(0) {
+  getStoredStmts()[NumCaptures] = 0;
+}
+
+CapturedStmt *CapturedStmt::Create(ASTContext &Context, Stmt *S,
+                                   ArrayRef<Capture> Captures,
+                                   ArrayRef<Expr *> CaptureInits,
+                                   FunctionDecl *FD,
+                                   RecordDecl *RD) {
+  // The layout is
+  //
+  // -----------------------------------------------------------
+  // | CapturedStmt, Init, ..., Init, S, Capture, ..., Capture |
+  // ----------------^-------------------^----------------------
+  //                 getStoredStmts()    getStoredCaptures()
+  //
+  // where S is the statement being captured.
+  //
+  assert(CaptureInits.size() == Captures.size() && "wrong number of arguments");
+
+  unsigned Size = sizeof(CapturedStmt) + sizeof(Stmt *) * (Captures.size() + 1);
+  if (!Captures.empty()) {
+    // Realign for the following Capture array.
+    Size = llvm::RoundUpToAlignment(Size, llvm::alignOf<Capture>());
+    Size += sizeof(Capture) * Captures.size();
+  }
+
+  void *Mem = Context.Allocate(Size);
+  return new (Mem) CapturedStmt(S, Captures, CaptureInits, FD, RD);
+}
+
+CapturedStmt *CapturedStmt::CreateDeserialized(ASTContext &Context,
+                                               unsigned NumCaptures) {
+  unsigned Size = sizeof(CapturedStmt) + sizeof(Stmt *) * (NumCaptures + 1);
+  if (NumCaptures > 0) {
+    // Realign for the following Capture array.
+    Size = llvm::RoundUpToAlignment(Size, llvm::alignOf<Capture>());
+    Size += sizeof(Capture) * NumCaptures;
+  }
+
+  void *Mem = Context.Allocate(Size);
+  return new (Mem) CapturedStmt(EmptyShell(), NumCaptures);
+}
+
+Stmt::child_range CapturedStmt::children() {
+  // Children are captured field initilizers and the statement being captured.
+  return child_range(getStoredStmts(), getStoredStmts() + NumCaptures + 1);
+}
+
+bool CapturedStmt::capturesVariable(const VarDecl *Var) const {
+  for (capture_iterator I = capture_begin(),
+                        E = capture_end(); I != E; ++I) {
+    if (I->capturesThis())
+      continue;
+
+    // This does not handle variable redeclarations. This should be
+    // extended to capture variables with redeclarations, for example
+    // a thread-private variable in OpenMP.
+    if (I->getCapturedVar() == Var)
+      return true;
+  }
+
+  return false;
+}
diff --git a/lib/AST/StmtPrinter.cpp b/lib/AST/StmtPrinter.cpp
index a86159f..469c284 100644
--- a/lib/AST/StmtPrinter.cpp
+++ b/lib/AST/StmtPrinter.cpp
@@ -450,6 +450,10 @@
     Indent() << "}\n";
 }
 
+void StmtPrinter::VisitCapturedStmt(CapturedStmt *Node) {
+  PrintStmt(Node->getCapturedStmt());
+}
+
 void StmtPrinter::VisitObjCAtTryStmt(ObjCAtTryStmt *Node) {
   Indent() << "@try";
   if (CompoundStmt *TS = dyn_cast<CompoundStmt>(Node->getTryBody())) {
diff --git a/lib/AST/StmtProfile.cpp b/lib/AST/StmtProfile.cpp
index 5525018..d99400c 100644
--- a/lib/AST/StmtProfile.cpp
+++ b/lib/AST/StmtProfile.cpp
@@ -215,6 +215,10 @@
   VisitStmt(S);
 }
 
+void StmtProfiler::VisitCapturedStmt(const CapturedStmt *S) {
+  VisitStmt(S);
+}
+
 void StmtProfiler::VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
   VisitStmt(S);
 }
diff --git a/lib/CodeGen/CGStmt.cpp b/lib/CodeGen/CGStmt.cpp
index 3153ca8..d10818c 100644
--- a/lib/CodeGen/CGStmt.cpp
+++ b/lib/CodeGen/CGStmt.cpp
@@ -134,7 +134,9 @@
   case Stmt::SwitchStmtClass:   EmitSwitchStmt(cast<SwitchStmt>(*S));     break;
   case Stmt::GCCAsmStmtClass:   // Intentional fall-through.
   case Stmt::MSAsmStmtClass:    EmitAsmStmt(cast<AsmStmt>(*S));           break;
-
+  case Stmt::CapturedStmtClass:
+    EmitCapturedStmt(cast<CapturedStmt>(*S));
+    break;
   case Stmt::ObjCAtTryStmtClass:
     EmitObjCAtTryStmt(cast<ObjCAtTryStmt>(*S));
     break;
@@ -1735,3 +1737,7 @@
     EmitStoreThroughLValue(RValue::get(Tmp), ResultRegDests[i]);
   }
 }
+
+void CodeGenFunction::EmitCapturedStmt(const CapturedStmt &S) {
+  llvm_unreachable("not implemented yet");
+}
diff --git a/lib/CodeGen/CodeGenFunction.h b/lib/CodeGen/CodeGenFunction.h
index 645d5ff..941eebe 100644
--- a/lib/CodeGen/CodeGenFunction.h
+++ b/lib/CodeGen/CodeGenFunction.h
@@ -2133,6 +2133,7 @@
   void EmitCaseStmt(const CaseStmt &S);
   void EmitCaseStmtRange(const CaseStmt &S);
   void EmitAsmStmt(const AsmStmt &S);
+  void EmitCapturedStmt(const CapturedStmt &S);
 
   void EmitObjCForCollectionStmt(const ObjCForCollectionStmt &S);
   void EmitObjCAtTryStmt(const ObjCAtTryStmt &S);
diff --git a/lib/Sema/TreeTransform.h b/lib/Sema/TreeTransform.h
index b4083e9..55f1587 100644
--- a/lib/Sema/TreeTransform.h
+++ b/lib/Sema/TreeTransform.h
@@ -9377,6 +9377,12 @@
                                             /*TemplateArgs*/ 0);
 }
 
+template<typename Derived>
+StmtResult
+TreeTransform<Derived>::TransformCapturedStmt(CapturedStmt *S) {
+  llvm_unreachable("not implement yet");
+}
+
 } // end namespace clang
 
 #endif // LLVM_CLANG_SEMA_TREETRANSFORM_H
diff --git a/lib/Serialization/ASTReaderStmt.cpp b/lib/Serialization/ASTReaderStmt.cpp
index 567d50e..b18114f 100644
--- a/lib/Serialization/ASTReaderStmt.cpp
+++ b/lib/Serialization/ASTReaderStmt.cpp
@@ -324,6 +324,10 @@
   VisitStmt(S);
 }
 
+void ASTStmtReader::VisitCapturedStmt(CapturedStmt *S) {
+  llvm_unreachable("not implemented yet");
+}
+
 void ASTStmtReader::VisitExpr(Expr *E) {
   VisitStmt(E);
   E->setType(Reader.readType(F, Record, Idx));
@@ -1724,6 +1728,10 @@
       S = new (Context) MSAsmStmt(Empty);
       break;
 
+    case STMT_CAPTURED:
+      llvm_unreachable("not implemented yet");
+      break;
+
     case EXPR_PREDEFINED:
       S = new (Context) PredefinedExpr(Empty);
       break;
diff --git a/lib/Serialization/ASTWriterStmt.cpp b/lib/Serialization/ASTWriterStmt.cpp
index 920730f..61ddec0 100644
--- a/lib/Serialization/ASTWriterStmt.cpp
+++ b/lib/Serialization/ASTWriterStmt.cpp
@@ -255,6 +255,13 @@
   Code = serialization::STMT_MSASM;
 }
 
+void ASTStmtWriter::VisitCapturedStmt(CapturedStmt *S) {
+  VisitStmt(S);
+  Code = serialization::STMT_CAPTURED;
+
+  llvm_unreachable("not implemented yet");
+}
+
 void ASTStmtWriter::VisitExpr(Expr *E) {
   VisitStmt(E);
   Writer.AddTypeRef(E->getType(), Record);
diff --git a/lib/StaticAnalyzer/Core/ExprEngine.cpp b/lib/StaticAnalyzer/Core/ExprEngine.cpp
index cf75deb..4759b51 100644
--- a/lib/StaticAnalyzer/Core/ExprEngine.cpp
+++ b/lib/StaticAnalyzer/Core/ExprEngine.cpp
@@ -656,6 +656,7 @@
     case Stmt::SwitchStmtClass:
     case Stmt::WhileStmtClass:
     case Expr::MSDependentExistsStmtClass:
+    case Stmt::CapturedStmtClass:
       llvm_unreachable("Stmt should not be in analyzer evaluation loop");
 
     case Stmt::ObjCSubscriptRefExprClass: