[InstCombine] Teach foldSelectICmpAndOr to handle vector splats

This was pretty close to working already. While I was here I went ahead and passed the ICmpInst pointer from the caller instead of doing a dyn_cast that can never fail.

Differential Revision: https://reviews.llvm.org/D37237

llvm-svn: 311960
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 08aba88..c43dad5 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -309,11 +309,13 @@
 /// 1. The icmp predicate is inverted
 /// 2. The select operands are reversed
 /// 3. The magnitude of C2 and C1 are flipped
-static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,
+static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal,
                                   Value *FalseVal,
                                   InstCombiner::BuilderTy &Builder) {
-  const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition());
-  if (!IC || !SI.getType()->isIntegerTy())
+  // Only handle integer compares. Also, if this is a vector select, we need a
+  // vector compare.
+  if (!TrueVal->getType()->isIntOrIntVectorTy() ||
+      TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy())
     return nullptr;
 
   Value *CmpLHS = IC->getOperand(0);
@@ -367,8 +369,8 @@
 
   bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal);
   bool NeedShift = C1Log != C2Log;
-  bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() !=
-                       V->getType()->getIntegerBitWidth();
+  bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() !=
+                       V->getType()->getScalarSizeInBits();
 
   // Make sure we don't create more instructions than we save.
   Value *Or = OrOnFalseVal ? FalseVal : TrueVal;
@@ -818,7 +820,7 @@
     }
   }
 
-  if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder))
+  if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder))
     return replaceInstUsesWith(SI, V);
 
   if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder))