InstCombine: simplify comparisons to zero of (shl %x, Cst) or (mul %x, Cst)

This simplification happens at 2 places :
 - using the nsw attribute when the shl / mul is used by a sign test
 - when the shl / mul is compared for (in)equality to zero

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@177856 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp
index a4e117e..24af2bf 100644
--- a/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -139,6 +139,42 @@
   }
 }
 
+/// Returns true if the exploded icmp can be expressed as a comparison to zero
+/// and update the predicate accordingly. The signedness of the comparison is
+static bool isSignTest(ICmpInst::Predicate &pred, const ConstantInt *RHS) {
+  if (!ICmpInst::isSigned(pred))
+    return false;
+
+  if (RHS->isZero())
+    return true;
+
+  if (RHS->isOne())
+    switch (pred) {
+    case ICmpInst::ICMP_SGE:
+      pred = ICmpInst::ICMP_SGT;
+      return true;
+    case ICmpInst::ICMP_SLT:
+      pred = ICmpInst::ICMP_SLE;
+      return true;
+    default:
+      return false;
+    }
+
+  if (RHS->isAllOnesValue())
+    switch (pred) {
+    case ICmpInst::ICMP_SLE:
+      pred = ICmpInst::ICMP_SLT;
+      return true;
+    case ICmpInst::ICMP_SGT:
+      pred = ICmpInst::ICMP_SGE;
+      return true;
+    default:
+      return false;
+    }
+
+  return false;
+}
+
 // isHighOnes - Return true if the constant is of the form 1+0+.
 // This is the same as lowones(~X).
 static bool isHighOnes(const ConstantInt *CI) {
@@ -1282,6 +1318,25 @@
     break;
   }
 
+  case Instruction::Mul: {       // (icmp pred (mul X, Val), CI)
+    ConstantInt *Val = dyn_cast<ConstantInt>(LHSI->getOperand(1));
+    if (!Val) break;
+
+    if (!ICI.isEquality()) {
+      // If this is a signed comparison to 0 and the mul is sign preserving,
+      // use the mul LHS operand instead.
+      ICmpInst::Predicate pred = ICI.getPredicate();
+      if (isSignTest(pred, RHS) && !Val->isZero() &&
+          cast<BinaryOperator>(LHSI)->hasNoSignedWrap())
+          return new ICmpInst(Val->isNegative() ?
+                                  ICmpInst::getSwappedPredicate(pred) : pred,
+                              LHSI->getOperand(0),
+                              Constant::getNullValue(RHS->getType()));
+    }
+
+    break;
+  }
+
   case Instruction::Shl: {       // (icmp pred (shl X, ShAmt), CI)
     ConstantInt *ShAmt = dyn_cast<ConstantInt>(LHSI->getOperand(1));
     if (!ShAmt) break;
@@ -1313,6 +1368,12 @@
         return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
                             ConstantExpr::getLShr(RHS, ShAmt));
 
+      // If the shift is NSW and we compare to 0, then it is just shifting out
+      // sign bits, no need for an AND either.
+      if (cast<BinaryOperator>(LHSI)->hasNoSignedWrap() && RHSV == 0)
+        return new ICmpInst(ICI.getPredicate(), LHSI->getOperand(0),
+                            ConstantExpr::getLShr(RHS, ShAmt));
+
       if (LHSI->hasOneUse()) {
         // Otherwise strength reduce the shift into an and.
         uint32_t ShAmtVal = (uint32_t)ShAmt->getLimitedValue(TypeBits);
@@ -1327,6 +1388,15 @@
       }
     }
 
+    // If this is a signed comparison to 0 and the shift is sign preserving,
+    // use the shift LHS operand instead.
+    ICmpInst::Predicate pred = ICI.getPredicate();
+    if (isSignTest(pred, RHS) &&
+        cast<BinaryOperator>(LHSI)->hasNoSignedWrap())
+      return new ICmpInst(pred,
+                          LHSI->getOperand(0),
+                          Constant::getNullValue(RHS->getType()));
+
     // Otherwise, if this is a comparison of the sign bit, simplify to and/test.
     bool TrueIfSigned = false;
     if (LHSI->hasOneUse() &&
@@ -1541,6 +1611,19 @@
             return new ICmpInst(pred, X, NegX);
           }
         }
+        break;
+      case Instruction::Mul:
+        if (RHSV == 0) {
+          if (ConstantInt *BOC = dyn_cast<ConstantInt>(BO->getOperand(1))) {
+            // The trivial case (mul X, 0) is handled by InstSimplify
+            // General case : (mul X, C) != 0 iff X != 0
+            //                (mul X, C) == 0 iff X == 0
+            if (!BOC->isZero())
+              return new ICmpInst(ICI.getPredicate(), BO->getOperand(0),
+                                  Constant::getNullValue(RHS->getType()));
+          }
+        }
+        break;
       default: break;
       }
     } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(LHSI)) {