blob: a9d3c87a0b56f75b53476aa7e23f966fe9a44f95 [file] [log] [blame]
Clement Courbet063bed92017-11-03 12:12:27 +00001//===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10// This pass tries to partially inline the fast path of well-known library
11// functions, such as using square-root instructions for cases where sqrt()
12// does not need to set errno.
13//
14//===----------------------------------------------------------------------===//
15
16#include "llvm/ADT/Statistic.h"
17#include "llvm/Analysis/ConstantFolding.h"
18#include "llvm/Analysis/TargetLibraryInfo.h"
19#include "llvm/Analysis/TargetTransformInfo.h"
20#include "llvm/Analysis/ValueTracking.h"
David Blaikieb3bde2e2017-11-17 01:07:10 +000021#include "llvm/CodeGen/TargetLowering.h"
Clement Courbet063bed92017-11-03 12:12:27 +000022#include "llvm/CodeGen/TargetPassConfig.h"
David Blaikieb3bde2e2017-11-17 01:07:10 +000023#include "llvm/CodeGen/TargetSubtargetInfo.h"
Clement Courbet063bed92017-11-03 12:12:27 +000024#include "llvm/IR/IRBuilder.h"
Clement Courbet063bed92017-11-03 12:12:27 +000025
26using namespace llvm;
27
28#define DEBUG_TYPE "expandmemcmp"
29
30STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
31STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
32STATISTIC(NumMemCmpGreaterThanMax,
33 "Number of memcmp calls with size greater than max size");
34STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
35
36static cl::opt<unsigned> MemCmpNumLoadsPerBlock(
37 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
38 cl::desc("The number of loads per basic block for inline expansion of "
39 "memcmp that is only being compared against zero."));
40
41namespace {
42
43
44// This class provides helper functions to expand a memcmp library call into an
45// inline expansion.
46class MemCmpExpansion {
47 struct ResultBlock {
48 BasicBlock *BB = nullptr;
49 PHINode *PhiSrc1 = nullptr;
50 PHINode *PhiSrc2 = nullptr;
51
52 ResultBlock() = default;
53 };
54
55 CallInst *const CI;
56 ResultBlock ResBlock;
57 const uint64_t Size;
58 unsigned MaxLoadSize;
59 uint64_t NumLoadsNonOneByte;
60 const uint64_t NumLoadsPerBlock;
61 std::vector<BasicBlock *> LoadCmpBlocks;
62 BasicBlock *EndBlock;
63 PHINode *PhiRes;
64 const bool IsUsedForZeroCmp;
65 const DataLayout &DL;
66 IRBuilder<> Builder;
67 // Represents the decomposition in blocks of the expansion. For example,
68 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
69 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
70 // TODO(courbet): Involve the target more in this computation. On X86, 7
71 // bytes can be done more efficiently with two overlaping 4-byte loads than
72 // covering the interval with [{4, 0},{2, 4},{1, 6}}.
73 struct LoadEntry {
74 LoadEntry(unsigned LoadSize, uint64_t Offset)
75 : LoadSize(LoadSize), Offset(Offset) {
76 assert(Offset % LoadSize == 0 && "invalid load entry");
77 }
78
79 uint64_t getGEPIndex() const { return Offset / LoadSize; }
80
81 // The size of the load for this block, in bytes.
82 const unsigned LoadSize;
83 // The offset of this load WRT the base pointer, in bytes.
84 const uint64_t Offset;
85 };
86 SmallVector<LoadEntry, 8> LoadSequence;
87
88 void createLoadCmpBlocks();
89 void createResultBlock();
90 void setupResultBlockPHINodes();
91 void setupEndBlockPHINodes();
92 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
93 void emitLoadCompareBlock(unsigned BlockIndex);
94 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
95 unsigned &LoadIndex);
96 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
97 void emitMemCmpResultBlock();
98 Value *getMemCmpExpansionZeroCase();
99 Value *getMemCmpEqZeroOneBlock();
100 Value *getMemCmpOneBlock();
101
102 public:
103 MemCmpExpansion(CallInst *CI, uint64_t Size,
104 const TargetTransformInfo::MemCmpExpansionOptions &Options,
105 unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
106 unsigned NumLoadsPerBlock, const DataLayout &DL);
107
108 unsigned getNumBlocks();
109 uint64_t getNumLoads() const { return LoadSequence.size(); }
110
111 Value *getMemCmpExpansion();
112};
113
114// Initialize the basic block structure required for expansion of memcmp call
115// with given maximum load size and memcmp size parameter.
116// This structure includes:
117// 1. A list of load compare blocks - LoadCmpBlocks.
118// 2. An EndBlock, split from original instruction point, which is the block to
119// return from.
120// 3. ResultBlock, block to branch to for early exit when a
121// LoadCmpBlock finds a difference.
122MemCmpExpansion::MemCmpExpansion(
123 CallInst *const CI, uint64_t Size,
124 const TargetTransformInfo::MemCmpExpansionOptions &Options,
125 const unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
126 const unsigned NumLoadsPerBlock, const DataLayout &TheDataLayout)
127 : CI(CI),
128 Size(Size),
129 MaxLoadSize(0),
130 NumLoadsNonOneByte(0),
131 NumLoadsPerBlock(NumLoadsPerBlock),
132 IsUsedForZeroCmp(IsUsedForZeroCmp),
133 DL(TheDataLayout),
134 Builder(CI) {
135 assert(Size > 0 && "zero blocks");
136 // Scale the max size down if the target can load more bytes than we need.
137 size_t LoadSizeIndex = 0;
138 while (LoadSizeIndex < Options.LoadSizes.size() &&
139 Options.LoadSizes[LoadSizeIndex] > Size) {
140 ++LoadSizeIndex;
141 }
142 this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex];
143 // Compute the decomposition.
144 uint64_t CurSize = Size;
145 uint64_t Offset = 0;
146 while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) {
147 const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex];
148 assert(LoadSize > 0 && "zero load size");
149 const uint64_t NumLoadsForThisSize = CurSize / LoadSize;
150 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
151 // Do not expand if the total number of loads is larger than what the
152 // target allows. Note that it's important that we exit before completing
153 // the expansion to avoid using a ton of memory to store the expansion for
154 // large sizes.
155 LoadSequence.clear();
156 return;
157 }
158 if (NumLoadsForThisSize > 0) {
159 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
160 LoadSequence.push_back({LoadSize, Offset});
161 Offset += LoadSize;
162 }
163 if (LoadSize > 1) {
164 ++NumLoadsNonOneByte;
165 }
166 CurSize = CurSize % LoadSize;
167 }
168 ++LoadSizeIndex;
169 }
170 assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
171}
172
173unsigned MemCmpExpansion::getNumBlocks() {
174 if (IsUsedForZeroCmp)
175 return getNumLoads() / NumLoadsPerBlock +
176 (getNumLoads() % NumLoadsPerBlock != 0 ? 1 : 0);
177 return getNumLoads();
178}
179
180void MemCmpExpansion::createLoadCmpBlocks() {
181 for (unsigned i = 0; i < getNumBlocks(); i++) {
182 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
183 EndBlock->getParent(), EndBlock);
184 LoadCmpBlocks.push_back(BB);
185 }
186}
187
188void MemCmpExpansion::createResultBlock() {
189 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
190 EndBlock->getParent(), EndBlock);
191}
192
193// This function creates the IR instructions for loading and comparing 1 byte.
194// It loads 1 byte from each source of the memcmp parameters with the given
195// GEPIndex. It then subtracts the two loaded values and adds this result to the
196// final phi node for selecting the memcmp result.
197void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
198 unsigned GEPIndex) {
199 Value *Source1 = CI->getArgOperand(0);
200 Value *Source2 = CI->getArgOperand(1);
201
202 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
203 Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
204 // Cast source to LoadSizeType*.
205 if (Source1->getType() != LoadSizeType)
206 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
207 if (Source2->getType() != LoadSizeType)
208 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
209
210 // Get the base address using the GEPIndex.
211 if (GEPIndex != 0) {
212 Source1 = Builder.CreateGEP(LoadSizeType, Source1,
213 ConstantInt::get(LoadSizeType, GEPIndex));
214 Source2 = Builder.CreateGEP(LoadSizeType, Source2,
215 ConstantInt::get(LoadSizeType, GEPIndex));
216 }
217
218 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
219 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
220
221 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
222 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
223 Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
224
225 PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
226
227 if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
228 // Early exit branch if difference found to EndBlock. Otherwise, continue to
229 // next LoadCmpBlock,
230 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
231 ConstantInt::get(Diff->getType(), 0));
232 BranchInst *CmpBr =
233 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
234 Builder.Insert(CmpBr);
235 } else {
236 // The last block has an unconditional branch to EndBlock.
237 BranchInst *CmpBr = BranchInst::Create(EndBlock);
238 Builder.Insert(CmpBr);
239 }
240}
241
242/// Generate an equality comparison for one or more pairs of loaded values.
243/// This is used in the case where the memcmp() call is compared equal or not
244/// equal to zero.
245Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
246 unsigned &LoadIndex) {
247 assert(LoadIndex < getNumLoads() &&
248 "getCompareLoadPairs() called with no remaining loads");
249 std::vector<Value *> XorList, OrList;
250 Value *Diff;
251
252 const unsigned NumLoads =
253 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlock);
254
255 // For a single-block expansion, start inserting before the memcmp call.
256 if (LoadCmpBlocks.empty())
257 Builder.SetInsertPoint(CI);
258 else
259 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
260
261 Value *Cmp = nullptr;
262 // If we have multiple loads per block, we need to generate a composite
263 // comparison using xor+or. The type for the combinations is the largest load
264 // type.
265 IntegerType *const MaxLoadType =
266 NumLoads == 1 ? nullptr
267 : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
268 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
269 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
270
271 IntegerType *LoadSizeType =
272 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
273
274 Value *Source1 = CI->getArgOperand(0);
275 Value *Source2 = CI->getArgOperand(1);
276
277 // Cast source to LoadSizeType*.
278 if (Source1->getType() != LoadSizeType)
279 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
280 if (Source2->getType() != LoadSizeType)
281 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
282
283 // Get the base address using a GEP.
284 if (CurLoadEntry.Offset != 0) {
285 Source1 = Builder.CreateGEP(
286 LoadSizeType, Source1,
287 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
288 Source2 = Builder.CreateGEP(
289 LoadSizeType, Source2,
290 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
291 }
292
293 // Get a constant or load a value for each source address.
294 Value *LoadSrc1 = nullptr;
295 if (auto *Source1C = dyn_cast<Constant>(Source1))
296 LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
297 if (!LoadSrc1)
298 LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
299
300 Value *LoadSrc2 = nullptr;
301 if (auto *Source2C = dyn_cast<Constant>(Source2))
302 LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
303 if (!LoadSrc2)
304 LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
305
306 if (NumLoads != 1) {
307 if (LoadSizeType != MaxLoadType) {
308 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
309 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
310 }
311 // If we have multiple loads per block, we need to generate a composite
312 // comparison using xor+or.
313 Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
314 Diff = Builder.CreateZExt(Diff, MaxLoadType);
315 XorList.push_back(Diff);
316 } else {
317 // If there's only one load per block, we just compare the loaded values.
318 Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
319 }
320 }
321
322 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
323 std::vector<Value *> OutList;
324 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
325 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
326 OutList.push_back(Or);
327 }
328 if (InList.size() % 2 != 0)
329 OutList.push_back(InList.back());
330 return OutList;
331 };
332
333 if (!Cmp) {
334 // Pairwise OR the XOR results.
335 OrList = pairWiseOr(XorList);
336
337 // Pairwise OR the OR results until one result left.
338 while (OrList.size() != 1) {
339 OrList = pairWiseOr(OrList);
340 }
341 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
342 }
343
344 return Cmp;
345}
346
347void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
348 unsigned &LoadIndex) {
349 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
350
351 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
352 ? EndBlock
353 : LoadCmpBlocks[BlockIndex + 1];
354 // Early exit branch if difference found to ResultBlock. Otherwise,
355 // continue to next LoadCmpBlock or EndBlock.
356 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
357 Builder.Insert(CmpBr);
358
359 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
360 // since early exit to ResultBlock was not taken (no difference was found in
361 // any of the bytes).
362 if (BlockIndex == LoadCmpBlocks.size() - 1) {
363 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
364 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
365 }
366}
367
368// This function creates the IR intructions for loading and comparing using the
369// given LoadSize. It loads the number of bytes specified by LoadSize from each
370// source of the memcmp parameters. It then does a subtract to see if there was
371// a difference in the loaded values. If a difference is found, it branches
372// with an early exit to the ResultBlock for calculating which source was
373// larger. Otherwise, it falls through to the either the next LoadCmpBlock or
374// the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
375// a special case through emitLoadCompareByteBlock. The special handling can
376// simply subtract the loaded values and add it to the result phi node.
377void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
378 // There is one load per block in this case, BlockIndex == LoadIndex.
379 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
380
381 if (CurLoadEntry.LoadSize == 1) {
382 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
383 CurLoadEntry.getGEPIndex());
384 return;
385 }
386
387 Type *LoadSizeType =
388 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
389 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
390 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
391
392 Value *Source1 = CI->getArgOperand(0);
393 Value *Source2 = CI->getArgOperand(1);
394
395 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
396 // Cast source to LoadSizeType*.
397 if (Source1->getType() != LoadSizeType)
398 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
399 if (Source2->getType() != LoadSizeType)
400 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
401
402 // Get the base address using a GEP.
403 if (CurLoadEntry.Offset != 0) {
404 Source1 = Builder.CreateGEP(
405 LoadSizeType, Source1,
406 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
407 Source2 = Builder.CreateGEP(
408 LoadSizeType, Source2,
409 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
410 }
411
412 // Load LoadSizeType from the base address.
413 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
414 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
415
416 if (DL.isLittleEndian()) {
417 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
418 Intrinsic::bswap, LoadSizeType);
419 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
420 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
421 }
422
423 if (LoadSizeType != MaxLoadType) {
424 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
425 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
426 }
427
428 // Add the loaded values to the phi nodes for calculating memcmp result only
429 // if result is not used in a zero equality.
430 if (!IsUsedForZeroCmp) {
431 ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
432 ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
433 }
434
435 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
436 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
437 ? EndBlock
438 : LoadCmpBlocks[BlockIndex + 1];
439 // Early exit branch if difference found to ResultBlock. Otherwise, continue
440 // to next LoadCmpBlock or EndBlock.
441 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
442 Builder.Insert(CmpBr);
443
444 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
445 // since early exit to ResultBlock was not taken (no difference was found in
446 // any of the bytes).
447 if (BlockIndex == LoadCmpBlocks.size() - 1) {
448 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
449 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
450 }
451}
452
453// This function populates the ResultBlock with a sequence to calculate the
454// memcmp result. It compares the two loaded source values and returns -1 if
455// src1 < src2 and 1 if src1 > src2.
456void MemCmpExpansion::emitMemCmpResultBlock() {
457 // Special case: if memcmp result is used in a zero equality, result does not
458 // need to be calculated and can simply return 1.
459 if (IsUsedForZeroCmp) {
460 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
461 Builder.SetInsertPoint(ResBlock.BB, InsertPt);
462 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
463 PhiRes->addIncoming(Res, ResBlock.BB);
464 BranchInst *NewBr = BranchInst::Create(EndBlock);
465 Builder.Insert(NewBr);
466 return;
467 }
468 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
469 Builder.SetInsertPoint(ResBlock.BB, InsertPt);
470
471 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
472 ResBlock.PhiSrc2);
473
474 Value *Res =
475 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
476 ConstantInt::get(Builder.getInt32Ty(), 1));
477
478 BranchInst *NewBr = BranchInst::Create(EndBlock);
479 Builder.Insert(NewBr);
480 PhiRes->addIncoming(Res, ResBlock.BB);
481}
482
483void MemCmpExpansion::setupResultBlockPHINodes() {
484 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
485 Builder.SetInsertPoint(ResBlock.BB);
486 // Note: this assumes one load per block.
487 ResBlock.PhiSrc1 =
488 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
489 ResBlock.PhiSrc2 =
490 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
491}
492
493void MemCmpExpansion::setupEndBlockPHINodes() {
494 Builder.SetInsertPoint(&EndBlock->front());
495 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
496}
497
498Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
499 unsigned LoadIndex = 0;
500 // This loop populates each of the LoadCmpBlocks with the IR sequence to
501 // handle multiple loads per block.
502 for (unsigned I = 0; I < getNumBlocks(); ++I) {
503 emitLoadCompareBlockMultipleLoads(I, LoadIndex);
504 }
505
506 emitMemCmpResultBlock();
507 return PhiRes;
508}
509
510/// A memcmp expansion that compares equality with 0 and only has one block of
511/// load and compare can bypass the compare, branch, and phi IR that is required
512/// in the general case.
513Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
514 unsigned LoadIndex = 0;
515 Value *Cmp = getCompareLoadPairs(0, LoadIndex);
516 assert(LoadIndex == getNumLoads() && "some entries were not consumed");
517 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
518}
519
520/// A memcmp expansion that only has one block of load and compare can bypass
521/// the compare, branch, and phi IR that is required in the general case.
522Value *MemCmpExpansion::getMemCmpOneBlock() {
523 assert(NumLoadsPerBlock == 1 && "Only handles one load pair per block");
524
525 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
526 Value *Source1 = CI->getArgOperand(0);
527 Value *Source2 = CI->getArgOperand(1);
528
529 // Cast source to LoadSizeType*.
530 if (Source1->getType() != LoadSizeType)
531 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
532 if (Source2->getType() != LoadSizeType)
533 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
534
535 // Load LoadSizeType from the base address.
536 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
537 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
538
539 if (DL.isLittleEndian() && Size != 1) {
540 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
541 Intrinsic::bswap, LoadSizeType);
542 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
543 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
544 }
545
546 if (Size < 4) {
547 // The i8 and i16 cases don't need compares. We zext the loaded values and
548 // subtract them to get the suitable negative, zero, or positive i32 result.
549 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
550 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
551 return Builder.CreateSub(LoadSrc1, LoadSrc2);
552 }
553
554 // The result of memcmp is negative, zero, or positive, so produce that by
555 // subtracting 2 extended compare bits: sub (ugt, ult).
556 // If a target prefers to use selects to get -1/0/1, they should be able
557 // to transform this later. The inverse transform (going from selects to math)
558 // may not be possible in the DAG because the selects got converted into
559 // branches before we got there.
560 Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
561 Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
562 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
563 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
564 return Builder.CreateSub(ZextUGT, ZextULT);
565}
566
567// This function expands the memcmp call into an inline expansion and returns
568// the memcmp result.
569Value *MemCmpExpansion::getMemCmpExpansion() {
570 // A memcmp with zero-comparison with only one block of load and compare does
571 // not need to set up any extra blocks. This case could be handled in the DAG,
572 // but since we have all of the machinery to flexibly expand any memcpy here,
573 // we choose to handle this case too to avoid fragmented lowering.
574 if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || getNumBlocks() != 1) {
575 BasicBlock *StartBlock = CI->getParent();
576 EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
577 setupEndBlockPHINodes();
578 createResultBlock();
579
580 // If return value of memcmp is not used in a zero equality, we need to
581 // calculate which source was larger. The calculation requires the
582 // two loaded source values of each load compare block.
583 // These will be saved in the phi nodes created by setupResultBlockPHINodes.
584 if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
585
586 // Create the number of required load compare basic blocks.
587 createLoadCmpBlocks();
588
589 // Update the terminator added by splitBasicBlock to branch to the first
590 // LoadCmpBlock.
591 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
592 }
593
594 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
595
596 if (IsUsedForZeroCmp)
597 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
598 : getMemCmpExpansionZeroCase();
599
600 // TODO: Handle more than one load pair per block in getMemCmpOneBlock().
601 if (getNumBlocks() == 1 && NumLoadsPerBlock == 1) return getMemCmpOneBlock();
602
603 for (unsigned I = 0; I < getNumBlocks(); ++I) {
604 emitLoadCompareBlock(I);
605 }
606
607 emitMemCmpResultBlock();
608 return PhiRes;
609}
610
611// This function checks to see if an expansion of memcmp can be generated.
612// It checks for constant compare size that is less than the max inline size.
613// If an expansion cannot occur, returns false to leave as a library call.
614// Otherwise, the library call is replaced with a new IR instruction sequence.
615/// We want to transform:
616/// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
617/// To:
618/// loadbb:
619/// %0 = bitcast i32* %buffer2 to i8*
620/// %1 = bitcast i32* %buffer1 to i8*
621/// %2 = bitcast i8* %1 to i64*
622/// %3 = bitcast i8* %0 to i64*
623/// %4 = load i64, i64* %2
624/// %5 = load i64, i64* %3
625/// %6 = call i64 @llvm.bswap.i64(i64 %4)
626/// %7 = call i64 @llvm.bswap.i64(i64 %5)
627/// %8 = sub i64 %6, %7
628/// %9 = icmp ne i64 %8, 0
629/// br i1 %9, label %res_block, label %loadbb1
630/// res_block: ; preds = %loadbb2,
631/// %loadbb1, %loadbb
632/// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
633/// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
634/// %10 = icmp ult i64 %phi.src1, %phi.src2
635/// %11 = select i1 %10, i32 -1, i32 1
636/// br label %endblock
637/// loadbb1: ; preds = %loadbb
638/// %12 = bitcast i32* %buffer2 to i8*
639/// %13 = bitcast i32* %buffer1 to i8*
640/// %14 = bitcast i8* %13 to i32*
641/// %15 = bitcast i8* %12 to i32*
642/// %16 = getelementptr i32, i32* %14, i32 2
643/// %17 = getelementptr i32, i32* %15, i32 2
644/// %18 = load i32, i32* %16
645/// %19 = load i32, i32* %17
646/// %20 = call i32 @llvm.bswap.i32(i32 %18)
647/// %21 = call i32 @llvm.bswap.i32(i32 %19)
648/// %22 = zext i32 %20 to i64
649/// %23 = zext i32 %21 to i64
650/// %24 = sub i64 %22, %23
651/// %25 = icmp ne i64 %24, 0
652/// br i1 %25, label %res_block, label %loadbb2
653/// loadbb2: ; preds = %loadbb1
654/// %26 = bitcast i32* %buffer2 to i8*
655/// %27 = bitcast i32* %buffer1 to i8*
656/// %28 = bitcast i8* %27 to i16*
657/// %29 = bitcast i8* %26 to i16*
658/// %30 = getelementptr i16, i16* %28, i16 6
659/// %31 = getelementptr i16, i16* %29, i16 6
660/// %32 = load i16, i16* %30
661/// %33 = load i16, i16* %31
662/// %34 = call i16 @llvm.bswap.i16(i16 %32)
663/// %35 = call i16 @llvm.bswap.i16(i16 %33)
664/// %36 = zext i16 %34 to i64
665/// %37 = zext i16 %35 to i64
666/// %38 = sub i64 %36, %37
667/// %39 = icmp ne i64 %38, 0
668/// br i1 %39, label %res_block, label %loadbb3
669/// loadbb3: ; preds = %loadbb2
670/// %40 = bitcast i32* %buffer2 to i8*
671/// %41 = bitcast i32* %buffer1 to i8*
672/// %42 = getelementptr i8, i8* %41, i8 14
673/// %43 = getelementptr i8, i8* %40, i8 14
674/// %44 = load i8, i8* %42
675/// %45 = load i8, i8* %43
676/// %46 = zext i8 %44 to i32
677/// %47 = zext i8 %45 to i32
678/// %48 = sub i32 %46, %47
679/// br label %endblock
680/// endblock: ; preds = %res_block,
681/// %loadbb3
682/// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
683/// ret i32 %phi.res
684static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
685 const TargetLowering *TLI, const DataLayout *DL) {
686 NumMemCmpCalls++;
687
688 // Early exit from expansion if -Oz.
689 if (CI->getFunction()->optForMinSize())
690 return false;
691
692 // Early exit from expansion if size is not a constant.
693 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
694 if (!SizeCast) {
695 NumMemCmpNotConstant++;
696 return false;
697 }
698 const uint64_t SizeVal = SizeCast->getZExtValue();
699
700 if (SizeVal == 0) {
701 return false;
702 }
703
704 // TTI call to check if target would like to expand memcmp. Also, get the
705 // available load sizes.
706 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
707 const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp);
708 if (!Options) return false;
709
710 const unsigned MaxNumLoads =
711 TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize());
712
713 MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads,
714 IsUsedForZeroCmp, MemCmpNumLoadsPerBlock, *DL);
715
716 // Don't expand if this will require more loads than desired by the target.
717 if (Expansion.getNumLoads() == 0) {
718 NumMemCmpGreaterThanMax++;
719 return false;
720 }
721
722 NumMemCmpInlined++;
723
724 Value *Res = Expansion.getMemCmpExpansion();
725
726 // Replace call with result of expansion and erase call.
727 CI->replaceAllUsesWith(Res);
728 CI->eraseFromParent();
729
730 return true;
731}
732
733
734
735class ExpandMemCmpPass : public FunctionPass {
736public:
737 static char ID;
738
739 ExpandMemCmpPass() : FunctionPass(ID) {
740 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
741 }
742
743 bool runOnFunction(Function &F) override {
744 if (skipFunction(F)) return false;
745
746 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
747 if (!TPC) {
748 return false;
749 }
750 const TargetLowering* TL =
751 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
752
753 const TargetLibraryInfo *TLI =
754 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
755 const TargetTransformInfo *TTI =
756 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
757 auto PA = runImpl(F, TLI, TTI, TL);
758 return !PA.areAllPreserved();
759 }
760
761private:
762 void getAnalysisUsage(AnalysisUsage &AU) const override {
763 AU.addRequired<TargetLibraryInfoWrapperPass>();
764 AU.addRequired<TargetTransformInfoWrapperPass>();
765 FunctionPass::getAnalysisUsage(AU);
766 }
767
768 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
769 const TargetTransformInfo *TTI,
770 const TargetLowering* TL);
771 // Returns true if a change was made.
772 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
773 const TargetTransformInfo *TTI, const TargetLowering* TL,
774 const DataLayout& DL);
775};
776
777bool ExpandMemCmpPass::runOnBlock(
778 BasicBlock &BB, const TargetLibraryInfo *TLI,
779 const TargetTransformInfo *TTI, const TargetLowering* TL,
780 const DataLayout& DL) {
781 for (Instruction& I : BB) {
782 CallInst *CI = dyn_cast<CallInst>(&I);
783 if (!CI) {
784 continue;
785 }
786 LibFunc Func;
787 if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
788 Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) {
789 return true;
790 }
791 }
792 return false;
793}
794
795
796PreservedAnalyses ExpandMemCmpPass::runImpl(
797 Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
798 const TargetLowering* TL) {
799 const DataLayout& DL = F.getParent()->getDataLayout();
800 bool MadeChanges = false;
801 for (auto BBIt = F.begin(); BBIt != F.end();) {
802 if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
803 MadeChanges = true;
804 // If changes were made, restart the function from the beginning, since
805 // the structure of the function was changed.
806 BBIt = F.begin();
807 } else {
808 ++BBIt;
809 }
810 }
811 return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
812}
813
814} // namespace
815
816char ExpandMemCmpPass::ID = 0;
817INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
818 "Expand memcmp() to load/stores", false, false)
819INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
820INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
821INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
822 "Expand memcmp() to load/stores", false, false)
823
824FunctionPass *llvm::createExpandMemCmpPass() {
825 return new ExpandMemCmpPass();
826}