[InstCombine] fold vector select of binops with constant ops to 1 binop (PR37806)

This is the simplest case from PR37806:
https://bugs.llvm.org/show_bug.cgi?id=37806

If we have a common variable operand used in a pair of binops with vector constants 
that are vector selected together, then we can constant shuffle the constant vectors 
to eliminate the shuffle instruction.

This has some tricky parts that are hopefully addressed in the tests and their 
respective comments:

  1. If the shuffle mask contains an undef element, then that lane of the result is 
     undef:
     http://llvm.org/docs/LangRef.html#shufflevector-instruction

     Therefore, we can replace the constant in that lane with an undef value except 
     for div/rem. With div/rem, an undef in the divisor would cause the whole op to 
     be undef. So I'm using the same hack as in D47686 - replace the undefs with '1'.

  2. Intersect the wrapping and FMF of the original binops for the new binop. There 
     should be no extra poison or fast-math potential in the new binop that wasn't 
     possible in the original code.

  3. Disregard other uses. Given that we're eliminating uses (shortening the 
     dependency chain), I think that's always the right IR canonicalization. But 
     I purposely chose the udiv test to demonstrate the scenario where both 
     intermediate values have other uses because that seems likely worse for 
     codegen with an expensive math op. This seems like a very rare possibility to 
     me, so I don't think it requires a backend patch first.

Differential Revision: https://reviews.llvm.org/D48401

llvm-svn: 335283
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index aeac891..d8546c7 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -1140,6 +1140,54 @@
   return true;
 }
 
+static Instruction *foldSelectShuffles(ShuffleVectorInst &Shuf) {
+  // Folds under here require the equivalent of a vector select.
+  if (!Shuf.isSelect())
+    return nullptr;
+
+  BinaryOperator *B0, *B1;
+  if (!match(Shuf.getOperand(0), m_BinOp(B0)) ||
+      !match(Shuf.getOperand(1), m_BinOp(B1)))
+    return nullptr;
+
+  // TODO: There are potential folds where the opcodes do not match (mul+shl).
+  if (B0->getOpcode() != B1->getOpcode())
+    return nullptr;
+
+  // TODO: Fold the case with different variable operands (requires creating a
+  // new shuffle and checking number of uses).
+  Value *X;
+  Constant *C0, *C1;
+  if (!match(B0, m_c_BinOp(m_Value(X), m_Constant(C0))) ||
+      !match(B1, m_c_BinOp(m_Specific(X), m_Constant(C1))))
+    return nullptr;
+
+  // If all operands are constants, let constant folding remove the binops.
+  if (isa<Constant>(X))
+    return nullptr;
+
+  // Remove a binop and the shuffle by rearranging the constant:
+  // shuffle (op X, C0), (op X, C1), M --> op X, C'
+  // shuffle (op C0, X), (op C1, X), M --> op C', X
+  Constant *NewC = ConstantExpr::getShuffleVector(C0, C1, Shuf.getMask());
+
+  // If the shuffle mask contains undef elements, then the new constant
+  // vector will have undefs in those lanes. This could cause the entire
+  // binop to be undef.
+  if (B0->isIntDivRem())
+    NewC = getSafeVectorConstantForIntDivRem(NewC);
+
+  BinaryOperator::BinaryOps Opc = B0->getOpcode();
+  bool Op0IsConst = isa<Constant>(B0->getOperand(0));
+  Instruction *NewBO = Op0IsConst ? BinaryOperator::Create(Opc, NewC, X) :
+                                    BinaryOperator::Create(Opc, X, NewC);
+
+  // Flags are intersected from the 2 source binops.
+  NewBO->copyIRFlags(B0);
+  NewBO->andIRFlags(B1);
+  return NewBO;
+}
+
 Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
   Value *LHS = SVI.getOperand(0);
   Value *RHS = SVI.getOperand(1);
@@ -1150,6 +1198,9 @@
           LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI)))
     return replaceInstUsesWith(SVI, V);
 
+  if (Instruction *I = foldSelectShuffles(SVI))
+    return I;
+
   bool MadeChange = false;
   unsigned VWidth = SVI.getType()->getVectorNumElements();
 
