[mips][msa] Added support for matching splati from normal IR (i.e. not intrinsics)

Updated some of the vshf since they (correctly) emit splati's now

llvm-svn: 191511
diff --git a/llvm/lib/Target/Mips/InstPrinter/MipsInstPrinter.cpp b/llvm/lib/Target/Mips/InstPrinter/MipsInstPrinter.cpp
index 2612e3b..7884589 100644
--- a/llvm/lib/Target/Mips/InstPrinter/MipsInstPrinter.cpp
+++ b/llvm/lib/Target/Mips/InstPrinter/MipsInstPrinter.cpp
@@ -184,6 +184,15 @@
     printOperand(MI, opNum, O);
 }
 
+void MipsInstPrinter::printUnsignedImm8(const MCInst *MI, int opNum,
+                                        raw_ostream &O) {
+  const MCOperand &MO = MI->getOperand(opNum);
+  if (MO.isImm())
+    O << (unsigned short int)(unsigned char)MO.getImm();
+  else
+    printOperand(MI, opNum, O);
+}
+
 void MipsInstPrinter::
 printMemOperand(const MCInst *MI, int opNum, raw_ostream &O) {
   // Load/Store memory operands -- imm($reg)
diff --git a/llvm/lib/Target/Mips/InstPrinter/MipsInstPrinter.h b/llvm/lib/Target/Mips/InstPrinter/MipsInstPrinter.h
index 11590ad..f75ae24 100644
--- a/llvm/lib/Target/Mips/InstPrinter/MipsInstPrinter.h
+++ b/llvm/lib/Target/Mips/InstPrinter/MipsInstPrinter.h
@@ -93,6 +93,7 @@
 private:
   void printOperand(const MCInst *MI, unsigned OpNo, raw_ostream &O);
   void printUnsignedImm(const MCInst *MI, int opNum, raw_ostream &O);
+  void printUnsignedImm8(const MCInst *MI, int opNum, raw_ostream &O);
   void printMemOperand(const MCInst *MI, int opNum, raw_ostream &O);
   void printMemOperandEA(const MCInst *MI, int opNum, raw_ostream &O);
   void printFCCOperand(const MCInst *MI, int opNum, raw_ostream &O);
diff --git a/llvm/lib/Target/Mips/MSA.txt b/llvm/lib/Target/Mips/MSA.txt
index 398c7af..a4f320a 100644
--- a/llvm/lib/Target/Mips/MSA.txt
+++ b/llvm/lib/Target/Mips/MSA.txt
@@ -28,3 +28,7 @@
 ilvr.d, ilvod.d, pckod.d:
         It is not possible to emit ilvr.d, or pckod.d since ilvod.d covers the
         same shuffle. ilvod.d will be emitted instead.
+
+splati.w:
+        It is not possible to emit splati.w since shf.w covers the same cases.
+        shf.w will be emitted instead.
diff --git a/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp b/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp
index 7766fea..0d239fd 100644
--- a/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp
+++ b/llvm/lib/Target/Mips/MipsISelDAGToDAG.cpp
@@ -104,6 +104,16 @@
   return false;
 }
 
