Add function importing info from samplepgo profile to the module summary.

Summary: For SamplePGO, the profile may contain cross-module inline stacks. As we need to make sure the profile annotation happens when all the hot inline stacks are expanded, we need to pass this info to the module importer so that it can import proper functions if necessary. This patch implemented this feature by emitting cross-module targets as part of function entry metadata. In the module-summary phase, the metadata is used to build call edges that points to functions need to be imported.

Reviewers: mehdi_amini, tejohnson

Reviewed By: tejohnson

Subscribers: davidxl, llvm-commits

Differential Revision: https://reviews.llvm.org/D30053

llvm-svn: 296498
diff --git a/llvm/docs/BranchWeightMetadata.rst b/llvm/docs/BranchWeightMetadata.rst
index 9e61d23..b941d0d 100644
--- a/llvm/docs/BranchWeightMetadata.rst
+++ b/llvm/docs/BranchWeightMetadata.rst
@@ -123,11 +123,11 @@
 optimization, ``MD_prof`` nodes can also be assigned to a function definition.
 The first operand is a string indicating the name of the associated counter.
 
-Currently, one counter is supported: "function_entry_count". This is a 64-bit
-counter that indicates the number of times that this function was invoked (in
-the case of instrumentation-based profiles). In the case of sampling-based
-profiles, this counter is an approximation of how many times the function was
-invoked.
+Currently, one counter is supported: "function_entry_count". The second operand
+is a 64-bit counter that indicates the number of times that this function was
+invoked (in the case of instrumentation-based profiles). In the case of
+sampling-based profiles, this operand is an approximation of how many times
+the function was invoked.
 
 For example, in the code below, the instrumentation for function foo()
 indicates that it was called 2,590 times at runtime.
@@ -138,3 +138,13 @@
     ret i32 0
   }
   !1 = !{!"function_entry_count", i64 2590}
+
+If "function_entry_count" has more than 2 operands, the later operands are
+the GUID of the functions that needs to be imported by ThinLTO. This is only
+set by sampling based profile. It is needed because the sampling based profile
+was collected on a binary that had already imported and inlined these functions,
+and we need to ensure the IR matches in the ThinLTO backends for profile
+annotation. The reason why we cannot annotate this on the callsite is that it
+can only goes down 1 level in the call chain. For the cases where
+foo_in_a_cc()->bar_in_b_cc()->baz_in_c_cc(), we will need to go down 2 levels
+in the call chain to import both bar_in_b_cc and baz_in_c_cc.
diff --git a/llvm/include/llvm/IR/Function.h b/llvm/include/llvm/IR/Function.h
index 0c83916..f9e8fcc 100644
--- a/llvm/include/llvm/IR/Function.h
+++ b/llvm/include/llvm/IR/Function.h
@@ -18,6 +18,7 @@
 #ifndef LLVM_IR_FUNCTION_H
 #define LLVM_IR_FUNCTION_H
 
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/ilist_node.h"
 #include "llvm/ADT/iterator_range.h"
 #include "llvm/ADT/StringRef.h"
@@ -207,8 +208,11 @@
   /// \brief Set the entry count for this function.
   ///
   /// Entry count is the number of times this function was executed based on
-  /// pgo data.
-  void setEntryCount(uint64_t Count);
+  /// pgo data. \p Imports points to a set of GUIDs that needs to be imported
+  /// by the function for sample PGO, to enable the same inlines as the
+  /// profiled optimized binary.
+  void setEntryCount(uint64_t Count,
+                     const DenseSet<GlobalValue::GUID> *Imports = nullptr);
 
   /// \brief Get the entry count for this function.
   ///
@@ -216,6 +220,10 @@
   /// pgo data.
   Optional<uint64_t> getEntryCount() const;
 
+  /// Returns the set of GUIDs that needs to be imported to the function for
+  /// sample PGO, to enable the same inlines as the profiled optimized binary.
+  DenseSet<GlobalValue::GUID> getImportGUIDs() const;
+
   /// Set the section prefix for this function.
   void setSectionPrefix(StringRef Prefix);
 
