[mips][msa] Mask vectors holding shift amounts

Masked vectors which hold shift amounts when creating the following nodes:
ISD::SHL, ISD::SRL or ISD::SRA.
Instructions that use said nodes, which have had their arguments altered are
sll, srl, sra, bneg, bclr and bset.

For said instructions, the shift amount or the bit position that is
specified in the corresponding vector elements will be interpreted as the
shift amount/bit position modulo the size of the element in bits.

The problem lies in compiling with -O2 enabled, where the instructions for
formats .w and .d are not generated, but are instead optimized away.
In this case, having shift amounts that are either negative or greater than
the element bit size results in generation of incorrect results when
constant folding.

We remedy this by masking the operands for the nodes mentioned above before
actually creating them, so that the final result is correct before placed
into the constant pool.

Patch by Stefan Maksimovic.

Differential Revision: https://reviews.llvm.org/D31331

llvm-svn: 300839
diff --git a/llvm/lib/Target/Mips/MipsMSAInstrInfo.td b/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
index 8b04fcb..bf79f0f 100644
--- a/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
+++ b/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
@@ -3781,6 +3781,80 @@
        ISA_MIPS1_NOT_32R6_64R6;
 }
 
+def vsplati64_imm_eq_63 : PatLeaf<(bitconvert (v4i32 (build_vector))), [{
+  APInt Imm;
+  SDNode *BV = N->getOperand(0).getNode();
+  EVT EltTy = N->getValueType(0).getVectorElementType();
+
+  return selectVSplat(BV, Imm, EltTy.getSizeInBits()) &&
+         Imm.getBitWidth() == EltTy.getSizeInBits() && Imm == 63;
+}]>;
+
+def immi32Cst7  : ImmLeaf<i32, [{return isUInt<32>(Imm) && Imm == 7;}]>;
+def immi32Cst15 : ImmLeaf<i32, [{return isUInt<32>(Imm) && Imm == 15;}]>;
+def immi32Cst31 : ImmLeaf<i32, [{return isUInt<32>(Imm) && Imm == 31;}]>;
+
+def vsplati8imm7 :   PatFrag<(ops node:$wt),
+                             (and node:$wt, (vsplati8 immi32Cst7))>;
+def vsplati16imm15 : PatFrag<(ops node:$wt),
+                             (and node:$wt, (vsplati16 immi32Cst15))>;
+def vsplati32imm31 : PatFrag<(ops node:$wt),
+                             (and node:$wt, (vsplati32 immi32Cst31))>;
+def vsplati64imm63 : PatFrag<(ops node:$wt),
+                             (and node:$wt, vsplati64_imm_eq_63)>;
+
+class MSAShiftPat<SDNode Node, ValueType VT, MSAInst Insn, dag Vec> :
+  MSAPat<(VT (Node VT:$ws, (VT (and VT:$wt, Vec)))),
+         (VT (Insn VT:$ws, VT:$wt))>;
+
+class MSABitPat<SDNode Node, ValueType VT, MSAInst Insn, PatFrag Frag> :
+  MSAPat<(VT (Node VT:$ws, (shl vsplat_imm_eq_1, (Frag VT:$wt)))),
+         (VT (Insn VT:$ws, VT:$wt))>;
+
+multiclass MSAShiftPats<SDNode Node, string Insn> {
+  def : MSAShiftPat<Node, v16i8, !cast<MSAInst>(Insn#_B),
+                    (vsplati8 immi32Cst7)>;
+  def : MSAShiftPat<Node, v8i16, !cast<MSAInst>(Insn#_H),
+                    (vsplati16 immi32Cst15)>;
+  def : MSAShiftPat<Node, v4i32, !cast<MSAInst>(Insn#_W),
+                    (vsplati32 immi32Cst31)>;
+  def : MSAPat<(v2i64 (Node v2i64:$ws, (v2i64 (and v2i64:$wt,
+                                                   vsplati64_imm_eq_63)))),
+               (v2i64 (!cast<MSAInst>(Insn#_D) v2i64:$ws, v2i64:$wt))>;
+}
+
+multiclass MSABitPats<SDNode Node, string Insn> {
+  def : MSABitPat<Node, v16i8, !cast<MSAInst>(Insn#_B), vsplati8imm7>;
+  def : MSABitPat<Node, v8i16, !cast<MSAInst>(Insn#_H), vsplati16imm15>;
+  def : MSABitPat<Node, v4i32, !cast<MSAInst>(Insn#_W), vsplati32imm31>;
+  def : MSAPat<(Node v2i64:$ws, (shl (v2i64 vsplati64_imm_eq_1),
+                                     (vsplati64imm63 v2i64:$wt))),
+               (v2i64 (!cast<MSAInst>(Insn#_D) v2i64:$ws, v2i64:$wt))>;
+}
+
+defm : MSAShiftPats<shl, "SLL">;
+defm : MSAShiftPats<srl, "SRL">;
+defm : MSAShiftPats<sra, "SRA">;
+defm : MSABitPats<xor, "BNEG">;
+defm : MSABitPats<or, "BSET">;
+
+def : MSAPat<(and v16i8:$ws, (xor (shl vsplat_imm_eq_1,
+                                       (vsplati8imm7 v16i8:$wt)),
+                                  immAllOnesV)),
+             (v16i8 (BCLR_B v16i8:$ws, v16i8:$wt))>;
+def : MSAPat<(and v8i16:$ws, (xor (shl vsplat_imm_eq_1,
+                                       (vsplati16imm15 v8i16:$wt)),
+                             immAllOnesV)),
+             (v8i16 (BCLR_H v8i16:$ws, v8i16:$wt))>;
+def : MSAPat<(and v4i32:$ws, (xor (shl vsplat_imm_eq_1,
+                                       (vsplati32imm31 v4i32:$wt)),
+                             immAllOnesV)),
+             (v4i32 (BCLR_W v4i32:$ws, v4i32:$wt))>;
+def : MSAPat<(and v2i64:$ws, (xor (shl (v2i64 vsplati64_imm_eq_1),
+                                       (vsplati64imm63 v2i64:$wt)),
+                                  (bitconvert (v4i32 immAllOnesV)))),
+             (v2i64 (BCLR_D v2i64:$ws, v2i64:$wt))>;
+
 // Vector extraction with fixed index.
 //
 // Extracting 32-bit values on MSA32 should always use COPY_S_W rather than
diff --git a/llvm/lib/Target/Mips/MipsSEISelLowering.cpp b/llvm/lib/Target/Mips/MipsSEISelLowering.cpp
index e2da847..bf7f079 100644
--- a/llvm/lib/Target/Mips/MipsSEISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsSEISelLowering.cpp
@@ -1547,11 +1547,24 @@
   return DAG.getNode(Opc, DL, VecTy, Op->getOperand(1), Exp2Imm);
 }
 
+static SDValue truncateVecElts(SDValue Op, SelectionDAG &DAG) {
+  SDLoc DL(Op);
+  EVT ResTy = Op->getValueType(0);
+  SDValue Vec = Op->getOperand(2);
+  bool BigEndian = !DAG.getSubtarget().getTargetTriple().isLittleEndian();
+  MVT ResEltTy = ResTy == MVT::v2i64 ? MVT::i64 : MVT::i32;
+  SDValue ConstValue = DAG.getConstant(Vec.getScalarValueSizeInBits() - 1,
+                                       DL, ResEltTy);
+  SDValue SplatVec = getBuildVectorSplat(ResTy, ConstValue, BigEndian, DAG);
+
+  return DAG.getNode(ISD::AND, DL, ResTy, Vec, SplatVec);
+}
+
 static SDValue lowerMSABitClear(SDValue Op, SelectionDAG &DAG) {
   EVT ResTy = Op->getValueType(0);
   SDLoc DL(Op);
   SDValue One = DAG.getConstant(1, DL, ResTy);
-  SDValue Bit = DAG.getNode(ISD::SHL, DL, ResTy, One, Op->getOperand(2));
+  SDValue Bit = DAG.getNode(ISD::SHL, DL, ResTy, One, truncateVecElts(Op, DAG));
 
   return DAG.getNode(ISD::AND, DL, ResTy, Op->getOperand(1),
                      DAG.getNOT(DL, Bit, ResTy));
@@ -1687,7 +1700,7 @@
 
     return DAG.getNode(ISD::XOR, DL, VecTy, Op->getOperand(1),
                        DAG.getNode(ISD::SHL, DL, VecTy, One,
-                                   Op->getOperand(2)));
+                                   truncateVecElts(Op, DAG)));
   }
   case Intrinsic::mips_bnegi_b:
   case Intrinsic::mips_bnegi_h:
@@ -1723,7 +1736,7 @@
 
     return DAG.getNode(ISD::OR, DL, VecTy, Op->getOperand(1),
                        DAG.getNode(ISD::SHL, DL, VecTy, One,
-                                   Op->getOperand(2)));
+                                   truncateVecElts(Op, DAG)));
   }
   case Intrinsic::mips_bseti_b:
   case Intrinsic::mips_bseti_h:
@@ -2210,7 +2223,7 @@
   case Intrinsic::mips_sll_w:
   case Intrinsic::mips_sll_d:
     return DAG.getNode(ISD::SHL, DL, Op->getValueType(0), Op->getOperand(1),
-                       Op->getOperand(2));
+                       truncateVecElts(Op, DAG));
   case Intrinsic::mips_slli_b:
   case Intrinsic::mips_slli_h:
   case Intrinsic::mips_slli_w:
@@ -2240,7 +2253,7 @@
   case Intrinsic::mips_sra_w:
   case Intrinsic::mips_sra_d:
     return DAG.getNode(ISD::SRA, DL, Op->getValueType(0), Op->getOperand(1),
-                       Op->getOperand(2));
+                       truncateVecElts(Op, DAG));
   case Intrinsic::mips_srai_b:
   case Intrinsic::mips_srai_h:
   case Intrinsic::mips_srai_w:
@@ -2270,7 +2283,7 @@
   case Intrinsic::mips_srl_w:
   case Intrinsic::mips_srl_d:
     return DAG.getNode(ISD::SRL, DL, Op->getValueType(0), Op->getOperand(1),
-                       Op->getOperand(2));
+                       truncateVecElts(Op, DAG));
   case Intrinsic::mips_srli_b:
   case Intrinsic::mips_srli_h:
   case Intrinsic::mips_srli_w: