blob: 8434111d807313d14b6b575e4c6b153055668a28 [file] [log] [blame]
Sjoerd Meijerc89ca552018-06-28 12:55:29 +00001//===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2//
Chandler Carruth2946cd72019-01-19 08:50:56 +00003// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Sjoerd Meijerc89ca552018-06-28 12:55:29 +00006//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11/// purpose of this pass is do some IR pattern matching to create ACLE
12/// DSP intrinsics, which map on these 32-bit SIMD operations.
Sjoerd Meijer53449da2018-07-11 12:36:25 +000013/// This pass runs only when unaligned accesses is supported/enabled.
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000014//
15//===----------------------------------------------------------------------===//
16
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +000017#include "llvm/ADT/Statistic.h"
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000018#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/Analysis/AliasAnalysis.h"
20#include "llvm/Analysis/LoopAccessAnalysis.h"
21#include "llvm/Analysis/LoopPass.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/IR/Instructions.h"
24#include "llvm/IR/NoFolder.h"
25#include "llvm/Transforms/Scalar.h"
26#include "llvm/Transforms/Utils/BasicBlockUtils.h"
27#include "llvm/Transforms/Utils/LoopUtils.h"
28#include "llvm/Pass.h"
29#include "llvm/PassRegistry.h"
30#include "llvm/PassSupport.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/IR/PatternMatch.h"
33#include "llvm/CodeGen/TargetPassConfig.h"
34#include "ARM.h"
35#include "ARMSubtarget.h"
36
37using namespace llvm;
38using namespace PatternMatch;
39
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +000040#define DEBUG_TYPE "arm-parallel-dsp"
41
42STATISTIC(NumSMLAD , "Number of smlad instructions generated");
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000043
Sjoerd Meijer3c859b32018-08-14 07:43:49 +000044static cl::opt<bool>
45DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46 cl::desc("Disable the ARM Parallel DSP pass"));
47
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000048namespace {
Sam Parker89a37992018-07-23 15:25:59 +000049 struct OpChain;
Sam Parker414dd1c2019-07-29 08:41:51 +000050 struct MulCandidate;
Sam Parker85ad78b2019-07-11 07:47:50 +000051 class Reduction;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000052
Sam Parker414dd1c2019-07-29 08:41:51 +000053 using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000054 using ReductionList = SmallVector<Reduction, 8>;
55 using ValueList = SmallVector<Value*, 8>;
Sam Parker4c4ff132019-03-14 11:14:13 +000056 using MemInstList = SmallVector<LoadInst*, 8>;
Sam Parker414dd1c2019-07-29 08:41:51 +000057 using PMACPair = std::pair<MulCandidate*,MulCandidate*>;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000058 using PMACPairList = SmallVector<PMACPair, 8>;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000059
Sam Parker414dd1c2019-07-29 08:41:51 +000060 // 'MulCandidate' holds the multiplication instructions that are candidates
Sam Parker3da59e52019-07-26 14:11:40 +000061 // for parallel execution.
Sam Parker414dd1c2019-07-29 08:41:51 +000062 struct MulCandidate {
Sam Parker89a37992018-07-23 15:25:59 +000063 Instruction *Root;
Sam Parker414dd1c2019-07-29 08:41:51 +000064 MemInstList VecLd; // Container for loads to widen.
65 Value* LHS;
66 Value* RHS;
Sam Parker3da59e52019-07-26 14:11:40 +000067 bool Exchange = false;
Sam Parker89a37992018-07-23 15:25:59 +000068 bool ReadOnly = true;
69
Sam Parker414dd1c2019-07-29 08:41:51 +000070 MulCandidate(Instruction *I, ValueList &lhs, ValueList &rhs) :
71 Root(I), LHS(lhs.front()), RHS(rhs.front()) { }
Sam Parker89a37992018-07-23 15:25:59 +000072
Sam Parker414dd1c2019-07-29 08:41:51 +000073 bool HasTwoLoadInputs() const {
74 return isa<LoadInst>(LHS) && isa<LoadInst>(RHS);
75 }
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000076 };
77
Sam Parker85ad78b2019-07-11 07:47:50 +000078 /// Represent a sequence of multiply-accumulate operations with the aim to
79 /// perform the multiplications in parallel.
80 class Reduction {
81 Instruction *Root = nullptr;
82 Value *Acc = nullptr;
Sam Parker414dd1c2019-07-29 08:41:51 +000083 MulCandList Muls;
Sam Parker85ad78b2019-07-11 07:47:50 +000084 PMACPairList MulPairs;
85 SmallPtrSet<Instruction*, 4> Adds;
86
87 public:
88 Reduction() = delete;
89
90 Reduction (Instruction *Add) : Root(Add) { }
91
92 /// Record an Add instruction that is a part of the this reduction.
93 void InsertAdd(Instruction *I) { Adds.insert(I); }
94
Sam Parker414dd1c2019-07-29 08:41:51 +000095 /// Record a MulCandidate, rooted at a Mul instruction, that is a part of
Sam Parker85ad78b2019-07-11 07:47:50 +000096 /// this reduction.
97 void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) {
Sam Parker414dd1c2019-07-29 08:41:51 +000098 Muls.push_back(make_unique<MulCandidate>(I, LHS, RHS));
Sam Parker85ad78b2019-07-11 07:47:50 +000099 }
100
101 /// Add the incoming accumulator value, returns true if a value had not
102 /// already been added. Returning false signals to the user that this
103 /// reduction already has a value to initialise the accumulator.
104 bool InsertAcc(Value *V) {
105 if (Acc)
106 return false;
107 Acc = V;
108 return true;
109 }
110
Sam Parker414dd1c2019-07-29 08:41:51 +0000111 /// Set two MulCandidates, rooted at muls, that can be executed as a single
Sam Parker85ad78b2019-07-11 07:47:50 +0000112 /// parallel operation.
Sam Parker414dd1c2019-07-29 08:41:51 +0000113 void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1) {
Sam Parker85ad78b2019-07-11 07:47:50 +0000114 MulPairs.push_back(std::make_pair(Mul0, Mul1));
115 }
116
117 /// Return true if enough mul operations are found that can be executed in
118 /// parallel.
119 bool CreateParallelPairs();
120
121 /// Return the add instruction which is the root of the reduction.
122 Instruction *getRoot() { return Root; }
123
124 /// Return the incoming value to be accumulated. This maybe null.
125 Value *getAccumulator() { return Acc; }
126
127 /// Return the set of adds that comprise the reduction.
128 SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; }
129
Sam Parker414dd1c2019-07-29 08:41:51 +0000130 /// Return the MulCandidate, rooted at mul instruction, that comprise the
Sam Parker85ad78b2019-07-11 07:47:50 +0000131 /// the reduction.
Sam Parker414dd1c2019-07-29 08:41:51 +0000132 MulCandList &getMuls() { return Muls; }
Sam Parker85ad78b2019-07-11 07:47:50 +0000133
Sam Parker414dd1c2019-07-29 08:41:51 +0000134 /// Return the MulCandidate, rooted at mul instructions, that have been
Sam Parker85ad78b2019-07-11 07:47:50 +0000135 /// paired for parallel execution.
136 PMACPairList &getMulPairs() { return MulPairs; }
137
138 /// To finalise, replace the uses of the root with the intrinsic call.
139 void UpdateRoot(Instruction *SMLAD) {
140 Root->replaceAllUsesWith(SMLAD);
141 }
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000142 };
143
Sam Parker4c4ff132019-03-14 11:14:13 +0000144 class WidenedLoad {
145 LoadInst *NewLd = nullptr;
146 SmallVector<LoadInst*, 4> Loads;
147
148 public:
149 WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
150 : NewLd(Wide) {
151 for (auto *I : Lds)
152 Loads.push_back(I);
153 }
154 LoadInst *getLoad() {
155 return NewLd;
156 }
157 };
158
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000159 class ARMParallelDSP : public LoopPass {
160 ScalarEvolution *SE;
161 AliasAnalysis *AA;
162 TargetLibraryInfo *TLI;
163 DominatorTree *DT;
164 LoopInfo *LI;
165 Loop *L;
166 const DataLayout *DL;
167 Module *M;
Sam Parker453ba912018-11-09 09:18:00 +0000168 std::map<LoadInst*, LoadInst*> LoadPairs;
Sam Parker85ad78b2019-07-11 07:47:50 +0000169 SmallPtrSet<LoadInst*, 4> OffsetLoads;
Sam Parker4c4ff132019-03-14 11:14:13 +0000170 std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000171
Sam Parker85ad78b2019-07-11 07:47:50 +0000172 template<unsigned>
173 bool IsNarrowSequence(Value *V, ValueList &VL);
174
Sam Parkera33e3112019-05-13 09:23:32 +0000175 bool RecordMemoryOps(BasicBlock *BB);
Sam Parker85ad78b2019-07-11 07:47:50 +0000176 void InsertParallelMACs(Reduction &Reduction);
Fangrui Song68169342018-07-03 19:12:27 +0000177 bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
Sam Parkera33e3112019-05-13 09:23:32 +0000178 LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
179 IntegerType *LoadTy);
Sam Parker85ad78b2019-07-11 07:47:50 +0000180 bool CreateParallelPairs(Reduction &R);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000181
182 /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
183 /// Dual performs two signed 16x16-bit multiplications. It adds the
184 /// products to a 32-bit accumulate operand. Optionally, the instruction can
185 /// exchange the halfwords of the second operand before performing the
186 /// arithmetic.
Sam Parker85ad78b2019-07-11 07:47:50 +0000187 bool MatchSMLAD(Loop *L);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000188
189 public:
190 static char ID;
191
192 ARMParallelDSP() : LoopPass(ID) { }
193
Sam Parkera33e3112019-05-13 09:23:32 +0000194 bool doInitialization(Loop *L, LPPassManager &LPM) override {
195 LoadPairs.clear();
196 WideLoads.clear();
197 return true;
198 }
199
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000200 void getAnalysisUsage(AnalysisUsage &AU) const override {
201 LoopPass::getAnalysisUsage(AU);
202 AU.addRequired<AssumptionCacheTracker>();
203 AU.addRequired<ScalarEvolutionWrapperPass>();
204 AU.addRequired<AAResultsWrapperPass>();
205 AU.addRequired<TargetLibraryInfoWrapperPass>();
206 AU.addRequired<LoopInfoWrapperPass>();
207 AU.addRequired<DominatorTreeWrapperPass>();
208 AU.addRequired<TargetPassConfig>();
209 AU.addPreserved<LoopInfoWrapperPass>();
210 AU.setPreservesCFG();
211 }
212
213 bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
Sjoerd Meijer3c859b32018-08-14 07:43:49 +0000214 if (DisableParallelDSP)
215 return false;
Eli Friedmanb27fc952019-07-23 20:48:46 +0000216 if (skipLoop(TheLoop))
217 return false;
218
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000219 L = TheLoop;
220 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
221 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
222 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
223 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
224 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
225 auto &TPC = getAnalysis<TargetPassConfig>();
226
227 BasicBlock *Header = TheLoop->getHeader();
228 if (!Header)
229 return false;
230
231 // TODO: We assume the loop header and latch to be the same block.
232 // This is not a fundamental restriction, but lifting this would just
233 // require more work to do the transformation and then patch up the CFG.
234 if (Header != TheLoop->getLoopLatch()) {
235 LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
236 "running pass ARMParallelDSP\n");
237 return false;
238 }
239
Sam Parker85ad78b2019-07-11 07:47:50 +0000240 if (!TheLoop->getLoopPreheader())
241 InsertPreheaderForLoop(L, DT, LI, nullptr, true);
Sam Parker9e730202019-03-15 10:19:32 +0000242
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000243 Function &F = *Header->getParent();
244 M = F.getParent();
245 DL = &M->getDataLayout();
246
247 auto &TM = TPC.getTM<TargetMachine>();
248 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
249
250 if (!ST->allowsUnalignedMem()) {
251 LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
252 "running pass ARMParallelDSP\n");
253 return false;
254 }
255
256 if (!ST->hasDSP()) {
257 LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
258 "ARMParallelDSP\n");
259 return false;
260 }
261
Sam Parker9e730202019-03-15 10:19:32 +0000262 if (!ST->isLittle()) {
263 LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
Sam Parkera33e3112019-05-13 09:23:32 +0000264 << "ARMParallelDSP\n");
Sam Parker9e730202019-03-15 10:19:32 +0000265 return false;
266 }
267
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000268 LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000269
Sam Parkera023c7a2018-09-12 09:17:44 +0000270 LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
271 LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
Sam Parker453ba912018-11-09 09:18:00 +0000272
Sam Parkera33e3112019-05-13 09:23:32 +0000273 if (!RecordMemoryOps(Header)) {
Sam Parker453ba912018-11-09 09:18:00 +0000274 LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
275 return false;
276 }
277
Sam Parker85ad78b2019-07-11 07:47:50 +0000278 bool Changes = MatchSMLAD(L);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000279 return Changes;
280 }
281 };
282}
283
Sam Parkerffc16812018-07-03 12:44:16 +0000284template<typename MemInst>
285static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
Sam Parker453ba912018-11-09 09:18:00 +0000286 const DataLayout &DL, ScalarEvolution &SE) {
Sam Parker4c4ff132019-03-14 11:14:13 +0000287 if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
Sam Parkerffc16812018-07-03 12:44:16 +0000288 return true;
Sam Parkerffc16812018-07-03 12:44:16 +0000289 return false;
290}
291
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000292bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
Sam Parkerffc16812018-07-03 12:44:16 +0000293 MemInstList &VecMem) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000294 if (!Ld0 || !Ld1)
295 return false;
296
Sam Parker4c4ff132019-03-14 11:14:13 +0000297 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
298 return false;
299
300 LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000301 dbgs() << "Ld0:"; Ld0->dump();
302 dbgs() << "Ld1:"; Ld1->dump();
303 );
304
Sam Parker453ba912018-11-09 09:18:00 +0000305 VecMem.clear();
306 VecMem.push_back(Ld0);
307 VecMem.push_back(Ld1);
308 return true;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000309}
310
Sam Parker85ad78b2019-07-11 07:47:50 +0000311// MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
312// instructions, which is set to 16. So here we should collect all i8 and i16
313// narrow operations.
314// TODO: we currently only collect i16, and will support i8 later, so that's
315// why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
316template<unsigned MaxBitWidth>
317bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) {
Sam Parker74400652019-07-26 10:57:42 +0000318 if (auto *SExt = dyn_cast<SExtInst>(V)) {
319 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
Sam Parker85ad78b2019-07-11 07:47:50 +0000320 return false;
321
Sam Parker74400652019-07-26 10:57:42 +0000322 if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
Sam Parker85ad78b2019-07-11 07:47:50 +0000323 // Check that these load could be paired.
324 if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld))
325 return false;
326
Sam Parker74400652019-07-26 10:57:42 +0000327 VL.push_back(Ld);
328 VL.push_back(SExt);
Sam Parker85ad78b2019-07-11 07:47:50 +0000329 return true;
330 }
331 }
332 return false;
333}
334
Sam Parkera33e3112019-05-13 09:23:32 +0000335/// Iterate through the block and record base, offset pairs of loads which can
336/// be widened into a single load.
337bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
Sam Parker453ba912018-11-09 09:18:00 +0000338 SmallVector<LoadInst*, 8> Loads;
Sam Parkera33e3112019-05-13 09:23:32 +0000339 SmallVector<Instruction*, 8> Writes;
340
341 // Collect loads and instruction that may write to memory. For now we only
342 // record loads which are simple, sign-extended and have a single user.
343 // TODO: Allow zero-extended loads.
Sam Parker4c4ff132019-03-14 11:14:13 +0000344 for (auto &I : *BB) {
Sam Parkera33e3112019-05-13 09:23:32 +0000345 if (I.mayWriteToMemory())
346 Writes.push_back(&I);
Sam Parker453ba912018-11-09 09:18:00 +0000347 auto *Ld = dyn_cast<LoadInst>(&I);
Sam Parker4c4ff132019-03-14 11:14:13 +0000348 if (!Ld || !Ld->isSimple() ||
349 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
Sam Parker453ba912018-11-09 09:18:00 +0000350 continue;
351 Loads.push_back(Ld);
352 }
353
Sam Parkera33e3112019-05-13 09:23:32 +0000354 using InstSet = std::set<Instruction*>;
355 using DepMap = std::map<Instruction*, InstSet>;
356 DepMap RAWDeps;
357
358 // Record any writes that may alias a load.
359 const auto Size = LocationSize::unknown();
360 for (auto Read : Loads) {
361 for (auto Write : Writes) {
362 MemoryLocation ReadLoc =
363 MemoryLocation(Read->getPointerOperand(), Size);
364
365 if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
366 ModRefInfo::ModRef)))
367 continue;
368 if (DT->dominates(Write, Read))
369 RAWDeps[Read].insert(Write);
370 }
371 }
372
373 // Check whether there's not a write between the two loads which would
374 // prevent them from being safely merged.
375 auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
376 LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
377 LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
378
379 if (RAWDeps.count(Dominated)) {
380 InstSet &WritesBefore = RAWDeps[Dominated];
381
382 for (auto Before : WritesBefore) {
383
384 // We can't move the second load backward, past a write, to merge
385 // with the first load.
386 if (DT->dominates(Dominator, Before))
387 return false;
388 }
389 }
390 return true;
391 };
392
393 // Record base, offset load pairs.
394 for (auto *Base : Loads) {
395 for (auto *Offset : Loads) {
396 if (Base == Offset)
Sam Parker453ba912018-11-09 09:18:00 +0000397 continue;
398
Sam Parkera33e3112019-05-13 09:23:32 +0000399 if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
400 SafeToPair(Base, Offset)) {
401 LoadPairs[Base] = Offset;
Sam Parker85ad78b2019-07-11 07:47:50 +0000402 OffsetLoads.insert(Offset);
Sam Parker4c4ff132019-03-14 11:14:13 +0000403 break;
Sam Parker453ba912018-11-09 09:18:00 +0000404 }
405 }
406 }
Sam Parker4c4ff132019-03-14 11:14:13 +0000407
408 LLVM_DEBUG(if (!LoadPairs.empty()) {
409 dbgs() << "Consecutive load pairs:\n";
410 for (auto &MapIt : LoadPairs) {
411 LLVM_DEBUG(dbgs() << *MapIt.first << ", "
412 << *MapIt.second << "\n");
413 }
414 });
Sam Parker453ba912018-11-09 09:18:00 +0000415 return LoadPairs.size() > 1;
416}
417
Sam Parker85ad78b2019-07-11 07:47:50 +0000418// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
419// multiplications.
420// To use SMLAD:
421// 1) we first need to find integer add then look for this pattern:
422//
423// acc0 = ...
424// ld0 = load i16
425// sext0 = sext i16 %ld0 to i32
426// ld1 = load i16
427// sext1 = sext i16 %ld1 to i32
428// mul0 = mul %sext0, %sext1
429// ld2 = load i16
430// sext2 = sext i16 %ld2 to i32
431// ld3 = load i16
432// sext3 = sext i16 %ld3 to i32
433// mul1 = mul i32 %sext2, %sext3
434// add0 = add i32 %mul0, %acc0
435// acc1 = add i32 %add0, %mul1
436//
437// Which can be selected to:
438//
439// ldr r0
440// ldr r1
441// smlad r2, r0, r1, r2
442//
443// If constants are used instead of loads, these will need to be hoisted
444// out and into a register.
445//
446// If loop invariants are used instead of loads, these need to be packed
447// before the loop begins.
448//
449bool ARMParallelDSP::MatchSMLAD(Loop *L) {
450 // Search recursively back through the operands to find a tree of values that
451 // form a multiply-accumulate chain. The search records the Add and Mul
452 // instructions that form the reduction and allows us to find a single value
453 // to be used as the initial input to the accumlator.
454 std::function<bool(Value*, Reduction&)> Search = [&]
455 (Value *V, Reduction &R) -> bool {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000456
Sam Parker85ad78b2019-07-11 07:47:50 +0000457 // If we find a non-instruction, try to use it as the initial accumulator
458 // value. This may have already been found during the search in which case
459 // this function will return false, signaling a search fail.
460 auto *I = dyn_cast<Instruction>(V);
461 if (!I)
462 return R.InsertAcc(V);
Sam Parker453ba912018-11-09 09:18:00 +0000463
Sam Parker85ad78b2019-07-11 07:47:50 +0000464 switch (I->getOpcode()) {
465 default:
466 break;
467 case Instruction::PHI:
468 // Could be the accumulator value.
469 return R.InsertAcc(V);
470 case Instruction::Add: {
471 // Adds should be adding together two muls, or another add and a mul to
472 // be within the mac chain. One of the operands may also be the
473 // accumulator value at which point we should stop searching.
474 bool ValidLHS = Search(I->getOperand(0), R);
475 bool ValidRHS = Search(I->getOperand(1), R);
476 if (!ValidLHS && !ValidLHS)
477 return false;
478 else if (ValidLHS && ValidRHS) {
479 R.InsertAdd(I);
480 return true;
481 } else {
482 R.InsertAdd(I);
483 return R.InsertAcc(I);
484 }
485 }
486 case Instruction::Mul: {
487 Value *MulOp0 = I->getOperand(0);
488 Value *MulOp1 = I->getOperand(1);
489 if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
490 ValueList LHS;
491 ValueList RHS;
492 if (IsNarrowSequence<16>(MulOp0, LHS) &&
493 IsNarrowSequence<16>(MulOp1, RHS)) {
494 R.InsertMul(I, LHS, RHS);
495 return true;
496 }
497 }
498 return false;
499 }
500 case Instruction::SExt:
501 return Search(I->getOperand(0), R);
502 }
503 return false;
504 };
505
506 bool Changed = false;
507 SmallPtrSet<Instruction*, 4> AllAdds;
508 BasicBlock *Latch = L->getLoopLatch();
509
510 for (Instruction &I : reverse(*Latch)) {
511 if (I.getOpcode() != Instruction::Add)
512 continue;
513
514 if (AllAdds.count(&I))
515 continue;
516
517 const auto *Ty = I.getType();
518 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
519 continue;
520
521 Reduction R(&I);
522 if (!Search(&I, R))
523 continue;
524
525 if (!CreateParallelPairs(R))
526 continue;
527
528 InsertParallelMACs(R);
529 Changed = true;
530 AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
531 }
532
533 return Changed;
534}
535
536bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
537
538 // Not enough mul operations to make a pair.
539 if (R.getMuls().size() < 2)
540 return false;
541
542 // Check that the muls operate directly upon sign extended loads.
Sam Parker414dd1c2019-07-29 08:41:51 +0000543 for (auto &MulCand : R.getMuls()) {
544 if (!MulCand->HasTwoLoadInputs())
Sam Parker85ad78b2019-07-11 07:47:50 +0000545 return false;
Sam Parker85ad78b2019-07-11 07:47:50 +0000546 }
547
Sam Parker414dd1c2019-07-29 08:41:51 +0000548 auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) {
Sam Parker453ba912018-11-09 09:18:00 +0000549 // The first elements of each vector should be loads with sexts. If we
550 // find that its two pairs of consecutive loads, then these can be
551 // transformed into two wider loads and the users can be replaced with
552 // DSP intrinsics.
Sam Parker414dd1c2019-07-29 08:41:51 +0000553 auto Ld0 = static_cast<LoadInst*>(PMul0->LHS);
554 auto Ld1 = static_cast<LoadInst*>(PMul1->LHS);
555 auto Ld2 = static_cast<LoadInst*>(PMul0->RHS);
556 auto Ld3 = static_cast<LoadInst*>(PMul1->RHS);
Sam Parker453ba912018-11-09 09:18:00 +0000557
Sam Parker414dd1c2019-07-29 08:41:51 +0000558 LLVM_DEBUG(dbgs() << "Loads:\n"
559 << " - " << *Ld0 << "\n"
560 << " - " << *Ld1 << "\n"
561 << " - " << *Ld2 << "\n"
562 << " - " << *Ld3 << "\n");
Sam Parker453ba912018-11-09 09:18:00 +0000563
Sam Parker414dd1c2019-07-29 08:41:51 +0000564 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
565 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
Sam Parker453ba912018-11-09 09:18:00 +0000566 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
Sam Parker414dd1c2019-07-29 08:41:51 +0000567 R.AddMulPair(PMul0, PMul1);
568 return true;
569 } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
570 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
571 LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
572 PMul1->Exchange = true;
573 R.AddMulPair(PMul0, PMul1);
Sam Parker453ba912018-11-09 09:18:00 +0000574 return true;
575 }
Sam Parker414dd1c2019-07-29 08:41:51 +0000576 } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
577 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
578 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
579 LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
580 LLVM_DEBUG(dbgs() << " and swapping muls\n");
581 PMul0->Exchange = true;
582 // Only the second operand can be exchanged, so swap the muls.
583 R.AddMulPair(PMul1, PMul0);
584 return true;
Sam Parker453ba912018-11-09 09:18:00 +0000585 }
586 return false;
587 };
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000588
Sam Parker414dd1c2019-07-29 08:41:51 +0000589 MulCandList &Muls = R.getMuls();
Sam Parker85ad78b2019-07-11 07:47:50 +0000590 const unsigned Elems = Muls.size();
Sam Parkera023c7a2018-09-12 09:17:44 +0000591 SmallPtrSet<const Instruction*, 4> Paired;
592 for (unsigned i = 0; i < Elems; ++i) {
Sam Parker414dd1c2019-07-29 08:41:51 +0000593 MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get());
Sam Parkera023c7a2018-09-12 09:17:44 +0000594 if (Paired.count(PMul0->Root))
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000595 continue;
596
Sam Parkera023c7a2018-09-12 09:17:44 +0000597 for (unsigned j = 0; j < Elems; ++j) {
598 if (i == j)
599 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000600
Sam Parker414dd1c2019-07-29 08:41:51 +0000601 MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get());
Sam Parkera023c7a2018-09-12 09:17:44 +0000602 if (Paired.count(PMul1->Root))
603 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000604
Sam Parkera023c7a2018-09-12 09:17:44 +0000605 const Instruction *Mul0 = PMul0->Root;
606 const Instruction *Mul1 = PMul1->Root;
607 if (Mul0 == Mul1)
608 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000609
Sam Parkera023c7a2018-09-12 09:17:44 +0000610 assert(PMul0 != PMul1 && "expected different chains");
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000611
Sam Parker85ad78b2019-07-11 07:47:50 +0000612 if (CanPair(R, PMul0, PMul1)) {
Sam Parkera023c7a2018-09-12 09:17:44 +0000613 Paired.insert(Mul0);
614 Paired.insert(Mul1);
615 break;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000616 }
617 }
618 }
Sam Parker85ad78b2019-07-11 07:47:50 +0000619 return !R.getMulPairs().empty();
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000620}
621
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000622
Sam Parker85ad78b2019-07-11 07:47:50 +0000623void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
624
625 auto CreateSMLADCall = [&](SmallVectorImpl<LoadInst*> &VecLd0,
626 SmallVectorImpl<LoadInst*> &VecLd1,
627 Value *Acc, bool Exchange,
628 Instruction *InsertAfter) {
629 // Replace the reduction chain with an intrinsic call
630 IntegerType *Ty = IntegerType::get(M->getContext(), 32);
631 LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
632 WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
633 LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
634 WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
635
636 Value* Args[] = { WideLd0, WideLd1, Acc };
637 Function *SMLAD = nullptr;
638 if (Exchange)
639 SMLAD = Acc->getType()->isIntegerTy(32) ?
640 Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
641 Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
642 else
643 SMLAD = Acc->getType()->isIntegerTy(32) ?
644 Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
645 Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
646
647 IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
648 ++BasicBlock::iterator(InsertAfter));
649 Instruction *Call = Builder.CreateCall(SMLAD, Args);
650 NumSMLAD++;
651 return Call;
652 };
653
654 Instruction *InsertAfter = R.getRoot();
655 Value *Acc = R.getAccumulator();
656 if (!Acc)
657 Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
658
659 LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"
660 << "Acc: " << *Acc << "\n");
661 for (auto &Pair : R.getMulPairs()) {
Sam Parker414dd1c2019-07-29 08:41:51 +0000662 MulCandidate *PMul0 = Pair.first;
663 MulCandidate *PMul1 = Pair.second;
Sam Parker85ad78b2019-07-11 07:47:50 +0000664 LLVM_DEBUG(dbgs() << "Muls:\n"
Sam Parkera33e3112019-05-13 09:23:32 +0000665 << "- " << *PMul0->Root << "\n"
666 << "- " << *PMul1->Root << "\n");
Sam Parkera023c7a2018-09-12 09:17:44 +0000667
Sam Parker4c4ff132019-03-14 11:14:13 +0000668 Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
669 InsertAfter);
Sam Parker85ad78b2019-07-11 07:47:50 +0000670 InsertAfter = cast<Instruction>(Acc);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000671 }
Sam Parker85ad78b2019-07-11 07:47:50 +0000672 R.UpdateRoot(cast<Instruction>(Acc));
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000673}
674
Sam Parkera33e3112019-05-13 09:23:32 +0000675LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
676 IntegerType *LoadTy) {
Sam Parker4c4ff132019-03-14 11:14:13 +0000677 assert(Loads.size() == 2 && "currently only support widening two loads");
Sam Parkera33e3112019-05-13 09:23:32 +0000678
679 LoadInst *Base = Loads[0];
680 LoadInst *Offset = Loads[1];
681
682 Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
683 Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
684
685 assert((BaseSExt && OffsetSExt)
686 && "Loads should have a single, extending, user");
687
688 std::function<void(Value*, Value*)> MoveBefore =
689 [&](Value *A, Value *B) -> void {
690 if (!isa<Instruction>(A) || !isa<Instruction>(B))
691 return;
692
693 auto *Source = cast<Instruction>(A);
694 auto *Sink = cast<Instruction>(B);
695
696 if (DT->dominates(Source, Sink) ||
697 Source->getParent() != Sink->getParent() ||
698 isa<PHINode>(Source) || isa<PHINode>(Sink))
699 return;
700
701 Source->moveBefore(Sink);
Sam Parkeraeb21b92019-07-24 09:38:39 +0000702 for (auto &Op : Source->operands())
703 MoveBefore(Op, Source);
Sam Parkera33e3112019-05-13 09:23:32 +0000704 };
705
706 // Insert the load at the point of the original dominating load.
707 LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
708 IRBuilder<NoFolder> IRB(DomLoad->getParent(),
709 ++BasicBlock::iterator(DomLoad));
710
711 // Bitcast the pointer to a wider type and create the wide load, while making
712 // sure to maintain the original alignment as this prevents ldrd from being
713 // generated when it could be illegal due to memory alignment.
714 const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
715 Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
Eli Friedmanb09c7782018-10-18 19:34:30 +0000716 LoadTy->getPointerTo(AddrSpace));
Sam Parker4c4ff132019-03-14 11:14:13 +0000717 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
Sam Parkera33e3112019-05-13 09:23:32 +0000718 Base->getAlignment());
Sam Parker4c4ff132019-03-14 11:14:13 +0000719
Sam Parkera33e3112019-05-13 09:23:32 +0000720 // Make sure everything is in the correct order in the basic block.
721 MoveBefore(Base->getPointerOperand(), VecPtr);
722 MoveBefore(VecPtr, WideLoad);
Sam Parker4c4ff132019-03-14 11:14:13 +0000723
724 // From the wide load, create two values that equal the original two loads.
Sam Parkera33e3112019-05-13 09:23:32 +0000725 // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
726 // TODO: Support big-endian as well.
727 Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
728 BaseSExt->setOperand(0, Bottom);
Sam Parker4c4ff132019-03-14 11:14:13 +0000729
Sam Parkera33e3112019-05-13 09:23:32 +0000730 IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
731 Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
Sam Parker4c4ff132019-03-14 11:14:13 +0000732 Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
Sam Parkera33e3112019-05-13 09:23:32 +0000733 Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
734 OffsetSExt->setOperand(0, Trunc);
Sam Parker4c4ff132019-03-14 11:14:13 +0000735
Sam Parkera33e3112019-05-13 09:23:32 +0000736 WideLoads.emplace(std::make_pair(Base,
Sam Parker4c4ff132019-03-14 11:14:13 +0000737 make_unique<WidenedLoad>(Loads, WideLoad)));
738 return WideLoad;
Eli Friedmanb09c7782018-10-18 19:34:30 +0000739}
740
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000741Pass *llvm::createARMParallelDSPPass() {
742 return new ARMParallelDSP();
743}
744
745char ARMParallelDSP::ID = 0;
746
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +0000747INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
Simon Pilgrimc09b5e32018-06-28 18:37:16 +0000748 "Transform loops to use DSP intrinsics", false, false)
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +0000749INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
Simon Pilgrimc09b5e32018-06-28 18:37:16 +0000750 "Transform loops to use DSP intrinsics", false, false)