SLPVectorizer: add support for vectorization of diamond shaped trees. We now perform a preliminary traversal of the graph to collect values with multiple users and check where the users came from. 

llvm-svn: 179414
diff --git a/llvm/lib/Transforms/Vectorize/VecUtils.cpp b/llvm/lib/Transforms/Vectorize/VecUtils.cpp
index 3efaaf6..d8be0ae 100644
--- a/llvm/lib/Transforms/Vectorize/VecUtils.cpp
+++ b/llvm/lib/Transforms/Vectorize/VecUtils.cpp
@@ -6,7 +6,7 @@
 // License. See LICENSE.TXT for details.
 //
 //===----------------------------------------------------------------------===//
-#define DEBUG_TYPE "VecUtils"
+#define DEBUG_TYPE "SLP"
 
 #include "VecUtils.h"
 #include "llvm/ADT/DenseMap.h"
@@ -37,6 +37,10 @@
 
 using namespace llvm;
 
+static const unsigned MinVecRegSize = 128;
+
+static const unsigned RecursionMaxDepth = 6;
+
 namespace llvm {
 
 BoUpSLP::BoUpSLP(BasicBlock *Bb, ScalarEvolution *S, DataLayout *Dl,
@@ -98,9 +102,39 @@
   return ((-Offset) == Sz);
 }
 
+bool BoUpSLP::vectorizeStoreChain(ValueList &Chain, int CostThreshold) {
+  Type *StoreTy = cast<StoreInst>(Chain[0])->getValueOperand()->getType();
+  unsigned Sz = DL->getTypeSizeInBits(StoreTy);
+  unsigned VF = MinVecRegSize / Sz;
+
+  if (!isPowerOf2_32(Sz) || VF < 2) return false;
+
+  bool Changed = false;
+  for (unsigned i = 0, e = Chain.size(); i < e; ++i) {
+    if (i + VF > e) return Changed;
+    DEBUG(dbgs()<<"SLP: Analyzing " << VF << " stores at offset "<< i << "\n");
+    ValueList Operands(&Chain[i], &Chain[i] + VF);
+
+    int Cost = getTreeCost(Operands);
+    DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n");
+    if (Cost < CostThreshold) {
+      DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n");
+      vectorizeTree(Operands, VF);
+      i += VF;
+      Changed = true;
+    }
+  }
+
+  return Changed;
+}
+
 bool BoUpSLP::vectorizeStores(StoreList &Stores, int costThreshold) {
   ValueSet Heads, Tails;
   SmallDenseMap<Value*, Value*> ConsecutiveChain;
+
+  // We may run into multiple chains that merge into a single chain. We mark the
+  // stores that we vectorized so that we don't visit the same store twice.
+  ValueSet VectorizedStores;
   bool Changed = false;
 
   // Do a quadratic search on all of the given stores and find
@@ -123,27 +157,17 @@
     // to vectorize it.
     ValueList Operands;
     Value *I = *it;
-    int MinCost = 0, MinVF = 0;
+    // Collect the chain into a list.
     while (Tails.count(I) || Heads.count(I)) {
+      if (VectorizedStores.count(I)) break;
       Operands.push_back(I);
-      unsigned VF = Operands.size();
-      if (isPowerOf2_32(VF) && VF > 1) {
-        int cost = getTreeRollCost(Operands, 0);
-        DEBUG(dbgs() << "Found cost=" << cost << " for VF=" << VF << "\n");
-        if (cost < MinCost) { MinCost = cost; MinVF = VF; }
-      }
       // Move to the next value in the chain.
       I = ConsecutiveChain[I];
     }
 
-    if (MinCost <= costThreshold && MinVF > 1) {
-      DEBUG(dbgs() << "Decided to vectorize cost=" << MinCost << "\n");
-      vectorizeTree(Operands, MinVF);
-      Stores.clear();
-      // The current numbering is invalid because we added and removed instrs.
-      numberInstructions();
-      Changed = true;
-    }
+    bool Vectorized = vectorizeStoreChain(Operands, costThreshold);
+    if (Vectorized) VectorizedStores.insert(Operands.begin(), Operands.end());
+    Changed |= Vectorized;
   }
 
   return Changed;
@@ -184,8 +208,138 @@
   return 0;
 }
 
