AMDGPU: Form more FMAs if fusion is allowed
Extend the existing fadd/fsub->fmad combines to produce
FMA if allowed.
llvm-svn: 290311
diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
index 1572897..52cc042 100644
--- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
+++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp
@@ -3871,24 +3871,31 @@
return SDValue();
}
+unsigned SITargetLowering::getFusedOpcode(const SelectionDAG &DAG, EVT VT) const {
+ // Only do this if we are not trying to support denormals. v_mad_f32 does not
+ // support denormals ever.
+ if ((VT == MVT::f32 && !Subtarget->hasFP32Denormals()) ||
+ (VT == MVT::f16 && !Subtarget->hasFP16Denormals()))
+ return ISD::FMAD;
+
+ const TargetOptions &Options = DAG.getTarget().Options;
+ if ((Options.AllowFPOpFusion == FPOpFusion::Fast || Options.UnsafeFPMath) &&
+ isFMAFasterThanFMulAndFAdd(VT)) {
+ return ISD::FMA;
+ }
+
+ return 0;
+}
+
SDValue SITargetLowering::performFAddCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
if (DCI.getDAGCombineLevel() < AfterLegalizeDAG)
return SDValue();
- EVT VT = N->getValueType(0);
- if (VT == MVT::f64)
- return SDValue();
-
- assert(!VT.isVector());
-
- // Only do this if we are not trying to support denormals. v_mad_f32 does
- // not support denormals ever.
- if ((VT == MVT::f32 && Subtarget->hasFP32Denormals()) ||
- (VT == MVT::f16 && Subtarget->hasFP16Denormals()))
- return SDValue();
-
SelectionDAG &DAG = DCI.DAG;
+ EVT VT = N->getValueType(0);
+ assert(!VT.isVector());
+
SDLoc SL(N);
SDValue LHS = N->getOperand(0);
SDValue RHS = N->getOperand(1);
@@ -3900,8 +3907,11 @@
if (LHS.getOpcode() == ISD::FADD) {
SDValue A = LHS.getOperand(0);
if (A == LHS.getOperand(1)) {
- const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
- return DAG.getNode(ISD::FMAD, SL, VT, Two, A, RHS);
+ unsigned FusedOp = getFusedOpcode(DAG, VT);
+ if (FusedOp != 0) {
+ const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
+ return DAG.getNode(FusedOp, SL, VT, Two, A, RHS);
+ }
}
}
@@ -3909,8 +3919,11 @@
if (RHS.getOpcode() == ISD::FADD) {
SDValue A = RHS.getOperand(0);
if (A == RHS.getOperand(1)) {
- const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
- return DAG.getNode(ISD::FMAD, SL, VT, Two, A, LHS);
+ unsigned FusedOp = getFusedOpcode(DAG, VT);
+ if (FusedOp != 0) {
+ const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
+ return DAG.getNode(FusedOp, SL, VT, Two, A, LHS);
+ }
}
}
@@ -3932,29 +3945,31 @@
//
// Only do this if we are not trying to support denormals. v_mad_f32 does
// not support denormals ever.
- if ((VT == MVT::f32 && !Subtarget->hasFP32Denormals()) ||
- (VT == MVT::f16 && !Subtarget->hasFP16Denormals())) {
- SDValue LHS = N->getOperand(0);
- SDValue RHS = N->getOperand(1);
- if (LHS.getOpcode() == ISD::FADD) {
- // (fsub (fadd a, a), c) -> mad 2.0, a, (fneg c)
-
- SDValue A = LHS.getOperand(0);
- if (A == LHS.getOperand(1)) {
+ SDValue LHS = N->getOperand(0);
+ SDValue RHS = N->getOperand(1);
+ if (LHS.getOpcode() == ISD::FADD) {
+ // (fsub (fadd a, a), c) -> mad 2.0, a, (fneg c)
+ SDValue A = LHS.getOperand(0);
+ if (A == LHS.getOperand(1)) {
+ unsigned FusedOp = getFusedOpcode(DAG, VT);
+ if (FusedOp != 0){
const SDValue Two = DAG.getConstantFP(2.0, SL, VT);
SDValue NegRHS = DAG.getNode(ISD::FNEG, SL, VT, RHS);
- return DAG.getNode(ISD::FMAD, SL, VT, Two, A, NegRHS);
+ return DAG.getNode(FusedOp, SL, VT, Two, A, NegRHS);
}
}
+ }
- if (RHS.getOpcode() == ISD::FADD) {
- // (fsub c, (fadd a, a)) -> mad -2.0, a, c
+ if (RHS.getOpcode() == ISD::FADD) {
+ // (fsub c, (fadd a, a)) -> mad -2.0, a, c
- SDValue A = RHS.getOperand(0);
- if (A == RHS.getOperand(1)) {
+ SDValue A = RHS.getOperand(0);
+ if (A == RHS.getOperand(1)) {
+ unsigned FusedOp = getFusedOpcode(DAG, VT);
+ if (FusedOp != 0){
const SDValue NegTwo = DAG.getConstantFP(-2.0, SL, VT);
- return DAG.getNode(ISD::FMAD, SL, VT, NegTwo, A, LHS);
+ return DAG.getNode(FusedOp, SL, VT, NegTwo, A, LHS);
}
}
}