Perform factorization as a last resort of unsafe fadd/fsub simplification.

Rules include:
  1)1 x*y +/- x*z => x*(y +/- z) 
    (the order of operands dosen't matter)

  2) y/x +/- z/x => (y +/- z)/x 

 The transformation is disabled if the new add/sub expr "y +/- z" is a 
denormal/naz/inifinity.

rdar://12911472


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@177088 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index c6d60d6..3c5781c 100644
--- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -150,7 +150,9 @@
     typedef SmallVector<const FAddend*, 4> AddendVect;
   
     Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota);
-  
+
+    Value *performFactorization(Instruction *I);
+
     /// Convert given addend to a Value
     Value *createAddendVal(const FAddend &A, bool& NeedNeg);
     
@@ -159,6 +161,7 @@
     Value *createFSub(Value *Opnd0, Value *Opnd1);
     Value *createFAdd(Value *Opnd0, Value *Opnd1);
     Value *createFMul(Value *Opnd0, Value *Opnd1);
+    Value *createFDiv(Value *Opnd0, Value *Opnd1);
     Value *createFNeg(Value *V);
     Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota);
     void createInstPostProc(Instruction *NewInst);
@@ -388,6 +391,78 @@
   return BreakNum;
 }
 
+// Try to perform following optimization on the input instruction I. Return the
+// simplified expression if was successful; otherwise, return 0.
+//
+//   Instruction "I" is                Simplified into
+// -------------------------------------------------------
+//   (x * y) +/- (x * z)               x * (y +/- z)
+//   (y / x) +/- (z / x)               (y +/- z) / x
+//
+Value *FAddCombine::performFactorization(Instruction *I) {
+  assert((I->getOpcode() == Instruction::FAdd ||
+          I->getOpcode() == Instruction::FSub) && "Expect add/sub");
+  
+  Instruction *I0 = dyn_cast<Instruction>(I->getOperand(0));
+  Instruction *I1 = dyn_cast<Instruction>(I->getOperand(1));
+  
+  if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode())
+    return 0;
+
+  bool isMpy = false;
+  if (I0->getOpcode() == Instruction::FMul)
+    isMpy = true;
+  else if (I0->getOpcode() != Instruction::FDiv)
+    return 0;
+
+  Value *Opnd0_0 = I0->getOperand(0);
+  Value *Opnd0_1 = I0->getOperand(1);
+  Value *Opnd1_0 = I1->getOperand(0);
+  Value *Opnd1_1 = I1->getOperand(1);
+
+  //  Input Instr I       Factor   AddSub0  AddSub1 
+  //  ----------------------------------------------
+  // (x*y) +/- (x*z)        x        y         z
+  // (y/x) +/- (z/x)        x        y         z
+  //
+  Value *Factor = 0;
+  Value *AddSub0 = 0, *AddSub1 = 0;
+  
+  if (isMpy) {
+    if (Opnd0_0 == Opnd1_0 || Opnd0_0 == Opnd1_1)
+      Factor = Opnd0_0;
+    else if (Opnd0_1 == Opnd1_0 || Opnd0_1 == Opnd1_1)
+      Factor = Opnd0_1;
+
+    if (Factor) {
+      AddSub0 = (Factor == Opnd0_0) ? Opnd0_1 : Opnd0_0;
+      AddSub1 = (Factor == Opnd1_0) ? Opnd1_1 : Opnd1_0;
+    }
+  } else if (Opnd0_1 == Opnd1_1) {
+    Factor = Opnd0_1;
+    AddSub0 = Opnd0_0;
+    AddSub1 = Opnd1_0;
+  }
+
+  if (!Factor)
+    return 0;
+
+  // Create expression "NewAddSub = AddSub0 +/- AddsSub1"
+  Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ?
+                      createFAdd(AddSub0, AddSub1) :
+                      createFSub(AddSub0, AddSub1);
+  if (ConstantFP *CFP = dyn_cast<ConstantFP>(NewAddSub)) {
+    const APFloat &F = CFP->getValueAPF();
+    if (!F.isNormal() || F.isDenormal())
+      return 0;
+  }
+
+  if (isMpy)
+    return createFMul(Factor, NewAddSub);
+ 
+  return createFDiv(NewAddSub, Factor);
+}
+
 Value *FAddCombine::simplify(Instruction *I) {
   assert(I->hasUnsafeAlgebra() && "Should be in unsafe mode");
 
@@ -471,7 +546,8 @@
       return R;
   }
 
-  return 0;
+  // step 6: Try factorization as the last resort, 
+  return performFactorization(I);
 }
 
 Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
@@ -627,7 +703,8 @@
 Value *FAddCombine::createFSub
   (Value *Opnd0, Value *Opnd1) {
   Value *V = Builder->CreateFSub(Opnd0, Opnd1);
-  createInstPostProc(cast<Instruction>(V));
+  if (Instruction *I = dyn_cast<Instruction>(V))
+    createInstPostProc(I);
   return V;
 }
 
@@ -639,13 +716,22 @@
 Value *FAddCombine::createFAdd
   (Value *Opnd0, Value *Opnd1) {
   Value *V = Builder->CreateFAdd(Opnd0, Opnd1);
-  createInstPostProc(cast<Instruction>(V));
+  if (Instruction *I = dyn_cast<Instruction>(V))
+    createInstPostProc(I);
   return V;
 }
 
 Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) {
   Value *V = Builder->CreateFMul(Opnd0, Opnd1);
-  createInstPostProc(cast<Instruction>(V));
+  if (Instruction *I = dyn_cast<Instruction>(V))
+    createInstPostProc(I);
+  return V;
+}
+
+Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) {
+  Value *V = Builder->CreateFDiv(Opnd0, Opnd1);
+  if (Instruction *I = dyn_cast<Instruction>(V))
+    createInstPostProc(I);
   return V;
 }