Make SimplifyDemandedVectorElts simplify vectors with multiple
users, and teach it about shufflevector instructions.
Also, fix a subtle bug in SimplifyDemandedVectorElts'
insertelement code.
This is a patch that was originally written by Eli Friedman,
with some fixes and cleanup by me.
llvm-svn: 55995
diff --git a/llvm/lib/Transforms/Scalar/InstructionCombining.cpp b/llvm/lib/Transforms/Scalar/InstructionCombining.cpp
index 6b25f52..3ec97dd 100644
--- a/llvm/lib/Transforms/Scalar/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/Scalar/InstructionCombining.cpp
@@ -1355,8 +1355,7 @@
   unsigned VWidth = cast<VectorType>(V->getType())->getNumElements();
   assert(VWidth <= 64 && "Vector too wide to analyze!");
   uint64_t EltMask = ~0ULL >> (64-VWidth);
-  assert(DemandedElts != EltMask && (DemandedElts & ~EltMask) == 0 &&
-         "Invalid DemandedElts!");
+  assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!");
 
   if (isa<UndefValue>(V)) {
     // If the entire vector is undefined, just return this info.
@@ -1400,14 +1399,23 @@
     return ConstantVector::get(Elts);
   }
   
-  if (!V->hasOneUse()) {    // Other users may use these bits.
-    if (Depth != 0) {       // Not at the root.
+  // Limit search depth.
+  if (Depth == 10)
+    return false;
+
+  // If multiple users are using the root value, procede with
+  // simplification conservatively assuming that all elements
+  // are needed.
+  if (!V->hasOneUse()) {
+    // Quit if we find multiple users of a non-root value though.
+    // They'll be handled when it's their turn to be visited by
+    // the main instcombine process.
+    if (Depth != 0)
       // TODO: Just compute the UndefElts information recursively.
       return false;
-    }
-    return false;
-  } else if (Depth == 10) {        // Limit search depth.
-    return false;
+
+    // Conservatively assume that all elements are needed.
+    DemandedElts = EltMask;
   }
   
   Instruction *I = dyn_cast<Instruction>(V);
@@ -1446,7 +1454,65 @@
     if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
 
     // The inserted element is defined.
-    UndefElts |= 1ULL << IdxNo;
+    UndefElts &= ~(1ULL << IdxNo);
+    break;
+  }
+  case Instruction::ShuffleVector: {
+    ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I);
+    uint64_t LeftDemanded = 0, RightDemanded = 0;
+    for (unsigned i = 0; i < VWidth; i++) {
+      if (DemandedElts & (1ULL << i)) {
+        unsigned MaskVal = Shuffle->getMaskValue(i);
+        if (MaskVal != -1u) {
+          assert(MaskVal < VWidth * 2 &&
+                 "shufflevector mask index out of range!");
+          if (MaskVal < VWidth)
+            LeftDemanded |= 1ULL << MaskVal;
+          else
+            RightDemanded |= 1ULL << (MaskVal - VWidth);
+        }
+      }
+    }
+
+    TmpV = SimplifyDemandedVectorElts(I->getOperand(0), LeftDemanded,
+                                      UndefElts2, Depth+1);
+    if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; }
+
+    uint64_t UndefElts3;
+    TmpV = SimplifyDemandedVectorElts(I->getOperand(1), RightDemanded,
+                                      UndefElts3, Depth+1);
+    if (TmpV) { I->setOperand(1, TmpV); MadeChange = true; }
+
+    bool NewUndefElts = false;
+    for (unsigned i = 0; i < VWidth; i++) {
+      unsigned MaskVal = Shuffle->getMaskValue(i);
+      if (MaskVal == -1) {
+        uint64_t NewBit = 1ULL << i;
+        UndefElts |= NewBit;
+      } else if (MaskVal < VWidth) {
+        uint64_t NewBit = ((UndefElts2 >> MaskVal) & 1) << i;
+        NewUndefElts |= NewBit;
+        UndefElts |= NewBit;
+      } else {
+        uint64_t NewBit = ((UndefElts3 >> (MaskVal - VWidth)) & 1) << i;
+        NewUndefElts |= NewBit;
+        UndefElts |= NewBit;
+      }
+    }
+
+    if (NewUndefElts) {
+      // Add additional discovered undefs.
+      std::vector<Constant*> Elts;
+      for (unsigned i = 0; i < VWidth; ++i) {
+        if (UndefElts & (1ULL << i))
+          Elts.push_back(UndefValue::get(Type::Int32Ty));
+        else
+          Elts.push_back(ConstantInt::get(Type::Int32Ty,
+                                          Shuffle->getMaskValue(i)));
+      }
+      I->setOperand(2, ConstantVector::get(Elts));
+      MadeChange = true;
+    }
     break;
   }
   case Instruction::BitCast: {
@@ -11224,31 +11290,13 @@
   // Undefined shuffle mask -> undefined value.
   if (isa<UndefValue>(SVI.getOperand(2)))
     return ReplaceInstUsesWith(SVI, UndefValue::get(SVI.getType()));
-  
-  // If we have shuffle(x, undef, mask) and any elements of mask refer to
-  // the undef, change them to undefs.
-  if (isa<UndefValue>(SVI.getOperand(1))) {
-    // Scan to see if there are any references to the RHS.  If so, replace them
-    // with undef element refs and set MadeChange to true.
-    for (unsigned i = 0, e = Mask.size(); i != e; ++i) {
-      if (Mask[i] >= e && Mask[i] != 2*e) {
-        Mask[i] = 2*e;
-        MadeChange = true;
-      }
-    }
-    
-    if (MadeChange) {
-      // Remap any references to RHS to use LHS.
-      std::vector<Constant*> Elts;
-      for (unsigned i = 0, e = Mask.size(); i != e; ++i) {
-        if (Mask[i] == 2*e)
-          Elts.push_back(UndefValue::get(Type::Int32Ty));
-        else
-          Elts.push_back(ConstantInt::get(Type::Int32Ty, Mask[i]));
-      }
-      SVI.setOperand(2, ConstantVector::get(Elts));
-    }
-  }
+
+  uint64_t UndefElts;
+  unsigned VWidth = cast<VectorType>(SVI.getType())->getNumElements();
+  uint64_t AllOnesEltMask = ~0ULL >> (64-VWidth);
+  if (VWidth <= 64 &&
+      SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts))
+    MadeChange = true;
   
   // Canonicalize shuffle(x    ,x,mask) -> shuffle(x, undef,mask')
   // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask').
