blob: 262299c5f3bbedc024de5ad4436cadf24a6e2b47 [file] [log] [blame]
Easwaran Ramanbdf20262018-01-09 19:39:35 +00001//===--- SyntheticCountsUtils.cpp - synthetic counts propagation utils ---===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This file defines utilities for propagating synthetic counts.
11//
12//===----------------------------------------------------------------------===//
13
14#include "llvm/Analysis/SyntheticCountsUtils.h"
15#include "llvm/ADT/DenseSet.h"
16#include "llvm/ADT/SCCIterator.h"
17#include "llvm/ADT/SmallPtrSet.h"
18#include "llvm/Analysis/CallGraph.h"
19#include "llvm/IR/CallSite.h"
20#include "llvm/IR/Function.h"
21#include "llvm/IR/InstIterator.h"
22#include "llvm/IR/Instructions.h"
23
24using namespace llvm;
25
26// Given a set of functions in an SCC, propagate entry counts to functions
27// called by the SCC.
28static void
29propagateFromSCC(const SmallPtrSetImpl<Function *> &SCCFunctions,
30 function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq,
31 function_ref<uint64_t(Function *F)> GetCount,
32 function_ref<void(Function *F, uint64_t)> AddToCount) {
33
34 SmallVector<CallSite, 16> CallSites;
35
36 // Gather all callsites in the SCC.
37 auto GatherCallSites = [&]() {
38 for (auto *F : SCCFunctions) {
39 assert(F && !F->isDeclaration());
40 for (auto &I : instructions(F)) {
41 if (auto CS = CallSite(&I)) {
42 CallSites.push_back(CS);
43 }
44 }
45 }
46 };
47
48 GatherCallSites();
49
50 // Partition callsites so that the callsites that call functions in the same
51 // SCC come first.
52 auto Mid = partition(CallSites, [&](CallSite &CS) {
53 auto *Callee = CS.getCalledFunction();
54 if (Callee)
55 return SCCFunctions.count(Callee);
56 // FIXME: Use the !callees metadata to propagate counts through indirect
57 // calls.
58 return 0U;
59 });
60
61 // For functions in the same SCC, update the counts in two steps:
62 // 1. Compute the additional count for each function by propagating the counts
63 // along all incoming edges to the function that originate from the same SCC
64 // and summing them up.
65 // 2. Add the additional counts to the functions in the SCC.
66 // This ensures that the order of
67 // traversal of functions within the SCC doesn't change the final result.
68
69 DenseMap<Function *, uint64_t> AdditionalCounts;
70 for (auto It = CallSites.begin(); It != Mid; It++) {
71 auto &CS = *It;
72 auto RelFreq = GetCallSiteRelFreq(CS);
73 Function *Callee = CS.getCalledFunction();
74 Function *Caller = CS.getCaller();
75 RelFreq *= Scaled64(GetCount(Caller), 0);
76 uint64_t AdditionalCount = RelFreq.toInt<uint64_t>();
77 AdditionalCounts[Callee] += AdditionalCount;
78 }
79
80 // Update the counts for the functions in the SCC.
81 for (auto &Entry : AdditionalCounts)
82 AddToCount(Entry.first, Entry.second);
83
84 // Now update the counts for functions not in SCC.
85 for (auto It = Mid; It != CallSites.end(); It++) {
86 auto &CS = *It;
87 auto Weight = GetCallSiteRelFreq(CS);
88 Function *Callee = CS.getCalledFunction();
89 Function *Caller = CS.getCaller();
90 Weight *= Scaled64(GetCount(Caller), 0);
91 AddToCount(Callee, Weight.toInt<uint64_t>());
92 }
93}
94
95/// Propgate synthetic entry counts on a callgraph.
96///
97/// This performs a reverse post-order traversal of the callgraph SCC. For each
98/// SCC, it first propagates the entry counts to the functions within the SCC
99/// through call edges and updates them in one shot. Then the entry counts are
100/// propagated to functions outside the SCC.
101void llvm::propagateSyntheticCounts(
102 const CallGraph &CG, function_ref<Scaled64(CallSite CS)> GetCallSiteRelFreq,
103 function_ref<uint64_t(Function *F)> GetCount,
104 function_ref<void(Function *F, uint64_t)> AddToCount) {
105
106 SmallVector<SmallPtrSet<Function *, 8>, 16> SCCs;
107 for (auto I = scc_begin(&CG); !I.isAtEnd(); ++I) {
108 auto SCC = *I;
109
110 SmallPtrSet<Function *, 8> SCCFunctions;
111 for (auto *Node : SCC) {
112 Function *F = Node->getFunction();
113 if (F && !F->isDeclaration()) {
114 SCCFunctions.insert(F);
115 }
116 }
117 SCCs.push_back(SCCFunctions);
118 }
119
120 for (auto &SCCFunctions : reverse(SCCs))
121 propagateFromSCC(SCCFunctions, GetCallSiteRelFreq, GetCount, AddToCount);
122}