blob: 38bb4c00aa389681eba9dc8b2ff9bed6e96812c9 [file] [log] [blame]
Justin Bogneref512b92014-01-06 22:27:43 +00001//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Instrumentation-based profile-guided optimization
11//
12//===----------------------------------------------------------------------===//
13
14#include "CodeGenPGO.h"
15#include "CodeGenFunction.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
Justin Bogner529f6dd2014-01-07 03:43:15 +000018#include "llvm/Config/config.h" // for strtoull()/strtoll() define
Justin Bogneref512b92014-01-06 22:27:43 +000019#include "llvm/IR/MDBuilder.h"
20#include "llvm/Support/FileSystem.h"
21
22using namespace clang;
23using namespace CodeGen;
24
25static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
26 DiagnosticsEngine &Diags = CGM.getDiags();
Alp Toker29cb66b2014-01-26 06:17:37 +000027 unsigned diagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, "%0");
28 Diags.Report(diagID) << Message;
Justin Bogneref512b92014-01-06 22:27:43 +000029}
30
31PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
32 : CGM(CGM) {
33 if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
34 ReportBadPGOData(CGM, "failed to open pgo data file");
35 return;
36 }
37
38 if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
39 ReportBadPGOData(CGM, "pgo data file too big");
40 return;
41 }
42
43 // Scan through the data file and map each function to the corresponding
44 // file offset where its counts are stored.
45 const char *BufferStart = DataBuffer->getBufferStart();
46 const char *BufferEnd = DataBuffer->getBufferEnd();
47 const char *CurPtr = BufferStart;
Manman Ren67a28132014-02-05 20:40:15 +000048 uint64_t MaxCount = 0;
Justin Bogneref512b92014-01-06 22:27:43 +000049 while (CurPtr < BufferEnd) {
50 // Read the mangled function name.
51 const char *FuncName = CurPtr;
52 // FIXME: Something will need to be added to distinguish static functions.
53 CurPtr = strchr(CurPtr, ' ');
54 if (!CurPtr) {
55 ReportBadPGOData(CGM, "pgo data file has malformed function entry");
56 return;
57 }
58 StringRef MangledName(FuncName, CurPtr - FuncName);
59
60 // Read the number of counters.
61 char *EndPtr;
62 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
63 if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
64 ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
65 return;
66 }
67 CurPtr = EndPtr;
68
Manman Ren67a28132014-02-05 20:40:15 +000069 // Read function count.
70 uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
71 if (EndPtr == CurPtr || *EndPtr != '\n') {
72 ReportBadPGOData(CGM, "pgo-data file has bad count value");
73 return;
74 }
Manman Renf1a6a2d2014-02-15 01:29:02 +000075 CurPtr = EndPtr; // Point to '\n'.
Manman Ren67a28132014-02-05 20:40:15 +000076 FunctionCounts[MangledName] = Count;
77 MaxCount = Count > MaxCount ? Count : MaxCount;
78
Justin Bogneref512b92014-01-06 22:27:43 +000079 // There is one line for each counter; skip over those lines.
Manman Ren67a28132014-02-05 20:40:15 +000080 // Since function count is already read, we start the loop from 1.
81 for (unsigned N = 1; N < NumCounters; ++N) {
Justin Bogneref512b92014-01-06 22:27:43 +000082 CurPtr = strchr(++CurPtr, '\n');
83 if (!CurPtr) {
84 ReportBadPGOData(CGM, "pgo data file is missing some counter info");
85 return;
86 }
87 }
88
89 // Skip over the blank line separating functions.
90 CurPtr += 2;
91
92 DataOffsets[MangledName] = FuncName - BufferStart;
93 }
Manman Ren67a28132014-02-05 20:40:15 +000094 MaxFunctionCount = MaxCount;
95}
96
97/// Return true if a function is hot. If we know nothing about the function,
98/// return false.
99bool PGOProfileData::isHotFunction(StringRef MangledName) {
100 llvm::StringMap<uint64_t>::const_iterator CountIter =
101 FunctionCounts.find(MangledName);
102 // If we know nothing about the function, return false.
103 if (CountIter == FunctionCounts.end())
104 return false;
105 // FIXME: functions with >= 30% of the maximal function count are
106 // treated as hot. This number is from preliminary tuning on SPEC.
107 return CountIter->getValue() >= (uint64_t)(0.3 * (double)MaxFunctionCount);
108}
109
110/// Return true if a function is cold. If we know nothing about the function,
111/// return false.
112bool PGOProfileData::isColdFunction(StringRef MangledName) {
113 llvm::StringMap<uint64_t>::const_iterator CountIter =
114 FunctionCounts.find(MangledName);
115 // If we know nothing about the function, return false.
116 if (CountIter == FunctionCounts.end())
117 return false;
118 // FIXME: functions with <= 1% of the maximal function count are treated as
119 // cold. This number is from preliminary tuning on SPEC.
120 return CountIter->getValue() <= (uint64_t)(0.01 * (double)MaxFunctionCount);
Justin Bogneref512b92014-01-06 22:27:43 +0000121}
122
123bool PGOProfileData::getFunctionCounts(StringRef MangledName,
124 std::vector<uint64_t> &Counts) {
125 // Find the relevant section of the pgo-data file.
126 llvm::StringMap<unsigned>::const_iterator OffsetIter =
127 DataOffsets.find(MangledName);
128 if (OffsetIter == DataOffsets.end())
129 return true;
130 const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
131
132 // Skip over the function name.
133 CurPtr = strchr(CurPtr, ' ');
134 assert(CurPtr && "pgo-data has corrupted function entry");
135
136 // Read the number of counters.
137 char *EndPtr;
138 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
139 assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
140 "pgo-data file has corrupted number of counters");
141 CurPtr = EndPtr;
142
143 Counts.reserve(NumCounters);
144
145 for (unsigned N = 0; N < NumCounters; ++N) {
146 // Read the count value.
147 uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
148 if (EndPtr == CurPtr || *EndPtr != '\n') {
149 ReportBadPGOData(CGM, "pgo-data file has bad count value");
150 return true;
151 }
152 Counts.push_back(Count);
153 CurPtr = EndPtr + 1;
154 }
155
156 // Make sure the number of counters matches up.
157 if (Counts.size() != NumCounters) {
158 ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
159 return true;
160 }
161
162 return false;
163}
164
165void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) {
166 if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
167 return;
168
169 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
170
171 llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx);
172 llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
173
174 llvm::Function *WriteoutF =
175 CGM.getModule().getFunction("__llvm_pgo_writeout");
176 if (!WriteoutF) {
177 llvm::FunctionType *WriteoutFTy =
178 llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
179 WriteoutF = llvm::Function::Create(WriteoutFTy,
180 llvm::GlobalValue::InternalLinkage,
181 "__llvm_pgo_writeout", &CGM.getModule());
182 }
183 WriteoutF->setUnnamedAddr(true);
184 WriteoutF->addFnAttr(llvm::Attribute::NoInline);
185 if (CGM.getCodeGenOpts().DisableRedZone)
186 WriteoutF->addFnAttr(llvm::Attribute::NoRedZone);
187
188 llvm::BasicBlock *BB = WriteoutF->empty() ?
189 llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock();
190
191 CGBuilderTy PGOBuilder(BB);
192
193 llvm::Instruction *I = BB->getTerminator();
194 if (!I)
195 I = PGOBuilder.CreateRetVoid();
196 PGOBuilder.SetInsertPoint(I);
197
198 llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
199 llvm::Type *Args[] = {
200 Int8PtrTy, // const char *MangledName
201 Int32Ty, // uint32_t NumCounters
202 Int64PtrTy // uint64_t *Counters
203 };
204 llvm::FunctionType *FTy =
205 llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false);
206 llvm::Constant *EmitFunc =
207 CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy);
208
209 llvm::Constant *MangledName =
210 CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name");
211 MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy);
212 PGOBuilder.CreateCall3(EmitFunc, MangledName,
213 PGOBuilder.getInt32(NumRegionCounters),
214 PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy));
215}
216
217llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
218 llvm::Function *WriteoutF =
219 CGM.getModule().getFunction("__llvm_pgo_writeout");
220 if (!WriteoutF)
221 return NULL;
222
223 // Create a small bit of code that registers the "__llvm_pgo_writeout" to
224 // be executed at exit.
225 llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init");
226 if (F)
227 return NULL;
228
229 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
230 llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
231 false);
232 F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage,
233 "__llvm_pgo_init", &CGM.getModule());
234 F->setUnnamedAddr(true);
235 F->setLinkage(llvm::GlobalValue::InternalLinkage);
236 F->addFnAttr(llvm::Attribute::NoInline);
237 if (CGM.getCodeGenOpts().DisableRedZone)
238 F->addFnAttr(llvm::Attribute::NoRedZone);
239
240 llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F);
241 CGBuilderTy PGOBuilder(BB);
242
243 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false);
244 llvm::Type *Params[] = {
245 llvm::PointerType::get(FTy, 0)
246 };
247 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false);
248
249 // Inialize the environment and register the local writeout function.
250 llvm::Constant *PGOInit =
251 CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy);
252 PGOBuilder.CreateCall(PGOInit, WriteoutF);
253 PGOBuilder.CreateRetVoid();
254
255 return F;
256}
257
258namespace {
259 /// A StmtVisitor that fills a map of statements to PGO counters.
260 struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
261 /// The next counter value to assign.
262 unsigned NextCounter;
263 /// The map of statements to counters.
264 llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
265
266 MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
267 NextCounter(0), CounterMap(CounterMap) {
268 }
269
270 void VisitChildren(const Stmt *S) {
271 for (Stmt::const_child_range I = S->children(); I; ++I)
272 if (*I)
273 this->Visit(*I);
274 }
275 void VisitStmt(const Stmt *S) { VisitChildren(S); }
276
Justin Bognerea278c32014-01-07 00:20:28 +0000277 /// Assign a counter to track entry to the function body.
Justin Bogneref512b92014-01-06 22:27:43 +0000278 void VisitFunctionDecl(const FunctionDecl *S) {
279 (*CounterMap)[S->getBody()] = NextCounter++;
280 Visit(S->getBody());
281 }
Justin Bognerea278c32014-01-07 00:20:28 +0000282 /// Assign a counter to track the block following a label.
Justin Bogneref512b92014-01-06 22:27:43 +0000283 void VisitLabelStmt(const LabelStmt *S) {
284 (*CounterMap)[S] = NextCounter++;
285 Visit(S->getSubStmt());
286 }
Bob Wilsonbf854f02014-02-17 19:21:09 +0000287 /// Assign a counter for the body of a while loop.
Justin Bogneref512b92014-01-06 22:27:43 +0000288 void VisitWhileStmt(const WhileStmt *S) {
Bob Wilsonbf854f02014-02-17 19:21:09 +0000289 (*CounterMap)[S] = NextCounter++;
Justin Bogneref512b92014-01-06 22:27:43 +0000290 Visit(S->getCond());
291 Visit(S->getBody());
292 }
Bob Wilsonbf854f02014-02-17 19:21:09 +0000293 /// Assign a counter for the body of a do-while loop.
Justin Bogneref512b92014-01-06 22:27:43 +0000294 void VisitDoStmt(const DoStmt *S) {
Bob Wilsonbf854f02014-02-17 19:21:09 +0000295 (*CounterMap)[S] = NextCounter++;
Justin Bogneref512b92014-01-06 22:27:43 +0000296 Visit(S->getBody());
297 Visit(S->getCond());
298 }
Bob Wilsonbf854f02014-02-17 19:21:09 +0000299 /// Assign a counter for the body of a for loop.
Justin Bogneref512b92014-01-06 22:27:43 +0000300 void VisitForStmt(const ForStmt *S) {
Bob Wilsonbf854f02014-02-17 19:21:09 +0000301 (*CounterMap)[S] = NextCounter++;
302 if (S->getInit())
303 Visit(S->getInit());
Justin Bogneref512b92014-01-06 22:27:43 +0000304 const Expr *E;
305 if ((E = S->getCond()))
306 Visit(E);
Justin Bogneref512b92014-01-06 22:27:43 +0000307 if ((E = S->getInc()))
308 Visit(E);
Bob Wilsonbf854f02014-02-17 19:21:09 +0000309 Visit(S->getBody());
Justin Bogneref512b92014-01-06 22:27:43 +0000310 }
Bob Wilsonbf854f02014-02-17 19:21:09 +0000311 /// Assign a counter for the body of a for-range loop.
Justin Bogneref512b92014-01-06 22:27:43 +0000312 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
Bob Wilsonbf854f02014-02-17 19:21:09 +0000313 (*CounterMap)[S] = NextCounter++;
314 Visit(S->getRangeStmt());
315 Visit(S->getBeginEndStmt());
316 Visit(S->getCond());
317 Visit(S->getLoopVarStmt());
Justin Bogneref512b92014-01-06 22:27:43 +0000318 Visit(S->getBody());
Bob Wilsonbf854f02014-02-17 19:21:09 +0000319 Visit(S->getInc());
Justin Bogneref512b92014-01-06 22:27:43 +0000320 }
Bob Wilsonbf854f02014-02-17 19:21:09 +0000321 /// Assign a counter for the body of a for-collection loop.
Justin Bogneref512b92014-01-06 22:27:43 +0000322 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
Bob Wilsonbf854f02014-02-17 19:21:09 +0000323 (*CounterMap)[S] = NextCounter++;
Justin Bogneref512b92014-01-06 22:27:43 +0000324 Visit(S->getElement());
325 Visit(S->getBody());
326 }
327 /// Assign a counter for the exit block of the switch statement.
328 void VisitSwitchStmt(const SwitchStmt *S) {
329 (*CounterMap)[S] = NextCounter++;
330 Visit(S->getCond());
331 Visit(S->getBody());
332 }
333 /// Assign a counter for a particular case in a switch. This counts jumps
334 /// from the switch header as well as fallthrough from the case before this
335 /// one.
336 void VisitCaseStmt(const CaseStmt *S) {
337 (*CounterMap)[S] = NextCounter++;
338 Visit(S->getSubStmt());
339 }
340 /// Assign a counter for the default case of a switch statement. The count
341 /// is the number of branches from the loop header to the default, and does
342 /// not include fallthrough from previous cases. If we have multiple
343 /// conditional branch blocks from the switch instruction to the default
344 /// block, as with large GNU case ranges, this is the counter for the last
345 /// edge in that series, rather than the first.
346 void VisitDefaultStmt(const DefaultStmt *S) {
347 (*CounterMap)[S] = NextCounter++;
348 Visit(S->getSubStmt());
349 }
350 /// Assign a counter for the "then" part of an if statement. The count for
351 /// the "else" part, if it exists, will be calculated from this counter.
352 void VisitIfStmt(const IfStmt *S) {
353 (*CounterMap)[S] = NextCounter++;
354 Visit(S->getCond());
355 Visit(S->getThen());
356 if (S->getElse())
357 Visit(S->getElse());
358 }
359 /// Assign a counter for the continuation block of a C++ try statement.
360 void VisitCXXTryStmt(const CXXTryStmt *S) {
361 (*CounterMap)[S] = NextCounter++;
362 Visit(S->getTryBlock());
363 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
364 Visit(S->getHandler(I));
365 }
366 /// Assign a counter for a catch statement's handler block.
367 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
368 (*CounterMap)[S] = NextCounter++;
369 Visit(S->getHandlerBlock());
370 }
371 /// Assign a counter for the "true" part of a conditional operator. The
372 /// count in the "false" part will be calculated from this counter.
373 void VisitConditionalOperator(const ConditionalOperator *E) {
374 (*CounterMap)[E] = NextCounter++;
375 Visit(E->getCond());
376 Visit(E->getTrueExpr());
377 Visit(E->getFalseExpr());
378 }
379 /// Assign a counter for the right hand side of a logical and operator.
380 void VisitBinLAnd(const BinaryOperator *E) {
381 (*CounterMap)[E] = NextCounter++;
382 Visit(E->getLHS());
383 Visit(E->getRHS());
384 }
385 /// Assign a counter for the right hand side of a logical or operator.
386 void VisitBinLOr(const BinaryOperator *E) {
387 (*CounterMap)[E] = NextCounter++;
388 Visit(E->getLHS());
389 Visit(E->getRHS());
390 }
391 };
Bob Wilsonbf854f02014-02-17 19:21:09 +0000392
393 /// A StmtVisitor that propagates the raw counts through the AST and
394 /// records the count at statements where the value may change.
395 struct ComputeRegionCounts : public ConstStmtVisitor<ComputeRegionCounts> {
396 /// PGO state.
397 CodeGenPGO &PGO;
398
399 /// A flag that is set when the current count should be recorded on the
400 /// next statement, such as at the exit of a loop.
401 bool RecordNextStmtCount;
402
403 /// The map of statements to count values.
404 llvm::DenseMap<const Stmt*, uint64_t> *CountMap;
405
406 /// BreakContinueStack - Keep counts of breaks and continues inside loops.
407 struct BreakContinue {
408 uint64_t BreakCount;
409 uint64_t ContinueCount;
410 BreakContinue() : BreakCount(0), ContinueCount(0) {}
411 };
412 SmallVector<BreakContinue, 8> BreakContinueStack;
413
414 ComputeRegionCounts(llvm::DenseMap<const Stmt*, uint64_t> *CountMap,
415 CodeGenPGO &PGO) :
416 PGO(PGO), RecordNextStmtCount(false), CountMap(CountMap) {
417 }
418
419 void RecordStmtCount(const Stmt *S) {
420 if (RecordNextStmtCount) {
421 (*CountMap)[S] = PGO.getCurrentRegionCount();
422 RecordNextStmtCount = false;
423 }
424 }
425
426 void VisitStmt(const Stmt *S) {
427 RecordStmtCount(S);
428 for (Stmt::const_child_range I = S->children(); I; ++I) {
429 if (*I)
430 this->Visit(*I);
431 }
432 }
433
434 void VisitFunctionDecl(const FunctionDecl *S) {
435 RegionCounter Cnt(PGO, S->getBody());
436 Cnt.beginRegion();
437 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
438 Visit(S->getBody());
439 }
440
441 void VisitReturnStmt(const ReturnStmt *S) {
442 RecordStmtCount(S);
443 if (S->getRetValue())
444 Visit(S->getRetValue());
445 PGO.setCurrentRegionUnreachable();
446 RecordNextStmtCount = true;
447 }
448
449 void VisitGotoStmt(const GotoStmt *S) {
450 RecordStmtCount(S);
451 PGO.setCurrentRegionUnreachable();
452 RecordNextStmtCount = true;
453 }
454
455 void VisitLabelStmt(const LabelStmt *S) {
456 RecordNextStmtCount = false;
457 RegionCounter Cnt(PGO, S);
458 Cnt.beginRegion();
459 (*CountMap)[S] = PGO.getCurrentRegionCount();
460 Visit(S->getSubStmt());
461 }
462
463 void VisitBreakStmt(const BreakStmt *S) {
464 RecordStmtCount(S);
465 assert(!BreakContinueStack.empty() && "break not in a loop or switch!");
466 BreakContinueStack.back().BreakCount += PGO.getCurrentRegionCount();
467 PGO.setCurrentRegionUnreachable();
468 RecordNextStmtCount = true;
469 }
470
471 void VisitContinueStmt(const ContinueStmt *S) {
472 RecordStmtCount(S);
473 assert(!BreakContinueStack.empty() && "continue stmt not in a loop!");
474 BreakContinueStack.back().ContinueCount += PGO.getCurrentRegionCount();
475 PGO.setCurrentRegionUnreachable();
476 RecordNextStmtCount = true;
477 }
478
479 void VisitWhileStmt(const WhileStmt *S) {
480 RecordStmtCount(S);
481 RegionCounter Cnt(PGO, S);
482 BreakContinueStack.push_back(BreakContinue());
483 // Visit the body region first so the break/continue adjustments can be
484 // included when visiting the condition.
485 Cnt.beginRegion();
486 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
487 Visit(S->getBody());
488 Cnt.adjustForControlFlow();
489
490 // ...then go back and propagate counts through the condition. The count
491 // at the start of the condition is the sum of the incoming edges,
492 // the backedge from the end of the loop body, and the edges from
493 // continue statements.
494 BreakContinue BC = BreakContinueStack.pop_back_val();
495 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
496 Cnt.getAdjustedCount() + BC.ContinueCount);
497 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
498 Visit(S->getCond());
499 Cnt.adjustForControlFlow();
500 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
501 RecordNextStmtCount = true;
502 }
503
504 void VisitDoStmt(const DoStmt *S) {
505 RecordStmtCount(S);
506 RegionCounter Cnt(PGO, S);
507 BreakContinueStack.push_back(BreakContinue());
508 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
509 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
510 Visit(S->getBody());
511 Cnt.adjustForControlFlow();
512
513 BreakContinue BC = BreakContinueStack.pop_back_val();
514 // The count at the start of the condition is equal to the count at the
515 // end of the body. The adjusted count does not include either the
516 // fall-through count coming into the loop or the continue count, so add
517 // both of those separately. This is coincidentally the same equation as
518 // with while loops but for different reasons.
519 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
520 Cnt.getAdjustedCount() + BC.ContinueCount);
521 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
522 Visit(S->getCond());
523 Cnt.adjustForControlFlow();
524 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
525 RecordNextStmtCount = true;
526 }
527
528 void VisitForStmt(const ForStmt *S) {
529 RecordStmtCount(S);
530 if (S->getInit())
531 Visit(S->getInit());
532 RegionCounter Cnt(PGO, S);
533 BreakContinueStack.push_back(BreakContinue());
534 // Visit the body region first. (This is basically the same as a while
535 // loop; see further comments in VisitWhileStmt.)
536 Cnt.beginRegion();
537 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
538 Visit(S->getBody());
539 Cnt.adjustForControlFlow();
540
541 // The increment is essentially part of the body but it needs to include
542 // the count for all the continue statements.
543 if (S->getInc()) {
544 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
545 BreakContinueStack.back().ContinueCount);
546 (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount();
547 Visit(S->getInc());
548 Cnt.adjustForControlFlow();
549 }
550
551 BreakContinue BC = BreakContinueStack.pop_back_val();
552
553 // ...then go back and propagate counts through the condition.
554 if (S->getCond()) {
555 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
556 Cnt.getAdjustedCount() +
557 BC.ContinueCount);
558 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
559 Visit(S->getCond());
560 Cnt.adjustForControlFlow();
561 }
562 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
563 RecordNextStmtCount = true;
564 }
565
566 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
567 RecordStmtCount(S);
568 Visit(S->getRangeStmt());
569 Visit(S->getBeginEndStmt());
570 RegionCounter Cnt(PGO, S);
571 BreakContinueStack.push_back(BreakContinue());
572 // Visit the body region first. (This is basically the same as a while
573 // loop; see further comments in VisitWhileStmt.)
574 Cnt.beginRegion();
575 (*CountMap)[S->getLoopVarStmt()] = PGO.getCurrentRegionCount();
576 Visit(S->getLoopVarStmt());
577 Visit(S->getBody());
578 Cnt.adjustForControlFlow();
579
580 // The increment is essentially part of the body but it needs to include
581 // the count for all the continue statements.
582 Cnt.setCurrentRegionCount(PGO.getCurrentRegionCount() +
583 BreakContinueStack.back().ContinueCount);
584 (*CountMap)[S->getInc()] = PGO.getCurrentRegionCount();
585 Visit(S->getInc());
586 Cnt.adjustForControlFlow();
587
588 BreakContinue BC = BreakContinueStack.pop_back_val();
589
590 // ...then go back and propagate counts through the condition.
591 Cnt.setCurrentRegionCount(Cnt.getParentCount() +
592 Cnt.getAdjustedCount() +
593 BC.ContinueCount);
594 (*CountMap)[S->getCond()] = PGO.getCurrentRegionCount();
595 Visit(S->getCond());
596 Cnt.adjustForControlFlow();
597 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
598 RecordNextStmtCount = true;
599 }
600
601 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
602 RecordStmtCount(S);
603 Visit(S->getElement());
604 RegionCounter Cnt(PGO, S);
605 BreakContinueStack.push_back(BreakContinue());
606 Cnt.beginRegion();
607 (*CountMap)[S->getBody()] = PGO.getCurrentRegionCount();
608 Visit(S->getBody());
609 BreakContinue BC = BreakContinueStack.pop_back_val();
610 Cnt.adjustForControlFlow();
611 Cnt.applyAdjustmentsToRegion(BC.BreakCount + BC.ContinueCount);
612 RecordNextStmtCount = true;
613 }
614
615 void VisitSwitchStmt(const SwitchStmt *S) {
616 RecordStmtCount(S);
617 Visit(S->getCond());
618 PGO.setCurrentRegionUnreachable();
619 BreakContinueStack.push_back(BreakContinue());
620 Visit(S->getBody());
621 // If the switch is inside a loop, add the continue counts.
622 BreakContinue BC = BreakContinueStack.pop_back_val();
623 if (!BreakContinueStack.empty())
624 BreakContinueStack.back().ContinueCount += BC.ContinueCount;
625 RegionCounter ExitCnt(PGO, S);
626 ExitCnt.beginRegion();
627 RecordNextStmtCount = true;
628 }
629
630 void VisitCaseStmt(const CaseStmt *S) {
631 RecordNextStmtCount = false;
632 RegionCounter Cnt(PGO, S);
633 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
634 (*CountMap)[S] = Cnt.getCount();
635 RecordNextStmtCount = true;
636 Visit(S->getSubStmt());
637 }
638
639 void VisitDefaultStmt(const DefaultStmt *S) {
640 RecordNextStmtCount = false;
641 RegionCounter Cnt(PGO, S);
642 Cnt.beginRegion(/*AddIncomingFallThrough=*/true);
643 (*CountMap)[S] = Cnt.getCount();
644 RecordNextStmtCount = true;
645 Visit(S->getSubStmt());
646 }
647
648 void VisitIfStmt(const IfStmt *S) {
649 RecordStmtCount(S);
650 RegionCounter Cnt(PGO, S);
651 Visit(S->getCond());
652
653 Cnt.beginRegion();
654 (*CountMap)[S->getThen()] = PGO.getCurrentRegionCount();
655 Visit(S->getThen());
656 Cnt.adjustForControlFlow();
657
658 if (S->getElse()) {
659 Cnt.beginElseRegion();
660 (*CountMap)[S->getElse()] = PGO.getCurrentRegionCount();
661 Visit(S->getElse());
662 Cnt.adjustForControlFlow();
663 }
664 Cnt.applyAdjustmentsToRegion(0);
665 RecordNextStmtCount = true;
666 }
667
668 void VisitCXXTryStmt(const CXXTryStmt *S) {
669 RecordStmtCount(S);
670 Visit(S->getTryBlock());
671 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
672 Visit(S->getHandler(I));
673 RegionCounter Cnt(PGO, S);
674 Cnt.beginRegion();
675 RecordNextStmtCount = true;
676 }
677
678 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
679 RecordNextStmtCount = false;
680 RegionCounter Cnt(PGO, S);
681 Cnt.beginRegion();
682 (*CountMap)[S] = PGO.getCurrentRegionCount();
683 Visit(S->getHandlerBlock());
684 }
685
686 void VisitConditionalOperator(const ConditionalOperator *E) {
687 RecordStmtCount(E);
688 RegionCounter Cnt(PGO, E);
689 Visit(E->getCond());
690
691 Cnt.beginRegion();
692 (*CountMap)[E->getTrueExpr()] = PGO.getCurrentRegionCount();
693 Visit(E->getTrueExpr());
694 Cnt.adjustForControlFlow();
695
696 Cnt.beginElseRegion();
697 (*CountMap)[E->getFalseExpr()] = PGO.getCurrentRegionCount();
698 Visit(E->getFalseExpr());
699 Cnt.adjustForControlFlow();
700
701 Cnt.applyAdjustmentsToRegion(0);
702 RecordNextStmtCount = true;
703 }
704
705 void VisitBinLAnd(const BinaryOperator *E) {
706 RecordStmtCount(E);
707 RegionCounter Cnt(PGO, E);
708 Visit(E->getLHS());
709 Cnt.beginRegion();
710 (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount();
711 Visit(E->getRHS());
712 Cnt.adjustForControlFlow();
713 Cnt.applyAdjustmentsToRegion(0);
714 RecordNextStmtCount = true;
715 }
716
717 void VisitBinLOr(const BinaryOperator *E) {
718 RecordStmtCount(E);
719 RegionCounter Cnt(PGO, E);
720 Visit(E->getLHS());
721 Cnt.beginRegion();
722 (*CountMap)[E->getRHS()] = PGO.getCurrentRegionCount();
723 Visit(E->getRHS());
724 Cnt.adjustForControlFlow();
725 Cnt.applyAdjustmentsToRegion(0);
726 RecordNextStmtCount = true;
727 }
728 };
Justin Bogneref512b92014-01-06 22:27:43 +0000729}
730
731void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) {
732 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
733 PGOProfileData *PGOData = CGM.getPGOData();
734 if (!InstrumentRegions && !PGOData)
735 return;
736 const Decl *D = GD.getDecl();
737 if (!D)
738 return;
739 mapRegionCounters(D);
740 if (InstrumentRegions)
741 emitCounterVariables();
Bob Wilsonbf854f02014-02-17 19:21:09 +0000742 if (PGOData) {
Justin Bogneref512b92014-01-06 22:27:43 +0000743 loadRegionCounts(GD, PGOData);
Bob Wilsonbf854f02014-02-17 19:21:09 +0000744 computeRegionCounts(D);
745 }
Justin Bogneref512b92014-01-06 22:27:43 +0000746}
747
748void CodeGenPGO::mapRegionCounters(const Decl *D) {
749 RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
750 MapRegionCounters Walker(RegionCounterMap);
751 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
752 Walker.VisitFunctionDecl(FD);
753 NumRegionCounters = Walker.NextCounter;
754}
755
Bob Wilsonbf854f02014-02-17 19:21:09 +0000756void CodeGenPGO::computeRegionCounts(const Decl *D) {
757 StmtCountMap = new llvm::DenseMap<const Stmt*, uint64_t>();
758 ComputeRegionCounts Walker(StmtCountMap, *this);
759 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
760 Walker.VisitFunctionDecl(FD);
761}
762
Justin Bogneref512b92014-01-06 22:27:43 +0000763void CodeGenPGO::emitCounterVariables() {
764 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
765 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
766 NumRegionCounters);
767 RegionCounters =
768 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false,
769 llvm::GlobalVariable::PrivateLinkage,
770 llvm::Constant::getNullValue(CounterTy),
771 "__llvm_pgo_ctr");
772}
773
774void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
Bob Wilson749ebc72014-03-06 04:55:28 +0000775 if (!RegionCounters)
Justin Bogneref512b92014-01-06 22:27:43 +0000776 return;
777 llvm::Value *Addr =
778 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
779 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
780 Count = Builder.CreateAdd(Count, Builder.getInt64(1));
781 Builder.CreateStore(Count, Addr);
782}
783
784void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) {
785 // For now, ignore the counts from the PGO data file only if the number of
786 // counters does not match. This could be tightened down in the future to
787 // ignore counts when the input changes in various ways, e.g., by comparing a
788 // hash value based on some characteristics of the input.
789 RegionCounts = new std::vector<uint64_t>();
790 if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) ||
791 RegionCounts->size() != NumRegionCounters) {
792 delete RegionCounts;
793 RegionCounts = 0;
794 }
795}
796
797void CodeGenPGO::destroyRegionCounters() {
798 if (RegionCounterMap != 0)
799 delete RegionCounterMap;
Bob Wilsonbf854f02014-02-17 19:21:09 +0000800 if (StmtCountMap != 0)
801 delete StmtCountMap;
Justin Bogneref512b92014-01-06 22:27:43 +0000802 if (RegionCounts != 0)
803 delete RegionCounts;
804}
805
806llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
807 uint64_t FalseCount) {
808 if (!TrueCount && !FalseCount)
809 return 0;
810
811 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
812 // TODO: need to scale down to 32-bits
813 // According to Laplace's Rule of Succession, it is better to compute the
814 // weight based on the count plus 1.
815 return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1);
816}
817
Bob Wilson95a27b02014-02-17 19:20:59 +0000818llvm::MDNode *CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
Justin Bogneref512b92014-01-06 22:27:43 +0000819 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
820 // TODO: need to scale down to 32-bits, instead of just truncating.
821 // According to Laplace's Rule of Succession, it is better to compute the
822 // weight based on the count plus 1.
823 SmallVector<uint32_t, 16> ScaledWeights;
824 ScaledWeights.reserve(Weights.size());
825 for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end();
826 WI != WE; ++WI) {
827 ScaledWeights.push_back(*WI + 1);
828 }
829 return MDHelper.createBranchWeights(ScaledWeights);
830}
Bob Wilsonbf854f02014-02-17 19:21:09 +0000831
832llvm::MDNode *CodeGenPGO::createLoopWeights(const Stmt *Cond,
833 RegionCounter &Cnt) {
834 if (!haveRegionCounts())
835 return 0;
836 uint64_t LoopCount = Cnt.getCount();
837 uint64_t CondCount = 0;
838 bool Found = getStmtCount(Cond, CondCount);
839 assert(Found && "missing expected loop condition count");
840 (void)Found;
841 if (CondCount == 0)
842 return 0;
843 return createBranchWeights(LoopCount,
844 std::max(CondCount, LoopCount) - LoopCount);
845}