+bool MipsDAGToDAGISel::selectVSplatUimm1(SDValue N, SDValue &Imm) const {
+  llvm_unreachable("Unimplemented function.");
+  return false;
+}
+
+bool MipsDAGToDAGISel::selectVSplatUimm2(SDValue N, SDValue &Imm) const {
+  llvm_unreachable("Unimplemented function.");
+  return false;
+}
+
 bool MipsDAGToDAGISel::selectVSplatUimm3(SDValue N, SDValue &Imm) const {
   llvm_unreachable("Unimplemented function.");
   return false;
diff --git a/llvm/lib/Target/Mips/MipsISelDAGToDAG.h b/llvm/lib/Target/Mips/MipsISelDAGToDAG.h
index 208701e..e5695c4 100644
--- a/llvm/lib/Target/Mips/MipsISelDAGToDAG.h
+++ b/llvm/lib/Target/Mips/MipsISelDAGToDAG.h
@@ -78,6 +78,10 @@
 
   /// \brief Select constant vector splats.
   virtual bool selectVSplat(SDNode *N, APInt &Imm) const;
+  /// \brief Select constant vector splats whose value fits in a uimm1.
+  virtual bool selectVSplatUimm1(SDValue N, SDValue &Imm) const;
+  /// \brief Select constant vector splats whose value fits in a uimm2.
+  virtual bool selectVSplatUimm2(SDValue N, SDValue &Imm) const;
   /// \brief Select constant vector splats whose value fits in a uimm3.
   virtual bool selectVSplatUimm3(SDValue N, SDValue &Imm) const;
   /// \brief Select constant vector splats whose value fits in a uimm4.
diff --git a/llvm/lib/Target/Mips/MipsMSAInstrInfo.td b/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
index 76fcdcc..7ed2f84 100644
--- a/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
+++ b/llvm/lib/Target/Mips/MipsMSAInstrInfo.td
@@ -77,6 +77,14 @@
 
 def simm10 : Operand<i32>;
 
+def vsplat_uimm1 : Operand<vAny> {
+  let PrintMethod = "printUnsignedImm8";
+}
+
+def vsplat_uimm2 : Operand<vAny> {
+  let PrintMethod = "printUnsignedImm8";
+}
+
 def vsplat_uimm3 : Operand<vAny> {
   let PrintMethod = "printUnsignedImm";
 }
@@ -222,6 +230,10 @@
                                          "selectVSplatUimm3",
                                          [build_vector, bitconvert]>;
 
+def vsplati8_uimm4 : SplatComplexPattern<vsplat_uimm4, v16i8, 1,
+                                         "selectVSplatUimm4",
+                                         [build_vector, bitconvert]>;
+
 def vsplati8_uimm5 : SplatComplexPattern<vsplat_uimm5, v16i8, 1,
                                          "selectVSplatUimm5",
                                          [build_vector, bitconvert]>;
@@ -234,6 +246,10 @@
                                          "selectVSplatSimm5",
                                          [build_vector, bitconvert]>;
 
+def vsplati16_uimm3 : SplatComplexPattern<vsplat_uimm3, v8i16, 1,
+                                          "selectVSplatUimm3",
+                                          [build_vector, bitconvert]>;
+
 def vsplati16_uimm4 : SplatComplexPattern<vsplat_uimm4, v8i16, 1,
                                           "selectVSplatUimm4",
                                           [build_vector, bitconvert]>;
@@ -246,6 +262,10 @@
                                           "selectVSplatSimm5",
                                           [build_vector, bitconvert]>;
 
+def vsplati32_uimm2 : SplatComplexPattern<vsplat_uimm2, v4i32, 1,
+                                          "selectVSplatUimm2",
+                                          [build_vector, bitconvert]>;
+
 def vsplati32_uimm5 : SplatComplexPattern<vsplat_uimm5, v4i32, 1,
                                           "selectVSplatUimm5",
                                           [build_vector, bitconvert]>;
@@ -254,6 +274,10 @@
                                           "selectVSplatSimm5",
                                           [build_vector, bitconvert]>;
 
+def vsplati64_uimm1 : SplatComplexPattern<vsplat_uimm1, v2i64, 1,
+                                          "selectVSplatUimm1",
+                                          [build_vector, bitconvert]>;
+
 def vsplati64_uimm5 : SplatComplexPattern<vsplat_uimm5, v2i64, 1,
                                           "selectVSplatUimm5",
                                           [build_vector, bitconvert]>;
@@ -1235,6 +1259,18 @@
   InstrItinClass Itinerary = itin;
 }
 
+class MSA_ELM_SPLAT_DESC_BASE<string instr_asm, SplatComplexPattern SplatImm,
+                              RegisterClass RCWD,
+                              RegisterClass RCWS = RCWD,
+                              InstrItinClass itin = NoItinerary> {
+  dag OutOperandList = (outs RCWD:$wd);
+  dag InOperandList = (ins RCWS:$ws, SplatImm.OpClass:$u3);
+  string AsmString = !strconcat(instr_asm, "\t$wd, $ws[$u3]");
+  list<dag> Pattern = [(set RCWD:$wd, (MipsVSHF SplatImm:$u3, RCWS:$ws,
+                                                RCWS:$ws))];
+  InstrItinClass Itinerary = itin;
+}
+
 class MSA_VEC_PSEUDO_BASE<SDPatternOperator OpNode, RegisterClass RCWD,
                           RegisterClass RCWS = RCWD,
                           RegisterClass RCWT = RCWD> :
@@ -2172,14 +2208,14 @@
 class SPLAT_D_DESC : MSA_3R_DESC_BASE<"splat.d", int_mips_splat_d, MSA128DOpnd,
                                       MSA128DOpnd, GPR32Opnd>;
 
