Simplify more cases of logical ops of masked icmps.
Summary:
For example,
((X & 255) != 0) && ((X & 15) == 8) -> ((X & 15) == 8).
((X & 7) != 0) && ((X & 15) == 8) -> false.
Reviewers: davidxl
Reviewed By: davidxl
Subscribers: llvm-commits
Differential Revision: https://reviews.llvm.org/D43835
llvm-svn: 327450
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 1b84eea..4a5a6d1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -305,17 +305,21 @@
 }
 
 /// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E).
-/// Return the set of pattern classes (from MaskedICmpType) that both LHS and
-/// RHS satisfy.
-static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
-                                         Value *&D, Value *&E, ICmpInst *LHS,
-                                         ICmpInst *RHS,
-                                         ICmpInst::Predicate &PredL,
-                                         ICmpInst::Predicate &PredR) {
+/// Return the pattern classes (from MaskedICmpType) for the left hand side and
+/// the right hand side as a pair.
+/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
+/// and PredR are their predicates, respectively.
+static
+Optional<std::pair<unsigned, unsigned>>
+getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
+                         Value *&D, Value *&E, ICmpInst *LHS,
+                         ICmpInst *RHS,
+                         ICmpInst::Predicate &PredL,
+                         ICmpInst::Predicate &PredR) {
   // vectors are not (yet?) supported. Don't support pointers either.
   if (!LHS->getOperand(0)->getType()->isIntegerTy() ||
       !RHS->getOperand(0)->getType()->isIntegerTy())
-    return 0;
+    return None;
 
   // Here comes the tricky part:
   // LHS might be of the form L11 & L12 == X, X == L21 & L22,
@@ -346,7 +350,7 @@
 
   // Bail if LHS was a icmp that can't be decomposed into an equality.
   if (!ICmpInst::isEquality(PredL))
-    return 0;
+    return None;
 
   Value *R1 = RHS->getOperand(0);
   Value *R2 = RHS->getOperand(1);
@@ -360,7 +364,7 @@
       A = R12;
       D = R11;
     } else {
-      return 0;
+      return None;
     }
     E = R2;
     R1 = nullptr;
@@ -388,7 +392,7 @@
 
   // Bail if RHS was a icmp that can't be decomposed into an equality.
   if (!ICmpInst::isEquality(PredR))
-    return 0;
+    return None;
 
   // Look for ANDs on the right side of the RHS icmp.
   if (!Ok) {
@@ -408,11 +412,11 @@
       E = R1;
       Ok = true;
     } else {
-      return 0;
+      return None;
     }
   }
   if (!Ok)
