Model ashr(shl(x, n), m) as mul(x, 2^(n-m)) when n > m

Given below case:

  %y = shl %x, n
  %z = ashr %y, m

when n = m, SCEV models it as sext(trunc(x)). This patch tries to handle
the case where n > m by using sext(mul(trunc(x), 2^(n-m)))) as the SCEV
expression.

llvm-svn: 298631
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index c820464..5863406 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -5356,28 +5356,55 @@
     break;
 
     case Instruction::AShr:
-      // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression.
-      if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS))
-        if (Operator *L = dyn_cast<Operator>(BO->LHS))
-          if (L->getOpcode() == Instruction::Shl &&
-              L->getOperand(1) == BO->RHS) {
-            uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType());
+      // AShr X, C, where C is a constant.
+      ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);
+      if (!CI)
+        break;
 
-            // If the shift count is not less than the bitwidth, the result of
-            // the shift is undefined. Don't try to analyze it, because the
-            // resolution chosen here may differ from the resolution chosen in
-            // other parts of the compiler.
-            if (CI->getValue().uge(BitWidth))
-              break;
+      Type *OuterTy = BO->LHS->getType();
+      uint64_t BitWidth = getTypeSizeInBits(OuterTy);
+      // If the shift count is not less than the bitwidth, the result of
+      // the shift is undefined. Don't try to analyze it, because the
+      // resolution chosen here may differ from the resolution chosen in
+      // other parts of the compiler.
+      if (CI->getValue().uge(BitWidth))
+        break;
 
-            uint64_t Amt = BitWidth - CI->getZExtValue();
-            if (Amt == BitWidth)
-              return getSCEV(L->getOperand(0)); // shift by zero --> noop
+      if (CI->isNullValue())
+        return getSCEV(BO->LHS); // shift by zero --> noop
+
+      uint64_t AShrAmt = CI->getZExtValue();
+      Type *TruncTy = IntegerType::get(getContext(), BitWidth - AShrAmt);
+
+      Operator *L = dyn_cast<Operator>(BO->LHS);
+      if (L && L->getOpcode() == Instruction::Shl) {
+        // X = Shl A, n
+        // Y = AShr X, m
+        // Both n and m are constant.
+
+        const SCEV *ShlOp0SCEV = getSCEV(L->getOperand(0));
+        if (L->getOperand(1) == BO->RHS)
+          // For a two-shift sext-inreg, i.e. n = m,
+          // use sext(trunc(x)) as the SCEV expression.
+          return getSignExtendExpr(
+              getTruncateExpr(ShlOp0SCEV, TruncTy), OuterTy);
+
+        ConstantInt *ShlAmtCI = dyn_cast<ConstantInt>(L->getOperand(1));
+        if (ShlAmtCI && ShlAmtCI->getValue().ult(BitWidth)) {
+          uint64_t ShlAmt = ShlAmtCI->getZExtValue();
+          if (ShlAmt > AShrAmt) {
+            // When n > m, use sext(mul(trunc(x), 2^(n-m)))) as the SCEV
+            // expression. We already checked that ShlAmt < BitWidth, so
+            // the multiplier, 1 << (ShlAmt - AShrAmt), fits into TruncTy as
+            // ShlAmt - AShrAmt < Amt.
+            APInt Mul = APInt::getOneBitSet(BitWidth - AShrAmt,
+                                            ShlAmt - AShrAmt);
             return getSignExtendExpr(
-                getTruncateExpr(getSCEV(L->getOperand(0)),
-                                IntegerType::get(getContext(), Amt)),
-                BO->LHS->getType());
+                getMulExpr(getTruncateExpr(ShlOp0SCEV, TruncTy),
+                getConstant(Mul)), OuterTy);
           }
+        }
+      }
       break;
     }
   }