[InstCombine] Allow values with multiple users in SimplifyDemandedVectorElts
Summary:
Allow for ignoring the check for a single use in SimplifyDemandedVectorElts
to be able to simplify operands if DemandedElts is known to contain
the union of elements used by all users.
It is a responsibility of a caller of SimplifyDemandedVectorElts to
supply correct DemandedElts.
Simplify a series of extractelement instructions if only a subset of
elements is used.
Reviewers: reames, arsenm, majnemer, nhaehnle
Reviewed By: nhaehnle
Subscribers: wdng, jvesely, nhaehnle, llvm-commits
Tags: #llvm
Differential Revision: https://reviews.llvm.org/D67345
llvm-svn: 375395
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 13c38ca..9c89074 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -253,6 +253,69 @@
return nullptr;
}
+/// Find elements of V demanded by UserInstr.
+static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) {
+ unsigned VWidth = V->getType()->getVectorNumElements();
+
+ // Conservatively assume that all elements are needed.
+ APInt UsedElts(APInt::getAllOnesValue(VWidth));
+
+ switch (UserInstr->getOpcode()) {
+ case Instruction::ExtractElement: {
+ ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr);
+ assert(EEI->getVectorOperand() == V);
+ ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand());
+ if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) {
+ UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue());
+ }
+ break;
+ }
+ case Instruction::ShuffleVector: {
+ ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr);
+ unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements();
+
+ UsedElts = APInt(VWidth, 0);
+ for (unsigned i = 0; i < MaskNumElts; i++) {
+ unsigned MaskVal = Shuffle->getMaskValue(i);
+ if (MaskVal == -1u || MaskVal >= 2 * VWidth)
+ continue;
+ if (Shuffle->getOperand(0) == V && (MaskVal < VWidth))
+ UsedElts.setBit(MaskVal);
+ if (Shuffle->getOperand(1) == V &&
+ ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth)))
+ UsedElts.setBit(MaskVal - VWidth);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ return UsedElts;
+}
+
+/// Find union of elements of V demanded by all its users.
+/// If it is known by querying findDemandedEltsBySingleUser that
+/// no user demands an element of V, then the corresponding bit
+/// remains unset in the returned value.
+static APInt findDemandedEltsByAllUsers(Value *V) {
+ unsigned VWidth = V->getType()->getVectorNumElements();
+
+ APInt UnionUsedElts(VWidth, 0);
+ for (const Use &U : V->uses()) {
+ if (Instruction *I = dyn_cast<Instruction>(U.getUser())) {
+ UnionUsedElts |= findDemandedEltsBySingleUser(V, I);
+ } else {
+ UnionUsedElts = APInt::getAllOnesValue(VWidth);
+ break;
+ }
+
+ if (UnionUsedElts.isAllOnesValue())
+ break;
+ }
+
+ return UnionUsedElts;
+}
+
Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
Value *SrcVec = EI.getVectorOperand();
Value *Index = EI.getIndexOperand();
@@ -271,19 +334,35 @@
return nullptr;
// This instruction only demands the single element from the input vector.
- // If the input vector has a single use, simplify it based on this use
- // property.
- if (SrcVec->hasOneUse() && NumElts != 1) {
- APInt UndefElts(NumElts, 0);
- APInt DemandedElts(NumElts, 0);
- DemandedElts.setBit(IndexC->getZExtValue());
- if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts,
- UndefElts)) {
- EI.setOperand(0, V);
- return &EI;
+ if (NumElts != 1) {
+ // If the input vector has a single use, simplify it based on this use
+ // property.
+ if (SrcVec->hasOneUse()) {
+ APInt UndefElts(NumElts, 0);
+ APInt DemandedElts(NumElts, 0);
+ DemandedElts.setBit(IndexC->getZExtValue());
+ if (Value *V =
+ SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) {
+ EI.setOperand(0, V);
+ return &EI;
+ }
+ } else {
+ // If the input vector has multiple uses, simplify it based on a union
+ // of all elements used.
+ APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec);
+ if (!DemandedElts.isAllOnesValue()) {
+ APInt UndefElts(NumElts, 0);
+ if (Value *V = SimplifyDemandedVectorElts(
+ SrcVec, DemandedElts, UndefElts, 0 /* Depth */,
+ true /* AllowMultipleUsers */)) {
+ if (V != SrcVec) {
+ SrcVec->replaceAllUsesWith(V);
+ return &EI;
+ }
+ }
+ }
}
}
-
if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian()))
return I;