[DAGCombiner] Add support for TRUNCATE + FP_EXTEND vector constant folding

This patch adds supports for the vector constant folding of TRUNCATE and FP_EXTEND instructions and tidies up the SINT_TO_FP and UINT_TO_FP instructions to match.

It also moves the vector constant folding for the FNEG and FABS instructions to use the DAG.getNode() functionality like the other unary instructions.

Differential Revision: http://reviews.llvm.org/D8593

llvm-svn: 233224
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index dbdde4e..8b3d057 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -251,7 +251,6 @@
     SDValue visitORLike(SDValue N0, SDValue N1, SDNode *LocReference);
     SDValue visitXOR(SDNode *N);
     SDValue SimplifyVBinOp(SDNode *N);
-    SDValue SimplifyVUnaryOp(SDNode *N);
     SDValue visitSHL(SDNode *N);
     SDValue visitSRA(SDNode *N);
     SDValue visitSRL(SDNode *N);
@@ -716,6 +715,22 @@
   return nullptr;
 }
 
+static SDNode *isConstantIntBuildVectorOrConstantInt(SDValue N) {
+  if (isa<ConstantSDNode>(N))
+    return N.getNode();
+  if (ISD::isBuildVectorOfConstantSDNodes(N.getNode()))
+    return N.getNode();
+  return nullptr;
+}
+
+static SDNode *isConstantFPBuildVectorOrConstantFP(SDValue N) {
+  if (isa<ConstantFPSDNode>(N))
+    return N.getNode();
+  if (ISD::isBuildVectorOfConstantFPSDNodes(N.getNode()))
+    return N.getNode();
+  return nullptr;
+}
+
 // \brief Returns the SDNode if it is a constant splat BuildVector or constant
 // int.
 static ConstantSDNode *isConstOrConstSplat(SDValue N) {
@@ -6557,7 +6572,7 @@
   if (N0.getValueType() == N->getValueType(0))
     return N0;
   // fold (truncate c1) -> c1
-  if (isa<ConstantSDNode>(N0))
+  if (isConstantIntBuildVectorOrConstantInt(N0))
     return DAG.getNode(ISD::TRUNCATE, SDLoc(N), VT, N0);
   // fold (truncate (truncate x)) -> (truncate x)
   if (N0.getOpcode() == ISD::TRUNCATE)
@@ -7947,8 +7962,7 @@
   EVT OpVT = N0.getValueType();
 
   // fold (sint_to_fp c1) -> c1fp
-  ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
-  if (N0C &&
+  if (isConstantIntBuildVectorOrConstantInt(N0) &&
       // ...but only if the target supports immediate floating-point values
       (!LegalOperations ||
        TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT)))
@@ -8000,8 +8014,7 @@
   EVT OpVT = N0.getValueType();
 
   // fold (uint_to_fp c1) -> c1fp
-  ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
-  if (N0C &&
+  if (isConstantIntBuildVectorOrConstantInt(N0) &&
       // ...but only if the target supports immediate floating-point values
       (!LegalOperations ||
        TLI.isOperationLegalOrCustom(llvm::ISD::ConstantFP, VT)))
@@ -8159,7 +8172,6 @@
 
 SDValue DAGCombiner::visitFP_EXTEND(SDNode *N) {
   SDValue N0 = N->getOperand(0);
-  ConstantFPSDNode *N0CFP = dyn_cast<ConstantFPSDNode>(N0);
   EVT VT = N->getValueType(0);
 
   // If this is fp_round(fpextend), don't fold it, allow ourselves to be folded.
@@ -8168,7 +8180,7 @@
     return SDValue();
 
   // fold (fp_extend c1fp) -> c1fp
-  if (N0CFP)
+  if (isConstantFPBuildVectorOrConstantFP(N0))
     return DAG.getNode(ISD::FP_EXTEND, SDLoc(N), VT, N0);
 
   // Turn fp_extend(fp_round(X, 1)) -> x since the fp_round doesn't affect the
@@ -8243,14 +8255,9 @@
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
 
-  if (VT.isVector()) {
-    SDValue FoldedVOp = SimplifyVUnaryOp(N);
-    if (FoldedVOp.getNode()) return FoldedVOp;
-  }
-
   // Constant fold FNEG.
-  if (isa<ConstantFPSDNode>(N0))
-    return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N->getOperand(0));
+  if (isConstantFPBuildVectorOrConstantFP(N0))
+    return DAG.getNode(ISD::FNEG, SDLoc(N), VT, N0);
 
   if (isNegatibleForFree(N0, LegalOperations, DAG.getTargetLoweringInfo(),
                          &DAG.getTarget().Options))
@@ -8345,13 +8352,8 @@
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
 
-  if (VT.isVector()) {
-    SDValue FoldedVOp = SimplifyVUnaryOp(N);
-    if (FoldedVOp.getNode()) return FoldedVOp;
-  }
-
   // fold (fabs c1) -> fabs(c1)
-  if (isa<ConstantFPSDNode>(N0))
+  if (isConstantFPBuildVectorOrConstantFP(N0))
     return DAG.getNode(ISD::FABS, SDLoc(N), VT, N0);
 
   // fold (fabs (fabs x)) -> (fabs x)
@@ -12401,38 +12403,6 @@
   return SDValue();
 }
 
-/// Visit a binary vector operation, like FABS/FNEG.
-SDValue DAGCombiner::SimplifyVUnaryOp(SDNode *N) {
-  assert(N->getValueType(0).isVector() &&
-         "SimplifyVUnaryOp only works on vectors!");
-
-  SDValue N0 = N->getOperand(0);
-
-  if (N0.getOpcode() != ISD::BUILD_VECTOR)
-    return SDValue();
-
-  // Operand is a BUILD_VECTOR node, see if we can constant fold it.
-  SmallVector<SDValue, 8> Ops;
-  for (unsigned i = 0, e = N0.getNumOperands(); i != e; ++i) {
-    SDValue Op = N0.getOperand(i);
-    if (Op.getOpcode() != ISD::UNDEF &&
-        Op.getOpcode() != ISD::ConstantFP)
-      break;
-    EVT EltVT = Op.getValueType();
-    SDValue FoldOp = DAG.getNode(N->getOpcode(), SDLoc(N0), EltVT, Op);
-    if (FoldOp.getOpcode() != ISD::UNDEF &&
-        FoldOp.getOpcode() != ISD::ConstantFP)
-      break;
-    Ops.push_back(FoldOp);
-    AddToWorklist(FoldOp.getNode());
-  }
-
-  if (Ops.size() != N0.getNumOperands())
-    return SDValue();
-
-  return DAG.getNode(ISD::BUILD_VECTOR, SDLoc(N), N0.getValueType(), Ops);
-}
-
 SDValue DAGCombiner::SimplifySelect(SDLoc DL, SDValue N0,
                                     SDValue N1, SDValue N2){
   assert(N0.getOpcode() ==ISD::SETCC && "First argument must be a SetCC node!");
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index eef07d9..b52f648 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -196,6 +196,22 @@
   return true;
 }
 
+/// \brief Return true if the specified node is a BUILD_VECTOR node of
+/// all ConstantFPSDNode or undef.
+bool ISD::isBuildVectorOfConstantFPSDNodes(const SDNode *N) {
+  if (N->getOpcode() != ISD::BUILD_VECTOR)
+    return false;
+
+  for (unsigned i = 0, e = N->getNumOperands(); i != e; ++i) {
+    SDValue Op = N->getOperand(i);
+    if (Op.getOpcode() == ISD::UNDEF)
+      continue;
+    if (!isa<ConstantFPSDNode>(Op))
+      return false;
+  }
+  return true;
+}
+
 /// isScalarToVector - Return true if the specified node is a
 /// ISD::SCALAR_TO_VECTOR node or a BUILD_VECTOR node where only the low
 /// element is not an undef.
@@ -2827,7 +2843,7 @@
     }
   }
 
