[MIPS] Add support to match more patterns for DINS instruction

This patch adds support for recognizing patterns to match
DINS instruction.

Differential Revision: https://reviews.llvm.org/D31465

llvm-svn: 303537
diff --git a/llvm/lib/Target/Mips/MipsISelLowering.cpp b/llvm/lib/Target/Mips/MipsISelLowering.cpp
index 78bae69..3641a70 100644
--- a/llvm/lib/Target/Mips/MipsISelLowering.cpp
+++ b/llvm/lib/Target/Mips/MipsISelLowering.cpp
@@ -795,7 +795,7 @@
 
   SDValue And0 = N->getOperand(0), And1 = N->getOperand(1);
   uint64_t SMPos0, SMSize0, SMPos1, SMSize1;
-  ConstantSDNode *CN;
+  ConstantSDNode *CN, *CN1;
 
   // See if Op's first operand matches (and $src1 , mask0).
   if (And0.getOpcode() != ISD::AND)
@@ -806,37 +806,74 @@
     return SDValue();
 
   // See if Op's second operand matches (and (shl $src, pos), mask1).
-  if (And1.getOpcode() != ISD::AND)
+  if (And1.getOpcode() == ISD::AND &&
+      And1.getOperand(0).getOpcode() == ISD::SHL) {
+
+    if (!(CN = dyn_cast<ConstantSDNode>(And1.getOperand(1))) ||
+        !isShiftedMask(CN->getZExtValue(), SMPos1, SMSize1))
+      return SDValue();
+
+      // The shift masks must have the same position and size.
+      if (SMPos0 != SMPos1 || SMSize0 != SMSize1)
+        return SDValue();
+
+      SDValue Shl = And1.getOperand(0);
+
+      if (!(CN = dyn_cast<ConstantSDNode>(Shl.getOperand(1))))
+        return SDValue();
+
+      unsigned Shamt = CN->getZExtValue();
+
+      // Return if the shift amount and the first bit position of mask are not the
+      // same.
+      EVT ValTy = N->getValueType(0);
+      if ((Shamt != SMPos0) || (SMPos0 + SMSize0 > ValTy.getSizeInBits()))
+        return SDValue();
+
+      SDLoc DL(N);
+      return DAG.getNode(MipsISD::Ins, DL, ValTy, Shl.getOperand(0),
+                         DAG.getConstant(SMPos0, DL, MVT::i32),
+                         DAG.getConstant(SMSize0, DL, MVT::i32),
+                         And0.getOperand(0));
+  } else {
+    // Pattern match DINS.
+    //  $dst = or (and $src, mask0), mask1
+    //  where mask0 = ((1 << SMSize0) -1) << SMPos0
+    //  => dins $dst, $src, pos, size
+    if (~CN->getSExtValue() == ((((int64_t)1 << SMSize0) - 1) << SMPos0) &&
+        ((SMSize0 + SMPos0 <= 64 && Subtarget.hasMips64r2()) ||
+         (SMSize0 + SMPos0 <= 32))) {
+      // Check if AND instruction has constant as argument
+      bool isConstCase = And1.getOpcode() != ISD::AND;
+      if (And1.getOpcode() == ISD::AND) {
+        if (!(CN1 = dyn_cast<ConstantSDNode>(And1->getOperand(1))))
+          return SDValue();
+      } else {
+        if (!(CN1 = dyn_cast<ConstantSDNode>(N->getOperand(1))))
+          return SDValue();
+      }
+      SDLoc DL(N);
+      EVT ValTy = N->getOperand(0)->getValueType(0);
+      SDValue Const1;
+      SDValue SrlX;
+      if (!isConstCase) {
+        Const1 = DAG.getConstant(SMPos0, DL, MVT::i32);
+        SrlX = DAG.getNode(ISD::SRL, DL, And1->getValueType(0), And1, Const1);
+      }
+      return DAG.getNode(
+          MipsISD::Ins, DL, N->getValueType(0),
+          isConstCase
+              ? DAG.getConstant(CN1->getSExtValue() >> SMPos0, DL, ValTy)
+              : SrlX,
+          DAG.getConstant(SMPos0, DL, MVT::i32),
+          DAG.getConstant(ValTy.getSizeInBits() / 8 < 8 ? SMSize0 & 31
+                                                        : SMSize0,
+                          DL, MVT::i32),
+          And0->getOperand(0));
+
+    }
     return SDValue();
-
-  if (!(CN = dyn_cast<ConstantSDNode>(And1.getOperand(1))) ||
-      !isShiftedMask(CN->getZExtValue(), SMPos1, SMSize1))
-    return SDValue();
-
-  // The shift masks must have the same position and size.
-  if (SMPos0 != SMPos1 || SMSize0 != SMSize1)
-    return SDValue();
-
-  SDValue Shl = And1.getOperand(0);
-  if (Shl.getOpcode() != ISD::SHL)
-    return SDValue();
-
-  if (!(CN = dyn_cast<ConstantSDNode>(Shl.getOperand(1))))
-    return SDValue();
-
-  unsigned Shamt = CN->getZExtValue();
-
-  // Return if the shift amount and the first bit position of mask are not the
-  // same.
-  EVT ValTy = N->getValueType(0);
-  if ((Shamt != SMPos0) || (SMPos0 + SMSize0 > ValTy.getSizeInBits()))
-    return SDValue();
-
-  SDLoc DL(N);
-  return DAG.getNode(MipsISD::Ins, DL, ValTy, Shl.getOperand(0),
-                     DAG.getConstant(SMPos0, DL, MVT::i32),
-                     DAG.getConstant(SMSize0, DL, MVT::i32),
-                     And0.getOperand(0));
+  }
 }
 
 static SDValue performADDCombine(SDNode *N, SelectionDAG &DAG,