WholeProgramDevirt: generate more detailed and accurate remarks.

Summary:
Keep track of all methods for which we have devirtualized at least
one call and then print them sorted alphabetically. That allows to
avoid duplicates and also makes the order deterministic.

Add optimization names into the remarks, so that it's easier to
understand how has each method been devirtualized.

Fix a bug when wrong methods could have been reported for
tryVirtualConstProp.

Reviewers: kcc, mehdi_amini

Differential Revision: https://reviews.llvm.org/D23297

llvm-svn: 278389
diff --git a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
index 640d994..e78665f 100644
--- a/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
+++ b/llvm/lib/Transforms/IPO/WholeProgramDevirt.cpp
@@ -235,15 +235,18 @@
   // of that field for details.
   unsigned *NumUnsafeUses;
 
-  void emitRemark() {
+  void emitRemark(const Twine &OptName, const Twine &TargetName) {
     Function *F = CS.getCaller();
-    emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F,
-                           CS.getInstruction()->getDebugLoc(),
-                           "devirtualized call");
+    emitOptimizationRemark(
+        F->getContext(), DEBUG_TYPE, *F,
+        CS.getInstruction()->getDebugLoc(),
+        OptName + ": devirtualized a call to " + TargetName);
   }
 
-  void replaceAndErase(Value *New) {
-    emitRemark();
+  void replaceAndErase(const Twine &OptName, const Twine &TargetName,
+                       bool RemarksEnabled, Value *New) {
+    if (RemarksEnabled)
+      emitRemark(OptName, TargetName);
     CS->replaceAllUsesWith(New);
     if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
       BranchInst::Create(II->getNormalDest(), CS.getInstruction());
@@ -262,6 +265,8 @@
   PointerType *Int8PtrTy;
   IntegerType *Int32Ty;
 
+  bool RemarksEnabled;
+
   MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots;
 
   // This map keeps track of the number of "unsafe" uses of a loaded function
@@ -277,7 +282,10 @@
   DevirtModule(Module &M)
       : M(M), Int8Ty(Type::getInt8Ty(M.getContext())),
         Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
-        Int32Ty(Type::getInt32Ty(M.getContext())) {}
+        Int32Ty(Type::getInt32Ty(M.getContext())),
+        RemarksEnabled(areRemarksEnabled()) {}
+
+  bool areRemarksEnabled();
 
   void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
   void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
@@ -289,16 +297,16 @@
   tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
                             const std::set<TypeMemberInfo> &TypeMemberInfos,
                             uint64_t ByteOffset);
-  bool trySingleImplDevirt(ArrayRef<VirtualCallTarget> TargetsForSlot,
+  bool trySingleImplDevirt(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
                            MutableArrayRef<VirtualCallSite> CallSites);
   bool tryEvaluateFunctionsWithArgs(
       MutableArrayRef<VirtualCallTarget> TargetsForSlot,
       ArrayRef<ConstantInt *> Args);
   bool tryUniformRetValOpt(IntegerType *RetType,
-                           ArrayRef<VirtualCallTarget> TargetsForSlot,
+                           MutableArrayRef<VirtualCallTarget> TargetsForSlot,
                            MutableArrayRef<VirtualCallSite> CallSites);
   bool tryUniqueRetValOpt(unsigned BitWidth,
-                          ArrayRef<VirtualCallTarget> TargetsForSlot,
+                          MutableArrayRef<VirtualCallTarget> TargetsForSlot,
                           MutableArrayRef<VirtualCallSite> CallSites);
   bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
                            ArrayRef<VirtualCallSite> CallSites);
@@ -413,7 +421,7 @@
 }
 
 bool DevirtModule::trySingleImplDevirt(
-    ArrayRef<VirtualCallTarget> TargetsForSlot,
+    MutableArrayRef<VirtualCallTarget> TargetsForSlot,
     MutableArrayRef<VirtualCallSite> CallSites) {
   // See if the program contains a single implementation of this virtual
   // function.
@@ -422,9 +430,12 @@
     if (TheFn != Target.Fn)
       return false;
 
+  if (RemarksEnabled)
+    TargetsForSlot[0].WasDevirt = true;
   // If so, update each call site to call that implementation directly.
   for (auto &&VCallSite : CallSites) {
-    VCallSite.emitRemark();
+    if (RemarksEnabled)
+      VCallSite.emitRemark("single-impl", TheFn->getName());
     VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
         TheFn, VCallSite.CS.getCalledValue()->getType()));
     // This use is no longer unsafe.
