[InstCombine] PR35354: Convert store(bitcast, load bitcast (select (Cond, &V1, &V2))  --> store (, load (select(Cond, load &V1, load &V2)))

Summary:
If we have the code like this:
```
float a, b;
a = std::max(a ,b);
```
it is converted into something like this:
```
%call = call dereferenceable(4) float* @_ZSt3maxIfERKT_S2_S2_(float* nonnull dereferenceable(4) %a.addr, float* nonnull dereferenceable(4) %b.addr)
%1 = bitcast float* %call to i32*
%2 = load i32, i32* %1, align 4
%3 = bitcast float* %a.addr to i32*
store i32 %2, i32* %3, align 4
```
After inlinning this code is converted to the next:
```
%1 = load float, float* %a.addr
%2 = load float, float* %b.addr
%cmp.i = fcmp fast olt float %1, %2
%__b.__a.i = select i1 %cmp.i, float* %a.addr, float* %b.addr
%3 = bitcast float* %__b.__a.i to i32*
%4 = load i32, i32* %3, align 4
%5 = bitcast float* %arrayidx to i32*
store i32 %4, i32* %5, align 4

```
This pattern is not recognized as minmax pattern.
Patch solves this problem by converting sequence
```
store (bitcast, (load bitcast (select ((cmp V1, V2), &V1, &V2))))
```
to a sequence
```
store (,load (select((cmp V1, V2), &V1, &V2)))
```
After this the code is recognized as minmax pattern.

Reviewers: RKSimon, spatel

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D40304

llvm-svn: 320157
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 5d24023..29e0a09 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -22,9 +22,11 @@
 #include "llvm/IR/IntrinsicInst.h"
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
+#include "llvm/IR/PatternMatch.h"
 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
 #include "llvm/Transforms/Utils/Local.h"
 using namespace llvm;
+using namespace PatternMatch;
 
 #define DEBUG_TYPE "instcombine"
 
@@ -561,6 +563,28 @@
   return NewStore;
 }
 
+/// Returns true if instruction represent minmax pattern like:
+///   select ((cmp load V1, load V2), V1, V2).
+static bool isMinMaxWithLoads(Value *V) {
+  assert(V->getType()->isPointerTy() && "Expected pointer type.");
+  // Ignore possible ty* to ixx* bitcast.
+  V = peekThroughBitcast(V);
+  // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax
+  // pattern.
+  CmpInst::Predicate Pred;
+  Instruction *L1;
+  Instruction *L2;
+  Value *LHS;
+  Value *RHS;
+  if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)),
+                         m_Value(LHS), m_Value(RHS))))
+    return false;
+  return (match(L1, m_Load(m_Specific(LHS))) &&
+          match(L2, m_Load(m_Specific(RHS)))) ||
+         (match(L1, m_Load(m_Specific(RHS))) &&
+          match(L2, m_Load(m_Specific(LHS))));
+}
+
 /// \brief Combine loads to match the type of their uses' value after looking
 /// through intervening bitcasts.
 ///
@@ -598,10 +622,14 @@
   // integers instead of any other type. We only do this when the loaded type
   // is sized and has a size exactly the same as its store size and the store
   // size is a legal integer type.
+  // Do not perform canonicalization if minmax pattern is found (to avoid
+  // infinite loop).
   if (!Ty->isIntegerTy() && Ty->isSized() &&
       DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) &&
       DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) &&
-      !DL.isNonIntegralPointerType(Ty)) {
+      !DL.isNonIntegralPointerType(Ty) &&
+      !isMinMaxWithLoads(
+          peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) {
     if (all_of(LI.users(), [&LI](User *U) {
           auto *SI = dyn_cast<StoreInst>(U);
           return SI && SI->getPointerOperand() != &LI &&
@@ -1298,6 +1326,30 @@
   return false;
 }
 
+/// Converts store (bitcast (load (bitcast (select ...)))) to
+/// store (load (select ...)), where select is minmax:
+/// select ((cmp load V1, load V2), V1, V2).
+bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, StoreInst &SI) {
+  // bitcast?
+  Value *StoreAddr;
+  if (!match(SI.getPointerOperand(), m_BitCast(m_Value(StoreAddr))))
+    return false;
+  // load? integer?
+  Value *LoadAddr;
+  if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr)))))
+    return false;
+  auto *LI = cast<LoadInst>(SI.getValueOperand());
+  if (!LI->getType()->isIntegerTy())
+    return false;
+  if (!isMinMaxWithLoads(LoadAddr))
+    return false;
+
+  LoadInst *NewLI = combineLoadToNewType(
+      IC, *LI, LoadAddr->getType()->getPointerElementType());
+  combineStoreToNewValue(IC, SI, NewLI);
+  return true;
+}
+
 Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
   Value *Val = SI.getOperand(0);
   Value *Ptr = SI.getOperand(1);
@@ -1322,6 +1374,9 @@
   if (unpackStoreToAggregate(*this, SI))
     return eraseInstFromFunction(SI);
 
+  if (removeBitcastsFromLoadStoreOnMinMax(*this, SI))
+    return eraseInstFromFunction(SI);
+
   // Replace GEP indices if possible.
   if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) {
       Worklist.Add(NewGEPI);