Re-apply r124518 with fix. Watch out for invalidated iterator.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@124526 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/TailDuplication.cpp b/lib/CodeGen/TailDuplication.cpp
index ce4b1be..15aed34 100644
--- a/lib/CodeGen/TailDuplication.cpp
+++ b/lib/CodeGen/TailDuplication.cpp
@@ -465,9 +465,12 @@
     MaxDuplicateCount = TailDuplicateSize;
 
   if (PreRegAlloc) {
-      // Pre-regalloc tail duplication hurts compile time and doesn't help
-      // much except for indirect branches.
-    if (TailBB->empty() || !TailBB->back().getDesc().isIndirectBranch())
+    if (TailBB->empty())
+      return false;
+    const TargetInstrDesc &TID = TailBB->back().getDesc();
+    // Pre-regalloc tail duplication hurts compile time and doesn't help
+    // much except for indirect branches and returns.
+    if (!TID.isIndirectBranch() && !TID.isReturn())
       return false;
     // If the target has hardware branch prediction that can handle indirect
     // branches, duplicating them can often make them predictable when there
@@ -502,7 +505,7 @@
   }
   // Heuristically, don't tail-duplicate calls if it would expand code size,
   // as it's less likely to be worth the extra cost.
-  if (InstrCount > 1 && HasCall)
+  if (InstrCount > 1 && (PreRegAlloc && HasCall))
     return false;
 
   DEBUG(dbgs() << "\n*** Tail-duplicating BB#" << TailBB->getNumber() << '\n');
diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 23514fd..61f2308 100644
--- a/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -52,11 +52,13 @@
 
 #define DEBUG_TYPE "tailcallelim"
 #include "llvm/Transforms/Scalar.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Constants.h"
 #include "llvm/DerivedTypes.h"
 #include "llvm/Function.h"
 #include "llvm/Instructions.h"
+#include "llvm/IntrinsicInst.h"
 #include "llvm/Pass.h"
 #include "llvm/Analysis/CaptureTracking.h"
 #include "llvm/Analysis/InlineCost.h"
@@ -64,7 +66,9 @@
 #include "llvm/Analysis/Loads.h"
 #include "llvm/Support/CallSite.h"
 #include "llvm/Support/CFG.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/STLExtras.h"
 using namespace llvm;
 
 STATISTIC(NumEliminated, "Number of tail calls removed");
@@ -80,6 +84,18 @@
     virtual bool runOnFunction(Function &F);
 
   private:
+    CallInst *FindTRECandidate(Instruction *I,
+                               bool CannotTailCallElimCallsMarkedTail);
+    bool EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret,
+                                    BasicBlock *&OldEntry,
+                                    bool &TailCallsAreMarkedTail,
+                                    SmallVector<PHINode*, 8> &ArgumentPHIs,
+                                    bool CannotTailCallElimCallsMarkedTail);
+    bool FoldReturnAndProcessPred(BasicBlock *BB,
+                                  ReturnInst *Ret, BasicBlock *&OldEntry,
+                                  bool &TailCallsAreMarkedTail,
+                                  SmallVector<PHINode*, 8> &ArgumentPHIs,
+                                  bool CannotTailCallElimCallsMarkedTail);
     bool ProcessReturningBlock(ReturnInst *RI, BasicBlock *&OldEntry,
                                bool &TailCallsAreMarkedTail,
                                SmallVector<PHINode*, 8> &ArgumentPHIs,
@@ -136,7 +152,6 @@
   bool TailCallsAreMarkedTail = false;
   SmallVector<PHINode*, 8> ArgumentPHIs;
   bool MadeChange = false;
-
   bool FunctionContainsEscapingAllocas = false;
 
   // CannotTCETailMarkedCall - If true, we cannot perform TCE on tail calls
@@ -163,10 +178,17 @@
     return false;
 
   // Second pass, change any tail calls to loops.
-  for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB)
-    if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator()))
-      MadeChange |= ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail,
+  for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) {
+    if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) {
+      bool Change = ProcessReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail,
                                           ArgumentPHIs,CannotTCETailMarkedCall);
+      if (!Change && BB->getFirstNonPHIOrDbg() == Ret)
+        Change = FoldReturnAndProcessPred(BB, Ret, OldEntry,
+                                          TailCallsAreMarkedTail, ArgumentPHIs,
+                                          CannotTCETailMarkedCall);
+      MadeChange |= Change;
+    }
+  }
 
   // If we eliminated any tail recursions, it's possible that we inserted some
   // silly PHI nodes which just merge an initial value (the incoming operand)
@@ -325,41 +347,47 @@
   return getCommonReturnValue(cast<ReturnInst>(I->use_back()), CI);
 }
 
-bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
-                                         bool &TailCallsAreMarkedTail,
-                                         SmallVector<PHINode*, 8> &ArgumentPHIs,
-                                       bool CannotTailCallElimCallsMarkedTail) {
-  BasicBlock *BB = Ret->getParent();
+static Instruction *FirstNonDbg(BasicBlock::iterator I) {
+  while (isa<DbgInfoIntrinsic>(I))
+    ++I;
+  return &*I;
+}
+
+CallInst*
+TailCallElim::FindTRECandidate(Instruction *TI,
+                               bool CannotTailCallElimCallsMarkedTail) {
+  BasicBlock *BB = TI->getParent();
   Function *F = BB->getParent();
 
-  if (&BB->front() == Ret) // Make sure there is something before the ret...
-    return false;
+  if (&BB->front() == TI) // Make sure there is something before the terminator.
+    return 0;
   
   // Scan backwards from the return, checking to see if there is a tail call in
   // this block.  If so, set CI to it.
-  CallInst *CI;
-  BasicBlock::iterator BBI = Ret;
-  while (1) {
+  CallInst *CI = 0;
+  BasicBlock::iterator BBI = TI;
+  while (true) {
     CI = dyn_cast<CallInst>(BBI);
     if (CI && CI->getCalledFunction() == F)
       break;
 
     if (BBI == BB->begin())
-      return false;          // Didn't find a potential tail call.
+      return 0;          // Didn't find a potential tail call.
     --BBI;
   }
 
   // If this call is marked as a tail call, and if there are dynamic allocas in
   // the function, we cannot perform this optimization.
   if (CI->isTailCall() && CannotTailCallElimCallsMarkedTail)
-    return false;
+    return 0;
 
   // As a special case, detect code like this:
   //   double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call
   // and disable this xform in this case, because the code generator will
   // lower the call to fabs into inline code.
   if (BB == &F->getEntryBlock() && 
-      &BB->front() == CI && &*++BB->begin() == Ret &&
+      FirstNonDbg(BB->front()) == CI &&
+      FirstNonDbg(llvm::next(BB->begin())) == TI &&
       callIsSmall(F)) {
     // A single-block function with just a call and a return. Check that
     // the arguments match.
@@ -370,9 +398,17 @@
     for (; I != E && FI != FE; ++I, ++FI)
       if (*I != &*FI) break;
     if (I == E && FI == FE)
-      return false;
+      return 0;
   }
 
+  return CI;
+}
+
+bool TailCallElim::EliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret,
+                                       BasicBlock *&OldEntry,
+                                       bool &TailCallsAreMarkedTail,
+                                       SmallVector<PHINode*, 8> &ArgumentPHIs,
+                                       bool CannotTailCallElimCallsMarkedTail) {
   // If we are introducing accumulator recursion to eliminate operations after
   // the call instruction that are both associative and commutative, the initial
   // value for the accumulator is placed in this variable.  If this value is set
@@ -390,7 +426,8 @@
   // tail call if all of the instructions between the call and the return are
   // movable to above the call itself, leaving the call next to the return.
   // Check that this is the case now.
-  for (BBI = CI, ++BBI; &*BBI != Ret; ++BBI) {
+  BasicBlock::iterator BBI = CI;
+  for (++BBI; &*BBI != Ret; ++BBI) {
     if (CanMoveAboveCall(BBI, CI)) continue;
     
     // If we can't move the instruction above the call, it might be because it
@@ -427,6 +464,9 @@
       return false;
   }
 
+  BasicBlock *BB = Ret->getParent();
+  Function *F = BB->getParent();
+
   // OK! We can transform this tail call.  If this is the first one found,
   // create the new entry block, allowing us to branch back to the old entry.
   if (OldEntry == 0) {
@@ -536,3 +576,52 @@
   ++NumEliminated;
   return true;
 }
+
+bool TailCallElim::FoldReturnAndProcessPred(BasicBlock *BB,
+                                       ReturnInst *Ret, BasicBlock *&OldEntry,
+                                       bool &TailCallsAreMarkedTail,
+                                       SmallVector<PHINode*, 8> &ArgumentPHIs,
+                                       bool CannotTailCallElimCallsMarkedTail) {
+  bool Change = false;
+
+  // If the return block contains nothing but the return and PHI's,
+  // there might be an opportunity to duplicate the return in its
+  // predecessors and perform TRC there. Look for predecessors that end
+  // in unconditional branch and recursive call(s).
+  SmallVector<BranchInst*, 8> UncondBranchPreds;
+  for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) {
+    BasicBlock *Pred = *PI;
+    TerminatorInst *PTI = Pred->getTerminator();
+    if (BranchInst *BI = dyn_cast<BranchInst>(PTI))
+      if (BI->isUnconditional())
+        UncondBranchPreds.push_back(BI);
+  }
+
+  while (!UncondBranchPreds.empty()) {
+    BranchInst *BI = UncondBranchPreds.pop_back_val();
+    BasicBlock *Pred = BI->getParent();
+    if (CallInst *CI = FindTRECandidate(BI, CannotTailCallElimCallsMarkedTail)){
+      DEBUG(dbgs() << "FOLDING: " << *BB
+            << "INTO UNCOND BRANCH PRED: " << *Pred);
+      EliminateRecursiveTailCall(CI, FoldReturnIntoUncondBranch(Ret, BB, Pred),
+                                 OldEntry, TailCallsAreMarkedTail, ArgumentPHIs,
+                                 CannotTailCallElimCallsMarkedTail);
+      Change = true;
+    }
+  }
+
+  return Change;
+}
+
+bool TailCallElim::ProcessReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry,
+                                         bool &TailCallsAreMarkedTail,
+                                         SmallVector<PHINode*, 8> &ArgumentPHIs,
+                                       bool CannotTailCallElimCallsMarkedTail) {
+  CallInst *CI = FindTRECandidate(Ret, CannotTailCallElimCallsMarkedTail);
+  if (!CI)
+    return false;
+
+  return EliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail,
+                                    ArgumentPHIs,
+                                    CannotTailCallElimCallsMarkedTail);
+}
diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp
index 5d3af37..acaea19 100644
--- a/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -509,7 +509,32 @@
       // Go up one level.
       InStack.erase(VisitStack.pop_back_val().first);
     }
