Simplify some logic in ComputeMaskedBits. And change ComputeMaskedBits
to pass the mask APInt by value, not by reference. 


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@47096 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index de5c411..3ec60cd 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1130,7 +1130,7 @@
 /// known to be either zero or one and return them in the KnownZero/KnownOne
 /// bitsets.  This code only analyzes bits in Mask, in order to short-circuit
 /// processing.
-void SelectionDAG::ComputeMaskedBits(SDOperand Op, APInt Mask, 
+void SelectionDAG::ComputeMaskedBits(SDOperand Op, const APInt &Mask, 
                                      APInt &KnownZero, APInt &KnownOne,
                                      unsigned Depth) const {
   unsigned BitWidth = Mask.getBitWidth();
@@ -1153,8 +1153,8 @@
   case ISD::AND:
     // If either the LHS or the RHS are Zero, the result is zero.
     ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
-    Mask &= ~KnownZero;
-    ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), Mask & ~KnownZero,
+                      KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); 
 
@@ -1165,8 +1165,8 @@
     return;
   case ISD::OR:
     ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
-    Mask &= ~KnownOne;
-    ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), Mask & ~KnownOne,
+                      KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?"); 
     
@@ -1271,21 +1271,19 @@
     return;
   case ISD::SIGN_EXTEND_INREG: {
     MVT::ValueType EVT = cast<VTSDNode>(Op.getOperand(1))->getVT();
+    unsigned EBits = MVT::getSizeInBits(EVT);
     
     // Sign extension.  Compute the demanded bits in the result that are not 
     // present in the input.
-    APInt NewBits = ~APInt::getLowBitsSet(BitWidth,
-                                          MVT::getSizeInBits(EVT)) & Mask;
+    APInt NewBits = APInt::getHighBitsSet(BitWidth, BitWidth - EBits) & Mask;
 
-    APInt InSignBit = APInt::getSignBit(MVT::getSizeInBits(EVT));
-    APInt InputDemandedBits =
-      Mask & APInt::getLowBitsSet(BitWidth,
-                                  MVT::getSizeInBits(EVT));
+    APInt InSignBit = APInt::getSignBit(EBits);
+    APInt InputDemandedBits = Mask & APInt::getLowBitsSet(BitWidth, EBits);
     
     // If the sign extended bits are demanded, we know that the sign
     // bit is demanded.
     InSignBit.zext(BitWidth);
-    if (!!NewBits)
+    if (NewBits.getBoolValue())
       InputDemandedBits |= InSignBit;
     
     ComputeMaskedBits(Op.getOperand(0), InputDemandedBits,
@@ -1318,19 +1316,20 @@
     if (ISD::isZEXTLoad(Op.Val)) {
       LoadSDNode *LD = cast<LoadSDNode>(Op);
       MVT::ValueType VT = LD->getMemoryVT();
-      KnownZero |= ~APInt::getLowBitsSet(BitWidth, MVT::getSizeInBits(VT)) & Mask;
+      unsigned MemBits = MVT::getSizeInBits(VT);
+      KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - MemBits) & Mask;
     }
     return;
   }
   case ISD::ZERO_EXTEND: {
     MVT::ValueType InVT = Op.getOperand(0).getValueType();
     unsigned InBits = MVT::getSizeInBits(InVT);
-    APInt InMask    = APInt::getLowBitsSet(BitWidth, InBits);
-    APInt NewBits = (~InMask) & Mask;
-    Mask.trunc(InBits);
+    APInt NewBits   = APInt::getHighBitsSet(BitWidth, BitWidth - InBits) & Mask;
+    APInt InMask    = Mask;
+    InMask.trunc(InBits);
     KnownZero.trunc(InBits);
     KnownOne.trunc(InBits);
-    ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), InMask, KnownZero, KnownOne, Depth+1);
     KnownZero.zext(BitWidth);
     KnownOne.zext(BitWidth);
     KnownZero |= NewBits;
