blob: e73875de709ef3848cd1daec5b7af4a6bed3d968 [file] [log] [blame]
Krzysztof Parzyszekc8b94382017-01-26 21:41:10 +00001//===--- HexagonLoopIdiomRecognition.cpp ----------------------------------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9
10#define DEBUG_TYPE "hexagon-lir"
11
12#include "llvm/ADT/SetVector.h"
13#include "llvm/ADT/SmallSet.h"
14#include "llvm/Analysis/AliasAnalysis.h"
15#include "llvm/Analysis/InstructionSimplify.h"
16#include "llvm/Analysis/LoopPass.h"
17#include "llvm/Analysis/ScalarEvolution.h"
18#include "llvm/Analysis/ScalarEvolutionExpander.h"
19#include "llvm/Analysis/ScalarEvolutionExpressions.h"
20#include "llvm/Analysis/TargetLibraryInfo.h"
21#include "llvm/Analysis/ValueTracking.h"
22#include "llvm/IR/DataLayout.h"
23#include "llvm/IR/Dominators.h"
24#include "llvm/IR/IRBuilder.h"
25#include "llvm/IR/PatternMatch.h"
26#include "llvm/Transforms/Scalar.h"
27#include "llvm/Transforms/Utils/Local.h"
28#include "llvm/Support/Debug.h"
29#include "llvm/Support/raw_ostream.h"
30
31#include <algorithm>
32#include <array>
33
34using namespace llvm;
35
36static cl::opt<bool> DisableMemcpyIdiom("disable-memcpy-idiom",
37 cl::Hidden, cl::init(false),
38 cl::desc("Disable generation of memcpy in loop idiom recognition"));
39
40static cl::opt<bool> DisableMemmoveIdiom("disable-memmove-idiom",
41 cl::Hidden, cl::init(false),
42 cl::desc("Disable generation of memmove in loop idiom recognition"));
43
44static cl::opt<unsigned> RuntimeMemSizeThreshold("runtime-mem-idiom-threshold",
45 cl::Hidden, cl::init(0), cl::desc("Threshold (in bytes) for the runtime "
46 "check guarding the memmove."));
47
48static cl::opt<unsigned> CompileTimeMemSizeThreshold(
49 "compile-time-mem-idiom-threshold", cl::Hidden, cl::init(64),
50 cl::desc("Threshold (in bytes) to perform the transformation, if the "
51 "runtime loop count (mem transfer size) is known at compile-time."));
52
53static cl::opt<bool> OnlyNonNestedMemmove("only-nonnested-memmove-idiom",
54 cl::Hidden, cl::init(true),
55 cl::desc("Only enable generating memmove in non-nested loops"));
56
57cl::opt<bool> HexagonVolatileMemcpy("disable-hexagon-volatile-memcpy",
58 cl::Hidden, cl::init(false),
59 cl::desc("Enable Hexagon-specific memcpy for volatile destination."));
60
61static const char *HexagonVolatileMemcpyName
62 = "hexagon_memcpy_forward_vp4cp4n2";
63
64
65namespace llvm {
66 void initializeHexagonLoopIdiomRecognizePass(PassRegistry&);
67 Pass *createHexagonLoopIdiomPass();
68}
69
70namespace {
71 class HexagonLoopIdiomRecognize : public LoopPass {
72 public:
73 static char ID;
74 explicit HexagonLoopIdiomRecognize() : LoopPass(ID) {
75 initializeHexagonLoopIdiomRecognizePass(*PassRegistry::getPassRegistry());
76 }
77 StringRef getPassName() const override {
78 return "Recognize Hexagon-specific loop idioms";
79 }
80
81 void getAnalysisUsage(AnalysisUsage &AU) const override {
82 AU.addRequired<LoopInfoWrapperPass>();
83 AU.addRequiredID(LoopSimplifyID);
84 AU.addRequiredID(LCSSAID);
85 AU.addRequired<AAResultsWrapperPass>();
86 AU.addPreserved<AAResultsWrapperPass>();
87 AU.addRequired<ScalarEvolutionWrapperPass>();
88 AU.addRequired<DominatorTreeWrapperPass>();
89 AU.addRequired<TargetLibraryInfoWrapperPass>();
90 AU.addPreserved<TargetLibraryInfoWrapperPass>();
91 }
92
93 bool runOnLoop(Loop *L, LPPassManager &LPM) override;
94
95 private:
96 unsigned getStoreSizeInBytes(StoreInst *SI);
97 int getSCEVStride(const SCEVAddRecExpr *StoreEv);
98 bool isLegalStore(Loop *CurLoop, StoreInst *SI);
99 void collectStores(Loop *CurLoop, BasicBlock *BB,
100 SmallVectorImpl<StoreInst*> &Stores);
101 bool processCopyingStore(Loop *CurLoop, StoreInst *SI, const SCEV *BECount);
102 bool coverLoop(Loop *L, SmallVectorImpl<Instruction*> &Insts) const;
103 bool runOnLoopBlock(Loop *CurLoop, BasicBlock *BB, const SCEV *BECount,
104 SmallVectorImpl<BasicBlock*> &ExitBlocks);
105 bool runOnCountableLoop(Loop *L);
106
107 AliasAnalysis *AA;
108 const DataLayout *DL;
109 DominatorTree *DT;
110 LoopInfo *LF;
111 const TargetLibraryInfo *TLI;
112 ScalarEvolution *SE;
113 bool HasMemcpy, HasMemmove;
114 };
115}
116
117char HexagonLoopIdiomRecognize::ID = 0;
118
119INITIALIZE_PASS_BEGIN(HexagonLoopIdiomRecognize, "hexagon-loop-idiom",
120 "Recognize Hexagon-specific loop idioms", false, false)
121INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
122INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
123INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
124INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
125INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
126INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
127INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
128INITIALIZE_PASS_END(HexagonLoopIdiomRecognize, "hexagon-loop-idiom",
129 "Recognize Hexagon-specific loop idioms", false, false)
130
131
132//===----------------------------------------------------------------------===//
133//
134// Implementation of PolynomialMultiplyRecognize
135//
136//===----------------------------------------------------------------------===//
137
138namespace {
139 class PolynomialMultiplyRecognize {
140 public:
141 explicit PolynomialMultiplyRecognize(Loop *loop, const DataLayout &dl,
142 const DominatorTree &dt, const TargetLibraryInfo &tli,
143 ScalarEvolution &se)
144 : CurLoop(loop), DL(dl), DT(dt), TLI(tli), SE(se) {}
145
146 bool recognize();
147 private:
148 typedef SetVector<Value*> ValueSeq;
149
150 Value *getCountIV(BasicBlock *BB);
151 bool findCycle(Value *Out, Value *In, ValueSeq &Cycle);
152 void classifyCycle(Instruction *DivI, ValueSeq &Cycle, ValueSeq &Early,
153 ValueSeq &Late);
154 bool classifyInst(Instruction *UseI, ValueSeq &Early, ValueSeq &Late);
155 bool commutesWithShift(Instruction *I);
156 bool highBitsAreZero(Value *V, unsigned IterCount);
157 bool keepsHighBitsZero(Value *V, unsigned IterCount);
158 bool isOperandShifted(Instruction *I, Value *Op);
159 bool convertShiftsToLeft(BasicBlock *LoopB, BasicBlock *ExitB,
160 unsigned IterCount);
161 void cleanupLoopBody(BasicBlock *LoopB);
162
163 struct ParsedValues {
164 ParsedValues() : M(nullptr), P(nullptr), Q(nullptr), R(nullptr),
165 X(nullptr), Res(nullptr), IterCount(0), Left(false), Inv(false) {}
166 Value *M, *P, *Q, *R, *X;
167 Instruction *Res;
168 unsigned IterCount;
169 bool Left, Inv;
170 };
171
172 bool matchLeftShift(SelectInst *SelI, Value *CIV, ParsedValues &PV);
173 bool matchRightShift(SelectInst *SelI, ParsedValues &PV);
174 bool scanSelect(SelectInst *SI, BasicBlock *LoopB, BasicBlock *PrehB,
175 Value *CIV, ParsedValues &PV, bool PreScan);
176 unsigned getInverseMxN(unsigned QP);
177 Value *generate(BasicBlock::iterator At, ParsedValues &PV);
178
179 Loop *CurLoop;
180 const DataLayout &DL;
181 const DominatorTree &DT;
182 const TargetLibraryInfo &TLI;
183 ScalarEvolution &SE;
184 };
185}
186
187
188Value *PolynomialMultiplyRecognize::getCountIV(BasicBlock *BB) {
189 pred_iterator PI = pred_begin(BB), PE = pred_end(BB);
190 if (std::distance(PI, PE) != 2)
191 return nullptr;
192 BasicBlock *PB = (*PI == BB) ? *std::next(PI) : *PI;
193
194 for (auto I = BB->begin(), E = BB->end(); I != E && isa<PHINode>(I); ++I) {
195 auto *PN = cast<PHINode>(I);
196 Value *InitV = PN->getIncomingValueForBlock(PB);
197 if (!isa<ConstantInt>(InitV) || !cast<ConstantInt>(InitV)->isZero())
198 continue;
199 Value *IterV = PN->getIncomingValueForBlock(BB);
200 if (!isa<BinaryOperator>(IterV))
201 continue;
202 auto *BO = dyn_cast<BinaryOperator>(IterV);
203 if (BO->getOpcode() != Instruction::Add)
204 continue;
205 Value *IncV = nullptr;
206 if (BO->getOperand(0) == PN)
207 IncV = BO->getOperand(1);
208 else if (BO->getOperand(1) == PN)
209 IncV = BO->getOperand(0);
210 if (IncV == nullptr)
211 continue;
212
213 if (auto *T = dyn_cast<ConstantInt>(IncV))
214 if (T->getZExtValue() == 1)
215 return PN;
216 }
217 return nullptr;
218}
219
220
221static void replaceAllUsesOfWithIn(Value *I, Value *J, BasicBlock *BB) {
222 for (auto UI = I->user_begin(), UE = I->user_end(); UI != UE;) {
223 Use &TheUse = UI.getUse();
224 ++UI;
225 if (auto *II = dyn_cast<Instruction>(TheUse.getUser()))
226 if (BB == II->getParent())
227 II->replaceUsesOfWith(I, J);
228 }
229}
230
231
232bool PolynomialMultiplyRecognize::matchLeftShift(SelectInst *SelI,
233 Value *CIV, ParsedValues &PV) {
234 // Match the following:
235 // select (X & (1 << i)) != 0 ? R ^ (Q << i) : R
236 // select (X & (1 << i)) == 0 ? R : R ^ (Q << i)
237 // The condition may also check for equality with the masked value, i.e
238 // select (X & (1 << i)) == (1 << i) ? R ^ (Q << i) : R
239 // select (X & (1 << i)) != (1 << i) ? R : R ^ (Q << i);
240
241 Value *CondV = SelI->getCondition();
242 Value *TrueV = SelI->getTrueValue();
243 Value *FalseV = SelI->getFalseValue();
244
245 using namespace PatternMatch;
246
247 CmpInst::Predicate P;
248 Value *A = nullptr, *B = nullptr, *C = nullptr;
249
250 if (!match(CondV, m_ICmp(P, m_And(m_Value(A), m_Value(B)), m_Value(C))) &&
251 !match(CondV, m_ICmp(P, m_Value(C), m_And(m_Value(A), m_Value(B)))))
252 return false;
253 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE)
254 return false;
255 // Matched: select (A & B) == C ? ... : ...
256 // select (A & B) != C ? ... : ...
257
258 Value *X = nullptr, *Sh1 = nullptr;
259 // Check (A & B) for (X & (1 << i)):
260 if (match(A, m_Shl(m_One(), m_Specific(CIV)))) {
261 Sh1 = A;
262 X = B;
263 } else if (match(B, m_Shl(m_One(), m_Specific(CIV)))) {
264 Sh1 = B;
265 X = A;
266 } else {
267 // TODO: Could also check for an induction variable containing single
268 // bit shifted left by 1 in each iteration.
269 return false;
270 }
271
272 bool TrueIfZero;
273
274 // Check C against the possible values for comparison: 0 and (1 << i):
275 if (match(C, m_Zero()))
276 TrueIfZero = (P == CmpInst::ICMP_EQ);
277 else if (C == Sh1)
278 TrueIfZero = (P == CmpInst::ICMP_NE);
279 else
280 return false;
281
282 // So far, matched:
283 // select (X & (1 << i)) ? ... : ...
284 // including variations of the check against zero/non-zero value.
285
286 Value *ShouldSameV = nullptr, *ShouldXoredV = nullptr;
287 if (TrueIfZero) {
288 ShouldSameV = TrueV;
289 ShouldXoredV = FalseV;
290 } else {
291 ShouldSameV = FalseV;
292 ShouldXoredV = TrueV;
293 }
294
295 Value *Q = nullptr, *R = nullptr, *Y = nullptr, *Z = nullptr;
296 Value *T = nullptr;
297 if (match(ShouldXoredV, m_Xor(m_Value(Y), m_Value(Z)))) {
298 // Matched: select +++ ? ... : Y ^ Z
299 // select +++ ? Y ^ Z : ...
300 // where +++ denotes previously checked matches.
301 if (ShouldSameV == Y)
302 T = Z;
303 else if (ShouldSameV == Z)
304 T = Y;
305 else
306 return false;
307 R = ShouldSameV;
308 // Matched: select +++ ? R : R ^ T
309 // select +++ ? R ^ T : R
310 // depending on TrueIfZero.
311
312 } else if (match(ShouldSameV, m_Zero())) {
313 // Matched: select +++ ? 0 : ...
314 // select +++ ? ... : 0
315 if (!SelI->hasOneUse())
316 return false;
317 T = ShouldXoredV;
318 // Matched: select +++ ? 0 : T
319 // select +++ ? T : 0
320
321 Value *U = *SelI->user_begin();
322 if (!match(U, m_Xor(m_Specific(SelI), m_Value(R))) &&
323 !match(U, m_Xor(m_Value(R), m_Specific(SelI))))
324 return false;
325 // Matched: xor (select +++ ? 0 : T), R
326 // xor (select +++ ? T : 0), R
327 } else
328 return false;
329
330 // The xor input value T is isolated into its own match so that it could
331 // be checked against an induction variable containing a shifted bit
332 // (todo).
333 // For now, check against (Q << i).
334 if (!match(T, m_Shl(m_Value(Q), m_Specific(CIV))) &&
335 !match(T, m_Shl(m_ZExt(m_Value(Q)), m_ZExt(m_Specific(CIV)))))
336 return false;
337 // Matched: select +++ ? R : R ^ (Q << i)
338 // select +++ ? R ^ (Q << i) : R
339
340 PV.X = X;
341 PV.Q = Q;
342 PV.R = R;
343 PV.Left = true;
344 return true;
345}
346
347
348bool PolynomialMultiplyRecognize::matchRightShift(SelectInst *SelI,
349 ParsedValues &PV) {
350 // Match the following:
351 // select (X & 1) != 0 ? (R >> 1) ^ Q : (R >> 1)
352 // select (X & 1) == 0 ? (R >> 1) : (R >> 1) ^ Q
353 // The condition may also check for equality with the masked value, i.e
354 // select (X & 1) == 1 ? (R >> 1) ^ Q : (R >> 1)
355 // select (X & 1) != 1 ? (R >> 1) : (R >> 1) ^ Q
356
357 Value *CondV = SelI->getCondition();
358 Value *TrueV = SelI->getTrueValue();
359 Value *FalseV = SelI->getFalseValue();
360
361 using namespace PatternMatch;
362
363 Value *C = nullptr;
364 CmpInst::Predicate P;
365 bool TrueIfZero;
366
367 if (match(CondV, m_ICmp(P, m_Value(C), m_Zero())) ||
368 match(CondV, m_ICmp(P, m_Zero(), m_Value(C)))) {
369 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE)
370 return false;
371 // Matched: select C == 0 ? ... : ...
372 // select C != 0 ? ... : ...
373 TrueIfZero = (P == CmpInst::ICMP_EQ);
374 } else if (match(CondV, m_ICmp(P, m_Value(C), m_One())) ||
375 match(CondV, m_ICmp(P, m_One(), m_Value(C)))) {
376 if (P != CmpInst::ICMP_EQ && P != CmpInst::ICMP_NE)
377 return false;
378 // Matched: select C == 1 ? ... : ...
379 // select C != 1 ? ... : ...
380 TrueIfZero = (P == CmpInst::ICMP_NE);
381 } else
382 return false;
383
384 Value *X = nullptr;
385 if (!match(C, m_And(m_Value(X), m_One())) &&
386 !match(C, m_And(m_One(), m_Value(X))))
387 return false;
388 // Matched: select (X & 1) == +++ ? ... : ...
389 // select (X & 1) != +++ ? ... : ...
390
391 Value *R = nullptr, *Q = nullptr;
392 if (TrueIfZero) {
393 // The select's condition is true if the tested bit is 0.
394 // TrueV must be the shift, FalseV must be the xor.
395 if (!match(TrueV, m_LShr(m_Value(R), m_One())))
396 return false;
397 // Matched: select +++ ? (R >> 1) : ...
398 if (!match(FalseV, m_Xor(m_Specific(TrueV), m_Value(Q))) &&
399 !match(FalseV, m_Xor(m_Value(Q), m_Specific(TrueV))))
400 return false;
401 // Matched: select +++ ? (R >> 1) : (R >> 1) ^ Q
402 // with commuting ^.
403 } else {
404 // The select's condition is true if the tested bit is 1.
405 // TrueV must be the xor, FalseV must be the shift.
406 if (!match(FalseV, m_LShr(m_Value(R), m_One())))
407 return false;
408 // Matched: select +++ ? ... : (R >> 1)
409 if (!match(TrueV, m_Xor(m_Specific(FalseV), m_Value(Q))) &&
410 !match(TrueV, m_Xor(m_Value(Q), m_Specific(FalseV))))
411 return false;
412 // Matched: select +++ ? (R >> 1) ^ Q : (R >> 1)
413 // with commuting ^.
414 }
415
416 PV.X = X;
417 PV.Q = Q;
418 PV.R = R;
419 PV.Left = false;
420 return true;
421}
422
423
424bool PolynomialMultiplyRecognize::scanSelect(SelectInst *SelI,
425 BasicBlock *LoopB, BasicBlock *PrehB, Value *CIV, ParsedValues &PV,
426 bool PreScan) {
427 using namespace PatternMatch;
428
429 // The basic pattern for R = P.Q is:
430 // for i = 0..31
431 // R = phi (0, R')
432 // if (P & (1 << i)) ; test-bit(P, i)
433 // R' = R ^ (Q << i)
434 //
435 // Similarly, the basic pattern for R = (P/Q).Q - P
436 // for i = 0..31
437 // R = phi(P, R')
438 // if (R & (1 << i))
439 // R' = R ^ (Q << i)
440
441 // There exist idioms, where instead of Q being shifted left, P is shifted
442 // right. This produces a result that is shifted right by 32 bits (the
443 // non-shifted result is 64-bit).
444 //
445 // For R = P.Q, this would be:
446 // for i = 0..31
447 // R = phi (0, R')
448 // if ((P >> i) & 1)
449 // R' = (R >> 1) ^ Q ; R is cycled through the loop, so it must
450 // else ; be shifted by 1, not i.
451 // R' = R >> 1
452 //
453 // And for the inverse:
454 // for i = 0..31
455 // R = phi (P, R')
456 // if (R & 1)
457 // R' = (R >> 1) ^ Q
458 // else
459 // R' = R >> 1
460
461 // The left-shifting idioms share the same pattern:
462 // select (X & (1 << i)) ? R ^ (Q << i) : R
463 // Similarly for right-shifting idioms:
464 // select (X & 1) ? (R >> 1) ^ Q
465
466 if (matchLeftShift(SelI, CIV, PV)) {
467 // If this is a pre-scan, getting this far is sufficient.
468 if (PreScan)
469 return true;
470
471 // Need to make sure that the SelI goes back into R.
472 auto *RPhi = dyn_cast<PHINode>(PV.R);
473 if (!RPhi)
474 return false;
475 if (SelI != RPhi->getIncomingValueForBlock(LoopB))
476 return false;
477 PV.Res = SelI;
478
479 // If X is loop invariant, it must be the input polynomial, and the
480 // idiom is the basic polynomial multiply.
481 if (CurLoop->isLoopInvariant(PV.X)) {
482 PV.P = PV.X;
483 PV.Inv = false;
484 } else {
485 // X is not loop invariant. If X == R, this is the inverse pmpy.
486 // Otherwise, check for an xor with an invariant value. If the
487 // variable argument to the xor is R, then this is still a valid
488 // inverse pmpy.
489 PV.Inv = true;
490 if (PV.X != PV.R) {
491 Value *Var = nullptr, *Inv = nullptr, *X1 = nullptr, *X2 = nullptr;
492 if (!match(PV.X, m_Xor(m_Value(X1), m_Value(X2))))
493 return false;
494 auto *I1 = dyn_cast<Instruction>(X1);
495 auto *I2 = dyn_cast<Instruction>(X2);
496 if (!I1 || I1->getParent() != LoopB) {
497 Var = X2;
498 Inv = X1;
499 } else if (!I2 || I2->getParent() != LoopB) {
500 Var = X1;
501 Inv = X2;
502 } else
503 return false;
504 if (Var != PV.R)
505 return false;
506 PV.M = Inv;
507 }
508 // The input polynomial P still needs to be determined. It will be
509 // the entry value of R.
510 Value *EntryP = RPhi->getIncomingValueForBlock(PrehB);
511 PV.P = EntryP;
512 }
513
514 return true;
515 }
516
517 if (matchRightShift(SelI, PV)) {
518 // If this is an inverse pattern, the Q polynomial must be known at
519 // compile time.
520 if (PV.Inv && !isa<ConstantInt>(PV.Q))
521 return false;
522 if (PreScan)
523 return true;
524 // There is no exact matching of right-shift pmpy.
525 return false;
526 }
527
528 return false;
529}
530
531
532bool PolynomialMultiplyRecognize::findCycle(Value *Out, Value *In,
533 ValueSeq &Cycle) {
534 // Out = ..., In, ...
535 if (Out == In)
536 return true;
537
538 auto *BB = cast<Instruction>(Out)->getParent();
539 bool HadPhi = false;
540
541 for (auto U : Out->users()) {
542 auto *I = dyn_cast<Instruction>(&*U);
543 if (I == nullptr || I->getParent() != BB)
544 continue;
545 // Make sure that there are no multi-iteration cycles, e.g.
546 // p1 = phi(p2)
547 // p2 = phi(p1)
548 // The cycle p1->p2->p1 would span two loop iterations.
549 // Check that there is only one phi in the cycle.
550 bool IsPhi = isa<PHINode>(I);
551 if (IsPhi && HadPhi)
552 return false;
553 HadPhi |= IsPhi;
554 if (Cycle.count(I))
555 return false;
556 Cycle.insert(I);
557 if (findCycle(I, In, Cycle))
558 break;
559 Cycle.remove(I);
560 }
561 return !Cycle.empty();
562}
563
564
565void PolynomialMultiplyRecognize::classifyCycle(Instruction *DivI,
566 ValueSeq &Cycle, ValueSeq &Early, ValueSeq &Late) {
567 // All the values in the cycle that are between the phi node and the
568 // divider instruction will be classified as "early", all other values
569 // will be "late".
570
571 bool IsE = true;
572 unsigned I, N = Cycle.size();
573 for (I = 0; I < N; ++I) {
574 Value *V = Cycle[I];
575 if (DivI == V)
576 IsE = false;
577 else if (!isa<PHINode>(V))
578 continue;
579 // Stop if found either.
580 break;
581 }
582 // "I" is the index of either DivI or the phi node, whichever was first.
583 // "E" is "false" or "true" respectively.
584 ValueSeq &First = !IsE ? Early : Late;
585 for (unsigned J = 0; J < I; ++J)
586 First.insert(Cycle[J]);
587
588 ValueSeq &Second = IsE ? Early : Late;
589 Second.insert(Cycle[I]);
590 for (++I; I < N; ++I) {
591 Value *V = Cycle[I];
592 if (DivI == V || isa<PHINode>(V))
593 break;
594 Second.insert(V);
595 }
596
597 for (; I < N; ++I)
598 First.insert(Cycle[I]);
599}
600
601
602bool PolynomialMultiplyRecognize::classifyInst(Instruction *UseI,
603 ValueSeq &Early, ValueSeq &Late) {
604 // Select is an exception, since the condition value does not have to be
605 // classified in the same way as the true/false values. The true/false
606 // values do have to be both early or both late.
607 if (UseI->getOpcode() == Instruction::Select) {
608 Value *TV = UseI->getOperand(1), *FV = UseI->getOperand(2);
609 if (Early.count(TV) || Early.count(FV)) {
610 if (Late.count(TV) || Late.count(FV))
611 return false;
612 Early.insert(UseI);
613 } else if (Late.count(TV) || Late.count(FV)) {
614 if (Early.count(TV) || Early.count(FV))
615 return false;
616 Late.insert(UseI);
617 }
618 return true;
619 }
620
621 // Not sure what would be the example of this, but the code below relies
622 // on having at least one operand.
623 if (UseI->getNumOperands() == 0)
624 return true;
625
626 bool AE = true, AL = true;
627 for (auto &I : UseI->operands()) {
628 if (Early.count(&*I))
629 AL = false;
630 else if (Late.count(&*I))
631 AE = false;
632 }
633 // If the operands appear "all early" and "all late" at the same time,
634 // then it means that none of them are actually classified as either.
635 // This is harmless.
636 if (AE && AL)
637 return true;
638 // Conversely, if they are neither "all early" nor "all late", then
639 // we have a mixture of early and late operands that is not a known
640 // exception.
641 if (!AE && !AL)
642 return false;
643
644 // Check that we have covered the two special cases.
645 assert(AE != AL);
646
647 if (AE)
648 Early.insert(UseI);
649 else
650 Late.insert(UseI);
651 return true;
652}
653
654
655bool PolynomialMultiplyRecognize::commutesWithShift(Instruction *I) {
656 switch (I->getOpcode()) {
657 case Instruction::And:
658 case Instruction::Or:
659 case Instruction::Xor:
660 case Instruction::LShr:
661 case Instruction::Shl:
662 case Instruction::Select:
663 case Instruction::ICmp:
664 case Instruction::PHI:
665 break;
666 default:
667 return false;
668 }
669 return true;
670}
671
672
673bool PolynomialMultiplyRecognize::highBitsAreZero(Value *V,
674 unsigned IterCount) {
675 auto *T = dyn_cast<IntegerType>(V->getType());
676 if (!T)
677 return false;
678
679 unsigned BW = T->getBitWidth();
680 APInt K0(BW, 0), K1(BW, 0);
681 computeKnownBits(V, K0, K1, DL);
682 return K0.countLeadingOnes() >= IterCount;
683}
684
685
686bool PolynomialMultiplyRecognize::keepsHighBitsZero(Value *V,
687 unsigned IterCount) {
688 // Assume that all inputs to the value have the high bits zero.
689 // Check if the value itself preserves the zeros in the high bits.
690 if (auto *C = dyn_cast<ConstantInt>(V))
691 return C->getValue().countLeadingZeros() >= IterCount;
692
693 if (auto *I = dyn_cast<Instruction>(V)) {
694 switch (I->getOpcode()) {
695 case Instruction::And:
696 case Instruction::Or:
697 case Instruction::Xor:
698 case Instruction::LShr:
699 case Instruction::Select:
700 case Instruction::ICmp:
701 case Instruction::PHI:
702 return true;
703 }
704 }
705
706 return false;
707}
708
709
710bool PolynomialMultiplyRecognize::isOperandShifted(Instruction *I, Value *Op) {
711 unsigned Opc = I->getOpcode();
712 if (Opc == Instruction::Shl || Opc == Instruction::LShr)
713 return Op != I->getOperand(1);
714 return true;
715}
716
717
718bool PolynomialMultiplyRecognize::convertShiftsToLeft(BasicBlock *LoopB,
719 BasicBlock *ExitB, unsigned IterCount) {
720 Value *CIV = getCountIV(LoopB);
721 if (CIV == nullptr)
722 return false;
723 auto *CIVTy = dyn_cast<IntegerType>(CIV->getType());
724 if (CIVTy == nullptr)
725 return false;
726
727 ValueSeq RShifts;
728 ValueSeq Early, Late, Cycled;
729
730 // Find all value cycles that contain logical right shifts by 1.
731 for (Instruction &I : *LoopB) {
732 using namespace PatternMatch;
733 Value *V = nullptr;
734 if (!match(&I, m_LShr(m_Value(V), m_One())))
735 continue;
736 ValueSeq C;
737 if (!findCycle(&I, V, C))
738 continue;
739
740 // Found a cycle.
741 C.insert(&I);
742 classifyCycle(&I, C, Early, Late);
743 Cycled.insert(C.begin(), C.end());
744 RShifts.insert(&I);
745 }
746
747 // Find the set of all values affected by the shift cycles, i.e. all
748 // cycled values, and (recursively) all their users.
749 ValueSeq Users(Cycled.begin(), Cycled.end());
750 for (unsigned i = 0; i < Users.size(); ++i) {
751 Value *V = Users[i];
752 if (!isa<IntegerType>(V->getType()))
753 return false;
754 auto *R = cast<Instruction>(V);
755 // If the instruction does not commute with shifts, the loop cannot
756 // be unshifted.
757 if (!commutesWithShift(R))
758 return false;
759 for (auto I = R->user_begin(), E = R->user_end(); I != E; ++I) {
760 auto *T = cast<Instruction>(*I);
761 // Skip users from outside of the loop. They will be handled later.
762 // Also, skip the right-shifts and phi nodes, since they mix early
763 // and late values.
764 if (T->getParent() != LoopB || RShifts.count(T) || isa<PHINode>(T))
765 continue;
766
767 Users.insert(T);
768 if (!classifyInst(T, Early, Late))
769 return false;
770 }
771 }
772
773 if (Users.size() == 0)
774 return false;
775
776 // Verify that high bits remain zero.
777 ValueSeq Internal(Users.begin(), Users.end());
778 ValueSeq Inputs;
779 for (unsigned i = 0; i < Internal.size(); ++i) {
780 auto *R = dyn_cast<Instruction>(Internal[i]);
781 if (!R)
782 continue;
783 for (Value *Op : R->operands()) {
784 auto *T = dyn_cast<Instruction>(Op);
785 if (T && T->getParent() != LoopB)
786 Inputs.insert(Op);
787 else
788 Internal.insert(Op);
789 }
790 }
791 for (Value *V : Inputs)
792 if (!highBitsAreZero(V, IterCount))
793 return false;
794 for (Value *V : Internal)
795 if (!keepsHighBitsZero(V, IterCount))
796 return false;
797
798 // Finally, the work can be done. Unshift each user.
799 IRBuilder<> IRB(LoopB);
800 std::map<Value*,Value*> ShiftMap;
801 typedef std::map<std::pair<Value*,Type*>,Value*> CastMapType;
802 CastMapType CastMap;
803
804 auto upcast = [] (CastMapType &CM, IRBuilder<> &IRB, Value *V,
805 IntegerType *Ty) -> Value* {
806 auto H = CM.find(std::make_pair(V, Ty));
807 if (H != CM.end())
808 return H->second;
809 Value *CV = IRB.CreateIntCast(V, Ty, false);
810 CM.insert(std::make_pair(std::make_pair(V, Ty), CV));
811 return CV;
812 };
813
814 for (auto I = LoopB->begin(), E = LoopB->end(); I != E; ++I) {
815 if (isa<PHINode>(I) || !Users.count(&*I))
816 continue;
817 using namespace PatternMatch;
818 // Match lshr x, 1.
819 Value *V = nullptr;
820 if (match(&*I, m_LShr(m_Value(V), m_One()))) {
821 replaceAllUsesOfWithIn(&*I, V, LoopB);
822 continue;
823 }
824 // For each non-cycled operand, replace it with the corresponding
825 // value shifted left.
826 for (auto &J : I->operands()) {
827 Value *Op = J.get();
828 if (!isOperandShifted(&*I, Op))
829 continue;
830 if (Users.count(Op))
831 continue;
832 // Skip shifting zeros.
833 if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero())
834 continue;
835 // Check if we have already generated a shift for this value.
836 auto F = ShiftMap.find(Op);
837 Value *W = (F != ShiftMap.end()) ? F->second : nullptr;
838 if (W == nullptr) {
839 IRB.SetInsertPoint(&*I);
840 // First, the shift amount will be CIV or CIV+1, depending on
841 // whether the value is early or late. Instead of creating CIV+1,
842 // do a single shift of the value.
843 Value *ShAmt = CIV, *ShVal = Op;
844 auto *VTy = cast<IntegerType>(ShVal->getType());
845 auto *ATy = cast<IntegerType>(ShAmt->getType());
846 if (Late.count(&*I))
847 ShVal = IRB.CreateShl(Op, ConstantInt::get(VTy, 1));
848 // Second, the types of the shifted value and the shift amount
849 // must match.
850 if (VTy != ATy) {
851 if (VTy->getBitWidth() < ATy->getBitWidth())
852 ShVal = upcast(CastMap, IRB, ShVal, ATy);
853 else
854 ShAmt = upcast(CastMap, IRB, ShAmt, VTy);
855 }
856 // Ready to generate the shift and memoize it.
857 W = IRB.CreateShl(ShVal, ShAmt);
858 ShiftMap.insert(std::make_pair(Op, W));
859 }
860 I->replaceUsesOfWith(Op, W);
861 }
862 }
863
864 // Update the users outside of the loop to account for having left
865 // shifts. They would normally be shifted right in the loop, so shift
866 // them right after the loop exit.
867 // Take advantage of the loop-closed SSA form, which has all the post-
868 // loop values in phi nodes.
869 IRB.SetInsertPoint(ExitB, ExitB->getFirstInsertionPt());
870 for (auto P = ExitB->begin(), Q = ExitB->end(); P != Q; ++P) {
871 if (!isa<PHINode>(P))
872 break;
873 auto *PN = cast<PHINode>(P);
874 Value *U = PN->getIncomingValueForBlock(LoopB);
875 if (!Users.count(U))
876 continue;
877 Value *S = IRB.CreateLShr(PN, ConstantInt::get(PN->getType(), IterCount));
878 PN->replaceAllUsesWith(S);
879 // The above RAUW will create
880 // S = lshr S, IterCount
881 // so we need to fix it back into
882 // S = lshr PN, IterCount
883 cast<User>(S)->replaceUsesOfWith(S, PN);
884 }
885
886 return true;
887}
888
889
890void PolynomialMultiplyRecognize::cleanupLoopBody(BasicBlock *LoopB) {
891 for (auto &I : *LoopB)
892 if (Value *SV = SimplifyInstruction(&I, DL, &TLI, &DT))
893 I.replaceAllUsesWith(SV);
894
895 for (auto I = LoopB->begin(), N = I; I != LoopB->end(); I = N) {
896 N = std::next(I);
897 RecursivelyDeleteTriviallyDeadInstructions(&*I, &TLI);
898 }
899}
900
901
902unsigned PolynomialMultiplyRecognize::getInverseMxN(unsigned QP) {
903 // Arrays of coefficients of Q and the inverse, C.
904 // Q[i] = coefficient at x^i.
905 std::array<char,32> Q, C;
906
907 for (unsigned i = 0; i < 32; ++i) {
908 Q[i] = QP & 1;
909 QP >>= 1;
910 }
911 assert(Q[0] == 1);
912
913 // Find C, such that
914 // (Q[n]*x^n + ... + Q[1]*x + Q[0]) * (C[n]*x^n + ... + C[1]*x + C[0]) = 1
915 //
916 // For it to have a solution, Q[0] must be 1. Since this is Z2[x], the
917 // operations * and + are & and ^ respectively.
918 //
919 // Find C[i] recursively, by comparing i-th coefficient in the product
920 // with 0 (or 1 for i=0).
921 //
922 // C[0] = 1, since C[0] = Q[0], and Q[0] = 1.
923 C[0] = 1;
924 for (unsigned i = 1; i < 32; ++i) {
925 // Solve for C[i] in:
926 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i]Q[0] = 0
927 // This is equivalent to
928 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] ^ C[i] = 0
929 // which is
930 // C[0]Q[i] ^ C[1]Q[i-1] ^ ... ^ C[i-1]Q[1] = C[i]
931 unsigned T = 0;
932 for (unsigned j = 0; j < i; ++j)
933 T = T ^ (C[j] & Q[i-j]);
934 C[i] = T;
935 }
936
937 unsigned QV = 0;
938 for (unsigned i = 0; i < 32; ++i)
939 if (C[i])
940 QV |= (1 << i);
941
942 return QV;
943}
944
945
946Value *PolynomialMultiplyRecognize::generate(BasicBlock::iterator At,
947 ParsedValues &PV) {
948 IRBuilder<> B(&*At);
949 Module *M = At->getParent()->getParent()->getParent();
950 Value *PMF = Intrinsic::getDeclaration(M, Intrinsic::hexagon_M4_pmpyw);
951
952 Value *P = PV.P, *Q = PV.Q, *P0 = P;
953 unsigned IC = PV.IterCount;
954
955 if (PV.M != nullptr)
956 P0 = P = B.CreateXor(P, PV.M);
957
958 // Create a bit mask to clear the high bits beyond IterCount.
959 auto *BMI = ConstantInt::get(P->getType(), APInt::getLowBitsSet(32, IC));
960
961 if (PV.IterCount != 32)
962 P = B.CreateAnd(P, BMI);
963
964 if (PV.Inv) {
965 auto *QI = dyn_cast<ConstantInt>(PV.Q);
966 assert(QI && QI->getBitWidth() <= 32);
967
968 // Again, clearing bits beyond IterCount.
969 unsigned M = (1 << PV.IterCount) - 1;
970 unsigned Tmp = (QI->getZExtValue() | 1) & M;
971 unsigned QV = getInverseMxN(Tmp) & M;
972 auto *QVI = ConstantInt::get(QI->getType(), QV);
973 P = B.CreateCall(PMF, {P, QVI});
974 P = B.CreateTrunc(P, QI->getType());
975 if (IC != 32)
976 P = B.CreateAnd(P, BMI);
977 }
978
979 Value *R = B.CreateCall(PMF, {P, Q});
980
981 if (PV.M != nullptr)
982 R = B.CreateXor(R, B.CreateIntCast(P0, R->getType(), false));
983
984 return R;
985}
986
987
988bool PolynomialMultiplyRecognize::recognize() {
989 // Restrictions:
990 // - The loop must consist of a single block.
991 // - The iteration count must be known at compile-time.
992 // - The loop must have an induction variable starting from 0, and
993 // incremented in each iteration of the loop.
994 BasicBlock *LoopB = CurLoop->getHeader();
995 if (LoopB != CurLoop->getLoopLatch())
996 return false;
997 BasicBlock *ExitB = CurLoop->getExitBlock();
998 if (ExitB == nullptr)
999 return false;
1000 BasicBlock *EntryB = CurLoop->getLoopPreheader();
1001 if (EntryB == nullptr)
1002 return false;
1003
1004 unsigned IterCount = 0;
1005 const SCEV *CT = SE.getBackedgeTakenCount(CurLoop);
1006 if (isa<SCEVCouldNotCompute>(CT))
1007 return false;
1008 if (auto *CV = dyn_cast<SCEVConstant>(CT))
1009 IterCount = CV->getValue()->getZExtValue() + 1;
1010
1011 Value *CIV = getCountIV(LoopB);
1012 ParsedValues PV;
1013 PV.IterCount = IterCount;
1014
1015 // Test function to see if a given select instruction is a part of the
1016 // pmpy pattern. The argument PreScan set to "true" indicates that only
1017 // a preliminary scan is needed, "false" indicated an exact match.
1018 auto CouldBePmpy = [this, LoopB, EntryB, CIV, &PV] (bool PreScan)
1019 -> std::function<bool (Instruction &I)> {
1020 return [this, LoopB, EntryB, CIV, &PV, PreScan] (Instruction &I) -> bool {
1021 if (auto *SelI = dyn_cast<SelectInst>(&I))
1022 return scanSelect(SelI, LoopB, EntryB, CIV, PV, PreScan);
1023 return false;
1024 };
1025 };
1026 auto PreF = std::find_if(LoopB->begin(), LoopB->end(), CouldBePmpy(true));
1027 if (PreF == LoopB->end())
1028 return false;
1029
1030 if (!PV.Left) {
1031 convertShiftsToLeft(LoopB, ExitB, IterCount);
1032 cleanupLoopBody(LoopB);
1033 }
1034
1035 auto PostF = std::find_if(LoopB->begin(), LoopB->end(), CouldBePmpy(false));
1036 if (PostF == LoopB->end())
1037 return false;
1038
1039 DEBUG({
1040 StringRef PP = (PV.M ? "(P+M)" : "P");
1041 if (!PV.Inv)
1042 dbgs() << "Found pmpy idiom: R = " << PP << ".Q\n";
1043 else
1044 dbgs() << "Found inverse pmpy idiom: R = (" << PP << "/Q).Q) + "
1045 << PP << "\n";
1046 dbgs() << " Res:" << *PV.Res << "\n P:" << *PV.P << "\n";
1047 if (PV.M)
1048 dbgs() << " M:" << *PV.M << "\n";
1049 dbgs() << " Q:" << *PV.Q << "\n";
1050 dbgs() << " Iteration count:" << PV.IterCount << "\n";
1051 });
1052
1053 BasicBlock::iterator At(EntryB->getTerminator());
1054 Value *PM = generate(At, PV);
1055 if (PM == nullptr)
1056 return false;
1057
1058 if (PM->getType() != PV.Res->getType())
1059 PM = IRBuilder<>(&*At).CreateIntCast(PM, PV.Res->getType(), false);
1060
1061 PV.Res->replaceAllUsesWith(PM);
1062 PV.Res->eraseFromParent();
1063 return true;
1064}
1065
1066
1067unsigned HexagonLoopIdiomRecognize::getStoreSizeInBytes(StoreInst *SI) {
1068 uint64_t SizeInBits = DL->getTypeSizeInBits(SI->getValueOperand()->getType());
1069 assert(((SizeInBits & 7) || (SizeInBits >> 32) == 0) &&
1070 "Don't overflow unsigned.");
1071 return (unsigned)SizeInBits >> 3;
1072}
1073
1074
1075int HexagonLoopIdiomRecognize::getSCEVStride(const SCEVAddRecExpr *S) {
1076 if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(S->getOperand(1)))
1077 return SC->getAPInt().getSExtValue();
1078 return 0;
1079}
1080
1081
1082bool HexagonLoopIdiomRecognize::isLegalStore(Loop *CurLoop, StoreInst *SI) {
1083 bool IsVolatile = false;
1084 if (SI->isVolatile() && HexagonVolatileMemcpy)
1085 IsVolatile = true;
1086 else if (!SI->isSimple())
1087 return false;
1088
1089 Value *StoredVal = SI->getValueOperand();
1090 Value *StorePtr = SI->getPointerOperand();
1091
1092 // Reject stores that are so large that they overflow an unsigned.
1093 uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType());
1094 if ((SizeInBits & 7) || (SizeInBits >> 32) != 0)
1095 return false;
1096
1097 // See if the pointer expression is an AddRec like {base,+,1} on the current
1098 // loop, which indicates a strided store. If we have something else, it's a
1099 // random store we can't handle.
1100 auto *StoreEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
1101 if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine())
1102 return false;
1103
1104 // Check to see if the stride matches the size of the store. If so, then we
1105 // know that every byte is touched in the loop.
1106 int Stride = getSCEVStride(StoreEv);
1107 if (Stride == 0)
1108 return false;
1109 unsigned StoreSize = getStoreSizeInBytes(SI);
1110 if (StoreSize != unsigned(std::abs(Stride)))
1111 return false;
1112
1113 // The store must be feeding a non-volatile load.
1114 LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand());
1115 if (!LI || !LI->isSimple())
1116 return false;
1117
1118 // See if the pointer expression is an AddRec like {base,+,1} on the current
1119 // loop, which indicates a strided load. If we have something else, it's a
1120 // random load we can't handle.
1121 Value *LoadPtr = LI->getPointerOperand();
1122 auto *LoadEv = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LoadPtr));
1123 if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine())
1124 return false;
1125
1126 // The store and load must share the same stride.
1127 if (StoreEv->getOperand(1) != LoadEv->getOperand(1))
1128 return false;
1129
1130 // Success. This store can be converted into a memcpy.
1131 return true;
1132}
1133
1134
1135/// mayLoopAccessLocation - Return true if the specified loop might access the
1136/// specified pointer location, which is a loop-strided access. The 'Access'
1137/// argument specifies what the verboten forms of access are (read or write).
1138static bool
1139mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L,
1140 const SCEV *BECount, unsigned StoreSize,
1141 AliasAnalysis &AA,
1142 SmallPtrSetImpl<Instruction *> &Ignored) {
1143 // Get the location that may be stored across the loop. Since the access
1144 // is strided positively through memory, we say that the modified location
1145 // starts at the pointer and has infinite size.
1146 uint64_t AccessSize = MemoryLocation::UnknownSize;
1147
1148 // If the loop iterates a fixed number of times, we can refine the access
1149 // size to be exactly the size of the memset, which is (BECount+1)*StoreSize
1150 if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount))
1151 AccessSize = (BECst->getValue()->getZExtValue() + 1) * StoreSize;
1152
1153 // TODO: For this to be really effective, we have to dive into the pointer
1154 // operand in the store. Store to &A[i] of 100 will always return may alias
1155 // with store of &A[100], we need to StoreLoc to be "A" with size of 100,
1156 // which will then no-alias a store to &A[100].
1157 MemoryLocation StoreLoc(Ptr, AccessSize);
1158
1159 for (auto *B : L->blocks())
1160 for (auto &I : *B)
1161 if (Ignored.count(&I) == 0 && (AA.getModRefInfo(&I, StoreLoc) & Access))
1162 return true;
1163
1164 return false;
1165}
1166
1167
1168void HexagonLoopIdiomRecognize::collectStores(Loop *CurLoop, BasicBlock *BB,
1169 SmallVectorImpl<StoreInst*> &Stores) {
1170 Stores.clear();
1171 for (Instruction &I : *BB)
1172 if (StoreInst *SI = dyn_cast<StoreInst>(&I))
1173 if (isLegalStore(CurLoop, SI))
1174 Stores.push_back(SI);
1175}
1176
1177
1178bool HexagonLoopIdiomRecognize::processCopyingStore(Loop *CurLoop,
1179 StoreInst *SI, const SCEV *BECount) {
1180 assert(SI->isSimple() || (SI->isVolatile() && HexagonVolatileMemcpy) &&
1181 "Expected only non-volatile stores, or Hexagon-specific memcpy"
1182 "to volatile destination.");
1183
1184 Value *StorePtr = SI->getPointerOperand();
1185 auto *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr));
1186 unsigned Stride = getSCEVStride(StoreEv);
1187 unsigned StoreSize = getStoreSizeInBytes(SI);
1188 if (Stride != StoreSize)
1189 return false;
1190
1191 // See if the pointer expression is an AddRec like {base,+,1} on the current
1192 // loop, which indicates a strided load. If we have something else, it's a
1193 // random load we can't handle.
1194 LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand());
1195 auto *LoadEv = cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand()));
1196
1197 // The trip count of the loop and the base pointer of the addrec SCEV is
1198 // guaranteed to be loop invariant, which means that it should dominate the
1199 // header. This allows us to insert code for it in the preheader.
1200 BasicBlock *Preheader = CurLoop->getLoopPreheader();
1201 Instruction *ExpPt = Preheader->getTerminator();
1202 IRBuilder<> Builder(ExpPt);
1203 SCEVExpander Expander(*SE, *DL, "hexagon-loop-idiom");
1204
1205 Type *IntPtrTy = Builder.getIntPtrTy(*DL, SI->getPointerAddressSpace());
1206
1207 // Okay, we have a strided store "p[i]" of a loaded value. We can turn
1208 // this into a memcpy/memmove in the loop preheader now if we want. However,
1209 // this would be unsafe to do if there is anything else in the loop that may
1210 // read or write the memory region we're storing to. For memcpy, this
1211 // includes the load that feeds the stores. Check for an alias by generating
1212 // the base address and checking everything.
1213 Value *StoreBasePtr = Expander.expandCodeFor(StoreEv->getStart(),
1214 Builder.getInt8PtrTy(SI->getPointerAddressSpace()), ExpPt);
1215 Value *LoadBasePtr = nullptr;
1216
1217 bool Overlap = false;
1218 bool DestVolatile = SI->isVolatile();
1219 Type *BECountTy = BECount->getType();
1220
1221 if (DestVolatile) {
1222 // The trip count must fit in i32, since it is the type of the "num_words"
1223 // argument to hexagon_memcpy_forward_vp4cp4n2.
1224 if (StoreSize != 4 || DL->getTypeSizeInBits(BECountTy) > 32) {
1225CleanupAndExit:
1226 // If we generated new code for the base pointer, clean up.
1227 Expander.clear();
1228 if (StoreBasePtr && (LoadBasePtr != StoreBasePtr)) {
1229 RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI);
1230 StoreBasePtr = nullptr;
1231 }
1232 if (LoadBasePtr) {
1233 RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI);
1234 LoadBasePtr = nullptr;
1235 }
1236 return false;
1237 }
1238 }
1239
1240 SmallPtrSet<Instruction*, 2> Ignore1;
1241 Ignore1.insert(SI);
1242 if (mayLoopAccessLocation(StoreBasePtr, MRI_ModRef, CurLoop, BECount,
1243 StoreSize, *AA, Ignore1)) {
1244 // Check if the load is the offending instruction.
1245 Ignore1.insert(LI);
1246 if (mayLoopAccessLocation(StoreBasePtr, MRI_ModRef, CurLoop, BECount,
1247 StoreSize, *AA, Ignore1)) {
1248 // Still bad. Nothing we can do.
1249 goto CleanupAndExit;
1250 }
1251 // It worked with the load ignored.
1252 Overlap = true;
1253 }
1254
1255 if (!Overlap) {
1256 if (DisableMemcpyIdiom || !HasMemcpy)
1257 goto CleanupAndExit;
1258 } else {
1259 // Don't generate memmove if this function will be inlined. This is
1260 // because the caller will undergo this transformation after inlining.
1261 Function *Func = CurLoop->getHeader()->getParent();
1262 if (Func->hasFnAttribute(Attribute::AlwaysInline))
1263 goto CleanupAndExit;
1264
1265 // In case of a memmove, the call to memmove will be executed instead
1266 // of the loop, so we need to make sure that there is nothing else in
1267 // the loop than the load, store and instructions that these two depend
1268 // on.
1269 SmallVector<Instruction*,2> Insts;
1270 Insts.push_back(SI);
1271 Insts.push_back(LI);
1272 if (!coverLoop(CurLoop, Insts))
1273 goto CleanupAndExit;
1274
1275 if (DisableMemmoveIdiom || !HasMemmove)
1276 goto CleanupAndExit;
1277 bool IsNested = CurLoop->getParentLoop() != 0;
1278 if (IsNested && OnlyNonNestedMemmove)
1279 goto CleanupAndExit;
1280 }
1281
1282 // For a memcpy, we have to make sure that the input array is not being
1283 // mutated by the loop.
1284 LoadBasePtr = Expander.expandCodeFor(LoadEv->getStart(),
1285 Builder.getInt8PtrTy(LI->getPointerAddressSpace()), ExpPt);
1286
1287 SmallPtrSet<Instruction*, 2> Ignore2;
1288 Ignore2.insert(SI);
1289 if (mayLoopAccessLocation(LoadBasePtr, MRI_Mod, CurLoop, BECount, StoreSize,
1290 *AA, Ignore2))
1291 goto CleanupAndExit;
1292
1293 // Check the stride.
1294 bool StridePos = getSCEVStride(LoadEv) >= 0;
1295
1296 // Currently, the volatile memcpy only emulates traversing memory forward.
1297 if (!StridePos && DestVolatile)
1298 goto CleanupAndExit;
1299
1300 bool RuntimeCheck = (Overlap || DestVolatile);
1301
1302 BasicBlock *ExitB;
1303 if (RuntimeCheck) {
1304 // The runtime check needs a single exit block.
1305 SmallVector<BasicBlock*, 8> ExitBlocks;
1306 CurLoop->getUniqueExitBlocks(ExitBlocks);
1307 if (ExitBlocks.size() != 1)
1308 goto CleanupAndExit;
1309 ExitB = ExitBlocks[0];
1310 }
1311
1312 // The # stored bytes is (BECount+1)*Size. Expand the trip count out to
1313 // pointer size if it isn't already.
1314 LLVMContext &Ctx = SI->getContext();
1315 BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy);
1316 unsigned Alignment = std::min(SI->getAlignment(), LI->getAlignment());
1317 DebugLoc DLoc = SI->getDebugLoc();
1318
1319 const SCEV *NumBytesS =
1320 SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW);
1321 if (StoreSize != 1)
1322 NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize),
1323 SCEV::FlagNUW);
1324 Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, ExpPt);
1325 if (Instruction *In = dyn_cast<Instruction>(NumBytes))
1326 if (Value *Simp = SimplifyInstruction(In, *DL, TLI, DT))
1327 NumBytes = Simp;
1328
1329 CallInst *NewCall;
1330
1331 if (RuntimeCheck) {
1332 unsigned Threshold = RuntimeMemSizeThreshold;
1333 if (ConstantInt *CI = dyn_cast<ConstantInt>(NumBytes)) {
1334 uint64_t C = CI->getZExtValue();
1335 if (Threshold != 0 && C < Threshold)
1336 goto CleanupAndExit;
1337 if (C < CompileTimeMemSizeThreshold)
1338 goto CleanupAndExit;
1339 }
1340
1341 BasicBlock *Header = CurLoop->getHeader();
1342 Function *Func = Header->getParent();
1343 Loop *ParentL = LF->getLoopFor(Preheader);
1344 StringRef HeaderName = Header->getName();
1345
1346 // Create a new (empty) preheader, and update the PHI nodes in the
1347 // header to use the new preheader.
1348 BasicBlock *NewPreheader = BasicBlock::Create(Ctx, HeaderName+".rtli.ph",
1349 Func, Header);
1350 if (ParentL)
1351 ParentL->addBasicBlockToLoop(NewPreheader, *LF);
1352 IRBuilder<>(NewPreheader).CreateBr(Header);
1353 for (auto &In : *Header) {
1354 PHINode *PN = dyn_cast<PHINode>(&In);
1355 if (!PN)
1356 break;
1357 int bx = PN->getBasicBlockIndex(Preheader);
1358 if (bx >= 0)
1359 PN->setIncomingBlock(bx, NewPreheader);
1360 }
1361 DT->addNewBlock(NewPreheader, Preheader);
1362 DT->changeImmediateDominator(Header, NewPreheader);
1363
1364 // Check for safe conditions to execute memmove.
1365 // If stride is positive, copying things from higher to lower addresses
1366 // is equivalent to memmove. For negative stride, it's the other way
1367 // around. Copying forward in memory with positive stride may not be
1368 // same as memmove since we may be copying values that we just stored
1369 // in some previous iteration.
1370 Value *LA = Builder.CreatePtrToInt(LoadBasePtr, IntPtrTy);
1371 Value *SA = Builder.CreatePtrToInt(StoreBasePtr, IntPtrTy);
1372 Value *LowA = StridePos ? SA : LA;
1373 Value *HighA = StridePos ? LA : SA;
1374 Value *CmpA = Builder.CreateICmpULT(LowA, HighA);
1375 Value *Cond = CmpA;
1376
1377 // Check for distance between pointers.
1378 Value *Dist = Builder.CreateSub(HighA, LowA);
1379 Value *CmpD = Builder.CreateICmpSLT(NumBytes, Dist);
1380 Value *CmpEither = Builder.CreateOr(Cond, CmpD);
1381 Cond = CmpEither;
1382
1383 if (Threshold != 0) {
1384 Type *Ty = NumBytes->getType();
1385 Value *Thr = ConstantInt::get(Ty, Threshold);
1386 Value *CmpB = Builder.CreateICmpULT(Thr, NumBytes);
1387 Value *CmpBoth = Builder.CreateAnd(Cond, CmpB);
1388 Cond = CmpBoth;
1389 }
1390 BasicBlock *MemmoveB = BasicBlock::Create(Ctx, Header->getName()+".rtli",
1391 Func, NewPreheader);
1392 if (ParentL)
1393 ParentL->addBasicBlockToLoop(MemmoveB, *LF);
1394 Instruction *OldT = Preheader->getTerminator();
1395 Builder.CreateCondBr(Cond, MemmoveB, NewPreheader);
1396 OldT->eraseFromParent();
1397 Preheader->setName(Preheader->getName()+".old");
1398 DT->addNewBlock(MemmoveB, Preheader);
1399 // Find the new immediate dominator of the exit block.
1400 BasicBlock *ExitD = Preheader;
1401 for (auto PI = pred_begin(ExitB), PE = pred_end(ExitB); PI != PE; ++PI) {
1402 BasicBlock *PB = *PI;
1403 ExitD = DT->findNearestCommonDominator(ExitD, PB);
1404 if (!ExitD)
1405 break;
1406 }
1407 // If the prior immediate dominator of ExitB was dominated by the
1408 // old preheader, then the old preheader becomes the new immediate
1409 // dominator. Otherwise don't change anything (because the newly
1410 // added blocks are dominated by the old preheader).
1411 if (ExitD && DT->dominates(Preheader, ExitD)) {
1412 DomTreeNode *BN = DT->getNode(ExitB);
1413 DomTreeNode *DN = DT->getNode(ExitD);
1414 BN->setIDom(DN);
1415 }
1416
1417 // Add a call to memmove to the conditional block.
1418 IRBuilder<> CondBuilder(MemmoveB);
1419 CondBuilder.CreateBr(ExitB);
1420 CondBuilder.SetInsertPoint(MemmoveB->getTerminator());
1421
1422 if (DestVolatile) {
1423 Type *Int32Ty = Type::getInt32Ty(Ctx);
1424 Type *Int32PtrTy = Type::getInt32PtrTy(Ctx);
1425 Type *VoidTy = Type::getVoidTy(Ctx);
1426 Module *M = Func->getParent();
1427 Constant *CF = M->getOrInsertFunction(HexagonVolatileMemcpyName, VoidTy,
1428 Int32PtrTy, Int32PtrTy, Int32Ty,
1429 nullptr);
1430 Function *Fn = cast<Function>(CF);
1431 Fn->setLinkage(Function::ExternalLinkage);
1432
1433 const SCEV *OneS = SE->getConstant(Int32Ty, 1);
1434 const SCEV *BECount32 = SE->getTruncateOrZeroExtend(BECount, Int32Ty);
1435 const SCEV *NumWordsS = SE->getAddExpr(BECount32, OneS, SCEV::FlagNUW);
1436 Value *NumWords = Expander.expandCodeFor(NumWordsS, Int32Ty,
1437 MemmoveB->getTerminator());
1438 if (Instruction *In = dyn_cast<Instruction>(NumWords))
1439 if (Value *Simp = SimplifyInstruction(In, *DL, TLI, DT))
1440 NumWords = Simp;
1441
1442 Value *Op0 = (StoreBasePtr->getType() == Int32PtrTy)
1443 ? StoreBasePtr
1444 : CondBuilder.CreateBitCast(StoreBasePtr, Int32PtrTy);
1445 Value *Op1 = (LoadBasePtr->getType() == Int32PtrTy)
1446 ? LoadBasePtr
1447 : CondBuilder.CreateBitCast(LoadBasePtr, Int32PtrTy);
1448 NewCall = CondBuilder.CreateCall(Fn, {Op0, Op1, NumWords});
1449 } else {
1450 NewCall = CondBuilder.CreateMemMove(StoreBasePtr, LoadBasePtr,
1451 NumBytes, Alignment);
1452 }
1453 } else {
1454 NewCall = Builder.CreateMemCpy(StoreBasePtr, LoadBasePtr,
1455 NumBytes, Alignment);
1456 // Okay, the memcpy has been formed. Zap the original store and
1457 // anything that feeds into it.
1458 RecursivelyDeleteTriviallyDeadInstructions(SI, TLI);
1459 }
1460
1461 NewCall->setDebugLoc(DLoc);
1462
1463 DEBUG(dbgs() << " Formed " << (Overlap ? "memmove: " : "memcpy: ")
1464 << *NewCall << "\n"
1465 << " from load ptr=" << *LoadEv << " at: " << *LI << "\n"
1466 << " from store ptr=" << *StoreEv << " at: " << *SI << "\n");
1467
1468 return true;
1469}
1470
1471
1472// \brief Check if the instructions in Insts, together with their dependencies
1473// cover the loop in the sense that the loop could be safely eliminated once
1474// the instructions in Insts are removed.
1475bool HexagonLoopIdiomRecognize::coverLoop(Loop *L,
1476 SmallVectorImpl<Instruction*> &Insts) const {
1477 SmallSet<BasicBlock*,8> LoopBlocks;
1478 for (auto *B : L->blocks())
1479 LoopBlocks.insert(B);
1480
1481 SetVector<Instruction*> Worklist(Insts.begin(), Insts.end());
1482
1483 // Collect all instructions from the loop that the instructions in Insts
1484 // depend on (plus their dependencies, etc.). These instructions will
1485 // constitute the expression trees that feed those in Insts, but the trees
1486 // will be limited only to instructions contained in the loop.
1487 for (unsigned i = 0; i < Worklist.size(); ++i) {
1488 Instruction *In = Worklist[i];
1489 for (auto I = In->op_begin(), E = In->op_end(); I != E; ++I) {
1490 Instruction *OpI = dyn_cast<Instruction>(I);
1491 if (!OpI)
1492 continue;
1493 BasicBlock *PB = OpI->getParent();
1494 if (!LoopBlocks.count(PB))
1495 continue;
1496 Worklist.insert(OpI);
1497 }
1498 }
1499
1500 // Scan all instructions in the loop, if any of them have a user outside
1501 // of the loop, or outside of the expressions collected above, then either
1502 // the loop has a side-effect visible outside of it, or there are
1503 // instructions in it that are not involved in the original set Insts.
1504 for (auto *B : L->blocks()) {
1505 for (auto &In : *B) {
1506 if (isa<BranchInst>(In) || isa<DbgInfoIntrinsic>(In))
1507 continue;
1508 if (!Worklist.count(&In) && In.mayHaveSideEffects())
1509 return false;
1510 for (const auto &K : In.users()) {
1511 Instruction *UseI = dyn_cast<Instruction>(K);
1512 if (!UseI)
1513 continue;
1514 BasicBlock *UseB = UseI->getParent();
1515 if (LF->getLoopFor(UseB) != L)
1516 return false;
1517 }
1518 }
1519 }
1520
1521 return true;
1522}
1523
1524/// runOnLoopBlock - Process the specified block, which lives in a counted loop
1525/// with the specified backedge count. This block is known to be in the current
1526/// loop and not in any subloops.
1527bool HexagonLoopIdiomRecognize::runOnLoopBlock(Loop *CurLoop, BasicBlock *BB,
1528 const SCEV *BECount, SmallVectorImpl<BasicBlock*> &ExitBlocks) {
1529 // We can only promote stores in this block if they are unconditionally
1530 // executed in the loop. For a block to be unconditionally executed, it has
1531 // to dominate all the exit blocks of the loop. Verify this now.
1532 auto DominatedByBB = [this,BB] (BasicBlock *EB) -> bool {
1533 return DT->dominates(BB, EB);
1534 };
1535 if (!std::all_of(ExitBlocks.begin(), ExitBlocks.end(), DominatedByBB))
1536 return false;
1537
1538 bool MadeChange = false;
1539 // Look for store instructions, which may be optimized to memset/memcpy.
1540 SmallVector<StoreInst*,8> Stores;
1541 collectStores(CurLoop, BB, Stores);
1542
1543 // Optimize the store into a memcpy, if it feeds an similarly strided load.
1544 for (auto &SI : Stores)
1545 MadeChange |= processCopyingStore(CurLoop, SI, BECount);
1546
1547 return MadeChange;
1548}
1549
1550
1551bool HexagonLoopIdiomRecognize::runOnCountableLoop(Loop *L) {
1552 PolynomialMultiplyRecognize PMR(L, *DL, *DT, *TLI, *SE);
1553 if (PMR.recognize())
1554 return true;
1555
1556 if (!HasMemcpy && !HasMemmove)
1557 return false;
1558
1559 const SCEV *BECount = SE->getBackedgeTakenCount(L);
1560 assert(!isa<SCEVCouldNotCompute>(BECount) &&
1561 "runOnCountableLoop() called on a loop without a predictable"
1562 "backedge-taken count");
1563
1564 SmallVector<BasicBlock *, 8> ExitBlocks;
1565 L->getUniqueExitBlocks(ExitBlocks);
1566
1567 bool Changed = false;
1568
1569 // Scan all the blocks in the loop that are not in subloops.
1570 for (auto *BB : L->getBlocks()) {
1571 // Ignore blocks in subloops.
1572 if (LF->getLoopFor(BB) != L)
1573 continue;
1574 Changed |= runOnLoopBlock(L, BB, BECount, ExitBlocks);
1575 }
1576
1577 return Changed;
1578}
1579
1580
1581bool HexagonLoopIdiomRecognize::runOnLoop(Loop *L, LPPassManager &LPM) {
1582 const Module &M = *L->getHeader()->getParent()->getParent();
1583 if (Triple(M.getTargetTriple()).getArch() != Triple::hexagon)
1584 return false;
1585
1586 if (skipLoop(L))
1587 return false;
1588
1589 // If the loop could not be converted to canonical form, it must have an
1590 // indirectbr in it, just give up.
1591 if (!L->getLoopPreheader())
1592 return false;
1593
1594 // Disable loop idiom recognition if the function's name is a common idiom.
1595 StringRef Name = L->getHeader()->getParent()->getName();
1596 if (Name == "memset" || Name == "memcpy" || Name == "memmove")
1597 return false;
1598
1599 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
1600 DL = &L->getHeader()->getModule()->getDataLayout();
1601 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
1602 LF = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
1603 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
1604 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
1605
1606 HasMemcpy = TLI->has(LibFunc_memcpy);
1607 HasMemmove = TLI->has(LibFunc_memmove);
1608
1609 if (SE->hasLoopInvariantBackedgeTakenCount(L))
1610 return runOnCountableLoop(L);
1611 return false;
1612}
1613
1614
1615Pass *llvm::createHexagonLoopIdiomPass() {
1616 return new HexagonLoopIdiomRecognize();
1617}
1618