diff --git a/llvm/test/Transforms/InstCombine/pr2645.ll b/llvm/test/Transforms/InstCombine/pr2645.ll
new file mode 100644
index 0000000..04cc185
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/pr2645.ll
@@ -0,0 +1,33 @@
+; RUN: llvm-as < %s | opt -instcombine | llvm-dis | grep {insertelement <4 x float> undef}
+
+; Instcombine should be able to prove that none of the
+; insertelement's first operand's elements are needed.
+
+define internal void @""(i8*) {
+; <label>:1
+        bitcast i8* %0 to i32*          ; <i32*>:2 [#uses=1]
+        load i32* %2, align 1           ; <i32>:3 [#uses=1]
+        getelementptr i8* %0, i32 4             ; <i8*>:4 [#uses=1]
+        bitcast i8* %4 to i32*          ; <i32*>:5 [#uses=1]
+        load i32* %5, align 1           ; <i32>:6 [#uses=1]
+        br label %7
+
+; <label>:7             ; preds = %9, %1
+        %.01 = phi <4 x float> [ undef, %1 ], [ %12, %9 ]               ; <<4 x float>> [#uses=1]
+        %.0 = phi i32 [ %3, %1 ], [ %15, %9 ]           ; <i32> [#uses=3]
+        icmp slt i32 %.0, %6            ; <i1>:8 [#uses=1]
+        br i1 %8, label %9, label %16
+
+; <label>:9             ; preds = %7
+        sitofp i32 %.0 to float         ; <float>:10 [#uses=1]
+        insertelement <4 x float> %.01, float %10, i32 0                ; <<4 x float>>:11 [#uses=1]
+        shufflevector <4 x float> %11, <4 x float> undef, <4 x i32> zeroinitializer             ; <<4 x float>>:12 [#uses=2]
+        getelementptr i8* %0, i32 48            ; <i8*>:13 [#uses=1]
+        bitcast i8* %13 to <4 x float>*         ; <<4 x float>*>:14 [#uses=1]
+        store <4 x float> %12, <4 x float>* %14, align 16
+        add i32 %.0, 2          ; <i32>:15 [#uses=1]
+        br label %7
+
+; <label>:16            ; preds = %7
+        ret void
+}