teach reassociate to factor x+x+x -> x*3.  While I'm at it,
fix RemoveDeadBinaryOp to actually do something.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@92368 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp
index fe63381..e0528f7 100644
--- a/lib/Transforms/Scalar/Reassociate.cpp
+++ b/lib/Transforms/Scalar/Reassociate.cpp
@@ -88,7 +88,7 @@
   private:
     void BuildRankMap(Function &F);
     unsigned getRank(Value *V);
-    void ReassociateExpression(BinaryOperator *I);
+    Value *ReassociateExpression(BinaryOperator *I);
     void RewriteExprTree(BinaryOperator *I, SmallVectorImpl<ValueEntry> &Ops,
                          unsigned Idx = 0);
     Value *OptimizeExpression(BinaryOperator *I,
@@ -111,10 +111,13 @@
 
 void Reassociate::RemoveDeadBinaryOp(Value *V) {
   Instruction *Op = dyn_cast<Instruction>(V);
-  if (!Op || !isa<BinaryOperator>(Op) || !isa<CmpInst>(Op) || !Op->use_empty())
+  if (!Op || !isa<BinaryOperator>(Op) || !Op->use_empty())
     return;
   
   Value *LHS = Op->getOperand(0), *RHS = Op->getOperand(1);
+  
+  ValueRankMap.erase(Op);
+  Op->eraseFromParent();
   RemoveDeadBinaryOp(LHS);
   RemoveDeadBinaryOp(RHS);
 }
@@ -602,15 +605,57 @@
 /// is returned, otherwise the Ops list is mutated as necessary.
 Value *Reassociate::OptimizeAdd(Instruction *I,
                                 SmallVectorImpl<ValueEntry> &Ops) {
+  SmallPtrSet<Value*, 8> OperandsSeen;
+  
+Restart:
+  OperandsSeen.clear();
+  
   // Scan the operand lists looking for X and -X pairs.  If we find any, we
-  // can simplify the expression. X+-X == 0.
+  // can simplify the expression. X+-X == 0.  While we're at it, scan for any
+  // duplicates.  We want to canonicalize Y+Y+Y+Z -> 3*Y+Z.
   for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
-    assert(i < Ops.size());
+    Value *TheOp = Ops[i].Op;
+    // Check to see if we've seen this operand before.  If so, we factor all
+    // instances of the operand together.
+    if (!OperandsSeen.insert(TheOp)) {
+      // Rescan the list, removing all instances of this operand from the expr.
+      unsigned NumFound = 0;
+      for (unsigned j = 0, je = Ops.size(); j != je; ++j) {
+        if (Ops[j].Op != TheOp) continue;
+        ++NumFound;
+        Ops.erase(Ops.begin()+j);
+        --j; --je;
+      }
+
+      /*DEBUG*/(errs() << "\nFACTORING [" << NumFound << "]: " << *TheOp << '\n');
+      ++NumFactor;
+
+      
+      // Insert a new multiply.
+      Value *Mul = ConstantInt::get(cast<IntegerType>(I->getType()), NumFound);
+      Mul = BinaryOperator::CreateMul(TheOp, Mul, "factor", I);
+      
+      // Now that we have inserted a multiply, optimize it. This allows us to
+      // handle cases that require multiple factoring steps, such as this:
+      // (X*2) + (X*2) + (X*2) -> (X*2)*3 -> X*6
+      Mul = ReassociateExpression(cast<BinaryOperator>(Mul));
+      
+      // If every add operand was a duplicate, return the multiply.
+      if (Ops.empty())
+        return Mul;
+      
+      // Otherwise, we had some input that didn't have the dupe, such as
+      // "A + A + B" -> "A*2 + B".  Add the new multiply to the list of
+      // things being added by this operation.
+      Ops.insert(Ops.begin(), ValueEntry(getRank(Mul), Mul));
+      goto Restart;
+    }
+    
     // Check for X and -X in the operand list.
-    if (!BinaryOperator::isNeg(Ops[i].Op))
+    if (!BinaryOperator::isNeg(TheOp))
       continue;
     
-    Value *X = BinaryOperator::getNegArgument(Ops[i].Op);
+    Value *X = BinaryOperator::getNegArgument(TheOp);
     unsigned FoundX = FindInOperandList(Ops, i, X);
     if (FoundX == i)
       continue;
@@ -639,7 +684,6 @@
   
   // Keep track of each multiply we see, to avoid triggering on (X*4)+(X*4)
   // where they are actually the same multiply.
-  SmallPtrSet<BinaryOperator*, 4> Multiplies;
   unsigned MaxOcc = 0;
   Value *MaxOccVal = 0;
   for (unsigned i = 0, e = Ops.size(); i != e; ++i) {
@@ -647,9 +691,6 @@
     if (BOp == 0 || BOp->getOpcode() != Instruction::Mul || !BOp->use_empty())
       continue;
     
-    // If we've already seen this multiply, don't revisit it.
-    if (!Multiplies.insert(BOp)) continue;
-    
     // Compute all of the factors of this added value.
     SmallVector<Value*, 8> Factors;
     FindSingleUseMultiplyFactors(BOp, Factors);
@@ -676,7 +717,7 @@
   
   // If any factor occurred more than one time, we can pull it out.
   if (MaxOcc > 1) {
-    DEBUG(errs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << "\n");
+    DEBUG(errs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << '\n');
     ++NumFactor;
 
     // Create a new instruction that uses the MaxOccVal twice.  If we don't do
@@ -698,13 +739,17 @@
     
     unsigned NumAddedValues = NewMulOps.size();
     Value *V = EmitAddTreeOfValues(I, NewMulOps);
-    Value *V2 = BinaryOperator::CreateMul(V, MaxOccVal, "tmp", I);
     
-    // Now that we have inserted V and its sole use, optimize it. This allows
-    // us to handle cases that require multiple factoring steps, such as this:
+    // Now that we have inserted the add tree, optimize it. This allows us to
+    // handle cases that require multiple factoring steps, such as this:
     // A*A*B + A*A*C   -->   A*(A*B+A*C)   -->   A*(A*(B+C))
     assert(NumAddedValues > 1 && "Each occurrence should contribute a value");
-    ReassociateExpression(cast<BinaryOperator>(V));
+    V = ReassociateExpression(cast<BinaryOperator>(V));
+
+    // Create the multiply.
+    Value *V2 = BinaryOperator::CreateMul(V, MaxOccVal, "tmp", I);
+
+    // FIXME: Should rerun 'ReassociateExpression' on the mul too??
     
     // If every add operand included the factor (e.g. "A*B + A*C"), then the
     // entire result expression is just the multiply "A*(B+C)".
@@ -852,9 +897,10 @@
   }
 }
 
-void Reassociate::ReassociateExpression(BinaryOperator *I) {
+Value *Reassociate::ReassociateExpression(BinaryOperator *I) {
   
-  // First, walk the expression tree, linearizing the tree, collecting
+  // First, walk the expression tree, linearizing the tree, collecting the
+  // operand information.
   SmallVector<ValueEntry, 8> Ops;
   LinearizeExprTree(I, Ops);
   
@@ -877,7 +923,7 @@
     I->replaceAllUsesWith(V);
     RemoveDeadBinaryOp(I);
     ++NumAnnihil;
-    return;
+    return V;
   }
   
   // We want to sink immediates as deeply as possible except in the case where
@@ -899,11 +945,13 @@
     // eliminate it.
     I->replaceAllUsesWith(Ops[0].Op);
     RemoveDeadBinaryOp(I);
-  } else {
-    // Now that we ordered and optimized the expressions, splat them back into
-    // the expression tree, removing any unneeded nodes.
-    RewriteExprTree(I, Ops);
+    return Ops[0].Op;
   }
+  
+  // Now that we ordered and optimized the expressions, splat them back into
+  // the expression tree, removing any unneeded nodes.
+  RewriteExprTree(I, Ops);
+  return I;
 }