blob: b14bfa36f410cd5a57a80940641f13c98629dfa9 [file] [log] [blame]
Peter Collingbournedf49d1b2016-02-09 22:50:34 +00001//===- WholeProgramDevirt.cpp - Whole program virtual call optimization ---===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass implements whole program optimization of virtual calls in cases
Peter Collingbourne7efd7502016-06-24 21:21:32 +000011// where we know (via !type metadata) that the list of callees is fixed. This
Peter Collingbournedf49d1b2016-02-09 22:50:34 +000012// includes the following:
13// - Single implementation devirtualization: if a virtual call has a single
14// possible callee, replace all calls with a direct call to that callee.
15// - Virtual constant propagation: if the virtual function's return type is an
16// integer <=64 bits and all possible callees are readnone, for each class and
17// each list of constant arguments: evaluate the function, store the return
18// value alongside the virtual table, and rewrite each virtual call as a load
19// from the virtual table.
20// - Uniform return value optimization: if the conditions for virtual constant
21// propagation hold and each function returns the same constant value, replace
22// each virtual call with that constant.
23// - Unique return value optimization for i1 return values: if the conditions
24// for virtual constant propagation hold and a single vtable's function
25// returns 0, or a single vtable's function returns 1, replace each virtual
26// call with a comparison of the vptr against that vtable's address.
27//
28//===----------------------------------------------------------------------===//
29
30#include "llvm/Transforms/IPO/WholeProgramDevirt.h"
Mehdi Aminib550cb12016-04-18 09:17:29 +000031#include "llvm/ADT/ArrayRef.h"
Peter Collingbournedf49d1b2016-02-09 22:50:34 +000032#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/MapVector.h"
Peter Collingbourne7efd7502016-06-24 21:21:32 +000034#include "llvm/Analysis/TypeMetadataUtils.h"
Peter Collingbournedf49d1b2016-02-09 22:50:34 +000035#include "llvm/IR/CallSite.h"
36#include "llvm/IR/Constants.h"
37#include "llvm/IR/DataLayout.h"
Ivan Krasinb05e06e2016-08-05 19:45:16 +000038#include "llvm/IR/DebugInfoMetadata.h"
Ivan Krasin54746452016-07-12 02:38:37 +000039#include "llvm/IR/DiagnosticInfo.h"
Peter Collingbournedf49d1b2016-02-09 22:50:34 +000040#include "llvm/IR/IRBuilder.h"
41#include "llvm/IR/Instructions.h"
42#include "llvm/IR/Intrinsics.h"
43#include "llvm/IR/Module.h"
44#include "llvm/Pass.h"
45#include "llvm/Support/raw_ostream.h"
Mehdi Aminib550cb12016-04-18 09:17:29 +000046#include "llvm/Transforms/IPO.h"
Peter Collingbournedf49d1b2016-02-09 22:50:34 +000047#include "llvm/Transforms/Utils/Evaluator.h"
48#include "llvm/Transforms/Utils/Local.h"
49
50#include <set>
51
52using namespace llvm;
53using namespace wholeprogramdevirt;
54
55#define DEBUG_TYPE "wholeprogramdevirt"
56
57// Find the minimum offset that we may store a value of size Size bits at. If
58// IsAfter is set, look for an offset before the object, otherwise look for an
59// offset after the object.
60uint64_t
61wholeprogramdevirt::findLowestOffset(ArrayRef<VirtualCallTarget> Targets,
62 bool IsAfter, uint64_t Size) {
63 // Find a minimum offset taking into account only vtable sizes.
64 uint64_t MinByte = 0;
65 for (const VirtualCallTarget &Target : Targets) {
66 if (IsAfter)
67 MinByte = std::max(MinByte, Target.minAfterBytes());
68 else
69 MinByte = std::max(MinByte, Target.minBeforeBytes());
70 }
71
72 // Build a vector of arrays of bytes covering, for each target, a slice of the
73 // used region (see AccumBitVector::BytesUsed in
74 // llvm/Transforms/IPO/WholeProgramDevirt.h) starting at MinByte. Effectively,
75 // this aligns the used regions to start at MinByte.
76 //
77 // In this example, A, B and C are vtables, # is a byte already allocated for
78 // a virtual function pointer, AAAA... (etc.) are the used regions for the
79 // vtables and Offset(X) is the value computed for the Offset variable below
80 // for X.
81 //
82 // Offset(A)
83 // | |
84 // |MinByte
85 // A: ################AAAAAAAA|AAAAAAAA
86 // B: ########BBBBBBBBBBBBBBBB|BBBB
87 // C: ########################|CCCCCCCCCCCCCCCC
88 // | Offset(B) |
89 //
90 // This code produces the slices of A, B and C that appear after the divider
91 // at MinByte.
92 std::vector<ArrayRef<uint8_t>> Used;
93 for (const VirtualCallTarget &Target : Targets) {
Peter Collingbourne7efd7502016-06-24 21:21:32 +000094 ArrayRef<uint8_t> VTUsed = IsAfter ? Target.TM->Bits->After.BytesUsed
95 : Target.TM->Bits->Before.BytesUsed;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +000096 uint64_t Offset = IsAfter ? MinByte - Target.minAfterBytes()
97 : MinByte - Target.minBeforeBytes();
98
99 // Disregard used regions that are smaller than Offset. These are
100 // effectively all-free regions that do not need to be checked.
101 if (VTUsed.size() > Offset)
102 Used.push_back(VTUsed.slice(Offset));
103 }
104
105 if (Size == 1) {
106 // Find a free bit in each member of Used.
107 for (unsigned I = 0;; ++I) {
108 uint8_t BitsUsed = 0;
109 for (auto &&B : Used)
110 if (I < B.size())
111 BitsUsed |= B[I];
112 if (BitsUsed != 0xff)
113 return (MinByte + I) * 8 +
114 countTrailingZeros(uint8_t(~BitsUsed), ZB_Undefined);
115 }
116 } else {
117 // Find a free (Size/8) byte region in each member of Used.
118 // FIXME: see if alignment helps.
119 for (unsigned I = 0;; ++I) {
120 for (auto &&B : Used) {
121 unsigned Byte = 0;
122 while ((I + Byte) < B.size() && Byte < (Size / 8)) {
123 if (B[I + Byte])
124 goto NextI;
125 ++Byte;
126 }
127 }
128 return (MinByte + I) * 8;
129 NextI:;
130 }
131 }
132}
133
134void wholeprogramdevirt::setBeforeReturnValues(
135 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocBefore,
136 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
137 if (BitWidth == 1)
138 OffsetByte = -(AllocBefore / 8 + 1);
139 else
140 OffsetByte = -((AllocBefore + 7) / 8 + (BitWidth + 7) / 8);
141 OffsetBit = AllocBefore % 8;
142
143 for (VirtualCallTarget &Target : Targets) {
144 if (BitWidth == 1)
145 Target.setBeforeBit(AllocBefore);
146 else
147 Target.setBeforeBytes(AllocBefore, (BitWidth + 7) / 8);
148 }
149}
150
151void wholeprogramdevirt::setAfterReturnValues(
152 MutableArrayRef<VirtualCallTarget> Targets, uint64_t AllocAfter,
153 unsigned BitWidth, int64_t &OffsetByte, uint64_t &OffsetBit) {
154 if (BitWidth == 1)
155 OffsetByte = AllocAfter / 8;
156 else
157 OffsetByte = (AllocAfter + 7) / 8;
158 OffsetBit = AllocAfter % 8;
159
160 for (VirtualCallTarget &Target : Targets) {
161 if (BitWidth == 1)
162 Target.setAfterBit(AllocAfter);
163 else
164 Target.setAfterBytes(AllocAfter, (BitWidth + 7) / 8);
165 }
166}
167
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000168VirtualCallTarget::VirtualCallTarget(Function *Fn, const TypeMemberInfo *TM)
169 : Fn(Fn), TM(TM),
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000170 IsBigEndian(Fn->getParent()->getDataLayout().isBigEndian()) {}
171
172namespace {
173
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000174// A slot in a set of virtual tables. The TypeID identifies the set of virtual
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000175// tables, and the ByteOffset is the offset in bytes from the address point to
176// the virtual function pointer.
177struct VTableSlot {
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000178 Metadata *TypeID;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000179 uint64_t ByteOffset;
180};
181
182}
183
Peter Collingbourne9b656522016-02-09 23:01:38 +0000184namespace llvm {
185
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000186template <> struct DenseMapInfo<VTableSlot> {
187 static VTableSlot getEmptyKey() {
188 return {DenseMapInfo<Metadata *>::getEmptyKey(),
189 DenseMapInfo<uint64_t>::getEmptyKey()};
190 }
191 static VTableSlot getTombstoneKey() {
192 return {DenseMapInfo<Metadata *>::getTombstoneKey(),
193 DenseMapInfo<uint64_t>::getTombstoneKey()};
194 }
195 static unsigned getHashValue(const VTableSlot &I) {
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000196 return DenseMapInfo<Metadata *>::getHashValue(I.TypeID) ^
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000197 DenseMapInfo<uint64_t>::getHashValue(I.ByteOffset);
198 }
199 static bool isEqual(const VTableSlot &LHS,
200 const VTableSlot &RHS) {
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000201 return LHS.TypeID == RHS.TypeID && LHS.ByteOffset == RHS.ByteOffset;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000202 }
203};
204
Peter Collingbourne9b656522016-02-09 23:01:38 +0000205}
206
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000207namespace {
208
209// A virtual call site. VTable is the loaded virtual table pointer, and CS is
210// the indirect virtual call.
211struct VirtualCallSite {
212 Value *VTable;
213 CallSite CS;
214
Peter Collingbourne0312f612016-06-25 00:23:04 +0000215 // If non-null, this field points to the associated unsafe use count stored in
216 // the DevirtModule::NumUnsafeUsesForTypeTest map below. See the description
217 // of that field for details.
218 unsigned *NumUnsafeUses;
219
Ivan Krasin54746452016-07-12 02:38:37 +0000220 void emitRemark() {
221 Function *F = CS.getCaller();
222 emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F,
223 CS.getInstruction()->getDebugLoc(),
224 "devirtualized call");
225 }
226
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000227 void replaceAndErase(Value *New) {
Ivan Krasin54746452016-07-12 02:38:37 +0000228 emitRemark();
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000229 CS->replaceAllUsesWith(New);
230 if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) {
231 BranchInst::Create(II->getNormalDest(), CS.getInstruction());
232 II->getUnwindDest()->removePredecessor(II->getParent());
233 }
234 CS->eraseFromParent();
Peter Collingbourne0312f612016-06-25 00:23:04 +0000235 // This use is no longer unsafe.
236 if (NumUnsafeUses)
237 --*NumUnsafeUses;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000238 }
239};
240
241struct DevirtModule {
242 Module &M;
243 IntegerType *Int8Ty;
244 PointerType *Int8PtrTy;
245 IntegerType *Int32Ty;
246
247 MapVector<VTableSlot, std::vector<VirtualCallSite>> CallSlots;
248
Peter Collingbourne0312f612016-06-25 00:23:04 +0000249 // This map keeps track of the number of "unsafe" uses of a loaded function
250 // pointer. The key is the associated llvm.type.test intrinsic call generated
251 // by this pass. An unsafe use is one that calls the loaded function pointer
252 // directly. Every time we eliminate an unsafe use (for example, by
253 // devirtualizing it or by applying virtual constant propagation), we
254 // decrement the value stored in this map. If a value reaches zero, we can
255 // eliminate the type check by RAUWing the associated llvm.type.test call with
256 // true.
257 std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest;
258
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000259 DevirtModule(Module &M)
260 : M(M), Int8Ty(Type::getInt8Ty(M.getContext())),
261 Int8PtrTy(Type::getInt8PtrTy(M.getContext())),
262 Int32Ty(Type::getInt32Ty(M.getContext())) {}
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000263
Peter Collingbourne0312f612016-06-25 00:23:04 +0000264 void scanTypeTestUsers(Function *TypeTestFunc, Function *AssumeFunc);
265 void scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc);
266
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000267 void buildTypeIdentifierMap(
268 std::vector<VTableBits> &Bits,
269 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap);
270 bool
271 tryFindVirtualCallTargets(std::vector<VirtualCallTarget> &TargetsForSlot,
272 const std::set<TypeMemberInfo> &TypeMemberInfos,
273 uint64_t ByteOffset);
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000274 bool trySingleImplDevirt(ArrayRef<VirtualCallTarget> TargetsForSlot,
275 MutableArrayRef<VirtualCallSite> CallSites);
276 bool tryEvaluateFunctionsWithArgs(
277 MutableArrayRef<VirtualCallTarget> TargetsForSlot,
278 ArrayRef<ConstantInt *> Args);
279 bool tryUniformRetValOpt(IntegerType *RetType,
280 ArrayRef<VirtualCallTarget> TargetsForSlot,
281 MutableArrayRef<VirtualCallSite> CallSites);
282 bool tryUniqueRetValOpt(unsigned BitWidth,
283 ArrayRef<VirtualCallTarget> TargetsForSlot,
284 MutableArrayRef<VirtualCallSite> CallSites);
285 bool tryVirtualConstProp(MutableArrayRef<VirtualCallTarget> TargetsForSlot,
286 ArrayRef<VirtualCallSite> CallSites);
287
288 void rebuildGlobal(VTableBits &B);
289
290 bool run();
291};
292
293struct WholeProgramDevirt : public ModulePass {
294 static char ID;
295 WholeProgramDevirt() : ModulePass(ID) {
296 initializeWholeProgramDevirtPass(*PassRegistry::getPassRegistry());
297 }
Andrew Kayloraa641a52016-04-22 22:06:11 +0000298 bool runOnModule(Module &M) {
299 if (skipModule(M))
300 return false;
301
302 return DevirtModule(M).run();
303 }
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000304};
305
306} // anonymous namespace
307
308INITIALIZE_PASS(WholeProgramDevirt, "wholeprogramdevirt",
309 "Whole program devirtualization", false, false)
310char WholeProgramDevirt::ID = 0;
311
312ModulePass *llvm::createWholeProgramDevirtPass() {
313 return new WholeProgramDevirt;
314}
315
Chandler Carruth164a2aa62016-06-17 00:11:01 +0000316PreservedAnalyses WholeProgramDevirtPass::run(Module &M,
317 ModuleAnalysisManager &) {
Davide Italianod737dd22016-06-14 21:44:19 +0000318 if (!DevirtModule(M).run())
319 return PreservedAnalyses::all();
320 return PreservedAnalyses::none();
321}
322
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000323void DevirtModule::buildTypeIdentifierMap(
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000324 std::vector<VTableBits> &Bits,
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000325 DenseMap<Metadata *, std::set<TypeMemberInfo>> &TypeIdMap) {
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000326 DenseMap<GlobalVariable *, VTableBits *> GVToBits;
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000327 Bits.reserve(M.getGlobalList().size());
328 SmallVector<MDNode *, 2> Types;
329 for (GlobalVariable &GV : M.globals()) {
330 Types.clear();
331 GV.getMetadata(LLVMContext::MD_type, Types);
332 if (Types.empty())
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000333 continue;
334
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000335 VTableBits *&BitsPtr = GVToBits[&GV];
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000336 if (!BitsPtr) {
337 Bits.emplace_back();
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000338 Bits.back().GV = &GV;
339 Bits.back().ObjectSize =
340 M.getDataLayout().getTypeAllocSize(GV.getInitializer()->getType());
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000341 BitsPtr = &Bits.back();
342 }
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000343
344 for (MDNode *Type : Types) {
345 auto TypeID = Type->getOperand(1).get();
346
347 uint64_t Offset =
348 cast<ConstantInt>(
349 cast<ConstantAsMetadata>(Type->getOperand(0))->getValue())
350 ->getZExtValue();
351
352 TypeIdMap[TypeID].insert({BitsPtr, Offset});
353 }
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000354 }
355}
356
357bool DevirtModule::tryFindVirtualCallTargets(
358 std::vector<VirtualCallTarget> &TargetsForSlot,
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000359 const std::set<TypeMemberInfo> &TypeMemberInfos, uint64_t ByteOffset) {
360 for (const TypeMemberInfo &TM : TypeMemberInfos) {
361 if (!TM.Bits->GV->isConstant())
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000362 return false;
363
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000364 auto Init = dyn_cast<ConstantArray>(TM.Bits->GV->getInitializer());
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000365 if (!Init)
366 return false;
367 ArrayType *VTableTy = Init->getType();
368
369 uint64_t ElemSize =
370 M.getDataLayout().getTypeAllocSize(VTableTy->getElementType());
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000371 uint64_t GlobalSlotOffset = TM.Offset + ByteOffset;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000372 if (GlobalSlotOffset % ElemSize != 0)
373 return false;
374
375 unsigned Op = GlobalSlotOffset / ElemSize;
376 if (Op >= Init->getNumOperands())
377 return false;
378
379 auto Fn = dyn_cast<Function>(Init->getOperand(Op)->stripPointerCasts());
380 if (!Fn)
381 return false;
382
383 // We can disregard __cxa_pure_virtual as a possible call target, as
384 // calls to pure virtuals are UB.
385 if (Fn->getName() == "__cxa_pure_virtual")
386 continue;
387
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000388 TargetsForSlot.push_back({Fn, &TM});
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000389 }
390
391 // Give up if we couldn't find any targets.
392 return !TargetsForSlot.empty();
393}
394
395bool DevirtModule::trySingleImplDevirt(
396 ArrayRef<VirtualCallTarget> TargetsForSlot,
397 MutableArrayRef<VirtualCallSite> CallSites) {
398 // See if the program contains a single implementation of this virtual
399 // function.
400 Function *TheFn = TargetsForSlot[0].Fn;
401 for (auto &&Target : TargetsForSlot)
402 if (TheFn != Target.Fn)
403 return false;
404
405 // If so, update each call site to call that implementation directly.
406 for (auto &&VCallSite : CallSites) {
Ivan Krasin54746452016-07-12 02:38:37 +0000407 VCallSite.emitRemark();
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000408 VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast(
409 TheFn, VCallSite.CS.getCalledValue()->getType()));
Peter Collingbourne0312f612016-06-25 00:23:04 +0000410 // This use is no longer unsafe.
411 if (VCallSite.NumUnsafeUses)
412 --*VCallSite.NumUnsafeUses;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000413 }
414 return true;
415}
416
417bool DevirtModule::tryEvaluateFunctionsWithArgs(
418 MutableArrayRef<VirtualCallTarget> TargetsForSlot,
419 ArrayRef<ConstantInt *> Args) {
420 // Evaluate each function and store the result in each target's RetVal
421 // field.
422 for (VirtualCallTarget &Target : TargetsForSlot) {
423 if (Target.Fn->arg_size() != Args.size() + 1)
424 return false;
425 for (unsigned I = 0; I != Args.size(); ++I)
426 if (Target.Fn->getFunctionType()->getParamType(I + 1) !=
427 Args[I]->getType())
428 return false;
429
430 Evaluator Eval(M.getDataLayout(), nullptr);
431 SmallVector<Constant *, 2> EvalArgs;
432 EvalArgs.push_back(
433 Constant::getNullValue(Target.Fn->getFunctionType()->getParamType(0)));
434 EvalArgs.insert(EvalArgs.end(), Args.begin(), Args.end());
435 Constant *RetVal;
436 if (!Eval.EvaluateFunction(Target.Fn, RetVal, EvalArgs) ||
437 !isa<ConstantInt>(RetVal))
438 return false;
439 Target.RetVal = cast<ConstantInt>(RetVal)->getZExtValue();
440 }
441 return true;
442}
443
444bool DevirtModule::tryUniformRetValOpt(
445 IntegerType *RetType, ArrayRef<VirtualCallTarget> TargetsForSlot,
446 MutableArrayRef<VirtualCallSite> CallSites) {
447 // Uniform return value optimization. If all functions return the same
448 // constant, replace all calls with that constant.
449 uint64_t TheRetVal = TargetsForSlot[0].RetVal;
450 for (const VirtualCallTarget &Target : TargetsForSlot)
451 if (Target.RetVal != TheRetVal)
452 return false;
453
454 auto TheRetValConst = ConstantInt::get(RetType, TheRetVal);
455 for (auto Call : CallSites)
456 Call.replaceAndErase(TheRetValConst);
457 return true;
458}
459
460bool DevirtModule::tryUniqueRetValOpt(
461 unsigned BitWidth, ArrayRef<VirtualCallTarget> TargetsForSlot,
462 MutableArrayRef<VirtualCallSite> CallSites) {
463 // IsOne controls whether we look for a 0 or a 1.
464 auto tryUniqueRetValOptFor = [&](bool IsOne) {
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000465 const TypeMemberInfo *UniqueMember = 0;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000466 for (const VirtualCallTarget &Target : TargetsForSlot) {
Peter Collingbourne3866cc52016-03-08 03:50:36 +0000467 if (Target.RetVal == (IsOne ? 1 : 0)) {
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000468 if (UniqueMember)
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000469 return false;
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000470 UniqueMember = Target.TM;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000471 }
472 }
473
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000474 // We should have found a unique member or bailed out by now. We already
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000475 // checked for a uniform return value in tryUniformRetValOpt.
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000476 assert(UniqueMember);
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000477
478 // Replace each call with the comparison.
479 for (auto &&Call : CallSites) {
480 IRBuilder<> B(Call.CS.getInstruction());
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000481 Value *OneAddr = B.CreateBitCast(UniqueMember->Bits->GV, Int8PtrTy);
482 OneAddr = B.CreateConstGEP1_64(OneAddr, UniqueMember->Offset);
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000483 Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
484 Call.VTable, OneAddr);
485 Call.replaceAndErase(Cmp);
486 }
487 return true;
488 };
489
490 if (BitWidth == 1) {
491 if (tryUniqueRetValOptFor(true))
492 return true;
493 if (tryUniqueRetValOptFor(false))
494 return true;
495 }
496 return false;
497}
498
499bool DevirtModule::tryVirtualConstProp(
500 MutableArrayRef<VirtualCallTarget> TargetsForSlot,
501 ArrayRef<VirtualCallSite> CallSites) {
502 // This only works if the function returns an integer.
503 auto RetType = dyn_cast<IntegerType>(TargetsForSlot[0].Fn->getReturnType());
504 if (!RetType)
505 return false;
506 unsigned BitWidth = RetType->getBitWidth();
507 if (BitWidth > 64)
508 return false;
509
510 // Make sure that each function does not access memory, takes at least one
511 // argument, does not use its first argument (which we assume is 'this'),
512 // and has the same return type.
513 for (VirtualCallTarget &Target : TargetsForSlot) {
514 if (!Target.Fn->doesNotAccessMemory() || Target.Fn->arg_empty() ||
515 !Target.Fn->arg_begin()->use_empty() ||
516 Target.Fn->getReturnType() != RetType)
517 return false;
518 }
519
520 // Group call sites by the list of constant arguments they pass.
521 // The comparator ensures deterministic ordering.
522 struct ByAPIntValue {
523 bool operator()(const std::vector<ConstantInt *> &A,
524 const std::vector<ConstantInt *> &B) const {
525 return std::lexicographical_compare(
526 A.begin(), A.end(), B.begin(), B.end(),
527 [](ConstantInt *AI, ConstantInt *BI) {
528 return AI->getValue().ult(BI->getValue());
529 });
530 }
531 };
532 std::map<std::vector<ConstantInt *>, std::vector<VirtualCallSite>,
533 ByAPIntValue>
534 VCallSitesByConstantArg;
535 for (auto &&VCallSite : CallSites) {
536 std::vector<ConstantInt *> Args;
537 if (VCallSite.CS.getType() != RetType)
538 continue;
539 for (auto &&Arg :
540 make_range(VCallSite.CS.arg_begin() + 1, VCallSite.CS.arg_end())) {
541 if (!isa<ConstantInt>(Arg))
542 break;
543 Args.push_back(cast<ConstantInt>(&Arg));
544 }
545 if (Args.size() + 1 != VCallSite.CS.arg_size())
546 continue;
547
548 VCallSitesByConstantArg[Args].push_back(VCallSite);
549 }
550
551 for (auto &&CSByConstantArg : VCallSitesByConstantArg) {
552 if (!tryEvaluateFunctionsWithArgs(TargetsForSlot, CSByConstantArg.first))
553 continue;
554
555 if (tryUniformRetValOpt(RetType, TargetsForSlot, CSByConstantArg.second))
556 continue;
557
558 if (tryUniqueRetValOpt(BitWidth, TargetsForSlot, CSByConstantArg.second))
559 continue;
560
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000561 // Find an allocation offset in bits in all vtables associated with the
562 // type.
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000563 uint64_t AllocBefore =
564 findLowestOffset(TargetsForSlot, /*IsAfter=*/false, BitWidth);
565 uint64_t AllocAfter =
566 findLowestOffset(TargetsForSlot, /*IsAfter=*/true, BitWidth);
567
568 // Calculate the total amount of padding needed to store a value at both
569 // ends of the object.
570 uint64_t TotalPaddingBefore = 0, TotalPaddingAfter = 0;
571 for (auto &&Target : TargetsForSlot) {
572 TotalPaddingBefore += std::max<int64_t>(
573 (AllocBefore + 7) / 8 - Target.allocatedBeforeBytes() - 1, 0);
574 TotalPaddingAfter += std::max<int64_t>(
575 (AllocAfter + 7) / 8 - Target.allocatedAfterBytes() - 1, 0);
576 }
577
578 // If the amount of padding is too large, give up.
579 // FIXME: do something smarter here.
580 if (std::min(TotalPaddingBefore, TotalPaddingAfter) > 128)
581 continue;
582
583 // Calculate the offset to the value as a (possibly negative) byte offset
584 // and (if applicable) a bit offset, and store the values in the targets.
585 int64_t OffsetByte;
586 uint64_t OffsetBit;
587 if (TotalPaddingBefore <= TotalPaddingAfter)
588 setBeforeReturnValues(TargetsForSlot, AllocBefore, BitWidth, OffsetByte,
589 OffsetBit);
590 else
591 setAfterReturnValues(TargetsForSlot, AllocAfter, BitWidth, OffsetByte,
592 OffsetBit);
593
594 // Rewrite each call to a load from OffsetByte/OffsetBit.
595 for (auto Call : CSByConstantArg.second) {
596 IRBuilder<> B(Call.CS.getInstruction());
597 Value *Addr = B.CreateConstGEP1_64(Call.VTable, OffsetByte);
598 if (BitWidth == 1) {
599 Value *Bits = B.CreateLoad(Addr);
Aaron Ballmanef0fe1e2016-03-30 21:30:00 +0000600 Value *Bit = ConstantInt::get(Int8Ty, 1ULL << OffsetBit);
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000601 Value *BitsAndBit = B.CreateAnd(Bits, Bit);
602 auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0));
603 Call.replaceAndErase(IsBitSet);
604 } else {
605 Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo());
606 Value *Val = B.CreateLoad(RetType, ValAddr);
607 Call.replaceAndErase(Val);
608 }
609 }
610 }
611 return true;
612}
613
Ivan Krasinb05e06e2016-08-05 19:45:16 +0000614static void emitTargetsRemarks(const std::vector<VirtualCallTarget> &TargetsForSlot) {
615 for (const VirtualCallTarget &Target : TargetsForSlot) {
616 Function *F = Target.Fn;
617 DISubprogram *SP = F->getSubprogram();
618 DebugLoc DL = SP ? DebugLoc::get(SP->getScopeLine(), 0, SP) : DebugLoc();
619 emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, DL,
620 std::string("devirtualized ") + F->getName().str());
621 }
622}
623
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000624void DevirtModule::rebuildGlobal(VTableBits &B) {
625 if (B.Before.Bytes.empty() && B.After.Bytes.empty())
626 return;
627
628 // Align each byte array to pointer width.
629 unsigned PointerSize = M.getDataLayout().getPointerSize();
630 B.Before.Bytes.resize(alignTo(B.Before.Bytes.size(), PointerSize));
631 B.After.Bytes.resize(alignTo(B.After.Bytes.size(), PointerSize));
632
633 // Before was stored in reverse order; flip it now.
634 for (size_t I = 0, Size = B.Before.Bytes.size(); I != Size / 2; ++I)
635 std::swap(B.Before.Bytes[I], B.Before.Bytes[Size - 1 - I]);
636
637 // Build an anonymous global containing the before bytes, followed by the
638 // original initializer, followed by the after bytes.
639 auto NewInit = ConstantStruct::getAnon(
640 {ConstantDataArray::get(M.getContext(), B.Before.Bytes),
641 B.GV->getInitializer(),
642 ConstantDataArray::get(M.getContext(), B.After.Bytes)});
643 auto NewGV =
644 new GlobalVariable(M, NewInit->getType(), B.GV->isConstant(),
645 GlobalVariable::PrivateLinkage, NewInit, "", B.GV);
646 NewGV->setSection(B.GV->getSection());
647 NewGV->setComdat(B.GV->getComdat());
648
Peter Collingbourne0312f612016-06-25 00:23:04 +0000649 // Copy the original vtable's metadata to the anonymous global, adjusting
650 // offsets as required.
651 NewGV->copyMetadata(B.GV, B.Before.Bytes.size());
652
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000653 // Build an alias named after the original global, pointing at the second
654 // element (the original initializer).
655 auto Alias = GlobalAlias::create(
656 B.GV->getInitializer()->getType(), 0, B.GV->getLinkage(), "",
657 ConstantExpr::getGetElementPtr(
658 NewInit->getType(), NewGV,
659 ArrayRef<Constant *>{ConstantInt::get(Int32Ty, 0),
660 ConstantInt::get(Int32Ty, 1)}),
661 &M);
662 Alias->setVisibility(B.GV->getVisibility());
663 Alias->takeName(B.GV);
664
665 B.GV->replaceAllUsesWith(Alias);
666 B.GV->eraseFromParent();
667}
668
Peter Collingbourne0312f612016-06-25 00:23:04 +0000669void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc,
670 Function *AssumeFunc) {
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000671 // Find all virtual calls via a virtual table pointer %p under an assumption
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000672 // of the form llvm.assume(llvm.type.test(%p, %md)). This indicates that %p
673 // points to a member of the type identifier %md. Group calls by (type ID,
674 // offset) pair (effectively the identity of the virtual function) and store
675 // to CallSlots.
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000676 DenseSet<Value *> SeenPtrs;
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000677 for (auto I = TypeTestFunc->use_begin(), E = TypeTestFunc->use_end();
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000678 I != E;) {
679 auto CI = dyn_cast<CallInst>(I->getUser());
680 ++I;
681 if (!CI)
682 continue;
683
Peter Collingbourneccdc2252016-05-10 18:07:21 +0000684 // Search for virtual calls based on %p and add them to DevirtCalls.
685 SmallVector<DevirtCallSite, 1> DevirtCalls;
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000686 SmallVector<CallInst *, 1> Assumes;
Peter Collingbourne0312f612016-06-25 00:23:04 +0000687 findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI);
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000688
Peter Collingbourneccdc2252016-05-10 18:07:21 +0000689 // If we found any, add them to CallSlots. Only do this if we haven't seen
690 // the vtable pointer before, as it may have been CSE'd with pointers from
691 // other call sites, and we don't want to process call sites multiple times.
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000692 if (!Assumes.empty()) {
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000693 Metadata *TypeId =
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000694 cast<MetadataAsValue>(CI->getArgOperand(1))->getMetadata();
695 Value *Ptr = CI->getArgOperand(0)->stripPointerCasts();
Peter Collingbourneccdc2252016-05-10 18:07:21 +0000696 if (SeenPtrs.insert(Ptr).second) {
697 for (DevirtCallSite Call : DevirtCalls) {
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000698 CallSlots[{TypeId, Call.Offset}].push_back(
Peter Collingbourne0312f612016-06-25 00:23:04 +0000699 {CI->getArgOperand(0), Call.CS, nullptr});
Peter Collingbourneccdc2252016-05-10 18:07:21 +0000700 }
701 }
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000702 }
703
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000704 // We no longer need the assumes or the type test.
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000705 for (auto Assume : Assumes)
706 Assume->eraseFromParent();
707 // We can't use RecursivelyDeleteTriviallyDeadInstructions here because we
708 // may use the vtable argument later.
709 if (CI->use_empty())
710 CI->eraseFromParent();
711 }
Peter Collingbourne0312f612016-06-25 00:23:04 +0000712}
713
714void DevirtModule::scanTypeCheckedLoadUsers(Function *TypeCheckedLoadFunc) {
715 Function *TypeTestFunc = Intrinsic::getDeclaration(&M, Intrinsic::type_test);
716
717 for (auto I = TypeCheckedLoadFunc->use_begin(),
718 E = TypeCheckedLoadFunc->use_end();
719 I != E;) {
720 auto CI = dyn_cast<CallInst>(I->getUser());
721 ++I;
722 if (!CI)
723 continue;
724
725 Value *Ptr = CI->getArgOperand(0);
726 Value *Offset = CI->getArgOperand(1);
727 Value *TypeIdValue = CI->getArgOperand(2);
728 Metadata *TypeId = cast<MetadataAsValue>(TypeIdValue)->getMetadata();
729
730 SmallVector<DevirtCallSite, 1> DevirtCalls;
731 SmallVector<Instruction *, 1> LoadedPtrs;
732 SmallVector<Instruction *, 1> Preds;
733 bool HasNonCallUses = false;
734 findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds,
735 HasNonCallUses, CI);
736
737 // Start by generating "pessimistic" code that explicitly loads the function
738 // pointer from the vtable and performs the type check. If possible, we will
739 // eliminate the load and the type check later.
740
741 // If possible, only generate the load at the point where it is used.
742 // This helps avoid unnecessary spills.
743 IRBuilder<> LoadB(
744 (LoadedPtrs.size() == 1 && !HasNonCallUses) ? LoadedPtrs[0] : CI);
745 Value *GEP = LoadB.CreateGEP(Int8Ty, Ptr, Offset);
746 Value *GEPPtr = LoadB.CreateBitCast(GEP, PointerType::getUnqual(Int8PtrTy));
747 Value *LoadedValue = LoadB.CreateLoad(Int8PtrTy, GEPPtr);
748
749 for (Instruction *LoadedPtr : LoadedPtrs) {
750 LoadedPtr->replaceAllUsesWith(LoadedValue);
751 LoadedPtr->eraseFromParent();
752 }
753
754 // Likewise for the type test.
755 IRBuilder<> CallB((Preds.size() == 1 && !HasNonCallUses) ? Preds[0] : CI);
756 CallInst *TypeTestCall = CallB.CreateCall(TypeTestFunc, {Ptr, TypeIdValue});
757
758 for (Instruction *Pred : Preds) {
759 Pred->replaceAllUsesWith(TypeTestCall);
760 Pred->eraseFromParent();
761 }
762
763 // We have already erased any extractvalue instructions that refer to the
764 // intrinsic call, but the intrinsic may have other non-extractvalue uses
765 // (although this is unlikely). In that case, explicitly build a pair and
766 // RAUW it.
767 if (!CI->use_empty()) {
768 Value *Pair = UndefValue::get(CI->getType());
769 IRBuilder<> B(CI);
770 Pair = B.CreateInsertValue(Pair, LoadedValue, {0});
771 Pair = B.CreateInsertValue(Pair, TypeTestCall, {1});
772 CI->replaceAllUsesWith(Pair);
773 }
774
775 // The number of unsafe uses is initially the number of uses.
776 auto &NumUnsafeUses = NumUnsafeUsesForTypeTest[TypeTestCall];
777 NumUnsafeUses = DevirtCalls.size();
778
779 // If the function pointer has a non-call user, we cannot eliminate the type
780 // check, as one of those users may eventually call the pointer. Increment
781 // the unsafe use count to make sure it cannot reach zero.
782 if (HasNonCallUses)
783 ++NumUnsafeUses;
784 for (DevirtCallSite Call : DevirtCalls) {
785 CallSlots[{TypeId, Call.Offset}].push_back(
786 {Ptr, Call.CS, &NumUnsafeUses});
787 }
788
789 CI->eraseFromParent();
790 }
791}
792
793bool DevirtModule::run() {
794 Function *TypeTestFunc =
795 M.getFunction(Intrinsic::getName(Intrinsic::type_test));
796 Function *TypeCheckedLoadFunc =
797 M.getFunction(Intrinsic::getName(Intrinsic::type_checked_load));
798 Function *AssumeFunc = M.getFunction(Intrinsic::getName(Intrinsic::assume));
799
800 if ((!TypeTestFunc || TypeTestFunc->use_empty() || !AssumeFunc ||
801 AssumeFunc->use_empty()) &&
802 (!TypeCheckedLoadFunc || TypeCheckedLoadFunc->use_empty()))
803 return false;
804
805 if (TypeTestFunc && AssumeFunc)
806 scanTypeTestUsers(TypeTestFunc, AssumeFunc);
807
808 if (TypeCheckedLoadFunc)
809 scanTypeCheckedLoadUsers(TypeCheckedLoadFunc);
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000810
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000811 // Rebuild type metadata into a map for easy lookup.
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000812 std::vector<VTableBits> Bits;
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000813 DenseMap<Metadata *, std::set<TypeMemberInfo>> TypeIdMap;
814 buildTypeIdentifierMap(Bits, TypeIdMap);
815 if (TypeIdMap.empty())
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000816 return true;
817
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000818 // For each (type, offset) pair:
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000819 bool DidVirtualConstProp = false;
820 for (auto &S : CallSlots) {
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000821 // Search each of the members of the type identifier for the virtual
822 // function implementation at offset S.first.ByteOffset, and add to
823 // TargetsForSlot.
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000824 std::vector<VirtualCallTarget> TargetsForSlot;
Peter Collingbourne7efd7502016-06-24 21:21:32 +0000825 if (!tryFindVirtualCallTargets(TargetsForSlot, TypeIdMap[S.first.TypeID],
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000826 S.first.ByteOffset))
827 continue;
828
Ivan Krasinb05e06e2016-08-05 19:45:16 +0000829 if (trySingleImplDevirt(TargetsForSlot, S.second)) {
830 emitTargetsRemarks(TargetsForSlot);
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000831 continue;
Ivan Krasinb05e06e2016-08-05 19:45:16 +0000832 }
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000833
Ivan Krasinb05e06e2016-08-05 19:45:16 +0000834 if (tryVirtualConstProp(TargetsForSlot, S.second)) {
835 emitTargetsRemarks(TargetsForSlot);
836 DidVirtualConstProp = true;
837 }
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000838 }
839
Peter Collingbourne0312f612016-06-25 00:23:04 +0000840 // If we were able to eliminate all unsafe uses for a type checked load,
841 // eliminate the type test by replacing it with true.
842 if (TypeCheckedLoadFunc) {
843 auto True = ConstantInt::getTrue(M.getContext());
844 for (auto &&U : NumUnsafeUsesForTypeTest) {
845 if (U.second == 0) {
846 U.first->replaceAllUsesWith(True);
847 U.first->eraseFromParent();
848 }
849 }
850 }
851
Peter Collingbournedf49d1b2016-02-09 22:50:34 +0000852 // Rebuild each global we touched as part of virtual constant propagation to
853 // include the before and after bytes.
854 if (DidVirtualConstProp)
855 for (VTableBits &B : Bits)
856 rebuildGlobal(B);
857
858 return true;
859}