[TRE] Improve code motion in TRE, use AA to tell whether a load can be moved before a call that writes to memory.

Summary: use AA to tell whether a load can be moved before a call that writes to memory.

Reviewers: dberlin, davide, sanjoy, hfinkel

Reviewed By: hfinkel

Subscribers: hfinkel, llvm-commits

Differential Revision: https://reviews.llvm.org/D34115

llvm-svn: 305698
diff --git a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
index 365566c..9397b87c 100644
--- a/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
+++ b/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp
@@ -321,7 +321,7 @@
 /// instruction from after the call to before the call, assuming that all
 /// instructions between the call and this instruction are movable.
 ///
-static bool canMoveAboveCall(Instruction *I, CallInst *CI) {
+static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) {
   // FIXME: We can move load/store/call/free instructions above the call if the
   // call does not mod/ref the memory location being processed.
   if (I->mayHaveSideEffects())  // This also handles volatile loads.
@@ -332,10 +332,10 @@
     if (CI->mayHaveSideEffects()) {
       // Non-volatile loads may be moved above a call with side effects if it
       // does not write to memory and the load provably won't trap.
-      // FIXME: Writes to memory only matter if they may alias the pointer
+      // Writes to memory only matter if they may alias the pointer
       // being loaded from.
       const DataLayout &DL = L->getModule()->getDataLayout();
-      if (CI->mayWriteToMemory() ||
+      if ((AA->getModRefInfo(CI, MemoryLocation::get(L)) & MRI_Mod) ||
           !isSafeToLoadUnconditionally(L->getPointerOperand(),
                                        L->getAlignment(), DL, L))
         return false;
@@ -492,10 +492,11 @@
   return CI;
 }
 
-static bool
-eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry,
-                           bool &TailCallsAreMarkedTail,
-                           SmallVectorImpl<PHINode *> &ArgumentPHIs) {
+static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret,
+                                       BasicBlock *&OldEntry,
+                                       bool &TailCallsAreMarkedTail,
+                                       SmallVectorImpl<PHINode *> &ArgumentPHIs,
+                                       AliasAnalysis *AA) {
   // If we are introducing accumulator recursion to eliminate operations after
   // the call instruction that are both associative and commutative, the initial
   // value for the accumulator is placed in this variable.  If this value is set
@@ -515,7 +516,8 @@
   // Check that this is the case now.
   BasicBlock::iterator BBI(CI);
   for (++BBI; &*BBI != Ret; ++BBI) {
-    if (canMoveAboveCall(&*BBI, CI)) continue;
+    if (canMoveAboveCall(&*BBI, CI, AA))
+      continue;
 
     // If we can't move the instruction above the call, it might be because it
     // is an associative and commutative operation that could be transformed
@@ -674,7 +676,8 @@
                                      bool &TailCallsAreMarkedTail,
                                      SmallVectorImpl<PHINode *> &ArgumentPHIs,
                                      bool CannotTailCallElimCallsMarkedTail,
-                                     const TargetTransformInfo *TTI) {
+                                     const TargetTransformInfo *TTI,
+                                     AliasAnalysis *AA) {
   bool Change = false;
 
   // Make sure this block is a trivial return block.
@@ -710,7 +713,7 @@
         BB->eraseFromParent();
 
       eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail,
-                                 ArgumentPHIs);
+                                 ArgumentPHIs, AA);
       ++NumRetDuped;
       Change = true;
     }
@@ -723,16 +726,18 @@
                                   bool &TailCallsAreMarkedTail,
                                   SmallVectorImpl<PHINode *> &ArgumentPHIs,
                                   bool CannotTailCallElimCallsMarkedTail,
-                                  const TargetTransformInfo *TTI) {
+                                  const TargetTransformInfo *TTI,
+                                  AliasAnalysis *AA) {
   CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI);
   if (!CI)
     return false;
 
   return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail,
