Generalize this optimization to work on equality comparisons between any two
integers that are constant except for a single bit (the same n-th bit in each).


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@90646 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp
index d12ad81..2b4b66b 100644
--- a/lib/Transforms/Scalar/InstructionCombining.cpp
+++ b/lib/Transforms/Scalar/InstructionCombining.cpp
@@ -8585,25 +8585,36 @@
   if (ICI->isEquality() && CI.getType() == ICI->getOperand(0)->getType()) {
     if (const IntegerType *ITy = dyn_cast<IntegerType>(CI.getType())) {
       uint32_t BitWidth = ITy->getBitWidth();
-      if (BitWidth > 1) {
-        Value *LHS = ICI->getOperand(0);
-        Value *RHS = ICI->getOperand(1);
+      Value *LHS = ICI->getOperand(0);
+      Value *RHS = ICI->getOperand(1);
 
-        APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0);
-        APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0);
-        APInt TypeMask(APInt::getHighBitsSet(BitWidth, BitWidth-1));
-        ComputeMaskedBits(LHS, TypeMask, KnownZeroLHS, KnownOneLHS);
-        ComputeMaskedBits(RHS, TypeMask, KnownZeroRHS, KnownOneRHS);
+      APInt KnownZeroLHS(BitWidth, 0), KnownOneLHS(BitWidth, 0);
+      APInt KnownZeroRHS(BitWidth, 0), KnownOneRHS(BitWidth, 0);
+      APInt TypeMask(APInt::getAllOnesValue(BitWidth));
+      ComputeMaskedBits(LHS, TypeMask, KnownZeroLHS, KnownOneLHS);
+      ComputeMaskedBits(RHS, TypeMask, KnownZeroRHS, KnownOneRHS);
 
-        if (KnownZeroLHS.countLeadingOnes() == BitWidth-1 &&
-            KnownZeroRHS.countLeadingOnes() == BitWidth-1) {
+      if (KnownZeroLHS == KnownZeroRHS && KnownOneLHS == KnownOneRHS) {
+        APInt KnownBits = KnownZeroLHS | KnownOneLHS;
+        APInt UnknownBit = ~KnownBits;
+        if (UnknownBit.countPopulation() == 1) {
           if (!DoXform) return ICI;
 
-          Value *Xor = Builder->CreateXor(LHS, RHS);
+          Value *Result = Builder->CreateXor(LHS, RHS);
+
+          // Mask off any bits that are set and won't be shifted away.
+          if (KnownOneLHS.uge(UnknownBit))
+            Result = Builder->CreateAnd(Result,
+                                        ConstantInt::get(ITy, UnknownBit));
+
+          // Shift the bit we're testing down to the lsb.
+          Result = Builder->CreateLShr(
+               Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros()));
+
           if (ICI->getPredicate() == ICmpInst::ICMP_EQ)
-            Xor = Builder->CreateXor(Xor, ConstantInt::get(ITy, 1));
-          Xor->takeName(ICI);
-          return ReplaceInstUsesWith(CI, Xor);
+            Result = Builder->CreateXor(Result, ConstantInt::get(ITy, 1));
+          Result->takeName(ICI);
+          return ReplaceInstUsesWith(CI, Result);
         }
       }
     }