[InstCombine] enhance shuffle-of-binops to allow different variable ops (PR37806)
This was discussed in D48401 as another improvement for:
https://bugs.llvm.org/show_bug.cgi?id=37806
If we have 2 different variable values, then we shuffle (select) those lanes, 
shuffle (select) the constants, and then perform the binop. This eliminates a binop.
The new shuffle uses the same shuffle mask as the existing shuffle, so there's no 
danger of creating a difficult shuffle.
All of the earlier constraints still apply, but we also check for extra uses to 
avoid creating more instructions than we'll remove.
Additionally, we're disallowing the fold for div/rem because that could expose a
UB hole.
Differential Revision: https://reviews.llvm.org/D48678
llvm-svn: 335974
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index d4fb086..dd579a6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -1140,7 +1140,8 @@
   return true;
 }
 
-static Instruction *foldSelectShuffles(ShuffleVectorInst &Shuf) {
+static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf,
+                                      InstCombiner::BuilderTy &Builder) {
   // Folds under here require the equivalent of a vector select.
   if (!Shuf.isSelect())
     return nullptr;
@@ -1150,16 +1151,14 @@
       !match(Shuf.getOperand(1), m_BinOp(B1)))
     return nullptr;
 
-  // TODO: Fold the case with different variable operands (requires creating a
-  // new shuffle and checking number of uses).
-  Value *X;
+  Value *X, *Y;
   Constant *C0, *C1;
   bool ConstantsAreOp1;
   if (match(B0, m_BinOp(m_Value(X), m_Constant(C0))) &&
-      match(B1, m_BinOp(m_Specific(X), m_Constant(C1))))
+      match(B1, m_BinOp(m_Value(Y), m_Constant(C1))))
     ConstantsAreOp1 = true;
   else if (match(B0, m_BinOp(m_Constant(C0), m_Value(X))) &&
-           match(B1, m_BinOp(m_Constant(C1), m_Specific(X))))
+           match(B1, m_BinOp(m_Constant(C1), m_Value(Y))))
     ConstantsAreOp1 = false;
   else
     return nullptr;
@@ -1191,9 +1190,36 @@
   // The opcodes must be the same. Use a new name to make that clear.
   BinaryOperator::BinaryOps BOpc = Opc0;
 
-  // Remove a binop and the shuffle by rearranging the constant:
-  // shuffle (op X, C0), (op X, C1), M --> op X, C'
-  // shuffle (op C0, X), (op C1, X), M --> op C', X
+  Value *V;
+  if (X == Y) {
+    // Remove a binop and the shuffle by rearranging the constant:
+    // shuffle (op V, C0), (op V, C1), M --> op V, C'
+    // shuffle (op C0, V), (op C1, V), M --> op C', V
+    V = X;
+  } else if (!Instruction::isIntDivRem(BOpc) &&
+             (B0->hasOneUse() || B1->hasOneUse())) {
+    // If there are 2 different variable operands, we must create a new shuffle
+    // (select) first, so check uses to ensure that we don't end up with more
+    // instructions than we started with.
+    //
+    // Note: In general, we do not create new shuffles in InstCombine because we
+    // do not know if a target can lower an arbitrary shuffle optimally. In this
+    // case, the shuffle uses the existing mask, so there is no additional risk.
+    //
+    // TODO: We are disallowing div/rem because a shuffle with an undef mask
+    // element would propagate an undef value to the div/rem. That's not
+    // safe in general because div/rem allow for undefined behavior. We can
+    // loosen this restriction (eg, check if the mask has no undefs or replace
+    // undef elements).
+
+    // Select the variable vectors first, then perform the binop:
+    // shuffle (op X, C0), (op Y, C1), M --> op (shuffle X, Y, M), C'
+    // shuffle (op C0, X), (op C1, Y), M --> op C', (shuffle X, Y, M)
+    V = Builder.CreateShuffleVector(X, Y, Shuf.getMask());
+  } else {
+    return nullptr;
+  }
+
   Constant *NewC = ConstantExpr::getShuffleVector(C0, C1, Shuf.getMask());
 
   // If the shuffle mask contains undef elements, then the new constant
@@ -1202,8 +1228,8 @@
   if (Instruction::isIntDivRem(BOpc))
     NewC = getSafeVectorConstantForIntDivRem(NewC);
 
