Don't try to simplify urem and srem using arithmetic rules that don't work
under modulo (overflow). Fixes PR1933.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@47987 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/InstructionCombining.cpp b/lib/Transforms/Scalar/InstructionCombining.cpp
index 1000ba6..8e99dcc 100644
--- a/lib/Transforms/Scalar/InstructionCombining.cpp
+++ b/lib/Transforms/Scalar/InstructionCombining.cpp
@@ -834,6 +834,49 @@
       return;
     }
     break;
+  case Instruction::SRem:
+    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      APInt RA = Rem->getValue();
+      if (RA.isPowerOf2() || (-RA).isPowerOf2()) {
+        APInt LowBits = RA.isStrictlyPositive() ? ((RA - 1) | RA) : ~RA;
+        APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
+        ComputeMaskedBits(I->getOperand(0), Mask2,KnownZero2,KnownOne2,Depth+1);
+
+        // The sign of a remainder is equal to the sign of the first
+        // operand (zero being positive).
+        if (KnownZero2[BitWidth-1] || ((KnownZero2 & LowBits) == LowBits))
+          KnownZero2 |= ~LowBits;
+        else if (KnownOne2[BitWidth-1])
+          KnownOne2 |= ~LowBits;
+
+        KnownZero |= KnownZero2 & Mask;
+        KnownOne |= KnownOne2 & Mask;
+
+        assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); 
+      }
+    }
+    break;
+  case Instruction::URem:
+    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      APInt RA = Rem->getValue();
+      if (RA.isStrictlyPositive() && RA.isPowerOf2()) {
+        APInt LowBits = (RA - 1) | RA;
+        APInt Mask2 = LowBits & Mask;
+        KnownZero |= ~LowBits & Mask;
+        ComputeMaskedBits(I->getOperand(0), Mask2, KnownZero, KnownOne,Depth+1);
+        assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+      }
+    } else {
+      // Since the result is less than or equal to RHS, any leading zero bits
+      // in RHS must also exist in the result.
+      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
+      ComputeMaskedBits(I->getOperand(1), AllOnes, KnownZero2, KnownOne2, Depth+1);
+
+      uint32_t Leaders = KnownZero2.countLeadingOnes();
+      KnownZero |= APInt::getHighBitsSet(BitWidth, Leaders) & Mask;
+      assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?");
+    }
+    break;
   }
 }
 
@@ -1418,6 +1461,52 @@
       }
     }
     break;
+  case Instruction::SRem:
+    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      APInt RA = Rem->getValue();
+      if (RA.isPowerOf2() || (-RA).isPowerOf2()) {
+        APInt LowBits = RA.isStrictlyPositive() ? (RA - 1) | RA : ~RA;
+        APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
+        if (SimplifyDemandedBits(I->getOperand(0), Mask2,
+                                 LHSKnownZero, LHSKnownOne, Depth+1))
+          return true;
+
+        if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits))
+          LHSKnownZero |= ~LowBits;
+        else if (LHSKnownOne[BitWidth-1])
+          LHSKnownOne |= ~LowBits;
+
+        KnownZero |= LHSKnownZero & DemandedMask;
+        KnownOne |= LHSKnownOne & DemandedMask;
+
+        assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); 
+      }
+    }
+    break;
+  case Instruction::URem:
+    if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) {
+      APInt RA = Rem->getValue();
+      if (RA.isPowerOf2()) {
+        APInt LowBits = (RA - 1) | RA;
+        APInt Mask2 = LowBits & DemandedMask;
+        KnownZero |= ~LowBits & DemandedMask;
+        if (SimplifyDemandedBits(I->getOperand(0), Mask2,
+                                 KnownZero, KnownOne, Depth+1))
+          return true;
+
+        assert((KnownZero & KnownOne) == 0&&"Bits known to be one AND zero?"); 
+      }
+    } else {
+      APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0);
+      APInt AllOnes = APInt::getAllOnesValue(BitWidth);
+      if (SimplifyDemandedBits(I->getOperand(1), AllOnes,
+                               KnownZero2, KnownOne2, Depth+1))
+        return true;
+
+      uint32_t Leaders = KnownZero2.countLeadingOnes();
+      KnownZero |= APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
+    }
+    break;
   }
   
   // If the client is only demanding bits that we know, return the known
@@ -2780,46 +2869,6 @@
   return commonDivTransforms(I);
 }
 
-/// GetFactor - If we can prove that the specified value is at least a multiple
-/// of some factor, return that factor.
-static Constant *GetFactor(Value *V) {
-  if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
-    return CI;
-  
-  // Unless we can be tricky, we know this is a multiple of 1.
-  Constant *Result = ConstantInt::get(V->getType(), 1);
-  
-  Instruction *I = dyn_cast<Instruction>(V);
-  if (!I) return Result;
-  
-  if (I->getOpcode() == Instruction::Mul) {
-    // Handle multiplies by a constant, etc.
-    return ConstantExpr::getMul(GetFactor(I->getOperand(0)),
-                                GetFactor(I->getOperand(1)));
-  } else if (I->getOpcode() == Instruction::Shl) {
-    // (X<<C) -> X * (1 << C)
-    if (Constant *ShRHS = dyn_cast<Constant>(I->getOperand(1))) {
-      ShRHS = ConstantExpr::getShl(Result, ShRHS);
-      return ConstantExpr::getMul(GetFactor(I->getOperand(0)), ShRHS);
-    }
-  } else if (I->getOpcode() == Instruction::And) {
-    if (ConstantInt *RHS = dyn_cast<ConstantInt>(I->getOperand(1))) {
-      // X & 0xFFF0 is known to be a multiple of 16.
-      uint32_t Zeros = RHS->getValue().countTrailingZeros();
-      if (Zeros != V->getType()->getPrimitiveSizeInBits())// don't shift by "32"
-        return ConstantExpr::getShl(Result, 
-                                    ConstantInt::get(Result->getType(), Zeros));
-    }
-  } else if (CastInst *CI = dyn_cast<CastInst>(I)) {
-    // Only handle int->int casts.
-    if (!CI->isIntegerCast())
-      return Result;
-    Value *Op = CI->getOperand(0);
-    return ConstantExpr::getCast(CI->getOpcode(), GetFactor(Op), V->getType());
-  }    
-  return Result;
-}
-
 /// This function implements the transforms on rem instructions that work
 /// regardless of the kind of rem instruction it is (urem, srem, or frem). It 
 /// is used by the visitors to those instructions.
@@ -2901,9 +2950,13 @@
         if (Instruction *NV = FoldOpIntoPhi(I))
           return NV;
       }
-      // (X * C1) % C2 --> 0  iff  C1 % C2 == 0
-      if (ConstantExpr::getSRem(GetFactor(Op0I), RHS)->isNullValue())
-        return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+
+      // See if we can fold away this rem instruction.
+      uint32_t BitWidth = cast<IntegerType>(I.getType())->getBitWidth();
+      APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0);
+      if (SimplifyDemandedBits(&I, APInt::getAllOnesValue(BitWidth),
+                               KnownZero, KnownOne))
+        return &I;
     }
   }