Teach LSR sink to sink the immediate portion of the common expression back into uses if they fit in address modes of all the uses.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@65215 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index d18a008..2099cea 100644
--- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -26,19 +26,19 @@
 #include "llvm/Analysis/LoopInfo.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Analysis/ScalarEvolutionExpander.h"
-#include "llvm/Support/CFG.h"
-#include "llvm/Support/GetElementPtrTypeIterator.h"
+#include "llvm/Transforms/Utils/AddrModeMatcher.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Target/TargetData.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/Support/CFG.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/GetElementPtrTypeIterator.h"
 #include "llvm/Target/TargetLowering.h"
 #include <algorithm>
-#include <set>
 using namespace llvm;
 
 STATISTIC(NumReduced ,    "Number of GEPs strength reduced");
@@ -46,6 +46,7 @@
 STATISTIC(NumVariable,    "Number of PHIs with variable strides");
 STATISTIC(NumEliminated,  "Number of strides eliminated");
 STATISTIC(NumShadow,      "Number of Shadow IVs optimized");
+STATISTIC(NumImmSunk,     "Number of common expr immediates sunk into uses");
 
 static cl::opt<bool> EnableFullLSRMode("enable-full-lsr",
                                        cl::init(false),
@@ -954,21 +955,17 @@
 /// that can fit into the immediate field of instructions in the target.
 /// Accumulate these immediate values into the Imm value.
 static void MoveImmediateValues(const TargetLowering *TLI,
-                                Instruction *User,
+                                const Type *UseTy,
                                 SCEVHandle &Val, SCEVHandle &Imm,
                                 bool isAddress, Loop *L,
                                 ScalarEvolution *SE) {
-  const Type *UseTy = User->getType();
-  if (StoreInst *SI = dyn_cast<StoreInst>(User))
-    UseTy = SI->getOperand(0)->getType();
-
   if (SCEVAddExpr *SAE = dyn_cast<SCEVAddExpr>(Val)) {
     std::vector<SCEVHandle> NewOps;
     NewOps.reserve(SAE->getNumOperands());
     
     for (unsigned i = 0; i != SAE->getNumOperands(); ++i) {
       SCEVHandle NewOp = SAE->getOperand(i);
-      MoveImmediateValues(TLI, User, NewOp, Imm, isAddress, L, SE);
+      MoveImmediateValues(TLI, UseTy, NewOp, Imm, isAddress, L, SE);
       
       if (!NewOp->isLoopInvariant(L)) {
         // If this is a loop-variant expression, it must stay in the immediate
@@ -987,7 +984,7 @@
   } else if (SCEVAddRecExpr *SARE = dyn_cast<SCEVAddRecExpr>(Val)) {
     // Try to pull immediates out of the start value of nested addrec's.
     SCEVHandle Start = SARE->getStart();
-    MoveImmediateValues(TLI, User, Start, Imm, isAddress, L, SE);
+    MoveImmediateValues(TLI, UseTy, Start, Imm, isAddress, L, SE);
     
     if (Start != SARE->getStart()) {
       std::vector<SCEVHandle> Ops(SARE->op_begin(), SARE->op_end());
@@ -1002,7 +999,7 @@
 
       SCEVHandle SubImm = SE->getIntegerSCEV(0, Val->getType());
       SCEVHandle NewOp = SME->getOperand(1);
-      MoveImmediateValues(TLI, User, NewOp, SubImm, isAddress, L, SE);
+      MoveImmediateValues(TLI, UseTy, NewOp, SubImm, isAddress, L, SE);
       
       // If we extracted something out of the subexpressions, see if we can 
       // simplify this!
@@ -1034,6 +1031,16 @@
   // Otherwise, no immediates to move.
 }
 
+static void MoveImmediateValues(const TargetLowering *TLI,
+                                Instruction *User,
+                                SCEVHandle &Val, SCEVHandle &Imm,
+                                bool isAddress, Loop *L,
+                                ScalarEvolution *SE) {
+  const Type *UseTy = User->getType();
+  if (StoreInst *SI = dyn_cast<StoreInst>(User))
+    UseTy = SI->getOperand(0)->getType();
+  MoveImmediateValues(TLI, UseTy, Val, Imm, isAddress, L, SE);
+}
 
 /// SeparateSubExprs - Decompose Expr into all of the subexpressions that are
 /// added together.  This is used to reassociate common addition subexprs
@@ -1450,6 +1457,9 @@
       UsersToProcess[i].Base = 
         SE->getIntegerSCEV(0, UsersToProcess[i].Base->getType());
     } else {
+      // Not all uses are outside the loop.
+      AllUsesAreOutsideLoop = false; 
+
       // Addressing modes can be folded into loads and stores.  Be careful that
       // the store is through the expression, not of the expression though.
       bool isPHI = false;
@@ -1460,9 +1470,6 @@
         ++NumPHI;
       }
 
-      // Not all uses are outside the loop.
-      AllUsesAreOutsideLoop = false; 
-
       if (isAddress)
         HasAddress = true;
      
@@ -1475,12 +1482,12 @@
     }
   }
 
-  // If one of the use if a PHI node and all other uses are addresses, still
+  // If one of the use is a PHI node and all other uses are addresses, still
   // allow iv reuse. Essentially we are trading one constant multiplication
   // for one fewer iv.
   if (NumPHI > 1)
     AllUsesAreAddresses = false;
-
+    
   // There are no in-loop address uses.
   if (AllUsesAreAddresses && (!HasAddress && !AllUsesAreOutsideLoop))
     AllUsesAreAddresses = false;
@@ -1754,6 +1761,28 @@
                                   "commonbase", PreInsertPt);
 }
 
