[InstCombine] New opportunities for FoldAndOfICmp and FoldXorOfICmp

A number of new patterns for simplifying and/xor of icmp:

(icmp ne %x, 0) ^ (icmp ne %y, 0) => icmp ne %x, %y if the following is true:
1- (%x = and %a, %mask) and (%y = and %b, %mask)
2- %mask is a power of 2.

(icmp eq %x, 0) & (icmp ne %y, 0) => icmp ult %x, %y if the following is true:
1- (%x = and %a, %mask1) and (%y = and %b, %mask2)
2- Let %t be the smallest power of 2 where %mask1 & %t != 0. Then for any
   %s that is a power of 2 and %s & %mask2 != 0, we must have %s <= %t.
For example if %mask1 = 24 and %mask2 = 16, setting %s = 16 and %t = 8
violates condition (2) above. So this optimization cannot be applied.

llvm-svn: 289813
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index d4bd78b..e1e060b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -733,6 +733,44 @@
   return nullptr;
 }
 
+namespace {
+
+struct BitGroupCheck {
+  // If the Cmp, checks the bits in the group are nonzero?
+  bool CheckIfSet {false};
+  // The mask that identifies the bitgroup in question.
+  const APInt *Mask {nullptr};
+};
+}
+/// For an ICMP where RHS is zero, we want to check if the ICMP is equivalent to
+/// comparing a group of bits in an integer value against zero.
+BitGroupCheck isAnyBitSet(Value *LHS, ICmpInst::Predicate CC) {
+
+  BitGroupCheck BGC;
+  auto *Inst = dyn_cast<Instruction>(LHS);
+
+  if (!Inst || Inst->getOpcode() != Instruction::And)
+    return BGC;
+
+  // TODO Currently this does not work for vectors.
+  ConstantInt *Mask;
+  if (!match(LHS, m_And(m_Value(), m_ConstantInt(Mask))))
+    return BGC;
+  // At this point we know that LHS of ICMP is "and" of a value with a constant.
+  // Also we know that the RHS is zero. That means we are checking if a certain
+  // group of bits in a given integer value are all zero or at least one of them
+  // is set to one.
+  if (CC == ICmpInst::ICMP_EQ)
+    BGC.CheckIfSet = false;
+  else if (CC == ICmpInst::ICMP_NE)
+    BGC.CheckIfSet = true;
+  else
+    return BGC;
+
+  BGC.Mask = &Mask->getValue();
+  return BGC;
+}
+
 /// Try to fold a signed range checked with lower bound 0 to an unsigned icmp.
 /// Example: (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n
 /// If \p Inverted is true then the check is for the inverted range, e.g.
@@ -789,6 +827,32 @@
   return Builder->CreateICmp(NewPred, Input, RangeEnd);
 }
 
+Value *InstCombiner::FoldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
+
+  Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0);
+  // TODO The lines below does not work for vectors. ConstantInt is scalar.
+  auto *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1));
+  auto *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1));
+  if (!LHSCst || !RHSCst)
+    return nullptr;
+  ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
+
+  // E.g. (icmp ne %x, 0) ^ (icmp ne %y, 0) => icmp ne %x, %y if the following
+  // conditions hold:
+  // 1- (%x = and %a, %mask) and (%y = and %b, %mask)
+  // 2- %mask is a power of 2.
+  if (RHSCst->isZero() && LHSCst == RHSCst) {
+
+    BitGroupCheck BGC1 = isAnyBitSet(Val, LHSCC);
+    BitGroupCheck BGC2 = isAnyBitSet(Val2, RHSCC);
+    if (BGC1.Mask && BGC2.Mask && BGC1.CheckIfSet == BGC2.CheckIfSet &&
+        *BGC1.Mask == *BGC2.Mask && BGC1.Mask->isPowerOf2()) {
+      return Builder->CreateICmp(ICmpInst::ICMP_NE, Val2, Val);
+    }
+  }
+  return nullptr;
+}
+
 /// Fold (icmp)&(icmp) if possible.
 Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
   ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
@@ -871,6 +935,29 @@
     }
   }
 
+  // E.g. (icmp eq %x, 0) & (icmp ne %y, 0) => icmp ult %x, %y if the following
+  // conditions hold:
+  // 1- (%x = and %a, %mask1) and (%y = and %b, %mask2)
+  // 2- Let %t be the smallest power of 2 where %mask1 & %t != 0. Then for any
+  //    %s that is a power of 2 and %s & %mask2 != 0, we must have %s <= %t.
+  // For example if %mask1 = 24 and %mask2 = 16, setting %s = 16 and %t = 8
+  // violates condition (2) above. So this optimization cannot be applied.
+  if (RHSCst->isZero() && LHSCst == RHSCst) {
+    BitGroupCheck BGC1 = isAnyBitSet(Val, LHSCC);
+    BitGroupCheck BGC2 = isAnyBitSet(Val2, RHSCC);
+
+    if (BGC1.Mask && BGC2.Mask && (BGC1.CheckIfSet != BGC2.CheckIfSet)) {
+      if (!BGC1.CheckIfSet &&
+          BGC1.Mask->countTrailingZeros() >=
+          BGC2.Mask->getBitWidth() - BGC2.Mask->countLeadingZeros() - 1)
+        return Builder->CreateICmp(ICmpInst::ICMP_ULT, Val, Val2);
+      else if (!BGC2.CheckIfSet &&
+          BGC2.Mask->countTrailingZeros() >=
+          BGC1.Mask->getBitWidth() - BGC1.Mask->countLeadingZeros() - 1)
+        return Builder->CreateICmp(ICmpInst::ICMP_ULT, Val2, Val);
+    }
+  }
+
   // From here on, we only handle:
   //    (icmp1 A, C1) & (icmp2 A, C2) --> something simpler.
   if (Val != Val2) return nullptr;
@@ -2704,9 +2791,16 @@
       match(Op1, m_Not(m_Specific(A))))
     return BinaryOperator::CreateNot(Builder->CreateAnd(A, B));
 
-  // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B)
   if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1)))
-    if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0)))
+    if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) {
+
+      // E.g. if we have xor (icmp eq %A, 0), (icmp eq %B, 0)
+      // and we know both A and B are either 8 (power of 2) or 0
+      // we can simplify to (icmp ne A, B).
+      if (Value *Res = FoldXorOfICmps(LHS, RHS))
+        return replaceInstUsesWith(I, Res);
+
+      // (icmp1 A, B) ^ (icmp2 A, B) --> (icmp3 A, B)
       if (PredicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) {
         if (LHS->getOperand(0) == RHS->getOperand(1) &&
             LHS->getOperand(1) == RHS->getOperand(0))
@@ -2721,6 +2815,7 @@
                                                Builder));
         }
       }
+    }
 
   if (Instruction *CastedXor = foldCastedBitwiseLogic(I))
     return CastedXor;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 8b71352..24ba412 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -239,6 +239,7 @@
   Instruction *visitFDiv(BinaryOperator &I);
   Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted);
   Value *FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS);
+  Value *FoldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS);
   Value *FoldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS);
   Instruction *visitAnd(BinaryOperator &I);
   Value *FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI);