[InstCombine] try to reduce shuffle with bitcasted operand
shuf (bitcast X), undef, Mask --> bitcast X'
The 'inverse shuffles' test (shuf_bitcast_operand) is a pattern
in the motivating examples from PR35454:
https://bugs.llvm.org/show_bug.cgi?id=35454
(see also D76727)
We can deal with this class of patterns in generic instcombine
because we are not creating any new shuffles, just a bitcast.
Alive2 proof:
http://volta.cs.utah.edu:8080/z/mwDUZf
Differential Revision: https://reviews.llvm.org/D76844
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index 2d72696..1e95f23 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -1889,9 +1889,9 @@
Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
Value *LHS = SVI.getOperand(0);
Value *RHS = SVI.getOperand(1);
- if (auto *V =
- SimplifyShuffleVectorInst(LHS, RHS, SVI.getShuffleMask(),
- SVI.getType(), SQ.getWithInstruction(&SVI)))
+ SimplifyQuery ShufQuery = SQ.getWithInstruction(&SVI);
+ if (auto *V = SimplifyShuffleVectorInst(LHS, RHS, SVI.getShuffleMask(),
+ SVI.getType(), ShufQuery))
return replaceInstUsesWith(SVI, V);
// shuffle x, x, mask --> shuffle x, undef, mask'
@@ -1899,6 +1899,32 @@
unsigned LHSWidth = LHS->getType()->getVectorNumElements();
ArrayRef<int> Mask = SVI.getShuffleMask();
Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
+
+ // Peek through a bitcasted shuffle operand by scaling the mask. If the
+ // simulated shuffle can simplify, then this shuffle is unnecessary:
+ // shuf (bitcast X), undef, Mask --> bitcast X'
+ // TODO: This could be extended to allow length-changing shuffles and/or casts
+ // to narrower elements. The transform might also be obsoleted if we
+ // allowed canonicalization of bitcasted shuffles.
+ Value *X;
+ if (match(LHS, m_BitCast(m_Value(X))) && match(RHS, m_Undef()) &&
+ X->getType()->isVectorTy() && VWidth == LHSWidth &&
+ X->getType()->getVectorNumElements() >= VWidth) {
+ // Create the scaled mask constant.
+ Type *XType = X->getType();
+ unsigned XNumElts = XType->getVectorNumElements();
+ assert(XNumElts % VWidth == 0 && "Unexpected vector bitcast");
+ unsigned ScaleFactor = XNumElts / VWidth;
+ SmallVector<int, 16> ScaledMask;
+ scaleShuffleMask(ScaleFactor, Mask, ScaledMask);
+
+ // If the shuffled source vector simplifies, cast that value to this
+ // shuffle's type.
+ if (auto *V = SimplifyShuffleVectorInst(X, UndefValue::get(XType),
+ ScaledMask, XType, ShufQuery))
+ return BitCastInst::Create(Instruction::BitCast, V, SVI.getType());
+ }
+
if (LHS == RHS) {
assert(!isa<UndefValue>(RHS) && "Shuffle with 2 undef ops not simplified?");
// Remap any references to RHS to use LHS.