Enhance a bunch of transformations in instcombine to start generating
exact/nsw/nuw shifts and have instcombine infer them when it can prove
that the relevant properties are true for a given shift without them.

Also, a variety of refactoring to use the new patternmatch logic thrown
in for good luck.  I believe that this takes care of a bunch of related
code quality issues attached to PR8862.



git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@125267 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 559788b..d1a1fd6 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -19,11 +19,6 @@
 using namespace llvm;
 using namespace PatternMatch;
 
-/// SubOne - Subtract one from a ConstantInt.
-static Constant *SubOne(ConstantInt *C) {
-  return ConstantInt::get(C->getContext(), C->getValue()-1);
-}
-
 /// MultiplyOverflows - True if the multiply can not be expressed in an int
 /// this size.
 static bool MultiplyOverflows(ConstantInt *C1, ConstantInt *C2, bool sign) {
@@ -57,52 +52,39 @@
   if (Value *V = SimplifyUsingDistributiveLaws(I))
     return ReplaceInstUsesWith(I, V);
 
-  // Simplify mul instructions with a constant RHS.
-  if (Constant *Op1C = dyn_cast<Constant>(Op1)) {
-    if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1C)) {
-
-      // ((X << C1)*C2) == (X * (C2 << C1))
-      if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op0))
-        if (SI->getOpcode() == Instruction::Shl)
-          if (Constant *ShOp = dyn_cast<Constant>(SI->getOperand(1)))
-            return BinaryOperator::CreateMul(SI->getOperand(0),
-                                        ConstantExpr::getShl(CI, ShOp));
-
-      if (CI->isAllOnesValue())              // X * -1 == 0 - X
-        return BinaryOperator::CreateNeg(Op0, I.getName());
-
-      const APInt& Val = cast<ConstantInt>(CI)->getValue();
-      if (Val.isPowerOf2()) {          // Replace X*(2^C) with X << C
-        return BinaryOperator::CreateShl(Op0,
-                 ConstantInt::get(Op0->getType(), Val.logBase2()));
-      }
-    } else if (Op1C->getType()->isVectorTy()) {
-      if (Op1C->isNullValue())
-        return ReplaceInstUsesWith(I, Op1C);
-
-      if (ConstantVector *Op1V = dyn_cast<ConstantVector>(Op1C)) {
-        if (Op1V->isAllOnesValue())              // X * -1 == 0 - X
-          return BinaryOperator::CreateNeg(Op0, I.getName());
-
-        // As above, vector X*splat(1.0) -> X in all defined cases.
-        if (Constant *Splat = Op1V->getSplatValue()) {
-          if (ConstantInt *CI = dyn_cast<ConstantInt>(Splat))
-            if (CI->equalsInt(1))
-              return ReplaceInstUsesWith(I, Op0);
-        }
-      }
+  if (match(Op1, m_AllOnes()))  // X * -1 == 0 - X
+    return BinaryOperator::CreateNeg(Op0, I.getName());
+  
+  if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
+    
+    // ((X << C1)*C2) == (X * (C2 << C1))
+    if (BinaryOperator *SI = dyn_cast<BinaryOperator>(Op0))
+      if (SI->getOpcode() == Instruction::Shl)
+        if (Constant *ShOp = dyn_cast<Constant>(SI->getOperand(1)))
+          return BinaryOperator::CreateMul(SI->getOperand(0),
+                                           ConstantExpr::getShl(CI, ShOp));
+    
+    const APInt &Val = CI->getValue();
+    if (Val.isPowerOf2()) {          // Replace X*(2^C) with X << C
+      Constant *NewCst = ConstantInt::get(Op0->getType(), Val.logBase2());
+      BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, NewCst);
+      if (I.hasNoSignedWrap()) Shl->setHasNoSignedWrap();
+      if (I.hasNoUnsignedWrap()) Shl->setHasNoUnsignedWrap();
+      return Shl;
     }
     