-  } while (!VisitStack.empty());
-  
-  
+  } while (!VisitStack.empty()); 
+}
+
+/// FoldReturnIntoUncondBranch - This method duplicates the specified return
+/// instruction into a predecessor which ends in an unconditional branch. If
+/// the return instruction returns a value defined by a PHI, propagate the
+/// right value into the return. It returns the new return instruction in the
+/// predecessor.
+ReturnInst *llvm::FoldReturnIntoUncondBranch(ReturnInst *RI, BasicBlock *BB,
+                                             BasicBlock *Pred) {
+  Instruction *UncondBranch = Pred->getTerminator();
+  // Clone the return and add it to the end of the predecessor.
+  Instruction *NewRet = RI->clone();
+  Pred->getInstList().push_back(NewRet);
+      
+  // If the return instruction returns a value, and if the value was a
+  // PHI node in "BB", propagate the right value into the return.
+  for (User::op_iterator i = NewRet->op_begin(), e = NewRet->op_end();
+       i != e; ++i)
+    if (PHINode *PN = dyn_cast<PHINode>(*i))
+      if (PN->getParent() == BB)
+        *i = PN->getIncomingValueForBlock(Pred);
+      
+  // Update any PHI nodes in the returning block to realize that we no
+  // longer branch to them.
+  BB->removePredecessor(Pred);
+  UncondBranch->eraseFromParent();
+  return cast<ReturnInst>(NewRet);
 }
diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp
index f6d7d76..b9432c2 100644
--- a/lib/Transforms/Utils/SimplifyCFG.cpp
+++ b/lib/Transforms/Utils/SimplifyCFG.cpp
@@ -28,6 +28,7 @@
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/CFG.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/ConstantRange.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
@@ -36,6 +37,10 @@
 #include <map>
 using namespace llvm;
 
+static cl::opt<bool>
+DupRet("simplifycfg-dup-ret", cl::Hidden, cl::init(false),
+       cl::desc("Duplicate return instructions into unconditional branches"));
+
 STATISTIC(NumSpeculations, "Number of speculative executed instructions");
 
 namespace {
@@ -2027,28 +2032,12 @@
   }
   
   // If we found some, do the transformation!
-  if (!UncondBranchPreds.empty()) {
+  if (!UncondBranchPreds.empty() && DupRet) {
     while (!UncondBranchPreds.empty()) {
       BasicBlock *Pred = UncondBranchPreds.pop_back_val();
       DEBUG(dbgs() << "FOLDING: " << *BB
             << "INTO UNCOND BRANCH PRED: " << *Pred);
-      Instruction *UncondBranch = Pred->getTerminator();
-      // Clone the return and add it to the end of the predecessor.
-      Instruction *NewRet = RI->clone();
-      Pred->getInstList().push_back(NewRet);
-      
-      // If the return instruction returns a value, and if the value was a
-      // PHI node in "BB", propagate the right value into the return.
-      for (User::op_iterator i = NewRet->op_begin(), e = NewRet->op_end();
-           i != e; ++i)
-        if (PHINode *PN = dyn_cast<PHINode>(*i))
-          if (PN->getParent() == BB)
-            *i = PN->getIncomingValueForBlock(Pred);
-      
-      // Update any PHI nodes in the returning block to realize that we no
-      // longer branch to them.
-      BB->removePredecessor(Pred);
-      UncondBranch->eraseFromParent();
+      (void)FoldReturnIntoUncondBranch(RI, BB, Pred);
     }
     
     // If we eliminated all predecessors of the block, delete the block now.