[DAGCombiner] Call SimplifyDemandedVectorElts from EXTRACT_VECTOR_ELT

If we are only extracting vector elements via EXTRACT_VECTOR_ELT(s) we may be able to use SimplifyDemandedVectorElts to avoid unnecessary vector ops.

Differential Revision: https://reviews.llvm.org/D49262

llvm-svn: 337258
diff --git a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
index e69afd8..302c788 100644
--- a/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
+++ b/llvm/lib/Target/SystemZ/SystemZISelLowering.cpp
@@ -3893,20 +3893,34 @@
   return nullptr;
 }
 
-// Convert the mask of the given VECTOR_SHUFFLE into a byte-level mask,
+// Convert the mask of the given shuffle op into a byte-level mask,
 // as if it had type vNi8.
-static void getVPermMask(ShuffleVectorSDNode *VSN,
+static bool getVPermMask(SDValue ShuffleOp,
                          SmallVectorImpl<int> &Bytes) {
-  EVT VT = VSN->getValueType(0);
+  EVT VT = ShuffleOp.getValueType();
   unsigned NumElements = VT.getVectorNumElements();
   unsigned BytesPerElement = VT.getVectorElementType().getStoreSize();
-  Bytes.resize(NumElements * BytesPerElement, -1);
-  for (unsigned I = 0; I < NumElements; ++I) {
-    int Index = VSN->getMaskElt(I);
-    if (Index >= 0)
+
+  if (auto *VSN = dyn_cast<ShuffleVectorSDNode>(ShuffleOp)) {
+    Bytes.resize(NumElements * BytesPerElement, -1);
+    for (unsigned I = 0; I < NumElements; ++I) {
+      int Index = VSN->getMaskElt(I);
+      if (Index >= 0)
+        for (unsigned J = 0; J < BytesPerElement; ++J)
+          Bytes[I * BytesPerElement + J] = Index * BytesPerElement + J;
+    }
+    return true;
+  }
+  if (SystemZISD::SPLAT == ShuffleOp.getOpcode() &&
+      isa<ConstantSDNode>(ShuffleOp.getOperand(1))) {
+    unsigned Index = ShuffleOp.getConstantOperandVal(1);
+    Bytes.resize(NumElements * BytesPerElement, -1);
+    for (unsigned I = 0; I < NumElements; ++I)
       for (unsigned J = 0; J < BytesPerElement; ++J)
         Bytes[I * BytesPerElement + J] = Index * BytesPerElement + J;
+    return true;
   }
+  return false;
 }
 
 // Bytes is a VPERM-like permute vector, except that -1 is used for
@@ -4075,7 +4089,8 @@
       // See whether the bytes we need come from a contiguous part of one
       // operand.
       SmallVector<int, SystemZ::VectorBytes> OpBytes;
-      getVPermMask(cast<ShuffleVectorSDNode>(Op), OpBytes);
+      if (!getVPermMask(Op, OpBytes))
+        break;
       int NewByte;
       if (!getShuffleInput(OpBytes, Byte, BytesPerElement, NewByte))
         break;
@@ -5109,13 +5124,14 @@
     if (Opcode == ISD::BITCAST)
       // Look through bitcasts.
       Op = Op.getOperand(0);
-    else if (Opcode == ISD::VECTOR_SHUFFLE &&
+    else if ((Opcode == ISD::VECTOR_SHUFFLE || Opcode == SystemZISD::SPLAT) &&
              canTreatAsByteVector(Op.getValueType())) {
       // Get a VPERM-like permute mask and see whether the bytes covered
       // by the extracted element are a contiguous sequence from one
       // source operand.
       SmallVector<int, SystemZ::VectorBytes> Bytes;
-      getVPermMask(cast<ShuffleVectorSDNode>(Op), Bytes);
+      if (!getVPermMask(Op, Bytes))
+        break;
       int First;
       if (!getShuffleInput(Bytes, Index * BytesPerElement,
                            BytesPerElement, First))