teach SimplifyDemandedBits that exact shifts demand the bits they 
are shifting out since they do require them to be zeros.  Similarly
for NUW/NSW bits of shl


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125263 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 48184c0..bda8cea 100644
--- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -576,8 +576,16 @@
     break;
   case Instruction::Shl:
     if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth);
+      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
       APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt));
+      
+      // If the shift is NUW/NSW, then it does demand the high bits.
+      ShlOperator *IOp = cast<ShlOperator>(I);
+      if (IOp->hasNoSignedWrap())
+        DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
+      else if (IOp->hasNoUnsignedWrap())
+        DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
+      
       if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, 
                                KnownZero, KnownOne, Depth+1))
         return I;
@@ -592,10 +600,16 @@
   case Instruction::LShr:
     // For a logical shift right
     if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth);
+      uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
       
       // Unsigned shift right.
       APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
+      
+      // If the shift is exact, then it does demand the low bits (and knows that
+      // they are zero).
+      if (cast<LShrOperator>(I)->isExact())
+        DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+      
       if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn,
                                KnownZero, KnownOne, Depth+1))
         return I;
@@ -627,7 +641,7 @@
       return I->getOperand(0);
     
     if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      uint32_t ShiftAmt = SA->getLimitedValue(BitWidth);
+      uint32_t ShiftAmt = SA->getLimitedValue(BitWidth-1);
       
       // Signed shift right.
       APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt));
@@ -635,6 +649,12 @@
       // demanded.
       if (DemandedMask.countLeadingZeros() <= ShiftAmt)
         DemandedMaskIn.setBit(BitWidth-1);
+      
+      // If the shift is exact, then it does demand the low bits (and knows that
+      // they are zero).
+      if (cast<AShrOperator>(I)->isExact())
+        DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+      
       if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn,
                                KnownZero, KnownOne, Depth+1))
         return I;