-int BoUpSLP::getTreeRollCost(ValueList &VL, unsigned Depth) {
-  if (Depth == 6) return max_cost;
+int BoUpSLP::getTreeCost(ValueList &VL) {
+  // Get rid of the list of stores that were removed, and from the
+  // lists of instructions with multiple users.
+  MemBarrierIgnoreList.clear();
+  LaneMap.clear();
+  MultiUserVals.clear();
+  MustScalarize.clear();
+
+  // Scan the tree and find which value is used by which lane, and which values
+  // must be scalarized.
+  getTreeUses_rec(VL, 0);
+
+  // Check that instructions with multiple users can be vectorized. Mark unsafe
+  // instructions.
+  for (ValueSet::iterator it = MultiUserVals.begin(),
+       e = MultiUserVals.end(); it != e; ++it) {
+    // Check that all of the users of this instr are within the tree
+    // and that they are all from the same lane.
+    int Lane = -1;
+    for (Value::use_iterator I = (*it)->use_begin(), E = (*it)->use_end();
+         I != E; ++I) {
+      if (LaneMap.find(*I) == LaneMap.end()) {
+        MustScalarize.insert(*it);
+        DEBUG(dbgs()<<"SLP: Adding " << **it <<
+              " to MustScalarize because of an out of tree usage.\n");
+        break;
+      }
+      if (Lane == -1) Lane = LaneMap[*I];
+      if (Lane != LaneMap[*I]) {
+        MustScalarize.insert(*it);
+        DEBUG(dbgs()<<"Adding " << **it <<
+              " to MustScalarize because multiple lane use it: "
+              << Lane << " and " << LaneMap[*I] << ".\n");
+        break;
+      }
+    }
+  }
+
+  // Now calculate the cost of vectorizing the tree.
+  return getTreeCost_rec(VL, 0);
+}
+
+void BoUpSLP::getTreeUses_rec(ValueList &VL, unsigned Depth) {
+  if (Depth == RecursionMaxDepth) return;
+
+  // Don't handle vectors.
+  if (VL[0]->getType()->isVectorTy()) return;
+  if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
+    if (SI->getValueOperand()->getType()->isVectorTy()) return;
+
+  // Check if all of the operands are constants.
+  bool AllConst = true;
+  bool AllSameScalar = true;
+  for (unsigned i = 0, e = VL.size(); i < e; ++i) {
+    AllConst &= isa<Constant>(VL[i]);
+    AllSameScalar &= (VL[0] == VL[i]);
+    Instruction *I = dyn_cast<Instruction>(VL[i]);
+    // If one of the instructions is out of this BB, we need to scalarize all.
+    if (I && I->getParent() != BB) return;
+  }
+
+  // If all of the operands are identical or constant we have a simple solution.
+  if (AllConst || AllSameScalar) return;
+
+  // Scalarize unknown structures.
+  Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
+  if (!VL0) return;
+
+  unsigned Opcode = VL0->getOpcode();
+  for (unsigned i = 0, e = VL.size(); i < e; ++i) {
+    Instruction *I = dyn_cast<Instruction>(VL[i]);
+    // If not all of the instructions are identical then we have to scalarize.
+    if (!I || Opcode != I->getOpcode()) return;
+  }
+
+  // Mark instructions with multiple users.
+  for (unsigned i = 0, e = VL.size(); i < e; ++i) {
+    Instruction *I = dyn_cast<Instruction>(VL[i]);
+    // Remember to check if all of the users of this instr are vectorized
+    // within our tree.
+    if (I && I->getNumUses() > 1) MultiUserVals.insert(I);
+  }
+
+  for (int i = 0, e = VL.size(); i < e; ++i) {
+    // Check that the instruction is only used within
+    // one lane.
+    if (LaneMap.count(VL[i]) && LaneMap[VL[i]] != i) return;
+    // Make this instruction as 'seen' and remember the lane.
+    LaneMap[VL[i]] = i;
+  }
+
+  switch (Opcode) {
+    case Instruction::Add:
+    case Instruction::FAdd:
+    case Instruction::Sub:
+    case Instruction::FSub:
+    case Instruction::Mul:
+    case Instruction::FMul:
+    case Instruction::UDiv:
+    case Instruction::SDiv:
+    case Instruction::FDiv:
+    case Instruction::URem:
+    case Instruction::SRem:
+    case Instruction::FRem:
+    case Instruction::Shl:
+    case Instruction::LShr:
+    case Instruction::AShr:
+    case Instruction::And:
+    case Instruction::Or:
+    case Instruction::Xor: {
+      for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
+        ValueList Operands;
+        // Prepare the operand vector.
+        for (unsigned j = 0; j < VL.size(); ++j)
+          Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
+
+        getTreeUses_rec(Operands, Depth+1);
+      }
+    }
+    case Instruction::Store: {
+      ValueList Operands;
+      for (unsigned j = 0; j < VL.size(); ++j)
+        Operands.push_back(cast<Instruction>(VL[j])->getOperand(0));
+      getTreeUses_rec(Operands, Depth+1);
+      return;
+    }
+    default:
+    return;
+  }
+}
+
+int BoUpSLP::getTreeCost_rec(ValueList &VL, unsigned Depth) {
   Type *ScalarTy = VL[0]->getType();
 
   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
@@ -193,9 +347,10 @@
 
   /// Don't mess with vectors.
   if (ScalarTy->isVectorTy()) return max_cost;
-
   VectorType *VecTy = VectorType::get(ScalarTy, VL.size());
 
+  if (Depth == RecursionMaxDepth) return getScalarizationCost(VecTy);
+
   // Check if all of the operands are constants.
   bool AllConst = true;
   bool AllSameScalar = true;
@@ -204,8 +359,8 @@
     AllSameScalar &= (VL[0] == VL[i]);
     // Must have a single use.
     Instruction *I = dyn_cast<Instruction>(VL[i]);
-    // Need to scalarize instructions with multiple users or from other BBs.
-    if (I && ((I->getNumUses() > 1) || (I->getParent() != BB)))
+    // This instruction is outside the basic block or if it is a known hazard.
+    if (MustScalarize.count(VL[i]) || (I && I->getParent() != BB))
       return getScalarizationCost(VecTy);
   }
 
