[InstCombine] Bypass high bit extract before variable sign-extension (PR43523)

https://rise4fun.com/Alive/8BY - valid for lshr+trunc+variable sext
https://rise4fun.com/Alive/7jk - the variable sext can be redundant

https://rise4fun.com/Alive/Qslu - 'exact'-ness of first shift can be preserver

https://rise4fun.com/Alive/IF63 - without trunc we could view this as
                                  more general "drop redundant mask before right-shift",
                                  but let's handle it here for now
https://rise4fun.com/Alive/iip - likewise, without trunc, variable sext can be redundant.

There's more patterns for sure - e.g. we can have 'lshr' as the final shift,
but that might be best handled by some more generic transform, e.g.
"drop redundant masking before right-shift" (PR42456)

I'm singling-out this sext patch because you can only extract
high bits with `*shr` (unlike abstract bit masking),
and i *know* this fold is wanted by existing code.

I don't believe there is much to review here,
so i'm gonna opt into post-review mode here.

https://bugs.llvm.org/show_bug.cgi?id=43523

llvm-svn: 373542
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 6730994..dcdbee1 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -351,6 +351,8 @@
   Instruction *visitOr(BinaryOperator &I);
   Instruction *visitXor(BinaryOperator &I);
   Instruction *visitShl(BinaryOperator &I);
+  Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract(
+      BinaryOperator &OldAShr);
   Instruction *visitAShr(BinaryOperator &I);
   Instruction *visitLShr(BinaryOperator &I);
   Instruction *commonShiftTransforms(BinaryOperator &I);
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index bc4affb..9d96ddc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -1039,6 +1039,75 @@
   return nullptr;
 }
 
+Instruction *
+InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract(
+    BinaryOperator &OldAShr) {
+  assert(OldAShr.getOpcode() == Instruction::AShr &&
+         "Must be called with arithmetic right-shift instruction only.");
+
+  // Check that constant C is a splat of the element-wise bitwidth of V.
+  auto BitWidthSplat = [](Constant *C, Value *V) {
+    return match(
+        C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+                              APInt(C->getType()->getScalarSizeInBits(),
+                                    V->getType()->getScalarSizeInBits())));
+  };
+
+  // It should look like variable-length sign-extension on the outside:
+  //   (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits)
+  Value *NBits;
+  Instruction *MaybeTrunc;
+  Constant *C1, *C2;
+  if (!match(&OldAShr,
+             m_AShr(m_Shl(m_Instruction(MaybeTrunc),
+                          m_ZExtOrSelf(m_Sub(m_Constant(C1),
+                                             m_ZExtOrSelf(m_Value(NBits))))),
+                    m_ZExtOrSelf(m_Sub(m_Constant(C2),
+                                       m_ZExtOrSelf(m_Deferred(NBits)))))) ||
+      !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr))
+    return nullptr;
+
+  // There may or may not be a truncation after outer two shifts.
+  Instruction *HighBitExtract;
+  match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract)));
+  bool HadTrunc = MaybeTrunc != HighBitExtract;
+
+  // And finally, the innermost part of the pattern must be a right-shift.
+  Value *X, *NumLowBitsToSkip;
+  if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip))))
+    return nullptr;
+
+  // Said right-shift must extract high NBits bits - C0 must be it's bitwidth.
+  Constant *C0;
+  if (!match(NumLowBitsToSkip,
+             m_ZExtOrSelf(
+                 m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) ||
+      !BitWidthSplat(C0, HighBitExtract))
+    return nullptr;
+
+  // Since the NBits is identical for all shifts, if the outermost and
+  // innermost shifts are identical, then outermost shifts are redundant.
+  // If we had truncation, do keep it though.
+  if (HighBitExtract->getOpcode() == OldAShr.getOpcode())
+    return replaceInstUsesWith(OldAShr, MaybeTrunc);
+
+  // Else, if there was a truncation, then we need to ensure that one
+  // instruction will go away.
+  if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+    return nullptr;
+
+  // Finally, bypass two innermost shifts, and perform the outermost shift on
+  // the operands of the innermost shift.
+  Instruction *NewAShr =
+      BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip);
+  NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness.
+  if (!HadTrunc)
+    return NewAShr;
+
+  Builder.Insert(NewAShr);
+  return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType());
+}
+
 Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
   if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
                                   SQ.getWithInstruction(&I)))
@@ -1113,6 +1182,9 @@
     }
   }
 
+  if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I))
+    return R;
+
   // See if we can turn a signed shr into an unsigned shr.
   if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
     return BinaryOperator::CreateLShr(Op0, Op1);