CodeGen for CapturedStmts

EmitCapturedStmt creates a captured struct containing all of the captured
variables, and then emits a call to the outlined function.  This is similar in
principle to EmitBlockLiteral.

GenerateCapturedFunction actually produces the outlined function.  It is based
on GenerateBlockFunction, but is much simpler.  The function type is determined
by the parameters that are in the CapturedDecl.

Some changes have been added to this patch that were reviewed as part of the
serialization patch and moving the parameters to the captured decl.

Differential Revision: http://llvm-reviews.chandlerc.com/D640


git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@181536 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/CGExpr.cpp b/lib/CodeGen/CGExpr.cpp
index 64670c5..8ba5a1c 100644
--- a/lib/CodeGen/CGExpr.cpp
+++ b/lib/CodeGen/CGExpr.cpp
@@ -1793,6 +1793,13 @@
   return CGF.MakeAddrLValue(V, E->getType(), Alignment);
 }
 
+static LValue EmitCapturedFieldLValue(CodeGenFunction &CGF, const FieldDecl *FD,
+                                      llvm::Value *ThisValue) {
+  QualType TagType = CGF.getContext().getTagDeclType(FD->getParent());
+  LValue LV = CGF.MakeNaturalAlignAddrLValue(ThisValue, TagType);
+  return CGF.EmitLValueForField(LV, FD);
+}
+
 LValue CodeGenFunction::EmitDeclRefLValue(const DeclRefExpr *E) {
   const NamedDecl *ND = E->getDecl();
   CharUnits Alignment = getContext().getDeclAlign(ND);
@@ -1844,10 +1851,11 @@
     // Use special handling for lambdas.
     if (!V) {
       if (FieldDecl *FD = LambdaCaptureFields.lookup(VD)) {
-        QualType LambdaTagType = getContext().getTagDeclType(FD->getParent());
-        LValue LambdaLV = MakeNaturalAlignAddrLValue(CXXABIThisValue,
-                                                     LambdaTagType);
-        return EmitLValueForField(LambdaLV, FD);
+        return EmitCapturedFieldLValue(*this, FD, CXXABIThisValue);
+      } else if (CapturedStmtInfo) {
+        if (const FieldDecl *FD = CapturedStmtInfo->lookup(VD))
+          return EmitCapturedFieldLValue(*this, FD,
+                                         CapturedStmtInfo->getContextValue());
       }
 
       assert(isa<BlockDecl>(CurCodeDecl) && E->refersToEnclosingLocal());
diff --git a/lib/CodeGen/CGStmt.cpp b/lib/CodeGen/CGStmt.cpp
index 208455c..353d6c3 100644
--- a/lib/CodeGen/CGStmt.cpp
+++ b/lib/CodeGen/CGStmt.cpp
@@ -22,6 +22,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/InlineAsm.h"
 #include "llvm/IR/Intrinsics.h"
+#include "llvm/Support/CallSite.h"
 using namespace clang;
 using namespace CodeGen;
 
@@ -137,7 +138,7 @@
   case Stmt::GCCAsmStmtClass:   // Intentional fall-through.
   case Stmt::MSAsmStmtClass:    EmitAsmStmt(cast<AsmStmt>(*S));           break;
   case Stmt::CapturedStmtClass:
-    EmitCapturedStmt(cast<CapturedStmt>(*S));
+    EmitCapturedStmt(cast<CapturedStmt>(*S), CR_Default);
     break;
   case Stmt::ObjCAtTryStmtClass:
     EmitObjCAtTryStmt(cast<ObjCAtTryStmt>(*S));
@@ -1750,6 +1751,94 @@
   }
 }
 
-void CodeGenFunction::EmitCapturedStmt(const CapturedStmt &S) {
-  llvm_unreachable("not implemented yet");
+static LValue InitCapturedStruct(CodeGenFunction &CGF, const CapturedStmt &S) {
+  const RecordDecl *RD = S.getCapturedRecordDecl();
+  QualType RecordTy = CGF.getContext().getRecordType(RD);
+
+  // Initialize the captured struct.
+  LValue SlotLV = CGF.MakeNaturalAlignAddrLValue(
+                    CGF.CreateMemTemp(RecordTy, "agg.captured"), RecordTy);
+
+  RecordDecl::field_iterator CurField = RD->field_begin();
+  for (CapturedStmt::capture_init_iterator I = S.capture_init_begin(),
+                                           E = S.capture_init_end();
+       I != E; ++I, ++CurField) {
+    LValue LV = CGF.EmitLValueForFieldInitialization(SlotLV, *CurField);
+    CGF.EmitInitializerForField(*CurField, LV, *I, ArrayRef<VarDecl *>());
+  }
+
+  return SlotLV;
+}
+
+/// Generate an outlined function for the body of a CapturedStmt, store any
+/// captured variables into the captured struct, and call the outlined function.
+llvm::Function *
+CodeGenFunction::EmitCapturedStmt(const CapturedStmt &S, CapturedRegionKind K) {
+  const CapturedDecl *CD = S.getCapturedDecl();
+  const RecordDecl *RD = S.getCapturedRecordDecl();
+  assert(CD->hasBody() && "missing CapturedDecl body");
+
+  LValue CapStruct = InitCapturedStruct(*this, S);
+
+  // Emit the CapturedDecl
+  CodeGenFunction CGF(CGM, true);
+  CGF.CapturedStmtInfo = new CGCapturedStmtInfo(S, K);
+  llvm::Function *F = CGF.GenerateCapturedStmtFunction(CD, RD);
+  delete CGF.CapturedStmtInfo;
+
+  // Emit call to the helper function.
+  EmitCallOrInvoke(F, CapStruct.getAddress());
+
+  return F;
+}
+
+/// Creates the outlined function for a CapturedStmt.
+llvm::Function *
+CodeGenFunction::GenerateCapturedStmtFunction(const CapturedDecl *CD,
+                                              const RecordDecl *RD) {
+  assert(CapturedStmtInfo &&
+    "CapturedStmtInfo should be set when generating the captured function");
+
+  // Check if we should generate debug info for this function.
+  maybeInitializeDebugInfo();
+
+  // Build the argument list.
+  ASTContext &Ctx = CGM.getContext();
+  FunctionArgList Args;
+  Args.append(CD->param_begin(), CD->param_end());
+
+  // Create the function declaration.
+  FunctionType::ExtInfo ExtInfo;
+  const CGFunctionInfo &FuncInfo =
+    CGM.getTypes().arrangeFunctionDeclaration(Ctx.VoidTy, Args, ExtInfo,
+                                              /*IsVariadic=*/false);
+  llvm::FunctionType *FuncLLVMTy = CGM.getTypes().GetFunctionType(FuncInfo);
+
+  llvm::Function *F =
+    llvm::Function::Create(FuncLLVMTy, llvm::GlobalValue::InternalLinkage,
+                           CapturedStmtInfo->getHelperName(), &CGM.getModule());
+  CGM.SetInternalFunctionAttributes(CD, F, FuncInfo);
+
+  // Generate the function.
+  StartFunction(CD, Ctx.VoidTy, F, FuncInfo, Args, CD->getBody()->getLocStart());
+
+  // Set the context parameter in CapturedStmtInfo.
+  llvm::Value *DeclPtr = LocalDeclMap[CD->getContextParam()];
+  assert(DeclPtr && "missing context parameter for CapturedStmt");
+  CapturedStmtInfo->setContextValue(Builder.CreateLoad(DeclPtr));
+
+  // If 'this' is captured, load it into CXXThisValue.
+  if (CapturedStmtInfo->isCXXThisExprCaptured()) {
+    FieldDecl *FD = CapturedStmtInfo->getThisFieldDecl();
+    LValue LV = MakeNaturalAlignAddrLValue(CapturedStmtInfo->getContextValue(),
+                                           Ctx.getTagDeclType(RD));
+    LValue ThisLValue = EmitLValueForField(LV, FD);
+
+    CXXThisValue = EmitLoadOfLValue(ThisLValue).getScalarVal();
+  }
+
+  CapturedStmtInfo->EmitBody(*this, CD->getBody());
+  FinishFunction(CD->getBodyRBrace());
+
+  return F;
 }
diff --git a/lib/CodeGen/CodeGenFunction.cpp b/lib/CodeGen/CodeGenFunction.cpp
index 75c60ed..071c08e 100644
--- a/lib/CodeGen/CodeGenFunction.cpp
+++ b/lib/CodeGen/CodeGenFunction.cpp
@@ -33,6 +33,7 @@
 CodeGenFunction::CodeGenFunction(CodeGenModule &cgm, bool suppressNewContext)
   : CodeGenTypeCache(cgm), CGM(cgm), Target(cgm.getTarget()),
     Builder(cgm.getModule().getContext()),
+    CapturedStmtInfo(0),
     SanitizePerformTypeCheck(CGM.getSanOpts().Null |
                              CGM.getSanOpts().Alignment |
                              CGM.getSanOpts().ObjectSize |
@@ -1447,3 +1448,5 @@
 
   return V;
 }
+
+CodeGenFunction::CGCapturedStmtInfo::~CGCapturedStmtInfo() { }
diff --git a/lib/CodeGen/CodeGenFunction.h b/lib/CodeGen/CodeGenFunction.h
index ff74c15..6caf168 100644
--- a/lib/CodeGen/CodeGenFunction.h
+++ b/lib/CodeGen/CodeGenFunction.h
@@ -23,6 +23,7 @@
 #include "clang/AST/ExprObjC.h"
 #include "clang/AST/Type.h"
 #include "clang/Basic/ABI.h"
+#include "clang/Basic/CapturedStmt.h"
 #include "clang/Basic/TargetInfo.h"
 #include "clang/Frontend/CodeGenOptions.h"
 #include "llvm/ADT/ArrayRef.h"
@@ -606,6 +607,65 @@
   /// we prefer to insert allocas.
   llvm::AssertingVH<llvm::Instruction> AllocaInsertPt;
 
+  /// \brief API for captured statement code generation.
+  class CGCapturedStmtInfo {
+  public:
+    explicit CGCapturedStmtInfo(const CapturedStmt &S,
+                                CapturedRegionKind K = CR_Default)
+      : Kind(K), ThisValue(0), CXXThisFieldDecl(0) {
+
+      RecordDecl::field_iterator Field =
+        S.getCapturedRecordDecl()->field_begin();
+      for (CapturedStmt::const_capture_iterator I = S.capture_begin(),
+                                                E = S.capture_end();
+           I != E; ++I, ++Field) {
+        if (I->capturesThis())
+          CXXThisFieldDecl = *Field;
+        else
+          CaptureFields[I->getCapturedVar()] = *Field;
+      }
+    }
+
+    virtual ~CGCapturedStmtInfo();
+
+    CapturedRegionKind getKind() const { return Kind; }
+
+    void setContextValue(llvm::Value *V) { ThisValue = V; }
+    // \brief Retrieve the value of the context parameter.
+    llvm::Value *getContextValue() const { return ThisValue; }
+
+    /// \brief Lookup the captured field decl for a variable.
+    const FieldDecl *lookup(const VarDecl *VD) const {
+      return CaptureFields.lookup(VD);
+    }
+
+    bool isCXXThisExprCaptured() const { return CXXThisFieldDecl != 0; }
+    FieldDecl *getThisFieldDecl() const { return CXXThisFieldDecl; }
+
+    /// \brief Emit the captured statement body.
+    virtual void EmitBody(CodeGenFunction &CGF, Stmt *S) {
+      CGF.EmitStmt(S);
+    }
+
+    /// \brief Get the name of the capture helper.
+    virtual StringRef getHelperName() const { return "__captured_stmt"; }
+
+  private:
+    /// \brief The kind of captured statement being generated.
+    CapturedRegionKind Kind;
+
+    /// \brief Keep the map between VarDecl and FieldDecl.
+    llvm::SmallDenseMap<const VarDecl *, FieldDecl *> CaptureFields;
+
+    /// \brief The base address of the captured record, passed in as the first
+    /// argument of the parallel region function.
+    llvm::Value *ThisValue;
+
+    /// \brief Captured 'this' type.
+    FieldDecl *CXXThisFieldDecl;
+  };
+  CGCapturedStmtInfo *CapturedStmtInfo;
+
   /// BoundsChecking - Emit run-time bounds checks. Higher values mean
   /// potentially higher performance penalties.
   unsigned char BoundsChecking;
@@ -2188,7 +2248,6 @@
   void EmitCaseStmt(const CaseStmt &S);
   void EmitCaseStmtRange(const CaseStmt &S);
   void EmitAsmStmt(const AsmStmt &S);
-  void EmitCapturedStmt(const CapturedStmt &S);
 
   void EmitObjCForCollectionStmt(const ObjCForCollectionStmt &S);
   void EmitObjCAtTryStmt(const ObjCAtTryStmt &S);
@@ -2204,6 +2263,10 @@
   void EmitCXXTryStmt(const CXXTryStmt &S);
   void EmitCXXForRangeStmt(const CXXForRangeStmt &S);
 
+  llvm::Function *EmitCapturedStmt(const CapturedStmt &S, CapturedRegionKind K);
+  llvm::Function *GenerateCapturedStmtFunction(const CapturedDecl *CD,
+                                               const RecordDecl *RD);
+
   //===--------------------------------------------------------------------===//
   //                         LValue Expression Emission
   //===--------------------------------------------------------------------===//