Refactor reciprocal and reciprocal square root estimate into target-independent functions (part 2).

This is purely refactoring. No functional changes intended. PowerPC is the only target
that is currently using this interface.

The ultimate goal is to allow targets other than PowerPC (certainly X86 and Aarch64) to turn this:

z = y / sqrt(x)

into:

z = y * rsqrte(x)

And:

z = y / x

into:

z = y * rcpe(x)

using whatever HW magic they can use. See http://llvm.org/bugs/show_bug.cgi?id=20900 .

There is one hook in TargetLowering to get the target-specific opcode for an estimate instruction
along with the number of refinement steps needed to make the estimate usable.

Differential Revision: http://reviews.llvm.org/D5484

llvm-svn: 218553
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 33e7059..34a0e04 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -276,6 +276,7 @@
     SDValue visitFMA(SDNode *N);
     SDValue visitFDIV(SDNode *N);
     SDValue visitFREM(SDNode *N);
+    SDValue visitFSQRT(SDNode *N);
     SDValue visitFCOPYSIGN(SDNode *N);
     SDValue visitSINT_TO_FP(SDNode *N);
     SDValue visitUINT_TO_FP(SDNode *N);
@@ -326,7 +327,8 @@
     SDValue BuildSDIV(SDNode *N);
     SDValue BuildSDIVPow2(SDNode *N);
     SDValue BuildUDIV(SDNode *N);
-    SDValue BuildRSQRTE(SDNode *N);
+    SDValue BuildReciprocalEstimate(SDValue Op);
+    SDValue BuildRsqrtEstimate(SDValue Op);
     SDValue MatchBSwapHWordLow(SDNode *N, SDValue N0, SDValue N1,
                                bool DemandHighBits = true);
     SDValue MatchBSwapHWord(SDNode *N, SDValue N0, SDValue N1);
@@ -1307,6 +1309,7 @@
   case ISD::FMA:                return visitFMA(N);
   case ISD::FDIV:               return visitFDIV(N);
   case ISD::FREM:               return visitFREM(N);
+  case ISD::FSQRT:              return visitFSQRT(N);
   case ISD::FCOPYSIGN:          return visitFCOPYSIGN(N);
   case ISD::SINT_TO_FP:         return visitSINT_TO_FP(N);
   case ISD::UINT_TO_FP:         return visitUINT_TO_FP(N);
@@ -6976,6 +6979,7 @@
   ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
   ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
   EVT VT = N->getValueType(0);
+  SDLoc DL(N);
   const TargetOptions &Options = DAG.getTarget().Options;
 
   // fold vector ops
@@ -7007,10 +7011,37 @@
         return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0,
                            DAG.getConstantFP(Recip, VT));
     }
+    
     // If this FDIV is part of a reciprocal square root, it may be folded
     // into a target-specific square root estimate instruction.
-    if (SDValue SqrtOp = BuildRSQRTE(N))
-      return SqrtOp;
+    if (N1.getOpcode() == ISD::FSQRT) {
+      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0))) {
+        AddToWorklist(RV.getNode());
+        return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
+      }
+    } else if (N1.getOpcode() == ISD::FP_EXTEND &&
+               N1.getOperand(0).getOpcode() == ISD::FSQRT) {
+      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0))) {
+        AddToWorklist(RV.getNode());
+        RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
+        AddToWorklist(RV.getNode());
+        return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
+      }
+    } else if (N1.getOpcode() == ISD::FP_ROUND &&
+               N1.getOperand(0).getOpcode() == ISD::FSQRT) {
+      if (SDValue RV = BuildRsqrtEstimate(N1.getOperand(0).getOperand(0))) {
+        AddToWorklist(RV.getNode());
+        RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
+        AddToWorklist(RV.getNode());
+        return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
+      }
+    }
+    
+    // Fold into a reciprocal estimate and multiply instead of a real divide.
+    if (SDValue RV = BuildReciprocalEstimate(N1)) {
+      AddToWorklist(RV.getNode());
+      return DAG.getNode(ISD::FMUL, DL, VT, N0, RV);
+    }
   }
 
   // (fdiv (fneg X), (fneg Y)) -> (fdiv X, Y)
@@ -7042,6 +7073,33 @@
   return SDValue();
 }
 
+SDValue DAGCombiner::visitFSQRT(SDNode *N) {
+  if (DAG.getTarget().Options.UnsafeFPMath) {
+    // Compute this as 1/(1/sqrt(X)): the reciprocal of the reciprocal sqrt.
+    if (SDValue RV = BuildRsqrtEstimate(N->getOperand(0))) {
+      AddToWorklist(RV.getNode());
+      RV = BuildReciprocalEstimate(RV);
+      if (RV.getNode()) {
+        // Unfortunately, RV is now NaN if the input was exactly 0.
+        // Select out this case and force the answer to 0.
+        EVT VT = RV.getValueType();
+      
+        SDValue Zero = DAG.getConstantFP(0.0, VT);
+        SDValue ZeroCmp =
+          DAG.getSetCC(SDLoc(N), TLI.getSetCCResultType(*DAG.getContext(), VT),
+                       N->getOperand(0), Zero, ISD::SETEQ);
+        AddToWorklist(ZeroCmp.getNode());
+        AddToWorklist(RV.getNode());
+
+        RV = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT,
+                         SDLoc(N), VT, ZeroCmp, Zero, RV);
+        return RV;
+      }
+    }
+  }
+  return SDValue();
+}
+
 SDValue DAGCombiner::visitFCOPYSIGN(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
@@ -11702,36 +11760,92 @@
   return S;
 }
 
