blob: 5475342ecddf2c99f705f3919131cfda37341ea4 [file] [log] [blame]
Justin Bogneref512b92014-01-06 22:27:43 +00001//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Instrumentation-based profile-guided optimization
11//
12//===----------------------------------------------------------------------===//
13
14#include "CodeGenPGO.h"
15#include "CodeGenFunction.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
Justin Bogner529f6dd2014-01-07 03:43:15 +000018#include "llvm/Config/config.h" // for strtoull()/strtoll() define
Justin Bogneref512b92014-01-06 22:27:43 +000019#include "llvm/IR/MDBuilder.h"
20#include "llvm/Support/FileSystem.h"
21
22using namespace clang;
23using namespace CodeGen;
24
25static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
26 DiagnosticsEngine &Diags = CGM.getDiags();
Alp Toker29cb66b2014-01-26 06:17:37 +000027 unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0");
28 Diags.Report(diagID) << Message;
Justin Bogneref512b92014-01-06 22:27:43 +000029}
30
31PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
32 : CGM(CGM) {
33 if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
34 ReportBadPGOData(CGM, "failed to open pgo data file");
35 return;
36 }
37
38 if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
39 ReportBadPGOData(CGM, "pgo data file too big");
40 return;
41 }
42
43 // Scan through the data file and map each function to the corresponding
44 // file offset where its counts are stored.
45 const char *BufferStart = DataBuffer->getBufferStart();
46 const char *BufferEnd = DataBuffer->getBufferEnd();
47 const char *CurPtr = BufferStart;
Manman Ren67a28132014-02-05 20:40:15 +000048 uint64_t MaxCount = 0;
Justin Bogneref512b92014-01-06 22:27:43 +000049 while (CurPtr < BufferEnd) {
50 // Read the mangled function name.
51 const char *FuncName = CurPtr;
52 // FIXME: Something will need to be added to distinguish static functions.
53 CurPtr = strchr(CurPtr, ' ');
54 if (!CurPtr) {
55 ReportBadPGOData(CGM, "pgo data file has malformed function entry");
56 return;
57 }
58 StringRef MangledName(FuncName, CurPtr - FuncName);
59
60 // Read the number of counters.
61 char *EndPtr;
62 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
63 if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
64 ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
65 return;
66 }
67 CurPtr = EndPtr;
68
Manman Ren67a28132014-02-05 20:40:15 +000069 // Read function count.
70 uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
71 if (EndPtr == CurPtr || *EndPtr != '\n') {
72 ReportBadPGOData(CGM, "pgo-data file has bad count value");
73 return;
74 }
Manman Renf1a6a2d2014-02-15 01:29:02 +000075 CurPtr = EndPtr; // Point to '\n'.
Manman Ren67a28132014-02-05 20:40:15 +000076 FunctionCounts[MangledName] = Count;
77 MaxCount = Count > MaxCount ? Count : MaxCount;
78
Justin Bogneref512b92014-01-06 22:27:43 +000079 // There is one line for each counter; skip over those lines.
Manman Ren67a28132014-02-05 20:40:15 +000080 // Since function count is already read, we start the loop from 1.
81 for (unsigned N = 1; N < NumCounters; ++N) {
Justin Bogneref512b92014-01-06 22:27:43 +000082 CurPtr = strchr(++CurPtr, '\n');
83 if (!CurPtr) {
84 ReportBadPGOData(CGM, "pgo data file is missing some counter info");
85 return;
86 }
87 }
88
89 // Skip over the blank line separating functions.
90 CurPtr += 2;
91
92 DataOffsets[MangledName] = FuncName - BufferStart;
93 }
Manman Ren67a28132014-02-05 20:40:15 +000094 MaxFunctionCount = MaxCount;
95}
96
97/// Return true if a function is hot. If we know nothing about the function,
98/// return false.
99bool PGOProfileData::isHotFunction(StringRef MangledName) {
100 llvm::StringMap<uint64_t>::const_iterator CountIter =
101 FunctionCounts.find(MangledName);
102 // If we know nothing about the function, return false.
103 if (CountIter == FunctionCounts.end())
104 return false;
105 // FIXME: functions with >= 30% of the maximal function count are
106 // treated as hot. This number is from preliminary tuning on SPEC.
107 return CountIter->getValue() >= (uint64_t)(0.3 * (double)MaxFunctionCount);
108}
109
110/// Return true if a function is cold. If we know nothing about the function,
111/// return false.
112bool PGOProfileData::isColdFunction(StringRef MangledName) {
113 llvm::StringMap<uint64_t>::const_iterator CountIter =
114 FunctionCounts.find(MangledName);
115 // If we know nothing about the function, return false.
116 if (CountIter == FunctionCounts.end())
117 return false;
118 // FIXME: functions with <= 1% of the maximal function count are treated as
119 // cold. This number is from preliminary tuning on SPEC.
120 return CountIter->getValue() <= (uint64_t)(0.01 * (double)MaxFunctionCount);
Justin Bogneref512b92014-01-06 22:27:43 +0000121}
122
123bool PGOProfileData::getFunctionCounts(StringRef MangledName,
124 std::vector<uint64_t> &Counts) {
125 // Find the relevant section of the pgo-data file.
126 llvm::StringMap<unsigned>::const_iterator OffsetIter =
127 DataOffsets.find(MangledName);
128 if (OffsetIter == DataOffsets.end())
129 return true;
130 const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
131
132 // Skip over the function name.
133 CurPtr = strchr(CurPtr, ' ');
134 assert(CurPtr && "pgo-data has corrupted function entry");
135
136 // Read the number of counters.
137 char *EndPtr;
138 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
139 assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
140 "pgo-data file has corrupted number of counters");
141 CurPtr = EndPtr;
142
143 Counts.reserve(NumCounters);
144
145 for (unsigned N = 0; N < NumCounters; ++N) {
146 // Read the count value.
147 uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
148 if (EndPtr == CurPtr || *EndPtr != '\n') {
149 ReportBadPGOData(CGM, "pgo-data file has bad count value");
150 return true;
151 }
152 Counts.push_back(Count);
153 CurPtr = EndPtr + 1;
154 }
155
156 // Make sure the number of counters matches up.
157 if (Counts.size() != NumCounters) {
158 ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
159 return true;
160 }
161
162 return false;
163}
164
165void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) {
166 if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
167 return;
168
169 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
170
171 llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx);
172 llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
173
174 llvm::Function *WriteoutF =
175 CGM.getModule().getFunction("__llvm_pgo_writeout");
176 if (!WriteoutF) {
177 llvm::FunctionType *WriteoutFTy =
178 llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
179 WriteoutF = llvm::Function::Create(WriteoutFTy,
180 llvm::GlobalValue::InternalLinkage,
181 "__llvm_pgo_writeout", &CGM.getModule());
182 }
183 WriteoutF->setUnnamedAddr(true);
184 WriteoutF->addFnAttr(llvm::Attribute::NoInline);
185 if (CGM.getCodeGenOpts().DisableRedZone)
186 WriteoutF->addFnAttr(llvm::Attribute::NoRedZone);
187
188 llvm::BasicBlock *BB = WriteoutF->empty() ?
189 llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock();
190
191 CGBuilderTy PGOBuilder(BB);
192
193 llvm::Instruction *I = BB->getTerminator();
194 if (!I)
195 I = PGOBuilder.CreateRetVoid();
196 PGOBuilder.SetInsertPoint(I);
197
198 llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
199 llvm::Type *Args[] = {
200 Int8PtrTy, // const char *MangledName
201 Int32Ty, // uint32_t NumCounters
202 Int64PtrTy // uint64_t *Counters
203 };
204 llvm::FunctionType *FTy =
205 llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false);
206 llvm::Constant *EmitFunc =
207 CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy);
208
209 llvm::Constant *MangledName =
210 CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name");
211 MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy);
212 PGOBuilder.CreateCall3(EmitFunc, MangledName,
213 PGOBuilder.getInt32(NumRegionCounters),
214 PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy));
215}
216
217llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
218 llvm::Function *WriteoutF =
219 CGM.getModule().getFunction("__llvm_pgo_writeout");
220 if (!WriteoutF)
221 return NULL;
222
223 // Create a small bit of code that registers the "__llvm_pgo_writeout" to
224 // be executed at exit.
225 llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init");
226 if (F)
227 return NULL;
228
229 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
230 llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
231 false);
232 F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage,
233 "__llvm_pgo_init", &CGM.getModule());
234 F->setUnnamedAddr(true);
235 F->setLinkage(llvm::GlobalValue::InternalLinkage);
236 F->addFnAttr(llvm::Attribute::NoInline);
237 if (CGM.getCodeGenOpts().DisableRedZone)
238 F->addFnAttr(llvm::Attribute::NoRedZone);
239
240 llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F);
241 CGBuilderTy PGOBuilder(BB);
242
243 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false);
244 llvm::Type *Params[] = {
245 llvm::PointerType::get(FTy, 0)
246 };
247 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false);
248
249 // Inialize the environment and register the local writeout function.
250 llvm::Constant *PGOInit =
251 CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy);
252 PGOBuilder.CreateCall(PGOInit, WriteoutF);
253 PGOBuilder.CreateRetVoid();
254
255 return F;
256}
257
258namespace {
259 /// A StmtVisitor that fills a map of statements to PGO counters.
260 struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
261 /// The next counter value to assign.
262 unsigned NextCounter;
263 /// The map of statements to counters.
264 llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
265
266 MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
267 NextCounter(0), CounterMap(CounterMap) {
268 }
269
270 void VisitChildren(const Stmt *S) {
271 for (Stmt::const_child_range I = S->children(); I; ++I)
272 if (*I)
273 this->Visit(*I);
274 }
275 void VisitStmt(const Stmt *S) { VisitChildren(S); }
276
Justin Bognerea278c32014-01-07 00:20:28 +0000277 /// Assign a counter to track entry to the function body.
Justin Bogneref512b92014-01-06 22:27:43 +0000278 void VisitFunctionDecl(const FunctionDecl *S) {
279 (*CounterMap)[S->getBody()] = NextCounter++;
280 Visit(S->getBody());
281 }
Justin Bognerea278c32014-01-07 00:20:28 +0000282 /// Assign a counter to track the block following a label.
Justin Bogneref512b92014-01-06 22:27:43 +0000283 void VisitLabelStmt(const LabelStmt *S) {
284 (*CounterMap)[S] = NextCounter++;
285 Visit(S->getSubStmt());
286 }
287 /// Assign three counters - one for the body of the loop, one for breaks
288 /// from the loop, and one for continues.
289 ///
290 /// The break and continue counters cover all such statements in this loop,
291 /// and are used in calculations to find the number of times the condition
292 /// and exit of the loop occur. They are needed so we can differentiate
293 /// these statements from non-local exits like return and goto.
294 void VisitWhileStmt(const WhileStmt *S) {
295 (*CounterMap)[S] = NextCounter;
296 NextCounter += 3;
297 Visit(S->getCond());
298 Visit(S->getBody());
299 }
300 /// Assign counters for the body of the loop, and for breaks and
301 /// continues. See VisitWhileStmt.
302 void VisitDoStmt(const DoStmt *S) {
303 (*CounterMap)[S] = NextCounter;
304 NextCounter += 3;
305 Visit(S->getBody());
306 Visit(S->getCond());
307 }
308 /// Assign counters for the body of the loop, and for breaks and
309 /// continues. See VisitWhileStmt.
310 void VisitForStmt(const ForStmt *S) {
311 (*CounterMap)[S] = NextCounter;
312 NextCounter += 3;
313 const Expr *E;
314 if ((E = S->getCond()))
315 Visit(E);
316 Visit(S->getBody());
317 if ((E = S->getInc()))
318 Visit(E);
319 }
320 /// Assign counters for the body of the loop, and for breaks and
321 /// continues. See VisitWhileStmt.
322 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
323 (*CounterMap)[S] = NextCounter;
324 NextCounter += 3;
325 const Expr *E;
326 if ((E = S->getCond()))
327 Visit(E);
328 Visit(S->getBody());
329 if ((E = S->getInc()))
330 Visit(E);
331 }
332 /// Assign counters for the body of the loop, and for breaks and
333 /// continues. See VisitWhileStmt.
334 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
335 (*CounterMap)[S] = NextCounter;
336 NextCounter += 3;
337 Visit(S->getElement());
338 Visit(S->getBody());
339 }
340 /// Assign a counter for the exit block of the switch statement.
341 void VisitSwitchStmt(const SwitchStmt *S) {
342 (*CounterMap)[S] = NextCounter++;
343 Visit(S->getCond());
344 Visit(S->getBody());
345 }
346 /// Assign a counter for a particular case in a switch. This counts jumps
347 /// from the switch header as well as fallthrough from the case before this
348 /// one.
349 void VisitCaseStmt(const CaseStmt *S) {
350 (*CounterMap)[S] = NextCounter++;
351 Visit(S->getSubStmt());
352 }
353 /// Assign a counter for the default case of a switch statement. The count
354 /// is the number of branches from the loop header to the default, and does
355 /// not include fallthrough from previous cases. If we have multiple
356 /// conditional branch blocks from the switch instruction to the default
357 /// block, as with large GNU case ranges, this is the counter for the last
358 /// edge in that series, rather than the first.
359 void VisitDefaultStmt(const DefaultStmt *S) {
360 (*CounterMap)[S] = NextCounter++;
361 Visit(S->getSubStmt());
362 }
363 /// Assign a counter for the "then" part of an if statement. The count for
364 /// the "else" part, if it exists, will be calculated from this counter.
365 void VisitIfStmt(const IfStmt *S) {
366 (*CounterMap)[S] = NextCounter++;
367 Visit(S->getCond());
368 Visit(S->getThen());
369 if (S->getElse())
370 Visit(S->getElse());
371 }
372 /// Assign a counter for the continuation block of a C++ try statement.
373 void VisitCXXTryStmt(const CXXTryStmt *S) {
374 (*CounterMap)[S] = NextCounter++;
375 Visit(S->getTryBlock());
376 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
377 Visit(S->getHandler(I));
378 }
379 /// Assign a counter for a catch statement's handler block.
380 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
381 (*CounterMap)[S] = NextCounter++;
382 Visit(S->getHandlerBlock());
383 }
384 /// Assign a counter for the "true" part of a conditional operator. The
385 /// count in the "false" part will be calculated from this counter.
386 void VisitConditionalOperator(const ConditionalOperator *E) {
387 (*CounterMap)[E] = NextCounter++;
388 Visit(E->getCond());
389 Visit(E->getTrueExpr());
390 Visit(E->getFalseExpr());
391 }
392 /// Assign a counter for the right hand side of a logical and operator.
393 void VisitBinLAnd(const BinaryOperator *E) {
394 (*CounterMap)[E] = NextCounter++;
395 Visit(E->getLHS());
396 Visit(E->getRHS());
397 }
398 /// Assign a counter for the right hand side of a logical or operator.
399 void VisitBinLOr(const BinaryOperator *E) {
400 (*CounterMap)[E] = NextCounter++;
401 Visit(E->getLHS());
402 Visit(E->getRHS());
403 }
404 };
405}
406
407void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) {
408 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
409 PGOProfileData *PGOData = CGM.getPGOData();
410 if (!InstrumentRegions && !PGOData)
411 return;
412 const Decl *D = GD.getDecl();
413 if (!D)
414 return;
415 mapRegionCounters(D);
416 if (InstrumentRegions)
417 emitCounterVariables();
418 if (PGOData)
419 loadRegionCounts(GD, PGOData);
420}
421
422void CodeGenPGO::mapRegionCounters(const Decl *D) {
423 RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
424 MapRegionCounters Walker(RegionCounterMap);
425 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
426 Walker.VisitFunctionDecl(FD);
427 NumRegionCounters = Walker.NextCounter;
428}
429
430void CodeGenPGO::emitCounterVariables() {
431 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
432 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
433 NumRegionCounters);
434 RegionCounters =
435 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false,
436 llvm::GlobalVariable::PrivateLinkage,
437 llvm::Constant::getNullValue(CounterTy),
438 "__llvm_pgo_ctr");
439}
440
441void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
442 if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
443 return;
444 llvm::Value *Addr =
445 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
446 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
447 Count = Builder.CreateAdd(Count, Builder.getInt64(1));
448 Builder.CreateStore(Count, Addr);
449}
450
451void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) {
452 // For now, ignore the counts from the PGO data file only if the number of
453 // counters does not match. This could be tightened down in the future to
454 // ignore counts when the input changes in various ways, e.g., by comparing a
455 // hash value based on some characteristics of the input.
456 RegionCounts = new std::vector<uint64_t>();
457 if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) ||
458 RegionCounts->size() != NumRegionCounters) {
459 delete RegionCounts;
460 RegionCounts = 0;
461 }
462}
463
464void CodeGenPGO::destroyRegionCounters() {
465 if (RegionCounterMap != 0)
466 delete RegionCounterMap;
467 if (RegionCounts != 0)
468 delete RegionCounts;
469}
470
471llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
472 uint64_t FalseCount) {
473 if (!TrueCount && !FalseCount)
474 return 0;
475
476 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
477 // TODO: need to scale down to 32-bits
478 // According to Laplace's Rule of Succession, it is better to compute the
479 // weight based on the count plus 1.
480 return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1);
481}
482
483llvm::MDNode *
484CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
485 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
486 // TODO: need to scale down to 32-bits, instead of just truncating.
487 // According to Laplace's Rule of Succession, it is better to compute the
488 // weight based on the count plus 1.
489 SmallVector<uint32_t, 16> ScaledWeights;
490 ScaledWeights.reserve(Weights.size());
491 for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end();
492 WI != WE; ++WI) {
493 ScaledWeights.push_back(*WI + 1);
494 }
495 return MDHelper.createBranchWeights(ScaledWeights);
496}