[WPD] Fix incorrect devirtualization after indirect call promotion

Summary:
Add a dominance check to ensure that the possible devirtualizable
call is actually dominated by the type test/checked load intrinsic being
analyzed. With PGO, after indirect call promotion is performed during
the compile step, followed by inlining, we may have a type test in the
promoted and inlined sequence that allows an indirect call in that
sequence to be devirtualized. That indirect call (inserted by inlining
after promotion) will share the same vtable pointer as the fallback
indirect call that cannot be devirtualized.

Before this patch the code was incorrectly devirtualizing the fallback
indirect call.

See the new test and the example described there for more details.

Reviewers: pcc, vitalybuka

Subscribers: mehdi_amini, Prazek, eraman, steven_wu, dexonsmith, llvm-commits

Differential Revision: https://reviews.llvm.org/D52514

llvm-svn: 343226
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 9e080bb..b8f68d4 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -58,6 +58,7 @@
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/DebugLoc.h"
 #include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Dominators.h"
 #include "llvm/IR/Function.h"
 #include "llvm/IR/GlobalAlias.h"
 #include "llvm/IR/GlobalVariable.h"
@@ -406,6 +407,7 @@
 struct DevirtModule {
   Module &M;
   function_ref<AAResults &(Function &)> AARGetter;
+  function_ref<DominatorTree &(Function &)> LookupDomTree;
 
   ModuleSummaryIndex *ExportSummary;
   const ModuleSummaryIndex *ImportSummary;
@@ -433,10 +435,12 @@
 
   DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter,
                function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
+               function_ref<DominatorTree &(Function &)> LookupDomTree,
                ModuleSummaryIndex *ExportSummary,
                const ModuleSummaryIndex *ImportSummary)
-      : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary),
-        ImportSummary(ImportSummary), Int8Ty(Type::getInt8Ty(M.getContext())),
+      : M(M), AARGetter(AARGetter), LookupDomTree(LookupDomTree),
+        ExportSummary(ExportSummary), ImportSummary(ImportSummary),
+        Int8Ty(Type::getInt8Ty(M.getContext())),
         Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
         Int32Ty(Type::getInt32Ty(M.getContext())),
         Int64Ty(Type::getInt64Ty(M.getContext())),
@@ -533,9 +537,10 @@
 
   // Lower the module using the action and summary passed as command line
   // arguments. For testing purposes only.
-  static bool runForTesting(
-      Module &M, function_ref<AAResults &(Function &)> AARGetter,
-      function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter);
+  static bool
+  runForTesting(Module &M, function_ref<AAResults &(Function &)> AARGetter,
+                function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
+                function_ref<DominatorTree &(Function &)> LookupDomTree);
 };
 
 struct WholeProgramDevirt : public ModulePass {
@@ -572,17 +577,23 @@
       return *ORE;
     };
 
-    if (UseCommandLine)
-      return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter);
+    auto LookupDomTree = [this](Function &F) -> DominatorTree & {
+      return this->getAnalysis<DominatorTreeWrapperPass>(F).getDomTree();
+    };
 
-    return DevirtModule(M, LegacyAARGetter(*this), OREGetter, ExportSummary,
-                        ImportSummary)
+    if (UseCommandLine)
+      return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter,
+                                         LookupDomTree);
+
+    return DevirtModule(M, LegacyAARGetter(*this), OREGetter, LookupDomTree,
+                        ExportSummary, ImportSummary)
         .run();
   }
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.addRequired<AssumptionCacheTracker>();
     AU.addRequired<TargetLibraryInfoWrapperPass>();
+    AU.addRequired<DominatorTreeWrapperPass>();
   }
 };
 
@@ -592,6 +603,7 @@
                       "Whole program devirtualization", false, false)
 INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
 INITIALIZE_PASS_END(WholeProgramDevirt, "wholeprogramdevirt",
                     "Whole program devirtualization", false, false)
 char WholeProgramDevirt::ID = 0;
@@ -611,7 +623,11 @@
   auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & {
     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
   };
-  if (!DevirtModule(M, AARGetter, OREGetter, ExportSummary, ImportSummary)
+  auto LookupDomTree = [&FAM](Function &F) -> DominatorTree & {
+    return FAM.getResult<DominatorTreeAnalysis>(F);
+  };
+  if (!DevirtModule(M, AARGetter, OREGetter, LookupDomTree, ExportSummary,
+                    ImportSummary)
            .run())
     return PreservedAnalyses::all();
   return PreservedAnalyses::none();
@@ -619,7 +635,8 @@
 
 bool DevirtModule::runForTesting(
     Module &M, function_ref<AAResults &(Function &)> AARGetter,
-    function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) {
+    function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter,
+    function_ref<DominatorTree &(Function &)> LookupDomTree) {
   ModuleSummaryIndex Summary(/*HaveGVs=*/false);
 
   // Handle the command-line summary arguments. This code is for testing
@@ -637,7 +654,7 @@
 
   bool Changed =
       DevirtModule(
-          M, AARGetter, OREGetter,
+          M, AARGetter, OREGetter, LookupDomTree,
           ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr,
           ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr)
           .run();
@@ -1342,7 +1359,7 @@
   // points to a member of the type identifier %md. Group calls by (type ID,
   // offset) pair (effectively the identity of the virtual function) and store
   // to CallSlots.
-  DenseSet<Value *> SeenPtrs;
+  DenseSet<CallSite> SeenCallSites;
   for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
        I != E;) {
     auto CI = dyn_cast<CallInst>(I->getUser());
@@ -1353,19 +1370,22 @@
     // Search for virtual calls based on %p and add them to DevirtCalls.
     SmallVector<DevirtCallSite, 1> DevirtCalls;
     SmallVector<CallInst *, 1> Assumes;
-    findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
+    auto &DT = LookupDomTree(*CI->getFunction());
+    findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT);
 
-    // If we found any, add them to CallSlots. Only do this if we haven't seen
-    // the vtable pointer before, as it may have been CSE'd with pointers from
-    // other call sites, and we don't want to process call sites multiple times.
+    // If we found any, add them to CallSlots.
     if (!Assumes.empty()) {
       Metadata *TypeId =
           cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
       Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
-      if (SeenPtrs.insert(Ptr).second) {
-        for (DevirtCallSite Call : DevirtCalls) {
+      for (DevirtCallSite Call : DevirtCalls) {
+        // Only add this CallSite if we haven't seen it before. The vtable
+        // pointer may have been CSE'd with pointers from other call sites,
+        // and we don't want to process call sites multiple times. We can't
+        // just skip the vtable Ptr if it has been seen before, however, since
+        // it may be shared by type tests that dominate different calls.
+        if (SeenCallSites.insert(Call.CS).second)
           CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr);
-        }
       }
     }
 
@@ -1399,8 +1419,9 @@
     SmallVector<Instruction *, 1> LoadedPtrs;
     SmallVector<Instruction *, 1> Preds;
     bool HasNonCallUses = false;
+    auto &DT = LookupDomTree(*CI->getFunction());
     findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
-                                               HasNonCallUses, CI);
+                                               HasNonCallUses, CI, DT);
 
     // Start by generating "pessimistic" code that explicitly loads the function
     // pointer from the vtable and performs the type check. If possible, we will