[InstCombine] allow zext(bool) + C --> select bool, C+1, C for vector types

The backend should be prepared for this transform after:
https://reviews.llvm.org/rL311731

llvm-svn: 315701
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index bcd60bc..e12bff0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -953,6 +953,19 @@
 static Instruction *foldAddWithConstant(BinaryOperator &Add,
                                         InstCombiner::BuilderTy &Builder) {
   Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1);
+  Constant *Op1C;
+  if (!match(Op1, m_Constant(Op1C)))
+    return nullptr;
+
+  Value *X;
+  Type *Ty = Add.getType();
+  if (match(Op0, m_ZExt(m_Value(X))) &&
+      X->getType()->getScalarSizeInBits() == 1) {
+    // zext(bool) + C -> bool ? C + 1 : C
+    Constant *One = ConstantInt::get(Ty, 1);
+    return SelectInst::Create(X, ConstantExpr::getAdd(Op1C, One), Op1);
+  }
+
   const APInt *C;
   if (!match(Op1, m_APInt(C)))
     return nullptr;
@@ -968,12 +981,9 @@
     return BinaryOperator::CreateXor(Op0, Op1);
   }
 
-  Value *X;
-  const APInt *C2;
-  Type *Ty = Add.getType();
-
   // Is this add the last step in a convoluted sext?
   // add(zext(xor i16 X, -32768), -32768) --> sext X
+  const APInt *C2;
   if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) &&
       C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C)
     return CastInst::Create(Instruction::SExt, X, Ty);
@@ -1031,13 +1041,8 @@
     return X;
 
   // FIXME: This should be moved into the above helper function to allow these
-  // transforms for splat vectors.
+  // transforms for general constant or constant splat vectors.
   if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
-    // zext(bool) + C -> bool ? C + 1 : C
-    if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))
-      if (ZI->getSrcTy()->isIntegerTy(1))
-        return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI);
-
     Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr;
     if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) {
       uint32_t TySizeBits = I.getType()->getScalarSizeInBits();