rdar://12329730 (2nd part)

 This change tries to simmplify E1 = " X >> C1 << C2" into :
  - E2 = "X << (C2 - C1)" if C2 > C1, or
  - E2 = "X >> (C1 - C2)" if C1 > C2, or
  - E2 = X if C1 == C2.

 Reviewed by Nadav. Thanks!


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@169182 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 602b203..c65076e 100644
--- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -16,9 +16,10 @@
 #include "InstCombine.h"
 #include "llvm/DataLayout.h"
 #include "llvm/IntrinsicInst.h"
+#include "llvm/Support/PatternMatch.h"
 
 using namespace llvm;
-
+using namespace llvm::PatternMatch;
 
 /// ShrinkDemandedConstant - Check to see if the specified operand of the 
 /// specified instruction is a constant integer.  If so, check to see if there
@@ -580,6 +581,17 @@
     break;
   case Instruction::Shl:
     if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      {
+        Value *VarX; ConstantInt *C1;
+        if (match(I->getOperand(0), m_Shr(m_Value(VarX), m_ConstantInt(C1)))) {
+          Instruction *Shr = cast<Instruction>(I->getOperand(0));
+          Value *R = SimplifyShrShlDemandedBits(Shr, I, DemandedMask,
+                                                KnownZero, KnownOne);
+          if (R)
+            return R;
+        }
+      }
+
       uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
       APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt));
       
@@ -800,6 +812,78 @@
   return 0;
 }
 
+/// Helper routine of SimplifyDemandedUseBits. It tries to simplify
+/// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into
+/// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign
+/// of "C2-C1".
+///
+/// Suppose E1 and E2 are generally different in bits S={bm, bm+1,
+/// ..., bn}, without considering the specific value X is holding.
+/// This transformation is legal iff one of following conditions is hold:
+///  1) All the bit in S are 0, in this case E1 == E2.
+///  2) We don't care those bits in S, per the input DemandedMask.
+///  3) Combination of 1) and 2). Some bits in S are 0, and we don't care the
+///     rest bits.
+///
+/// Currently we only test condition 2).
+///
+/// As with SimplifyDemandedUseBits, it returns NULL if the simplification was
+/// not successful.
+Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr,
+  Instruction *Shl, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne) {
+
+  unsigned ShlAmt = cast<ConstantInt>(Shl->getOperand(1))->getZExtValue();
+  unsigned ShrAmt = cast<ConstantInt>(Shr->getOperand(1))->getZExtValue();
+
+  KnownOne.clearAllBits();
+  KnownZero = APInt::getBitsSet(KnownZero.getBitWidth(), 0, ShlAmt-1);
+  KnownZero &= DemandedMask;
+
+  if (ShlAmt == 0 || ShrAmt == 0)
+    return 0;
+
+  Value *VarX = Shr->getOperand(0);
+  Type *Ty = VarX->getType();
+
+  APInt BitMask1(Ty->getIntegerBitWidth(), (uint64_t)-1);
+  APInt BitMask2(Ty->getIntegerBitWidth(), (uint64_t)-1);
+
+  bool isLshr = (Shr->getOpcode() == Instruction::LShr);
+  BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) :
+                      (BitMask1.ashr(ShrAmt) << ShlAmt);
+
+  if (ShrAmt <= ShlAmt) {
+    BitMask2 <<= (ShlAmt - ShrAmt);
+  } else {
+    BitMask2 = isLshr ? BitMask2.lshr(ShrAmt - ShlAmt):
+                        BitMask2.ashr(ShrAmt - ShlAmt);
+  }
+
+  // Check if condition-2 (see the comment to this function) is satified.
+  if ((BitMask1 & DemandedMask) == (BitMask2 & DemandedMask)) {
+    if (ShrAmt == ShlAmt)
+      return VarX;
+
+    if (!Shr->hasOneUse())
+      return 0;
+
+    BinaryOperator *New;
+    if (ShrAmt < ShlAmt) {
+      Constant *Amt = ConstantInt::get(VarX->getType(), ShlAmt - ShrAmt);
+      New = BinaryOperator::CreateShl(VarX, Amt);
+      BinaryOperator *Orig = cast<BinaryOperator>(Shl);
+      New->setHasNoSignedWrap(Orig->hasNoSignedWrap());
+      New->setHasNoUnsignedWrap(Orig->hasNoUnsignedWrap());
+    } else {
+      Constant *Amt = ConstantInt::get(VarX->getType(), ShrAmt - ShlAmt);
+      New = BinaryOperator::CreateLShr(VarX, Amt);
+    }
+
+    return InsertNewInstWith(New, *Shl);
+  }
+
+  return 0;
+}
 
 /// SimplifyDemandedVectorElts - The specified value produces a vector with
 /// any number of elements. DemandedElts contains the set of elements that are