-class SPLATI_B_DESC : MSA_BIT_B_DESC_BASE<"splati.b", int_mips_splati_b,
-                                          MSA128B>;
-class SPLATI_H_DESC : MSA_BIT_H_DESC_BASE<"splati.h", int_mips_splati_h,
-                                          MSA128H>;
-class SPLATI_W_DESC : MSA_BIT_W_DESC_BASE<"splati.w", int_mips_splati_w,
-                                          MSA128W>;
-class SPLATI_D_DESC : MSA_BIT_D_DESC_BASE<"splati.d", int_mips_splati_d,
-                                          MSA128D>;
+class SPLATI_B_DESC : MSA_ELM_SPLAT_DESC_BASE<"splati.b", vsplati8_uimm4,
+                                              MSA128B>;
+class SPLATI_H_DESC : MSA_ELM_SPLAT_DESC_BASE<"splati.h", vsplati16_uimm3,
+                                              MSA128H>;
+class SPLATI_W_DESC : MSA_ELM_SPLAT_DESC_BASE<"splati.w", vsplati32_uimm2,
+                                              MSA128W>;
+class SPLATI_D_DESC : MSA_ELM_SPLAT_DESC_BASE<"splati.d", vsplati64_uimm1,
+                                              MSA128D>;
 
 class SRA_B_DESC : MSA_3R_DESC_BASE<"sra.b", sra, MSA128BOpnd>;
 class SRA_H_DESC : MSA_3R_DESC_BASE<"sra.h", sra, MSA128HOpnd>;
diff --git a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp
index 8a9481c..6277b6b 100644
--- a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp
+++ b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.cpp
@@ -451,6 +451,16 @@
 
 // Select constant vector splats.
 bool MipsSEDAGToDAGISel::
+selectVSplatUimm1(SDValue N, SDValue &Imm) const {
+  return selectVSplatCommon(N, Imm, false, 1);
+}
+
+bool MipsSEDAGToDAGISel::
+selectVSplatUimm2(SDValue N, SDValue &Imm) const {
+  return selectVSplatCommon(N, Imm, false, 2);
+}
+
+bool MipsSEDAGToDAGISel::
 selectVSplatUimm3(SDValue N, SDValue &Imm) const {
   return selectVSplatCommon(N, Imm, false, 3);
 }
diff --git a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h
index fe0da12..759d3af 100644
--- a/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h
+++ b/llvm/lib/Target/Mips/MipsSEISelDAGToDAG.h
@@ -63,6 +63,10 @@
   /// \brief Select constant vector splats whose value fits in a given integer.
   virtual bool selectVSplatCommon(SDValue N, SDValue &Imm, bool Signed,
                                   unsigned ImmBitSize) const;
+  /// \brief Select constant vector splats whose value fits in a uimm1.
+  virtual bool selectVSplatUimm1(SDValue N, SDValue &Imm) const;
+  /// \brief Select constant vector splats whose value fits in a uimm2.
+  virtual bool selectVSplatUimm2(SDValue N, SDValue &Imm) const;
   /// \brief Select constant vector splats whose value fits in a uimm3.
   virtual bool selectVSplatUimm3(SDValue N, SDValue &Imm) const;
   /// \brief Select constant vector splats whose value fits in a uimm4.
diff --git a/llvm/lib/Target/Mips/MipsSEISelLowering.cpp b/llvm/lib/Target/Mips/MipsSEISelLowering.cpp
index 51682f0..84db5ce 100644
--- a/llvm/lib/Target/Mips/MipsSEISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsSEISelLowering.cpp
@@ -1470,6 +1470,13 @@
   case Intrinsic::mips_slli_d:
     return DAG.getNode(ISD::SHL, DL, Op->getValueType(0),
                        Op->getOperand(1), lowerMSASplatImm(Op, 2, DAG));
+  case Intrinsic::mips_splati_b:
+  case Intrinsic::mips_splati_h:
+  case Intrinsic::mips_splati_w:
+  case Intrinsic::mips_splati_d:
+    return DAG.getNode(MipsISD::VSHF, DL, Op->getValueType(0),
+                       lowerMSASplatImm(Op, 2, DAG), Op->getOperand(1),
+                       Op->getOperand(1));
   case Intrinsic::mips_sra_b:
   case Intrinsic::mips_sra_h:
   case Intrinsic::mips_sra_w: