Pattern match a setcc of boolean value with 0 as a truncate.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@154322 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 1d135ef..dbc1ca9 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -4408,6 +4408,44 @@
   return SDValue();
 }
 
+// isTruncateOf - If N is a truncate of some other value, return true, record
+// the value being truncated in Op and which of Op's bits are zero in KnownZero.
+// This function computes KnownZero to avoid a duplicated call to
+// ComputeMaskedBits in the caller.
+static bool isTruncateOf(SelectionDAG &DAG, SDValue N, SDValue &Op,
+                         APInt &KnownZero) {
+  APInt KnownOne;
+  if (N->getOpcode() == ISD::TRUNCATE) {
+    Op = N->getOperand(0);
+    DAG.ComputeMaskedBits(Op, KnownZero, KnownOne);
+    return true;
+  }
+
+  if (N->getOpcode() != ISD::SETCC || N->getValueType(0) != MVT::i1 ||
+      cast<CondCodeSDNode>(N->getOperand(2))->get() != ISD::SETNE)
+    return false;
+
+  SDValue Op0 = N->getOperand(0);
+  SDValue Op1 = N->getOperand(1);
+  assert(Op0.getValueType() == Op1.getValueType());
+
+  ConstantSDNode *COp0 = dyn_cast<ConstantSDNode>(Op0);
+  ConstantSDNode *COp1 = dyn_cast<ConstantSDNode>(Op1);
+  if (COp0 && COp0->getZExtValue() == 0)
+    Op = Op1;
+  else if (COp1 && COp1->getZExtValue() == 0)
+    Op = Op0;
+  else
+    return false;
+
+  DAG.ComputeMaskedBits(Op, KnownZero, KnownOne);
+
+  if (!(KnownZero | APInt(Op.getValueSizeInBits(), 1)).isAllOnesValue())
+    return false;
+
+  return true;
+}
+
 SDValue DAGCombiner::visitZERO_EXTEND(SDNode *N) {
   SDValue N0 = N->getOperand(0);
   EVT VT = N->getValueType(0);
@@ -4425,15 +4463,16 @@
   //      (zext (truncate x)) -> (truncate x)
   // This is valid when the truncated bits of x are already zero.
   // FIXME: We should extend this to work for vectors too.
-  if (N0.getOpcode() == ISD::TRUNCATE && !VT.isVector()) {
-    SDValue Op = N0.getOperand(0);
-    APInt TruncatedBits
-      = APInt::getBitsSet(Op.getValueSizeInBits(),
-                          N0.getValueSizeInBits(),
-                          std::min(Op.getValueSizeInBits(),
-                                   VT.getSizeInBits()));
-    APInt KnownZero, KnownOne;
-    DAG.ComputeMaskedBits(Op, KnownZero, KnownOne);
+  SDValue Op;
+  APInt KnownZero;
+  if (!VT.isVector() && isTruncateOf(DAG, N0, Op, KnownZero)) {
+    APInt TruncatedBits =
+      (Op.getValueSizeInBits() == N0.getValueSizeInBits()) ?
+      APInt(Op.getValueSizeInBits(), 0) :
+      APInt::getBitsSet(Op.getValueSizeInBits(),
+                        N0.getValueSizeInBits(),
+                        std::min(Op.getValueSizeInBits(),
+                                 VT.getSizeInBits()));
     if (TruncatedBits == (KnownZero & TruncatedBits)) {
       if (VT.bitsGT(Op.getValueType()))
         return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT, Op);