Add an explicit insert point argument to SplitBlockAndInsertIfThen.

Currently SplitBlockAndInsertIfThen requires that branch condition is an
Instruction itself, which is very inconvenient, because it is sometimes an
Operator, or even a Constant.

llvm-svn: 197677
diff --git a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
index f683bfb..b4c789c 100644
--- a/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/AddressSanitizer.cpp
@@ -617,7 +617,7 @@
 
     Value *Cmp = IRB.CreateICmpNE(Length,
                                   Constant::getNullValue(Length->getType()));
-    InsertBefore = SplitBlockAndInsertIfThen(cast<Instruction>(Cmp), false);
+    InsertBefore = SplitBlockAndInsertIfThen(Cmp, InsertBefore, false);
   }
 
   instrumentMemIntrinsicParam(MI, Dst, Length, InsertBefore, true);
@@ -780,7 +780,7 @@
 
   if (ClAlwaysSlowPath || (TypeSize < 8 * Granularity)) {
     TerminatorInst *CheckTerm =
-        SplitBlockAndInsertIfThen(cast<Instruction>(Cmp), false);
+        SplitBlockAndInsertIfThen(Cmp, InsertBefore, false);
     assert(dyn_cast<BranchInst>(CheckTerm)->isUnconditional());
     BasicBlock *NextBB = CheckTerm->getSuccessor(0);
     IRB.SetInsertPoint(CheckTerm);
@@ -791,7 +791,7 @@
     BranchInst *NewTerm = BranchInst::Create(CrashBlock, NextBB, Cmp2);
     ReplaceInstWithInst(CheckTerm, NewTerm);
   } else {
-    CrashTerm = SplitBlockAndInsertIfThen(cast<Instruction>(Cmp), true);
+    CrashTerm = SplitBlockAndInsertIfThen(Cmp, InsertBefore, true);
   }
 
   Instruction *Crash = generateCrashCode(
@@ -1188,7 +1188,7 @@
   Load->setAtomic(Monotonic);
   Load->setAlignment(1);
   Value *Cmp = IRB.CreateICmpEQ(Constant::getNullValue(Int8Ty), Load);
-  Instruction *Ins = SplitBlockAndInsertIfThen(cast<Instruction>(Cmp), false);
+  Instruction *Ins = SplitBlockAndInsertIfThen(Cmp, IP, false);
   IRB.SetInsertPoint(Ins);
   // We pass &F to __sanitizer_cov. We could avoid this and rely on
   // GET_CALLER_PC, but having the PC of the first instruction is just nice.
@@ -1448,8 +1448,7 @@
         kAsanOptionDetectUAR, IRB.getInt32Ty());
     Value *Cmp = IRB.CreateICmpNE(IRB.CreateLoad(OptionDetectUAR),
                                   Constant::getNullValue(IRB.getInt32Ty()));
-    Instruction *Term =
-        SplitBlockAndInsertIfThen(cast<Instruction>(Cmp), false);
+    Instruction *Term = SplitBlockAndInsertIfThen(Cmp, InsBefore, false);
     BasicBlock *CmpBlock = cast<Instruction>(Cmp)->getParent();
     IRBuilder<> IRBIf(Term);
     LocalStackBase = IRBIf.CreateCall2(
@@ -1529,8 +1528,7 @@
         //     **SavedFlagPtr(LocalStackBase) = 0
         // FIXME: if LocalStackBase != OrigStackBase don't call poisonRedZones.
         Value *Cmp = IRBRet.CreateICmpNE(LocalStackBase, OrigStackBase);
-        TerminatorInst *PoisonTerm =
-            SplitBlockAndInsertIfThen(cast<Instruction>(Cmp), false);
+        TerminatorInst *PoisonTerm = SplitBlockAndInsertIfThen(Cmp, Ret, false);
         IRBuilder<> IRBPoison(PoisonTerm);
         int ClassSize = kMinStackMallocSize << StackMallocIdx;
         SetShadowToStackAfterReturnInlined(IRBPoison, ShadowBase,
diff --git a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
index c539be9..653ad7f 100644
--- a/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp
@@ -735,10 +735,9 @@
         while (isa<PHINode>(Pos) || isa<AllocaInst>(Pos))
           Pos = Pos->getNextNode();
         IRBuilder<> IRB(Pos);
-        Instruction *NeInst = cast<Instruction>(
-            IRB.CreateICmpNE(*i, DFSF.DFS.ZeroShadow));
+        Value *Ne = IRB.CreateICmpNE(*i, DFSF.DFS.ZeroShadow);
         BranchInst *BI = cast<BranchInst>(SplitBlockAndInsertIfThen(
-            NeInst, /*Unreachable=*/ false, ColdCallWeights));
+            Ne, Pos, /*Unreachable=*/false, ColdCallWeights));
         IRBuilder<> ThenIRB(BI);
         ThenIRB.CreateCall(DFSF.DFS.DFSanNonzeroLabelFn);
       }
