[InstCombine] Fix PR35618: Instcombine hangs on single minmax load bitcast.

Summary:
If we have pattern `store (load(bitcast(select (cmp(V1, V2), &V1,
&V2)))), bitcast)`, but the load is used in other instructions, it leads
to looping in InstCombiner. Patch adds additional check that all users
of the load instructions are stores and then replaces all uses of load
instruction by the new one with new type.

Reviewers: RKSimon, spatel, majnemer

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D41072

llvm-svn: 320510
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 01fc152..5e4d32d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -1339,25 +1339,39 @@
 /// Converts store (bitcast (load (bitcast (select ...)))) to
 /// store (load (select ...)), where select is minmax:
 /// select ((cmp load V1, load V2), V1, V2).
-bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, StoreInst &SI) {
+static Instruction *removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC,
+                                                        StoreInst &SI) {
   // bitcast?
-  Value *StoreAddr;
-  if (!match(SI.getPointerOperand(), m_BitCast(m_Value(StoreAddr))))
-    return false;
+  if (!match(SI.getPointerOperand(), m_BitCast(m_Value())))
+    return nullptr;
   // load? integer?
   Value *LoadAddr;
   if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr)))))
-    return false;
+    return nullptr;
   auto *LI = cast<LoadInst>(SI.getValueOperand());
   if (!LI->getType()->isIntegerTy())
-    return false;
+    return nullptr;
   if (!isMinMaxWithLoads(LoadAddr))
-    return false;
+    return nullptr;
 
+  if (!all_of(LI->users(), [LI, LoadAddr](User *U) {
+        auto *SI = dyn_cast<StoreInst>(U);
+        return SI && SI->getPointerOperand() != LI &&
+               peekThroughBitcast(SI->getPointerOperand()) != LoadAddr &&
+               !SI->getPointerOperand()->isSwiftError();
+      }))
+    return nullptr;
+
+  IC.Builder.SetInsertPoint(LI);
   LoadInst *NewLI = combineLoadToNewType(
       IC, *LI, LoadAddr->getType()->getPointerElementType());
-  combineStoreToNewValue(IC, SI, NewLI);
-  return true;
+  // Replace all the stores with stores of the newly loaded value.
+  for (auto *UI : LI->users()) {
+    auto *USI = cast<StoreInst>(UI);
+    IC.Builder.SetInsertPoint(USI);
+    combineStoreToNewValue(IC, *USI, NewLI);
+  }
+  return LI;
 }
 
 Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
@@ -1384,8 +1398,12 @@
   if (unpackStoreToAggregate(*this, SI))
     return eraseInstFromFunction(SI);
 
-  if (removeBitcastsFromLoadStoreOnMinMax(*this, SI))
-    return eraseInstFromFunction(SI);
+  if (Instruction *I = removeBitcastsFromLoadStoreOnMinMax(*this, SI)) {
+    for (auto *UI : I->users())
+      eraseInstFromFunction(*cast<Instruction>(UI));
+    eraseInstFromFunction(*I);
+    return nullptr;
+  }
 
   // Replace GEP indices if possible.
   if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) {