Simplify operands of masked stores and scatters based on demanded elements

If we know we're not storing a lane, we don't need to compute the lane. This could be improved by using the undef element result to further prune the mask, but I want to separate that into its own change since it's relatively likely to expose other problems.

Differential Revision: https://reviews.llvm.org/D57247

llvm-svn: 356590
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 359f617..1646c0f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -25,6 +25,7 @@
 #include "llvm/Analysis/MemoryBuiltins.h"
 #include "llvm/Transforms/Utils/Local.h"
 #include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/Attributes.h"
 #include "llvm/IR/BasicBlock.h"
 #include "llvm/IR/Constant.h"
@@ -1175,6 +1176,20 @@
   return true;
 }
 
+/// Given a mask vector <Y x i1>, return an APInt (of bitwidth Y) for each lane
+/// which may be active.  TODO: This is a lot like known bits, but for
+/// vectors.  Is there something we can common this with?
+static APInt possiblyDemandedEltsInMask(Value *Mask) {
+
+  const unsigned VWidth = cast<VectorType>(Mask->getType())->getNumElements();
+  APInt DemandedElts = APInt::getAllOnesValue(VWidth);
+  if (auto *CV = dyn_cast<ConstantVector>(Mask))
+    for (unsigned i = 0; i < VWidth; i++)
+      if (CV->getAggregateElement(i)->isNullValue())
+        DemandedElts.clearBit(i);
+  return DemandedElts;
+}
+
 // TODO, Obvious Missing Transforms:
 // * Dereferenceable address -> speculative load/select
 // * Narrow width by halfs excluding zero/undef lanes
@@ -1196,14 +1211,14 @@
 // * SimplifyDemandedVectorElts
 // * Single constant active lane -> store
 // * Narrow width by halfs excluding zero/undef lanes
-static Instruction *simplifyMaskedStore(IntrinsicInst &II, InstCombiner &IC) {
+Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) {
   auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
   if (!ConstMask)
     return nullptr;
 
   // If the mask is all zeros, this instruction does nothing.
   if (ConstMask->isNullValue())
-    return IC.eraseInstFromFunction(II);
+    return eraseInstFromFunction(II);
 
   // If the mask is all ones, this is a plain vector store of the 1st argument.
   if (ConstMask->isAllOnesValue()) {
@@ -1212,6 +1227,15 @@
     return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment);
   }
 
+  // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
+  APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
+  APInt UndefElts(DemandedElts.getBitWidth(), 0);
+  if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0),
+                                            DemandedElts, UndefElts)) {
+    II.setOperand(0, V);
+    return &II;
+  }
+
   return nullptr;
 }
 
@@ -1268,11 +1292,28 @@
 // * Single constant active lane -> store
 // * Adjacent vector addresses -> masked.store
 // * Narrow store width by halfs excluding zero/undef lanes
-static Instruction *simplifyMaskedScatter(IntrinsicInst &II, InstCombiner &IC) {
-  // If the mask is all zeros, a scatter does nothing.
+Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) {
   auto *ConstMask = dyn_cast<Constant>(II.getArgOperand(3));
-  if (ConstMask && ConstMask->isNullValue())
-    return IC.eraseInstFromFunction(II);
+  if (!ConstMask)
+    return nullptr;
+
+  // If the mask is all zeros, a scatter does nothing.
+  if (ConstMask->isNullValue())
+    return eraseInstFromFunction(II);
+
+  // Use masked off lanes to simplify operands via SimplifyDemandedVectorElts
+  APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask);
+  APInt UndefElts(DemandedElts.getBitWidth(), 0);
+  if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0),
+                                            DemandedElts, UndefElts)) {
+    II.setOperand(0, V);
+    return &II;
+  }
+  if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1),
+                                            DemandedElts, UndefElts)) {
+    II.setOperand(1, V);
+    return &II;
+  }
 
   return nullptr;
 }
@@ -1972,11 +2013,11 @@
       return replaceInstUsesWith(CI, SimplifiedMaskedOp);
     break;
   case Intrinsic::masked_store:
-    return simplifyMaskedStore(*II, *this);
+    return simplifyMaskedStore(*II);
   case Intrinsic::masked_gather:
     return simplifyMaskedGather(*II, *this);
   case Intrinsic::masked_scatter:
-    return simplifyMaskedScatter(*II, *this);
+    return simplifyMaskedScatter(*II);
   case Intrinsic::launder_invariant_group:
   case Intrinsic::strip_invariant_group:
     if (auto *SkippedBarrier = simplifyInvariantGroupIntrinsic(*II, *this))