@@ -462,7 +473,7 @@
 }
 
 bool DevirtModule::tryUniformRetValOpt(
-    IntegerType *RetType, ArrayRef<VirtualCallTarget> TargetsForSlot,
+    IntegerType *RetType, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
     MutableArrayRef<VirtualCallSite> CallSites) {
   // Uniform return value optimization. If all functions return the same
   // constant, replace all calls with that constant.
@@ -473,12 +484,16 @@
 
   auto TheRetValConst = ConstantInt::get(RetType, TheRetVal);
   for (auto Call : CallSites)
-    Call.replaceAndErase(TheRetValConst);
+    Call.replaceAndErase("uniform-ret-val", TargetsForSlot[0].Fn->getName(),
+                         RemarksEnabled, TheRetValConst);
+  if (RemarksEnabled)
+    for (auto &&Target : TargetsForSlot)
+      Target.WasDevirt = true;
   return true;
 }
 
 bool DevirtModule::tryUniqueRetValOpt(
-    unsigned BitWidth, ArrayRef<VirtualCallTarget> TargetsForSlot,
+    unsigned BitWidth, MutableArrayRef<VirtualCallTarget> TargetsForSlot,
     MutableArrayRef<VirtualCallSite> CallSites) {
   // IsOne controls whether we look for a 0 or a 1.
   auto tryUniqueRetValOptFor = [&](bool IsOne) {
@@ -502,8 +517,14 @@
       OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset);
       Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
                                 Call.VTable, OneAddr);
-      Call.replaceAndErase(Cmp);
+      Call.replaceAndErase("unique-ret-val", TargetsForSlot[0].Fn->getName(),
+                           RemarksEnabled, Cmp);
     }
+    // Update devirtualization statistics for targets.
+    if (RemarksEnabled)
+      for (auto &&Target : TargetsForSlot)
+        Target.WasDevirt = true;
+
     return true;
   };
 
@@ -611,6 +632,10 @@
       setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
                            OffsetBit);
 
+    if (RemarksEnabled)
+      for (auto &&Target : TargetsForSlot)
+        Target.WasDevirt = true;
+
     // Rewrite each call to a load from OffsetByte/OffsetBit.
     for (auto Call : CSByConstantArg.second) {
       IRBuilder<> B(Call.CS.getInstruction());
@@ -620,27 +645,21 @@
         Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
         Value *BitsAndBit = B.CreateAnd(Bits, Bit);
         auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
-        Call.replaceAndErase(IsBitSet);
+        Call.replaceAndErase("virtual-const-prop-1-bit",
+                             TargetsForSlot[0].Fn->getName(),
+                             RemarksEnabled, IsBitSet);
       } else {
         Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
         Value *Val = B.CreateLoad(RetType, ValAddr);
-        Call.replaceAndErase(Val);
+        Call.replaceAndErase("virtual-const-prop",
+                             TargetsForSlot[0].Fn->getName(),
+                             RemarksEnabled, Val);
       }
     }
   }
   return true;
 }
 
-static void emitTargetsRemarks(const std::vector<VirtualCallTarget> &TargetsForSlot) {
-  for (const VirtualCallTarget &Target : TargetsForSlot) {
-    Function *F = Target.Fn;
-    DISubprogram *SP = F->getSubprogram();
-    DebugLoc DL = SP ? DebugLoc::get(SP->getScopeLine(), 0, SP) : DebugLoc();
-    emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, DL,
-                           std::string("devirtualized ") + F->getName().str());
-  }
-}
-
 void DevirtModule::rebuildGlobal(VTableBits &B) {
   if (B.Before.Bytes.empty() && B.After.Bytes.empty())
     return;
@@ -686,6 +705,15 @@
   B.GV->eraseFromParent();
 }
 
+bool DevirtModule::areRemarksEnabled() {
+  const auto &FL = M.getFunctionList();
+  if (FL.empty())
+    return false;
+  const Function &Fn = FL.front();
+  auto DI = DiagnosticInfoOptimizationRemark(DEBUG_TYPE, Fn, DebugLoc(), "");
+  return DI.isEnabled();
+}
+
 void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
                                      Function *AssumeFunc) {
   // Find all virtual calls via a virtual table pointer %p under an assumption
@@ -837,6 +865,7 @@
 
   // For each (type, offset) pair:
   bool DidVirtualConstProp = false;
+  std::map<std::string, Function*> DevirtTargets;
   for (auto &S : CallSlots) {
     // Search each of the members of the type identifier for the virtual
     // function implementation at offset S.first.ByteOffset, and add to
@@ -846,14 +875,25 @@
                                    S.first.ByteOffset))
       continue;
 
-    if (trySingleImplDevirt(TargetsForSlot, S.second)) {
-      emitTargetsRemarks(TargetsForSlot);
-      continue;
-    }
+    if (!trySingleImplDevirt(TargetsForSlot, S.second) &&
+        tryVirtualConstProp(TargetsForSlot, S.second))
+        DidVirtualConstProp = true;
 
-    if (tryVirtualConstProp(TargetsForSlot, S.second)) {
-      emitTargetsRemarks(TargetsForSlot);
-      DidVirtualConstProp = true;
+    // Collect functions devirtualized at least for one call site for stats.
+    if (RemarksEnabled)
+      for (const auto &T : TargetsForSlot)
+        if (T.WasDevirt)
+          DevirtTargets[T.Fn->getName()] = T.Fn;
+  }
+
+  if (RemarksEnabled) {
+    // Generate remarks for each devirtualized function.
+    for (const auto &DT : DevirtTargets) {
+      Function *F = DT.second;
+      DISubprogram *SP = F->getSubprogram();
+      DebugLoc DL = SP ? DebugLoc::get(SP->getScopeLine(), 0, SP) : DebugLoc();
+      emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, DL,
+                             Twine("devirtualized ") + F->getName());
     }
   }