[InstCombine] narrow rotate left/right patterns to eliminate zext/trunc (PR34046)

I couldn't find any smaller folds to help the cases in:
https://bugs.llvm.org/show_bug.cgi?id=34046
after:
rL310141

The truncated rotate-by-variable patterns elude all of the existing transforms because 
of multiple uses and knowledge about demanded bits and knownbits that doesn't exist 
without the whole pattern. So we need an unfortunately large pattern match. But by 
simplifying this pattern in IR, the backend is already able to generate 
rolb/rolw/rorb/rorw for x86 using its existing rotate matching logic (although
there is a likely extraneous 'and' of the rotate amount). 

Note that rotate-by-constant doesn't have this problem - smaller folds should already 
produce the narrow IR ops.

Differential Revision: https://reviews.llvm.org/D36395

llvm-svn: 310509
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 3139cb33..a3f56bc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -443,13 +443,81 @@
   return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt));
 }
 
+/// Rotate left/right may occur in a wider type than necessary because of type
+/// promotion rules. Try to narrow all of the component instructions.
+Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) {
+  assert((isa<VectorType>(Trunc.getSrcTy()) ||
+          shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) &&
+         "Don't narrow to an illegal scalar type");
+
+  // First, find an or'd pair of opposite shifts with the same shifted operand:
+  // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1))
+  Value *Or0, *Or1;
+  if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1)))))
+    return nullptr;
+
+  Value *ShVal, *ShAmt0, *ShAmt1;
+  if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) ||
+      !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1)))))
+    return nullptr;
+
+  auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode();
+  auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode();
+  if (ShiftOpcode0 == ShiftOpcode1)
+    return nullptr;
+
+  // The shift amounts must add up to the narrow bit width.
+  Value *ShAmt;
+  bool SubIsOnLHS;
+  Type *DestTy = Trunc.getType();
+  unsigned NarrowWidth = DestTy->getScalarSizeInBits();
+  if (match(ShAmt0,
+            m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) {
+    ShAmt = ShAmt1;
+    SubIsOnLHS = true;
+  } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth),
+                                          m_Specific(ShAmt0))))) {
+    ShAmt = ShAmt0;
+    SubIsOnLHS = false;
+  } else {
+    return nullptr;
+  }
+
+  // The shifted value must have high zeros in the wide type. Typically, this
+  // will be a zext, but it could also be the result of an 'and' or 'shift'.
+  unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits();
+  APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth);
+  if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc))
+    return nullptr;
+
+  // We have an unnecessarily wide rotate!
+  // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt))
+  // Narrow it down to eliminate the zext/trunc:
+  // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1')
+  Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy);
+  Value *NegShAmt = Builder.CreateNeg(NarrowShAmt);
+
+  // Mask both shift amounts to ensure there's no UB from oversized shifts.
+  Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1);
+  Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC);
+  Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC);
+
+  // Truncate the original value and use narrow ops.
+  Value *X = Builder.CreateTrunc(ShVal, DestTy);
+  Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt;
+  Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt;
+  Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0);
+  Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1);
+  return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1);
+}
+
 /// Try to narrow the width of math or bitwise logic instructions by pulling a
 /// truncate ahead of binary operators.
 /// TODO: Transforms for truncated shifts should be moved into here.
 Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) {
   Type *SrcTy = Trunc.getSrcTy();
   Type *DestTy = Trunc.getType();
-  if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
+  if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
     return nullptr;
 
   BinaryOperator *BinOp;
@@ -485,6 +553,9 @@
   default: break;
   }
 
+  if (Instruction *NarrowOr = narrowRotate(Trunc))
+    return NarrowOr;
+
   return nullptr;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index f550dd5..e6d5d1c 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -440,6 +440,7 @@
   Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask);
   Instruction *foldCastedBitwiseLogic(BinaryOperator &I);
   Instruction *narrowBinOp(TruncInst &Trunc);
+  Instruction *narrowRotate(TruncInst &Trunc);
   Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
 
   /// Determine if a pair of casts can be replaced by a single cast.