[PartialInlining] Support shrinkwrap life_range markers

Differential Revision: http://reviews.llvm.org/D33847

llvm-svn: 305170
diff --git a/llvm/lib/Transforms/Utils/CodeExtractor.cpp b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
index 24d28a6..5d57ed9 100644
--- a/llvm/lib/Transforms/Utils/CodeExtractor.cpp
+++ b/llvm/lib/Transforms/Utils/CodeExtractor.cpp
@@ -142,59 +142,237 @@
   return false;
 }
 
-void CodeExtractor::findAllocas(ValueSet &SinkCands) const {
+static BasicBlock *getCommonExitBlock(const SetVector<BasicBlock *> &Blocks) {
+  BasicBlock *CommonExitBlock = nullptr;
+  auto hasNonCommonExitSucc = [&](BasicBlock *Block) {
+    for (auto *Succ : successors(Block)) {
+      // Internal edges, ok.
+      if (Blocks.count(Succ))
+        continue;
+      if (!CommonExitBlock) {
+        CommonExitBlock = Succ;
+        continue;
+      }
+      if (CommonExitBlock == Succ)
+        continue;
+
+      return true;
+    }
+    return false;
+  };
+
+  if (any_of(Blocks, hasNonCommonExitSucc))
+    return nullptr;
+
+  return CommonExitBlock;
+}
+
+bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers(
+    Instruction *Addr) const {
+  AllocaInst *AI = cast<AllocaInst>(Addr->stripInBoundsConstantOffsets());
   Function *Func = (*Blocks.begin())->getParent();
   for (BasicBlock &BB : *Func) {
     if (Blocks.count(&BB))
       continue;
     for (Instruction &II : BB) {
+
+      if (isa<DbgInfoIntrinsic>(II))
+        continue;
+
+      unsigned Opcode = II.getOpcode();
+      Value *MemAddr = nullptr;
+      switch (Opcode) {
+      case Instruction::Store:
+      case Instruction::Load: {
+        if (Opcode == Instruction::Store) {
+          StoreInst *SI = cast<StoreInst>(&II);
+          MemAddr = SI->getPointerOperand();
+        } else {
+          LoadInst *LI = cast<LoadInst>(&II);
+          MemAddr = LI->getPointerOperand();
+        }
+        // Global variable can not be aliased with locals.
+        if (dyn_cast<Constant>(MemAddr))
+          break;
+        Value *Base = MemAddr->stripInBoundsConstantOffsets();
+        if (!dyn_cast<AllocaInst>(Base) || Base == AI)
+          return false;
+        break;
+      }
+      default: {
+        IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(&II);
+        if (IntrInst) {
+          if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start ||
+              IntrInst->getIntrinsicID() == Intrinsic::lifetime_end)
+            break;
+          return false;
+        }
+        // Treat all the other cases conservatively if it has side effects.
+        if (II.mayHaveSideEffects())
+          return false;
+      }
+      }
+    }
+  }
+
+  return true;
+}
+
+BasicBlock *
+CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) {
+  BasicBlock *SinglePredFromOutlineRegion = nullptr;
+  assert(!Blocks.count(CommonExitBlock) &&
+         "Expect a block outside the region!");
+  for (auto *Pred : predecessors(CommonExitBlock)) {
+    if (!Blocks.count(Pred))
+      continue;
+    if (!SinglePredFromOutlineRegion) {
+      SinglePredFromOutlineRegion = Pred;
+    } else if (SinglePredFromOutlineRegion != Pred) {
+      SinglePredFromOutlineRegion = nullptr;
+      break;
+    }
+  }
+
+  if (SinglePredFromOutlineRegion)
+    return SinglePredFromOutlineRegion;
+
+#ifndef NDEBUG
+  auto getFirstPHI = [](BasicBlock *BB) {
+    BasicBlock::iterator I = BB->begin();
+    PHINode *FirstPhi = nullptr;
+    while (I != BB->end()) {
+      PHINode *Phi = dyn_cast<PHINode>(I);
+      if (!Phi)
+        break;
+      if (!FirstPhi) {
+        FirstPhi = Phi;
+        break;
+      }
+    }
+    return FirstPhi;
+  };
+  // If there are any phi nodes, the single pred either exists or has already
+  // be created before code extraction.
+  assert(!getFirstPHI(CommonExitBlock) && "Phi not expected");
+#endif
+
+  BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock(
+      CommonExitBlock->getFirstNonPHI()->getIterator());
+
+  for (auto *Pred : predecessors(CommonExitBlock)) {
+    if (Blocks.count(Pred))
+      continue;
+    Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock);
+  }
+  // Now add the old exit block to the outline region.
+  Blocks.insert(CommonExitBlock);
+  return CommonExitBlock;
+}
+
+void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands,
+                                BasicBlock *&ExitBlock) const {
+  Function *Func = (*Blocks.begin())->getParent();
+  ExitBlock = getCommonExitBlock(Blocks);
+
+  for (BasicBlock &BB : *Func) {
+    if (Blocks.count(&BB))
+      continue;
+    for (Instruction &II : BB) {
       auto *AI = dyn_cast<AllocaInst>(&II);
       if (!AI)
         continue;
 
-      // Returns true if matching life time markers are found within
-      // the outlined region.
-      auto GetLifeTimeMarkers = [&](Instruction *Addr) {
+      // Find the pair of life time markers for address 'Addr' that are either
+      // defined inside the outline region or can legally be shrinkwrapped into
+      // the outline region. If there are not other untracked uses of the
+      // address, return the pair of markers if found; otherwise return a pair
+      // of nullptr.
+      auto GetLifeTimeMarkers =
+          [&](Instruction *Addr, bool &SinkLifeStart,
+              bool &HoistLifeEnd) -> std::pair<Instruction *, Instruction *> {
         Instruction *LifeStart = nullptr, *LifeEnd = nullptr;
-        for (User *U : Addr->users()) {
-          if (!definedInRegion(Blocks, U))
-            return false;
 
+        for (User *U : Addr->users()) {
           IntrinsicInst *IntrInst = dyn_cast<IntrinsicInst>(U);
           if (IntrInst) {
-            if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start)
+            if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_start) {
+              // Do not handle the case where AI has multiple start markers.
+              if (LifeStart)
+                return std::make_pair<Instruction *>(nullptr, nullptr);
               LifeStart = IntrInst;
-            if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end)
+            }
+            if (IntrInst->getIntrinsicID() == Intrinsic::lifetime_end) {
+              if (LifeEnd)
+                return std::make_pair<Instruction *>(nullptr, nullptr);
               LifeEnd = IntrInst;
+            }
+            continue;
           }
+          // Find untracked uses of the address, bail.
+          if (!definedInRegion(Blocks, U))
+            return std::make_pair<Instruction *>(nullptr, nullptr);
         }
-        return LifeStart && LifeEnd;
+
+        if (!LifeStart || !LifeEnd)
+          return std::make_pair<Instruction *>(nullptr, nullptr);
+
+        SinkLifeStart = !definedInRegion(Blocks, LifeStart);
+        HoistLifeEnd = !definedInRegion(Blocks, LifeEnd);
+        // Do legality Check.
+        if ((SinkLifeStart || HoistLifeEnd) &&
+            !isLegalToShrinkwrapLifetimeMarkers(Addr))
+          return std::make_pair<Instruction *>(nullptr, nullptr);
+
+        // Check to see if we have a place to do hoisting, if not, bail.
+        if (HoistLifeEnd && !ExitBlock)
+          return std::make_pair<Instruction *>(nullptr, nullptr);
+
+        return std::make_pair(LifeStart, LifeEnd);
       };
 
-      if (GetLifeTimeMarkers(AI)) {
+      bool SinkLifeStart = false, HoistLifeEnd = false;
+      auto Markers = GetLifeTimeMarkers(AI, SinkLifeStart, HoistLifeEnd);
+
+      if (Markers.first) {
+        if (SinkLifeStart)
+          SinkCands.insert(Markers.first);
         SinkCands.insert(AI);
+        if (HoistLifeEnd)
+          HoistCands.insert(Markers.second);
         continue;
       }
 
-      // Follow the bitcast:
+      // Follow the bitcast.
       Instruction *MarkerAddr = nullptr;
       for (User *U : AI->users()) {
-        if (U->stripPointerCasts() == AI) {
+
+        if (U->stripInBoundsConstantOffsets() == AI) {
+          SinkLifeStart = false;
+          HoistLifeEnd = false;
           Instruction *Bitcast = cast<Instruction>(U);
-          if (GetLifeTimeMarkers(Bitcast)) {
+          Markers = GetLifeTimeMarkers(Bitcast, SinkLifeStart, HoistLifeEnd);
+          if (Markers.first) {
             MarkerAddr = Bitcast;
             continue;
           }
         }
+
+        // Found unknown use of AI.
         if (!definedInRegion(Blocks, U)) {
           MarkerAddr = nullptr;
           break;
         }
       }
+
       if (MarkerAddr) {
+        if (SinkLifeStart)
+          SinkCands.insert(Markers.first);
         if (!definedInRegion(Blocks, MarkerAddr))
           SinkCands.insert(MarkerAddr);
         SinkCands.insert(AI);
+        if (HoistLifeEnd)
+          HoistCands.insert(Markers.second);
       }
     }
   }
@@ -780,7 +958,8 @@
   if (!isEligible())
     return nullptr;
 
-  ValueSet inputs, outputs, SinkingCands;
+  ValueSet inputs, outputs, SinkingCands, HoistingCands;
+  BasicBlock *CommonExit = nullptr;
 
   // Assumption: this is a single-entry code region, and the header is the first
   // block in the region.
@@ -819,7 +998,8 @@
                                                "newFuncRoot");
   newFuncRoot->getInstList().push_back(BranchInst::Create(header));
 
-  findAllocas(SinkingCands);
+  findAllocas(SinkingCands, HoistingCands, CommonExit);
+  assert(HoistingCands.empty() || CommonExit);
 
   // Find inputs to, outputs from the code region.
   findInputsOutputs(inputs, outputs, SinkingCands);
@@ -829,6 +1009,13 @@
     cast<Instruction>(II)->moveBefore(*newFuncRoot,
                                       newFuncRoot->getFirstInsertionPt());
 
+  if (!HoistingCands.empty()) {
+    auto *HoistToBlock = findOrCreateBlockForHoisting(CommonExit);
+    Instruction *TI = HoistToBlock->getTerminator();
+    for (auto *II : HoistingCands)
+      cast<Instruction>(II)->moveBefore(TI);
+  }
+
   // Calculate the exit blocks for the extracted region and the total exit
   //  weights for each of those blocks.
   DenseMap<BasicBlock *, BlockFrequency> ExitWeights;