@@ -1339,43 +1338,52 @@
   case ISD::SIGN_EXTEND: {
     MVT::ValueType InVT = Op.getOperand(0).getValueType();
     unsigned InBits = MVT::getSizeInBits(InVT);
-    APInt InMask    = APInt::getLowBitsSet(BitWidth, InBits);
     APInt InSignBit = APInt::getSignBit(InBits);
-    APInt NewBits   = (~InMask) & Mask;
+    APInt NewBits   = APInt::getHighBitsSet(BitWidth, BitWidth - InBits) & Mask;
+    APInt InMask = Mask;
+    InMask.trunc(InBits);
 
     // If any of the sign extended bits are demanded, we know that the sign
-    // bit is demanded.
-    InSignBit.zext(BitWidth);
-    if (!!(NewBits & Mask))
-      Mask |= InSignBit;
+    // bit is demanded. Temporarily set this bit in the mask for our callee.
+    if (NewBits.getBoolValue())
+      InMask |= InSignBit;
 
-    Mask.trunc(InBits);
     KnownZero.trunc(InBits);
     KnownOne.trunc(InBits);
-    ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), InMask, KnownZero, KnownOne, Depth+1);
+
+    // Note if the sign bit is known to be zero or one.
+    bool SignBitKnownZero = KnownZero.isNegative();
+    bool SignBitKnownOne  = KnownOne.isNegative();
+    assert(!(SignBitKnownZero && SignBitKnownOne) &&
+           "Sign bit can't be known to be both zero and one!");
+
+    // If the sign bit wasn't actually demanded by our caller, we don't
+    // want it set in the KnownZero and KnownOne result values. Reset the
+    // mask and reapply it to the result values.
+    InMask = Mask;
+    InMask.trunc(InBits);
+    KnownZero &= InMask;
+    KnownOne  &= InMask;
+
     KnownZero.zext(BitWidth);
     KnownOne.zext(BitWidth);
 
-    // If the sign bit is known zero or one, the  top bits match.
-    if (!!(KnownZero & InSignBit)) {
+    // If the sign bit is known zero or one, the top bits match.
+    if (SignBitKnownZero)
       KnownZero |= NewBits;
-      KnownOne  &= ~NewBits;
-    } else if (!!(KnownOne & InSignBit)) {
+    else if (SignBitKnownOne)
       KnownOne  |= NewBits;
-      KnownZero &= ~NewBits;
-    } else {   // Otherwise, top bits aren't known.
-      KnownOne  &= ~NewBits;
-      KnownZero &= ~NewBits;
-    }
     return;
   }
   case ISD::ANY_EXTEND: {
     MVT::ValueType InVT = Op.getOperand(0).getValueType();
     unsigned InBits = MVT::getSizeInBits(InVT);
-    Mask.trunc(InBits);
+    APInt InMask = Mask;
+    InMask.trunc(InBits);
     KnownZero.trunc(InBits);
     KnownOne.trunc(InBits);
-    ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), InMask, KnownZero, KnownOne, Depth+1);
     KnownZero.zext(BitWidth);
     KnownOne.zext(BitWidth);
     return;
@@ -1383,10 +1391,11 @@
   case ISD::TRUNCATE: {
     MVT::ValueType InVT = Op.getOperand(0).getValueType();
     unsigned InBits = MVT::getSizeInBits(InVT);
-    Mask.zext(InBits);
+    APInt InMask = Mask;
+    InMask.zext(InBits);
     KnownZero.zext(InBits);
     KnownOne.zext(InBits);
-    ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), InMask, KnownZero, KnownOne, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); 
     KnownZero.trunc(BitWidth);
     KnownOne.trunc(BitWidth);
@@ -1415,8 +1424,8 @@
     // Output known-0 bits are known if clear or set in both the low clear bits
     // common to both LHS & RHS.  For example, 8+(X<<3) is known to have the
     // low 3 bits clear.
-    unsigned KnownZeroOut = std::min((~KnownZero).countTrailingZeros(), 
-                                     (~KnownZero2).countTrailingZeros());
+    unsigned KnownZeroOut = std::min(KnownZero.countTrailingOnes(), 
+                                     KnownZero2.countTrailingOnes());
     
     KnownZero = APInt::getLowBitsSet(BitWidth, KnownZeroOut);
     KnownOne = APInt(BitWidth, 0);
@@ -1431,7 +1440,7 @@
     // positive if we can prove that X is >= 0 and < 16.
 
     // sign bit clear
-    if (!(CLHS->getAPIntValue() & APInt::getSignBit(BitWidth))) {
+    if (CLHS->getAPIntValue().isNonNegative()) {
       unsigned NLZ = (CLHS->getAPIntValue()+1).countLeadingZeros();
       // NLZ can't be BitWidth with no sign bit
       APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ);