X86: fold SSE2/AVX2 logical shift by immediate amount into zero vector when possible

Patch by Andrea Di Biagio


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@186165 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 6284dd7..95ca6c3 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -16321,6 +16321,38 @@
   return SDValue();
 }
 
+/// \brief Returns a vector of 0s if the node in input is a vector logical
+/// shift by a constant amount which is known to be bigger than or equal 
+/// to the vector element size in bits.
+static SDValue performShiftToAllZeros(SDNode *N, SelectionDAG &DAG,
+                                      const X86Subtarget *Subtarget) {
+  EVT VT = N->getValueType(0);
+
+  if (VT != MVT::v2i64 && VT != MVT::v4i32 && VT != MVT::v8i16 &&
+      (!Subtarget->hasInt256() ||
+       (VT != MVT::v4i64 && VT != MVT::v8i32 && VT != MVT::v16i16)))
+    return SDValue();
+
+  SDValue Amt = N->getOperand(1);
+  SDLoc DL(N);
+  if (isSplatVector(Amt.getNode())) {
+    SDValue SclrAmt = Amt->getOperand(0);
+    if (ConstantSDNode *C = dyn_cast<ConstantSDNode>(SclrAmt)) {
+      APInt ShiftAmt = C->getAPIntValue();
+      unsigned MaxAmount = VT.getVectorElementType().getSizeInBits();
+
+      // SSE2/AVX2 logical shifts always return a vector of 0s
+      // if the shift amount is bigger than or equal to 
+      // the element size. The constant shift amount will be
+      // encoded as a 8-bit immediate.
+      if (ShiftAmt.trunc(8).uge(MaxAmount))
+        return getZeroVector(VT, Subtarget, DAG, DL);
+    }
+  }
+
+  return SDValue();
+}
+
 /// PerformShiftCombine - Combine shifts.
 static SDValue PerformShiftCombine(SDNode* N, SelectionDAG &DAG,
                                    TargetLowering::DAGCombinerInfo &DCI,
@@ -16330,6 +16362,12 @@
     if (V.getNode()) return V;
   }
 
+  if (N->getOpcode() != ISD::SRA) {
+    // Try to fold this logical shift into a zero vector.
+    SDValue V = performShiftToAllZeros(N, DAG, Subtarget);
+    if (V.getNode()) return V;
+  }
+
   return SDValue();
 }