-    if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0))
-      if (Op0I->getOpcode() == Instruction::Add && Op0I->hasOneUse() &&
-          isa<ConstantInt>(Op0I->getOperand(1)) && isa<ConstantInt>(Op1C)) {
-        // Canonicalize (X+C1)*C2 -> X*C2+C1*C2.
-        Value *Add = Builder->CreateMul(Op0I->getOperand(0), Op1C, "tmp");
-        Value *C1C2 = Builder->CreateMul(Op1C, Op0I->getOperand(1));
-        return BinaryOperator::CreateAdd(Add, C1C2);
-        
+    // Canonicalize (X+C1)*CI -> X*CI+C1*CI.
+    { Value *X; ConstantInt *C1;
+      if (Op0->hasOneUse() &&
+          match(Op0, m_Add(m_Value(X), m_ConstantInt(C1)))) {
+        Value *Add = Builder->CreateMul(X, CI, "tmp");
+        return BinaryOperator::CreateAdd(Add, Builder->CreateMul(C1, CI));
       }
-
+    }
+  }
+  
+  // Simplify mul instructions with a constant RHS.
+  if (isa<Constant>(Op1)) {    
     // Try to fold constant mul into select arguments.
     if (SelectInst *SI = dyn_cast<SelectInst>(Op0))
       if (Instruction *R = FoldOpIntoSelect(I, SI))
@@ -324,9 +306,8 @@
           if (MultiplyOverflows(RHS, LHSRHS,
                                 I.getOpcode()==Instruction::SDiv))
             return ReplaceInstUsesWith(I, Constant::getNullValue(I.getType()));
-          else 
-            return BinaryOperator::Create(I.getOpcode(), LHS->getOperand(0),
-                                      ConstantExpr::getMul(RHS, LHSRHS));
+          return BinaryOperator::Create(I.getOpcode(), LHS->getOperand(0),
+                                        ConstantExpr::getMul(RHS, LHSRHS));
         }
 
     if (!RHS->isZero()) { // avoid X udiv 0
@@ -365,54 +346,50 @@
     // X udiv 2^C -> X >> C
     // Check to see if this is an unsigned division with an exact power of 2,
     // if so, convert to a right shift.
-    if (C->getValue().isPowerOf2())  // 0 not included in isPowerOf2
-      return BinaryOperator::CreateLShr(Op0, 
+    if (C->getValue().isPowerOf2()) { // 0 not included in isPowerOf2
+      BinaryOperator *LShr =
+        BinaryOperator::CreateLShr(Op0, 
             ConstantInt::get(Op0->getType(), C->getValue().logBase2()));
+      if (I.isExact()) LShr->setIsExact();
+      return LShr;
+    }
 
     // X udiv C, where C >= signbit
     if (C->getValue().isNegative()) {
-      Value *IC = Builder->CreateICmpULT( Op0, C);
+      Value *IC = Builder->CreateICmpULT(Op0, C);
       return SelectInst::Create(IC, Constant::getNullValue(I.getType()),
                                 ConstantInt::get(I.getType(), 1));
     }
   }
 
   // X udiv (C1 << N), where C1 is "1<<C2"  -->  X >> (N+C2)
-  if (BinaryOperator *RHSI = dyn_cast<BinaryOperator>(I.getOperand(1))) {
-    if (RHSI->getOpcode() == Instruction::Shl &&
-        isa<ConstantInt>(RHSI->getOperand(0))) {
-      const APInt& C1 = cast<ConstantInt>(RHSI->getOperand(0))->getValue();
-      if (C1.isPowerOf2()) {
-        Value *N = RHSI->getOperand(1);
-        const Type *NTy = N->getType();
-        if (uint32_t C2 = C1.logBase2())
-          N = Builder->CreateAdd(N, ConstantInt::get(NTy, C2), "tmp");
-        return BinaryOperator::CreateLShr(Op0, N);
-      }
+  { const APInt *CI; Value *N;
+    if (match(Op1, m_Shl(m_Power2(CI), m_Value(N)))) {
+      if (*CI != 1)
+        N = Builder->CreateAdd(N, ConstantInt::get(I.getType(), CI->logBase2()),
+                               "tmp");
+      if (I.isExact())
+        return BinaryOperator::CreateExactLShr(Op0, N);
+      return BinaryOperator::CreateLShr(Op0, N);
     }
   }
   
   // udiv X, (Select Cond, C1, C2) --> Select Cond, (shr X, C1), (shr X, C2)
   // where C1&C2 are powers of two.
-  if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) 
-    if (ConstantInt *STO = dyn_cast<ConstantInt>(SI->getOperand(1)))
-      if (ConstantInt *SFO = dyn_cast<ConstantInt>(SI->getOperand(2)))  {
-        const APInt &TVA = STO->getValue(), &FVA = SFO->getValue();
-        if (TVA.isPowerOf2() && FVA.isPowerOf2()) {
-          // Compute the shift amounts
-          uint32_t TSA = TVA.logBase2(), FSA = FVA.logBase2();
-          // Construct the "on true" case of the select
-          Constant *TC = ConstantInt::get(Op0->getType(), TSA);
-          Value *TSI = Builder->CreateLShr(Op0, TC, SI->getName()+".t");
+  { Value *Cond; const APInt *C1, *C2;
+    if (match(Op1, m_Select(m_Value(Cond), m_Power2(C1), m_Power2(C2)))) {
+      // Construct the "on true" case of the select
+      Value *TSI = Builder->CreateLShr(Op0, C1->logBase2(), Op1->getName()+".t",
+                                       I.isExact());
   
-          // Construct the "on false" case of the select
-          Constant *FC = ConstantInt::get(Op0->getType(), FSA); 
-          Value *FSI = Builder->CreateLShr(Op0, FC, SI->getName()+".f");
-
-          // construct the select instruction and return it.
-          return SelectInst::Create(SI->getOperand(0), TSI, FSI, SI->getName());
-        }
-      }
+      // Construct the "on false" case of the select
+      Value *FSI = Builder->CreateLShr(Op0, C2->logBase2(), Op1->getName()+".f",
+                                       I.isExact());
+      
+      // construct the select instruction and return it.
+      return SelectInst::Create(Cond, TSI, FSI);
+    }
+  }
   return 0;
 }
 
@@ -431,20 +408,17 @@
     if (RHS->isAllOnesValue())
       return BinaryOperator::CreateNeg(Op0);
 
-    // sdiv X, C  -->  ashr X, log2(C)
-    if (cast<SDivOperator>(&I)->isExact() &&
-        RHS->getValue().isNonNegative() &&
+    // sdiv X, C  -->  ashr exact X, log2(C)
+    if (I.isExact() && RHS->getValue().isNonNegative() &&
         RHS->getValue().isPowerOf2()) {
       Value *ShAmt = llvm::ConstantInt::get(RHS->getType(),
                                             RHS->getValue().exactLogBase2());
-      return BinaryOperator::CreateAShr(Op0, ShAmt, I.getName());
+      return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName());
     }
 
     // -X/C  -->  X/-C  provided the negation doesn't overflow.
     if (SubOperator *Sub = dyn_cast<SubOperator>(Op0))
-      if (isa<Constant>(Sub->getOperand(0)) &&
-          cast<Constant>(Sub->getOperand(0))->isNullValue() &&
-          Sub->hasNoSignedWrap())
+      if (match(Sub->getOperand(0), m_Zero()) && Sub->hasNoSignedWrap())
         return BinaryOperator::CreateSDiv(Sub->getOperand(1),
                                           ConstantExpr::getNeg(RHS));
   }
