Always compute all the bits in ComputeMaskedBits.
This allows us to keep passing reduced masks to SimplifyDemandedBits, but
know about all the bits if SimplifyDemandedBits fails. This allows instcombine
to simplify cases like the one in the included testcase.

git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@154011 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index a71cf82..e036687 100644
--- a/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -1452,16 +1452,14 @@
   if (VT.isInteger() && !VT.isVector()) {
     APInt LHSZero, LHSOne;
     APInt RHSZero, RHSOne;
-    APInt Mask = APInt::getAllOnesValue(VT.getScalarType().getSizeInBits());
-    DAG.ComputeMaskedBits(N0, Mask, LHSZero, LHSOne);
+    DAG.ComputeMaskedBits(N0, LHSZero, LHSOne);
 
     if (LHSZero.getBoolValue()) {
-      DAG.ComputeMaskedBits(N1, Mask, RHSZero, RHSOne);
+      DAG.ComputeMaskedBits(N1, RHSZero, RHSOne);
 
       // If all possibly-set bits on the LHS are clear on the RHS, return an OR.
       // If all possibly-set bits on the RHS are clear on the LHS, return an OR.
-      if ((RHSZero & (~LHSZero & Mask)) == (~LHSZero & Mask) ||
-          (LHSZero & (~RHSZero & Mask)) == (~RHSZero & Mask))
+      if ((RHSZero & ~LHSZero) == ~LHSZero || (LHSZero & ~RHSZero) == ~RHSZero)
         return DAG.getNode(ISD::OR, N->getDebugLoc(), VT, N0, N1);
     }
   }
@@ -1547,16 +1545,14 @@
   // fold (addc a, b) -> (or a, b), CARRY_FALSE iff a and b share no bits.
   APInt LHSZero, LHSOne;
   APInt RHSZero, RHSOne;
