Rather than having a ton of patterns for double shift instructions, e.g. SHLD16rrCL, just perform custom dag combine to form x86 specific dag so they match to the same pattern. This also makes sure later dag combine do not cause isel to miss them (e.g. promoting i16 to i32).


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@102485 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Target/X86/X86ISelLowering.cpp b/lib/Target/X86/X86ISelLowering.cpp
index 58f1d88..c38b678 100644
--- a/lib/Target/X86/X86ISelLowering.cpp
+++ b/lib/Target/X86/X86ISelLowering.cpp
@@ -9595,9 +9595,13 @@
 }
 
 static SDValue PerformOrCombine(SDNode *N, SelectionDAG &DAG,
+                                TargetLowering::DAGCombinerInfo &DCI,
                                 const X86Subtarget *Subtarget) {
+  if (!DCI.isBeforeLegalize())
+    return SDValue();
+
   EVT VT = N->getValueType(0);
-  if (VT != MVT::i64 || !Subtarget->is64Bit())
+  if (VT != MVT::i16 && VT != MVT::i32 && VT != MVT::i64)
     return SDValue();
 
   // fold (or (x << c) | (y >> (64 - c))) ==> (shld64 x, y, c)
@@ -9607,6 +9611,8 @@
     std::swap(N0, N1);
   if (N0.getOpcode() != ISD::SHL || N1.getOpcode() != ISD::SRL)
     return SDValue();
+  if (!N0.hasOneUse() || !N1.hasOneUse())
+    return SDValue();
 
   SDValue ShAmt0 = N0.getOperand(1);
   if (ShAmt0.getValueType() != MVT::i8)
@@ -9629,10 +9635,11 @@
     std::swap(ShAmt0, ShAmt1);
   }
 
+  unsigned Bits = VT.getSizeInBits();
   if (ShAmt1.getOpcode() == ISD::SUB) {
     SDValue Sum = ShAmt1.getOperand(0);
     if (ConstantSDNode *SumC = dyn_cast<ConstantSDNode>(Sum)) {
-      if (SumC->getSExtValue() == 64 &&
+      if (SumC->getSExtValue() == Bits &&
           ShAmt1.getOperand(1) == ShAmt0)
         return DAG.getNode(Opc, DL, VT,
                            Op0, Op1,
@@ -9642,7 +9649,7 @@
   } else if (ConstantSDNode *ShAmt1C = dyn_cast<ConstantSDNode>(ShAmt1)) {
     ConstantSDNode *ShAmt0C = dyn_cast<ConstantSDNode>(ShAmt0);
     if (ShAmt0C &&
-        ShAmt0C->getSExtValue() + ShAmt1C->getSExtValue() == 64)
+        ShAmt0C->getSExtValue() + ShAmt1C->getSExtValue() == Bits)
       return DAG.getNode(Opc, DL, VT,
                          N0.getOperand(0), N1.getOperand(0),
                          DAG.getNode(ISD::TRUNCATE, DL,
@@ -9921,7 +9928,7 @@
   case ISD::SHL:
   case ISD::SRA:
   case ISD::SRL:            return PerformShiftCombine(N, DAG, Subtarget);
-  case ISD::OR:             return PerformOrCombine(N, DAG, Subtarget);
+  case ISD::OR:             return PerformOrCombine(N, DAG, DCI, Subtarget);
   case ISD::STORE:          return PerformSTORECombine(N, DAG, Subtarget);
   case X86ISD::FXOR:
   case X86ISD::FOR:         return PerformFORCombine(N, DAG);