Re-apply r269081 and r269082 with a fix for MSVC.

llvm-svn: 269094
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index e7c161d..315de83 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -31,6 +31,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/MapVector.h"
+#include "llvm/Analysis/BitSetUtils.h"
 #include "llvm/IR/CallSite.h"
 #include "llvm/IR/Constants.h"
 #include "llvm/IR/DataLayout.h"
@@ -231,10 +232,6 @@
       : M(M), Int8Ty(Type::getInt8Ty(M.getContext())),
         Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
         Int32Ty(Type::getInt32Ty(M.getContext())) {}
-  void findLoadCallsAtConstantOffset(Metadata *BitSet, Value *Ptr,
-                                     uint64_t Offset, Value *VTable);
-  void findCallsAtConstantOffset(Metadata *BitSet, Value *Ptr, uint64_t Offset,
-                                 Value *VTable);
 
   void buildBitSets(std::vector<VTableBits> &Bits,
                     DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets);
@@ -283,43 +280,6 @@
   return new WholeProgramDevirt;
 }
 
-// Search for virtual calls that call FPtr and add them to CallSlots.
-void DevirtModule::findCallsAtConstantOffset(Metadata *BitSet, Value *FPtr,
-                                             uint64_t Offset, Value *VTable) {
-  for (const Use &U : FPtr->uses()) {
-    Value *User = U.getUser();
-    if (isa<BitCastInst>(User)) {
-      findCallsAtConstantOffset(BitSet, User, Offset, VTable);
-    } else if (auto CI = dyn_cast<CallInst>(User)) {
-      CallSlots[{BitSet, Offset}].push_back({VTable, CI});
-    } else if (auto II = dyn_cast<InvokeInst>(User)) {
-      CallSlots[{BitSet, Offset}].push_back({VTable, II});
-    }
-  }
-}
-
-// Search for virtual calls that load from VPtr and add them to CallSlots.
-void DevirtModule::findLoadCallsAtConstantOffset(Metadata *BitSet, Value *VPtr,
-                                                 uint64_t Offset,
-                                                 Value *VTable) {
-  for (const Use &U : VPtr->uses()) {
-    Value *User = U.getUser();
-    if (isa<BitCastInst>(User)) {
-      findLoadCallsAtConstantOffset(BitSet, User, Offset, VTable);
-    } else if (isa<LoadInst>(User)) {
-      findCallsAtConstantOffset(BitSet, User, Offset, VTable);
-    } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) {
-      // Take into account the GEP offset.
-      if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) {
-        SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end());
-        uint64_t GEPOffset = M.getDataLayout().getIndexedOffsetInType(
-            GEP->getSourceElementType(), Indices);
-        findLoadCallsAtConstantOffset(BitSet, User, Offset + GEPOffset, VTable);
-      }
-    }
-  }
-}
-
 void DevirtModule::buildBitSets(
     std::vector<VTableBits> &Bits,
     DenseMap<Metadata *, std::set<BitSetInfo>> &BitSets) {
@@ -674,22 +634,24 @@
     if (!CI)
       continue;
 
-    // Find llvm.assume intrinsics for this llvm.bitset.test call.
+    // Search for virtual calls based on %p and add them to DevirtCalls.
+    SmallVector<DevirtCallSite, 1> DevirtCalls;
     SmallVector<CallInst *, 1> Assumes;
-    for (const Use &CIU : CI->uses()) {
-      auto AssumeCI = dyn_cast<CallInst>(CIU.getUser());
-      if (AssumeCI && AssumeCI->getCalledValue() == AssumeFunc)
-        Assumes.push_back(AssumeCI);
-    }
+    findDevirtualizableCalls(DevirtCalls, Assumes, CI);
 
-    // If we found any, search for virtual calls based on %p and add them to
-    // CallSlots.
+    // If we found any, add them to CallSlots. Only do this if we haven't seen
+    // the vtable pointer before, as it may have been CSE'd with pointers from
+    // other call sites, and we don't want to process call sites multiple times.
     if (!Assumes.empty()) {
       Metadata *BitSet =
           cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
-      if (SeenPtrs.insert(Ptr).second)
-        findLoadCallsAtConstantOffset(BitSet, Ptr, 0, CI->getArgOperand(0));
+      if (SeenPtrs.insert(Ptr).second) {
+        for (DevirtCallSite Call : DevirtCalls) {
+          CallSlots[{BitSet, Call.Offset}].push_back(
+              {CI->getArgOperand(0), Call.CS});
+        }
+      }
     }
 
     // We no longer need the assumes or the bitset test.