AVX-512: added UNPACK instructions and tests for all-zero/all-ones vectors

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@189189 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 6e9ecef..a00f848 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -3866,37 +3866,46 @@
 static bool isUNPCKLMask(ArrayRef<int> Mask, MVT VT,
                          bool HasInt256, bool V2IsSplat = false) {
 
-  if (VT.is512BitVector())
-    return false;
-  assert((VT.is128BitVector() || VT.is256BitVector()) &&
-         "Unsupported vector type for unpckh");
+  assert(VT.getSizeInBits() >= 128 &&
+         "Unsupported vector type for unpckl");
 
+  // AVX defines UNPCK* to operate independently on 128-bit lanes.
+  unsigned NumLanes;
+  unsigned NumOf256BitLanes;
   unsigned NumElts = VT.getVectorNumElements();
-  if (VT.is256BitVector() && NumElts != 4 && NumElts != 8 &&
-      (!HasInt256 || (NumElts != 16 && NumElts != 32)))
+  if (VT.is256BitVector()) {
+    if (NumElts != 4 && NumElts != 8 &&
+        (!HasInt256 || (NumElts != 16 && NumElts != 32)))
     return false;
+    NumLanes = 2;
+    NumOf256BitLanes = 1;
+  } else if (VT.is512BitVector()) {
+    assert(VT.getScalarType().getSizeInBits() >= 32 &&
+           "Unsupported vector type for unpckh");
+    NumLanes = 2;
+    NumOf256BitLanes = 2;
+  } else {
+    NumLanes = 1;
+    NumOf256BitLanes = 1;
+  }
 
-  // Handle 128 and 256-bit vector lengths. AVX defines UNPCK* to operate
-  // independently on 128-bit lanes.
-  unsigned NumLanes = VT.getSizeInBits()/128;
-  unsigned NumLaneElts = NumElts/NumLanes;
+  unsigned NumEltsInStride = NumElts/NumOf256BitLanes;
+  unsigned NumLaneElts = NumEltsInStride/NumLanes;
 
-  for (unsigned l = 0; l != NumElts; l += NumLaneElts) {
-    for (unsigned i = 0, j = l; i != NumLaneElts; i += 2, ++j) {
-      int BitI  = Mask[l+i];
-      int BitI1 = Mask[l+i+1];
-      if (!isUndefOrEqual(BitI, j))
-        return false;
-      if (V2IsSplat) {
-        if (!isUndefOrEqual(BitI1, NumElts))
+  for (unsigned l256 = 0; l256 < NumOf256BitLanes; l256 += 1) {
+    for (unsigned l = 0; l != NumEltsInStride; l += NumLaneElts) {
+      for (unsigned i = 0, j = l; i != NumLaneElts; i += 2, ++j) {
+        int BitI  = Mask[l256*NumEltsInStride+l+i];
+        int BitI1 = Mask[l256*NumEltsInStride+l+i+1];
+        if (!isUndefOrEqual(BitI, j+l256*NumElts))
           return false;
-      } else {
-        if (!isUndefOrEqual(BitI1, j + NumElts))
+        if (V2IsSplat && !isUndefOrEqual(BitI1, NumElts))
+          return false;
+        if (!isUndefOrEqual(BitI1, j+l256*NumElts+NumEltsInStride))
           return false;
       }
     }
   }
-
   return true;
 }
 
@@ -3904,33 +3913,42 @@
 /// specifies a shuffle of elements that is suitable for input to UNPCKH.
 static bool isUNPCKHMask(ArrayRef<int> Mask, MVT VT,
                          bool HasInt256, bool V2IsSplat = false) {
-  unsigned NumElts = VT.getVectorNumElements();
-
-  if (VT.is512BitVector())
-    return false;
-  assert((VT.is128BitVector() || VT.is256BitVector()) &&
+  assert(VT.getSizeInBits() >= 128 &&
          "Unsupported vector type for unpckh");
 
-  if (VT.is256BitVector() && NumElts != 4 && NumElts != 8 &&
-      (!HasInt256 || (NumElts != 16 && NumElts != 32)))
+  // AVX defines UNPCK* to operate independently on 128-bit lanes.
+  unsigned NumLanes;
+  unsigned NumOf256BitLanes;
+  unsigned NumElts = VT.getVectorNumElements();
+  if (VT.is256BitVector()) {
+    if (NumElts != 4 && NumElts != 8 &&
+        (!HasInt256 || (NumElts != 16 && NumElts != 32)))
     return false;
+    NumLanes = 2;
+    NumOf256BitLanes = 1;
+  } else if (VT.is512BitVector()) {
+    assert(VT.getScalarType().getSizeInBits() >= 32 &&
+           "Unsupported vector type for unpckh");
+    NumLanes = 2;
+    NumOf256BitLanes = 2;
+  } else {
+    NumLanes = 1;
+    NumOf256BitLanes = 1;
+  }
 
-  // Handle 128 and 256-bit vector lengths. AVX defines UNPCK* to operate
-  // independently on 128-bit lanes.
-  unsigned NumLanes = VT.getSizeInBits()/128;
-  unsigned NumLaneElts = NumElts/NumLanes;
+  unsigned NumEltsInStride = NumElts/NumOf256BitLanes;
+  unsigned NumLaneElts = NumEltsInStride/NumLanes;
 
-  for (unsigned l = 0; l != NumElts; l += NumLaneElts) {
-    for (unsigned i = 0, j = l+NumLaneElts/2; i != NumLaneElts; i += 2, ++j) {
-      int BitI  = Mask[l+i];
-      int BitI1 = Mask[l+i+1];
-      if (!isUndefOrEqual(BitI, j))
-        return false;
-      if (V2IsSplat) {
-        if (isUndefOrEqual(BitI1, NumElts))
+  for (unsigned l256 = 0; l256 < NumOf256BitLanes; l256 += 1) {
+    for (unsigned l = 0; l != NumEltsInStride; l += NumLaneElts) {
+      for (unsigned i = 0, j = l+NumLaneElts/2; i != NumLaneElts; i += 2, ++j) {
+        int BitI  = Mask[l256*NumEltsInStride+l+i];
+        int BitI1 = Mask[l256*NumEltsInStride+l+i+1];
+        if (!isUndefOrEqual(BitI, j+l256*NumElts))
           return false;
-      } else {
-        if (!isUndefOrEqual(BitI1, j+NumElts))
+        if (V2IsSplat && !isUndefOrEqual(BitI1, NumElts))
+          return false;
+        if (!isUndefOrEqual(BitI1, j+l256*NumElts+NumEltsInStride))
           return false;
       }
     }
@@ -4336,7 +4354,7 @@
 static unsigned getShuffleSHUFImmediate(ShuffleVectorSDNode *N) {
   MVT VT = N->getSimpleValueType(0);
 
-  assert((VT.is128BitVector() || VT.is256BitVector()) &&
+  assert((VT.getSizeInBits() >= 128) &&
          "Unsupported vector type for PSHUF/SHUFP");
 
   // Handle 128 and 256-bit vector lengths. AVX defines PSHUF/SHUFP to operate
@@ -4345,10 +4363,10 @@
   unsigned NumLanes = VT.getSizeInBits()/128;
   unsigned NumLaneElts = NumElts/NumLanes;
 
-  assert((NumLaneElts == 2 || NumLaneElts == 4) &&
-         "Only supports 2 or 4 elements per lane");
+  assert((NumLaneElts == 2 || NumLaneElts == 4 || NumLaneElts == 8) &&
+         "Only supports 2, 4 or 8 elements per lane");
 
-  unsigned Shift = (NumLaneElts == 4) ? 1 : 0;
+  unsigned Shift = (NumLaneElts >= 4) ? 1 : 0;
   unsigned Mask = 0;
   for (unsigned i = 0; i != NumElts; ++i) {
     int Elt = N->getMaskElt(i);
@@ -4680,6 +4698,11 @@
       Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v8f32, Ops,
                         array_lengthof(Ops));
     }
+  } else if (VT.is512BitVector()) { // AVX-512
+      SDValue Cst = DAG.getTargetConstant(0, MVT::i32);
+      SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst,
+                        Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
+      Vec = DAG.getNode(ISD::BUILD_VECTOR, dl, MVT::v16i32, Ops, 16);
   } else
     llvm_unreachable("Unexpected vector type");
 
@@ -5674,10 +5697,13 @@
                        DAG.getIntPtrConstant(0));
   }
 
-  if (!isSplatVector(Op.getNode()))
-    llvm_unreachable("Unsupported predicate operation");
-
+  // Splat vector (with undefs)
   SDValue In = Op.getOperand(0);
+  for (unsigned i = 1, e = Op.getNumOperands(); i != e; ++i) {
+    if (Op.getOperand(i) != In && Op.getOperand(i).getOpcode() != ISD::UNDEF)
+      llvm_unreachable("Unsupported predicate operation");
+  }
+
   SDValue EFLAGS, X86CC;
   if (In.getOpcode() == ISD::SETCC) {
     SDValue Op0 = In.getOperand(0);
@@ -5759,7 +5785,8 @@
     if (VT == MVT::v4i32 || (VT == MVT::v8i32 && Subtarget->hasInt256()))
       return Op;
 
-    return getOnesVector(VT, Subtarget->hasInt256(), DAG, dl);
+    if (!VT.is512BitVector())
+      return getOnesVector(VT, Subtarget->hasInt256(), DAG, dl);
   }
 
   SDValue Broadcast = LowerVectorBroadcast(Op, Subtarget, DAG);
@@ -5841,7 +5868,7 @@
 
       if (ExtVT == MVT::i32 || ExtVT == MVT::f32 || ExtVT == MVT::f64 ||
           (ExtVT == MVT::i64 && Subtarget->is64Bit())) {
-        if (VT.is256BitVector()) {
+        if (VT.is256BitVector() || VT.is512BitVector()) {
           SDValue ZeroVec = getZeroVector(VT, Subtarget, DAG, dl);
           return DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, VT, ZeroVec,
                              Item, DAG.getIntPtrConstant(0));
diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td
index 0cb946d..17be5df 100644
--- a/lib/Target/X86/X86InstrAVX512.td
+++ b/lib/Target/X86/X86InstrAVX512.td
@@ -1561,6 +1561,68 @@
           (VPMULUDQZrr VR512:$src1, VR512:$src2)>;
 
 //===----------------------------------------------------------------------===//
+// AVX-512 - Unpack Instructions
+//===----------------------------------------------------------------------===//
+
+multiclass avx512_unpack_fp<bits<8> opc, SDNode OpNode, ValueType vt,
+                                   PatFrag mem_frag, RegisterClass RC,
+                                   X86MemOperand x86memop, string asm,
+                                   Domain d> {
+    def rr : AVX512PI<opc, MRMSrcReg,
+                (outs RC:$dst), (ins RC:$src1, RC:$src2),
+                asm, [(set RC:$dst,
+                           (vt (OpNode RC:$src1, RC:$src2)))],
+                           d>, EVEX_4V, TB;
+    def rm : AVX512PI<opc, MRMSrcMem,
+                (outs RC:$dst), (ins RC:$src1, x86memop:$src2),
+                asm, [(set RC:$dst,
+                       (vt (OpNode RC:$src1,
+                            (bitconvert (mem_frag addr:$src2)))))],
+                        d>, EVEX_4V, TB;
+}
+
+defm VUNPCKHPSZ: avx512_unpack_fp<0x15, X86Unpckh, v16f32, memopv8f64,
+      VR512, f512mem, "vunpckhps\t{$src2, $src1, $dst|$dst, $src1, $src2}",
+      SSEPackedSingle>, EVEX_V512, EVEX_CD8<32, CD8VF>;
+defm VUNPCKHPDZ: avx512_unpack_fp<0x15, X86Unpckh, v8f64, memopv8f64,
+      VR512, f512mem, "vunpckhpd\t{$src2, $src1, $dst|$dst, $src1, $src2}",
+      SSEPackedDouble>, OpSize, EVEX_V512, VEX_W, EVEX_CD8<64, CD8VF>;
+defm VUNPCKLPSZ: avx512_unpack_fp<0x14, X86Unpckl, v16f32, memopv8f64,
+      VR512, f512mem, "vunpcklps\t{$src2, $src1, $dst|$dst, $src1, $src2}",
+      SSEPackedSingle>, EVEX_V512, EVEX_CD8<32, CD8VF>;
+defm VUNPCKLPDZ: avx512_unpack_fp<0x14, X86Unpckl, v8f64, memopv8f64,
+      VR512, f512mem, "vunpcklpd\t{$src2, $src1, $dst|$dst, $src1, $src2}",
+      SSEPackedDouble>, OpSize, EVEX_V512, VEX_W, EVEX_CD8<64, CD8VF>;
+
+multiclass avx512_unpack_int<bits<8> opc, string OpcodeStr, SDNode OpNode,
+                        ValueType OpVT, RegisterClass RC, PatFrag memop_frag,
+                        X86MemOperand x86memop> {
+  def rr : AVX512BI<opc, MRMSrcReg, (outs RC:$dst),
+       (ins RC:$src1, RC:$src2),
+       !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
+       [(set RC:$dst, (OpVT (OpNode (OpVT RC:$src1), (OpVT RC:$src2))))], 
+       IIC_SSE_UNPCK>, EVEX_4V;
+  def rm : AVX512BI<opc, MRMSrcMem, (outs RC:$dst),
+       (ins RC:$src1, x86memop:$src2),
+       !strconcat(OpcodeStr, "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
+       [(set RC:$dst, (OpVT (OpNode (OpVT RC:$src1),
+                                     (bitconvert (memop_frag addr:$src2)))))],
+                                     IIC_SSE_UNPCK>, EVEX_4V;
+}
+defm VPUNPCKLDQZ  : avx512_unpack_int<0x62, "vpunpckldq", X86Unpckl, v16i32,
+                                VR512, memopv16i32, i512mem>, EVEX_V512,
+                                EVEX_CD8<32, CD8VF>;
+defm VPUNPCKLQDQZ : avx512_unpack_int<0x6C, "vpunpcklqdq", X86Unpckl, v8i64,
+                                VR512, memopv8i64, i512mem>, EVEX_V512,
+                                VEX_W, EVEX_CD8<64, CD8VF>;
+defm VPUNPCKHDQZ  : avx512_unpack_int<0x6A, "vpunpckhdq", X86Unpckh, v16i32,
+                                VR512, memopv16i32, i512mem>, EVEX_V512,
+                                EVEX_CD8<32, CD8VF>;
+defm VPUNPCKHQDQZ : avx512_unpack_int<0x6D, "vpunpckhqdq", X86Unpckh, v8i64,
+                                VR512, memopv8i64, i512mem>, EVEX_V512,
+                                VEX_W, EVEX_CD8<64, CD8VF>;
+
+//===----------------------------------------------------------------------===//
 // AVX-512  Logical Instructions
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/Target/X86/X86InstrInfo.cpp b/lib/Target/X86/X86InstrInfo.cpp
index 71df2bb..c4c090b 100644
--- a/lib/Target/X86/X86InstrInfo.cpp
+++ b/lib/Target/X86/X86InstrInfo.cpp
@@ -2939,7 +2939,6 @@
   if (X86::FR32XRegClass.contains(DestReg) && X86::GR32RegClass.contains(SrcReg))
     // Copy from a GR32 register to a FR32 register.
     return HasAVX512 ? X86::VMOVDI2SSZrr : (HasAVX ? X86::VMOVDI2SSrr : X86::MOVDI2SSrr);
-
   return 0;
 }
 
@@ -3781,6 +3780,8 @@
   case X86::AVX_SET0:
     assert(HasAVX && "AVX not supported");
     return Expand2AddrUndef(MIB, get(X86::VXORPSYrr));
+  case X86::AVX512_512_SET0:
+    return Expand2AddrUndef(MIB, get(X86::VPXORDZrr));
   case X86::V_SETALLONES:
     return Expand2AddrUndef(MIB, get(HasAVX ? X86::VPCMPEQDrr : X86::PCMPEQDrr));
   case X86::AVX2_SETALLONES:
@@ -3788,6 +3789,9 @@
   case X86::TEST8ri_NOREX:
     MI->setDesc(get(X86::TEST8ri));
     return true;
+  case X86::KSET0W: return Expand2AddrUndef(MIB, get(X86::KXORWrr));
+  case X86::KSET1B:
+  case X86::KSET1W: return Expand2AddrUndef(MIB, get(X86::KXNORWrr));
   }
   return false;
 }