-  Instruction *NewBO = ConstantsAreOp1 ? BinaryOperator::Create(BOpc, X, NewC) :
-                                         BinaryOperator::Create(BOpc, NewC, X);
+  Instruction *NewBO = ConstantsAreOp1 ? BinaryOperator::Create(BOpc, V, NewC) :
+                                         BinaryOperator::Create(BOpc, NewC, V);
 
   // Flags are intersected from the 2 source binops.
   NewBO->copyIRFlags(B0);
@@ -1223,7 +1249,7 @@
           LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI)))
     return replaceInstUsesWith(SVI, V);
 
-  if (Instruction *I = foldSelectShuffles(SVI))
+  if (Instruction *I = foldSelectShuffle(SVI, Builder))
     return I;
 
   bool MadeChange = false;
diff --git a/llvm/test/Transforms/InstCombine/shuffle_select.ll b/llvm/test/Transforms/InstCombine/shuffle_select.ll
index ecc1d92..a19dd62 100644
--- a/llvm/test/Transforms/InstCombine/shuffle_select.ll
+++ b/llvm/test/Transforms/InstCombine/shuffle_select.ll
@@ -241,9 +241,8 @@
 
 define <4 x i32> @add_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 ; CHECK-LABEL: @add_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = add <4 x i32> [[V0:%.*]], <i32 1, i32 undef, i32 3, i32 undef>
-; CHECK-NEXT:    [[T2:%.*]] = add <4 x i32> [[V1:%.*]], <i32 undef, i32 6, i32 undef, i32 8>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 0, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V0:%.*]], <4 x i32> [[V1:%.*]], <4 x i32> <i32 0, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = add <4 x i32> [[TMP1]], <i32 1, i32 6, i32 3, i32 8>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = add <4 x i32> %v0, <i32 1, i32 2, i32 3, i32 4>
@@ -256,9 +255,8 @@
 
 define <4 x i32> @sub_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 ; CHECK-LABEL: @sub_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = sub <4 x i32> <i32 1, i32 2, i32 3, i32 undef>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = sub <4 x i32> <i32 undef, i32 undef, i32 undef, i32 8>, [[V1:%.*]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 0, i32 1, i32 2, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V0:%.*]], <4 x i32> [[V1:%.*]], <4 x i32> <i32 0, i32 1, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = sub <4 x i32> <i32 1, i32 2, i32 3, i32 8>, [[TMP1]]
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = sub <4 x i32> <i32 1, i32 2, i32 3, i32 4>, %v0
@@ -272,9 +270,8 @@
 
 define <4 x i32> @mul_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 ; CHECK-LABEL: @mul_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = mul <4 x i32> [[V0:%.*]], <i32 undef, i32 undef, i32 3, i32 undef>
-; CHECK-NEXT:    [[T2:%.*]] = mul <4 x i32> [[V1:%.*]], <i32 undef, i32 6, i32 undef, i32 8>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 undef, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V0:%.*]], <4 x i32> [[V1:%.*]], <4 x i32> <i32 undef, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = mul <4 x i32> [[TMP1]], <i32 undef, i32 6, i32 3, i32 8>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = mul <4 x i32> %v0, <i32 1, i32 2, i32 3, i32 4>
@@ -287,9 +284,8 @@
 
 define <4 x i32> @shl_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 ; CHECK-LABEL: @shl_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = shl nsw <4 x i32> [[V0:%.*]], <i32 1, i32 2, i32 3, i32 4>
-; CHECK-NEXT:    [[T2:%.*]] = shl nsw <4 x i32> [[V1:%.*]], <i32 5, i32 6, i32 7, i32 8>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 undef, i32 5, i32 2, i32 undef>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V0:%.*]], <4 x i32> [[V1:%.*]], <4 x i32> <i32 undef, i32 5, i32 2, i32 undef>
+; CHECK-NEXT:    [[T3:%.*]] = shl nsw <4 x i32> [[TMP1]], <i32 undef, i32 6, i32 3, i32 undef>
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = shl nsw <4 x i32> %v0, <i32 1, i32 2, i32 3, i32 4>
@@ -302,9 +298,8 @@
 
 define <4 x i32> @lshr_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 ; CHECK-LABEL: @lshr_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = lshr <4 x i32> <i32 1, i32 2, i32 3, i32 4>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = lshr exact <4 x i32> <i32 5, i32 6, i32 7, i32 8>, [[V1:%.*]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 4, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V0:%.*]], <4 x i32> [[V1:%.*]], <4 x i32> <i32 4, i32 5, i32 2, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = lshr <4 x i32> <i32 5, i32 6, i32 3, i32 8>, [[TMP1]]
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
   %t1 = lshr <4 x i32> <i32 1, i32 2, i32 3, i32 4>, %v0