diff --git a/llvm/include/llvm/IR/MDBuilder.h b/llvm/include/llvm/IR/MDBuilder.h
index bab8728..899976a 100644
--- a/llvm/include/llvm/IR/MDBuilder.h
+++ b/llvm/include/llvm/IR/MDBuilder.h
@@ -15,7 +15,9 @@
 #ifndef LLVM_IR_MDBUILDER_H
 #define LLVM_IR_MDBUILDER_H
 
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/StringRef.h"
+#include "llvm/IR/GlobalValue.h"
 #include "llvm/Support/DataTypes.h"
 #include <utility>
 
@@ -63,8 +65,11 @@
   /// Return metadata specifying that a branch or switch is unpredictable.
   MDNode *createUnpredictable();
 
-  /// Return metadata containing the entry count for a function.
-  MDNode *createFunctionEntryCount(uint64_t Count);
+  /// Return metadata containing the entry \p Count for a function, and the
+  /// GUIDs stored in \p Imports that need to be imported for sample PGO, to
+  /// enable the same inlines as the profiled optimized binary
+  MDNode *createFunctionEntryCount(uint64_t Count,
+                                   const DenseSet<GlobalValue::GUID> *Imports);
 
   /// Return metadata containing the section prefix for a function.
   MDNode *createFunctionSectionPrefix(StringRef Prefix);
diff --git a/llvm/include/llvm/ProfileData/SampleProf.h b/llvm/include/llvm/ProfileData/SampleProf.h
index b286df3..1ae1cc4 100644
--- a/llvm/include/llvm/ProfileData/SampleProf.h
+++ b/llvm/include/llvm/ProfileData/SampleProf.h
@@ -15,8 +15,11 @@
 #ifndef LLVM_PROFILEDATA_SAMPLEPROF_H_
 #define LLVM_PROFILEDATA_SAMPLEPROF_H_
 
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringMap.h"
+#include "llvm/IR/GlobalValue.h"
+#include "llvm/IR/Module.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorOr.h"
 #include "llvm/Support/raw_ostream.h"
@@ -300,6 +303,20 @@
     return Result;
   }
 
+  /// Recursively traverses all children, if the corresponding function is
+  /// not defined in module \p M, and its total sample is no less than
+  /// \p Threshold, add its corresponding GUID to \p S.
+  void findImportedFunctions(DenseSet<GlobalValue::GUID> &S, const Module *M,
+                             uint64_t Threshold) const {
+    if (TotalSamples <= Threshold)
+      return;
+    Function *F = M->getFunction(Name);
+    if (!F || !F->getSubprogram())
+      S.insert(Function::getGUID(Name));
+    for (auto CS : CallsiteSamples)
+      CS.second.findImportedFunctions(S, M, Threshold);
+  }
+
   /// Set the name of the function.
   void setName(StringRef FunctionName) { Name = FunctionName; }
 
diff --git a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
index 950b898..9bc5503 100644
--- a/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
+++ b/llvm/lib/Analysis/ModuleSummaryAnalysis.cpp
@@ -259,6 +259,11 @@
       }
     }
 
+  // Explicit add hot edges to enforce importing for designated GUIDs for
+  // sample PGO, to enable the same inlines as the profiled optimized binary.
+  for (auto &I : F.getImportGUIDs())
+    CallGraphEdges[I].updateHotness(CalleeInfo::HotnessType::Hot);
+
   bool NonRenamableLocal = isNonRenamableLocal(F);
   bool NotEligibleForImport =
       NonRenamableLocal || HasInlineAsmMaybeReferencingInternal ||
diff --git a/llvm/lib/IR/Function.cpp b/llvm/lib/IR/Function.cpp
index 0c6b352..7b5d49c 100644
--- a/llvm/lib/IR/Function.cpp
+++ b/llvm/lib/IR/Function.cpp
@@ -1259,9 +1259,10 @@
     setValueSubclassData(getSubclassDataFromValue() & ~(1 << Bit));
 }
 
