Loop Strength Reduce: Scaling factor cost.

Account for the cost of scaling factor in Loop Strength Reduce when rating the
formulae. This uses a target hook.

The default implementation of the hook is: if the addressing mode is legal, the
scaling factor is free.

<rdar://problem/13806271>


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@183045 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index ecc96ae..b107fef 100644
--- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -779,6 +779,9 @@
 // Check if it is legal to fold 2 base registers.
 static bool isLegal2RegAMUse(const TargetTransformInfo &TTI, const LSRUse &LU,
                              const Formula &F);
+// Get the cost of the scaling factor used in F for LU.
+static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
+                                     const LSRUse &LU, const Formula &F);
 
 namespace {
 
@@ -792,11 +795,12 @@
   unsigned NumBaseAdds;
   unsigned ImmCost;
   unsigned SetupCost;
+  unsigned ScaleCost;
 
 public:
   Cost()
     : NumRegs(0), AddRecCost(0), NumIVMuls(0), NumBaseAdds(0), ImmCost(0),
-      SetupCost(0) {}
+      SetupCost(0), ScaleCost(0) {}
 
   bool operator<(const Cost &Other) const;
 
@@ -806,9 +810,9 @@
   // Once any of the metrics loses, they must all remain losers.
   bool isValid() {
     return ((NumRegs | AddRecCost | NumIVMuls | NumBaseAdds
-             | ImmCost | SetupCost) != ~0u)
+             | ImmCost | SetupCost | ScaleCost) != ~0u)
       || ((NumRegs & AddRecCost & NumIVMuls & NumBaseAdds
-           & ImmCost & SetupCost) == ~0u);
+           & ImmCost & SetupCost & ScaleCost) == ~0u);
   }
 #endif
 
@@ -947,6 +951,9 @@
     // allows to fold 2 registers.
     NumBaseAdds += NumBaseParts - (1 + isLegal2RegAMUse(TTI, LU, F));
 
+  // Accumulate non-free scaling amounts.
+  ScaleCost += getScalingFactorCost(TTI, LU, F);
+
   // Tally up the non-zero immediates.
   for (SmallVectorImpl<int64_t>::const_iterator I = Offsets.begin(),
        E = Offsets.end(); I != E; ++I) {
@@ -968,6 +975,7 @@
   NumBaseAdds = ~0u;
   ImmCost = ~0u;
   SetupCost = ~0u;
+  ScaleCost = ~0u;
 }
 
 /// operator< - Choose the lower cost.
@@ -980,6 +988,8 @@
     return NumIVMuls < Other.NumIVMuls;
   if (NumBaseAdds != Other.NumBaseAdds)
     return NumBaseAdds < Other.NumBaseAdds;
+  if (ScaleCost != Other.ScaleCost)
+    return ScaleCost < Other.ScaleCost;
   if (ImmCost != Other.ImmCost)
     return ImmCost < Other.ImmCost;
   if (SetupCost != Other.SetupCost)
@@ -996,6 +1006,8 @@
   if (NumBaseAdds != 0)
     OS << ", plus " << NumBaseAdds << " base add"
        << (NumBaseAdds == 1 ? "" : "s");
+  if (ScaleCost != 0)
+    OS << ", plus " << ScaleCost << " scale cost";
   if (ImmCost != 0)
     OS << ", plus " << ImmCost << " imm cost";
   if (SetupCost != 0)
@@ -1396,6 +1408,34 @@
                     F.BaseGV, F.BaseOffset, F.HasBaseReg, 1);
  }
 
+static unsigned getScalingFactorCost(const TargetTransformInfo &TTI,
+                                     const LSRUse &LU, const Formula &F) {
+  if (!F.Scale)
+    return 0;
+  assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind,
+                    LU.AccessTy, F) && "Illegal formula in use.");
+
+  switch (LU.Kind) {
+  case LSRUse::Address: {
+    int CurScaleCost = TTI.getScalingFactorCost(LU.AccessTy, F.BaseGV,
+                                                F.BaseOffset, F.HasBaseReg,
+                                                F.Scale);
+    assert(CurScaleCost >= 0 && "Legal addressing mode has an illegal cost!");
+    return CurScaleCost;
+  }
+  case LSRUse::ICmpZero:
+    // ICmpZero BaseReg + -1*ScaleReg => ICmp BaseReg, ScaleReg.
+    // Therefore, return 0 in case F.Scale == -1. 
+    return F.Scale != -1;
+
+  case LSRUse::Basic:
+  case LSRUse::Special:
+    return 0;
+  }
+
+  llvm_unreachable("Invalid LSRUse Kind!");
+}
+
 static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
                              LSRUse::KindType Kind, Type *AccessTy,
                              GlobalValue *BaseGV, int64_t BaseOffset,