[InstCombine] Add splat vector constant support to foldICmpAddOpConst.
Differential Revision: https://reviews.llvm.org/D50946
llvm-svn: 340231
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index a90a46d..2926e26 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -1079,19 +1079,20 @@
       ConstantInt::get(CmpTy, !CmpInst::isTrueWhenEqual(ICI.getPredicate())));
 }
 
-/// Fold "icmp pred (X+CI), X".
-Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI,
+/// Fold "icmp pred (X+C), X".
+Instruction *InstCombiner::foldICmpAddOpConst(Value *X, const APInt &C,
                                               ICmpInst::Predicate Pred) {
   // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
   // so the values can never be equal.  Similarly for all other "or equals"
   // operators.
+  assert(!!C && "C should not be zero!");
 
   // (X+1) <u X        --> X >u (MAXUINT-1)        --> X == 255
   // (X+2) <u X        --> X >u (MAXUINT-2)        --> X > 253
   // (X+MAXUINT) <u X  --> X >u (MAXUINT-MAXUINT)  --> X != 0
   if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) {
-    Value *R =
-      ConstantExpr::getSub(ConstantInt::getAllOnesValue(CI->getType()), CI);
+    Constant *R = ConstantInt::get(X->getType(),
+                                   APInt::getMaxValue(C.getBitWidth()) - C);
     return new ICmpInst(ICmpInst::ICMP_UGT, X, R);
   }
 
@@ -1099,11 +1100,10 @@
   // (X+2) >u X        --> X <u (0-2)        --> X <u 254
   // (X+MAXUINT) >u X  --> X <u (0-MAXUINT)  --> X <u 1  --> X == 0
   if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE)
-    return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantExpr::getNeg(CI));
+    return new ICmpInst(ICmpInst::ICMP_ULT, X,
+                        ConstantInt::get(X->getType(), -C));
 
-  unsigned BitWidth = CI->getType()->getPrimitiveSizeInBits();
-  ConstantInt *SMax = ConstantInt::get(X->getContext(),
-                                       APInt::getSignedMaxValue(BitWidth));
+  APInt SMax = APInt::getSignedMaxValue(C.getBitWidth());
 
   // (X+ 1) <s X       --> X >s (MAXSINT-1)          --> X == 127
   // (X+ 2) <s X       --> X >s (MAXSINT-2)          --> X >s 125
@@ -1112,7 +1112,8 @@
   // (X+ -2) <s X      --> X >s (MAXSINT- -2)        --> X >s 126
   // (X+ -1) <s X      --> X >s (MAXSINT- -1)        --> X != 127
   if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
-    return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantExpr::getSub(SMax, CI));
+    return new ICmpInst(ICmpInst::ICMP_SGT, X,
+                        ConstantInt::get(X->getType(), SMax - C));
 
   // (X+ 1) >s X       --> X <s (MAXSINT-(1-1))       --> X != 127
   // (X+ 2) >s X       --> X <s (MAXSINT-(2-1))       --> X <s 126
@@ -1122,8 +1123,8 @@
   // (X+ -1) >s X      --> X <s (MAXSINT-(-1-1))      --> X == -128
 
   assert(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE);
-  Constant *C = Builder.getInt(CI->getValue() - 1);
-  return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantExpr::getSub(SMax, C));
+  return new ICmpInst(ICmpInst::ICMP_SLT, X,
+                      ConstantInt::get(X->getType(), SMax - (C - 1)));
 }
 
 /// Handle "(icmp eq/ne (ashr/lshr AP2, A), AP1)" ->
@@ -4877,14 +4878,15 @@
           return ExtractValueInst::Create(ACXI, 1);
 
   {
-    Value *X; ConstantInt *Cst;
+    Value *X;
+    const APInt *C;
     // icmp X+Cst, X
-    if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X)
-      return foldICmpAddOpConst(X, Cst, I.getPredicate());
+    if (match(Op0, m_Add(m_Value(X), m_APInt(C))) && Op1 == X)
+      return foldICmpAddOpConst(X, *C, I.getPredicate());
 
     // icmp X, X+Cst
-    if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X)
-      return foldICmpAddOpConst(X, Cst, I.getSwappedPredicate());
+    if (match(Op1, m_Add(m_Value(X), m_APInt(C))) && Op0 == X)
+      return foldICmpAddOpConst(X, *C, I.getSwappedPredicate());
   }
 
   if (I.getType()->isVectorTy())
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 58ef3d4..4d9a53c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -813,7 +813,7 @@
                                             ConstantInt *AndCst = nullptr);
   Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
                                     Constant *RHSC);
-  Instruction *foldICmpAddOpConst(Value *X, ConstantInt *CI,
+  Instruction *foldICmpAddOpConst(Value *X, const APInt &C,
                                   ICmpInst::Predicate Pred);
   Instruction *foldICmpWithCastAndCast(ICmpInst &ICI);
 
diff --git a/llvm/test/Transforms/InstCombine/icmp.ll b/llvm/test/Transforms/InstCombine/icmp.ll
index 64f060b..b3e8bd0 100644
--- a/llvm/test/Transforms/InstCombine/icmp.ll
+++ b/llvm/test/Transforms/InstCombine/icmp.ll
@@ -101,8 +101,7 @@
 
 define <2 x i1> @test7_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @test7_vec(
-; CHECK-NEXT:    [[A:%.*]] = add <2 x i32> [[X:%.*]], <i32 -1, i32 -1>
-; CHECK-NEXT:    [[B:%.*]] = icmp ult <2 x i32> [[A]], [[X]]
+; CHECK-NEXT:    [[B:%.*]] = icmp ne <2 x i32> [[X:%.*]], zeroinitializer
 ; CHECK-NEXT:    ret <2 x i1> [[B]]
 ;
   %a = add <2 x i32> %x, <i32 -1, i32 -1>
@@ -140,8 +139,7 @@
 
 define <2 x i1> @test9_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @test9_vec(
-; CHECK-NEXT:    [[A:%.*]] = add <2 x i32> [[X:%.*]], <i32 -2, i32 -2>
-; CHECK-NEXT:    [[B:%.*]] = icmp ult <2 x i32> [[A]], [[X]]
+; CHECK-NEXT:    [[B:%.*]] = icmp ugt <2 x i32> [[X:%.*]], <i32 1, i32 1>
 ; CHECK-NEXT:    ret <2 x i1> [[B]]
 ;
   %a = add <2 x i32> %x, <i32 -2, i32 -2>
@@ -149,6 +147,26 @@
   ret <2 x i1> %b
 }
 
+define i1 @test9b(i32 %x) {
+; CHECK-LABEL: @test9b(
+; CHECK-NEXT:    [[B:%.*]] = icmp ult i32 [[X:%.*]], 2
+; CHECK-NEXT:    ret i1 [[B]]
+;
+  %a = add i32 %x, -2
+  %b = icmp ugt i32 %a, %x
+  ret i1 %b
+}
+
+define <2 x i1> @test9b_vec(<2 x i32> %x) {
+; CHECK-LABEL: @test9b_vec(
+; CHECK-NEXT:    [[B:%.*]] = icmp ult <2 x i32> [[X:%.*]], <i32 2, i32 2>
+; CHECK-NEXT:    ret <2 x i1> [[B]]
+;
+  %a = add <2 x i32> %x, <i32 -2, i32 -2>
+  %b = icmp ugt <2 x i32> %a, %x
+  ret <2 x i1> %b
+}
+
 define i1 @test10(i32 %x) {
 ; CHECK-LABEL: @test10(
 ; CHECK-NEXT:    [[B:%.*]] = icmp ne i32 %x, -2147483648
@@ -161,8 +179,7 @@
 
 define <2 x i1> @test10_vec(<2 x i32> %x) {
 ; CHECK-LABEL: @test10_vec(
-; CHECK-NEXT:    [[A:%.*]] = add <2 x i32> [[X:%.*]], <i32 -1, i32 -1>
-; CHECK-NEXT:    [[B:%.*]] = icmp slt <2 x i32> [[A]], [[X]]
+; CHECK-NEXT:    [[B:%.*]] = icmp ne <2 x i32> [[X:%.*]], <i32 -2147483648, i32 -2147483648>
 ; CHECK-NEXT:    ret <2 x i1> [[B]]
 ;
   %a = add <2 x i32> %x, <i32 -1, i32 -1>
@@ -170,6 +187,26 @@
   ret <2 x i1> %b
 }
 
+define i1 @test10b(i32 %x) {
+; CHECK-LABEL: @test10b(
+; CHECK-NEXT:    [[B:%.*]] = icmp eq i32 [[X:%.*]], -2147483648
+; CHECK-NEXT:    ret i1 [[B]]
+;
+  %a = add i32 %x, -1
+  %b = icmp sgt i32 %a, %x
+  ret i1 %b
+}
+
+define <2 x i1> @test10b_vec(<2 x i32> %x) {
+; CHECK-LABEL: @test10b_vec(
+; CHECK-NEXT:    [[B:%.*]] = icmp eq <2 x i32> [[X:%.*]], <i32 -2147483648, i32 -2147483648>
+; CHECK-NEXT:    ret <2 x i1> [[B]]
+;
+  %a = add <2 x i32> %x, <i32 -1, i32 -1>
+  %b = icmp sgt <2 x i32> %a, %x
+  ret <2 x i1> %b
+}
+
 define i1 @test11(i32 %x) {
 ; CHECK-LABEL: @test11(
 ; CHECK-NEXT:    ret i1 true