@@ -458,9 +432,8 @@
         // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
         return BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
       }
-      ConstantInt *ShiftedInt;
-      if (match(Op1, m_Shl(m_ConstantInt(ShiftedInt), m_Value())) &&
-          ShiftedInt->getValue().isPowerOf2()) {
+      
+      if (match(Op1, m_Shl(m_Power2(), m_Value()))) {
         // X sdiv (1 << Y) -> X udiv (1 << Y) ( -> X u>> Y)
         // Safe because the only negative value (1 << Y) can take on is
         // INT_MIN, and X sdiv INT_MIN == X udiv INT_MIN == 0 if X doesn't have
@@ -555,43 +528,30 @@
   if (Instruction *common = commonIRemTransforms(I))
     return common;
   
-  if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
-    // X urem C^2 -> X and C
-    // Check to see if this is an unsigned remainder with an exact power of 2,
-    // if so, convert to a bitwise and.
-    if (ConstantInt *C = dyn_cast<ConstantInt>(RHS))
-      if (C->getValue().isPowerOf2())
-        return BinaryOperator::CreateAnd(Op0, SubOne(C));
+  // X urem C^2 -> X and C-1
+  { const APInt *C;
+    if (match(Op1, m_Power2(C)))
+      return BinaryOperator::CreateAnd(Op0,
+                                       ConstantInt::get(I.getType(), *C-1));
   }
 
-  if (Instruction *RHSI = dyn_cast<Instruction>(I.getOperand(1))) {
-    // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1)  
-    if (RHSI->getOpcode() == Instruction::Shl &&
-        isa<ConstantInt>(RHSI->getOperand(0))) {
-      if (cast<ConstantInt>(RHSI->getOperand(0))->getValue().isPowerOf2()) {
-        Constant *N1 = Constant::getAllOnesValue(I.getType());
-        Value *Add = Builder->CreateAdd(RHSI, N1, "tmp");
-        return BinaryOperator::CreateAnd(Op0, Add);
-      }
+  // Turn A % (C << N), where C is 2^k, into A & ((C << N)-1)  
+  if (match(Op1, m_Shl(m_Power2(), m_Value()))) {
+    Constant *N1 = Constant::getAllOnesValue(I.getType());
+    Value *Add = Builder->CreateAdd(Op1, N1, "tmp");
+    return BinaryOperator::CreateAnd(Op0, Add);
+  }
+
+  // urem X, (select Cond, 2^C1, 2^C2) -->
+  //    select Cond, (and X, C1-1), (and X, C2-1)
+  // when C1&C2 are powers of two.
+  { Value *Cond; const APInt *C1, *C2;
+    if (match(Op1, m_Select(m_Value(Cond), m_Power2(C1), m_Power2(C2)))) {
+      Value *TrueAnd = Builder->CreateAnd(Op0, *C1-1, Op1->getName()+".t");
+      Value *FalseAnd = Builder->CreateAnd(Op0, *C2-1, Op1->getName()+".f");
+      return SelectInst::Create(Cond, TrueAnd, FalseAnd);
     }
   }
-
-  // urem X, (select Cond, 2^C1, 2^C2) --> select Cond, (and X, C1), (and X, C2)
-  // where C1&C2 are powers of two.
-  if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) {
-    if (ConstantInt *STO = dyn_cast<ConstantInt>(SI->getOperand(1)))
-      if (ConstantInt *SFO = dyn_cast<ConstantInt>(SI->getOperand(2))) {
-        // STO == 0 and SFO == 0 handled above.
-        if ((STO->getValue().isPowerOf2()) && 
-            (SFO->getValue().isPowerOf2())) {
-          Value *TrueAnd = Builder->CreateAnd(Op0, SubOne(STO),
-                                              SI->getName()+".t");
-          Value *FalseAnd = Builder->CreateAnd(Op0, SubOne(SFO),
-                                               SI->getName()+".f");
-          return SelectInst::Create(SI->getOperand(0), TrueAnd, FalseAnd);
-        }
-      }
-  }
   
   return 0;
 }