@@ -838,10 +837,9 @@
   IRBuilder<> IRB(Pos);
   BasicBlock *Head = Pos->getParent();
   Value *Ne = IRB.CreateICmpNE(V1, V2);
-  Instruction *NeInst = dyn_cast<Instruction>(Ne);
-  if (NeInst) {
+  if (Ne) {
     BranchInst *BI = cast<BranchInst>(SplitBlockAndInsertIfThen(
-        NeInst, /*Unreachable=*/ false, ColdCallWeights));
+        Ne, Pos, /*Unreachable=*/false, ColdCallWeights));
     IRBuilder<> ThenIRB(BI);
     CallInst *Call = ThenIRB.CreateCall2(DFSanUnionFn, V1, V2);
     Call->addAttribute(AttributeSet::ReturnIndex, Attribute::ZExt);
diff --git a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
index f2e1ab7..8a52a44 100644
--- a/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
+++ b/llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp
@@ -565,8 +565,7 @@
           Value *Cmp = IRB.CreateICmpNE(ConvertedShadow,
               getCleanShadow(ConvertedShadow), "_mscmp");
           Instruction *CheckTerm =
-            SplitBlockAndInsertIfThen(cast<Instruction>(Cmp), false,
-                                      MS.OriginStoreWeights);
+              SplitBlockAndInsertIfThen(Cmp, &I, false, MS.OriginStoreWeights);
           IRBuilder<> IRBNew(CheckTerm);
           IRBNew.CreateAlignedStore(getOrigin(Val), getOriginPtr(Addr, IRBNew),
                                     Alignment);
@@ -588,10 +587,9 @@
         continue;
       Value *Cmp = IRB.CreateICmpNE(ConvertedShadow,
                                     getCleanShadow(ConvertedShadow), "_mscmp");
-      Instruction *CheckTerm =
-        SplitBlockAndInsertIfThen(cast<Instruction>(Cmp),
-                                  /* Unreachable */ !ClKeepGoing,
-                                  MS.ColdCallWeights);
+      Instruction *CheckTerm = SplitBlockAndInsertIfThen(
+          Cmp, OrigIns,
+          /* Unreachable */ !ClKeepGoing, MS.ColdCallWeights);
 
       IRB.SetInsertPoint(CheckTerm);
       if (MS.TrackOrigins) {
@@ -629,7 +627,7 @@
             IRB.CreatePHI(Fn0->getType(), 2, "msandr.indirect_target");
 
         Instruction *CheckTerm = SplitBlockAndInsertIfThen(
-            cast<Instruction>(NotInThisModule),
+            NotInThisModule, NewFnPhi,
             /* Unreachable */ false, MS.ColdCallWeights);
 
         IRB.SetInsertPoint(CheckTerm);
diff --git a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
index 12de9ee..17bd115 100644
--- a/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
+++ b/llvm/lib/Transforms/Utils/BasicBlockUtils.cpp
@@ -630,28 +630,29 @@
 }
 
 /// SplitBlockAndInsertIfThen - Split the containing block at the
-/// specified instruction - everything before and including Cmp stays
-/// in the old basic block, and everything after Cmp is moved to a
+/// specified instruction - everything before and including SplitBefore stays
+/// in the old basic block, and everything after SplitBefore is moved to a
 /// new block. The two blocks are connected by a conditional branch
 /// (with value of Cmp being the condition).
 /// Before:
 ///   Head
-///   Cmp
+///   SplitBefore
 ///   Tail
 /// After:
 ///   Head
-///   Cmp
-///   if (Cmp)
+///   if (Cond)
 ///     ThenBlock
+///   SplitBefore
 ///   Tail
 ///
 /// If Unreachable is true, then ThenBlock ends with
 /// UnreachableInst, otherwise it branches to Tail.
 /// Returns the NewBasicBlock's terminator.
 
-TerminatorInst *llvm::SplitBlockAndInsertIfThen(Instruction *Cmp,
-    bool Unreachable, MDNode *BranchWeights) {
-  Instruction *SplitBefore = Cmp->getNextNode();
+TerminatorInst *llvm::SplitBlockAndInsertIfThen(Value *Cond,
+                                                Instruction *SplitBefore,
+                                                bool Unreachable,
+                                                MDNode *BranchWeights) {
   BasicBlock *Head = SplitBefore->getParent();
   BasicBlock *Tail = Head->splitBasicBlock(SplitBefore);
   TerminatorInst *HeadOldTerm = Head->getTerminator();
@@ -663,7 +664,7 @@
   else
     CheckTerm = BranchInst::Create(Tail, ThenBlock);
   BranchInst *HeadNewTerm =
-    BranchInst::Create(/*ifTrue*/ThenBlock, /*ifFalse*/Tail, Cmp);
+    BranchInst::Create(/*ifTrue*/ThenBlock, /*ifFalse*/Tail, Cond);
   HeadNewTerm->setMetadata(LLVMContext::MD_prof, BranchWeights);
   ReplaceInstWithInst(HeadOldTerm, HeadNewTerm);
   return CheckTerm;