SLPVectorizer: Add support for trees that don't start at binary operators, and add the cost of extracting values from the roots of the tree.

llvm-svn: 179475
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index 2f55a00..d94b2b2 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -85,14 +85,16 @@
     return true;
   }
 
-  bool tryToVectorizePair(BinaryOperator *A, BinaryOperator *B,  BoUpSLP &R) {
+  bool tryToVectorizePair(Value *A, Value *B,  BoUpSLP &R) {
     if (!A || !B) return false;
     BoUpSLP::ValueList VL;
     VL.push_back(A);
     VL.push_back(B);
     int Cost = R.getTreeCost(VL);
-    DEBUG(dbgs()<<"SLP: Cost of pair:" << Cost << ".\n");
-    if (Cost >= -SLPCostThreshold) return false;
+    int ExtrCost = R.getScalarizationCost(VL);
+    DEBUG(dbgs()<<"SLP: Cost of pair:" << Cost <<
+                  " Cost of extract:" << ExtrCost << ".\n");
+    if ((Cost+ExtrCost) >= -SLPCostThreshold) return false;
     DEBUG(dbgs()<<"SLP: Vectorizing pair.\n");
     R.vectorizeArith(VL);
     return true;
@@ -100,11 +102,12 @@
 
   bool tryToVectorizeCandidate(BinaryOperator *V,  BoUpSLP &R) {
     if (!V) return false;
+    // Try to vectorize V.
+    if (tryToVectorizePair(V->getOperand(0), V->getOperand(1), R))
+      return true;
+
     BinaryOperator *A = dyn_cast<BinaryOperator>(V->getOperand(0));
     BinaryOperator *B = dyn_cast<BinaryOperator>(V->getOperand(1));
-    // Try to vectorize V.
-    if (tryToVectorizePair(A, B, R)) return true;
-
     // Try to skip B.
     if (B && B->hasOneUse()) {
       BinaryOperator *B0 = dyn_cast<BinaryOperator>(B->getOperand(0));
diff --git a/llvm/lib/Transforms/Vectorize/VecUtils.cpp b/llvm/lib/Transforms/Vectorize/VecUtils.cpp
index 4d075c5..584f3d9 100644
--- a/llvm/lib/Transforms/Vectorize/VecUtils.cpp
+++ b/llvm/lib/Transforms/Vectorize/VecUtils.cpp
@@ -173,6 +173,16 @@
   return Changed;
 }
 
+int BoUpSLP::getScalarizationCost(ValueList &VL) {
+  Type *ScalarTy = VL[0]->getType();
+
+  if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
+    ScalarTy = SI->getValueOperand()->getType();
+
+  VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
+  return getScalarizationCost(VecTy);
+}
+
 int BoUpSLP::getScalarizationCost(Type *Ty) {
   int Cost = 0;
   for (unsigned i = 0, e = cast<VectorType>(Ty)->getNumElements(); i < e; ++i)
diff --git a/llvm/lib/Transforms/Vectorize/VecUtils.h b/llvm/lib/Transforms/Vectorize/VecUtils.h
index f865236..edebcb3 100644
--- a/llvm/lib/Transforms/Vectorize/VecUtils.h
+++ b/llvm/lib/Transforms/Vectorize/VecUtils.h
@@ -61,6 +61,11 @@
   /// A negative number means that this is profitable.
   int getTreeCost(ValueList &VL);
 
+  /// \returns the scalarization cost for this ValueList. Assuming that this
+  /// subtree gets vectorized, we may need to extract the values from the
+  /// roots. This method calculates the cost of extracting the values.
+  int getScalarizationCost(ValueList &VL);
+
   /// \brief Attempts to order and vectorize a sequence of stores. This
   /// function does a quadratic scan of the given stores.
   /// \returns true if the basic block was modified.
@@ -118,7 +123,7 @@
   /// by multiple lanes, or by users outside the tree.
   /// NOTICE: The vectorization methods also use this set.
   ValueSet MustScalarize;
-  
+
   // Contains a list of values that are used outside the current tree. This
   // set must be reset between runs.
   ValueSet MultiUserVals;