[DAG Combiner] Fix the native computation of the Newton series for reciprocals
The generic infrastructure to compute the Newton series for reciprocal and
reciprocal square root was conceived to allow a target to compute the series
itself. However, the original code did not properly consider this condition
if returned by a target. This patch addresses the issues to allow a target
to compute the series on its own.
Differential revision: https://reviews.llvm.org/D22975
llvm-svn: 286523
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 6c0f435..63da116 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -14928,6 +14928,12 @@
return S;
}
+/// Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
+/// For the reciprocal, we need to find the zero of the function:
+/// F(X) = A X - 1 [which has a zero at X = 1/A]
+/// =>
+/// X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
+/// does not require additional intermediate precision]
SDValue DAGCombiner::BuildReciprocalEstimate(SDValue Op, SDNodeFlags *Flags) {
if (Level >= AfterLegalizeDAG)
return SDValue();
@@ -14947,19 +14953,13 @@
// refinement steps.
int Iterations = TLI.getDivRefinementSteps(VT, MF);
if (SDValue Est = TLI.getRecipEstimate(Op, DAG, Enabled, Iterations)) {
+ AddToWorklist(Est.getNode());
+
if (Iterations) {
- // Newton iteration for a function: F(X) is X_{i+1} = X_i - F(X_i)/F'(X_i)
- // For the reciprocal, we need to find the zero of the function:
- // F(X) = A X - 1 [which has a zero at X = 1/A]
- // =>
- // X_{i+1} = X_i (2 - A X_i) = X_i + X_i (1 - A X_i) [this second form
- // does not require additional intermediate precision]
EVT VT = Op.getValueType();
SDLoc DL(Op);
SDValue FPOne = DAG.getConstantFP(1.0, DL, VT);
- AddToWorklist(Est.getNode());
-
// Newton iterations: Est = Est + Est (1 - Arg * Est)
for (int i = 0; i < Iterations; ++i) {
SDValue NewEst = DAG.getNode(ISD::FMUL, DL, VT, Op, Est, Flags);
@@ -15100,12 +15100,30 @@
bool UseOneConstNR = false;
if (SDValue Est =
- TLI.getRsqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR)) {
+ TLI.getSqrtEstimate(Op, DAG, Enabled, Iterations, UseOneConstNR,
+ Reciprocal)) {
AddToWorklist(Est.getNode());
+
if (Iterations) {
Est = UseOneConstNR
- ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
- : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
+ ? buildSqrtNROneConst(Op, Est, Iterations, Flags, Reciprocal)
+ : buildSqrtNRTwoConst(Op, Est, Iterations, Flags, Reciprocal);
+
+ if (!Reciprocal) {
+ // Unfortunately, Est is now NaN if the input was exactly 0.0.
+ // Select out this case and force the answer to 0.0.
+ EVT VT = Op.getValueType();
+ SDLoc DL(Op);
+
+ SDValue FPZero = DAG.getConstantFP(0.0, DL, VT);
+ EVT CCVT = getSetCCResultType(VT);
+ SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, FPZero, ISD::SETEQ);
+ AddToWorklist(ZeroCmp.getNode());
+
+ Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT,
+ ZeroCmp, FPZero, Est);
+ AddToWorklist(Est.getNode());
+ }
}
return Est;
}
@@ -15118,23 +15136,7 @@
}
SDValue DAGCombiner::buildSqrtEstimate(SDValue Op, SDNodeFlags *Flags) {
- SDValue Est = buildSqrtEstimateImpl(Op, Flags, false);
- if (!Est)
- return SDValue();
-
- // Unfortunately, Est is now NaN if the input was exactly 0.
- // Select out this case and force the answer to 0.
- EVT VT = Est.getValueType();
- SDLoc DL(Op);
- SDValue Zero = DAG.getConstantFP(0.0, DL, VT);
- EVT CCVT = getSetCCResultType(VT);
- SDValue ZeroCmp = DAG.getSetCC(DL, CCVT, Op, Zero, ISD::SETEQ);
- AddToWorklist(ZeroCmp.getNode());
-
- Est = DAG.getNode(VT.isVector() ? ISD::VSELECT : ISD::SELECT, DL, VT, ZeroCmp,
- Zero, Est);
- AddToWorklist(Est.getNode());
- return Est;
+ return buildSqrtEstimateImpl(Op, Flags, false);
}
/// Return true if base is a frame index, which is known not to alias with
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 5bea916..a9d2e1f 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -4644,10 +4644,11 @@
return SDValue();
}
-SDValue AArch64TargetLowering::getRsqrtEstimate(SDValue Operand,
- SelectionDAG &DAG, int Enabled,
- int &ExtraSteps,
- bool &UseOneConst) const {
+SDValue AArch64TargetLowering::getSqrtEstimate(SDValue Operand,
+ SelectionDAG &DAG, int Enabled,
+ int &ExtraSteps,
+ bool &UseOneConst,
+ bool Reciprocal) const {
if (Enabled == ReciprocalEstimate::Enabled ||
(Enabled == ReciprocalEstimate::Unspecified && Subtarget->useRSqrt()))
if (SDValue Estimate = getEstimate(Subtarget, AArch64ISD::FRSQRTE, Operand,
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index c9ae7cd..7b317d6 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -534,8 +534,9 @@
SDValue BuildSDIVPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG,
std::vector<SDNode *> *Created) const override;
- SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
- int &ExtraSteps, bool &UseOneConst) const override;
+ SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+ int &ExtraSteps, bool &UseOneConst,
+ bool Reciprocal) const override;
SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
int &ExtraSteps) const override;
unsigned combineRepeatedFPDivisors() const override;
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
index 5a87148..a4a5de1 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp
@@ -2978,10 +2978,11 @@
return nullptr;
}
-SDValue AMDGPUTargetLowering::getRsqrtEstimate(SDValue Operand,
- SelectionDAG &DAG, int Enabled,
- int &RefinementSteps,
- bool &UseOneConstNR) const {
+SDValue AMDGPUTargetLowering::getSqrtEstimate(SDValue Operand,
+ SelectionDAG &DAG, int Enabled,
+ int &RefinementSteps,
+ bool &UseOneConstNR,
+ bool Reciprocal) const {
EVT VT = Operand.getValueType();
if (VT == MVT::f32) {
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
index 07d2db8..6c6fc2e 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
+++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h
@@ -172,9 +172,9 @@
bool isFsqrtCheap(SDValue Operand, SelectionDAG &DAG) const override {
return true;
}
- SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
- int &RefinementSteps,
- bool &UseOneConstNR) const override;
+ SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+ int &RefinementSteps, bool &UseOneConstNR,
+ bool Reciprocal) const override;
SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
int &RefinementSteps) const override;
diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
index d54c76e..22f7110 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.cpp
@@ -9637,9 +9637,10 @@
return RefinementSteps;
}
-SDValue PPCTargetLowering::getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG,
- int Enabled, int &RefinementSteps,
- bool &UseOneConstNR) const {
+SDValue PPCTargetLowering::getSqrtEstimate(SDValue Operand, SelectionDAG &DAG,
+ int Enabled, int &RefinementSteps,
+ bool &UseOneConstNR,
+ bool Reciprocal) const {
EVT VT = Operand.getValueType();
if ((VT == MVT::f32 && Subtarget.hasFRSQRTES()) ||
(VT == MVT::f64 && Subtarget.hasFRSQRTE()) ||
diff --git a/llvm/lib/Target/PowerPC/PPCISelLowering.h b/llvm/lib/Target/PowerPC/PPCISelLowering.h
index 2944e99..689a2e6 100644
--- a/llvm/lib/Target/PowerPC/PPCISelLowering.h
+++ b/llvm/lib/Target/PowerPC/PPCISelLowering.h
@@ -968,9 +968,9 @@
SDValue DAGCombineTruncBoolExt(SDNode *N, DAGCombinerInfo &DCI) const;
SDValue combineFPToIntToFP(SDNode *N, DAGCombinerInfo &DCI) const;
- SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
- int &RefinementSteps,
- bool &UseOneConstNR) const override;
+ SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+ int &RefinementSteps, bool &UseOneConstNR,
+ bool Reciprocal) const override;
SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
int &RefinementSteps) const override;
unsigned combineRepeatedFPDivisors() const override;
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index fceaecf..5f4a29d 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -15296,10 +15296,11 @@
/// The minimum architected relative accuracy is 2^-12. We need one
/// Newton-Raphson step to have a good float result (24 bits of precision).
-SDValue X86TargetLowering::getRsqrtEstimate(SDValue Op,
- SelectionDAG &DAG, int Enabled,
- int &RefinementSteps,
- bool &UseOneConstNR) const {
+SDValue X86TargetLowering::getSqrtEstimate(SDValue Op,
+ SelectionDAG &DAG, int Enabled,
+ int &RefinementSteps,
+ bool &UseOneConstNR,
+ bool Reciprocal) const {
EVT VT = Op.getValueType();
// SSE1 has rsqrtss and rsqrtps. AVX adds a 256-bit variant for rsqrtps.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h
index dabef9d..b5903e8 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.h
+++ b/llvm/lib/Target/X86/X86ISelLowering.h
@@ -1268,9 +1268,9 @@
bool isFsqrtCheap(SDValue Operand, SelectionDAG &DAG) const override;
/// Use rsqrt* to speed up sqrt calculations.
- SDValue getRsqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
- int &RefinementSteps,
- bool &UseOneConstNR) const override;
+ SDValue getSqrtEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,
+ int &RefinementSteps, bool &UseOneConstNR,
+ bool Reciprocal) const override;
/// Use rcp* to speed up fdiv calculations.
SDValue getRecipEstimate(SDValue Operand, SelectionDAG &DAG, int Enabled,