Fix fmul combines with constant splat vectors

Fixes things like fmul x, 2 -> fadd x, x

llvm-svn: 215820
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index e9a38b1..f6dc938 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -701,6 +701,27 @@
   return nullptr;
 }
 
+// \brief Returns the SDNode if it is a constant splat BuildVector or constant
+// float.
+static ConstantFPSDNode *isConstOrConstSplatFP(SDValue N) {
+  if (ConstantFPSDNode *CN = dyn_cast<ConstantFPSDNode>(N))
+    return CN;
+
+  if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(N)) {
+    BitVector UndefElements;
+    ConstantFPSDNode *CN = BV->getConstantFPSplatNode(&UndefElements);
+
+    // BuildVectors can truncate their operands. Ignore that case here.
+    // FIXME: We blindly ignore splats which include undef which is overly
+    // pessimistic.
+    if (CN && UndefElements.none() &&
+        CN->getValueType(0) == N.getValueType().getScalarType())
+      return CN;
+  }
+
+  return nullptr;
+}
+
 SDValue DAGCombiner::ReassociateOps(unsigned Opc, SDLoc DL,
                                     SDValue N0, SDValue N1) {
   EVT VT = N0.getValueType();
@@ -6830,8 +6851,8 @@
 SDValue DAGCombiner::visitFMUL(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
-  ConstantFPSDNode *N1CFP = dyn_cast<ConstantFPSDNode>(N1);
+  ConstantFPSDNode *N0CFP = isConstOrConstSplatFP(N0);
+  ConstantFPSDNode *N1CFP = isConstOrConstSplatFP(N1);
   EVT VT = N->getValueType(0);
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
 
@@ -6851,13 +6872,10 @@
   if (DAG.getTarget().Options.UnsafeFPMath &&
       N1CFP && N1CFP->getValueAPF().isZero())
     return N1;
-  // fold (fmul A, 0) -> 0, vector edition.
-  if (DAG.getTarget().Options.UnsafeFPMath &&
-      ISD::isBuildVectorAllZeros(N1.getNode()))
-    return N1;
   // fold (fmul A, 1.0) -> A
   if (N1CFP && N1CFP->isExactlyValue(1.0))
     return N0;
+
   // fold (fmul X, 2.0) -> (fadd X, X)
   if (N1CFP && N1CFP->isExactlyValue(+2.0))
     return DAG.getNode(ISD::FADD, SDLoc(N), VT, N0, N0);
@@ -6883,10 +6901,11 @@
   // If allowed, fold (fmul (fmul x, c1), c2) -> (fmul x, (fmul c1, c2))
   if (DAG.getTarget().Options.UnsafeFPMath &&
       N1CFP && N0.getOpcode() == ISD::FMUL &&
-      N0.getNode()->hasOneUse() && isa<ConstantFPSDNode>(N0.getOperand(1)))
+      N0.getNode()->hasOneUse() && isConstOrConstSplatFP(N0.getOperand(1))) {
     return DAG.getNode(ISD::FMUL, SDLoc(N), VT, N0.getOperand(0),
                        DAG.getNode(ISD::FMUL, SDLoc(N), VT,
                                    N0.getOperand(1), N1));
+  }
 
   return SDValue();
 }