[InstCombine] use canEvaluateShiftedShift() to handle the lshr case (NFCI)

We need just a couple of logic tweaks to consolidate the shl and lshr cases.

This is step 5 of refactoring to solve PR26760:
https://llvm.org/bugs/show_bug.cgi?id=26760

llvm-svn: 265965
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 6d68a55..e65dc94 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -57,8 +57,6 @@
 
 /// Return true if we can simplify two logical (either left or right) shifts
 /// that have constant shift amounts.
-/// FIXME: This can be extended to handle either a shl or lshr instruction, but
-/// it is currently only valid for a shl.
 static bool canEvaluateShiftedShift(unsigned FirstShiftAmt,
                                     bool IsFirstShiftLeft,
                                     Instruction *SecondShift, InstCombiner &IC,
@@ -69,24 +67,29 @@
     return false;
 
   unsigned SecondShiftAmt = SecondShiftConst->getZExtValue();
+  bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl;
 
-  // We can always fold shl(c1) + shl(c2) -> shl(c1+c2).
-  if (IsFirstShiftLeft)
+  // We can always fold  shl(c1) +  shl(c2) ->  shl(c1+c2).
+  // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2).
+  if (IsFirstShiftLeft == IsSecondShiftLeft)
     return true;
 
-  // We can always fold shr(c) + shl(c) -> and(c2).
-  if (SecondShiftAmt == FirstShiftAmt)
+  // We can always fold lshr(c) +  shl(c) -> and(c2).
+  // We can always fold  shl(c) + lshr(c) -> and(c2).
+  if (FirstShiftAmt == SecondShiftAmt)
     return true;
 
   unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits();
 
   // If the 2nd shift is bigger than the 1st, we can fold:
-  //   shr(c1) + shl(c2) -> shl(c3) + and(c4)
+  //   lshr(c1) +  shl(c2) ->  shl(c3) + and(c4) or
+  //   shl(c1)  + lshr(c2) -> lshr(c3) + and(c4),
   // but it isn't profitable unless we know the and'd out bits are already zero.
   // Also check that the 2nd shift is valid (less than the type width) or we'll
   // crash trying to produce the bit mask for the 'and'.
   if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) {
-    unsigned MaskShift = TypeWidth - SecondShiftAmt;
+    unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt
+                                           : SecondShiftAmt - FirstShiftAmt;
     APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift;
     if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI))
       return true;
@@ -155,33 +158,9 @@
            CanEvaluateShifted(I->getOperand(1), NumBits, isLeftShift, IC, I);
 
   case Instruction::Shl:
+  case Instruction::LShr:
     return canEvaluateShiftedShift(NumBits, isLeftShift, I, IC, CxtI);
 
-  case Instruction::LShr: {
-    // We can often fold the shift into shifts-by-a-constant.
-    CI = dyn_cast<ConstantInt>(I->getOperand(1));
-    if (!CI) return false;
-
-    // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
-    if (!isLeftShift) return true;
-
-    // We can always turn lshr(c)+shl(c) -> and(c2).
-    if (CI->getValue() == NumBits) return true;
-
-    unsigned TypeWidth = I->getType()->getScalarSizeInBits();
-
-    // We can always turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but it isn't
-    // profitable unless we know the and'd out bits are already zero.
-    if (CI->getValue().ult(TypeWidth) && CI->getZExtValue() > NumBits) {
-      unsigned LowBits = CI->getZExtValue() - NumBits;
-      if (IC.MaskedValueIsZero(I->getOperand(0),
-                          APInt::getLowBitsSet(TypeWidth, NumBits) << LowBits,
-                          0, CxtI))
-        return true;
-    }
-
-    return false;
-  }
   case Instruction::Select: {
     SelectInst *SI = cast<SelectInst>(I);
     Value *TrueVal = SI->getTrueValue();