[InstCombine] try to convert x86 movmsk intrinsic to generic IR (PR39927)

call iM movmsk(sext <N x i1> X) --> zext (bitcast <N x i1> X to iN) to iM

This has the potential to create less-than-8-bit scalar types as shown in 
some of the test diffs, but it looks like the backend knows how to deal 
with that in these patterns. This is the simple part of the fix suggested in:
https://bugs.llvm.org/show_bug.cgi?id=39927

Differential Revision: https://reviews.llvm.org/D55529

llvm-svn: 348862
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index a023b6d..ae158ae 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -736,7 +736,8 @@
   return Builder.CreateInsertElement(Dst, Res, (uint64_t)0);
 }
 
-static Value *simplifyX86movmsk(const IntrinsicInst &II) {
+static Value *simplifyX86movmsk(const IntrinsicInst &II,
+                                InstCombiner::BuilderTy &Builder) {
   Value *Arg = II.getArgOperand(0);
   Type *ResTy = II.getType();
   Type *ArgTy = Arg->getType();
@@ -749,29 +750,46 @@
   if (!ArgTy->isVectorTy())
     return nullptr;
 
-  auto *C = dyn_cast<Constant>(Arg);
-  if (!C)
-    return nullptr;
+  if (auto *C = dyn_cast<Constant>(Arg)) {
+    // Extract signbits of the vector input and pack into integer result.
+    APInt Result(ResTy->getPrimitiveSizeInBits(), 0);
+    for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) {
+      auto *COp = C->getAggregateElement(I);
+      if (!COp)
+        return nullptr;
+      if (isa<UndefValue>(COp))
+        continue;
 
-  // Extract signbits of the vector input and pack into integer result.
-  APInt Result(ResTy->getPrimitiveSizeInBits(), 0);
-  for (unsigned I = 0, E = ArgTy->getVectorNumElements(); I != E; ++I) {
-    auto *COp = C->getAggregateElement(I);
-    if (!COp)
-      return nullptr;
-    if (isa<UndefValue>(COp))
-      continue;
+      auto *CInt = dyn_cast<ConstantInt>(COp);
+      auto *CFp = dyn_cast<ConstantFP>(COp);
+      if (!CInt && !CFp)
+        return nullptr;
 
-    auto *CInt = dyn_cast<ConstantInt>(COp);
-    auto *CFp = dyn_cast<ConstantFP>(COp);
-    if (!CInt && !CFp)
-      return nullptr;
-
-    if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative()))
-      Result.setBit(I);
+      if ((CInt && CInt->isNegative()) || (CFp && CFp->isNegative()))
+        Result.setBit(I);
+    }
+    return Constant::getIntegerValue(ResTy, Result);
   }
 
-  return Constant::getIntegerValue(ResTy, Result);
+  // Look for a sign-extended boolean source vector as the argument to this
+  // movmsk. If the argument is bitcast, look through that, but make sure the
+  // source of that bitcast is still a vector with the same number of elements.
+  // TODO: We can also convert a bitcast with wider elements, but that requires
+  // duplicating the bool source sign bits to match the number of elements
+  // expected by the movmsk call.
+  Arg = peekThroughBitcast(Arg);
+  Value *X;
+  if (Arg->getType()->isVectorTy() &&
+      Arg->getType()->getVectorNumElements() == ArgTy->getVectorNumElements() &&
+      match(Arg, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
+    // call iM movmsk(sext <N x i1> X) --> zext (bitcast <N x i1> X to iN) to iM
+    unsigned NumElts = X->getType()->getVectorNumElements();
+    Type *ScalarTy = Type::getIntNTy(Arg->getContext(), NumElts);
+    Value *BC = Builder.CreateBitCast(X, ScalarTy);
+    return Builder.CreateZExtOrTrunc(BC, ResTy);
+  }
+
+  return nullptr;
 }
 
 static Value *simplifyX86insertps(const IntrinsicInst &II,
@@ -2543,7 +2561,7 @@
   case Intrinsic::x86_avx_movmsk_pd_256:
   case Intrinsic::x86_avx_movmsk_ps_256:
   case Intrinsic::x86_avx2_pmovmskb:
-    if (Value *V = simplifyX86movmsk(*II))
+    if (Value *V = simplifyX86movmsk(*II, Builder))
       return replaceInstUsesWith(*II, V);
     break;