Extend VPBLENDVB and VPSIGN lowering to work for AVX2.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@144987 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 6a14f22..b45d3f6 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -13859,98 +13859,105 @@
     return R;
 
   EVT VT = N->getValueType(0);
-  if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64 && VT != MVT::v2i64)
-    return SDValue();
 
   SDValue N0 = N->getOperand(0);
   SDValue N1 = N->getOperand(1);
 
   // look for psign/blend
-  if (Subtarget->hasSSSE3() || Subtarget->hasAVX()) {
-    if (VT == MVT::v2i64) {
-      // Canonicalize pandn to RHS
-      if (N0.getOpcode() == X86ISD::ANDNP)
-        std::swap(N0, N1);
-      // or (and (m, x), (pandn m, y))
-      if (N0.getOpcode() == ISD::AND && N1.getOpcode() == X86ISD::ANDNP) {
-        SDValue Mask = N1.getOperand(0);
-        SDValue X    = N1.getOperand(1);
-        SDValue Y;
-        if (N0.getOperand(0) == Mask)
-          Y = N0.getOperand(1);
-        if (N0.getOperand(1) == Mask)
-          Y = N0.getOperand(0);
+  if (VT == MVT::v2i64 || VT == MVT::v4i64) {
+    if (!(Subtarget->hasSSSE3() || Subtarget->hasAVX()) ||
+        (VT == MVT::v4i64 && !Subtarget->hasAVX2()))
+      return SDValue();
 
-        // Check to see if the mask appeared in both the AND and ANDNP and
-        if (!Y.getNode())
-          return SDValue();
+    // Canonicalize pandn to RHS
+    if (N0.getOpcode() == X86ISD::ANDNP)
+      std::swap(N0, N1);
+    // or (and (m, x), (pandn m, y))
+    if (N0.getOpcode() == ISD::AND && N1.getOpcode() == X86ISD::ANDNP) {
+      SDValue Mask = N1.getOperand(0);
+      SDValue X    = N1.getOperand(1);
+      SDValue Y;
+      if (N0.getOperand(0) == Mask)
+        Y = N0.getOperand(1);
+      if (N0.getOperand(1) == Mask)
+        Y = N0.getOperand(0);
 
-        // Validate that X, Y, and Mask are BIT_CONVERTS, and see through them.
-        if (Mask.getOpcode() != ISD::BITCAST ||
-            X.getOpcode() != ISD::BITCAST ||
-            Y.getOpcode() != ISD::BITCAST)
-          return SDValue();
+      // Check to see if the mask appeared in both the AND and ANDNP and
+      if (!Y.getNode())
+        return SDValue();
 
-        // Look through mask bitcast.
-        Mask = Mask.getOperand(0);
-        EVT MaskVT = Mask.getValueType();
+      // Validate that X, Y, and Mask are BIT_CONVERTS, and see through them.
+      if (Mask.getOpcode() != ISD::BITCAST ||
+          X.getOpcode() != ISD::BITCAST ||
+          Y.getOpcode() != ISD::BITCAST)
+        return SDValue();
 
-        // Validate that the Mask operand is a vector sra node.  The sra node
-        // will be an intrinsic.
-        if (Mask.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
-          return SDValue();
+      // Look through mask bitcast.
+      Mask = Mask.getOperand(0);
+      EVT MaskVT = Mask.getValueType();
 
-        // FIXME: what to do for bytes, since there is a psignb/pblendvb, but
-        // there is no psrai.b
-        switch (cast<ConstantSDNode>(Mask.getOperand(0))->getZExtValue()) {
-        case Intrinsic::x86_sse2_psrai_w:
-        case Intrinsic::x86_sse2_psrai_d:
-          break;
-        default: return SDValue();
-        }
+      // Validate that the Mask operand is a vector sra node.  The sra node
+      // will be an intrinsic.
+      if (Mask.getOpcode() != ISD::INTRINSIC_WO_CHAIN)
+        return SDValue();
 
-        // Check that the SRA is all signbits.
-        SDValue SraC = Mask.getOperand(2);
-        unsigned SraAmt  = cast<ConstantSDNode>(SraC)->getZExtValue();
-        unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
-        if ((SraAmt + 1) != EltBits)
-          return SDValue();
-
-        DebugLoc DL = N->getDebugLoc();
-
-        // Now we know we at least have a plendvb with the mask val.  See if
-        // we can form a psignb/w/d.
-        // psign = x.type == y.type == mask.type && y = sub(0, x);
-        X = X.getOperand(0);
-        Y = Y.getOperand(0);
-        if (Y.getOpcode() == ISD::SUB && Y.getOperand(1) == X &&
-            ISD::isBuildVectorAllZeros(Y.getOperand(0).getNode()) &&
-            X.getValueType() == MaskVT && X.getValueType() == Y.getValueType()){
-          unsigned Opc = 0;
-          switch (EltBits) {
-          case 8: Opc = X86ISD::PSIGNB; break;
-          case 16: Opc = X86ISD::PSIGNW; break;
-          case 32: Opc = X86ISD::PSIGND; break;
-          default: break;
-          }
-          if (Opc) {
-            SDValue Sign = DAG.getNode(Opc, DL, MaskVT, X, Mask.getOperand(1));
-            return DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, Sign);
-          }
-        }
-        // PBLENDVB only available on SSE 4.1
-        if (!(Subtarget->hasSSE41() || Subtarget->hasAVX()))
-          return SDValue();
-
-        X = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, X);
-        Y = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Y);
-        Mask = DAG.getNode(ISD::BITCAST, DL, MVT::v16i8, Mask);
-        Mask = DAG.getNode(ISD::VSELECT, DL, MVT::v16i8, Mask, X, Y);
-        return DAG.getNode(ISD::BITCAST, DL, MVT::v2i64, Mask);
+      // FIXME: what to do for bytes, since there is a psignb/pblendvb, but
+      // there is no psrai.b
+      switch (cast<ConstantSDNode>(Mask.getOperand(0))->getZExtValue()) {
+      case Intrinsic::x86_sse2_psrai_w:
+      case Intrinsic::x86_sse2_psrai_d:
+      case Intrinsic::x86_avx2_psrai_w:
+      case Intrinsic::x86_avx2_psrai_d:
+        break;
+      default: return SDValue();
       }
+
+      // Check that the SRA is all signbits.
+      SDValue SraC = Mask.getOperand(2);
+      unsigned SraAmt  = cast<ConstantSDNode>(SraC)->getZExtValue();
+      unsigned EltBits = MaskVT.getVectorElementType().getSizeInBits();
+      if ((SraAmt + 1) != EltBits)
+        return SDValue();
+
+      DebugLoc DL = N->getDebugLoc();
+
+      // Now we know we at least have a plendvb with the mask val.  See if
+      // we can form a psignb/w/d.
+      // psign = x.type == y.type == mask.type && y = sub(0, x);
+      X = X.getOperand(0);
+      Y = Y.getOperand(0);
+      if (Y.getOpcode() == ISD::SUB && Y.getOperand(1) == X &&
+          ISD::isBuildVectorAllZeros(Y.getOperand(0).getNode()) &&
+          X.getValueType() == MaskVT && X.getValueType() == Y.getValueType()){
+        unsigned Opc = 0;
+        switch (EltBits) {
+        case 8: Opc = X86ISD::PSIGNB; break;
+        case 16: Opc = X86ISD::PSIGNW; break;
+        case 32: Opc = X86ISD::PSIGND; break;
+        default: break;
+        }
+        if (Opc) {
+          SDValue Sign = DAG.getNode(Opc, DL, MaskVT, X, Mask.getOperand(1));
+          return DAG.getNode(ISD::BITCAST, DL, VT, Sign);
+        }
+      }
+      // PBLENDVB only available on SSE 4.1
+      if (!(Subtarget->hasSSE41() || Subtarget->hasAVX()))
+        return SDValue();
+
+      EVT BlendVT = (VT == MVT::v4i64) ? MVT::v32i8 : MVT::v16i8;
+
+      X = DAG.getNode(ISD::BITCAST, DL, BlendVT, X);
+      Y = DAG.getNode(ISD::BITCAST, DL, BlendVT, Y);
+      Mask = DAG.getNode(ISD::BITCAST, DL, BlendVT, Mask);
+      Mask = DAG.getNode(ISD::VSELECT, DL, BlendVT, Mask, X, Y);
+      return DAG.getNode(ISD::BITCAST, DL, VT, Mask);
     }
   }
 
+  if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
+    return SDValue();
+
   // fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c)
   if (N0.getOpcode() == ISD::SRL && N1.getOpcode() == ISD::SHL)
     std::swap(N0, N1);