diff --git a/llvm/test/Transforms/InstCombine/shuffle_select.ll b/llvm/test/Transforms/InstCombine/shuffle_select.ll
index aeff808..bba5124 100644
--- a/llvm/test/Transforms/InstCombine/shuffle_select.ll
+++ b/llvm/test/Transforms/InstCombine/shuffle_select.ll
@@ -6,9 +6,7 @@
 
 define <4 x i32> @add(<4 x i32> %v0) {
 ; CHECK-LABEL: @add(
-; CHECK-NEXT:    [[T1:%.*]] = add <4 x i32> [[V0:%.*]], <i32 1, i32 undef, i32 3, i32 undef>
-; CHECK-NEXT:    [[T2:%.*]] = add <4 x i32> [[V0]], <i32 undef, i32 6, i32 undef, i32 8>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 0, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = add <4 x i32> [[V0:%.*]], <i32 1, i32 6, i32 3, i32 8>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = add <4 x i32> %v0, <i32 1, i32 2, i32 3, i32 4>
@@ -21,9 +19,7 @@
 
 define <4 x i32> @sub(<4 x i32> %v0) {
 ; CHECK-LABEL: @sub(
-; CHECK-NEXT:    [[T1:%.*]] = sub <4 x i32> <i32 1, i32 2, i32 3, i32 undef>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = sub <4 x i32> <i32 undef, i32 undef, i32 undef, i32 8>, [[V0]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 0, i32 1, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = sub <4 x i32> <i32 1, i32 2, i32 3, i32 8>, [[V0:%.*]]
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = sub <4 x i32> <i32 1, i32 2, i32 3, i32 4>, %v0
@@ -37,9 +33,7 @@
 
 define <4 x i32> @mul(<4 x i32> %v0) {
 ; CHECK-LABEL: @mul(
-; CHECK-NEXT:    [[T1:%.*]] = mul <4 x i32> [[V0:%.*]], <i32 undef, i32 undef, i32 3, i32 undef>
-; CHECK-NEXT:    [[T2:%.*]] = mul <4 x i32> [[V0]], <i32 undef, i32 6, i32 undef, i32 8>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 undef, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = mul <4 x i32> [[V0:%.*]], <i32 undef, i32 6, i32 3, i32 8>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = mul <4 x i32> %v0, <i32 1, i32 2, i32 3, i32 4>
@@ -52,9 +46,7 @@
 
 define <4 x i32> @shl(<4 x i32> %v0) {
 ; CHECK-LABEL: @shl(
-; CHECK-NEXT:    [[T1:%.*]] = shl nuw <4 x i32> [[V0:%.*]], <i32 1, i32 2, i32 3, i32 4>
-; CHECK-NEXT:    [[T2:%.*]] = shl nuw <4 x i32> [[V0]], <i32 5, i32 6, i32 7, i32 8>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 undef, i32 5, i32 2, i32 undef>
+; CHECK-NEXT:    [[T3:%.*]] = shl nuw <4 x i32> [[V0:%.*]], <i32 undef, i32 6, i32 3, i32 undef>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = shl nuw <4 x i32> %v0, <i32 1, i32 2, i32 3, i32 4>
@@ -67,9 +59,7 @@
 
 define <4 x i32> @lshr(<4 x i32> %v0) {
 ; CHECK-LABEL: @lshr(
-; CHECK-NEXT:    [[T1:%.*]] = lshr exact <4 x i32> <i32 1, i32 2, i32 3, i32 4>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = lshr <4 x i32> <i32 5, i32 6, i32 7, i32 8>, [[V0]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 4, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = lshr <4 x i32> <i32 5, i32 6, i32 3, i32 8>, [[V0:%.*]]
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = lshr exact <4 x i32> <i32 1, i32 2, i32 3, i32 4>, %v0
@@ -82,9 +72,7 @@
 
 define <3 x i32> @ashr(<3 x i32> %v0) {
 ; CHECK-LABEL: @ashr(
-; CHECK-NEXT:    [[T1:%.*]] = ashr <3 x i32> [[V0:%.*]], <i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[T2:%.*]] = ashr <3 x i32> [[V0]], <i32 4, i32 5, i32 6>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <3 x i32> [[T1]], <3 x i32> [[T2]], <3 x i32> <i32 3, i32 1, i32 2>
+; CHECK-NEXT:    [[T3:%.*]] = ashr <3 x i32> [[V0:%.*]], <i32 4, i32 2, i32 3>
 ; CHECK-NEXT:    ret <3 x i32> [[T3]]
 ;
   %t1 = ashr <3 x i32> %v0, <i32 1, i32 2, i32 3>
@@ -95,9 +83,7 @@
 
 define <3 x i42> @and(<3 x i42> %v0) {
 ; CHECK-LABEL: @and(
-; CHECK-NEXT:    [[T1:%.*]] = and <3 x i42> [[V0:%.*]], <i42 1, i42 undef, i42 undef>
-; CHECK-NEXT:    [[T2:%.*]] = and <3 x i42> [[V0]], <i42 undef, i42 5, i42 undef>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <3 x i42> [[T1]], <3 x i42> [[T2]], <3 x i32> <i32 0, i32 4, i32 undef>
+; CHECK-NEXT:    [[T3:%.*]] = and <3 x i42> [[V0:%.*]], <i42 1, i42 5, i42 undef>
 ; CHECK-NEXT:    ret <3 x i42> [[T3]]
 ;
   %t1 = and <3 x i42> %v0, <i42 1, i42 2, i42 3>
@@ -113,8 +99,7 @@
 define <4 x i32> @or(<4 x i32> %v0) {
 ; CHECK-LABEL: @or(
 ; CHECK-NEXT:    [[T1:%.*]] = or <4 x i32> [[V0:%.*]], <i32 1, i32 2, i32 3, i32 4>
-; CHECK-NEXT:    [[T2:%.*]] = or <4 x i32> [[V0]], <i32 5, i32 6, i32 undef, i32 undef>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[T3:%.*]] = or <4 x i32> [[V0]], <i32 5, i32 6, i32 3, i32 4>
 ; CHECK-NEXT:    call void @use_v4i32(<4 x i32> [[T1]])
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
@@ -127,9 +112,8 @@
 
 define <4 x i32> @xor(<4 x i32> %v0) {
 ; CHECK-LABEL: @xor(
-; CHECK-NEXT:    [[T1:%.*]] = xor <4 x i32> [[V0:%.*]], <i32 1, i32 undef, i32 3, i32 4>
-; CHECK-NEXT:    [[T2:%.*]] = xor <4 x i32> [[V0]], <i32 5, i32 6, i32 7, i32 8>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 0, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[T2:%.*]] = xor <4 x i32> [[V0:%.*]], <i32 5, i32 6, i32 7, i32 8>
+; CHECK-NEXT:    [[T3:%.*]] = xor <4 x i32> [[V0]], <i32 1, i32 6, i32 3, i32 4>
 ; CHECK-NEXT:    call void @use_v4i32(<4 x i32> [[T2]])
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
@@ -144,7 +128,7 @@
 ; CHECK-LABEL: @udiv(
 ; CHECK-NEXT:    [[T1:%.*]] = udiv <4 x i32> <i32 1, i32 2, i32 3, i32 4>, [[V0:%.*]]
 ; CHECK-NEXT:    [[T2:%.*]] = udiv <4 x i32> <i32 5, i32 6, i32 7, i32 8>, [[V0]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 0, i32 1, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = udiv <4 x i32> <i32 1, i32 2, i32 3, i32 8>, [[V0]]
 ; CHECK-NEXT:    call void @use_v4i32(<4 x i32> [[T1]])
 ; CHECK-NEXT:    call void @use_v4i32(<4 x i32> [[T2]])
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
@@ -161,9 +145,7 @@
 
 define <4 x i32> @sdiv(<4 x i32> %v0) {
 ; CHECK-LABEL: @sdiv(
-; CHECK-NEXT:    [[T1:%.*]] = sdiv <4 x i32> [[V0:%.*]], <i32 1, i32 2, i32 3, i32 4>
-; CHECK-NEXT:    [[T2:%.*]] = sdiv <4 x i32> [[V0]], <i32 5, i32 6, i32 7, i32 8>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 undef, i32 1, i32 6, i32 undef>
+; CHECK-NEXT:    [[T3:%.*]] = sdiv <4 x i32> [[V0:%.*]], <i32 1, i32 2, i32 7, i32 1>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = sdiv <4 x i32> %v0, <i32 1, i32 2, i32 3, i32 4>
@@ -174,9 +156,7 @@
 
 define <4 x i32> @urem(<4 x i32> %v0) {
 ; CHECK-LABEL: @urem(
-; CHECK-NEXT:    [[T1:%.*]] = urem <4 x i32> <i32 1, i32 2, i32 3, i32 4>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = urem <4 x i32> <i32 5, i32 6, i32 7, i32 8>, [[V0]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 0, i32 1, i32 6, i32 undef>
+; CHECK-NEXT:    [[T3:%.*]] = urem <4 x i32> <i32 1, i32 2, i32 7, i32 1>, [[V0:%.*]]
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = urem <4 x i32> <i32 1, i32 2, i32 3, i32 4>, %v0
@@ -187,9 +167,7 @@
 
 define <4 x i32> @srem(<4 x i32> %v0) {
 ; CHECK-LABEL: @srem(
-; CHECK-NEXT:    [[T1:%.*]] = srem <4 x i32> <i32 1, i32 2, i32 3, i32 4>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = srem <4 x i32> <i32 5, i32 6, i32 7, i32 8>, [[V0]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 0, i32 1, i32 6, i32 3>
+; CHECK-NEXT:    [[T3:%.*]] = srem <4 x i32> <i32 1, i32 2, i32 7, i32 4>, [[V0:%.*]]
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = srem <4 x i32> <i32 1, i32 2, i32 3, i32 4>, %v0
@@ -202,9 +180,7 @@
 
 define <4 x float> @fadd(<4 x float> %v0) {
 ; CHECK-LABEL: @fadd(
-; CHECK-NEXT:    [[T1:%.*]] = fadd <4 x float> [[V0:%.*]], <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00>
-; CHECK-NEXT:    [[T2:%.*]] = fadd <4 x float> [[V0]], <float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x float> [[T1]], <4 x float> [[T2]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = fadd <4 x float> [[V0:%.*]], <float 1.000000e+00, float 2.000000e+00, float 7.000000e+00, float 8.000000e+00>
 ; CHECK-NEXT:    ret <4 x float> [[T3]]
 ;
   %t1 = fadd <4 x float> %v0, <float 1.0, float 2.0, float 3.0, float 4.0>
@@ -215,9 +191,7 @@
 
 define <4 x double> @fsub(<4 x double> %v0) {
 ; CHECK-LABEL: @fsub(
-; CHECK-NEXT:    [[T1:%.*]] = fsub <4 x double> <double 1.000000e+00, double 2.000000e+00, double 3.000000e+00, double 4.000000e+00>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = fsub <4 x double> <double 5.000000e+00, double 6.000000e+00, double 7.000000e+00, double 8.000000e+00>, [[V0]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x double> [[T1]], <4 x double> [[T2]], <4 x i32> <i32 undef, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = fsub <4 x double> <double undef, double 2.000000e+00, double 7.000000e+00, double 8.000000e+00>, [[V0:%.*]]
 ; CHECK-NEXT:    ret <4 x double> [[T3]]
 ;
   %t1 = fsub <4 x double> <double 1.0, double 2.0, double 3.0, double 4.0>, %v0
@@ -230,9 +204,7 @@
 
 define <4 x float> @fmul(<4 x float> %v0) {
 ; CHECK-LABEL: @fmul(
-; CHECK-NEXT:    [[T1:%.*]] = fmul nnan ninf <4 x float> [[V0:%.*]], <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00>
-; CHECK-NEXT:    [[T2:%.*]] = fmul nnan ninf <4 x float> [[V0]], <float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x float> [[T1]], <4 x float> [[T2]], <4 x i32> <i32 0, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = fmul nnan ninf <4 x float> [[V0:%.*]], <float 1.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00>
 ; CHECK-NEXT:    ret <4 x float> [[T3]]
 ;
   %t1 = fmul nnan ninf <4 x float> %v0, <float 1.0, float 2.0, float 3.0, float 4.0>
@@ -243,9 +215,7 @@
 
 define <4 x double> @fdiv(<4 x double> %v0) {
 ; CHECK-LABEL: @fdiv(
-; CHECK-NEXT:    [[T1:%.*]] = fdiv fast <4 x double> <double 1.000000e+00, double 2.000000e+00, double 3.000000e+00, double 4.000000e+00>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = fdiv nnan arcp <4 x double> <double 5.000000e+00, double 6.000000e+00, double 7.000000e+00, double 8.000000e+00>, [[V0]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x double> [[T1]], <4 x double> [[T2]], <4 x i32> <i32 undef, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = fdiv nnan arcp <4 x double> <double undef, double 2.000000e+00, double 7.000000e+00, double 8.000000e+00>, [[V0:%.*]]
 ; CHECK-NEXT:    ret <4 x double> [[T3]]
 ;
   %t1 = fdiv fast <4 x double> <double 1.0, double 2.0, double 3.0, double 4.0>, %v0