Optimized integer vector multiplication operation by replacing it with shift/xor/sub when it is possible. Fixed a bug in SDIV, where the const operand is not a splat constant vector.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@184931 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index cb9778b..5af4aa0 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -326,7 +326,10 @@
     /// getShiftAmountTy - Returns a type large enough to hold any valid
     /// shift amount - before type legalization these can be huge.
     EVT getShiftAmountTy(EVT LHSTy) {
-      return LegalTypes ? TLI.getShiftAmountTy(LHSTy) : TLI.getPointerTy();
+      assert(LHSTy.isInteger() && "Shift amount is not an integer type!");
+      if (LHSTy.isVector())
+        return LHSTy;
+      return LegalTypes ? TLI.getScalarShiftAmountTy(LHSTy) : TLI.getPointerTy();
     }
 
     /// isTypeLegal - This method returns true if we are running before type
@@ -1762,43 +1765,73 @@
   return SDValue();
 }
 
+/// isSplatVector - Returns true if N is a BUILD_VECTOR node whose elements are
+/// all the same or undefined.
+static bool isConstantSplatVector(SDNode *N, APInt& SplatValue) {
+  BuildVectorSDNode *C = dyn_cast<BuildVectorSDNode>(N);
+  if (!C)
+    return false;
+
+  APInt SplatUndef;
+  unsigned SplatBitSize;
+  bool HasAnyUndefs;
+  EVT EltVT = N->getValueType(0).getVectorElementType();
+  return (C->isConstantSplat(SplatValue, SplatUndef, SplatBitSize,
+                             HasAnyUndefs) &&
+          EltVT.getSizeInBits() >= SplatBitSize);
+}
+
 SDValue DAGCombiner::visitMUL(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
-  ConstantSDNode *N0C = dyn_cast<ConstantSDNode>(N0);
-  ConstantSDNode *N1C = dyn_cast<ConstantSDNode>(N1);
   EVT VT = N0.getValueType();
 
-  // fold vector ops
-  if (VT.isVector()) {
-    SDValue FoldedVOp = SimplifyVBinOp(N);
-    if (FoldedVOp.getNode()) return FoldedVOp;
-  }
-
   // fold (mul x, undef) -> 0
   if (N0.getOpcode() == ISD::UNDEF || N1.getOpcode() == ISD::UNDEF)
     return DAG.getConstant(0, VT);
+
+  bool N0IsConst = false;
+  bool N1IsConst = false;
+  APInt ConstValue0, ConstValue1;
+  // fold vector ops
+  if (VT.isVector()) {
+    SDValue FoldedVOp = SimplifyVBinOp(N);
+    if (FoldedVOp.getNode()) return FoldedVOp;
+
+    N0IsConst = isConstantSplatVector(N0.getNode(), ConstValue0);
+    N1IsConst = isConstantSplatVector(N1.getNode(), ConstValue1);
+  } else {
+    N0IsConst = dyn_cast<ConstantSDNode>(N0) != 0;
+    ConstValue0 = N0IsConst? (dyn_cast<ConstantSDNode>(N0))->getAPIntValue() : APInt();
+    N1IsConst = dyn_cast<ConstantSDNode>(N1) != 0;
+    ConstValue1 = N1IsConst? (dyn_cast<ConstantSDNode>(N1))->getAPIntValue() : APInt();
+  }
+
   // fold (mul c1, c2) -> c1*c2
-  if (N0C && N1C)
-    return DAG.FoldConstantArithmetic(ISD::MUL, VT, N0C, N1C);
+  if (N0IsConst && N1IsConst)
+    return DAG.FoldConstantArithmetic(ISD::MUL, VT, N0.getNode(), N1.getNode());
+
   // canonicalize constant to RHS
-  if (N0C && !N1C)
+  if (N0IsConst && !N1IsConst)
     return DAG.getNode(ISD::MUL, SDLoc(N), VT, N1, N0);
   // fold (mul x, 0) -> 0
-  if (N1C && N1C->isNullValue())
+  if (N1IsConst && ConstValue1 == 0)
     return N1;
+  // fold (mul x, 1) -> x
+  if (N1IsConst && ConstValue1 == 1)
+    return N0;
   // fold (mul x, -1) -> 0-x
-  if (N1C && N1C->isAllOnesValue())
+  if (N1IsConst && ConstValue1.isAllOnesValue())
     return DAG.getNode(ISD::SUB, SDLoc(N), VT,
                        DAG.getConstant(0, VT), N0);
   // fold (mul x, (1 << c)) -> x << c
-  if (N1C && N1C->getAPIntValue().isPowerOf2())
+  if (N1IsConst && ConstValue1.isPowerOf2())
     return DAG.getNode(ISD::SHL, SDLoc(N), VT, N0,
-                       DAG.getConstant(N1C->getAPIntValue().logBase2(),
+                       DAG.getConstant(ConstValue1.logBase2(),
                                        getShiftAmountTy(N0.getValueType())));
   // fold (mul x, -(1 << c)) -> -(x << c) or (-x) << c
-  if (N1C && (-N1C->getAPIntValue()).isPowerOf2()) {
-    unsigned Log2Val = (-N1C->getAPIntValue()).logBase2();
+  if (N1IsConst && (-ConstValue1).isPowerOf2()) {
+    unsigned Log2Val = (-ConstValue1).logBase2();
     // FIXME: If the input is something that is easily negated (e.g. a
     // single-use add), we should put the negate there.
     return DAG.getNode(ISD::SUB, SDLoc(N), VT,
@@ -1807,9 +1840,12 @@
                             DAG.getConstant(Log2Val,
                                       getShiftAmountTy(N0.getValueType()))));
   }
+
+  APInt Val;
   // (mul (shl X, c1), c2) -> (mul X, c2 << c1)
-  if (N1C && N0.getOpcode() == ISD::SHL &&
-      isa<ConstantSDNode>(N0.getOperand(1))) {
+  if (N1IsConst && N0.getOpcode() == ISD::SHL && 
+      (isConstantSplatVector(N0.getOperand(1).getNode(), Val) ||
+                     isa<ConstantSDNode>(N0.getOperand(1)))) {
     SDValue C3 = DAG.getNode(ISD::SHL, SDLoc(N), VT,
                              N1, N0.getOperand(1));
     AddToWorkList(C3.getNode());
@@ -1822,7 +1858,9 @@
   {
     SDValue Sh(0,0), Y(0,0);
     // Check for both (mul (shl X, C), Y)  and  (mul Y, (shl X, C)).
-    if (N0.getOpcode() == ISD::SHL && isa<ConstantSDNode>(N0.getOperand(1)) &&
+    if (N0.getOpcode() == ISD::SHL && 
+        (isConstantSplatVector(N0.getOperand(1).getNode(), Val) ||
+                       isa<ConstantSDNode>(N0.getOperand(1))) &&
         N0.getNode()->hasOneUse()) {
       Sh = N0; Y = N1;
     } else if (N1.getOpcode() == ISD::SHL &&
@@ -1840,8 +1878,9 @@
   }
 
   // fold (mul (add x, c1), c2) -> (add (mul x, c2), c1*c2)
-  if (N1C && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() &&
-      isa<ConstantSDNode>(N0.getOperand(1)))
+  if (N1IsConst && N0.getOpcode() == ISD::ADD && N0.getNode()->hasOneUse() &&
+      (isConstantSplatVector(N0.getOperand(1).getNode(), Val) ||
+                     isa<ConstantSDNode>(N0.getOperand(1))))
     return DAG.getNode(ISD::ADD, SDLoc(N), VT,
                        DAG.getNode(ISD::MUL, SDLoc(N0), VT,
                                    N0.getOperand(0), N1),