blob: c1a45be115ec610ed22c43ec94d41001d13640c9 [file] [log] [blame]
Justin Bogneref512b92014-01-06 22:27:43 +00001//===--- CodeGenPGO.cpp - PGO Instrumentation for LLVM CodeGen --*- C++ -*-===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// Instrumentation-based profile-guided optimization
11//
12//===----------------------------------------------------------------------===//
13
14#include "CodeGenPGO.h"
15#include "CodeGenFunction.h"
16#include "clang/AST/RecursiveASTVisitor.h"
17#include "clang/AST/StmtVisitor.h"
18#include "llvm/IR/MDBuilder.h"
19#include "llvm/Support/FileSystem.h"
20
21using namespace clang;
22using namespace CodeGen;
23
24static void ReportBadPGOData(CodeGenModule &CGM, const char *Message) {
25 DiagnosticsEngine &Diags = CGM.getDiags();
26 unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error, Message);
27 Diags.Report(DiagID);
28}
29
30PGOProfileData::PGOProfileData(CodeGenModule &CGM, std::string Path)
31 : CGM(CGM) {
32 if (llvm::MemoryBuffer::getFile(Path, DataBuffer)) {
33 ReportBadPGOData(CGM, "failed to open pgo data file");
34 return;
35 }
36
37 if (DataBuffer->getBufferSize() > std::numeric_limits<unsigned>::max()) {
38 ReportBadPGOData(CGM, "pgo data file too big");
39 return;
40 }
41
42 // Scan through the data file and map each function to the corresponding
43 // file offset where its counts are stored.
44 const char *BufferStart = DataBuffer->getBufferStart();
45 const char *BufferEnd = DataBuffer->getBufferEnd();
46 const char *CurPtr = BufferStart;
47 while (CurPtr < BufferEnd) {
48 // Read the mangled function name.
49 const char *FuncName = CurPtr;
50 // FIXME: Something will need to be added to distinguish static functions.
51 CurPtr = strchr(CurPtr, ' ');
52 if (!CurPtr) {
53 ReportBadPGOData(CGM, "pgo data file has malformed function entry");
54 return;
55 }
56 StringRef MangledName(FuncName, CurPtr - FuncName);
57
58 // Read the number of counters.
59 char *EndPtr;
60 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
61 if (EndPtr == CurPtr || *EndPtr != '\n' || NumCounters <= 0) {
62 ReportBadPGOData(CGM, "pgo data file has unexpected number of counters");
63 return;
64 }
65 CurPtr = EndPtr;
66
67 // There is one line for each counter; skip over those lines.
68 for (unsigned N = 0; N < NumCounters; ++N) {
69 CurPtr = strchr(++CurPtr, '\n');
70 if (!CurPtr) {
71 ReportBadPGOData(CGM, "pgo data file is missing some counter info");
72 return;
73 }
74 }
75
76 // Skip over the blank line separating functions.
77 CurPtr += 2;
78
79 DataOffsets[MangledName] = FuncName - BufferStart;
80 }
81}
82
83bool PGOProfileData::getFunctionCounts(StringRef MangledName,
84 std::vector<uint64_t> &Counts) {
85 // Find the relevant section of the pgo-data file.
86 llvm::StringMap<unsigned>::const_iterator OffsetIter =
87 DataOffsets.find(MangledName);
88 if (OffsetIter == DataOffsets.end())
89 return true;
90 const char *CurPtr = DataBuffer->getBufferStart() + OffsetIter->getValue();
91
92 // Skip over the function name.
93 CurPtr = strchr(CurPtr, ' ');
94 assert(CurPtr && "pgo-data has corrupted function entry");
95
96 // Read the number of counters.
97 char *EndPtr;
98 unsigned NumCounters = strtol(++CurPtr, &EndPtr, 10);
99 assert(EndPtr != CurPtr && *EndPtr == '\n' && NumCounters > 0 &&
100 "pgo-data file has corrupted number of counters");
101 CurPtr = EndPtr;
102
103 Counts.reserve(NumCounters);
104
105 for (unsigned N = 0; N < NumCounters; ++N) {
106 // Read the count value.
107 uint64_t Count = strtoll(CurPtr, &EndPtr, 10);
108 if (EndPtr == CurPtr || *EndPtr != '\n') {
109 ReportBadPGOData(CGM, "pgo-data file has bad count value");
110 return true;
111 }
112 Counts.push_back(Count);
113 CurPtr = EndPtr + 1;
114 }
115
116 // Make sure the number of counters matches up.
117 if (Counts.size() != NumCounters) {
118 ReportBadPGOData(CGM, "pgo-data file has inconsistent counters");
119 return true;
120 }
121
122 return false;
123}
124
125void CodeGenPGO::emitWriteoutFunction(GlobalDecl &GD) {
126 if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
127 return;
128
129 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
130
131 llvm::Type *Int32Ty = llvm::Type::getInt32Ty(Ctx);
132 llvm::Type *Int8PtrTy = llvm::Type::getInt8PtrTy(Ctx);
133
134 llvm::Function *WriteoutF =
135 CGM.getModule().getFunction("__llvm_pgo_writeout");
136 if (!WriteoutF) {
137 llvm::FunctionType *WriteoutFTy =
138 llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
139 WriteoutF = llvm::Function::Create(WriteoutFTy,
140 llvm::GlobalValue::InternalLinkage,
141 "__llvm_pgo_writeout", &CGM.getModule());
142 }
143 WriteoutF->setUnnamedAddr(true);
144 WriteoutF->addFnAttr(llvm::Attribute::NoInline);
145 if (CGM.getCodeGenOpts().DisableRedZone)
146 WriteoutF->addFnAttr(llvm::Attribute::NoRedZone);
147
148 llvm::BasicBlock *BB = WriteoutF->empty() ?
149 llvm::BasicBlock::Create(Ctx, "", WriteoutF) : &WriteoutF->getEntryBlock();
150
151 CGBuilderTy PGOBuilder(BB);
152
153 llvm::Instruction *I = BB->getTerminator();
154 if (!I)
155 I = PGOBuilder.CreateRetVoid();
156 PGOBuilder.SetInsertPoint(I);
157
158 llvm::Type *Int64PtrTy = llvm::Type::getInt64PtrTy(Ctx);
159 llvm::Type *Args[] = {
160 Int8PtrTy, // const char *MangledName
161 Int32Ty, // uint32_t NumCounters
162 Int64PtrTy // uint64_t *Counters
163 };
164 llvm::FunctionType *FTy =
165 llvm::FunctionType::get(PGOBuilder.getVoidTy(), Args, false);
166 llvm::Constant *EmitFunc =
167 CGM.getModule().getOrInsertFunction("llvm_pgo_emit", FTy);
168
169 llvm::Constant *MangledName =
170 CGM.GetAddrOfConstantCString(CGM.getMangledName(GD), "__llvm_pgo_name");
171 MangledName = llvm::ConstantExpr::getBitCast(MangledName, Int8PtrTy);
172 PGOBuilder.CreateCall3(EmitFunc, MangledName,
173 PGOBuilder.getInt32(NumRegionCounters),
174 PGOBuilder.CreateBitCast(RegionCounters, Int64PtrTy));
175}
176
177llvm::Function *CodeGenPGO::emitInitialization(CodeGenModule &CGM) {
178 llvm::Function *WriteoutF =
179 CGM.getModule().getFunction("__llvm_pgo_writeout");
180 if (!WriteoutF)
181 return NULL;
182
183 // Create a small bit of code that registers the "__llvm_pgo_writeout" to
184 // be executed at exit.
185 llvm::Function *F = CGM.getModule().getFunction("__llvm_pgo_init");
186 if (F)
187 return NULL;
188
189 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
190 llvm::FunctionType *FTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx),
191 false);
192 F = llvm::Function::Create(FTy, llvm::GlobalValue::InternalLinkage,
193 "__llvm_pgo_init", &CGM.getModule());
194 F->setUnnamedAddr(true);
195 F->setLinkage(llvm::GlobalValue::InternalLinkage);
196 F->addFnAttr(llvm::Attribute::NoInline);
197 if (CGM.getCodeGenOpts().DisableRedZone)
198 F->addFnAttr(llvm::Attribute::NoRedZone);
199
200 llvm::BasicBlock *BB = llvm::BasicBlock::Create(CGM.getLLVMContext(), "", F);
201 CGBuilderTy PGOBuilder(BB);
202
203 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), false);
204 llvm::Type *Params[] = {
205 llvm::PointerType::get(FTy, 0)
206 };
207 FTy = llvm::FunctionType::get(PGOBuilder.getVoidTy(), Params, false);
208
209 // Inialize the environment and register the local writeout function.
210 llvm::Constant *PGOInit =
211 CGM.getModule().getOrInsertFunction("llvm_pgo_init", FTy);
212 PGOBuilder.CreateCall(PGOInit, WriteoutF);
213 PGOBuilder.CreateRetVoid();
214
215 return F;
216}
217
218namespace {
219 /// A StmtVisitor that fills a map of statements to PGO counters.
220 struct MapRegionCounters : public ConstStmtVisitor<MapRegionCounters> {
221 /// The next counter value to assign.
222 unsigned NextCounter;
223 /// The map of statements to counters.
224 llvm::DenseMap<const Stmt*, unsigned> *CounterMap;
225
226 MapRegionCounters(llvm::DenseMap<const Stmt*, unsigned> *CounterMap) :
227 NextCounter(0), CounterMap(CounterMap) {
228 }
229
230 void VisitChildren(const Stmt *S) {
231 for (Stmt::const_child_range I = S->children(); I; ++I)
232 if (*I)
233 this->Visit(*I);
234 }
235 void VisitStmt(const Stmt *S) { VisitChildren(S); }
236
Justin Bognerea278c32014-01-07 00:20:28 +0000237 /// Assign a counter to track entry to the function body.
Justin Bogneref512b92014-01-06 22:27:43 +0000238 void VisitFunctionDecl(const FunctionDecl *S) {
239 (*CounterMap)[S->getBody()] = NextCounter++;
240 Visit(S->getBody());
241 }
Justin Bognerea278c32014-01-07 00:20:28 +0000242 /// Assign a counter to track the block following a label.
Justin Bogneref512b92014-01-06 22:27:43 +0000243 void VisitLabelStmt(const LabelStmt *S) {
244 (*CounterMap)[S] = NextCounter++;
245 Visit(S->getSubStmt());
246 }
247 /// Assign three counters - one for the body of the loop, one for breaks
248 /// from the loop, and one for continues.
249 ///
250 /// The break and continue counters cover all such statements in this loop,
251 /// and are used in calculations to find the number of times the condition
252 /// and exit of the loop occur. They are needed so we can differentiate
253 /// these statements from non-local exits like return and goto.
254 void VisitWhileStmt(const WhileStmt *S) {
255 (*CounterMap)[S] = NextCounter;
256 NextCounter += 3;
257 Visit(S->getCond());
258 Visit(S->getBody());
259 }
260 /// Assign counters for the body of the loop, and for breaks and
261 /// continues. See VisitWhileStmt.
262 void VisitDoStmt(const DoStmt *S) {
263 (*CounterMap)[S] = NextCounter;
264 NextCounter += 3;
265 Visit(S->getBody());
266 Visit(S->getCond());
267 }
268 /// Assign counters for the body of the loop, and for breaks and
269 /// continues. See VisitWhileStmt.
270 void VisitForStmt(const ForStmt *S) {
271 (*CounterMap)[S] = NextCounter;
272 NextCounter += 3;
273 const Expr *E;
274 if ((E = S->getCond()))
275 Visit(E);
276 Visit(S->getBody());
277 if ((E = S->getInc()))
278 Visit(E);
279 }
280 /// Assign counters for the body of the loop, and for breaks and
281 /// continues. See VisitWhileStmt.
282 void VisitCXXForRangeStmt(const CXXForRangeStmt *S) {
283 (*CounterMap)[S] = NextCounter;
284 NextCounter += 3;
285 const Expr *E;
286 if ((E = S->getCond()))
287 Visit(E);
288 Visit(S->getBody());
289 if ((E = S->getInc()))
290 Visit(E);
291 }
292 /// Assign counters for the body of the loop, and for breaks and
293 /// continues. See VisitWhileStmt.
294 void VisitObjCForCollectionStmt(const ObjCForCollectionStmt *S) {
295 (*CounterMap)[S] = NextCounter;
296 NextCounter += 3;
297 Visit(S->getElement());
298 Visit(S->getBody());
299 }
300 /// Assign a counter for the exit block of the switch statement.
301 void VisitSwitchStmt(const SwitchStmt *S) {
302 (*CounterMap)[S] = NextCounter++;
303 Visit(S->getCond());
304 Visit(S->getBody());
305 }
306 /// Assign a counter for a particular case in a switch. This counts jumps
307 /// from the switch header as well as fallthrough from the case before this
308 /// one.
309 void VisitCaseStmt(const CaseStmt *S) {
310 (*CounterMap)[S] = NextCounter++;
311 Visit(S->getSubStmt());
312 }
313 /// Assign a counter for the default case of a switch statement. The count
314 /// is the number of branches from the loop header to the default, and does
315 /// not include fallthrough from previous cases. If we have multiple
316 /// conditional branch blocks from the switch instruction to the default
317 /// block, as with large GNU case ranges, this is the counter for the last
318 /// edge in that series, rather than the first.
319 void VisitDefaultStmt(const DefaultStmt *S) {
320 (*CounterMap)[S] = NextCounter++;
321 Visit(S->getSubStmt());
322 }
323 /// Assign a counter for the "then" part of an if statement. The count for
324 /// the "else" part, if it exists, will be calculated from this counter.
325 void VisitIfStmt(const IfStmt *S) {
326 (*CounterMap)[S] = NextCounter++;
327 Visit(S->getCond());
328 Visit(S->getThen());
329 if (S->getElse())
330 Visit(S->getElse());
331 }
332 /// Assign a counter for the continuation block of a C++ try statement.
333 void VisitCXXTryStmt(const CXXTryStmt *S) {
334 (*CounterMap)[S] = NextCounter++;
335 Visit(S->getTryBlock());
336 for (unsigned I = 0, E = S->getNumHandlers(); I < E; ++I)
337 Visit(S->getHandler(I));
338 }
339 /// Assign a counter for a catch statement's handler block.
340 void VisitCXXCatchStmt(const CXXCatchStmt *S) {
341 (*CounterMap)[S] = NextCounter++;
342 Visit(S->getHandlerBlock());
343 }
344 /// Assign a counter for the "true" part of a conditional operator. The
345 /// count in the "false" part will be calculated from this counter.
346 void VisitConditionalOperator(const ConditionalOperator *E) {
347 (*CounterMap)[E] = NextCounter++;
348 Visit(E->getCond());
349 Visit(E->getTrueExpr());
350 Visit(E->getFalseExpr());
351 }
352 /// Assign a counter for the right hand side of a logical and operator.
353 void VisitBinLAnd(const BinaryOperator *E) {
354 (*CounterMap)[E] = NextCounter++;
355 Visit(E->getLHS());
356 Visit(E->getRHS());
357 }
358 /// Assign a counter for the right hand side of a logical or operator.
359 void VisitBinLOr(const BinaryOperator *E) {
360 (*CounterMap)[E] = NextCounter++;
361 Visit(E->getLHS());
362 Visit(E->getRHS());
363 }
364 };
365}
366
367void CodeGenPGO::assignRegionCounters(GlobalDecl &GD) {
368 bool InstrumentRegions = CGM.getCodeGenOpts().ProfileInstrGenerate;
369 PGOProfileData *PGOData = CGM.getPGOData();
370 if (!InstrumentRegions && !PGOData)
371 return;
372 const Decl *D = GD.getDecl();
373 if (!D)
374 return;
375 mapRegionCounters(D);
376 if (InstrumentRegions)
377 emitCounterVariables();
378 if (PGOData)
379 loadRegionCounts(GD, PGOData);
380}
381
382void CodeGenPGO::mapRegionCounters(const Decl *D) {
383 RegionCounterMap = new llvm::DenseMap<const Stmt*, unsigned>();
384 MapRegionCounters Walker(RegionCounterMap);
385 if (const FunctionDecl *FD = dyn_cast_or_null<FunctionDecl>(D))
386 Walker.VisitFunctionDecl(FD);
387 NumRegionCounters = Walker.NextCounter;
388}
389
390void CodeGenPGO::emitCounterVariables() {
391 llvm::LLVMContext &Ctx = CGM.getLLVMContext();
392 llvm::ArrayType *CounterTy = llvm::ArrayType::get(llvm::Type::getInt64Ty(Ctx),
393 NumRegionCounters);
394 RegionCounters =
395 new llvm::GlobalVariable(CGM.getModule(), CounterTy, false,
396 llvm::GlobalVariable::PrivateLinkage,
397 llvm::Constant::getNullValue(CounterTy),
398 "__llvm_pgo_ctr");
399}
400
401void CodeGenPGO::emitCounterIncrement(CGBuilderTy &Builder, unsigned Counter) {
402 if (!CGM.getCodeGenOpts().ProfileInstrGenerate)
403 return;
404 llvm::Value *Addr =
405 Builder.CreateConstInBoundsGEP2_64(RegionCounters, 0, Counter);
406 llvm::Value *Count = Builder.CreateLoad(Addr, "pgocount");
407 Count = Builder.CreateAdd(Count, Builder.getInt64(1));
408 Builder.CreateStore(Count, Addr);
409}
410
411void CodeGenPGO::loadRegionCounts(GlobalDecl &GD, PGOProfileData *PGOData) {
412 // For now, ignore the counts from the PGO data file only if the number of
413 // counters does not match. This could be tightened down in the future to
414 // ignore counts when the input changes in various ways, e.g., by comparing a
415 // hash value based on some characteristics of the input.
416 RegionCounts = new std::vector<uint64_t>();
417 if (PGOData->getFunctionCounts(CGM.getMangledName(GD), *RegionCounts) ||
418 RegionCounts->size() != NumRegionCounters) {
419 delete RegionCounts;
420 RegionCounts = 0;
421 }
422}
423
424void CodeGenPGO::destroyRegionCounters() {
425 if (RegionCounterMap != 0)
426 delete RegionCounterMap;
427 if (RegionCounts != 0)
428 delete RegionCounts;
429}
430
431llvm::MDNode *CodeGenPGO::createBranchWeights(uint64_t TrueCount,
432 uint64_t FalseCount) {
433 if (!TrueCount && !FalseCount)
434 return 0;
435
436 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
437 // TODO: need to scale down to 32-bits
438 // According to Laplace's Rule of Succession, it is better to compute the
439 // weight based on the count plus 1.
440 return MDHelper.createBranchWeights(TrueCount + 1, FalseCount + 1);
441}
442
443llvm::MDNode *
444CodeGenPGO::createBranchWeights(ArrayRef<uint64_t> Weights) {
445 llvm::MDBuilder MDHelper(CGM.getLLVMContext());
446 // TODO: need to scale down to 32-bits, instead of just truncating.
447 // According to Laplace's Rule of Succession, it is better to compute the
448 // weight based on the count plus 1.
449 SmallVector<uint32_t, 16> ScaledWeights;
450 ScaledWeights.reserve(Weights.size());
451 for (ArrayRef<uint64_t>::iterator WI = Weights.begin(), WE = Weights.end();
452 WI != WE; ++WI) {
453 ScaledWeights.push_back(*WI + 1);
454 }
455 return MDHelper.createBranchWeights(ScaledWeights);
456}