-  APInt Mask = APInt::getAllOnesValue(VT.getScalarType().getSizeInBits());
-  DAG.ComputeMaskedBits(N0, Mask, LHSZero, LHSOne);
+  DAG.ComputeMaskedBits(N0, LHSZero, LHSOne);
 
   if (LHSZero.getBoolValue()) {
-    DAG.ComputeMaskedBits(N1, Mask, RHSZero, RHSOne);
+    DAG.ComputeMaskedBits(N1, RHSZero, RHSOne);
 
     // If all possibly-set bits on the LHS are clear on the RHS, return an OR.
     // If all possibly-set bits on the RHS are clear on the LHS, return an OR.
-    if ((RHSZero & (~LHSZero & Mask)) == (~LHSZero & Mask) ||
-        (LHSZero & (~RHSZero & Mask)) == (~RHSZero & Mask))
+    if ((RHSZero & ~LHSZero) == ~LHSZero || (LHSZero & ~RHSZero) == ~RHSZero)
       return CombineTo(N, DAG.getNode(ISD::OR, N->getDebugLoc(), VT, N0, N1),
                        DAG.getNode(ISD::CARRY_FALSE,
                                    N->getDebugLoc(), MVT::Glue));
@@ -3835,8 +3831,7 @@
   if (N1C && N0.getOpcode() == ISD::CTLZ &&
       N1C->getAPIntValue() == Log2_32(VT.getSizeInBits())) {
     APInt KnownZero, KnownOne;
-    APInt Mask = APInt::getAllOnesValue(VT.getScalarType().getSizeInBits());
-    DAG.ComputeMaskedBits(N0.getOperand(0), Mask, KnownZero, KnownOne);
+    DAG.ComputeMaskedBits(N0.getOperand(0), KnownZero, KnownOne);
 
     // If any of the input bits are KnownOne, then the input couldn't be all
     // zeros, thus the result of the srl will always be zero.
@@ -3844,7 +3839,7 @@
 
     // If all of the bits input the to ctlz node are known to be zero, then
     // the result of the ctlz is "32" and the result of the shift is one.
-    APInt UnknownBits = ~KnownZero & Mask;
+    APInt UnknownBits = ~KnownZero;
     if (UnknownBits == 0) return DAG.getConstant(1, VT);
 
     // Otherwise, check to see if there is exactly one bit input to the ctlz.
@@ -4439,8 +4434,8 @@
                           std::min(Op.getValueSizeInBits(),
                                    VT.getSizeInBits()));
     APInt KnownZero, KnownOne;
-    DAG.ComputeMaskedBits(Op, TruncatedBits, KnownZero, KnownOne);
-    if (TruncatedBits == KnownZero) {
+    DAG.ComputeMaskedBits(Op, KnownZero, KnownOne);
+    if (TruncatedBits == (KnownZero & TruncatedBits)) {
       if (VT.bitsGT(Op.getValueType()))
         return DAG.getNode(ISD::ZERO_EXTEND, N->getDebugLoc(), VT, Op);
       if (VT.bitsLT(Op.getValueType()))
diff --git a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 41506d1..95ddb1e 100644
--- a/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -1362,7 +1362,7 @@
 
   APInt HighBitMask = APInt::getHighBitsSet(ShBits, ShBits - Log2_32(NVTBits));
   APInt KnownZero, KnownOne;
-  DAG.ComputeMaskedBits(N->getOperand(1), HighBitMask, KnownZero, KnownOne);
+  DAG.ComputeMaskedBits(N->getOperand(1), KnownZero, KnownOne);
 
   // If we don't know anything about the high bits, exit.
   if (((KnownZero|KnownOne) & HighBitMask) == 0)
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
index 48f8f77..305eee8 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
@@ -1627,7 +1627,7 @@
 bool SelectionDAG::MaskedValueIsZero(SDValue Op, const APInt &Mask,
                                      unsigned Depth) const {
   APInt KnownZero, KnownOne;
-  ComputeMaskedBits(Op, Mask, KnownZero, KnownOne, Depth);
+  ComputeMaskedBits(Op, KnownZero, KnownOne, Depth);
   assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
   return (KnownZero & Mask) == Mask;
 }
@@ -1636,15 +1636,12 @@
 /// known to be either zero or one and return them in the KnownZero/KnownOne
 /// bitsets.  This code only analyzes bits in Mask, in order to short-circuit
 /// processing.
-void SelectionDAG::ComputeMaskedBits(SDValue Op, const APInt &Mask,
-                                     APInt &KnownZero, APInt &KnownOne,
-                                     unsigned Depth) const {
-  unsigned BitWidth = Mask.getBitWidth();
-  assert(BitWidth == Op.getValueType().getScalarType().getSizeInBits() &&
-         "Mask size mismatches value type size!");
+void SelectionDAG::ComputeMaskedBits(SDValue Op, APInt &KnownZero,
+                                     APInt &KnownOne, unsigned Depth) const {
+  unsigned BitWidth = Op.getValueType().getScalarType().getSizeInBits();
 
   KnownZero = KnownOne = APInt(BitWidth, 0);   // Don't know anything.
-  if (Depth == 6 || Mask == 0)
+  if (Depth == 6)
     return;  // Limit search depth.
 
   APInt KnownZero2, KnownOne2;
@@ -1652,14 +1649,13 @@
   switch (Op.getOpcode()) {
   case ISD::Constant:
     // We know all of the bits for a constant!
-    KnownOne = cast<ConstantSDNode>(Op)->getAPIntValue() & Mask;
-    KnownZero = ~KnownOne & Mask;
+    KnownOne = cast<ConstantSDNode>(Op)->getAPIntValue();
+    KnownZero = ~KnownOne;
     return;
   case ISD::AND:
     // If either the LHS or the RHS are Zero, the result is zero.
-    ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
-    ComputeMaskedBits(Op.getOperand(0), Mask & ~KnownZero,
-                      KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(1), KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
 
@@ -1669,9 +1665,8 @@
     KnownZero |= KnownZero2;
     return;
   case ISD::OR:
-    ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
-    ComputeMaskedBits(Op.getOperand(0), Mask & ~KnownOne,
-                      KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(1), KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
 
@@ -1681,8 +1676,8 @@
     KnownOne |= KnownOne2;
     return;
   case ISD::XOR: {
-    ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
-    ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(1), KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
 
@@ -1694,9 +1689,8 @@
     return;
   }
   case ISD::MUL: {
-    APInt Mask2 = APInt::getAllOnesValue(BitWidth);
-    ComputeMaskedBits(Op.getOperand(1), Mask2, KnownZero, KnownOne, Depth+1);
-    ComputeMaskedBits(Op.getOperand(0), Mask2, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(1), KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
 
@@ -1715,33 +1709,29 @@
     LeadZ = std::min(LeadZ, BitWidth);
     KnownZero = APInt::getLowBitsSet(BitWidth, TrailZ) |
                 APInt::getHighBitsSet(BitWidth, LeadZ);
-    KnownZero &= Mask;
     return;
   }
   case ISD::UDIV: {
     // For the purposes of computing leading zeros we can conservatively
     // treat a udiv as a logical right shift by the power of 2 known to
     // be less than the denominator.
-    APInt AllOnes = APInt::getAllOnesValue(BitWidth);
-    ComputeMaskedBits(Op.getOperand(0),
-                      AllOnes, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero2, KnownOne2, Depth+1);
     unsigned LeadZ = KnownZero2.countLeadingOnes();
 
     KnownOne2.clearAllBits();
     KnownZero2.clearAllBits();
-    ComputeMaskedBits(Op.getOperand(1),
-                      AllOnes, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(1), KnownZero2, KnownOne2, Depth+1);
     unsigned RHSUnknownLeadingOnes = KnownOne2.countLeadingZeros();
     if (RHSUnknownLeadingOnes != BitWidth)
       LeadZ = std::min(BitWidth,
                        LeadZ + BitWidth - RHSUnknownLeadingOnes - 1);
 
-    KnownZero = APInt::getHighBitsSet(BitWidth, LeadZ) & Mask;
+    KnownZero = APInt::getHighBitsSet(BitWidth, LeadZ);
     return;
   }
   case ISD::SELECT:
-    ComputeMaskedBits(Op.getOperand(2), Mask, KnownZero, KnownOne, Depth+1);
-    ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(2), KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(1), KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
 
@@ -1750,8 +1740,8 @@
     KnownZero &= KnownZero2;
     return;
   case ISD::SELECT_CC:
-    ComputeMaskedBits(Op.getOperand(3), Mask, KnownZero, KnownOne, Depth+1);
-    ComputeMaskedBits(Op.getOperand(2), Mask, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(3), KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(2), KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
 
@@ -1783,8 +1773,7 @@
       if (ShAmt >= BitWidth)
         return;
 
-      ComputeMaskedBits(Op.getOperand(0), Mask.lshr(ShAmt),
-                        KnownZero, KnownOne, Depth+1);
+      ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
       assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
       KnownZero <<= ShAmt;
       KnownOne  <<= ShAmt;
@@ -1801,13 +1790,12 @@
       if (ShAmt >= BitWidth)
         return;
 
-      ComputeMaskedBits(Op.getOperand(0), (Mask << ShAmt),
-                        KnownZero, KnownOne, Depth+1);
+      ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
       assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
       KnownZero = KnownZero.lshr(ShAmt);
       KnownOne  = KnownOne.lshr(ShAmt);
 
-      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt) & Mask;
+      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
       KnownZero |= HighBits;  // High bits known zero.
     }
     return;
@@ -1819,15 +1807,11 @@
       if (ShAmt >= BitWidth)
         return;
 
-      APInt InDemandedMask = (Mask << ShAmt);
       // If any of the demanded bits are produced by the sign extension, we also
       // demand the input sign bit.
-      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt) & Mask;
-      if (HighBits.getBoolValue())
-        InDemandedMask |= APInt::getSignBit(BitWidth);
+      APInt HighBits = APInt::getHighBitsSet(BitWidth, ShAmt);
 
-      ComputeMaskedBits(Op.getOperand(0), InDemandedMask, KnownZero, KnownOne,
-                        Depth+1);
+      ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
       assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
       KnownZero = KnownZero.lshr(ShAmt);
       KnownOne  = KnownOne.lshr(ShAmt);
@@ -1849,10 +1833,10 @@
 
     // Sign extension.  Compute the demanded bits in the result that are not
     // present in the input.
-    APInt NewBits = APInt::getHighBitsSet(BitWidth, BitWidth - EBits) & Mask;
+    APInt NewBits = APInt::getHighBitsSet(BitWidth, BitWidth - EBits);
 
     APInt InSignBit = APInt::getSignBit(EBits);
-    APInt InputDemandedBits = Mask & APInt::getLowBitsSet(BitWidth, EBits);
+    APInt InputDemandedBits = APInt::getLowBitsSet(BitWidth, EBits);
 
     // If the sign extended bits are demanded, we know that the sign
     // bit is demanded.
@@ -1860,8 +1844,9 @@
     if (NewBits.getBoolValue())
       InputDemandedBits |= InSignBit;
 
-    ComputeMaskedBits(Op.getOperand(0), InputDemandedBits,
-                      KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
+    KnownOne &= InputDemandedBits;
+    KnownZero &= InputDemandedBits;
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
 
     // If the sign bit of the input is known set or clear, then we know the
@@ -1893,20 +1878,19 @@
     if (ISD::isZEXTLoad(Op.getNode())) {
       EVT VT = LD->getMemoryVT();
       unsigned MemBits = VT.getScalarType().getSizeInBits();
-      KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - MemBits) & Mask;
+      KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - MemBits);
     } else if (const MDNode *Ranges = LD->getRanges()) {
-      computeMaskedBitsLoad(*Ranges, Mask, KnownZero);
+      computeMaskedBitsLoad(*Ranges, KnownZero);
     }
     return;
   }
   case ISD::ZERO_EXTEND: {
     EVT InVT = Op.getOperand(0).getValueType();
     unsigned InBits = InVT.getScalarType().getSizeInBits();
-    APInt NewBits   = APInt::getHighBitsSet(BitWidth, BitWidth - InBits) & Mask;
-    APInt InMask    = Mask.trunc(InBits);
+    APInt NewBits   = APInt::getHighBitsSet(BitWidth, BitWidth - InBits);
     KnownZero = KnownZero.trunc(InBits);
     KnownOne = KnownOne.trunc(InBits);
-    ComputeMaskedBits(Op.getOperand(0), InMask, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
     KnownZero = KnownZero.zext(BitWidth);
     KnownOne = KnownOne.zext(BitWidth);
     KnownZero |= NewBits;
@@ -1916,17 +1900,11 @@
     EVT InVT = Op.getOperand(0).getValueType();
     unsigned InBits = InVT.getScalarType().getSizeInBits();
     APInt InSignBit = APInt::getSignBit(InBits);
-    APInt NewBits   = APInt::getHighBitsSet(BitWidth, BitWidth - InBits) & Mask;
-    APInt InMask = Mask.trunc(InBits);
-
-    // If any of the sign extended bits are demanded, we know that the sign
-    // bit is demanded. Temporarily set this bit in the mask for our callee.
-    if (NewBits.getBoolValue())
-      InMask |= InSignBit;
+    APInt NewBits   = APInt::getHighBitsSet(BitWidth, BitWidth - InBits);
 
     KnownZero = KnownZero.trunc(InBits);
     KnownOne = KnownOne.trunc(InBits);
-    ComputeMaskedBits(Op.getOperand(0), InMask, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
 
     // Note if the sign bit is known to be zero or one.
     bool SignBitKnownZero = KnownZero.isNegative();
@@ -1934,13 +1912,6 @@
     assert(!(SignBitKnownZero && SignBitKnownOne) &&
            "Sign bit can't be known to be both zero and one!");
 
-    // If the sign bit wasn't actually demanded by our caller, we don't
-    // want it set in the KnownZero and KnownOne result values. Reset the
-    // mask and reapply it to the result values.
-    InMask = Mask.trunc(InBits);
-    KnownZero &= InMask;
-    KnownOne  &= InMask;
-
     KnownZero = KnownZero.zext(BitWidth);
     KnownOne = KnownOne.zext(BitWidth);
 
@@ -1954,10 +1925,9 @@
   case ISD::ANY_EXTEND: {
     EVT InVT = Op.getOperand(0).getValueType();
     unsigned InBits = InVT.getScalarType().getSizeInBits();
-    APInt InMask = Mask.trunc(InBits);
     KnownZero = KnownZero.trunc(InBits);
     KnownOne = KnownOne.trunc(InBits);
-    ComputeMaskedBits(Op.getOperand(0), InMask, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
     KnownZero = KnownZero.zext(BitWidth);
     KnownOne = KnownOne.zext(BitWidth);
     return;
@@ -1965,10 +1935,9 @@
   case ISD::TRUNCATE: {
     EVT InVT = Op.getOperand(0).getValueType();
     unsigned InBits = InVT.getScalarType().getSizeInBits();
-    APInt InMask = Mask.zext(InBits);
     KnownZero = KnownZero.zext(InBits);
     KnownOne = KnownOne.zext(InBits);
-    ComputeMaskedBits(Op.getOperand(0), InMask, KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
     assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?");
     KnownZero = KnownZero.trunc(BitWidth);
     KnownOne = KnownOne.trunc(BitWidth);
@@ -1977,9 +1946,8 @@
   case ISD::AssertZext: {
     EVT VT = cast<VTSDNode>(Op.getOperand(1))->getVT();
     APInt InMask = APInt::getLowBitsSet(BitWidth, VT.getSizeInBits());
-    ComputeMaskedBits(Op.getOperand(0), Mask & InMask, KnownZero,
-                      KnownOne, Depth+1);
-    KnownZero |= (~InMask) & Mask;
+    ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
+    KnownZero |= (~InMask);
     return;
   }
   case ISD::FGETSIGN:
@@ -1996,8 +1964,7 @@
         unsigned NLZ = (CLHS->getAPIntValue()+1).countLeadingZeros();
         // NLZ can't be BitWidth with no sign bit
         APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1);
-        ComputeMaskedBits(Op.getOperand(1), MaskV, KnownZero2, KnownOne2,
-                          Depth+1);
+        ComputeMaskedBits(Op.getOperand(1), KnownZero2, KnownOne2, Depth+1);
 
         // If all of the MaskV bits are known to be zero, then we know the
         // output top bits are zero, because we now know that the output is
@@ -2005,7 +1972,7 @@
         if ((KnownZero2 & MaskV) == MaskV) {
           unsigned NLZ2 = CLHS->getAPIntValue().countLeadingZeros();
           // Top bits known zero.
-          KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2) & Mask;
+          KnownZero = APInt::getHighBitsSet(BitWidth, NLZ2);
         }
       }
     }
@@ -2016,13 +1983,11 @@
     // Output known-0 bits are known if clear or set in both the low clear bits
     // common to both LHS & RHS.  For example, 8+(X<<3) is known to have the
     // low 3 bits clear.
-    APInt Mask2 = APInt::getLowBitsSet(BitWidth,
-                                       BitWidth - Mask.countLeadingZeros());
-    ComputeMaskedBits(Op.getOperand(0), Mask2, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
     unsigned KnownZeroOut = KnownZero2.countTrailingOnes();
 
-    ComputeMaskedBits(Op.getOperand(1), Mask2, KnownZero2, KnownOne2, Depth+1);
+    ComputeMaskedBits(Op.getOperand(1), KnownZero2, KnownOne2, Depth+1);
     assert((KnownZero2 & KnownOne2) == 0 && "Bits known to be one AND zero?");
     KnownZeroOut = std::min(KnownZeroOut,
                             KnownZero2.countTrailingOnes());
@@ -2046,7 +2011,7 @@
       if (RA.isPowerOf2()) {
         APInt LowBits = RA - 1;
         APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
-        ComputeMaskedBits(Op.getOperand(0), Mask2,KnownZero2,KnownOne2,Depth+1);
+        ComputeMaskedBits(Op.getOperand(0), KnownZero2,KnownOne2,Depth+1);
 
         // The low bits of the first operand are unchanged by the srem.
         KnownZero = KnownZero2 & LowBits;
@@ -2061,10 +2026,6 @@
         // the upper bits are all one.
         if (KnownOne2[BitWidth-1] && ((KnownOne2 & LowBits) != 0))
           KnownOne |= ~LowBits;
-
-        KnownZero &= Mask;
-        KnownOne &= Mask;
-
         assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
       }
     }
@@ -2074,9 +2035,8 @@
       const APInt &RA = Rem->getAPIntValue();
       if (RA.isPowerOf2()) {
         APInt LowBits = (RA - 1);
-        APInt Mask2 = LowBits & Mask;
-        KnownZero |= ~LowBits & Mask;
-        ComputeMaskedBits(Op.getOperand(0), Mask2, KnownZero, KnownOne,Depth+1);
+        KnownZero |= ~LowBits;
+        ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne,Depth+1);
         assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
         break;
       }
@@ -2084,16 +2044,13 @@
 
     // Since the result is less than or equal to either operand, any leading
     // zero bits in either operand must also exist in the result.
-    APInt AllOnes = APInt::getAllOnesValue(BitWidth);
-    ComputeMaskedBits(Op.getOperand(0), AllOnes, KnownZero, KnownOne,
-                      Depth+1);
-    ComputeMaskedBits(Op.getOperand(1), AllOnes, KnownZero2, KnownOne2,
-                      Depth+1);
+    ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
+    ComputeMaskedBits(Op.getOperand(1), KnownZero2, KnownOne2, Depth+1);
 
     uint32_t Leaders = std::max(KnownZero.countLeadingOnes(),
                                 KnownZero2.countLeadingOnes());
     KnownOne.clearAllBits();
-    KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & Mask;
+    KnownZero = APInt::getHighBitsSet(BitWidth, Leaders);
     return;
   }
   case ISD::FrameIndex:
@@ -2113,8 +2070,7 @@
   case ISD::INTRINSIC_W_CHAIN:
   case ISD::INTRINSIC_VOID:
     // Allow the target to implement this method for its nodes.
-    TLI.computeMaskedBitsForTargetNode(Op, Mask, KnownZero, KnownOne, *this,
-                                       Depth);
+    TLI.computeMaskedBitsForTargetNode(Op, KnownZero, KnownOne, *this, Depth);
     return;
   }
 }
@@ -2238,12 +2194,11 @@
     if (ConstantSDNode *CRHS = dyn_cast<ConstantSDNode>(Op.getOperand(1)))
       if (CRHS->isAllOnesValue()) {
         APInt KnownZero, KnownOne;
-        APInt Mask = APInt::getAllOnesValue(VTBits);
-        ComputeMaskedBits(Op.getOperand(0), Mask, KnownZero, KnownOne, Depth+1);
+        ComputeMaskedBits(Op.getOperand(0), KnownZero, KnownOne, Depth+1);
 
         // If the input is known to be 0 or 1, the output is 0/-1, which is all
         // sign bits set.
-        if ((KnownZero | APInt(VTBits, 1)) == Mask)
+        if ((KnownZero | APInt(VTBits, 1)).isAllOnesValue())
           return VTBits;
 
         // If we are subtracting one from a positive number, there is no carry
@@ -2264,11 +2219,10 @@
     if (ConstantSDNode *CLHS = dyn_cast<ConstantSDNode>(Op.getOperand(0)))
       if (CLHS->isNullValue()) {
         APInt KnownZero, KnownOne;
-        APInt Mask = APInt::getAllOnesValue(VTBits);
-        ComputeMaskedBits(Op.getOperand(1), Mask, KnownZero, KnownOne, Depth+1);
+        ComputeMaskedBits(Op.getOperand(1), KnownZero, KnownOne, Depth+1);
         // If the input is known to be 0 or 1, the output is 0/-1, which is all
         // sign bits set.
-        if ((KnownZero | APInt(VTBits, 1)) == Mask)
+        if ((KnownZero | APInt(VTBits, 1)).isAllOnesValue())
           return VTBits;
 
         // If the input is known to be positive (the sign bit is known clear),
@@ -2317,9 +2271,9 @@
   // Finally, if we can prove that the top bits of the result are 0's or 1's,
   // use this information.
   APInt KnownZero, KnownOne;
-  APInt Mask = APInt::getAllOnesValue(VTBits);
-  ComputeMaskedBits(Op, Mask, KnownZero, KnownOne, Depth);
+  ComputeMaskedBits(Op, KnownZero, KnownOne, Depth);
 
+  APInt Mask;
   if (KnownZero.isNegative()) {        // sign bit is 0
     Mask = KnownZero;
   } else if (KnownOne.isNegative()) {  // sign bit is 1;
@@ -6040,10 +5994,9 @@
   int64_t GVOffset = 0;
   if (TLI.isGAPlusOffset(Ptr.getNode(), GV, GVOffset)) {
     unsigned PtrWidth = TLI.getPointerTy().getSizeInBits();
-    APInt AllOnes = APInt::getAllOnesValue(PtrWidth);
     APInt KnownZero(PtrWidth, 0), KnownOne(PtrWidth, 0);
-    llvm::ComputeMaskedBits(const_cast<GlobalValue*>(GV), AllOnes,
-                            KnownZero, KnownOne, TLI.getTargetData());
+    llvm::ComputeMaskedBits(const_cast<GlobalValue*>(GV), KnownZero, KnownOne,
+                            TLI.getTargetData());
     unsigned AlignBits = KnownZero.countTrailingOnes();
     unsigned Align = AlignBits ? 1 << std::min(31U, AlignBits) : 0;
     if (Align)
diff --git a/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp b/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
index 8aabc02..605509b 100644
--- a/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
+++ b/lib/CodeGen/SelectionDAG/SelectionDAGISel.cpp
@@ -508,7 +508,6 @@
 
   Worklist.push_back(CurDAG->getRoot().getNode());
 
-  APInt Mask;
   APInt KnownZero;
   APInt KnownOne;
 
@@ -539,8 +538,7 @@
       continue;
 
     unsigned NumSignBits = CurDAG->ComputeNumSignBits(Src);
-    Mask = APInt::getAllOnesValue(SrcVT.getSizeInBits());
-    CurDAG->ComputeMaskedBits(Src, Mask, KnownZero, KnownOne);
+    CurDAG->ComputeMaskedBits(Src, KnownZero, KnownOne);
     FuncInfo->AddLiveOutRegInfo(DestReg, NumSignBits, KnownZero, KnownOne);
   } while (!Worklist.empty());
 }
@@ -1444,7 +1442,7 @@
   APInt NeededMask = DesiredMask & ~ActualMask;
 
   APInt KnownZero, KnownOne;
-  CurDAG->ComputeMaskedBits(LHS, NeededMask, KnownZero, KnownOne);
+  CurDAG->ComputeMaskedBits(LHS, KnownZero, KnownOne);
 
   // If all the missing bits in the or are already known to be set, match!
   if ((NeededMask & KnownOne) == NeededMask)
diff --git a/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
index e4a0016..eefb9e8 100644
--- a/lib/CodeGen/SelectionDAG/TargetLowering.cpp
+++ b/lib/CodeGen/SelectionDAG/TargetLowering.cpp
@@ -1244,7 +1244,7 @@
     if (Depth != 0) {
       // If not at the root, Just compute the KnownZero/KnownOne bits to
       // simplify things downstream.
-      TLO.DAG.ComputeMaskedBits(Op, DemandedMask, KnownZero, KnownOne, Depth);
+      TLO.DAG.ComputeMaskedBits(Op, KnownZero, KnownOne, Depth);
       return false;
     }
     // If this is the root being simplified, allow it to have multiple uses,
@@ -1263,8 +1263,8 @@
   switch (Op.getOpcode()) {
   case ISD::Constant:
     // We know all of the bits for a constant!
-    KnownOne = cast<ConstantSDNode>(Op)->getAPIntValue() & NewMask;
-    KnownZero = ~KnownOne & NewMask;
+    KnownOne = cast<ConstantSDNode>(Op)->getAPIntValue();
+    KnownZero = ~KnownOne;
     return false;   // Don't fall through, will infinitely loop.
   case ISD::AND:
     // If the RHS is a constant, check to see if the LHS would be zero without
@@ -1274,8 +1274,7 @@
     if (ConstantSDNode *RHSC = dyn_cast<ConstantSDNode>(Op.getOperand(1))) {
       APInt LHSZero, LHSOne;
       // Do not increment Depth here; that can cause an infinite loop.
-      TLO.DAG.ComputeMaskedBits(Op.getOperand(0), NewMask,
-                                LHSZero, LHSOne, Depth);
+      TLO.DAG.ComputeMaskedBits(Op.getOperand(0), LHSZero, LHSOne, Depth);
       // If the LHS already has zeros where RHSC does, this and is dead.
       if ((LHSZero & NewMask) == (~RHSC->getAPIntValue() & NewMask))
         return TLO.CombineTo(Op, Op.getOperand(0));
@@ -1725,11 +1724,11 @@
 
     // If the sign bit is known one, the top bits match.
     if (KnownOne.intersects(InSignBit)) {
-      KnownOne  |= NewBits;
-      KnownZero &= ~NewBits;
+      KnownOne |= NewBits;
+      assert((KnownZero & NewBits) == 0);
     } else {   // Otherwise, top bits aren't known.
-      KnownOne  &= ~NewBits;
-      KnownZero &= ~NewBits;
+      assert((KnownOne & NewBits) == 0);
+      assert((KnownZero & NewBits) == 0);
     }
     break;
   }
@@ -1863,7 +1862,7 @@
   // FALL THROUGH
   default:
     // Just use ComputeMaskedBits to compute output bits.
-    TLO.DAG.ComputeMaskedBits(Op, NewMask, KnownZero, KnownOne, Depth);
+    TLO.DAG.ComputeMaskedBits(Op, KnownZero, KnownOne, Depth);
     break;
   }
 
@@ -1879,7 +1878,6 @@
 /// in Mask are known to be either zero or one and return them in the
 /// KnownZero/KnownOne bitsets.
 void TargetLowering::computeMaskedBitsForTargetNode(const SDValue Op,
-                                                    const APInt &Mask,
                                                     APInt &KnownZero,
                                                     APInt &KnownOne,
                                                     const SelectionDAG &DAG,
@@ -1890,7 +1888,7 @@
           Op.getOpcode() == ISD::INTRINSIC_VOID) &&
          "Should use MaskedValueIsZero if you don't know whether Op"
          " is a target node!");
-  KnownZero = KnownOne = APInt(Mask.getBitWidth(), 0);
+  KnownZero = KnownOne = APInt(KnownOne.getBitWidth(), 0);
 }
 
 /// ComputeNumSignBitsForTargetNode - This method can be implemented by
@@ -1934,9 +1932,8 @@
   // Fall back to ComputeMaskedBits to catch other known cases.
   EVT OpVT = Val.getValueType();
   unsigned BitWidth = OpVT.getScalarType().getSizeInBits();
-  APInt Mask = APInt::getAllOnesValue(BitWidth);
   APInt KnownZero, KnownOne;
-  DAG.ComputeMaskedBits(Val, Mask, KnownZero, KnownOne);
+  DAG.ComputeMaskedBits(Val, KnownZero, KnownOne);
   return (KnownZero.countPopulation() == BitWidth - 1) &&
          (KnownOne.countPopulation() == 1);
 }