[PartialInlining] Shrinkwrap allocas with live range contained in outline region.

Differential Revision: http://reviews.llvm.org/D33618

llvm-svn: 304245
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index ed72099..24d28a6 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -27,6 +27,7 @@
 #include "llvm/IR/DerivedTypes.h"
 #include "llvm/IR/Dominators.h"
 #include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/Intrinsics.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
@@ -141,16 +142,77 @@
   return false;
 }
 
-void CodeExtractor::findInputsOutputs(ValueSet &Inputs,
-                                      ValueSet &Outputs) const {
+void CodeExtractor::findAllocas(ValueSet &SinkCands) const {
+  Function *Func = (*Blocks.begin())->getParent();
+  for (BasicBlock &BB : *Func) {
+    if (Blocks.count(&BB))
+      continue;
+    for (Instruction &II : BB) {
+      auto *AI = dyn_cast<AllocaInst>(&II);
+      if (!AI)
+        continue;
+
+      // Returns true if matching life time markers are found within
+      // the outlined region.
+      auto GetLifeTimeMarkers = [&](Instruction *Addr) {
+        Instruction *LifeStart = nullptr, *LifeEnd = nullptr;
+        for (User *U : Addr->users()) {
+          if (!definedInRegion(Blocks, U))
+            return false;
+
+          IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
+          if (IntrInst) {
+            if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start)
+              LifeStart = IntrInst;
+            if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end)
+              LifeEnd = IntrInst;
+          }
+        }
+        return LifeStart && LifeEnd;
+      };
+
+      if (GetLifeTimeMarkers(AI)) {
+        SinkCands.insert(AI);
+        continue;
+      }
+
+      // Follow the bitcast:
+      Instruction *MarkerAddr = nullptr;
+      for (User *U : AI->users()) {
+        if (U->stripPointerCasts() == AI) {
+          Instruction *Bitcast = cast<Instruction>(U);
+          if (GetLifeTimeMarkers(Bitcast)) {
+            MarkerAddr = Bitcast;
+            continue;
+          }
+        }
+        if (!definedInRegion(Blocks, U)) {
+          MarkerAddr = nullptr;
+          break;
+        }
+      }
+      if (MarkerAddr) {
+        if (!definedInRegion(Blocks, MarkerAddr))
+          SinkCands.insert(MarkerAddr);
+        SinkCands.insert(AI);
+      }
+    }
+  }
+}
+
+void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs,
+                                      const ValueSet &SinkCands) const {
+
   for (BasicBlock *BB : Blocks) {
     // If a used value is defined outside the region, it's an input.  If an
     // instruction is used outside the region, it's an output.
     for (Instruction &II : *BB) {
       for (User::op_iterator OI = II.op_begin(), OE = II.op_end(); OI != OE;
-           ++OI)
-        if (definedInCaller(Blocks, *OI))
-          Inputs.insert(*OI);
+           ++OI) {
+        Value *V = *OI;
+        if (!SinkCands.count(V) && definedInCaller(Blocks, V))
+          Inputs.insert(V);
+      }
 
       for (User *U : II.users())
         if (!definedInRegion(Blocks, U)) {
@@ -718,7 +780,7 @@
   if (!isEligible())
     return nullptr;
 
-  ValueSet inputs, outputs;
+  ValueSet inputs, outputs, SinkingCands;
 
   // Assumption: this is a single-entry code region, and the header is the first
   // block in the region.
@@ -757,8 +819,15 @@
                                                "newFuncRoot");
   newFuncRoot->getInstList().push_back(BranchInst::Create(header));
 
+  findAllocas(SinkingCands);
+
   // Find inputs to, outputs from the code region.
-  findInputsOutputs(inputs, outputs);
+  findInputsOutputs(inputs, outputs, SinkingCands);
+
+  // Now sink all instructions which only have non-phi uses inside the region
+  for (auto *II : SinkingCands)
+    cast<Instruction>(II)->moveBefore(*newFuncRoot,
+                                      newFuncRoot->getFirstInsertionPt());
 
   // Calculate the exit blocks for the extracted region and the total exit
   //  weights for each of those blocks.