@@ -317,9 +312,8 @@
 
 define <3 x i32> @ashr_2_vars(<3 x i32> %v0, <3 x i32> %v1) {
 ; CHECK-LABEL: @ashr_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = ashr <3 x i32> [[V0:%.*]], <i32 1, i32 2, i32 3>
-; CHECK-NEXT:    [[T2:%.*]] = ashr <3 x i32> [[V1:%.*]], <i32 4, i32 5, i32 6>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <3 x i32> [[T1]], <3 x i32> [[T2]], <3 x i32> <i32 3, i32 1, i32 2>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x i32> [[V0:%.*]], <3 x i32> [[V1:%.*]], <3 x i32> <i32 3, i32 1, i32 2>
+; CHECK-NEXT:    [[T3:%.*]] = ashr <3 x i32> [[TMP1]], <i32 4, i32 2, i32 3>
 ; CHECK-NEXT:    ret <3 x i32> [[T3]]
 ;
   %t1 = ashr <3 x i32> %v0, <i32 1, i32 2, i32 3>
@@ -330,9 +324,8 @@
 
 define <3 x i42> @and_2_vars(<3 x i42> %v0, <3 x i42> %v1) {
 ; CHECK-LABEL: @and_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = and <3 x i42> [[V0:%.*]], <i42 1, i42 undef, i42 undef>
-; CHECK-NEXT:    [[T2:%.*]] = and <3 x i42> [[V1:%.*]], <i42 undef, i42 5, i42 undef>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <3 x i42> [[T1]], <3 x i42> [[T2]], <3 x i32> <i32 0, i32 4, i32 undef>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <3 x i42> [[V0:%.*]], <3 x i42> [[V1:%.*]], <3 x i32> <i32 0, i32 4, i32 undef>
+; CHECK-NEXT:    [[T3:%.*]] = and <3 x i42> [[TMP1]], <i42 1, i42 5, i42 undef>
 ; CHECK-NEXT:    ret <3 x i42> [[T3]]
 ;
   %t1 = and <3 x i42> %v0, <i42 1, i42 2, i42 3>
@@ -346,8 +339,8 @@
 define <4 x i32> @or_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 ; CHECK-LABEL: @or_2_vars(
 ; CHECK-NEXT:    [[T1:%.*]] = or <4 x i32> [[V0:%.*]], <i32 1, i32 2, i32 3, i32 4>
-; CHECK-NEXT:    [[T2:%.*]] = or <4 x i32> [[V1:%.*]], <i32 5, i32 6, i32 undef, i32 undef>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x i32> [[T1]], <4 x i32> [[T2]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x i32> [[V0]], <4 x i32> [[V1:%.*]], <4 x i32> <i32 4, i32 5, i32 2, i32 3>
+; CHECK-NEXT:    [[T3:%.*]] = or <4 x i32> [[TMP1]], <i32 5, i32 6, i32 3, i32 4>
 ; CHECK-NEXT:    call void @use_v4i32(<4 x i32> [[T1]])
 ; CHECK-NEXT:    ret <4 x i32> [[T3]]
 ;
@@ -392,6 +385,8 @@
   ret <4 x i32> %t3
 }
 
+; TODO: If the shuffle has no undefs, it's safe to shuffle the variables first.
+
 define <4 x i32> @sdiv_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 ; CHECK-LABEL: @sdiv_2_vars(
 ; CHECK-NEXT:    [[T1:%.*]] = sdiv <4 x i32> [[V0:%.*]], <i32 1, i32 2, i32 3, i32 4>
@@ -405,6 +400,8 @@
   ret <4 x i32> %t3
 }
 
+; TODO: If the shuffle has no undefs, it's safe to shuffle the variables first.
+
 define <4 x i32> @urem_2_vars(<4 x i32> %v0, <4 x i32> %v1) {
 ; CHECK-LABEL: @urem_2_vars(
 ; CHECK-NEXT:    [[T1:%.*]] = urem <4 x i32> <i32 1, i32 2, i32 3, i32 4>, [[V0:%.*]]
@@ -435,9 +432,8 @@
 
 define <4 x float> @fadd_2_vars(<4 x float> %v0, <4 x float> %v1) {
 ; CHECK-LABEL: @fadd_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = fadd <4 x float> [[V0:%.*]], <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00>
-; CHECK-NEXT:    [[T2:%.*]] = fadd <4 x float> [[V1:%.*]], <float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x float> [[T1]], <4 x float> [[T2]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[V0:%.*]], <4 x float> [[V1:%.*]], <4 x i32> <i32 0, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = fadd <4 x float> [[TMP1]], <float 1.000000e+00, float 2.000000e+00, float 7.000000e+00, float 8.000000e+00>
 ; CHECK-NEXT:    ret <4 x float> [[T3]]
 ;
   %t1 = fadd <4 x float> %v0, <float 1.0, float 2.0, float 3.0, float 4.0>
@@ -448,9 +444,8 @@
 
 define <4 x double> @fsub_2_vars(<4 x double> %v0, <4 x double> %v1) {
 ; CHECK-LABEL: @fsub_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = fsub <4 x double> <double 1.000000e+00, double 2.000000e+00, double 3.000000e+00, double 4.000000e+00>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = fsub <4 x double> <double 5.000000e+00, double 6.000000e+00, double 7.000000e+00, double 8.000000e+00>, [[V1:%.*]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x double> [[T1]], <4 x double> [[T2]], <4 x i32> <i32 undef, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x double> [[V0:%.*]], <4 x double> [[V1:%.*]], <4 x i32> <i32 undef, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = fsub <4 x double> <double undef, double 2.000000e+00, double 7.000000e+00, double 8.000000e+00>, [[TMP1]]
 ; CHECK-NEXT:    ret <4 x double> [[T3]]
 ;
   %t1 = fsub <4 x double> <double 1.0, double 2.0, double 3.0, double 4.0>, %v0
@@ -463,9 +458,8 @@
 
 define <4 x float> @fmul_2_vars(<4 x float> %v0, <4 x float> %v1) {
 ; CHECK-LABEL: @fmul_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = fmul reassoc nsz <4 x float> [[V0:%.*]], <float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00>
-; CHECK-NEXT:    [[T2:%.*]] = fmul reassoc nsz <4 x float> [[V1:%.*]], <float 5.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00>
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x float> [[T1]], <4 x float> [[T2]], <4 x i32> <i32 0, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x float> [[V0:%.*]], <4 x float> [[V1:%.*]], <4 x i32> <i32 0, i32 5, i32 6, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = fmul reassoc nsz <4 x float> [[TMP1]], <float 1.000000e+00, float 6.000000e+00, float 7.000000e+00, float 8.000000e+00>
 ; CHECK-NEXT:    ret <4 x float> [[T3]]
 ;
   %t1 = fmul reassoc nsz <4 x float> %v0, <float 1.0, float 2.0, float 3.0, float 4.0>
@@ -476,9 +470,8 @@
 
 define <4 x double> @frem_2_vars(<4 x double> %v0, <4 x double> %v1) {
 ; CHECK-LABEL: @frem_2_vars(
-; CHECK-NEXT:    [[T1:%.*]] = frem nnan ninf <4 x double> <double 1.000000e+00, double 2.000000e+00, double 3.000000e+00, double 4.000000e+00>, [[V0:%.*]]
-; CHECK-NEXT:    [[T2:%.*]] = frem nnan arcp <4 x double> <double 5.000000e+00, double 6.000000e+00, double 7.000000e+00, double 8.000000e+00>, [[V1:%.*]]
-; CHECK-NEXT:    [[T3:%.*]] = shufflevector <4 x double> [[T1]], <4 x double> [[T2]], <4 x i32> <i32 undef, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x double> [[V0:%.*]], <4 x double> [[V1:%.*]], <4 x i32> <i32 undef, i32 1, i32 6, i32 7>
+; CHECK-NEXT:    [[T3:%.*]] = frem nnan <4 x double> <double undef, double 2.000000e+00, double 7.000000e+00, double 8.000000e+00>, [[TMP1]]
 ; CHECK-NEXT:    ret <4 x double> [[T3]]
 ;
   %t1 = frem nnan ninf <4 x double> <double 1.0, double 2.0, double 3.0, double 4.0>, %v0