[InstCombine] Shift amount reassociation in bittest: trunc-of-shl (PR42399)

Summary:
This is continuation of D63829 / https://bugs.llvm.org/show_bug.cgi?id=42399

I thought naive pattern would solve my issue, but nope, it involved truncation,
thus more folds needed.. This isn't really the fold i'm interested in,
i need trunc-of-lshr, but i'we decided to start with `shl` because it's simpler.

In this case, no extra legality checks are needed:
https://rise4fun.com/Alive/CAb

We should be careful about not increasing instruction count,
since we need to produce `zext` because `and` is done in wider type.

Reviewers: spatel, nikic, xbolva00

Reviewed By: spatel

Subscribers: hiraditya, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D66057

llvm-svn: 369117
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 88e56e0..babbd9d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -3299,6 +3299,7 @@
 // we should move shifts to the same hand of 'and', i.e. rewrite as
 //   icmp eq/ne (and (x shift (Q+K)), y), 0  iff (Q+K) u< bitwidth(x)
 // We are only interested in opposite logical shifts here.
+// One of the shifts can be truncated. For now, it can only be 'shl'.
 // If we can, we want to end up creating 'lshr' shift.
 static Value *
 foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
@@ -3308,18 +3309,37 @@
     return nullptr;
 
   auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value());
-  auto m_AnyLShr = m_LShr(m_Value(), m_Value());
 
-  // Look for an 'and' of two (opposite) logical shifts.
-  // Pick the single-use shift as XShift.
-  Instruction *XShift, *YShift;
-  if (!match(I.getOperand(0),
-             m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)),
-                     m_CombineAnd(m_AnyLogicalShift, m_Instruction(YShift)))))
+  // Look for an 'and' of two logical shifts, one of which may be truncated.
+  // We use m_TruncOrSelf() on the RHS to correctly handle commutative case.
+  Instruction *XShift, *MaybeTruncation, *YShift;
+  if (!match(
+          I.getOperand(0),
+          m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)),
+                  m_CombineAnd(m_TruncOrSelf(m_CombineAnd(
+                                   m_AnyLogicalShift, m_Instruction(YShift))),
+                               m_Instruction(MaybeTruncation)))))
     return nullptr;
 
+  Instruction *UntruncatedShift = XShift;
+
+  // We potentially looked past 'trunc', but only when matching YShift,
+  // therefore YShift must have the widest type.
+  Type *WidestTy = YShift->getType();
+  assert(XShift->getType() == I.getOperand(0)->getType() &&
+         "We did not look past any shifts while matching XShift though.");
+  bool HadTrunc = WidestTy != I.getOperand(0)->getType();
+
+  if (HadTrunc) {
+    // We did indeed have a truncation. For now, let's only proceed if the 'shl'
+    // was truncated, since that does not require any extra legality checks.
+    // FIXME: trunc-of-lshr.
+    if (!match(YShift, m_Shl(m_Value(), m_Value())))
+      return nullptr;
+  }
+
   // If YShift is a 'lshr', swap the shifts around.
-  if (match(YShift, m_AnyLShr))
+  if (match(YShift, m_LShr(m_Value(), m_Value())))
     std::swap(XShift, YShift);
 
   // The shifts must be in opposite directions.
@@ -3328,37 +3348,54 @@
     return nullptr; // Do not care about same-direction shifts here.
 
   Value *X, *XShAmt, *Y, *YShAmt;
-  match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt)));
-  match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt)));
+  match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt))));
+  match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt))));
 
   // If one of the values being shifted is a constant, then we will end with
-  // and+icmp, and shift instr will be constant-folded. If they are not,
+  // and+icmp, and [zext+]shift instrs will be constant-folded. If they are not,
   // however, we will need to ensure that we won't increase instruction count.
   if (!isa<Constant>(X) && !isa<Constant>(Y)) {
     // At least one of the hands of the 'and' should be one-use shift.
     if (!match(I.getOperand(0),
                m_c_And(m_OneUse(m_AnyLogicalShift), m_Value())))
       return nullptr;
+    if (HadTrunc) {
+      // Due to the 'trunc', we will need to widen X. For that either the old
+      // 'trunc' or the shift amt in the non-truncated shift should be one-use.
+      if (!MaybeTruncation->hasOneUse() &&
+          !UntruncatedShift->getOperand(1)->hasOneUse())
+        return nullptr;
+    }
   }
 
+  // We have two shift amounts from two different shifts. The types of those
+  // shift amounts may not match. If that's the case let's bailout now.
+  if (XShAmt->getType() != YShAmt->getType())
+    return nullptr;
+
   // Can we fold (XShAmt+YShAmt) ?
-  Value *NewShAmt = SimplifyAddInst(XShAmt, YShAmt, /*IsNSW=*/false,
-                                    /*IsNUW=*/false, SQ.getWithInstruction(&I));
+  auto *NewShAmt = dyn_cast_or_null<Constant>(
+      SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false,
+                      /*isNUW=*/false, SQ.getWithInstruction(&I)));
   if (!NewShAmt)
     return nullptr;
   // Is the new shift amount smaller than the bit width?
   // FIXME: could also rely on ConstantRange.
-  unsigned BitWidth = X->getType()->getScalarSizeInBits();
-  if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT,
-                                          APInt(BitWidth, BitWidth))))
+  if (!match(NewShAmt, m_SpecificInt_ICMP(
+                           ICmpInst::Predicate::ICMP_ULT,
+                           APInt(NewShAmt->getType()->getScalarSizeInBits(),
+                                 WidestTy->getScalarSizeInBits()))))
     return nullptr;
-  // All good, we can do this fold. The shift is the same that was for X.
+  // All good, we can do this fold.
+  NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy);
+  X = Builder.CreateZExt(X, WidestTy);
+  // The shift is the same that was for X.
   Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr
                   ? Builder.CreateLShr(X, NewShAmt)
                   : Builder.CreateShl(X, NewShAmt);
   Value *T1 = Builder.CreateAnd(T0, Y);
   return Builder.CreateICmp(I.getPredicate(), T1,
-                            Constant::getNullValue(X->getType()));
+                            Constant::getNullValue(WidestTy));
 }
 
 /// Try to fold icmp (binop), X or icmp X, (binop).