-/// Given an ISD::FDIV node with either a direct or indirect ISD::FSQRT operand,
-/// generate a DAG expression using a reciprocal square root estimate op.
-SDValue DAGCombiner::BuildRSQRTE(SDNode *N) {
+SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op) {
+  if (Level >= AfterLegalizeDAG)
+    return SDValue();
+
   // Expose the DAG combiner to the target combiner implementations.
   TargetLowering::DAGCombinerInfo DCI(DAG, Level, false, this);
-  SDLoc DL(N);
-  EVT VT = N->getValueType(0);
-  SDValue N1 = N->getOperand(1);
 
-  if (N1.getOpcode() == ISD::FSQRT) {
-    if (SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0), DCI)) {
-      AddToWorklist(RV.getNode());
-      return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV);
+  unsigned Iterations;
+  if (SDValue Est = TLI.getEstimate(ISD::FDIV, Op, DCI, Iterations)) {
+    // Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
+    // For the reciprocal, we need to find the zero of the function:
+    //   F(X) = A X - 1 [which has a zero at X = 1/A]
+    //     =>
+    //   X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
+    //     does not require additional intermediate precision]
+    EVT VT = Op.getValueType();
+    SDLoc DL(Op);
+    SDValue FPOne = DAG.getConstantFP(1.0, VT);
+
+    AddToWorklist(Est.getNode());
+
+    // Newton iterations: Est = Est + Est (1 - Arg * Est)
+    for (unsigned i = 0; i < Iterations; ++i) {
+      SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est);
+      AddToWorklist(NewEst.getNode());
+
+      NewEst = DAG.getNode(ISD::FSUB, DL, VT, FPOne, NewEst);
+      AddToWorklist(NewEst.getNode());
+
+      NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst);
+      AddToWorklist(NewEst.getNode());
+
+      Est = DAG.getNode(ISD::FADD, DL, VT, Est, NewEst);
+      AddToWorklist(Est.getNode());
     }
-  } else if (N1.getOpcode() == ISD::FP_EXTEND &&
-             N1.getOperand(0).getOpcode() == ISD::FSQRT) {
-    if (SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0).getOperand(0), DCI)) {
-      DCI.AddToWorklist(RV.getNode());
-      RV = DAG.getNode(ISD::FP_EXTEND, SDLoc(N1), VT, RV);
-      AddToWorklist(RV.getNode());
-      return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV);
+
+    return Est;
+  }
+
+  return SDValue();
+}
+
+SDValue DAGCombiner::BuildRsqrtEstimate(SDValue Op) {
+  if (Level >= AfterLegalizeDAG)
+    return SDValue();
+
+  // Expose the DAG combiner to the target combiner implementations.
+  TargetLowering::DAGCombinerInfo DCI(DAG, Level, false, this);
+  unsigned Iterations;
+  if (SDValue Est = TLI.getEstimate(ISD::FSQRT, Op, DCI, Iterations)) {
+    // Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
+    // For the reciprocal sqrt, we need to find the zero of the function:
+    //   F(X) = 1/X^2 - A [which has a zero at X = 1/sqrt(A)]
+    //     =>
+    //   X_{i+1} = X_i (1.5 - A X_i^2 / 2)
+    // As a result, we precompute A/2 prior to the iteration loop.
+    EVT VT = Op.getValueType();
+    SDLoc DL(Op);
+    SDValue FPThreeHalves = DAG.getConstantFP(1.5, VT);
+
+    AddToWorklist(Est.getNode());
+
+    // We now need 0.5 * Arg which we can write as (1.5 * Arg - Arg) so that
+    // this entire sequence requires only one FP constant.
+    SDValue HalfArg = DAG.getNode(ISD::FMUL, DL, VT, FPThreeHalves, Op);
+    AddToWorklist(HalfArg.getNode());
+
+    HalfArg = DAG.getNode(ISD::FSUB, DL, VT, HalfArg, Op);
+    AddToWorklist(HalfArg.getNode());
+
+    // Newton iterations: Est = Est * (1.5 - HalfArg * Est * Est)
+    for (unsigned i = 0; i < Iterations; ++i) {
+      SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Est, Est);
+      AddToWorklist(NewEst.getNode());
+
+      NewEst = DAG.getNode(ISD::FMUL, DL, VT, HalfArg, NewEst);
+      AddToWorklist(NewEst.getNode());
+
+      NewEst = DAG.getNode(ISD::FSUB, DL, VT, FPThreeHalves, NewEst);
+      AddToWorklist(NewEst.getNode());
+
+      Est = DAG.getNode(ISD::FMUL, DL, VT, Est, NewEst);
+      AddToWorklist(Est.getNode());
     }
-  } else if (N1.getOpcode() == ISD::FP_ROUND &&
-             N1.getOperand(0).getOpcode() == ISD::FSQRT) {
-    if (SDValue RV = TLI.BuildRSQRTE(N1.getOperand(0).getOperand(0), DCI)) {
-      DCI.AddToWorklist(RV.getNode());
-      RV = DAG.getNode(ISD::FP_ROUND, SDLoc(N1), VT, RV, N1.getOperand(1));
-      AddToWorklist(RV.getNode());
-      return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), RV);
-    }
+
+    return Est;
   }
 
   return SDValue();