-  // Constant fold unary operations with a vector integer operand.
+  // Constant fold unary operations with a vector integer or float operand.
   if (BuildVectorSDNode *BV = dyn_cast<BuildVectorSDNode>(Operand.getNode())) {
     if (BV->isConstant()) {
       switch (Opcode) {
@@ -2835,6 +2851,10 @@
         // FIXME: Entirely reasonable to perform folding of other unary
         // operations here as the need arises.
         break;
+      case ISD::FNEG:
+      case ISD::FABS:
+      case ISD::FP_EXTEND:
+      case ISD::TRUNCATE:
       case ISD::UINT_TO_FP:
       case ISD::SINT_TO_FP: {
         // Let the above scalar folding handle the folding of each element.
@@ -2842,9 +2862,14 @@
         for (int i = 0, e = VT.getVectorNumElements(); i != e; ++i) {
           SDValue OpN = BV->getOperand(i);
           OpN = getNode(Opcode, DL, VT.getVectorElementType(), OpN);
+          if (OpN.getOpcode() != ISD::UNDEF &&
+              OpN.getOpcode() != ISD::Constant &&
+              OpN.getOpcode() != ISD::ConstantFP)
+            break;
           Ops.push_back(OpN);
         }
-        return getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
+        if (Ops.size() == VT.getVectorNumElements())
+          return getNode(ISD::BUILD_VECTOR, DL, VT, Ops);
       }
       }
     }