[InstCombine] Allow values with multiple users in SimplifyDemandedVectorElts

Summary:
Allow for ignoring the check for a single use in SimplifyDemandedVectorElts
to be able to simplify operands if DemandedElts is known to contain
the union of elements used by all users.
It is a responsibility of a caller of SimplifyDemandedVectorElts to
supply correct DemandedElts.

Simplify a series of extractelement instructions if only a subset of
elements is used.

Reviewers: reames, arsenm, majnemer, nhaehnle

Reviewed By: nhaehnle

Subscribers: wdng, jvesely, nhaehnle, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D67345

llvm-svn: 375395
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 13c38ca..9c89074 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -253,6 +253,69 @@
   return nullptr;
 }
 
+/// Find elements of V demanded by UserInstr.
+static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
+  unsigned VWidth = V->getType()->getVectorNumElements();
+
+  // Conservatively assume that all elements are needed.
+  APInt UsedElts(APInt::getAllOnesValue(VWidth));
+
+  switch (UserInstr->getOpcode()) {
+  case Instruction::ExtractElement: {
+    ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr);
+    assert(EEI->getVectorOperand() == V);
+    ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand());
+    if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) {
+      UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue());
+    }
+    break;
+  }
+  case Instruction::ShuffleVector: {
+    ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr);
+    unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements();
+
+    UsedElts = APInt(VWidth, 0);
+    for (unsigned i = 0; i < MaskNumElts; i++) {
+      unsigned MaskVal = Shuffle->getMaskValue(i);
+      if (MaskVal == -1u || MaskVal >= 2 * VWidth)
+        continue;
+      if (Shuffle->getOperand(0) == V && (MaskVal < VWidth))
+        UsedElts.setBit(MaskVal);
+      if (Shuffle->getOperand(1) == V &&
+          ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth)))
+        UsedElts.setBit(MaskVal - VWidth);
+    }
+    break;
+  }
+  default:
+    break;
+  }
+  return UsedElts;
+}
+
+/// Find union of elements of V demanded by all its users.
+/// If it is known by querying findDemandedEltsBySingleUser that
+/// no user demands an element of V, then the corresponding bit
+/// remains unset in the returned value.
+static APInt findDemandedEltsByAllUsers(Value *V) {
+  unsigned VWidth = V->getType()->getVectorNumElements();
+
+  APInt UnionUsedElts(VWidth, 0);
+  for (const Use &U : V->uses()) {
+    if (Instruction *I = dyn_cast<Instruction>(U.getUser())) {
+      UnionUsedElts |= findDemandedEltsBySingleUser(V, I);
+    } else {
+      UnionUsedElts = APInt::getAllOnesValue(VWidth);
+      break;
+    }
+
+    if (UnionUsedElts.isAllOnesValue())
+      break;
+  }
+
+  return UnionUsedElts;
+}
+
 Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
   Value *SrcVec = EI.getVectorOperand();
   Value *Index = EI.getIndexOperand();
@@ -271,19 +334,35 @@
       return nullptr;
 
     // This instruction only demands the single element from the input vector.
-    // If the input vector has a single use, simplify it based on this use
-    // property.
-    if (SrcVec->hasOneUse() && NumElts != 1) {
-      APInt UndefElts(NumElts, 0);
-      APInt DemandedElts(NumElts, 0);
-      DemandedElts.setBit(IndexC->getZExtValue());
-      if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts,
-                                                UndefElts)) {
-        EI.setOperand(0, V);
-        return &EI;
+    if (NumElts != 1) {
+      // If the input vector has a single use, simplify it based on this use
+      // property.
+      if (SrcVec->hasOneUse()) {
+        APInt UndefElts(NumElts, 0);
+        APInt DemandedElts(NumElts, 0);
+        DemandedElts.setBit(IndexC->getZExtValue());
+        if (Value *V =
+                SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) {
+          EI.setOperand(0, V);
+          return &EI;
+        }
+      } else {
+        // If the input vector has multiple uses, simplify it based on a union
+        // of all elements used.
+        APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec);
+        if (!DemandedElts.isAllOnesValue()) {
+          APInt UndefElts(NumElts, 0);
+          if (Value *V = SimplifyDemandedVectorElts(
+                  SrcVec, DemandedElts, UndefElts, 0 /* Depth */,
+                  true /* AllowMultipleUsers */)) {
+            if (V != SrcVec) {
+              SrcVec->replaceAllUsesWith(V);
+              return &EI;
+            }
+          }
+        }
       }
     }
-
     if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian()))
       return I;