[ScalarizeMaskedMemIntrin] Bitcast the mask to the scalar domain and use scalar bit tests for the branches for expandload/compressstore.

Same as what was done for gather/scatter/load/store in r367489.
Expandload/compressstore were delayed due to lack of constant
masking handling that has since been fixed.

llvm-svn: 367738
diff --git a/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
index 71bd0fe..5155826 100644
--- a/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
@@ -634,6 +634,14 @@
     return;
   }
 
+  // If the mask is not v1i1, use scalar bit test operations. This generates
+  // better results on X86 at least.
+  Value *SclrMask;
+  if (VectorWidth != 1) {
+    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
+    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
+  }
+
   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     // Fill the "else" block, created in the previous iteration
     //
@@ -642,8 +650,14 @@
     //  br i1 %mask_1, label %cond.load, label %else
     //
 
-    Value *Predicate =
-        Builder.CreateExtractElement(Mask, Idx);
+    Value *Predicate;
+    if (VectorWidth != 1) {
+      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
+      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
+                                       Builder.getIntN(VectorWidth, 0));
+    } else {
+      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
+    }
 
     // Create "cond" block
     //
@@ -728,13 +742,28 @@
     return;
   }
 
+  // If the mask is not v1i1, use scalar bit test operations. This generates
+  // better results on X86 at least.
+  Value *SclrMask;
+  if (VectorWidth != 1) {
+    Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
+    SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
+  }
+
   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
     // Fill the "else" block, created in the previous iteration
     //
     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
     //  br i1 %mask_1, label %cond.store, label %else
     //
-    Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
+    Value *Predicate;
+    if (VectorWidth != 1) {
+      Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
+      Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
+                                       Builder.getIntN(VectorWidth, 0));
+    } else {
+      Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
+    }
 
     // Create "cond" block
     //