@@ -239,7 +394,7 @@
       if (VL[i] == Last) continue;
       Value *Barrier = isUnsafeToSink(cast<Instruction>(VL[i]), Last);
       if (Barrier) {
-        DEBUG(dbgs() << "LR: Can't sink " << *VL[i] << "\n down to " <<
+        DEBUG(dbgs() << "SLP: Can't sink " << *VL[i] << "\n down to " <<
               *Last << "\n because of " << *Barrier << "\n");
         return max_cost;
       }
@@ -265,20 +420,22 @@
   case Instruction::And:
   case Instruction::Or:
   case Instruction::Xor: {
-    ValueList Operands;
     int Cost = 0;
     // Calculate the cost of all of the operands.
     for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) {
+      ValueList Operands;
       // Prepare the operand vector.
       for (unsigned j = 0; j < VL.size(); ++j)
         Operands.push_back(cast<Instruction>(VL[j])->getOperand(i));
-      Cost += getTreeRollCost(Operands, Depth+1);
-      Operands.clear();
+
+      Cost += getTreeCost_rec(Operands, Depth+1);
+      if (Cost >= max_cost) return max_cost;
     }
 
     // Calculate the cost of this instruction.
     int ScalarCost = VecTy->getNumElements() *
       TTI->getArithmeticInstrCost(Opcode, ScalarTy);
+
     int VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy);
     Cost += (VecCost - ScalarCost);
     return Cost;
@@ -308,8 +465,7 @@
       MemBarrierIgnoreList.insert(VL[j]);
     }
 
