blob: 0e48ba4c7a4a5bba3342e81d04e8b64657ce2596 [file] [log] [blame]
Sjoerd Meijerc89ca552018-06-28 12:55:29 +00001//===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2//
Chandler Carruth2946cd72019-01-19 08:50:56 +00003// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Sjoerd Meijerc89ca552018-06-28 12:55:29 +00006//
7//===----------------------------------------------------------------------===//
8//
9/// \file
10/// Armv6 introduced instructions to perform 32-bit SIMD operations. The
11/// purpose of this pass is do some IR pattern matching to create ACLE
12/// DSP intrinsics, which map on these 32-bit SIMD operations.
Sjoerd Meijer53449da2018-07-11 12:36:25 +000013/// This pass runs only when unaligned accesses is supported/enabled.
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000014//
15//===----------------------------------------------------------------------===//
16
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +000017#include "llvm/ADT/Statistic.h"
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000018#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/Analysis/AliasAnalysis.h"
20#include "llvm/Analysis/LoopAccessAnalysis.h"
21#include "llvm/Analysis/LoopPass.h"
22#include "llvm/Analysis/LoopInfo.h"
23#include "llvm/IR/Instructions.h"
24#include "llvm/IR/NoFolder.h"
25#include "llvm/Transforms/Scalar.h"
26#include "llvm/Transforms/Utils/BasicBlockUtils.h"
27#include "llvm/Transforms/Utils/LoopUtils.h"
28#include "llvm/Pass.h"
29#include "llvm/PassRegistry.h"
30#include "llvm/PassSupport.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/IR/PatternMatch.h"
33#include "llvm/CodeGen/TargetPassConfig.h"
34#include "ARM.h"
35#include "ARMSubtarget.h"
36
37using namespace llvm;
38using namespace PatternMatch;
39
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +000040#define DEBUG_TYPE "arm-parallel-dsp"
41
42STATISTIC(NumSMLAD , "Number of smlad instructions generated");
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000043
Sjoerd Meijer3c859b32018-08-14 07:43:49 +000044static cl::opt<bool>
45DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
46 cl::desc("Disable the ARM Parallel DSP pass"));
47
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000048namespace {
Sam Parker89a37992018-07-23 15:25:59 +000049 struct OpChain;
50 struct BinOpChain;
Sam Parker85ad78b2019-07-11 07:47:50 +000051 class Reduction;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000052
Sam Parker3da59e52019-07-26 14:11:40 +000053 using OpChainList = SmallVector<std::unique_ptr<BinOpChain>, 8>;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000054 using ReductionList = SmallVector<Reduction, 8>;
55 using ValueList = SmallVector<Value*, 8>;
Sam Parker4c4ff132019-03-14 11:14:13 +000056 using MemInstList = SmallVector<LoadInst*, 8>;
Sam Parker89a37992018-07-23 15:25:59 +000057 using PMACPair = std::pair<BinOpChain*,BinOpChain*>;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000058 using PMACPairList = SmallVector<PMACPair, 8>;
59 using Instructions = SmallVector<Instruction*,16>;
60 using MemLocList = SmallVector<MemoryLocation, 4>;
61
Sam Parker3da59e52019-07-26 14:11:40 +000062 // 'BinOpChain' holds the multiplication instructions that are candidates
63 // for parallel execution.
64 struct BinOpChain {
Sam Parker89a37992018-07-23 15:25:59 +000065 Instruction *Root;
66 ValueList AllValues;
Sam Parkera33e3112019-05-13 09:23:32 +000067 MemInstList Loads;
Sam Parker3da59e52019-07-26 14:11:40 +000068 MemInstList VecLd; // List of all load instructions.
69 ValueList LHS; // List of all (narrow) left hand operands.
70 ValueList RHS; // List of all (narrow) right hand operands.
71 bool Exchange = false;
Sam Parker89a37992018-07-23 15:25:59 +000072 bool ReadOnly = true;
73
Sam Parker3da59e52019-07-26 14:11:40 +000074 BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
75 Root(I), LHS(lhs), RHS(rhs) {
76 for (auto *V : LHS)
77 AllValues.push_back(V);
78 for (auto *V : RHS)
79 AllValues.push_back(V);
80 }
Sam Parker89a37992018-07-23 15:25:59 +000081
Sam Parkera33e3112019-05-13 09:23:32 +000082 void PopulateLoads() {
Sam Parker89a37992018-07-23 15:25:59 +000083 for (auto *V : AllValues) {
Sam Parkera33e3112019-05-13 09:23:32 +000084 if (auto *Ld = dyn_cast<LoadInst>(V))
85 Loads.push_back(Ld);
Sam Parker89a37992018-07-23 15:25:59 +000086 }
87 }
88
89 unsigned size() const { return AllValues.size(); }
Sam Parker453ba912018-11-09 09:18:00 +000090
91 bool AreSymmetrical(BinOpChain *Other);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000092 };
93
Sam Parker85ad78b2019-07-11 07:47:50 +000094 /// Represent a sequence of multiply-accumulate operations with the aim to
95 /// perform the multiplications in parallel.
96 class Reduction {
97 Instruction *Root = nullptr;
98 Value *Acc = nullptr;
99 OpChainList Muls;
100 PMACPairList MulPairs;
101 SmallPtrSet<Instruction*, 4> Adds;
102
103 public:
104 Reduction() = delete;
105
106 Reduction (Instruction *Add) : Root(Add) { }
107
108 /// Record an Add instruction that is a part of the this reduction.
109 void InsertAdd(Instruction *I) { Adds.insert(I); }
110
111 /// Record a BinOpChain, rooted at a Mul instruction, that is a part of
112 /// this reduction.
113 void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) {
114 Muls.push_back(make_unique<BinOpChain>(I, LHS, RHS));
115 }
116
117 /// Add the incoming accumulator value, returns true if a value had not
118 /// already been added. Returning false signals to the user that this
119 /// reduction already has a value to initialise the accumulator.
120 bool InsertAcc(Value *V) {
121 if (Acc)
122 return false;
123 Acc = V;
124 return true;
125 }
126
127 /// Set two BinOpChains, rooted at muls, that can be executed as a single
128 /// parallel operation.
129 void AddMulPair(BinOpChain *Mul0, BinOpChain *Mul1) {
130 MulPairs.push_back(std::make_pair(Mul0, Mul1));
131 }
132
133 /// Return true if enough mul operations are found that can be executed in
134 /// parallel.
135 bool CreateParallelPairs();
136
137 /// Return the add instruction which is the root of the reduction.
138 Instruction *getRoot() { return Root; }
139
140 /// Return the incoming value to be accumulated. This maybe null.
141 Value *getAccumulator() { return Acc; }
142
143 /// Return the set of adds that comprise the reduction.
144 SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; }
145
146 /// Return the BinOpChain, rooted at mul instruction, that comprise the
147 /// the reduction.
148 OpChainList &getMuls() { return Muls; }
149
150 /// Return the BinOpChain, rooted at mul instructions, that have been
151 /// paired for parallel execution.
152 PMACPairList &getMulPairs() { return MulPairs; }
153
154 /// To finalise, replace the uses of the root with the intrinsic call.
155 void UpdateRoot(Instruction *SMLAD) {
156 Root->replaceAllUsesWith(SMLAD);
157 }
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000158 };
159
Sam Parker4c4ff132019-03-14 11:14:13 +0000160 class WidenedLoad {
161 LoadInst *NewLd = nullptr;
162 SmallVector<LoadInst*, 4> Loads;
163
164 public:
165 WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide)
166 : NewLd(Wide) {
167 for (auto *I : Lds)
168 Loads.push_back(I);
169 }
170 LoadInst *getLoad() {
171 return NewLd;
172 }
173 };
174
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000175 class ARMParallelDSP : public LoopPass {
176 ScalarEvolution *SE;
177 AliasAnalysis *AA;
178 TargetLibraryInfo *TLI;
179 DominatorTree *DT;
180 LoopInfo *LI;
181 Loop *L;
182 const DataLayout *DL;
183 Module *M;
Sam Parker453ba912018-11-09 09:18:00 +0000184 std::map<LoadInst*, LoadInst*> LoadPairs;
Sam Parker85ad78b2019-07-11 07:47:50 +0000185 SmallPtrSet<LoadInst*, 4> OffsetLoads;
Sam Parker4c4ff132019-03-14 11:14:13 +0000186 std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000187
Sam Parker85ad78b2019-07-11 07:47:50 +0000188 template<unsigned>
189 bool IsNarrowSequence(Value *V, ValueList &VL);
190
Sam Parkera33e3112019-05-13 09:23:32 +0000191 bool RecordMemoryOps(BasicBlock *BB);
Sam Parker85ad78b2019-07-11 07:47:50 +0000192 void InsertParallelMACs(Reduction &Reduction);
Fangrui Song68169342018-07-03 19:12:27 +0000193 bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
Sam Parkera33e3112019-05-13 09:23:32 +0000194 LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
195 IntegerType *LoadTy);
Sam Parker85ad78b2019-07-11 07:47:50 +0000196 bool CreateParallelPairs(Reduction &R);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000197
198 /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
199 /// Dual performs two signed 16x16-bit multiplications. It adds the
200 /// products to a 32-bit accumulate operand. Optionally, the instruction can
201 /// exchange the halfwords of the second operand before performing the
202 /// arithmetic.
Sam Parker85ad78b2019-07-11 07:47:50 +0000203 bool MatchSMLAD(Loop *L);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000204
205 public:
206 static char ID;
207
208 ARMParallelDSP() : LoopPass(ID) { }
209
Sam Parkera33e3112019-05-13 09:23:32 +0000210 bool doInitialization(Loop *L, LPPassManager &LPM) override {
211 LoadPairs.clear();
212 WideLoads.clear();
213 return true;
214 }
215
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000216 void getAnalysisUsage(AnalysisUsage &AU) const override {
217 LoopPass::getAnalysisUsage(AU);
218 AU.addRequired<AssumptionCacheTracker>();
219 AU.addRequired<ScalarEvolutionWrapperPass>();
220 AU.addRequired<AAResultsWrapperPass>();
221 AU.addRequired<TargetLibraryInfoWrapperPass>();
222 AU.addRequired<LoopInfoWrapperPass>();
223 AU.addRequired<DominatorTreeWrapperPass>();
224 AU.addRequired<TargetPassConfig>();
225 AU.addPreserved<LoopInfoWrapperPass>();
226 AU.setPreservesCFG();
227 }
228
229 bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
Sjoerd Meijer3c859b32018-08-14 07:43:49 +0000230 if (DisableParallelDSP)
231 return false;
Eli Friedmanb27fc952019-07-23 20:48:46 +0000232 if (skipLoop(TheLoop))
233 return false;
234
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000235 L = TheLoop;
236 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
237 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
238 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
239 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
240 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
241 auto &TPC = getAnalysis<TargetPassConfig>();
242
243 BasicBlock *Header = TheLoop->getHeader();
244 if (!Header)
245 return false;
246
247 // TODO: We assume the loop header and latch to be the same block.
248 // This is not a fundamental restriction, but lifting this would just
249 // require more work to do the transformation and then patch up the CFG.
250 if (Header != TheLoop->getLoopLatch()) {
251 LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
252 "running pass ARMParallelDSP\n");
253 return false;
254 }
255
Sam Parker85ad78b2019-07-11 07:47:50 +0000256 if (!TheLoop->getLoopPreheader())
257 InsertPreheaderForLoop(L, DT, LI, nullptr, true);
Sam Parker9e730202019-03-15 10:19:32 +0000258
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000259 Function &F = *Header->getParent();
260 M = F.getParent();
261 DL = &M->getDataLayout();
262
263 auto &TM = TPC.getTM<TargetMachine>();
264 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
265
266 if (!ST->allowsUnalignedMem()) {
267 LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
268 "running pass ARMParallelDSP\n");
269 return false;
270 }
271
272 if (!ST->hasDSP()) {
273 LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
274 "ARMParallelDSP\n");
275 return false;
276 }
277
Sam Parker9e730202019-03-15 10:19:32 +0000278 if (!ST->isLittle()) {
279 LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass "
Sam Parkera33e3112019-05-13 09:23:32 +0000280 << "ARMParallelDSP\n");
Sam Parker9e730202019-03-15 10:19:32 +0000281 return false;
282 }
283
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000284 LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000285
Sam Parkera023c7a2018-09-12 09:17:44 +0000286 LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
287 LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
Sam Parker453ba912018-11-09 09:18:00 +0000288
Sam Parkera33e3112019-05-13 09:23:32 +0000289 if (!RecordMemoryOps(Header)) {
Sam Parker453ba912018-11-09 09:18:00 +0000290 LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
291 return false;
292 }
293
Sam Parker85ad78b2019-07-11 07:47:50 +0000294 bool Changes = MatchSMLAD(L);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000295 return Changes;
296 }
297 };
298}
299
Sam Parkerffc16812018-07-03 12:44:16 +0000300template<typename MemInst>
301static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
Sam Parker453ba912018-11-09 09:18:00 +0000302 const DataLayout &DL, ScalarEvolution &SE) {
Sam Parker4c4ff132019-03-14 11:14:13 +0000303 if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE))
Sam Parkerffc16812018-07-03 12:44:16 +0000304 return true;
Sam Parkerffc16812018-07-03 12:44:16 +0000305 return false;
306}
307
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000308bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
Sam Parkerffc16812018-07-03 12:44:16 +0000309 MemInstList &VecMem) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000310 if (!Ld0 || !Ld1)
311 return false;
312
Sam Parker4c4ff132019-03-14 11:14:13 +0000313 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
314 return false;
315
316 LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n";
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000317 dbgs() << "Ld0:"; Ld0->dump();
318 dbgs() << "Ld1:"; Ld1->dump();
319 );
320
Sam Parker453ba912018-11-09 09:18:00 +0000321 VecMem.clear();
322 VecMem.push_back(Ld0);
323 VecMem.push_back(Ld1);
324 return true;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000325}
326
Sam Parker85ad78b2019-07-11 07:47:50 +0000327// MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
328// instructions, which is set to 16. So here we should collect all i8 and i16
329// narrow operations.
330// TODO: we currently only collect i16, and will support i8 later, so that's
331// why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
332template<unsigned MaxBitWidth>
333bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) {
Sam Parker74400652019-07-26 10:57:42 +0000334 if (auto *SExt = dyn_cast<SExtInst>(V)) {
335 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
Sam Parker85ad78b2019-07-11 07:47:50 +0000336 return false;
337
Sam Parker74400652019-07-26 10:57:42 +0000338 if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) {
Sam Parker85ad78b2019-07-11 07:47:50 +0000339 // Check that these load could be paired.
340 if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld))
341 return false;
342
Sam Parker74400652019-07-26 10:57:42 +0000343 VL.push_back(Ld);
344 VL.push_back(SExt);
Sam Parker85ad78b2019-07-11 07:47:50 +0000345 return true;
346 }
347 }
348 return false;
349}
350
Sam Parkera33e3112019-05-13 09:23:32 +0000351/// Iterate through the block and record base, offset pairs of loads which can
352/// be widened into a single load.
353bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) {
Sam Parker453ba912018-11-09 09:18:00 +0000354 SmallVector<LoadInst*, 8> Loads;
Sam Parkera33e3112019-05-13 09:23:32 +0000355 SmallVector<Instruction*, 8> Writes;
356
357 // Collect loads and instruction that may write to memory. For now we only
358 // record loads which are simple, sign-extended and have a single user.
359 // TODO: Allow zero-extended loads.
Sam Parker4c4ff132019-03-14 11:14:13 +0000360 for (auto &I : *BB) {
Sam Parkera33e3112019-05-13 09:23:32 +0000361 if (I.mayWriteToMemory())
362 Writes.push_back(&I);
Sam Parker453ba912018-11-09 09:18:00 +0000363 auto *Ld = dyn_cast<LoadInst>(&I);
Sam Parker4c4ff132019-03-14 11:14:13 +0000364 if (!Ld || !Ld->isSimple() ||
365 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back()))
Sam Parker453ba912018-11-09 09:18:00 +0000366 continue;
367 Loads.push_back(Ld);
368 }
369
Sam Parkera33e3112019-05-13 09:23:32 +0000370 using InstSet = std::set<Instruction*>;
371 using DepMap = std::map<Instruction*, InstSet>;
372 DepMap RAWDeps;
373
374 // Record any writes that may alias a load.
375 const auto Size = LocationSize::unknown();
376 for (auto Read : Loads) {
377 for (auto Write : Writes) {
378 MemoryLocation ReadLoc =
379 MemoryLocation(Read->getPointerOperand(), Size);
380
381 if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc),
382 ModRefInfo::ModRef)))
383 continue;
384 if (DT->dominates(Write, Read))
385 RAWDeps[Read].insert(Write);
386 }
387 }
388
389 // Check whether there's not a write between the two loads which would
390 // prevent them from being safely merged.
391 auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) {
392 LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset;
393 LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base;
394
395 if (RAWDeps.count(Dominated)) {
396 InstSet &WritesBefore = RAWDeps[Dominated];
397
398 for (auto Before : WritesBefore) {
399
400 // We can't move the second load backward, past a write, to merge
401 // with the first load.
402 if (DT->dominates(Dominator, Before))
403 return false;
404 }
405 }
406 return true;
407 };
408
409 // Record base, offset load pairs.
410 for (auto *Base : Loads) {
411 for (auto *Offset : Loads) {
412 if (Base == Offset)
Sam Parker453ba912018-11-09 09:18:00 +0000413 continue;
414
Sam Parkera33e3112019-05-13 09:23:32 +0000415 if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) &&
416 SafeToPair(Base, Offset)) {
417 LoadPairs[Base] = Offset;
Sam Parker85ad78b2019-07-11 07:47:50 +0000418 OffsetLoads.insert(Offset);
Sam Parker4c4ff132019-03-14 11:14:13 +0000419 break;
Sam Parker453ba912018-11-09 09:18:00 +0000420 }
421 }
422 }
Sam Parker4c4ff132019-03-14 11:14:13 +0000423
424 LLVM_DEBUG(if (!LoadPairs.empty()) {
425 dbgs() << "Consecutive load pairs:\n";
426 for (auto &MapIt : LoadPairs) {
427 LLVM_DEBUG(dbgs() << *MapIt.first << ", "
428 << *MapIt.second << "\n");
429 }
430 });
Sam Parker453ba912018-11-09 09:18:00 +0000431 return LoadPairs.size() > 1;
432}
433
Sam Parker85ad78b2019-07-11 07:47:50 +0000434// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
435// multiplications.
436// To use SMLAD:
437// 1) we first need to find integer add then look for this pattern:
438//
439// acc0 = ...
440// ld0 = load i16
441// sext0 = sext i16 %ld0 to i32
442// ld1 = load i16
443// sext1 = sext i16 %ld1 to i32
444// mul0 = mul %sext0, %sext1
445// ld2 = load i16
446// sext2 = sext i16 %ld2 to i32
447// ld3 = load i16
448// sext3 = sext i16 %ld3 to i32
449// mul1 = mul i32 %sext2, %sext3
450// add0 = add i32 %mul0, %acc0
451// acc1 = add i32 %add0, %mul1
452//
453// Which can be selected to:
454//
455// ldr r0
456// ldr r1
457// smlad r2, r0, r1, r2
458//
459// If constants are used instead of loads, these will need to be hoisted
460// out and into a register.
461//
462// If loop invariants are used instead of loads, these need to be packed
463// before the loop begins.
464//
465bool ARMParallelDSP::MatchSMLAD(Loop *L) {
466 // Search recursively back through the operands to find a tree of values that
467 // form a multiply-accumulate chain. The search records the Add and Mul
468 // instructions that form the reduction and allows us to find a single value
469 // to be used as the initial input to the accumlator.
470 std::function<bool(Value*, Reduction&)> Search = [&]
471 (Value *V, Reduction &R) -> bool {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000472
Sam Parker85ad78b2019-07-11 07:47:50 +0000473 // If we find a non-instruction, try to use it as the initial accumulator
474 // value. This may have already been found during the search in which case
475 // this function will return false, signaling a search fail.
476 auto *I = dyn_cast<Instruction>(V);
477 if (!I)
478 return R.InsertAcc(V);
Sam Parker453ba912018-11-09 09:18:00 +0000479
Sam Parker85ad78b2019-07-11 07:47:50 +0000480 switch (I->getOpcode()) {
481 default:
482 break;
483 case Instruction::PHI:
484 // Could be the accumulator value.
485 return R.InsertAcc(V);
486 case Instruction::Add: {
487 // Adds should be adding together two muls, or another add and a mul to
488 // be within the mac chain. One of the operands may also be the
489 // accumulator value at which point we should stop searching.
490 bool ValidLHS = Search(I->getOperand(0), R);
491 bool ValidRHS = Search(I->getOperand(1), R);
492 if (!ValidLHS && !ValidLHS)
493 return false;
494 else if (ValidLHS && ValidRHS) {
495 R.InsertAdd(I);
496 return true;
497 } else {
498 R.InsertAdd(I);
499 return R.InsertAcc(I);
500 }
501 }
502 case Instruction::Mul: {
503 Value *MulOp0 = I->getOperand(0);
504 Value *MulOp1 = I->getOperand(1);
505 if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
506 ValueList LHS;
507 ValueList RHS;
508 if (IsNarrowSequence<16>(MulOp0, LHS) &&
509 IsNarrowSequence<16>(MulOp1, RHS)) {
510 R.InsertMul(I, LHS, RHS);
511 return true;
512 }
513 }
514 return false;
515 }
516 case Instruction::SExt:
517 return Search(I->getOperand(0), R);
518 }
519 return false;
520 };
521
522 bool Changed = false;
523 SmallPtrSet<Instruction*, 4> AllAdds;
524 BasicBlock *Latch = L->getLoopLatch();
525
526 for (Instruction &I : reverse(*Latch)) {
527 if (I.getOpcode() != Instruction::Add)
528 continue;
529
530 if (AllAdds.count(&I))
531 continue;
532
533 const auto *Ty = I.getType();
534 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
535 continue;
536
537 Reduction R(&I);
538 if (!Search(&I, R))
539 continue;
540
541 if (!CreateParallelPairs(R))
542 continue;
543
544 InsertParallelMACs(R);
545 Changed = true;
546 AllAdds.insert(R.getAdds().begin(), R.getAdds().end());
547 }
548
549 return Changed;
550}
551
552bool ARMParallelDSP::CreateParallelPairs(Reduction &R) {
553
554 // Not enough mul operations to make a pair.
555 if (R.getMuls().size() < 2)
556 return false;
557
558 // Check that the muls operate directly upon sign extended loads.
559 for (auto &MulChain : R.getMuls()) {
560 // A mul has 2 operands, and a narrow op consist of sext and a load; thus
561 // we expect at least 4 items in this operand value list.
562 if (MulChain->size() < 4) {
563 LLVM_DEBUG(dbgs() << "Operand list too short.\n");
564 return false;
565 }
566 MulChain->PopulateLoads();
567 ValueList &LHS = static_cast<BinOpChain*>(MulChain.get())->LHS;
568 ValueList &RHS = static_cast<BinOpChain*>(MulChain.get())->RHS;
569
570 // Use +=2 to skip over the expected extend instructions.
571 for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
572 if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
573 return false;
574 }
575 }
576
577 auto CanPair = [&](Reduction &R, BinOpChain *PMul0, BinOpChain *PMul1) {
Sam Parker453ba912018-11-09 09:18:00 +0000578 if (!PMul0->AreSymmetrical(PMul1))
579 return false;
580
581 // The first elements of each vector should be loads with sexts. If we
582 // find that its two pairs of consecutive loads, then these can be
583 // transformed into two wider loads and the users can be replaced with
584 // DSP intrinsics.
585 for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
586 auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
587 auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
588 auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
589 auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
590
591 if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
592 return false;
593
Sam Parker4c4ff132019-03-14 11:14:13 +0000594 LLVM_DEBUG(dbgs() << "Loads:\n"
595 << " - " << *Ld0 << "\n"
596 << " - " << *Ld1 << "\n"
597 << " - " << *Ld2 << "\n"
598 << " - " << *Ld3 << "\n");
Sam Parker453ba912018-11-09 09:18:00 +0000599
600 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
601 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
602 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
Sam Parker85ad78b2019-07-11 07:47:50 +0000603 R.AddMulPair(PMul0, PMul1);
Sam Parker453ba912018-11-09 09:18:00 +0000604 return true;
605 } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
606 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
607 LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
608 PMul1->Exchange = true;
Sam Parker85ad78b2019-07-11 07:47:50 +0000609 R.AddMulPair(PMul0, PMul1);
Sam Parker453ba912018-11-09 09:18:00 +0000610 return true;
611 }
612 } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
613 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
614 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
615 LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
616 LLVM_DEBUG(dbgs() << " and swapping muls\n");
617 PMul0->Exchange = true;
618 // Only the second operand can be exchanged, so swap the muls.
Sam Parker85ad78b2019-07-11 07:47:50 +0000619 R.AddMulPair(PMul1, PMul0);
Sam Parker453ba912018-11-09 09:18:00 +0000620 return true;
621 }
622 }
623 return false;
624 };
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000625
Sam Parker85ad78b2019-07-11 07:47:50 +0000626 OpChainList &Muls = R.getMuls();
627 const unsigned Elems = Muls.size();
Sam Parkera023c7a2018-09-12 09:17:44 +0000628 SmallPtrSet<const Instruction*, 4> Paired;
629 for (unsigned i = 0; i < Elems; ++i) {
Sam Parker85ad78b2019-07-11 07:47:50 +0000630 BinOpChain *PMul0 = static_cast<BinOpChain*>(Muls[i].get());
Sam Parkera023c7a2018-09-12 09:17:44 +0000631 if (Paired.count(PMul0->Root))
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000632 continue;
633
Sam Parkera023c7a2018-09-12 09:17:44 +0000634 for (unsigned j = 0; j < Elems; ++j) {
635 if (i == j)
636 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000637
Sam Parker85ad78b2019-07-11 07:47:50 +0000638 BinOpChain *PMul1 = static_cast<BinOpChain*>(Muls[j].get());
Sam Parkera023c7a2018-09-12 09:17:44 +0000639 if (Paired.count(PMul1->Root))
640 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000641
Sam Parkera023c7a2018-09-12 09:17:44 +0000642 const Instruction *Mul0 = PMul0->Root;
643 const Instruction *Mul1 = PMul1->Root;
644 if (Mul0 == Mul1)
645 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000646
Sam Parkera023c7a2018-09-12 09:17:44 +0000647 assert(PMul0 != PMul1 && "expected different chains");
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000648
Sam Parker85ad78b2019-07-11 07:47:50 +0000649 if (CanPair(R, PMul0, PMul1)) {
Sam Parkera023c7a2018-09-12 09:17:44 +0000650 Paired.insert(Mul0);
651 Paired.insert(Mul1);
652 break;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000653 }
654 }
655 }
Sam Parker85ad78b2019-07-11 07:47:50 +0000656 return !R.getMulPairs().empty();
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000657}
658
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000659
Sam Parker85ad78b2019-07-11 07:47:50 +0000660void ARMParallelDSP::InsertParallelMACs(Reduction &R) {
661
662 auto CreateSMLADCall = [&](SmallVectorImpl<LoadInst*> &VecLd0,
663 SmallVectorImpl<LoadInst*> &VecLd1,
664 Value *Acc, bool Exchange,
665 Instruction *InsertAfter) {
666 // Replace the reduction chain with an intrinsic call
667 IntegerType *Ty = IntegerType::get(M->getContext(), 32);
668 LoadInst *WideLd0 = WideLoads.count(VecLd0[0]) ?
669 WideLoads[VecLd0[0]]->getLoad() : CreateWideLoad(VecLd0, Ty);
670 LoadInst *WideLd1 = WideLoads.count(VecLd1[0]) ?
671 WideLoads[VecLd1[0]]->getLoad() : CreateWideLoad(VecLd1, Ty);
672
673 Value* Args[] = { WideLd0, WideLd1, Acc };
674 Function *SMLAD = nullptr;
675 if (Exchange)
676 SMLAD = Acc->getType()->isIntegerTy(32) ?
677 Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
678 Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
679 else
680 SMLAD = Acc->getType()->isIntegerTy(32) ?
681 Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
682 Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
683
684 IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
685 ++BasicBlock::iterator(InsertAfter));
686 Instruction *Call = Builder.CreateCall(SMLAD, Args);
687 NumSMLAD++;
688 return Call;
689 };
690
691 Instruction *InsertAfter = R.getRoot();
692 Value *Acc = R.getAccumulator();
693 if (!Acc)
694 Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0);
695
696 LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n"
697 << "Acc: " << *Acc << "\n");
698 for (auto &Pair : R.getMulPairs()) {
Sam Parkera023c7a2018-09-12 09:17:44 +0000699 BinOpChain *PMul0 = Pair.first;
700 BinOpChain *PMul1 = Pair.second;
Sam Parker85ad78b2019-07-11 07:47:50 +0000701 LLVM_DEBUG(dbgs() << "Muls:\n"
Sam Parkera33e3112019-05-13 09:23:32 +0000702 << "- " << *PMul0->Root << "\n"
703 << "- " << *PMul1->Root << "\n");
Sam Parkera023c7a2018-09-12 09:17:44 +0000704
Sam Parker4c4ff132019-03-14 11:14:13 +0000705 Acc = CreateSMLADCall(PMul0->VecLd, PMul1->VecLd, Acc, PMul1->Exchange,
706 InsertAfter);
Sam Parker85ad78b2019-07-11 07:47:50 +0000707 InsertAfter = cast<Instruction>(Acc);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000708 }
Sam Parker85ad78b2019-07-11 07:47:50 +0000709 R.UpdateRoot(cast<Instruction>(Acc));
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000710}
711
Sam Parkera33e3112019-05-13 09:23:32 +0000712LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads,
713 IntegerType *LoadTy) {
Sam Parker4c4ff132019-03-14 11:14:13 +0000714 assert(Loads.size() == 2 && "currently only support widening two loads");
Sam Parkera33e3112019-05-13 09:23:32 +0000715
716 LoadInst *Base = Loads[0];
717 LoadInst *Offset = Loads[1];
718
719 Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back());
720 Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back());
721
722 assert((BaseSExt && OffsetSExt)
723 && "Loads should have a single, extending, user");
724
725 std::function<void(Value*, Value*)> MoveBefore =
726 [&](Value *A, Value *B) -> void {
727 if (!isa<Instruction>(A) || !isa<Instruction>(B))
728 return;
729
730 auto *Source = cast<Instruction>(A);
731 auto *Sink = cast<Instruction>(B);
732
733 if (DT->dominates(Source, Sink) ||
734 Source->getParent() != Sink->getParent() ||
735 isa<PHINode>(Source) || isa<PHINode>(Sink))
736 return;
737
738 Source->moveBefore(Sink);
Sam Parkeraeb21b92019-07-24 09:38:39 +0000739 for (auto &Op : Source->operands())
740 MoveBefore(Op, Source);
Sam Parkera33e3112019-05-13 09:23:32 +0000741 };
742
743 // Insert the load at the point of the original dominating load.
744 LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset;
745 IRBuilder<NoFolder> IRB(DomLoad->getParent(),
746 ++BasicBlock::iterator(DomLoad));
747
748 // Bitcast the pointer to a wider type and create the wide load, while making
749 // sure to maintain the original alignment as this prevents ldrd from being
750 // generated when it could be illegal due to memory alignment.
751 const unsigned AddrSpace = DomLoad->getPointerAddressSpace();
752 Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(),
Eli Friedmanb09c7782018-10-18 19:34:30 +0000753 LoadTy->getPointerTo(AddrSpace));
Sam Parker4c4ff132019-03-14 11:14:13 +0000754 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr,
Sam Parkera33e3112019-05-13 09:23:32 +0000755 Base->getAlignment());
Sam Parker4c4ff132019-03-14 11:14:13 +0000756
Sam Parkera33e3112019-05-13 09:23:32 +0000757 // Make sure everything is in the correct order in the basic block.
758 MoveBefore(Base->getPointerOperand(), VecPtr);
759 MoveBefore(VecPtr, WideLoad);
Sam Parker4c4ff132019-03-14 11:14:13 +0000760
761 // From the wide load, create two values that equal the original two loads.
Sam Parkera33e3112019-05-13 09:23:32 +0000762 // Loads[0] needs trunc while Loads[1] needs a lshr and trunc.
763 // TODO: Support big-endian as well.
764 Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType());
765 BaseSExt->setOperand(0, Bottom);
Sam Parker4c4ff132019-03-14 11:14:13 +0000766
Sam Parkera33e3112019-05-13 09:23:32 +0000767 IntegerType *OffsetTy = cast<IntegerType>(Offset->getType());
768 Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth());
Sam Parker4c4ff132019-03-14 11:14:13 +0000769 Value *Top = IRB.CreateLShr(WideLoad, ShiftVal);
Sam Parkera33e3112019-05-13 09:23:32 +0000770 Value *Trunc = IRB.CreateTrunc(Top, OffsetTy);
771 OffsetSExt->setOperand(0, Trunc);
Sam Parker4c4ff132019-03-14 11:14:13 +0000772
Sam Parkera33e3112019-05-13 09:23:32 +0000773 WideLoads.emplace(std::make_pair(Base,
Sam Parker4c4ff132019-03-14 11:14:13 +0000774 make_unique<WidenedLoad>(Loads, WideLoad)));
775 return WideLoad;
Eli Friedmanb09c7782018-10-18 19:34:30 +0000776}
777
Sam Parker453ba912018-11-09 09:18:00 +0000778// Compare the value lists in Other to this chain.
779bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
780 // Element-by-element comparison of Value lists returning true if they are
781 // instructions with the same opcode or constants with the same value.
782 auto CompareValueList = [](const ValueList &VL0,
783 const ValueList &VL1) {
784 if (VL0.size() != VL1.size()) {
785 LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
786 << VL0.size() << " != " << VL1.size() << "\n");
787 return false;
788 }
789
790 const unsigned Pairs = VL0.size();
Sam Parker453ba912018-11-09 09:18:00 +0000791
792 for (unsigned i = 0; i < Pairs; ++i) {
793 const Value *V0 = VL0[i];
794 const Value *V1 = VL1[i];
795 const auto *Inst0 = dyn_cast<Instruction>(V0);
796 const auto *Inst1 = dyn_cast<Instruction>(V1);
797
Sam Parker453ba912018-11-09 09:18:00 +0000798 if (!Inst0 || !Inst1)
799 return false;
800
Sam Parker4c4ff132019-03-14 11:14:13 +0000801 if (Inst0->isSameOperationAs(Inst1))
Sam Parker453ba912018-11-09 09:18:00 +0000802 continue;
Sam Parker453ba912018-11-09 09:18:00 +0000803
804 const APInt *C0, *C1;
805 if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
806 return false;
807 }
808
Sam Parker453ba912018-11-09 09:18:00 +0000809 return true;
810 };
811
812 return CompareValueList(LHS, Other->LHS) &&
813 CompareValueList(RHS, Other->RHS);
814}
815
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000816Pass *llvm::createARMParallelDSPPass() {
817 return new ARMParallelDSP();
818}
819
820char ARMParallelDSP::ID = 0;
821
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +0000822INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
Simon Pilgrimc09b5e32018-06-28 18:37:16 +0000823 "Transform loops to use DSP intrinsics", false, false)
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +0000824INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
Simon Pilgrimc09b5e32018-06-28 18:37:16 +0000825 "Transform loops to use DSP intrinsics", false, false)