Transform (x << (y&31)) -> (x << y). This takes advantage of the fact x86 shift instructions 2nd operand (shift count) is limited to 0 to 31 (or 63 in the x86-64 case).


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@55558 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index d0a3610..bb9dde0 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -2310,6 +2310,26 @@
   if (DAG.MaskedValueIsZero(SDValue(N, 0),
                             APInt::getAllOnesValue(VT.getSizeInBits())))
     return DAG.getConstant(0, VT);
+  // fold (shl x, (trunc (and y, c))) -> (shl x, (and (trunc y), c))
+  // iff (trunc c) == c
+  if (N1.getOpcode() == ISD::TRUNCATE &&
+      N1.getOperand(0).getOpcode() == ISD::AND) {
+    SDValue N101 = N1.getOperand(0).getOperand(1);
+    ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N101);
+    if (N101C) {
+      MVT TruncVT = N1.getValueType();
+      unsigned TruncBitSize = TruncVT.getSizeInBits();
+      APInt ShAmt = N101C->getAPIntValue();
+      if (ShAmt.trunc(TruncBitSize).getZExtValue() == N101C->getValue()) {
+        SDValue N100 = N1.getOperand(0).getOperand(0);
+        return DAG.getNode(ISD::SHL, VT, N0,
+                           DAG.getNode(ISD::AND, TruncVT,
+                                  DAG.getNode(ISD::TRUNCATE, TruncVT, N100),
+                                  DAG.getConstant(N101C->getValue(), TruncVT)));
+      }
+    }
+  }
+
   if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
     return SDValue(N, 0);
   // fold (shl (shl x, c1), c2) -> 0 or (shl x, c1+c2)
@@ -2421,6 +2441,26 @@
     }
   }
   
+  // fold (sra x, (trunc (and y, c))) -> (sra x, (and (trunc y), c))
+  // iff (trunc c) == c
+  if (N1.getOpcode() == ISD::TRUNCATE &&
+      N1.getOperand(0).getOpcode() == ISD::AND) {
+    SDValue N101 = N1.getOperand(0).getOperand(1);
+    ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N101);
+    if (N101C) {
+      MVT TruncVT = N1.getValueType();
+      unsigned TruncBitSize = TruncVT.getSizeInBits();
+      APInt ShAmt = N101C->getAPIntValue();
+      if (ShAmt.trunc(TruncBitSize).getZExtValue() == N101C->getValue()) {
+        SDValue N100 = N1.getOperand(0).getOperand(0);
+        return DAG.getNode(ISD::SRA, VT, N0,
+                           DAG.getNode(ISD::AND, TruncVT,
+                                  DAG.getNode(ISD::TRUNCATE, TruncVT, N100),
+                                  DAG.getConstant(N101C->getValue(), TruncVT)));
+      }
+    }
+  }
+
   // Simplify, based on bits shifted out of the LHS. 
   if (N1C && SimplifyDemandedBits(SDValue(N, 0)))
     return SDValue(N, 0);
@@ -2520,6 +2560,26 @@
       return DAG.getNode(ISD::XOR, VT, Op, DAG.getConstant(1, VT));
     }
   }
+
+  // fold (srl x, (trunc (and y, c))) -> (srl x, (and (trunc y), c))
+  // iff (trunc c) == c
+  if (N1.getOpcode() == ISD::TRUNCATE &&
+      N1.getOperand(0).getOpcode() == ISD::AND) {
+    SDValue N101 = N1.getOperand(0).getOperand(1);
+    ConstantSDNode *N101C = dyn_cast<ConstantSDNode>(N101);
+    if (N101C) {
+      MVT TruncVT = N1.getValueType();
+      unsigned TruncBitSize = TruncVT.getSizeInBits();
+      APInt ShAmt = N101C->getAPIntValue();
+      if (ShAmt.trunc(TruncBitSize).getZExtValue() == N101C->getValue()) {
+        SDValue N100 = N1.getOperand(0).getOperand(0);
+        return DAG.getNode(ISD::SRL, VT, N0,
+                           DAG.getNode(ISD::AND, TruncVT,
+                                  DAG.getNode(ISD::TRUNCATE, TruncVT, N100),
+                                  DAG.getConstant(N101C->getValue(), TruncVT)));
+      }
+    }
+  }
   
   // fold operands of srl based on knowledge that the low bits are not
   // demanded.