[InstCombine] Make folding (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 support splat vectors

This also uses decomposeBitTestICmp to decode the compare.

Differential Revision: https://reviews.llvm.org/D36781

llvm-svn: 311044
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 4de2902..b5fa2dc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombineInternal.h"
+#include "llvm/Analysis/CmpInstAnalysis.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/ValueTracking.h"
@@ -694,27 +695,31 @@
   // FIXME: Type and constness constraints could be lifted, but we have to
   //        watch code size carefully. We should consider xor instead of
   //        sub/add when we decide to do that.
-  if (IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) {
-    if (TrueVal->getType() == Ty) {
-      if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) {
-        ConstantInt *C1 = nullptr, *C2 = nullptr;
-        if (Pred == ICmpInst::ICMP_SGT && Cmp->isMinusOne()) {
-          C1 = dyn_cast<ConstantInt>(TrueVal);
-          C2 = dyn_cast<ConstantInt>(FalseVal);
-        } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isZero()) {
-          C1 = dyn_cast<ConstantInt>(FalseVal);
-          C2 = dyn_cast<ConstantInt>(TrueVal);
-        }
-        if (C1 && C2) {
+  if (CmpLHS->getType()->isIntOrIntVectorTy() &&
+      CmpLHS->getType() == TrueVal->getType()) {
+    const APInt *C1, *C2;
+    if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) {
+      ICmpInst::Predicate Pred = ICI->getPredicate();
+      Value *X;
+      APInt Mask;
+      if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask)) {
+        if (Mask.isSignMask()) {
+          assert(X == CmpLHS && "Expected to use the compare input directly");
+          assert(ICmpInst::isEquality(Pred) && "Expected equality predicate");
+
+          if (Pred == ICmpInst::ICMP_NE)
+            std::swap(C1, C2);
+
           // This shift results in either -1 or 0.
-          Value *AShr = Builder.CreateAShr(CmpLHS, Ty->getBitWidth() - 1);
+          Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1);
 
           // Check if we can express the operation with a single or.
-          if (C2->isMinusOne())
-            return replaceInstUsesWith(SI, Builder.CreateOr(AShr, C1));
+          if (C2->isAllOnesValue())
+            return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1));
 
-          Value *And = Builder.CreateAnd(AShr, C2->getValue() - C1->getValue());
-          return replaceInstUsesWith(SI, Builder.CreateAdd(And, C1));
+          Value *And = Builder.CreateAnd(AShr, *C2 - *C1);
+          return replaceInstUsesWith(SI, Builder.CreateAdd(And,
+                                        ConstantInt::get(And->getType(), *C1)));
         }
       }
     }