[InstCombine] Combine A->B->A BitCast


This patch enhances InstCombine to handle following case:

        A  ->  B    bitcast
        PHI
        B  ->  A    bitcast

llvm-svn: 263734
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index 6d0e882..826037d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -12,6 +12,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "InstCombineInternal.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/PatternMatch.h"
@@ -1786,6 +1787,106 @@
   return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand());
 }
 
+/// This function handles following case
+///
+///     A  ->  B    cast
+///     PHI
+///     B  ->  A    cast
+///
+/// All the related PHI nodes can be replaced by new PHI nodes with type A.
+/// The uses of \p CI can be changed to the new PHI node corresponding to \p PN.
+Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) {
+  Value *Src = CI.getOperand(0);
+  Type *SrcTy = Src->getType();         // Type B
+  Type *DestTy = CI.getType();          // Type A
+
+  SmallVector<PHINode *, 4> PhiWorklist;
+  SmallSetVector<PHINode *, 4> OldPhiNodes;
+
+  // Find all of the A->B casts and PHI nodes.
+  // We need to inpect all related PHI nodes, but PHIs can be cyclic, so
+  // OldPhiNodes is used to track all known PHI nodes, before adding a new
+  // PHI to PhiWorklist, it is checked against and added to OldPhiNodes first.
+  PhiWorklist.push_back(PN);
+  OldPhiNodes.insert(PN);
+  while (!PhiWorklist.empty()) {
+    auto *OldPN = PhiWorklist.pop_back_val();
+    for (Value *IncValue : OldPN->incoming_values()) {
+      if (isa<Constant>(IncValue))
+        continue;
+
+      auto *LI = dyn_cast<LoadInst>(IncValue);
+      if (LI) {
+        if (LI->hasOneUse() && LI->isSimple())
+          continue;
+        // If a LoadInst has more than one use, changing the type of loaded
+        // value may create another bitcast.
+        return nullptr;
+      }
+
+      auto *PNode = dyn_cast<PHINode>(IncValue);
+      if (PNode) {
+        if (OldPhiNodes.insert(PNode))
+          PhiWorklist.push_back(PNode);
+        continue;
+      }
+
+      auto *BCI = dyn_cast<BitCastInst>(IncValue);
+      // We can't handle other instructions.
+      if (!BCI)
+        return nullptr;
+
+      // Verify it's a A->B cast.
+      Type *TyA = BCI->getOperand(0)->getType();
+      Type *TyB = BCI->getType();
+      if (TyA != DestTy || TyB != SrcTy)
+        return nullptr;
+    }
+  }
+
+  // For each old PHI node, create a corresponding new PHI node with a type A.
+  SmallDenseMap<PHINode *, PHINode *> NewPNodes;
+  for (auto *OldPN : OldPhiNodes) {
+    Builder->SetInsertPoint(OldPN);
+    PHINode *NewPN = Builder->CreatePHI(DestTy, OldPN->getNumOperands());
+    NewPNodes[OldPN] = NewPN;
+  }
+
+  // Fill in the operands of new PHI nodes.
+  for (auto *OldPN : OldPhiNodes) {
+    PHINode *NewPN = NewPNodes[OldPN];
+    for (unsigned j = 0, e = OldPN->getNumOperands(); j != e; ++j) {
+      Value *V = OldPN->getOperand(j);
+      Value *NewV = nullptr;
+      if (auto *C = dyn_cast<Constant>(V)) {
+        NewV = Builder->CreateBitCast(C, DestTy);
+      } else if (auto *LI = dyn_cast<LoadInst>(V)) {
+        Builder->SetInsertPoint(OldPN->getIncomingBlock(j)->getTerminator());
+        NewV = Builder->CreateBitCast(LI, DestTy);
+        Worklist.Add(LI);
+      } else if (auto *BCI = dyn_cast<BitCastInst>(V)) {
+        NewV = BCI->getOperand(0);
+      } else if (auto *PrevPN = dyn_cast<PHINode>(V)) {
+        NewV = NewPNodes[PrevPN];
+      }
+      assert(NewV);
+      NewPN->addIncoming(NewV, OldPN->getIncomingBlock(j));
+    }
+  }
+
+  // If there is a store with type B, change it to type A.
+  for (User *U : PN->users()) {
+    auto *SI = dyn_cast<StoreInst>(U);
+    if (SI && SI->isSimple() && SI->getOperand(0) == PN) {
+      Builder->SetInsertPoint(SI);
+      SI->setOperand(0, Builder->CreateBitCast(NewPNodes[PN], SrcTy));
+      Worklist.Add(SI);
+    }
+  }
+
+  return replaceInstUsesWith(CI, NewPNodes[PN]);
+}
+
 Instruction *InstCombiner::visitBitCast(BitCastInst &CI) {
   // If the operands are integer typed then apply the integer transforms,
   // otherwise just apply the common ones.
@@ -1902,6 +2003,11 @@
     }
   }
 
+  // Handle the A->B->A cast, and there is an intervening PHI node.
+  if (PHINode *PN = dyn_cast<PHINode>(Src))
+    if (Instruction *I = optimizeBitCastFromPhi(CI, PN))
+      return I;
+
   if (Instruction *I = canonicalizeBitCastExtElt(CI, *this, DL))
     return I;
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 3bb50ba..b1e3881 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -394,6 +394,7 @@
   Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN);
   Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask);
   Instruction *foldCastedBitwiseLogic(BinaryOperator &I);
+  Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN);
 
 public:
   /// \brief Inserts an instruction \p New before instruction \p Old