blob: c43ab7b7238fdf33bb42939b4e5da792fe5ccb46 [file] [log] [blame]
Sjoerd Meijerc89ca552018-06-28 12:55:29 +00001//===- ParallelDSP.cpp - Parallel DSP Pass --------------------------------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9//
10/// \file
11/// Armv6 introduced instructions to perform 32-bit SIMD operations. The
12/// purpose of this pass is do some IR pattern matching to create ACLE
13/// DSP intrinsics, which map on these 32-bit SIMD operations.
Sjoerd Meijer53449da2018-07-11 12:36:25 +000014/// This pass runs only when unaligned accesses is supported/enabled.
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000015//
16//===----------------------------------------------------------------------===//
17
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +000018#include "llvm/ADT/Statistic.h"
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000019#include "llvm/ADT/SmallPtrSet.h"
20#include "llvm/Analysis/AliasAnalysis.h"
21#include "llvm/Analysis/LoopAccessAnalysis.h"
22#include "llvm/Analysis/LoopPass.h"
23#include "llvm/Analysis/LoopInfo.h"
24#include "llvm/IR/Instructions.h"
25#include "llvm/IR/NoFolder.h"
26#include "llvm/Transforms/Scalar.h"
27#include "llvm/Transforms/Utils/BasicBlockUtils.h"
28#include "llvm/Transforms/Utils/LoopUtils.h"
29#include "llvm/Pass.h"
30#include "llvm/PassRegistry.h"
31#include "llvm/PassSupport.h"
32#include "llvm/Support/Debug.h"
33#include "llvm/IR/PatternMatch.h"
34#include "llvm/CodeGen/TargetPassConfig.h"
35#include "ARM.h"
36#include "ARMSubtarget.h"
37
38using namespace llvm;
39using namespace PatternMatch;
40
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +000041#define DEBUG_TYPE "arm-parallel-dsp"
42
43STATISTIC(NumSMLAD , "Number of smlad instructions generated");
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000044
Sjoerd Meijer3c859b32018-08-14 07:43:49 +000045static cl::opt<bool>
46DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false),
47 cl::desc("Disable the ARM Parallel DSP pass"));
48
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000049namespace {
Sam Parker89a37992018-07-23 15:25:59 +000050 struct OpChain;
51 struct BinOpChain;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000052 struct Reduction;
53
Fangrui Song58407ca2018-07-23 17:43:21 +000054 using OpChainList = SmallVector<std::unique_ptr<OpChain>, 8>;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000055 using ReductionList = SmallVector<Reduction, 8>;
56 using ValueList = SmallVector<Value*, 8>;
Sam Parkerffc16812018-07-03 12:44:16 +000057 using MemInstList = SmallVector<Instruction*, 8>;
Sam Parker89a37992018-07-23 15:25:59 +000058 using PMACPair = std::pair<BinOpChain*,BinOpChain*>;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000059 using PMACPairList = SmallVector<PMACPair, 8>;
60 using Instructions = SmallVector<Instruction*,16>;
61 using MemLocList = SmallVector<MemoryLocation, 4>;
62
Sam Parker89a37992018-07-23 15:25:59 +000063 struct OpChain {
64 Instruction *Root;
65 ValueList AllValues;
Eli Friedmanb09c7782018-10-18 19:34:30 +000066 MemInstList VecLd; // List of all load instructions.
Sam Parker89a37992018-07-23 15:25:59 +000067 MemLocList MemLocs; // All memory locations read by this tree.
68 bool ReadOnly = true;
69
70 OpChain(Instruction *I, ValueList &vl) : Root(I), AllValues(vl) { }
Jordan Rupprechte5daf612018-07-23 17:38:05 +000071 virtual ~OpChain() = default;
Sam Parker89a37992018-07-23 15:25:59 +000072
73 void SetMemoryLocations() {
George Burgess IV6ef80022018-10-10 21:28:44 +000074 const auto Size = LocationSize::unknown();
Sam Parker89a37992018-07-23 15:25:59 +000075 for (auto *V : AllValues) {
76 if (auto *I = dyn_cast<Instruction>(V)) {
77 if (I->mayWriteToMemory())
78 ReadOnly = false;
Eli Friedmanb09c7782018-10-18 19:34:30 +000079 if (auto *Ld = dyn_cast<LoadInst>(V))
Sam Parker89a37992018-07-23 15:25:59 +000080 MemLocs.push_back(MemoryLocation(Ld->getPointerOperand(), Size));
81 }
82 }
83 }
84
85 unsigned size() const { return AllValues.size(); }
86 };
87
88 // 'BinOpChain' and 'Reduction' are just some bookkeeping data structures.
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000089 // 'Reduction' contains the phi-node and accumulator statement from where we
Sam Parker89a37992018-07-23 15:25:59 +000090 // start pattern matching, and 'BinOpChain' the multiplication
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000091 // instructions that are candidates for parallel execution.
Sam Parker89a37992018-07-23 15:25:59 +000092 struct BinOpChain : public OpChain {
93 ValueList LHS; // List of all (narrow) left hand operands.
94 ValueList RHS; // List of all (narrow) right hand operands.
Sam Parkera023c7a2018-09-12 09:17:44 +000095 bool Exchange = false;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +000096
Sam Parker89a37992018-07-23 15:25:59 +000097 BinOpChain(Instruction *I, ValueList &lhs, ValueList &rhs) :
98 OpChain(I, lhs), LHS(lhs), RHS(rhs) {
99 for (auto *V : RHS)
100 AllValues.push_back(V);
101 }
Sam Parker453ba912018-11-09 09:18:00 +0000102
103 bool AreSymmetrical(BinOpChain *Other);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000104 };
105
106 struct Reduction {
107 PHINode *Phi; // The Phi-node from where we start
108 // pattern matching.
109 Instruction *AccIntAdd; // The accumulating integer add statement,
110 // i.e, the reduction statement.
Sam Parker89a37992018-07-23 15:25:59 +0000111 OpChainList MACCandidates; // The MAC candidates associated with
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000112 // this reduction statement.
Sam Parker453ba912018-11-09 09:18:00 +0000113 PMACPairList PMACPairs;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000114 Reduction (PHINode *P, Instruction *Acc) : Phi(P), AccIntAdd(Acc) { };
115 };
116
117 class ARMParallelDSP : public LoopPass {
118 ScalarEvolution *SE;
119 AliasAnalysis *AA;
120 TargetLibraryInfo *TLI;
121 DominatorTree *DT;
122 LoopInfo *LI;
123 Loop *L;
124 const DataLayout *DL;
125 Module *M;
Sam Parker453ba912018-11-09 09:18:00 +0000126 std::map<LoadInst*, LoadInst*> LoadPairs;
127 std::map<LoadInst*, SmallVector<LoadInst*, 4>> SequentialLoads;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000128
Sam Parker453ba912018-11-09 09:18:00 +0000129 bool RecordSequentialLoads(BasicBlock *Header);
130 bool InsertParallelMACs(Reduction &Reduction);
Fangrui Song68169342018-07-03 19:12:27 +0000131 bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem);
Sam Parker453ba912018-11-09 09:18:00 +0000132 void CreateParallelMACPairs(Reduction &R);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000133 Instruction *CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
Sam Parkera023c7a2018-09-12 09:17:44 +0000134 Instruction *Acc, bool Exchange,
135 Instruction *InsertAfter);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000136
137 /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate
138 /// Dual performs two signed 16x16-bit multiplications. It adds the
139 /// products to a 32-bit accumulate operand. Optionally, the instruction can
140 /// exchange the halfwords of the second operand before performing the
141 /// arithmetic.
142 bool MatchSMLAD(Function &F);
143
144 public:
145 static char ID;
146
147 ARMParallelDSP() : LoopPass(ID) { }
148
149 void getAnalysisUsage(AnalysisUsage &AU) const override {
150 LoopPass::getAnalysisUsage(AU);
151 AU.addRequired<AssumptionCacheTracker>();
152 AU.addRequired<ScalarEvolutionWrapperPass>();
153 AU.addRequired<AAResultsWrapperPass>();
154 AU.addRequired<TargetLibraryInfoWrapperPass>();
155 AU.addRequired<LoopInfoWrapperPass>();
156 AU.addRequired<DominatorTreeWrapperPass>();
157 AU.addRequired<TargetPassConfig>();
158 AU.addPreserved<LoopInfoWrapperPass>();
159 AU.setPreservesCFG();
160 }
161
162 bool runOnLoop(Loop *TheLoop, LPPassManager &) override {
Sjoerd Meijer3c859b32018-08-14 07:43:49 +0000163 if (DisableParallelDSP)
164 return false;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000165 L = TheLoop;
166 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
167 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
168 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
169 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
170 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
171 auto &TPC = getAnalysis<TargetPassConfig>();
172
173 BasicBlock *Header = TheLoop->getHeader();
174 if (!Header)
175 return false;
176
177 // TODO: We assume the loop header and latch to be the same block.
178 // This is not a fundamental restriction, but lifting this would just
179 // require more work to do the transformation and then patch up the CFG.
180 if (Header != TheLoop->getLoopLatch()) {
181 LLVM_DEBUG(dbgs() << "The loop header is not the loop latch: not "
182 "running pass ARMParallelDSP\n");
183 return false;
184 }
185
186 Function &F = *Header->getParent();
187 M = F.getParent();
188 DL = &M->getDataLayout();
189
190 auto &TM = TPC.getTM<TargetMachine>();
191 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
192
193 if (!ST->allowsUnalignedMem()) {
194 LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not "
195 "running pass ARMParallelDSP\n");
196 return false;
197 }
198
199 if (!ST->hasDSP()) {
200 LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass "
201 "ARMParallelDSP\n");
202 return false;
203 }
204
205 LoopAccessInfo LAI(L, SE, TLI, AA, DT, LI);
206 bool Changes = false;
207
Sam Parkera023c7a2018-09-12 09:17:44 +0000208 LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n");
209 LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n");
Sam Parker453ba912018-11-09 09:18:00 +0000210
211 if (!RecordSequentialLoads(Header)) {
212 LLVM_DEBUG(dbgs() << " - No sequential loads found.\n");
213 return false;
214 }
215
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000216 Changes = MatchSMLAD(F);
217 return Changes;
218 }
219 };
220}
221
Sjoerd Meijer27be58b2018-07-05 08:21:40 +0000222// MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
223// instructions, which is set to 16. So here we should collect all i8 and i16
224// narrow operations.
225// TODO: we currently only collect i16, and will support i8 later, so that's
226// why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
227template<unsigned MaxBitWidth>
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000228static bool IsNarrowSequence(Value *V, ValueList &VL) {
Sjoerd Meijer27be58b2018-07-05 08:21:40 +0000229 LLVM_DEBUG(dbgs() << "Is narrow sequence? "; V->dump());
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000230 ConstantInt *CInt;
231
232 if (match(V, m_ConstantInt(CInt))) {
233 // TODO: if a constant is used, it needs to fit within the bit width.
234 return false;
235 }
236
237 auto *I = dyn_cast<Instruction>(V);
238 if (!I)
239 return false;
240
241 Value *Val, *LHS, *RHS;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000242 if (match(V, m_Trunc(m_Value(Val)))) {
Sjoerd Meijer27be58b2018-07-05 08:21:40 +0000243 if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
244 return IsNarrowSequence<MaxBitWidth>(Val, VL);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000245 } else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
246 // TODO: we need to implement sadd16/sadd8 for this, which enables to
247 // also do the rewrite for smlad8.ll, but it is unsupported for now.
Sjoerd Meijer27be58b2018-07-05 08:21:40 +0000248 LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
249 return false;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000250 } else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
Sjoerd Meijer27be58b2018-07-05 08:21:40 +0000251 if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) {
252 LLVM_DEBUG(dbgs() << "No, wrong SrcTy size: " <<
253 cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() << "\n");
254 return false;
255 }
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000256
Sjoerd Meijer27be58b2018-07-05 08:21:40 +0000257 if (match(Val, m_Load(m_Value()))) {
258 LLVM_DEBUG(dbgs() << "Yes, found narrow Load:\t"; Val->dump());
259 VL.push_back(Val);
260 VL.push_back(I);
261 return true;
262 }
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000263 }
Sjoerd Meijer27be58b2018-07-05 08:21:40 +0000264 LLVM_DEBUG(dbgs() << "No, unsupported Op:\t"; I->dump());
265 return false;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000266}
267
Sam Parkerffc16812018-07-03 12:44:16 +0000268template<typename MemInst>
269static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
Sam Parker453ba912018-11-09 09:18:00 +0000270 const DataLayout &DL, ScalarEvolution &SE) {
Sam Parkerffc16812018-07-03 12:44:16 +0000271 if (!MemOp0->isSimple() || !MemOp1->isSimple()) {
272 LLVM_DEBUG(dbgs() << "No, not touching volatile access\n");
273 return false;
274 }
275 if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE)) {
Sam Parkerffc16812018-07-03 12:44:16 +0000276 LLVM_DEBUG(dbgs() << "OK: accesses are consecutive.\n");
277 return true;
278 }
279 LLVM_DEBUG(dbgs() << "No, accesses aren't consecutive.\n");
280 return false;
281}
282
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000283bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1,
Sam Parkerffc16812018-07-03 12:44:16 +0000284 MemInstList &VecMem) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000285 if (!Ld0 || !Ld1)
286 return false;
287
288 LLVM_DEBUG(dbgs() << "Are consecutive loads:\n";
289 dbgs() << "Ld0:"; Ld0->dump();
290 dbgs() << "Ld1:"; Ld1->dump();
291 );
292
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000293 if (!Ld0->hasOneUse() || !Ld1->hasOneUse()) {
294 LLVM_DEBUG(dbgs() << "No, load has more than one use.\n");
295 return false;
296 }
Sam Parkerffc16812018-07-03 12:44:16 +0000297
Sam Parker453ba912018-11-09 09:18:00 +0000298 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1)
299 return false;
300
301 VecMem.clear();
302 VecMem.push_back(Ld0);
303 VecMem.push_back(Ld1);
304 return true;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000305}
306
Sam Parker453ba912018-11-09 09:18:00 +0000307/// Iterate through the block and record base, offset pairs of loads as well as
308/// maximal sequences of sequential loads.
309bool ARMParallelDSP::RecordSequentialLoads(BasicBlock *Header) {
310 SmallVector<LoadInst*, 8> Loads;
311 for (auto &I : *Header) {
312 auto *Ld = dyn_cast<LoadInst>(&I);
313 if (!Ld)
314 continue;
315 Loads.push_back(Ld);
316 }
317
318 std::map<LoadInst*, LoadInst*> BaseLoads;
319
320 for (auto *Ld0 : Loads) {
321 for (auto *Ld1 : Loads) {
322 if (Ld0 == Ld1)
323 continue;
324
325 if (AreSequentialAccesses<LoadInst>(Ld0, Ld1, *DL, *SE)) {
326 LoadPairs[Ld0] = Ld1;
327 if (BaseLoads.count(Ld0)) {
328 LoadInst *Base = BaseLoads[Ld0];
329 BaseLoads[Ld1] = Base;
330 SequentialLoads[Base].push_back(Ld1);
331 } else {
332 BaseLoads[Ld1] = Ld0;
333 SequentialLoads[Ld0].push_back(Ld1);
334 }
335 }
336 }
337 }
338 return LoadPairs.size() > 1;
339}
340
341void ARMParallelDSP::CreateParallelMACPairs(Reduction &R) {
342 OpChainList &Candidates = R.MACCandidates;
343 PMACPairList &PMACPairs = R.PMACPairs;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000344 const unsigned Elems = Candidates.size();
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000345
346 if (Elems < 2)
Sam Parker453ba912018-11-09 09:18:00 +0000347 return;
348
349 auto CanPair = [&](BinOpChain *PMul0, BinOpChain *PMul1) {
350 if (!PMul0->AreSymmetrical(PMul1))
351 return false;
352
353 // The first elements of each vector should be loads with sexts. If we
354 // find that its two pairs of consecutive loads, then these can be
355 // transformed into two wider loads and the users can be replaced with
356 // DSP intrinsics.
357 for (unsigned x = 0; x < PMul0->LHS.size(); x += 2) {
358 auto *Ld0 = dyn_cast<LoadInst>(PMul0->LHS[x]);
359 auto *Ld1 = dyn_cast<LoadInst>(PMul1->LHS[x]);
360 auto *Ld2 = dyn_cast<LoadInst>(PMul0->RHS[x]);
361 auto *Ld3 = dyn_cast<LoadInst>(PMul1->RHS[x]);
362
363 if (!Ld0 || !Ld1 || !Ld2 || !Ld3)
364 return false;
365
366 LLVM_DEBUG(dbgs() << "Looking at operands " << x << ":\n"
367 << "\t Ld0: " << *Ld0 << "\n"
368 << "\t Ld1: " << *Ld1 << "\n"
369 << "and operands " << x + 2 << ":\n"
370 << "\t Ld2: " << *Ld2 << "\n"
371 << "\t Ld3: " << *Ld3 << "\n");
372
373 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) {
374 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
375 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
376 PMACPairs.push_back(std::make_pair(PMul0, PMul1));
377 return true;
378 } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) {
379 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
380 LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n");
381 PMul1->Exchange = true;
382 PMACPairs.push_back(std::make_pair(PMul0, PMul1));
383 return true;
384 }
385 } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) &&
386 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) {
387 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n");
388 LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n");
389 LLVM_DEBUG(dbgs() << " and swapping muls\n");
390 PMul0->Exchange = true;
391 // Only the second operand can be exchanged, so swap the muls.
392 PMACPairs.push_back(std::make_pair(PMul1, PMul0));
393 return true;
394 }
395 }
396 return false;
397 };
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000398
Sam Parkera023c7a2018-09-12 09:17:44 +0000399 SmallPtrSet<const Instruction*, 4> Paired;
400 for (unsigned i = 0; i < Elems; ++i) {
Fangrui Song58407ca2018-07-23 17:43:21 +0000401 BinOpChain *PMul0 = static_cast<BinOpChain*>(Candidates[i].get());
Sam Parkera023c7a2018-09-12 09:17:44 +0000402 if (Paired.count(PMul0->Root))
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000403 continue;
404
Sam Parkera023c7a2018-09-12 09:17:44 +0000405 for (unsigned j = 0; j < Elems; ++j) {
406 if (i == j)
407 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000408
Sam Parkera023c7a2018-09-12 09:17:44 +0000409 BinOpChain *PMul1 = static_cast<BinOpChain*>(Candidates[j].get());
410 if (Paired.count(PMul1->Root))
411 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000412
Sam Parkera023c7a2018-09-12 09:17:44 +0000413 const Instruction *Mul0 = PMul0->Root;
414 const Instruction *Mul1 = PMul1->Root;
415 if (Mul0 == Mul1)
416 continue;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000417
Sam Parkera023c7a2018-09-12 09:17:44 +0000418 assert(PMul0 != PMul1 && "expected different chains");
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000419
Sam Parkera023c7a2018-09-12 09:17:44 +0000420 LLVM_DEBUG(dbgs() << "\nCheck parallel muls:\n";
421 dbgs() << "- "; Mul0->dump();
422 dbgs() << "- "; Mul1->dump());
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000423
Sam Parkera023c7a2018-09-12 09:17:44 +0000424 LLVM_DEBUG(dbgs() << "OK: mul operands list match:\n");
Sam Parker453ba912018-11-09 09:18:00 +0000425 if (CanPair(PMul0, PMul1)) {
Sam Parkera023c7a2018-09-12 09:17:44 +0000426 Paired.insert(Mul0);
427 Paired.insert(Mul1);
428 break;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000429 }
430 }
431 }
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000432}
433
Sam Parker453ba912018-11-09 09:18:00 +0000434bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000435 Instruction *Acc = Reduction.Phi;
436 Instruction *InsertAfter = Reduction.AccIntAdd;
437
Sam Parker453ba912018-11-09 09:18:00 +0000438 for (auto &Pair : Reduction.PMACPairs) {
Sam Parkera023c7a2018-09-12 09:17:44 +0000439 BinOpChain *PMul0 = Pair.first;
440 BinOpChain *PMul1 = Pair.second;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000441 LLVM_DEBUG(dbgs() << "Found parallel MACs!!\n";
Sam Parkera023c7a2018-09-12 09:17:44 +0000442 dbgs() << "- "; PMul0->Root->dump();
443 dbgs() << "- "; PMul1->Root->dump());
444
445 auto *VecLd0 = cast<LoadInst>(PMul0->VecLd[0]);
446 auto *VecLd1 = cast<LoadInst>(PMul1->VecLd[0]);
447 Acc = CreateSMLADCall(VecLd0, VecLd1, Acc, PMul1->Exchange, InsertAfter);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000448 InsertAfter = Acc;
449 }
450
451 if (Acc != Reduction.Phi) {
452 LLVM_DEBUG(dbgs() << "Replace Accumulate: "; Acc->dump());
453 Reduction.AccIntAdd->replaceAllUsesWith(Acc);
454 return true;
455 }
456 return false;
457}
458
Sam Parker89a37992018-07-23 15:25:59 +0000459static void MatchReductions(Function &F, Loop *TheLoop, BasicBlock *Header,
460 ReductionList &Reductions) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000461 RecurrenceDescriptor RecDesc;
462 const bool HasFnNoNaNAttr =
463 F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true";
464 const BasicBlock *Latch = TheLoop->getLoopLatch();
465
466 // We need a preheader as getIncomingValueForBlock assumes there is one.
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000467 if (!TheLoop->getLoopPreheader()) {
468 LLVM_DEBUG(dbgs() << "No preheader found, bailing out\n");
Sam Parker89a37992018-07-23 15:25:59 +0000469 return;
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000470 }
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000471
472 for (PHINode &Phi : Header->phis()) {
473 const auto *Ty = Phi.getType();
Sam Parker01db2982018-09-11 14:01:22 +0000474 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64))
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000475 continue;
476
477 const bool IsReduction =
478 RecurrenceDescriptor::AddReductionVar(&Phi,
479 RecurrenceDescriptor::RK_IntegerAdd,
480 TheLoop, HasFnNoNaNAttr, RecDesc);
481 if (!IsReduction)
482 continue;
483
484 Instruction *Acc = dyn_cast<Instruction>(Phi.getIncomingValueForBlock(Latch));
485 if (!Acc)
486 continue;
487
488 Reductions.push_back(Reduction(&Phi, Acc));
489 }
490
491 LLVM_DEBUG(
492 dbgs() << "\nAccumulating integer additions (reductions) found:\n";
Sam Parker89a37992018-07-23 15:25:59 +0000493 for (auto &R : Reductions) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000494 dbgs() << "- "; R.Phi->dump();
495 dbgs() << "-> "; R.AccIntAdd->dump();
496 }
497 );
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000498}
499
Eli Friedmanb09c7782018-10-18 19:34:30 +0000500static void AddMACCandidate(OpChainList &Candidates,
Sam Parker01db2982018-09-11 14:01:22 +0000501 Instruction *Mul,
502 Value *MulOp0, Value *MulOp1) {
Eli Friedmanb09c7782018-10-18 19:34:30 +0000503 LLVM_DEBUG(dbgs() << "OK, found acc mul:\t"; Mul->dump());
Sam Parker01db2982018-09-11 14:01:22 +0000504 assert(Mul->getOpcode() == Instruction::Mul &&
505 "expected mul instruction");
Sam Parker89a37992018-07-23 15:25:59 +0000506 ValueList LHS;
507 ValueList RHS;
508 if (IsNarrowSequence<16>(MulOp0, LHS) &&
509 IsNarrowSequence<16>(MulOp1, RHS)) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000510 LLVM_DEBUG(dbgs() << "OK, found narrow mul: "; Mul->dump());
Fangrui Song58407ca2018-07-23 17:43:21 +0000511 Candidates.push_back(make_unique<BinOpChain>(Mul, LHS, RHS));
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000512 }
513}
514
Sam Parker89a37992018-07-23 15:25:59 +0000515static void MatchParallelMACSequences(Reduction &R,
516 OpChainList &Candidates) {
Sam Parkera023c7a2018-09-12 09:17:44 +0000517 Instruction *Acc = R.AccIntAdd;
518 LLVM_DEBUG(dbgs() << "\n- Analysing:\t" << *Acc);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000519
Sam Parkera023c7a2018-09-12 09:17:44 +0000520 // Returns false to signal the search should be stopped.
521 std::function<bool(Value*)> Match =
522 [&Candidates, &Match](Value *V) -> bool {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000523
Sam Parkera023c7a2018-09-12 09:17:44 +0000524 auto *I = dyn_cast<Instruction>(V);
Sam Parker11879112018-09-12 09:58:56 +0000525 if (!I)
Sam Parkera023c7a2018-09-12 09:17:44 +0000526 return false;
Sam Parker01db2982018-09-11 14:01:22 +0000527
Sam Parkera023c7a2018-09-12 09:17:44 +0000528 Value *MulOp0, *MulOp1;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000529
Sam Parkera023c7a2018-09-12 09:17:44 +0000530 switch (I->getOpcode()) {
531 case Instruction::Add:
532 if (Match(I->getOperand(0)) || (Match(I->getOperand(1))))
533 return true;
534 break;
535 case Instruction::Mul:
536 if (match (I, (m_Mul(m_Value(MulOp0), m_Value(MulOp1))))) {
Eli Friedmanb09c7782018-10-18 19:34:30 +0000537 AddMACCandidate(Candidates, I, MulOp0, MulOp1);
Sam Parkera023c7a2018-09-12 09:17:44 +0000538 return false;
539 }
540 break;
541 case Instruction::SExt:
542 if (match (I, (m_SExt(m_Mul(m_Value(MulOp0), m_Value(MulOp1)))))) {
543 Instruction *Mul = cast<Instruction>(I->getOperand(0));
Eli Friedmanb09c7782018-10-18 19:34:30 +0000544 AddMACCandidate(Candidates, Mul, MulOp0, MulOp1);
Sam Parkera023c7a2018-09-12 09:17:44 +0000545 return false;
546 }
547 break;
548 }
549 return false;
550 };
551
552 while (Match (Acc));
553 LLVM_DEBUG(dbgs() << "Finished matching MAC sequences, found "
554 << Candidates.size() << " candidates.\n");
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000555}
556
557// Collects all instructions that are not part of the MAC chains, which is the
558// set of instructions that can potentially alias with the MAC operands.
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000559static void AliasCandidates(BasicBlock *Header, Instructions &Reads,
560 Instructions &Writes) {
561 for (auto &I : *Header) {
562 if (I.mayReadFromMemory())
563 Reads.push_back(&I);
564 if (I.mayWriteToMemory())
565 Writes.push_back(&I);
566 }
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000567}
568
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000569// Check whether statements in the basic block that write to memory alias with
570// the memory locations accessed by the MAC-chains.
571// TODO: we need the read statements when we accept more complicated chains.
572static bool AreAliased(AliasAnalysis *AA, Instructions &Reads,
Eli Friedmanb09c7782018-10-18 19:34:30 +0000573 Instructions &Writes, OpChainList &MACCandidates) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000574 LLVM_DEBUG(dbgs() << "Alias checks:\n");
Eli Friedmanb09c7782018-10-18 19:34:30 +0000575 for (auto &MAC : MACCandidates) {
576 LLVM_DEBUG(dbgs() << "mul: "; MAC->Root->dump());
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000577
578 // At the moment, we allow only simple chains that only consist of reads,
579 // accumulate their result with an integer add, and thus that don't write
580 // memory, and simply bail if they do.
Eli Friedmanb09c7782018-10-18 19:34:30 +0000581 if (!MAC->ReadOnly)
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000582 return true;
583
584 // Now for all writes in the basic block, check that they don't alias with
585 // the memory locations accessed by our MAC-chain:
586 for (auto *I : Writes) {
587 LLVM_DEBUG(dbgs() << "- "; I->dump());
Eli Friedmanb09c7782018-10-18 19:34:30 +0000588 assert(MAC->MemLocs.size() >= 2 && "expecting at least 2 memlocs");
589 for (auto &MemLoc : MAC->MemLocs) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000590 if (isModOrRefSet(intersectModRef(AA->getModRefInfo(I, MemLoc),
591 ModRefInfo::ModRef))) {
592 LLVM_DEBUG(dbgs() << "Yes, aliases found\n");
593 return true;
594 }
595 }
596 }
597 }
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000598
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000599 LLVM_DEBUG(dbgs() << "OK: no aliases found!\n");
600 return false;
601}
602
Eli Friedmanb09c7782018-10-18 19:34:30 +0000603static bool CheckMACMemory(OpChainList &Candidates) {
Fangrui Song58407ca2018-07-23 17:43:21 +0000604 for (auto &C : Candidates) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000605 // A mul has 2 operands, and a narrow op consist of sext and a load; thus
606 // we expect at least 4 items in this operand value list.
Sam Parker89a37992018-07-23 15:25:59 +0000607 if (C->size() < 4) {
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000608 LLVM_DEBUG(dbgs() << "Operand list too short.\n");
609 return false;
610 }
Eli Friedmanb09c7782018-10-18 19:34:30 +0000611 C->SetMemoryLocations();
Fangrui Song58407ca2018-07-23 17:43:21 +0000612 ValueList &LHS = static_cast<BinOpChain*>(C.get())->LHS;
613 ValueList &RHS = static_cast<BinOpChain*>(C.get())->RHS;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000614
Sam Parker89a37992018-07-23 15:25:59 +0000615 // Use +=2 to skip over the expected extend instructions.
616 for (unsigned i = 0, e = LHS.size(); i < e; i += 2) {
617 if (!isa<LoadInst>(LHS[i]) || !isa<LoadInst>(RHS[i]))
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000618 return false;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000619 }
620 }
621 return true;
622}
623
624// Loop Pass that needs to identify integer add/sub reductions of 16-bit vector
625// multiplications.
626// To use SMLAD:
627// 1) we first need to find integer add reduction PHIs,
628// 2) then from the PHI, look for this pattern:
629//
630// acc0 = phi i32 [0, %entry], [%acc1, %loop.body]
631// ld0 = load i16
632// sext0 = sext i16 %ld0 to i32
633// ld1 = load i16
634// sext1 = sext i16 %ld1 to i32
635// mul0 = mul %sext0, %sext1
636// ld2 = load i16
637// sext2 = sext i16 %ld2 to i32
638// ld3 = load i16
639// sext3 = sext i16 %ld3 to i32
640// mul1 = mul i32 %sext2, %sext3
641// add0 = add i32 %mul0, %acc0
642// acc1 = add i32 %add0, %mul1
643//
644// Which can be selected to:
645//
646// ldr.h r0
647// ldr.h r1
648// smlad r2, r0, r1, r2
649//
650// If constants are used instead of loads, these will need to be hoisted
651// out and into a register.
652//
653// If loop invariants are used instead of loads, these need to be packed
654// before the loop begins.
655//
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000656bool ARMParallelDSP::MatchSMLAD(Function &F) {
657 BasicBlock *Header = L->getHeader();
658 LLVM_DEBUG(dbgs() << "= Matching SMLAD =\n";
659 dbgs() << "Header block:\n"; Header->dump();
660 dbgs() << "Loop info:\n\n"; L->dump());
661
Eli Friedmanb09c7782018-10-18 19:34:30 +0000662 bool Changed = false;
Sam Parker89a37992018-07-23 15:25:59 +0000663 ReductionList Reductions;
664 MatchReductions(F, L, Header, Reductions);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000665
666 for (auto &R : Reductions) {
Sam Parker89a37992018-07-23 15:25:59 +0000667 OpChainList MACCandidates;
668 MatchParallelMACSequences(R, MACCandidates);
Eli Friedmanb09c7782018-10-18 19:34:30 +0000669 if (!CheckMACMemory(MACCandidates))
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000670 continue;
Sam Parker89a37992018-07-23 15:25:59 +0000671
Fangrui Song58407ca2018-07-23 17:43:21 +0000672 R.MACCandidates = std::move(MACCandidates);
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000673
674 LLVM_DEBUG(dbgs() << "MAC candidates:\n";
675 for (auto &M : R.MACCandidates)
Sam Parker89a37992018-07-23 15:25:59 +0000676 M->Root->dump();
Sjoerd Meijer53449da2018-07-11 12:36:25 +0000677 dbgs() << "\n";);
678 }
679
680 // Collect all instructions that may read or write memory. Our alias
681 // analysis checks bail out if any of these instructions aliases with an
682 // instruction from the MAC-chain.
683 Instructions Reads, Writes;
684 AliasCandidates(Header, Reads, Writes);
685
686 for (auto &R : Reductions) {
687 if (AreAliased(AA, Reads, Writes, R.MACCandidates))
688 return false;
Sam Parker453ba912018-11-09 09:18:00 +0000689 CreateParallelMACPairs(R);
690 Changed |= InsertParallelMACs(R);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000691 }
692
693 LLVM_DEBUG(if (Changed) dbgs() << "Header block:\n"; Header->dump(););
694 return Changed;
695}
696
Eli Friedmanb09c7782018-10-18 19:34:30 +0000697static LoadInst *CreateLoadIns(IRBuilder<NoFolder> &IRB, LoadInst &BaseLoad,
698 const Type *LoadTy) {
699 const unsigned AddrSpace = BaseLoad.getPointerAddressSpace();
700
701 Value *VecPtr = IRB.CreateBitCast(BaseLoad.getPointerOperand(),
702 LoadTy->getPointerTo(AddrSpace));
703 return IRB.CreateAlignedLoad(VecPtr, BaseLoad.getAlignment());
704}
705
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000706Instruction *ARMParallelDSP::CreateSMLADCall(LoadInst *VecLd0, LoadInst *VecLd1,
Sam Parkera023c7a2018-09-12 09:17:44 +0000707 Instruction *Acc, bool Exchange,
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000708 Instruction *InsertAfter) {
Sam Parkera023c7a2018-09-12 09:17:44 +0000709 LLVM_DEBUG(dbgs() << "Create SMLAD intrinsic using:\n"
710 << "- " << *VecLd0 << "\n"
711 << "- " << *VecLd1 << "\n"
712 << "- " << *Acc << "\n"
713 << "Exchange: " << Exchange << "\n");
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000714
715 IRBuilder<NoFolder> Builder(InsertAfter->getParent(),
716 ++BasicBlock::iterator(InsertAfter));
717
718 // Replace the reduction chain with an intrinsic call
Sam Parker01db2982018-09-11 14:01:22 +0000719 const Type *Ty = IntegerType::get(M->getContext(), 32);
Eli Friedmanb09c7782018-10-18 19:34:30 +0000720 LoadInst *NewLd0 = CreateLoadIns(Builder, VecLd0[0], Ty);
721 LoadInst *NewLd1 = CreateLoadIns(Builder, VecLd1[0], Ty);
Sam Parkera023c7a2018-09-12 09:17:44 +0000722 Value* Args[] = { NewLd0, NewLd1, Acc };
723 Function *SMLAD = nullptr;
724 if (Exchange)
725 SMLAD = Acc->getType()->isIntegerTy(32) ?
726 Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) :
727 Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx);
728 else
729 SMLAD = Acc->getType()->isIntegerTy(32) ?
730 Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) :
731 Intrinsic::getDeclaration(M, Intrinsic::arm_smlald);
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000732 CallInst *Call = Builder.CreateCall(SMLAD, Args);
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +0000733 NumSMLAD++;
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000734 return Call;
735}
736
Sam Parker453ba912018-11-09 09:18:00 +0000737// Compare the value lists in Other to this chain.
738bool BinOpChain::AreSymmetrical(BinOpChain *Other) {
739 // Element-by-element comparison of Value lists returning true if they are
740 // instructions with the same opcode or constants with the same value.
741 auto CompareValueList = [](const ValueList &VL0,
742 const ValueList &VL1) {
743 if (VL0.size() != VL1.size()) {
744 LLVM_DEBUG(dbgs() << "Muls are mismatching operand list lengths: "
745 << VL0.size() << " != " << VL1.size() << "\n");
746 return false;
747 }
748
749 const unsigned Pairs = VL0.size();
750 LLVM_DEBUG(dbgs() << "Number of operand pairs: " << Pairs << "\n");
751
752 for (unsigned i = 0; i < Pairs; ++i) {
753 const Value *V0 = VL0[i];
754 const Value *V1 = VL1[i];
755 const auto *Inst0 = dyn_cast<Instruction>(V0);
756 const auto *Inst1 = dyn_cast<Instruction>(V1);
757
758 LLVM_DEBUG(dbgs() << "Pair " << i << ":\n";
759 dbgs() << "mul1: "; V0->dump();
760 dbgs() << "mul2: "; V1->dump());
761
762 if (!Inst0 || !Inst1)
763 return false;
764
765 if (Inst0->isSameOperationAs(Inst1)) {
766 LLVM_DEBUG(dbgs() << "OK: same operation found!\n");
767 continue;
768 }
769
770 const APInt *C0, *C1;
771 if (!(match(V0, m_APInt(C0)) && match(V1, m_APInt(C1)) && C0 == C1))
772 return false;
773 }
774
775 LLVM_DEBUG(dbgs() << "OK: found symmetrical operand lists.\n");
776 return true;
777 };
778
779 return CompareValueList(LHS, Other->LHS) &&
780 CompareValueList(RHS, Other->RHS);
781}
782
Sjoerd Meijerc89ca552018-06-28 12:55:29 +0000783Pass *llvm::createARMParallelDSPPass() {
784 return new ARMParallelDSP();
785}
786
787char ARMParallelDSP::ID = 0;
788
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +0000789INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp",
Simon Pilgrimc09b5e32018-06-28 18:37:16 +0000790 "Transform loops to use DSP intrinsics", false, false)
Sjoerd Meijerb3e06fa2018-07-06 14:47:09 +0000791INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp",
Simon Pilgrimc09b5e32018-06-28 18:37:16 +0000792 "Transform loops to use DSP intrinsics", false, false)