-    return 0;
+    return None;
 
   if (L11 == A) {
     B = L12;
@@ -430,7 +434,174 @@
 
   unsigned LeftType = getMaskedICmpType(A, B, C, PredL);
   unsigned RightType = getMaskedICmpType(A, D, E, PredR);
-  return LeftType & RightType;
+  return Optional<std::pair<unsigned, unsigned>>(std::make_pair(LeftType, RightType));
+}
+
+/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) into a single
+/// (icmp(A & X) ==/!= Y), where the left-hand side is of type Mask_NotAllZeros
+/// and the right hand side is of type BMask_Mixed. For example,
+/// (icmp (A & 12) != 0) & (icmp (A & 15) == 8) -> (icmp (A & 15) == 8).
+static Value * foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
+    ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
+    Value *A, Value *B, Value *C, Value *D, Value *E,
+    ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+    llvm::InstCombiner::BuilderTy &Builder) {
+  // We are given the canonical form:
+  //   (icmp ne (A & B), 0) & (icmp eq (A & D), E).
+  // where D & E == E.
+  //
+  // If IsAnd is false, we get it in negated form:
+  //   (icmp eq (A & B), 0) | (icmp ne (A & D), E) ->
+  //      !((icmp ne (A & B), 0) & (icmp eq (A & D), E)).
+  //
+  // We currently handle the case of B, C, D, E are constant.
+  //
+  ConstantInt *BCst = dyn_cast<ConstantInt>(B);
+  if (!BCst)
+    return nullptr;
+  ConstantInt *CCst = dyn_cast<ConstantInt>(C);
+  if (!CCst)
+    return nullptr;
+  ConstantInt *DCst = dyn_cast<ConstantInt>(D);
+  if (!DCst)
+    return nullptr;
+  ConstantInt *ECst = dyn_cast<ConstantInt>(E);
+  if (!ECst)
+    return nullptr;
+
+  ICmpInst::Predicate NewCC = IsAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
+
+  // Update E to the canonical form when D is a power of two and RHS is
+  // canonicalized as,
+  // (icmp ne (A & D), 0) -> (icmp eq (A & D), D) or
+  // (icmp ne (A & D), D) -> (icmp eq (A & D), 0).
+  if (PredR != NewCC)
+    ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst));
+
+  // If B or D is zero, skip because if LHS or RHS can be trivially folded by
+  // other folding rules and this pattern won't apply any more.
+  if (BCst->getValue() == 0 || DCst->getValue() == 0)
+    return nullptr;
+
+  // If B and D don't intersect, ie. (B & D) == 0, no folding because we can't
+  // deduce anything from it.
+  // For example,
+  // (icmp ne (A & 12), 0) & (icmp eq (A & 3), 1) -> no folding.
+  if ((BCst->getValue() & DCst->getValue()) == 0)
+    return nullptr;
+
+  // If the following two conditions are met:
+  //
+  // 1. mask B covers only a single bit that's not covered by mask D, that is,
+  // (B & (B ^ D)) is a power of 2 (in other words, B minus the intersection of
+  // B and D has only one bit set) and,
+  //
+  // 2. RHS (and E) indicates that the rest of B's bits are zero (in other
+  // words, the intersection of B and D is zero), that is, ((B & D) & E) == 0
+  //
+  // then that single bit in B must be one and thus the whole expression can be
+  // folded to
+  //   (A & (B | D)) == (B & (B ^ D)) | E.
+  //
+  // For example,
+  // (icmp ne (A & 12), 0) & (icmp eq (A & 7), 1) -> (icmp eq (A & 15), 9)
+  // (icmp ne (A & 15), 0) & (icmp eq (A & 7), 0) -> (icmp eq (A & 15), 8)
+  if ((((BCst->getValue() & DCst->getValue()) & ECst->getValue()) == 0) &&
+      (BCst->getValue() & (BCst->getValue() ^ DCst->getValue())).isPowerOf2()) {
+    APInt BorD = BCst->getValue() | DCst->getValue();
+    APInt BandBxorDorE = (BCst->getValue() & (BCst->getValue() ^ DCst->getValue())) |
+        ECst->getValue();
+    Value *NewMask = ConstantInt::get(BCst->getType(), BorD);
+    Value *NewMaskedValue = ConstantInt::get(BCst->getType(), BandBxorDorE);
+    Value *NewAnd = Builder.CreateAnd(A, NewMask);
+    return Builder.CreateICmp(NewCC, NewAnd, NewMaskedValue);
+  }
+
+  auto IsSubSetOrEqual = [](ConstantInt *C1, ConstantInt *C2) {
+    return (C1->getValue() & C2->getValue()) == C1->getValue();
+  };
+  auto IsSuperSetOrEqual = [](ConstantInt *C1, ConstantInt *C2) {
+    return (C1->getValue() & C2->getValue()) == C2->getValue();
+  };
+
+  // In the following, we consider only the cases where B is a superset of D, B
+  // is a subset of D, or B == D because otherwise there's at least one bit
+  // covered by B but not D, in which case we can't deduce much from it, so
+  // no folding (aside from the single must-be-one bit case right above.)
+  // For example,
+  // (icmp ne (A & 14), 0) & (icmp eq (A & 3), 1) -> no folding.
+  if (!IsSubSetOrEqual(BCst, DCst) && !IsSuperSetOrEqual(BCst, DCst))
+    return nullptr;
+
+  // At this point, either B is a superset of D, B is a subset of D or B == D.
+
+  // If E is zero, if B is a subset of (or equal to) D, LHS and RHS contradict
+  // and the whole expression becomes false (or true if negated), otherwise, no
+  // folding.
+  // For example,
+  // (icmp ne (A & 3), 0) & (icmp eq (A & 7), 0) -> false.
+  // (icmp ne (A & 15), 0) & (icmp eq (A & 3), 0) -> no folding.
+  if (ECst->isZero()) {
+    if (IsSubSetOrEqual(BCst, DCst))
+      return ConstantInt::get(LHS->getType(), !IsAnd);
+    return nullptr;
+  }
+
+  // At this point, B, D, E aren't zero and (B & D) == B, (B & D) == D or B ==
+  // D. If B is a superset of (or equal to) D, since E is not zero, LHS is
+  // subsumed by RHS (RHS implies LHS.) So the whole expression becomes
+  // RHS. For example,
+  // (icmp ne (A & 255), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
+  // (icmp ne (A & 15), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
+  if (IsSuperSetOrEqual(BCst, DCst))
+    return RHS;
+  // Otherwise, B is a subset of D. If B and E have a common bit set,
+  // ie. (B & E) != 0, then LHS is subsumed by RHS. For example.
+  // (icmp ne (A & 12), 0) & (icmp eq (A & 15), 8) -> (icmp eq (A & 15), 8).
+  assert(IsSubSetOrEqual(BCst, DCst) && "Precondition due to above code");
+  if ((BCst->getValue() & ECst->getValue()) != 0)
+    return RHS;
+  // Otherwise, LHS and RHS contradict and the whole expression becomes false
+  // (or true if negated.) For example,
+  // (icmp ne (A & 7), 0) & (icmp eq (A & 15), 8) -> false.
+  // (icmp ne (A & 6), 0) & (icmp eq (A & 15), 8) -> false.
+  return ConstantInt::get(LHS->getType(), !IsAnd);
+}
+
+/// Try to fold (icmp(A & B) ==/!= 0) &/| (icmp(A & D) ==/!= E) into a single
+/// (icmp(A & X) ==/!= Y), where the left-hand side and the right hand side
+/// aren't of the common mask pattern type.
+static Value *foldLogOpOfMaskedICmpsAsymmetric(
+    ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
+    Value *A, Value *B, Value *C, Value *D, Value *E,
+    ICmpInst::Predicate PredL, ICmpInst::Predicate PredR,
+    unsigned LHSMask, unsigned RHSMask,
+    llvm::InstCombiner::BuilderTy &Builder) {
+  assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
+         "Expected equality predicates for masked type of icmps.");
+  // Handle Mask_NotAllZeros-BMask_Mixed cases.
+  // (icmp ne/eq (A & B), C) &/| (icmp eq/ne (A & D), E), or
+  // (icmp eq/ne (A & B), C) &/| (icmp ne/eq (A & D), E)
+  //    which gets swapped to
+  //    (icmp ne/eq (A & D), E) &/| (icmp eq/ne (A & B), C).
+  if (!IsAnd) {
+    LHSMask = conjugateICmpMask(LHSMask);
+    RHSMask = conjugateICmpMask(RHSMask);
+  }
+  if ((LHSMask & Mask_NotAllZeros) && (RHSMask & BMask_Mixed)) {
+    if (Value *V = foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
+            LHS, RHS, IsAnd, A, B, C, D, E,
+            PredL, PredR, Builder)) {
+      return V;
+    }
+  } else if ((LHSMask & BMask_Mixed) && (RHSMask & Mask_NotAllZeros)) {
+    if (Value *V = foldLogOpOfMaskedICmps_NotAllZeros_BMask_Mixed(
+            RHS, LHS, IsAnd, A, D, E, B, C,
+            PredR, PredL, Builder)) {
+      return V;
+    }
+  }
+  return nullptr;
 }
 
 /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
@@ -439,13 +610,24 @@
                                      llvm::InstCombiner::BuilderTy &Builder) {
   Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
   ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
-  unsigned Mask =
+  Optional<std::pair<unsigned, unsigned>> MaskPair =
       getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
-  if (Mask == 0)
+  if (!MaskPair)
     return nullptr;
-
   assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
          "Expected equality predicates for masked type of icmps.");
+  unsigned LHSMask = MaskPair->first;
+  unsigned RHSMask = MaskPair->second;
+  unsigned Mask = LHSMask & RHSMask;
+  if (Mask == 0) {
+    // Even if the two sides don't share a common pattern, check if folding can
+    // still happen.
+    if (Value *V = foldLogOpOfMaskedICmpsAsymmetric(
+            LHS, RHS, IsAnd, A, B, C, D, E, PredL, PredR, LHSMask, RHSMask,
+            Builder))
+      return V;
+    return nullptr;
+  }
 
   // In full generality:
   //     (icmp (A & B) Op C) | (icmp (A & D) Op E)