[InstCombine] fix and enhance udiv/urem narrowing

There are 3 small independent changes here:

  1. Account for multiple uses in the pattern matching: avoid the transform if it increases the instruction count.
  2. Add a missing fold for the case where the numerator is the constant: http://rise4fun.com/Alive/E2p
  3. Enable all folds for vector types.

There's still one more potential change - use "shouldChangeType()" to keep from transforming to an illegal integer type.

Differential Revision: https://reviews.llvm.org/D36988

llvm-svn: 311726
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index e3a5022..c99f757 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -969,19 +969,6 @@
   return nullptr;
 }
 
-/// dyn_castZExtVal - Checks if V is a zext or constant that can
-/// be truncated to Ty without losing bits.
-static Value *dyn_castZExtVal(Value *V, Type *Ty) {
-  if (ZExtInst *Z = dyn_cast<ZExtInst>(V)) {
-    if (Z->getSrcTy() == Ty)
-      return Z->getOperand(0);
-  } else if (ConstantInt *C = dyn_cast<ConstantInt>(V)) {
-    if (C->getValue().getActiveBits() <= cast<IntegerType>(Ty)->getBitWidth())
-      return ConstantExpr::getTrunc(C, Ty);
-  }
-  return nullptr;
-}
-
 namespace {
 const unsigned MaxDepth = 6;
 typedef Instruction *(*FoldUDivOperandCb)(Value *Op0, Value *Op1,
@@ -1095,6 +1082,43 @@
   return 0;
 }
 
+/// If we have zero-extended operands of an unsigned div or rem, we may be able
+/// to narrow the operation (sink the zext below the math).
+static Instruction *narrowUDivURem(BinaryOperator &I,
+                                   InstCombiner::BuilderTy &Builder) {
+  Instruction::BinaryOps Opcode = I.getOpcode();
+  Value *N = I.getOperand(0);
+  Value *D = I.getOperand(1);
+  Type *Ty = I.getType();
+  Value *X, *Y;
+  if (match(N, m_ZExt(m_Value(X))) && match(D, m_ZExt(m_Value(Y))) &&
+      X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) {
+    // udiv (zext X), (zext Y) --> zext (udiv X, Y)
+    // urem (zext X), (zext Y) --> zext (urem X, Y)
+    Value *NarrowOp = Builder.CreateBinOp(Opcode, X, Y);
+    return new ZExtInst(NarrowOp, Ty);
+  }
+
+  Constant *C;
+  if ((match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) ||
+      (match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C)))) {
+    // If the constant is the same in the smaller type, use the narrow version.
+    Constant *TruncC = ConstantExpr::getTrunc(C, X->getType());
+    if (ConstantExpr::getZExt(TruncC, Ty) != C)
+      return nullptr;
+
+    // udiv (zext X), C --> zext (udiv X, C')
+    // urem (zext X), C --> zext (urem X, C')
+    // udiv C, (zext X) --> zext (udiv C', X)
+    // urem C, (zext X) --> zext (urem C', X)
+    Value *NarrowOp = isa<Constant>(D) ? Builder.CreateBinOp(Opcode, X, TruncC)
+                                       : Builder.CreateBinOp(Opcode, TruncC, X);
+    return new ZExtInst(NarrowOp, Ty);
+  }
+
+  return nullptr;
+}
+
 Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
 
@@ -1127,12 +1151,8 @@
     }
   }
 
-  // (zext A) udiv (zext B) --> zext (A udiv B)
-  if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0))
-    if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy()))
-      return new ZExtInst(
-          Builder.CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()),
-          I.getType());
+  if (Instruction *NarrowDiv = narrowUDivURem(I, Builder))
+    return NarrowDiv;
 
   // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...))))
   SmallVector<UDivFoldAction, 6> UDivActions;
@@ -1477,11 +1497,8 @@
   if (Instruction *common = commonIRemTransforms(I))
     return common;
 
-  // (zext A) urem (zext B) --> zext (A urem B)
-  if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0))
-    if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy()))
-      return new ZExtInst(Builder.CreateURem(ZOp0->getOperand(0), ZOp1),
-                          I.getType());
+  if (Instruction *NarrowRem = narrowUDivURem(I, Builder))
+    return NarrowRem;
 
   // X urem Y -> X and Y-1, where Y is a power of 2,
   if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) {