AVX-512 set: added mask operations, lowering BUILD_VECTOR for i1 vector types.
Added intrinsics and tests.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@187717 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 1d5c6e5..90326cb 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -5606,6 +5606,108 @@
   return NV;
 }
 
+// Lower BUILD_VECTOR operation for v8i1 and v16i1 types.
+SDValue
+X86TargetLowering::LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const {
+
+  EVT VT = Op.getValueType();
+  assert((VT.getVectorElementType() == MVT::i1) && (VT.getSizeInBits() <= 16) &&
+         "Unexpected type in LowerBUILD_VECTORvXi1!");
+
+  SDLoc dl(Op);
+  if (ISD::isBuildVectorAllZeros(Op.getNode())) {
+    SDValue Cst = DAG.getTargetConstant(0, MVT::i1);
+    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst,
+                      Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
+    return DAG.getNode(ISD::BUILD_VECTOR, dl, VT,
+                       Ops, VT.getVectorNumElements());
+  }
+
+  if (ISD::isBuildVectorAllOnes(Op.getNode())) {
+    SDValue Cst = DAG.getTargetConstant(1, MVT::i1);
+    SDValue Ops[] = { Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst,
+                      Cst, Cst, Cst, Cst, Cst, Cst, Cst, Cst };
+    return DAG.getNode(ISD::BUILD_VECTOR, dl, VT,
+                       Ops, VT.getVectorNumElements());
+  }
+
+  bool AllContants = true;
+  uint64_t Immediate = 0;
+  for (unsigned idx = 0, e = Op.getNumOperands(); idx < e; ++idx) {
+    SDValue In = Op.getOperand(idx);
+    if (In.getOpcode() == ISD::UNDEF)
+      continue;
+    if (!isa<ConstantSDNode>(In)) {
+      AllContants = false;
+      break;
+    }
+    if (cast<ConstantSDNode>(In)->getZExtValue())
+      Immediate |= (1 << idx);
+  }
+
+  if (AllContants) {
+    SDValue FullMask = DAG.getNode(ISD::BITCAST, dl, MVT::v16i1,
+      DAG.getConstant(Immediate, MVT::i16));
+    return DAG.getNode(ISD::EXTRACT_SUBVECTOR, dl, VT, FullMask,
+		       DAG.getIntPtrConstant(0));
+  }
+
+  if (!isSplatVector(Op.getNode()))
+    llvm_unreachable("Unsupported predicate operation");
+
+  SDValue In = Op.getOperand(0);
+  SDValue EFLAGS, X86CC;
+  if (In.getOpcode() == ISD::SETCC) {
+    SDValue Op0 = In.getOperand(0);
+    SDValue Op1 = In.getOperand(1);
+    ISD::CondCode CC = cast<CondCodeSDNode>(In.getOperand(2))->get();
+    bool isFP = Op1.getValueType().isFloatingPoint();
+    unsigned X86CCVal = TranslateX86CC(CC, isFP, Op0, Op1, DAG);
+
+    assert(X86CCVal != X86::COND_INVALID && "Unsupported predicate operation");
+
+    X86CC = DAG.getConstant(X86CCVal, MVT::i8);
+    EFLAGS = EmitCmp(Op0, Op1, X86CCVal, DAG);
+    EFLAGS = ConvertCmpIfNecessary(EFLAGS, DAG);
+  } else if (In.getOpcode() == X86ISD::SETCC) {
+    X86CC = In.getOperand(0);
+    EFLAGS = In.getOperand(1);
+  } else {
+    // The algorithm:
+    //   Bit1 = In & 0x1
+    //   if (Bit1 != 0)
+    //     ZF = 0
+    //   else
+    //     ZF = 1
+    //   if (ZF == 0)
+    //     res = allOnes ### CMOVNE -1, %res
+    //   else
+    //     res = allZero
+    MVT InVT = In.getValueType().getSimpleVT();
+    SDValue Bit1 = DAG.getNode(ISD::AND, dl, InVT, In, DAG.getConstant(1, InVT));
+    EFLAGS = EmitTest(Bit1, X86::COND_NE, DAG);
+    X86CC = DAG.getConstant(X86::COND_NE, MVT::i8);
+  }
+
+  if (VT == MVT::v16i1) {
+    SDValue Cst1 = DAG.getConstant(-1, MVT::i16);
+    SDValue Cst0 = DAG.getConstant(0, MVT::i16);
+    SDValue CmovOp = DAG.getNode(X86ISD::CMOV, dl, MVT::i16,
+          Cst0, Cst1, X86CC, EFLAGS);
+    return DAG.getNode(ISD::BITCAST, dl, VT, CmovOp);
+  }
+
+  if (VT == MVT::v8i1) {
+    SDValue Cst1 = DAG.getConstant(-1, MVT::i32);
+    SDValue Cst0 = DAG.getConstant(0, MVT::i32);
+    SDValue CmovOp = DAG.getNode(X86ISD::CMOV, dl, MVT::i32,
+          Cst0, Cst1, X86CC, EFLAGS);
+    CmovOp = DAG.getNode(ISD::TRUNCATE, dl, MVT::i8, CmovOp);
+    return DAG.getNode(ISD::BITCAST, dl, VT, CmovOp);
+  }
+  llvm_unreachable("Unsupported predicate operation");
+}
+
 SDValue
 X86TargetLowering::LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const {
   SDLoc dl(Op);
@@ -5614,6 +5716,10 @@
   MVT ExtVT = VT.getVectorElementType();
   unsigned NumElems = Op.getNumOperands();
 
+  // Generate vectors for predicate vectors.
+  if (VT.getScalarType() == MVT::i1 && Subtarget->hasAVX512())
+    return LowerBUILD_VECTORvXi1(Op, DAG);
+
   // Vectors containing all zeros can be matched by pxor and xorps later
   if (ISD::isBuildVectorAllZeros(Op.getNode())) {
     // Canonicalize this to <4 x i32> to 1) ensure the zero vectors are CSE'd
diff --git a/lib/Target/X86/X86ISelLowering.h b/lib/Target/X86/X86ISelLowering.h
index 03765c1..e09104a 100644
--- a/lib/Target/X86/X86ISelLowering.h
+++ b/lib/Target/X86/X86ISelLowering.h
@@ -294,6 +294,10 @@
       // TESTP - Vector packed fp sign bitwise comparisons
       TESTP,
 
+      // OR/AND test for masks
+      KORTEST,
+      KTEST,
+
       // Several flavors of instructions with vector shuffle behaviors.
       PALIGNR,
       PSHUFD,
@@ -826,6 +830,7 @@
     SDValue LowerAsSplatVectorLoad(SDValue SrcOp, EVT VT, SDLoc dl,
                                    SelectionDAG &DAG) const;
     SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
+    SDValue LowerBUILD_VECTORvXi1(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
     SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
diff --git a/lib/Target/X86/X86InstrAVX512.td b/lib/Target/X86/X86InstrAVX512.td
index db90341..18ccdc3 100644
--- a/lib/Target/X86/X86InstrAVX512.td
+++ b/lib/Target/X86/X86InstrAVX512.td
@@ -346,3 +346,244 @@
       "vextractps{z}\t{$src2, $src1, $dst|$dst, $src1, $src2}",
       [(store (extractelt (bc_v4i32 (v4f32 VR128X:$src1)), imm:$src2),
                           addr:$dst)]>, EVEX;
+
+// Mask register copy, including
+// - copy between mask registers
+// - load/store mask registers
+// - copy from GPR to mask register and vice versa
+//
+multiclass avx512_mask_mov<bits<8> opc_kk, bits<8> opc_km, bits<8> opc_mk,
+                         string OpcodeStr, RegisterClass KRC,
+                         ValueType vt, X86MemOperand x86memop> {
+  let neverHasSideEffects = 1 in {
+    def kk : I<opc_kk, MRMSrcReg, (outs KRC:$dst), (ins KRC:$src),
+               !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"), []>;
+    let mayLoad = 1 in
+    def km : I<opc_km, MRMSrcMem, (outs KRC:$dst), (ins x86memop:$src),
+               !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),
+               [(set KRC:$dst, (vt (load addr:$src)))]>;
+    let mayStore = 1 in
+    def mk : I<opc_mk, MRMDestMem, (outs), (ins x86memop:$dst, KRC:$src),
+               !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"), []>;
+  }
+}
+
+multiclass avx512_mask_mov_gpr<bits<8> opc_kr, bits<8> opc_rk,
+                             string OpcodeStr,
+                             RegisterClass KRC, RegisterClass GRC> {
+  let neverHasSideEffects = 1 in {
+    def kr : I<opc_kr, MRMSrcReg, (outs KRC:$dst), (ins GRC:$src),
+               !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"), []>;
+    def rk : I<opc_rk, MRMSrcReg, (outs GRC:$dst), (ins KRC:$src),
+               !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"), []>;
+  }
+}
+
+let Predicates = [HasAVX512] in {
+  defm KMOVW : avx512_mask_mov<0x90, 0x90, 0x91, "kmovw", VK16, v16i1, i16mem>,
+               VEX, TB;
+  defm KMOVW : avx512_mask_mov_gpr<0x92, 0x93, "kmovw", VK16, GR32>,
+               VEX, TB;
+}
+
+let Predicates = [HasAVX512] in {
+  // GR16 from/to 16-bit mask
+  def : Pat<(v16i1 (bitconvert (i16 GR16:$src))),
+            (KMOVWkr (SUBREG_TO_REG (i32 0), GR16:$src, sub_16bit))>;
+  def : Pat<(i16 (bitconvert (v16i1 VK16:$src))),
+            (EXTRACT_SUBREG (KMOVWrk VK16:$src), sub_16bit)>;
+
+  // Store kreg in memory
+  def : Pat<(store (v16i1 VK16:$src), addr:$dst),
+            (KMOVWmk addr:$dst, VK16:$src)>;
+
+  def : Pat<(store (v8i1 VK8:$src), addr:$dst),
+            (KMOVWmk addr:$dst, (v16i1 (COPY_TO_REGCLASS VK8:$src, VK16)))>;
+}
+// With AVX-512 only, 8-bit mask is promoted to 16-bit mask.
+let Predicates = [HasAVX512] in {
+  // GR from/to 8-bit mask without native support
+  def : Pat<(v8i1 (bitconvert (i8 GR8:$src))),
+            (COPY_TO_REGCLASS
+              (KMOVWkr (SUBREG_TO_REG (i32 0), GR8:$src, sub_8bit)),
+              VK8)>;
+  def : Pat<(i8 (bitconvert (v8i1 VK8:$src))),
+            (EXTRACT_SUBREG
+              (KMOVWrk (COPY_TO_REGCLASS VK8:$src, VK16)),
+              sub_8bit)>;
+}
+
+// Mask unary operation
+// - KNOT
+multiclass avx512_mask_unop<bits<8> opc, string OpcodeStr,
+                         RegisterClass KRC, SDPatternOperator OpNode> {
+  let Predicates = [HasAVX512] in
+    def rr : I<opc, MRMSrcReg, (outs KRC:$dst), (ins KRC:$src),
+               !strconcat(OpcodeStr, "\t{$src, $dst|$dst, $src}"),
+               [(set KRC:$dst, (OpNode KRC:$src))]>;
+}
+
+multiclass avx512_mask_unop_w<bits<8> opc, string OpcodeStr,
+                               SDPatternOperator OpNode> {
+  defm W : avx512_mask_unop<opc, !strconcat(OpcodeStr, "w"), VK16, OpNode>,
+                          VEX, TB;
+}
+
+defm KNOT : avx512_mask_unop_w<0x44, "knot", not>;
+
+def : Pat<(xor VK16:$src1, (v16i1 immAllOnesV)), (KNOTWrr VK16:$src1)>;
+def : Pat<(xor VK8:$src1,  (v8i1 immAllOnesV)),
+          (COPY_TO_REGCLASS (KNOTWrr (COPY_TO_REGCLASS VK8:$src1, VK16)), VK8)>;
+
+// With AVX-512, 8-bit mask is promoted to 16-bit mask.
+def : Pat<(not VK8:$src),
+          (COPY_TO_REGCLASS
+            (KNOTWrr (COPY_TO_REGCLASS VK8:$src, VK16)), VK8)>;
+
+// Mask binary operation
+// - KADD, KAND, KANDN, KOR, KXNOR, KXOR
+multiclass avx512_mask_binop<bits<8> opc, string OpcodeStr,
+                           RegisterClass KRC, SDPatternOperator OpNode> {
+  let Predicates = [HasAVX512] in
+    def rr : I<opc, MRMSrcReg, (outs KRC:$dst), (ins KRC:$src1, KRC:$src2),
+               !strconcat(OpcodeStr,
+                          "\t{$src2, $src1, $dst|$dst, $src1, $src2}"),
+               [(set KRC:$dst, (OpNode KRC:$src1, KRC:$src2))]>;
+}
+
+multiclass avx512_mask_binop_w<bits<8> opc, string OpcodeStr,
+                             SDPatternOperator OpNode> {
+  defm W : avx512_mask_binop<opc, !strconcat(OpcodeStr, "w"), VK16, OpNode>,
+                           VEX_4V, VEX_L, TB;
+}
+
+def andn : PatFrag<(ops node:$i0, node:$i1), (and (not node:$i0), node:$i1)>;
+def xnor : PatFrag<(ops node:$i0, node:$i1), (not (xor node:$i0, node:$i1))>;
+
+let isCommutable = 1 in {
+  defm KADD  : avx512_mask_binop_w<0x4a, "kadd",  add>;
+  defm KAND  : avx512_mask_binop_w<0x41, "kand",  and>;
+  let isCommutable = 0 in
+  defm KANDN : avx512_mask_binop_w<0x42, "kandn", andn>;
+  defm KOR   : avx512_mask_binop_w<0x45, "kor",   or>;
+  defm KXNOR : avx512_mask_binop_w<0x46, "kxnor", xnor>;
+  defm KXOR  : avx512_mask_binop_w<0x47, "kxor",  xor>;
+}
+
+multiclass avx512_mask_binop_int<string IntName, string InstName> {
+  let Predicates = [HasAVX512] in
+    def : Pat<(!cast<Intrinsic>("int_x86_"##IntName##"_v16i1")
+                VK16:$src1, VK16:$src2),
+              (!cast<Instruction>(InstName##"Wrr") VK16:$src1, VK16:$src2)>;
+}
+
+defm : avx512_mask_binop_int<"kadd",  "KADD">;
+defm : avx512_mask_binop_int<"kand",  "KAND">;
+defm : avx512_mask_binop_int<"kandn", "KANDN">;
+defm : avx512_mask_binop_int<"kor",   "KOR">;
+defm : avx512_mask_binop_int<"kxnor", "KXNOR">;
+defm : avx512_mask_binop_int<"kxor",  "KXOR">;
+// With AVX-512, 8-bit mask is promoted to 16-bit mask.
+multiclass avx512_binop_pat<SDPatternOperator OpNode, Instruction Inst> {
+  let Predicates = [HasAVX512] in
+    def : Pat<(OpNode VK8:$src1, VK8:$src2),
+              (COPY_TO_REGCLASS
+                (Inst (COPY_TO_REGCLASS VK8:$src1, VK16),
+                      (COPY_TO_REGCLASS VK8:$src2, VK16)), VK8)>;
+}
+
+defm : avx512_binop_pat<and,  KANDWrr>;
+defm : avx512_binop_pat<andn, KANDNWrr>;
+defm : avx512_binop_pat<or,   KORWrr>;
+defm : avx512_binop_pat<xnor, KXNORWrr>;
+defm : avx512_binop_pat<xor,  KXORWrr>;
+
+// Mask unpacking
+multiclass avx512_mask_unpck<bits<8> opc, string OpcodeStr,
+                           RegisterClass KRC1, RegisterClass KRC2> {
+  let Predicates = [HasAVX512] in
+    def rr : I<opc, MRMSrcReg, (outs KRC1:$dst), (ins KRC2:$src1, KRC2:$src2),
+               !strconcat(OpcodeStr,
+                          "\t{$src2, $src1, $dst|$dst, $src1, $src2}"), []>;
+}
+
+multiclass avx512_mask_unpck_bw<bits<8> opc, string OpcodeStr> {
+  defm BW : avx512_mask_unpck<opc, !strconcat(OpcodeStr, "bw"), VK16, VK8>,
+                            VEX_4V, VEX_L, OpSize, TB;
+}
+
+defm KUNPCK : avx512_mask_unpck_bw<0x4b, "kunpck">;
+
+multiclass avx512_mask_unpck_int<string IntName, string InstName> {
+  let Predicates = [HasAVX512] in
+    def : Pat<(!cast<Intrinsic>("int_x86_"##IntName##"_v16i1")
+                VK8:$src1, VK8:$src2),
+              (!cast<Instruction>(InstName##"BWrr") VK8:$src1, VK8:$src2)>;
+}
+
+defm : avx512_mask_unpck_int<"kunpck", "KUNPCK">;
+// Mask bit testing
+multiclass avx512_mask_testop<bits<8> opc, string OpcodeStr, RegisterClass KRC,
+                            SDNode OpNode> {
+  let Predicates = [HasAVX512], Defs = [EFLAGS] in
+    def rr : I<opc, MRMSrcReg, (outs), (ins KRC:$src1, KRC:$src2),
+               !strconcat(OpcodeStr, "\t{$src2, $src1|$src1, $src2}"),
+               [(set EFLAGS, (OpNode KRC:$src1, KRC:$src2))]>;
+}
+
+multiclass avx512_mask_testop_w<bits<8> opc, string OpcodeStr, SDNode OpNode> {
+  defm W : avx512_mask_testop<opc, !strconcat(OpcodeStr, "w"), VK16, OpNode>,
+                            VEX, TB;
+}
+
+defm KORTEST : avx512_mask_testop_w<0x98, "kortest", X86kortest>;
+defm KTEST   : avx512_mask_testop_w<0x99, "ktest", X86ktest>;
+
+// Mask shift
+multiclass avx512_mask_shiftop<bits<8> opc, string OpcodeStr, RegisterClass KRC,
+                             SDNode OpNode> {
+  let Predicates = [HasAVX512] in
+    def ri : Ii8<opc, MRMSrcReg, (outs KRC:$dst), (ins KRC:$src, i8imm:$imm),
+                 !strconcat(OpcodeStr,
+                            "\t{$imm, $src, $dst|$dst, $src, $imm}"),
+                            [(set KRC:$dst, (OpNode KRC:$src, (i8 imm:$imm)))]>;
+}
+
+multiclass avx512_mask_shiftop_w<bits<8> opc1, bits<8> opc2, string OpcodeStr,
+                               SDNode OpNode> {
+  defm W : avx512_mask_shiftop<opc1, !strconcat(OpcodeStr, "w"), VK16, OpNode>,
+                             VEX, OpSize, TA, VEX_W;
+}
+
+defm KSHIFTL : avx512_mask_shiftop_w<0x32, 0x33, "kshiftl", shl>;
+defm KSHIFTR : avx512_mask_shiftop_w<0x30, 0x31, "kshiftr", srl>;
+
+// Mask setting all 0s or 1s
+multiclass avx512_mask_setop<RegisterClass KRC, ValueType VT, PatFrag Val> {
+  let Predicates = [HasAVX512] in
+    let isReMaterializable = 1, isAsCheapAsAMove = 1, isPseudo = 1 in
+      def #NAME# : I<0, Pseudo, (outs KRC:$dst), (ins), "",
+                     [(set KRC:$dst, (VT Val))]>;
+}
+
+multiclass avx512_mask_setop_w<PatFrag Val> {
+  defm B : avx512_mask_setop<VK8,  v8i1, Val>;
+  defm W : avx512_mask_setop<VK16, v16i1, Val>;
+}
+
+defm KSET0 : avx512_mask_setop_w<immAllZerosV>;
+defm KSET1 : avx512_mask_setop_w<immAllOnesV>;
+
+// With AVX-512 only, 8-bit mask is promoted to 16-bit mask.
+let Predicates = [HasAVX512] in {
+  def : Pat<(v8i1 immAllZerosV), (COPY_TO_REGCLASS (KSET0W), VK8)>;
+  def : Pat<(v8i1 immAllOnesV),  (COPY_TO_REGCLASS (KSET1W), VK8)>;
+}
+def : Pat<(v8i1 (extract_subvector (v16i1 VK16:$src), (iPTR 0))),
+          (v8i1 (COPY_TO_REGCLASS VK16:$src, VK8))>;
+
+def : Pat<(v16i1 (insert_subvector undef, (v8i1 VK8:$src), (iPTR 0))),
+          (v16i1 (COPY_TO_REGCLASS VK8:$src, VK16))>;
+
+def : Pat<(v8i1 (extract_subvector (v16i1 VK16:$src), (iPTR 8))),
+          (v8i1 (COPY_TO_REGCLASS (KSHIFTRWri VK16:$src, (i8 8)), VK8))>;
diff --git a/lib/Target/X86/X86InstrFragmentsSIMD.td b/lib/Target/X86/X86InstrFragmentsSIMD.td
index 4aa8777..db53af0 100644
--- a/lib/Target/X86/X86InstrFragmentsSIMD.td
+++ b/lib/Target/X86/X86InstrFragmentsSIMD.td
@@ -138,6 +138,8 @@
 def X86subus   : SDNode<"X86ISD::SUBUS", SDTIntBinOp>;
 def X86ptest   : SDNode<"X86ISD::PTEST", SDTX86CmpPTest>;
 def X86testp   : SDNode<"X86ISD::TESTP", SDTX86CmpPTest>;
+def X86kortest : SDNode<"X86ISD::KORTEST", SDTX86CmpPTest>;
+def X86ktest   : SDNode<"X86ISD::KTEST", SDTX86CmpPTest>;
 
 def X86pmuludq : SDNode<"X86ISD::PMULUDQ",
                         SDTypeProfile<1, 2, [SDTCisVec<0>, SDTCisVec<1>,