+static bool IsImmFoldedIntoAddrMode(GlobalValue *GV, int64_t Offset,
+                                    const Type *ReplacedTy,
+                                   std::vector<BasedUser> &UsersToProcess,
+                                   const TargetLowering *TLI) {
+  SmallVector<Instruction*, 16> AddrModeInsts;
+  for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i) {
+    if (UsersToProcess[i].isUseOfPostIncrementedValue)
+      continue;
+    ExtAddrMode AddrMode =
+      AddressingModeMatcher::Match(UsersToProcess[i].OperandValToReplace,
+                                   ReplacedTy, UsersToProcess[i].Inst,
+                                   AddrModeInsts, *TLI);
+    if (GV && GV != AddrMode.BaseGV)
+      return false;
+    if (Offset && !AddrMode.BaseOffs)
+      // FIXME: How to accurate check it's immediate offset is folded.
+      return false;
+    AddrModeInsts.clear();
+  }
+  return true;
+}
+
 /// StrengthReduceStridedIVUsers - Strength reduce all of the users of a single
 /// stride of IV.  All of the users may have different starting values, and this
 /// may not be the only stride (we know it is if isOnlyStride is true).
@@ -1797,6 +1826,41 @@
 
   const Type *ReplacedTy = CommonExprs->getType();
 
+  // If all uses are addresses, consider sinking the immediate part of the
+  // common expression back into uses if they can fit in the immediate fields.
+  if (HaveCommonExprs && AllUsesAreAddresses) {
+    SCEVHandle NewCommon = CommonExprs;
+    SCEVHandle Imm = SE->getIntegerSCEV(0, ReplacedTy);
+    MoveImmediateValues(TLI, ReplacedTy, NewCommon, Imm, true, L, SE);
+    if (!Imm->isZero()) {
+      bool DoSink = true;
+
+      // If the immediate part of the common expression is a GV, check if it's
+      // possible to fold it into the target addressing mode.
+      GlobalValue *GV = 0;
+      if (SCEVUnknown *SU = dyn_cast<SCEVUnknown>(Imm)) {
+        if (ConstantExpr *CE = dyn_cast<ConstantExpr>(SU->getValue()))
+          if (CE->getOpcode() == Instruction::PtrToInt)
+            GV = dyn_cast<GlobalValue>(CE->getOperand(0));
+      }
+      int64_t Offset = 0;
+      if (SCEVConstant *SC = dyn_cast<SCEVConstant>(Imm))
+        Offset = SC->getValue()->getSExtValue();
+      if (GV || Offset)
+        DoSink = IsImmFoldedIntoAddrMode(GV, Offset, ReplacedTy,
+                                         UsersToProcess, TLI);
+
+      if (DoSink) {
+        DOUT << "  Sinking " << *Imm << " back down into uses\n";
+        for (unsigned i = 0, e = UsersToProcess.size(); i != e; ++i)
+          UsersToProcess[i].Imm = SE->getAddExpr(UsersToProcess[i].Imm, Imm);
+        CommonExprs = NewCommon;
+        HaveCommonExprs = !CommonExprs->isZero();
+        ++NumImmSunk;
+      }
+    }
+  }
+
   // Now that we know what we need to do, insert the PHI node itself.
   //
   DOUT << "LSR: Examining IVs of TYPE " << *ReplacedTy << " of STRIDE "
@@ -2556,7 +2620,8 @@
     bool HasOneStride = IVUsesByStride.size() == 1;
 
 #ifndef NDEBUG
-    DOUT << "\nLSR on ";
+    DOUT << "\nLSR on \"" << L->getHeader()->getParent()->getNameStart()
+         << "\" ";
     DEBUG(L->dump());
 #endif