-void Function::setEntryCount(uint64_t Count) {
+void Function::setEntryCount(uint64_t Count,
+                             const DenseSet<GlobalValue::GUID> *S) {
   MDBuilder MDB(getContext());
-  setMetadata(LLVMContext::MD_prof, MDB.createFunctionEntryCount(Count));
+  setMetadata(LLVMContext::MD_prof, MDB.createFunctionEntryCount(Count, S));
 }
 
 Optional<uint64_t> Function::getEntryCount() const {
@@ -1278,6 +1279,18 @@
   return None;
 }
 
+DenseSet<GlobalValue::GUID> Function::getImportGUIDs() const {
+  DenseSet<GlobalValue::GUID> R;
+  if (MDNode *MD = getMetadata(LLVMContext::MD_prof))
+    if (MDString *MDS = dyn_cast<MDString>(MD->getOperand(0)))
+      if (MDS->getString().equals("function_entry_count"))
+        for (unsigned i = 2; i < MD->getNumOperands(); i++)
+          R.insert(mdconst::extract<ConstantInt>(MD->getOperand(i))
+                       ->getValue()
+                       .getZExtValue());
+  return R;
+}
+
 void Function::setSectionPrefix(StringRef Prefix) {
   MDBuilder MDB(getContext());
   setMetadata(LLVMContext::MD_section_prefix,
diff --git a/llvm/lib/IR/MDBuilder.cpp b/llvm/lib/IR/MDBuilder.cpp
index f4bfd59..b9c4f48 100644
--- a/llvm/lib/IR/MDBuilder.cpp
+++ b/llvm/lib/IR/MDBuilder.cpp
@@ -56,11 +56,16 @@
   return MDNode::get(Context, None);
 }
 
-MDNode *MDBuilder::createFunctionEntryCount(uint64_t Count) {
+MDNode *MDBuilder::createFunctionEntryCount(
+    uint64_t Count, const DenseSet<GlobalValue::GUID> *Imports) {
   Type *Int64Ty = Type::getInt64Ty(Context);
-  return MDNode::get(Context,
-                     {createString("function_entry_count"),
-                      createConstant(ConstantInt::get(Int64Ty, Count))});
+  SmallVector<Metadata *, 8> Ops;
+  Ops.push_back(createString("function_entry_count"));
+  Ops.push_back(createConstant(ConstantInt::get(Int64Ty, Count)));
+  if (Imports)
+    for (auto ID : *Imports)
+      Ops.push_back(createConstant(ConstantInt::get(Int64Ty, ID)));
+  return MDNode::get(Context, Ops);
 }
 
 MDNode *MDBuilder::createFunctionSectionPrefix(StringRef Prefix) {
diff --git a/llvm/lib/IR/Verifier.cpp b/llvm/lib/IR/Verifier.cpp
index 5b3ec41..9001735 100644
--- a/llvm/lib/IR/Verifier.cpp
+++ b/llvm/lib/IR/Verifier.cpp
@@ -1653,8 +1653,8 @@
   for (const auto &Pair : MDs) {
     if (Pair.first == LLVMContext::MD_prof) {
       MDNode *MD = Pair.second;
-      Assert(MD->getNumOperands() == 2,
-             "!prof annotations should have exactly 2 operands", MD);
+      Assert(MD->getNumOperands() >= 2,
+             "!prof annotations should have no less than 2 operands", MD);
 
       // Check first operand.
       Assert(MD->getOperand(0) != nullptr, "first operand should not be null",
diff --git a/llvm/lib/Transforms/IPO/SampleProfile.cpp b/llvm/lib/Transforms/IPO/SampleProfile.cpp
index ad057d5..a9f5b03 100644
--- a/llvm/lib/Transforms/IPO/SampleProfile.cpp
+++ b/llvm/lib/Transforms/IPO/SampleProfile.cpp
@@ -163,7 +163,8 @@
   ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB);
   const FunctionSamples *findCalleeFunctionSamples(const Instruction &I) const;
   const FunctionSamples *findFunctionSamples(const Instruction &I) const;
-  bool inlineHotFunctions(Function &F);
+  bool inlineHotFunctions(Function &F,
+                          DenseSet<GlobalValue::GUID> &ImportGUIDs);
   void printEdgeWeight(raw_ostream &OS, Edge E);
   void printBlockWeight(raw_ostream &OS, const BasicBlock *BB) const;
   void printBlockEquivalence(raw_ostream &OS, const BasicBlock *BB);
@@ -604,9 +605,12 @@
 /// it to direct call. Each indirect call is limited with a single target.
 ///
 /// \param F function to perform iterative inlining.
+/// \param ImportGUIDs a set to be updated to include all GUIDs that come
+///     from a different module but inlined in the profiled binary.
 ///
 /// \returns True if there is any inline happened.
-bool SampleProfileLoader::inlineHotFunctions(Function &F) {
+bool SampleProfileLoader::inlineHotFunctions(
+    Function &F, DenseSet<GlobalValue::GUID> &ImportGUIDs) {
   DenseSet<Instruction *> PromotedInsns;
   bool Changed = false;
   LLVMContext &Ctx = F.getContext();
@@ -655,8 +659,12 @@
           continue;
         }
       }
-      if (!CalledFunction || !CalledFunction->getSubprogram())
+      if (!CalledFunction || !CalledFunction->getSubprogram()) {
+        findCalleeFunctionSamples(*I)->findImportedFunctions(
+            ImportGUIDs, F.getParent(),
+            Samples->getTotalSamples() * SampleProfileHotThreshold / 100);
         continue;
+      }
       DebugLoc DLoc = I->getDebugLoc();
       uint64_t NumSamples = findCalleeFunctionSamples(*I)->getTotalSamples();
       if (InlineFunction(CallSite(DI), IFI)) {
@@ -1041,10 +1049,6 @@
   bool Changed = true;
   unsigned I = 0;
 
-  // Add an entry count to the function using the samples gathered
-  // at the function entry.
-  F.setEntryCount(Samples->getHeadSamples() + 1);
-
   // If BB weight is larger than its corresponding loop's header BB weight,
   // use the BB weight to replace the loop header BB weight.
   for (auto &BI : F) {
@@ -1273,12 +1277,19 @@
   DEBUG(dbgs() << "Line number for the first instruction in " << F.getName()
                << ": " << getFunctionLoc(F) << "\n");
 
-  Changed |= inlineHotFunctions(F);
+  DenseSet<GlobalValue::GUID> ImportGUIDs;
+  Changed |= inlineHotFunctions(F, ImportGUIDs);
 
   // Compute basic block weights.
   Changed |= computeBlockWeights(F);
 
   if (Changed) {
+    // Add an entry count to the function using the samples gathered at the
+    // function entry. Also sets the GUIDs that comes from a different
+    // module but inlined in the profiled binary. This is aiming at making
+    // the IR match the profiled binary before annotation.
+    F.setEntryCount(Samples->getHeadSamples() + 1, &ImportGUIDs);
+
     // Compute dominance and loop info needed for propagation.
     computeDominanceAndLoopInfo(F);
 
diff --git a/llvm/test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll b/llvm/test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll
index 9e6e72c..a9f65c9 100644
--- a/llvm/test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll
+++ b/llvm/test/Bitcode/thinlto-function-summary-callgraph-profile-summary.ll
@@ -10,7 +10,7 @@
 ; CHECK-NEXT:    <VERSION
 ; See if the call to func is registered, using the expected callsite count
 ; and profile count, with value id matching the subsequent value symbol table.
-; CHECK-NEXT:    <PERMODULE_PROFILE {{.*}} op4=[[HOT1:.*]] op5=3 op6=[[COLD:.*]] op7=1 op8=[[HOT2:.*]] op9=3 op10=[[NONE1:.*]] op11=2 op12=[[HOT3:.*]] op13=3 op14=[[NONE2:.*]] op15=2 op16=[[NONE3:.*]] op17=2/>
+; CHECK-NEXT:    <PERMODULE_PROFILE {{.*}} op4=[[HOT1:.*]] op5=3 op6=[[COLD:.*]] op7=1 op8=[[HOT2:.*]] op9=3 op10=[[NONE1:.*]] op11=2 op12=[[HOT3:.*]] op13=3 op14=[[NONE2:.*]] op15=2 op16=[[NONE3:.*]] op17=2 op18=[[LEGACY:.*]] op19=3/>
 ; CHECK-NEXT:  </GLOBALVAL_SUMMARY_BLOCK>
 ; CHECK-LABEL:  <VALUE_SYMTAB
 ; CHECK-NEXT:       <FNENTRY {{.*}} record string = 'hot_function
@@ -21,6 +21,7 @@
 ; CHECK-DAG:        <ENTRY abbrevid=6 op0=[[HOT1]] {{.*}} record string = 'hot1'
 ; CHECK-DAG:        <ENTRY abbrevid=6 op0=[[HOT2]] {{.*}} record string = 'hot2'
 ; CHECK-DAG:        <ENTRY abbrevid=6 op0=[[HOT3]] {{.*}} record string = 'hot3'
+; CHECK-DAG:        <COMBINED_ENTRY abbrevid=11 op0=[[LEGACY]] op1=123/>
 ; CHECK-LABEL:  </VALUE_SYMTAB>
 
 ; COMBINED:       <GLOBALVAL_SUMMARY_BLOCK
@@ -80,7 +81,7 @@
 
 
 !llvm.module.flags = !{!1}
-!20 = !{!"function_entry_count", i64 110}
+!20 = !{!"function_entry_count", i64 110, i64 123}
 
 !1 = !{i32 1, !"ProfileSummary", !2}
 !2 = !{!3, !4, !5, !6, !7, !8, !9, !10}
diff --git a/llvm/test/Transforms/SampleProfile/Inputs/import.prof b/llvm/test/Transforms/SampleProfile/Inputs/import.prof
new file mode 100644
index 0000000..efadc0c
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/Inputs/import.prof
@@ -0,0 +1,4 @@
+main:10000:0
+ 3: foo:1000
+  3: bar:200
+   4: baz:10
diff --git a/llvm/test/Transforms/SampleProfile/import.ll b/llvm/test/Transforms/SampleProfile/import.ll
new file mode 100644
index 0000000..1ee45fb
--- /dev/null
+++ b/llvm/test/Transforms/SampleProfile/import.ll
@@ -0,0 +1,31 @@
+; RUN: opt < %s -sample-profile -sample-profile-file=%S/Inputs/import.prof -S | FileCheck %s
+
+; Tests whether the functions in the inline stack are added to the
+; function_entry_count metadata.
+
+declare void @foo()
+
+define void @main() !dbg !7 {
+  call void @foo(), !dbg !18
+  ret void
+}
+
+; GUIDs of foo and bar should be included in the metadata to make sure hot
+; inline stacks are imported.
+; CHECK: !{!"function_entry_count", i64 1, i64 6699318081062747564, i64 -2012135647395072713}
+
+!llvm.dbg.cu = !{!0}
+!llvm.module.flags = !{!8, !9}
+!llvm.ident = !{!10}
+
+!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus, producer: "clang version 3.5 ", isOptimized: false, emissionKind: NoDebug, file: !1, enums: !2, retainedTypes: !2, globals: !2, imports: !2)
+!1 = !DIFile(filename: "calls.cc", directory: ".")
+!2 = !{}
+!6 = !DISubroutineType(types: !2)
+!7 = distinct !DISubprogram(name: "main", line: 7, isLocal: false, isDefinition: true, virtualIndex: 6, flags: DIFlagPrototyped, isOptimized: false, unit: !0, scopeLine: 7, file: !1, scope: !1, type: !6, variables: !2)
+!8 = !{i32 2, !"Dwarf Version", i32 4}
+!9 = !{i32 1, !"Debug Info Version", i32 3}
+!10 = !{!"clang version 3.5 "}
+!15 = !DILexicalBlockFile(discriminator: 1, file: !1, scope: !7)
+!17 = distinct !DILexicalBlock(line: 10, column: 0, file: !1, scope: !7)
+!18 = !DILocation(line: 10, scope: !17)
diff --git a/llvm/test/Verifier/function-metadata-bad.ll b/llvm/test/Verifier/function-metadata-bad.ll
index 9e7ba22..b3bd3c2 100644
--- a/llvm/test/Verifier/function-metadata-bad.ll
+++ b/llvm/test/Verifier/function-metadata-bad.ll
@@ -14,7 +14,7 @@
 }
 
 !1 = !{!"function_entry_count"}
-; CHECK-NEXT: !prof annotations should have exactly 2 operands
+; CHECK-NEXT: !prof annotations should have no less than 2 operands
 ; CHECK-NEXT: !1 = !{!"function_entry_count"}
 
 
diff --git a/llvm/test/Verifier/metadata-function-prof.ll b/llvm/test/Verifier/metadata-function-prof.ll
index d84a7fe..70548b1 100644
--- a/llvm/test/Verifier/metadata-function-prof.ll
+++ b/llvm/test/Verifier/metadata-function-prof.ll
@@ -12,4 +12,4 @@
   unreachable
 }
 
-!0 = !{}
+!0 = !{!"function_entry_count", i64 100}