-    int TotalCost =  StoreCost + getTreeRollCost(Operands, Depth + 1);
-    MemBarrierIgnoreList.clear();
+    int TotalCost = StoreCost + getTreeCost_rec(Operands, Depth + 1);
     return TotalCost;
   }
   default:
@@ -334,6 +490,15 @@
 }
 
 Value *BoUpSLP::vectorizeTree(ValueList &VL, int VF) {
+  Value *V = vectorizeTree_rec(VL, VF);
+  // We moved some instructions around. We have to number them again
+  // before we can do any analysis.
+  numberInstructions();
+  MustScalarize.clear();
+  return V;
+}
+
+Value *BoUpSLP::vectorizeTree_rec(ValueList &VL, int VF) {
   Type *ScalarTy = VL[0]->getType();
   if (StoreInst *SI = dyn_cast<StoreInst>(VL[0]))
     ScalarTy = SI->getValueOperand()->getType();
@@ -345,19 +510,21 @@
   for (unsigned i = 0, e = VF; i < e; ++i) {
     AllConst &= !!dyn_cast<Constant>(VL[i]);
     AllSameScalar &= (VL[0] == VL[i]);
-    // Must have a single use.
+    // The instruction must be in the same BB, and it must be vectorizable.
     Instruction *I = dyn_cast<Instruction>(VL[i]);
-    if (I && (I->getNumUses() > 1 || I->getParent() != BB))
+    if (MustScalarize.count(VL[i]) || (I && I->getParent() != BB))
       return Scalarize(VL, VecTy);
   }
 
-  // Is this a simple vector constant.
+  // Check that this is a simple vector constant.
   if (AllConst || AllSameScalar) return Scalarize(VL, VecTy);
 
   // Scalarize unknown structures.
   Instruction *VL0 = dyn_cast<Instruction>(VL[0]);
   if (!VL0) return Scalarize(VL, VecTy);
 
+  if (VectorizedValues.count(VL0)) return VectorizedValues[VL0];
+
   unsigned Opcode = VL0->getOpcode();
   for (unsigned i = 0, e = VF; i < e; ++i) {
     Instruction *I = dyn_cast<Instruction>(VL[i]);
@@ -390,11 +557,13 @@
       LHSVL.push_back(cast<Instruction>(VL[i])->getOperand(1));
     }
 
-    Value *RHS = vectorizeTree(RHSVL, VF);
-    Value *LHS = vectorizeTree(LHSVL, VF);
+    Value *RHS = vectorizeTree_rec(RHSVL, VF);
+    Value *LHS = vectorizeTree_rec(LHSVL, VF);
     IRBuilder<> Builder(GetLastInstr(VL, VF));
     BinaryOperator *BinOp = dyn_cast<BinaryOperator>(VL0);
-    return Builder.CreateBinOp(BinOp->getOpcode(), RHS,LHS);
+    Value *V = Builder.CreateBinOp(BinOp->getOpcode(), RHS,LHS);
+    VectorizedValues[VL0] = V;
+    return V;
   }
   case Instruction::Load: {
     LoadInst *LI = dyn_cast<LoadInst>(VL0);
@@ -410,6 +579,7 @@
                                           VecTy->getPointerTo());
     LI = Builder.CreateLoad(VecPtr);
     LI->setAlignment(Alignment);
+    VectorizedValues[VL0] = LI;
     return LI;
   }
   case Instruction::Store: {
@@ -420,7 +590,7 @@
     for (int i = 0; i < VF; ++i)
       ValueOp.push_back(cast<StoreInst>(VL[i])->getValueOperand());
 
-    Value *VecValue = vectorizeTree(ValueOp, VF);
+    Value *VecValue = vectorizeTree_rec(ValueOp, VF);
 
     IRBuilder<> Builder(GetLastInstr(VL, VF));
     Value *VecPtr = Builder.CreateBitCast(SI->getPointerOperand(),
@@ -432,7 +602,9 @@
     return 0;
   }
   default:
-    return Scalarize(VL, VecTy);
+    Value *S = Scalarize(VL, VecTy);
+    VectorizedValues[VL0] = S;
+    return S;
   }
 }