InstSimplify:  A shuffle of a splat is always the splat itself

Summary:
Fold:
 shuffle (splat-shuffle), undef, M --> splat-shuffle

Reviewers: spatel, RKSimon, craig.topper

Reviewed By: RKSimon

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D31527

llvm-svn: 299990
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 807bbd2..0e522cb 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -4085,8 +4085,9 @@
 static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask,
                                         Type *RetTy, const Query &Q,
                                         unsigned MaxRecurse) {
+  Type *InVecTy = Op0->getType();
   unsigned MaskNumElts = Mask->getType()->getVectorNumElements();
-  unsigned InVecNumElts = Op0->getType()->getVectorNumElements();
+  unsigned InVecNumElts = InVecTy->getVectorNumElements();
 
   auto *Op0Const = dyn_cast<Constant>(Op0);
   auto *Op1Const = dyn_cast<Constant>(Op1);
@@ -4108,11 +4109,22 @@
       MaskSelects1 = true;
   }
   if (!MaskSelects0 && Op1Const)
-    return ConstantFoldShuffleVectorInstruction(UndefValue::get(Op0->getType()),
+    return ConstantFoldShuffleVectorInstruction(UndefValue::get(InVecTy),
                                                 Op1Const, Mask);
   if (!MaskSelects1 && Op0Const)
-    return ConstantFoldShuffleVectorInstruction(
-        Op0Const, UndefValue::get(Op0->getType()), Mask);
+    return ConstantFoldShuffleVectorInstruction(Op0Const,
+                                                UndefValue::get(InVecTy), Mask);
+
+  // A shuffle of a splat is always the splat itself. Legal if the shuffle's
+  // value type is same as the input vectors' type.
+  if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0))
+    if (!MaskSelects1 && RetTy == InVecTy &&
+        OpShuf->getMask()->getSplatValue())
+      return Op0;
+  if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op1))
+    if (!MaskSelects0 && RetTy == InVecTy &&
+        OpShuf->getMask()->getSplatValue())
+      return Op1;
 
   return nullptr;
 }