-                                    ArgumentPHIs);
+                                    ArgumentPHIs, AA);
 }
 
-static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI) {
+static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI,
+                                   AliasAnalysis *AA) {
   if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true")
     return false;
 
@@ -767,11 +772,11 @@
     if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) {
       bool Change =
           processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail,
-                                ArgumentPHIs, !CanTRETailMarkedCall, TTI);
+                                ArgumentPHIs, !CanTRETailMarkedCall, TTI, AA);
       if (!Change && BB->getFirstNonPHIOrDbg() == Ret)
-        Change =
-            foldReturnAndProcessPred(BB, Ret, OldEntry, TailCallsAreMarkedTail,
-                                     ArgumentPHIs, !CanTRETailMarkedCall, TTI);
+        Change = foldReturnAndProcessPred(BB, Ret, OldEntry,
+                                          TailCallsAreMarkedTail, ArgumentPHIs,
+                                          !CanTRETailMarkedCall, TTI, AA);
       MadeChange |= Change;
     }
   }
@@ -801,6 +806,7 @@
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<TargetTransformInfoWrapperPass>();
+    AU.addRequired<AAResultsWrapperPass>();
     AU.addPreserved<GlobalsAAWrapperPass>();
   }
 
@@ -809,7 +815,8 @@
       return false;
 
     return eliminateTailRecursion(
-        F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F));
+        F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F),
+        &getAnalysis<AAResultsWrapperPass>().getAAResults());
   }
 };
 }
@@ -830,8 +837,9 @@
                                         FunctionAnalysisManager &AM) {
 
   TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F);
+  AliasAnalysis &AA = AM.getResult<AAManager>(F);
 
-  bool Changed = eliminateTailRecursion(F, &TTI);
+  bool Changed = eliminateTailRecursion(F, &TTI, &AA);
 
   if (!Changed)
     return PreservedAnalyses::all();
diff --git a/llvm/test/Transforms/TailCallElim/reorder_load.ll b/llvm/test/Transforms/TailCallElim/reorder_load.ll
index 2f9b692..78621b1 100644
--- a/llvm/test/Transforms/TailCallElim/reorder_load.ll
+++ b/llvm/test/Transforms/TailCallElim/reorder_load.ll
@@ -7,6 +7,7 @@
 ; then eliminate the tail recursion.
 
 
+
 @global = external global i32		; <i32*> [#uses=1]
 @extern_weak_global = extern_weak global i32		; <i32*> [#uses=1]
 
@@ -145,3 +146,29 @@
 	%tmp10 = add i32 %tmp9, %tmp8		; <i32> [#uses=1]
 	ret i32 %tmp10
 }
+
+; This load can be moved above the call because the function call does not write to the memory the load
+; is accessing and the load is safe to speculate.
+define fastcc i32 @raise_load_6(i32* %a_arg, i32 %a_len_arg, i32 %start_arg) nounwind  {
+; CHECK-LABEL: @raise_load_6(
+; CHECK-NOT: call
+; CHECK: load i32, i32*
+; CHECK-NOT: call
+; CHECK: }
+entry:
+  %s = alloca i32
+  store i32 4, i32* %s
+	%tmp2 = icmp sge i32 %start_arg, %a_len_arg		; <i1> [#uses=1]
+	br i1 %tmp2, label %if, label %else
+
+if:		; preds = %entry
+  store i32 1, i32* %a_arg
+	ret i32 0
+
+else:		; preds = %entry
+	%tmp7 = add i32 %start_arg, 1		; <i32> [#uses=1]
+	%tmp8 = call fastcc i32 @raise_load_6(i32* %a_arg, i32 %a_len_arg, i32 %tmp7)		; <i32> [#uses=1]
+	%tmp9 = load i32, i32* %s		; <i32> [#uses=1]
+	%tmp10 = add i32 %tmp9, %tmp8		; <i32> [#uses=1]
+	ret i32 %tmp10
+}