blob: 580b5d5881f92868e7b66b78f44845d21b876076 [file] [log] [blame]
Anna Welker72ca86f2020-01-14 09:48:02 +00001//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
Anna Welker346f6b52020-01-08 13:08:27 +00002//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This pass custom lowers llvm.gather and llvm.scatter instructions to
10/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11/// produce a better final result as we go.
12//
13//===----------------------------------------------------------------------===//
14
15#include "ARM.h"
16#include "ARMBaseInstrInfo.h"
17#include "ARMSubtarget.h"
Florian Hahna7aaadc2020-04-19 20:02:45 +010018#include "llvm/Analysis/LoopInfo.h"
Anna Welker346f6b52020-01-08 13:08:27 +000019#include "llvm/Analysis/TargetTransformInfo.h"
20#include "llvm/CodeGen/TargetLowering.h"
21#include "llvm/CodeGen/TargetPassConfig.h"
22#include "llvm/CodeGen/TargetSubtargetInfo.h"
23#include "llvm/IR/BasicBlock.h"
24#include "llvm/IR/Constant.h"
25#include "llvm/IR/Constants.h"
26#include "llvm/IR/DerivedTypes.h"
27#include "llvm/IR/Function.h"
Florian Hahna7aaadc2020-04-19 20:02:45 +010028#include "llvm/IR/IRBuilder.h"
Anna Welker346f6b52020-01-08 13:08:27 +000029#include "llvm/IR/InstrTypes.h"
30#include "llvm/IR/Instruction.h"
31#include "llvm/IR/Instructions.h"
32#include "llvm/IR/IntrinsicInst.h"
33#include "llvm/IR/Intrinsics.h"
34#include "llvm/IR/IntrinsicsARM.h"
35#include "llvm/IR/PatternMatch.h"
36#include "llvm/IR/Type.h"
37#include "llvm/IR/Value.h"
Florian Hahna7aaadc2020-04-19 20:02:45 +010038#include "llvm/InitializePasses.h"
Anna Welker346f6b52020-01-08 13:08:27 +000039#include "llvm/Pass.h"
40#include "llvm/Support/Casting.h"
Anna Welker89e12482020-04-08 11:43:55 +010041#include "llvm/Transforms/Utils/Local.h"
Anna Welker346f6b52020-01-08 13:08:27 +000042#include <algorithm>
43#include <cassert>
44
45using namespace llvm;
46
47#define DEBUG_TYPE "mve-gather-scatter-lowering"
48
49cl::opt<bool> EnableMaskedGatherScatters(
50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
51 cl::desc("Enable the generation of masked gathers and scatters"));
52
53namespace {
54
55class MVEGatherScatterLowering : public FunctionPass {
56public:
57 static char ID; // Pass identification, replacement for typeid
58
59 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
61 }
62
63 bool runOnFunction(Function &F) override;
64
65 StringRef getPassName() const override {
66 return "MVE gather/scatter lowering";
67 }
68
69 void getAnalysisUsage(AnalysisUsage &AU) const override {
70 AU.setPreservesCFG();
71 AU.addRequired<TargetPassConfig>();
Anna Welker89e12482020-04-08 11:43:55 +010072 AU.addRequired<LoopInfoWrapperPass>();
Anna Welker346f6b52020-01-08 13:08:27 +000073 FunctionPass::getAnalysisUsage(AU);
74 }
Anna Welker72ca86f2020-01-14 09:48:02 +000075
76private:
77 // Check this is a valid gather with correct alignment
78 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
79 unsigned Alignment);
80 // Check whether Ptr is hidden behind a bitcast and look through it
81 void lookThroughBitcast(Value *&Ptr);
82 // Check for a getelementptr and deduce base and offsets from it, on success
83 // returning the base directly and the offsets indirectly using the Offsets
84 // argument
Nikita Popov98ed6132020-02-17 17:15:06 +010085 Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> &Builder);
Anna Welkerff9877c2020-01-21 09:44:31 +000086 // Compute the scale of this gather/scatter instruction
87 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
Anna Welker72ca86f2020-01-14 09:48:02 +000088
Anna Welker89e12482020-04-08 11:43:55 +010089 Value *lowerGather(IntrinsicInst *I);
Anna Welker72ca86f2020-01-14 09:48:02 +000090 // Create a gather from a base + vector of offsets
91 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
Nikita Popov98ed6132020-02-17 17:15:06 +010092 Instruction *&Root, IRBuilder<> &Builder);
Anna Welker72ca86f2020-01-14 09:48:02 +000093 // Create a gather from a vector of pointers
94 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
Nikita Popov98ed6132020-02-17 17:15:06 +010095 IRBuilder<> &Builder);
Anna Welkerff9877c2020-01-21 09:44:31 +000096
Anna Welker89e12482020-04-08 11:43:55 +010097 Value *lowerScatter(IntrinsicInst *I);
Anna Welkerff9877c2020-01-21 09:44:31 +000098 // Create a scatter to a base + vector of offsets
Anna Welker89e12482020-04-08 11:43:55 +010099 Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
Nikita Popov98ed6132020-02-17 17:15:06 +0100100 IRBuilder<> &Builder);
Anna Welkerff9877c2020-01-21 09:44:31 +0000101 // Create a scatter to a vector of pointers
102 Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
Nikita Popov98ed6132020-02-17 17:15:06 +0100103 IRBuilder<> &Builder);
Anna Welker89e12482020-04-08 11:43:55 +0100104
105 // Check whether these offsets could be moved out of the loop they're in
106 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
107 // Pushes the given add out of the loop
108 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
109 // Pushes the given mul out of the loop
110 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
111 Value *OffsSecondOperand, unsigned LoopIncrement,
112 IRBuilder<> &Builder);
Anna Welker346f6b52020-01-08 13:08:27 +0000113};
114
115} // end anonymous namespace
116
117char MVEGatherScatterLowering::ID = 0;
118
119INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
120 "MVE gather/scattering lowering pass", false, false)
121
122Pass *llvm::createMVEGatherScatterLoweringPass() {
123 return new MVEGatherScatterLowering();
124}
125
Anna Welker72ca86f2020-01-14 09:48:02 +0000126bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
127 unsigned ElemSize,
128 unsigned Alignment) {
Anna Welkerc24cf972020-01-16 13:48:18 +0000129 if (((NumElements == 4 &&
130 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
131 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
Anna Welker72ca86f2020-01-14 09:48:02 +0000132 (NumElements == 16 && ElemSize == 8)) &&
133 ElemSize / 8 <= Alignment)
134 return true;
Anna Welkerff9877c2020-01-21 09:44:31 +0000135 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
136 << "valid alignment or vector type \n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000137 return false;
Anna Welker346f6b52020-01-08 13:08:27 +0000138}
139
Anna Welker72ca86f2020-01-14 09:48:02 +0000140Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr,
Nikita Popov98ed6132020-02-17 17:15:06 +0100141 IRBuilder<> &Builder) {
Anna Welker72ca86f2020-01-14 09:48:02 +0000142 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
143 if (!GEP) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000144 LLVM_DEBUG(
145 dbgs() << "masked gathers/scatters: no getelementpointer found\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000146 return nullptr;
147 }
Anna Welkerff9877c2020-01-21 09:44:31 +0000148 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
149 << " Looking at intrinsic for base + vector of offsets\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000150 Value *GEPPtr = GEP->getPointerOperand();
151 if (GEPPtr->getType()->isVectorTy()) {
Anna Welker72ca86f2020-01-14 09:48:02 +0000152 return nullptr;
153 }
154 if (GEP->getNumOperands() != 2) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000155 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
Anna Welker72ca86f2020-01-14 09:48:02 +0000156 << " operands. Expanding.\n");
157 return nullptr;
158 }
159 Offsets = GEP->getOperand(1);
Anna Welker39497412020-03-02 09:14:37 +0000160 // Paranoid check whether the number of parallel lanes is the same
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700161 assert(cast<VectorType>(Ty)->getNumElements() ==
162 cast<VectorType>(Offsets->getType())->getNumElements());
Anna Welker39497412020-03-02 09:14:37 +0000163 // Only <N x i32> offsets can be integrated into an arm gather, any smaller
164 // type would have to be sign extended by the gep - and arm gathers can only
165 // zero extend. Additionally, the offsets do have to originate from a zext of
166 // a vector with element types smaller or equal the type of the gather we're
167 // looking at
168 if (Offsets->getType()->getScalarSizeInBits() != 32)
169 return nullptr;
Anna Welker72ca86f2020-01-14 09:48:02 +0000170 if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets))
171 Offsets = ZextOffs->getOperand(0);
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700172 else if (!(cast<VectorType>(Offsets->getType())->getNumElements() == 4 &&
Anna Welker39497412020-03-02 09:14:37 +0000173 Offsets->getType()->getScalarSizeInBits() == 32))
174 return nullptr;
175
176 if (Ty != Offsets->getType()) {
177 if ((Ty->getScalarSizeInBits() <
178 Offsets->getType()->getScalarSizeInBits())) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000179 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type."
180 << " Can't create intrinsic.\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000181 return nullptr;
Anna Welker39497412020-03-02 09:14:37 +0000182 } else {
183 Offsets = Builder.CreateZExt(
184 Offsets, VectorType::getInteger(cast<VectorType>(Ty)));
Anna Welker72ca86f2020-01-14 09:48:02 +0000185 }
186 }
187 // If none of the checks failed, return the gep's base pointer
Anna Welker39497412020-03-02 09:14:37 +0000188 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000189 return GEPPtr;
190}
191
192void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
193 // Look through bitcast instruction if #elements is the same
194 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700195 auto *BCTy = cast<VectorType>(BitCast->getType());
196 auto *BCSrcTy = cast<VectorType>(BitCast->getOperand(0)->getType());
197 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000198 LLVM_DEBUG(
199 dbgs() << "masked gathers/scatters: looking through bitcast\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000200 Ptr = BitCast->getOperand(0);
201 }
202 }
203}
204
Anna Welkerff9877c2020-01-21 09:44:31 +0000205int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
206 unsigned MemoryElemSize) {
207 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
208 // or a 8bit, 16bit or 32bit load/store scaled by 1
209 if (GEPElemSize == 32 && MemoryElemSize == 32)
210 return 2;
211 else if (GEPElemSize == 16 && MemoryElemSize == 16)
212 return 1;
213 else if (GEPElemSize == 8)
214 return 0;
215 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
216 << "create intrinsic\n");
217 return -1;
218}
219
Anna Welker89e12482020-04-08 11:43:55 +0100220Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
Anna Welker346f6b52020-01-08 13:08:27 +0000221 using namespace PatternMatch;
222 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
223
224 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
225 // Attempt to turn the masked gather in I into a MVE intrinsic
226 // Potentially optimising the addressing modes as we do so.
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700227 auto *Ty = cast<VectorType>(I->getType());
Anna Welker346f6b52020-01-08 13:08:27 +0000228 Value *Ptr = I->getArgOperand(0);
229 unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue();
230 Value *Mask = I->getArgOperand(2);
231 Value *PassThru = I->getArgOperand(3);
232
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700233 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
234 Alignment))
Anna Welker89e12482020-04-08 11:43:55 +0100235 return nullptr;
Anna Welker72ca86f2020-01-14 09:48:02 +0000236 lookThroughBitcast(Ptr);
237 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
Anna Welker346f6b52020-01-08 13:08:27 +0000238
239 IRBuilder<> Builder(I->getContext());
240 Builder.SetInsertPoint(I);
241 Builder.SetCurrentDebugLocation(I->getDebugLoc());
Anna Welkerc24cf972020-01-16 13:48:18 +0000242
243 Instruction *Root = I;
Anna Welkerc24cf972020-01-16 13:48:18 +0000244 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
Anna Welker72ca86f2020-01-14 09:48:02 +0000245 if (!Load)
246 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
247 if (!Load)
Anna Welker89e12482020-04-08 11:43:55 +0100248 return nullptr;
Anna Welker346f6b52020-01-08 13:08:27 +0000249
250 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
251 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
252 << "creating select\n");
253 Load = Builder.CreateSelect(Mask, Load, PassThru);
254 }
255
Anna Welkerc24cf972020-01-16 13:48:18 +0000256 Root->replaceAllUsesWith(Load);
257 Root->eraseFromParent();
258 if (Root != I)
259 // If this was an extending gather, we need to get rid of the sext/zext
260 // sext/zext as well as of the gather itself
261 I->eraseFromParent();
Anna Welker89e12482020-04-08 11:43:55 +0100262
Anna Welker346f6b52020-01-08 13:08:27 +0000263 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
Anna Welker89e12482020-04-08 11:43:55 +0100264 return Load;
Anna Welker346f6b52020-01-08 13:08:27 +0000265}
266
Anna Welker89e12482020-04-08 11:43:55 +0100267Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
268 Value *Ptr,
269 IRBuilder<> &Builder) {
Anna Welker72ca86f2020-01-14 09:48:02 +0000270 using namespace PatternMatch;
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700271 auto *Ty = cast<VectorType>(I->getType());
Anna Welkerc24cf972020-01-16 13:48:18 +0000272 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700273 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
Anna Welker72ca86f2020-01-14 09:48:02 +0000274 // Can't build an intrinsic for this
275 return nullptr;
276 Value *Mask = I->getArgOperand(2);
277 if (match(Mask, m_One()))
278 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
279 {Ty, Ptr->getType()},
280 {Ptr, Builder.getInt32(0)});
281 else
282 return Builder.CreateIntrinsic(
283 Intrinsic::arm_mve_vldr_gather_base_predicated,
284 {Ty, Ptr->getType(), Mask->getType()},
285 {Ptr, Builder.getInt32(0), Mask});
286}
287
288Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
Nikita Popov98ed6132020-02-17 17:15:06 +0100289 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
Anna Welker72ca86f2020-01-14 09:48:02 +0000290 using namespace PatternMatch;
Anna Welkerc24cf972020-01-16 13:48:18 +0000291
292 Type *OriginalTy = I->getType();
293 Type *ResultTy = OriginalTy;
294
295 unsigned Unsigned = 1;
296 // The size of the gather was already checked in isLegalTypeAndAlignment;
297 // if it was not a full vector width an appropriate extend should follow.
298 auto *Extend = Root;
299 if (OriginalTy->getPrimitiveSizeInBits() < 128) {
300 // Only transform gathers with exactly one use
301 if (!I->hasOneUse())
302 return nullptr;
303
Anna Welker89e12482020-04-08 11:43:55 +0100304 // The correct root to replace is not the CallInst itself, but the
Anna Welkerc24cf972020-01-16 13:48:18 +0000305 // instruction which extends it
306 Extend = cast<Instruction>(*I->users().begin());
307 if (isa<SExtInst>(Extend)) {
308 Unsigned = 0;
309 } else if (!isa<ZExtInst>(Extend)) {
310 LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
311 << "Expanding\n");
312 return nullptr;
313 }
314 LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
315 ResultTy = Extend->getType();
316 // The final size of the gather must be a full vector width
317 if (ResultTy->getPrimitiveSizeInBits() != 128) {
318 LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
319 << "Expanding\n");
320 return nullptr;
321 }
322 }
323
Anna Welker72ca86f2020-01-14 09:48:02 +0000324 Value *Offsets;
Anna Welkerc24cf972020-01-16 13:48:18 +0000325 Value *BasePtr = checkGEP(Offsets, ResultTy, Ptr, Builder);
Anna Welker72ca86f2020-01-14 09:48:02 +0000326 if (!BasePtr)
327 return nullptr;
328
Anna Welkerff9877c2020-01-21 09:44:31 +0000329 int Scale = computeScale(
330 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
331 OriginalTy->getScalarSizeInBits());
332 if (Scale == -1)
Anna Welker72ca86f2020-01-14 09:48:02 +0000333 return nullptr;
Anna Welkerc24cf972020-01-16 13:48:18 +0000334 Root = Extend;
Anna Welker72ca86f2020-01-14 09:48:02 +0000335
336 Value *Mask = I->getArgOperand(2);
337 if (!match(Mask, m_One()))
338 return Builder.CreateIntrinsic(
339 Intrinsic::arm_mve_vldr_gather_offset_predicated,
Anna Welkerc24cf972020-01-16 13:48:18 +0000340 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
341 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
342 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
Anna Welker72ca86f2020-01-14 09:48:02 +0000343 else
344 return Builder.CreateIntrinsic(
345 Intrinsic::arm_mve_vldr_gather_offset,
Anna Welkerc24cf972020-01-16 13:48:18 +0000346 {ResultTy, BasePtr->getType(), Offsets->getType()},
347 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
348 Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
Anna Welker72ca86f2020-01-14 09:48:02 +0000349}
350
Anna Welker89e12482020-04-08 11:43:55 +0100351Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000352 using namespace PatternMatch;
353 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
354
355 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
356 // Attempt to turn the masked scatter in I into a MVE intrinsic
357 // Potentially optimising the addressing modes as we do so.
358 Value *Input = I->getArgOperand(0);
359 Value *Ptr = I->getArgOperand(1);
360 unsigned Alignment = cast<ConstantInt>(I->getArgOperand(2))->getZExtValue();
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700361 auto *Ty = cast<VectorType>(Input->getType());
Anna Welkerff9877c2020-01-21 09:44:31 +0000362
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700363 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
364 Alignment))
Anna Welker89e12482020-04-08 11:43:55 +0100365 return nullptr;
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700366
Anna Welkerff9877c2020-01-21 09:44:31 +0000367 lookThroughBitcast(Ptr);
368 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
369
370 IRBuilder<> Builder(I->getContext());
371 Builder.SetInsertPoint(I);
372 Builder.SetCurrentDebugLocation(I->getDebugLoc());
373
374 Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
375 if (!Store)
376 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
377 if (!Store)
Anna Welker89e12482020-04-08 11:43:55 +0100378 return nullptr;
Anna Welkerff9877c2020-01-21 09:44:31 +0000379
380 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
381 I->replaceAllUsesWith(Store);
382 I->eraseFromParent();
Anna Welker89e12482020-04-08 11:43:55 +0100383 return Store;
Anna Welkerff9877c2020-01-21 09:44:31 +0000384}
385
386Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
Nikita Popov98ed6132020-02-17 17:15:06 +0100387 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000388 using namespace PatternMatch;
389 Value *Input = I->getArgOperand(0);
390 Value *Mask = I->getArgOperand(3);
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700391 auto *Ty = cast<VectorType>(Input->getType());
Anna Welkerff9877c2020-01-21 09:44:31 +0000392 // Only QR variants allow truncating
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700393 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000394 // Can't build an intrinsic for this
395 return nullptr;
396 }
397 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
398 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
399 if (match(Mask, m_One()))
400 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
401 {Ptr->getType(), Input->getType()},
402 {Ptr, Builder.getInt32(0), Input});
403 else
404 return Builder.CreateIntrinsic(
405 Intrinsic::arm_mve_vstr_scatter_base_predicated,
406 {Ptr->getType(), Input->getType(), Mask->getType()},
407 {Ptr, Builder.getInt32(0), Input, Mask});
408}
409
410Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
Nikita Popov98ed6132020-02-17 17:15:06 +0100411 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000412 using namespace PatternMatch;
413 Value *Input = I->getArgOperand(0);
414 Value *Mask = I->getArgOperand(3);
415 Type *InputTy = Input->getType();
416 Type *MemoryTy = InputTy;
417 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
418 << " to base + vector of offsets\n");
419 // If the input has been truncated, try to integrate that trunc into the
420 // scatter instruction (we don't care about alignment here)
421 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
422 Value *PreTrunc = Trunc->getOperand(0);
423 Type *PreTruncTy = PreTrunc->getType();
424 if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
425 Input = PreTrunc;
426 InputTy = PreTruncTy;
427 }
428 }
429 if (InputTy->getPrimitiveSizeInBits() != 128) {
430 LLVM_DEBUG(
431 dbgs() << "masked scatters: cannot create scatters for non-standard"
432 << " input types. Expanding.\n");
433 return nullptr;
434 }
435
436 Value *Offsets;
437 Value *BasePtr = checkGEP(Offsets, InputTy, Ptr, Builder);
438 if (!BasePtr)
439 return nullptr;
440 int Scale = computeScale(
441 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
442 MemoryTy->getScalarSizeInBits());
443 if (Scale == -1)
444 return nullptr;
445
446 if (!match(Mask, m_One()))
447 return Builder.CreateIntrinsic(
448 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
449 {BasePtr->getType(), Offsets->getType(), Input->getType(),
450 Mask->getType()},
451 {BasePtr, Offsets, Input,
452 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
453 Builder.getInt32(Scale), Mask});
454 else
455 return Builder.CreateIntrinsic(
456 Intrinsic::arm_mve_vstr_scatter_offset,
457 {BasePtr->getType(), Offsets->getType(), Input->getType()},
458 {BasePtr, Offsets, Input,
459 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
460 Builder.getInt32(Scale)});
461}
462
Anna Welker89e12482020-04-08 11:43:55 +0100463void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
464 Value *OffsSecondOperand,
465 unsigned StartIndex) {
466 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
Anna Welkerd7365712020-04-16 18:09:24 +0100467 Instruction *InsertionPoint =
Anna Welker89e12482020-04-08 11:43:55 +0100468 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
469 // Initialize the phi with a vector that contains a sum of the constants
470 Instruction *NewIndex = BinaryOperator::Create(
471 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
472 "PushedOutAdd", InsertionPoint);
473 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
474
475 // Order such that start index comes first (this reduces mov's)
476 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
477 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
478 Phi->getIncomingBlock(IncrementIndex));
479 Phi->removeIncomingValue(IncrementIndex);
480 Phi->removeIncomingValue(StartIndex);
481}
482
483void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
484 Value *IncrementPerRound,
485 Value *OffsSecondOperand,
486 unsigned LoopIncrement,
487 IRBuilder<> &Builder) {
488 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
489
490 // Create a new scalar add outside of the loop and transform it to a splat
491 // by which loop variable can be incremented
Anna Welkerd7365712020-04-16 18:09:24 +0100492 Instruction *InsertionPoint = &cast<Instruction>(
Anna Welker89e12482020-04-08 11:43:55 +0100493 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
494
495 // Create a new index
496 Value *StartIndex = BinaryOperator::Create(
497 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
498 OffsSecondOperand, "PushedOutMul", InsertionPoint);
499
500 Instruction *Product =
501 BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
502 OffsSecondOperand, "Product", InsertionPoint);
503 // Increment NewIndex by Product instead of the multiplication
504 Instruction *NewIncrement = BinaryOperator::Create(
505 Instruction::Add, Phi, Product, "IncrementPushedOutMul",
506 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
507 .getPrevNode());
508
509 Phi->addIncoming(StartIndex,
510 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
511 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
512 Phi->removeIncomingValue((unsigned)0);
513 Phi->removeIncomingValue((unsigned)0);
514 return;
515}
516
517// Return true if the given intrinsic is a gather or scatter
518bool isGatherScatter(IntrinsicInst *IntInst) {
519 if (IntInst == nullptr)
520 return false;
521 unsigned IntrinsicID = IntInst->getIntrinsicID();
522 return (IntrinsicID == Intrinsic::masked_gather ||
523 IntrinsicID == Intrinsic::arm_mve_vldr_gather_base ||
524 IntrinsicID == Intrinsic::arm_mve_vldr_gather_base_predicated ||
525 IntrinsicID == Intrinsic::arm_mve_vldr_gather_base_wb ||
526 IntrinsicID == Intrinsic::arm_mve_vldr_gather_base_wb_predicated ||
527 IntrinsicID == Intrinsic::arm_mve_vldr_gather_offset ||
528 IntrinsicID == Intrinsic::arm_mve_vldr_gather_offset_predicated ||
529 IntrinsicID == Intrinsic::masked_scatter ||
530 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_base ||
531 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_base_predicated ||
532 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_base_wb ||
533 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_base_wb_predicated ||
534 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_offset ||
535 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_offset_predicated);
536}
537
538// Check whether all usages of this instruction are as offsets of
539// gathers/scatters or simple arithmetics only used by gathers/scatters
540bool hasAllGatScatUsers(Instruction *I) {
541 if (I->hasNUses(0)) {
542 return false;
543 }
544 bool Gatscat = true;
545 for (User *U : I->users()) {
546 if (!isa<Instruction>(U))
547 return false;
548 if (isa<GetElementPtrInst>(U) ||
549 isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
550 return Gatscat;
551 } else {
552 unsigned OpCode = cast<Instruction>(U)->getOpcode();
553 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
554 hasAllGatScatUsers(cast<Instruction>(U))) {
555 continue;
556 }
557 return false;
558 }
559 }
560 return Gatscat;
561}
562
563bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
564 LoopInfo *LI) {
565 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
566 // Optimise the addresses of gathers/scatters by moving invariant
567 // calculations out of the loop
568 if (!isa<Instruction>(Offsets))
569 return false;
570 Instruction *Offs = cast<Instruction>(Offsets);
571 if (Offs->getOpcode() != Instruction::Add &&
572 Offs->getOpcode() != Instruction::Mul)
573 return false;
574 Loop *L = LI->getLoopFor(BB);
575 if (L == nullptr)
576 return false;
577 if (!Offs->hasOneUse()) {
578 if (!hasAllGatScatUsers(Offs))
579 return false;
580 }
581
582 // Find out which, if any, operand of the instruction
583 // is a phi node
584 PHINode *Phi;
585 int OffsSecondOp;
586 if (isa<PHINode>(Offs->getOperand(0))) {
587 Phi = cast<PHINode>(Offs->getOperand(0));
588 OffsSecondOp = 1;
589 } else if (isa<PHINode>(Offs->getOperand(1))) {
590 Phi = cast<PHINode>(Offs->getOperand(1));
591 OffsSecondOp = 0;
592 } else {
593 bool Changed = true;
594 if (isa<Instruction>(Offs->getOperand(0)) &&
595 L->contains(cast<Instruction>(Offs->getOperand(0))))
596 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
597 if (isa<Instruction>(Offs->getOperand(1)) &&
598 L->contains(cast<Instruction>(Offs->getOperand(1))))
599 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
600 if (!Changed) {
601 return false;
602 } else {
603 if (isa<PHINode>(Offs->getOperand(0))) {
604 Phi = cast<PHINode>(Offs->getOperand(0));
605 OffsSecondOp = 1;
606 } else if (isa<PHINode>(Offs->getOperand(1))) {
607 Phi = cast<PHINode>(Offs->getOperand(1));
608 OffsSecondOp = 0;
609 } else {
610 return false;
611 }
612 }
613 }
614 // A phi node we want to perform this function on should be from the
615 // loop header, and shouldn't have more than 2 incoming values
616 if (Phi->getParent() != L->getHeader() ||
617 Phi->getNumIncomingValues() != 2)
618 return false;
619
620 // The phi must be an induction variable
621 Instruction *Op;
622 int IncrementingBlock = -1;
623
624 for (int i = 0; i < 2; i++)
625 if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr)
626 if (Op->getOpcode() == Instruction::Add &&
627 (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
628 IncrementingBlock = i;
629 if (IncrementingBlock == -1)
630 return false;
631
632 Instruction *IncInstruction =
633 cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
634
635 // If the phi is not used by anything else, we can just adapt it when
636 // replacing the instruction; if it is, we'll have to duplicate it
637 PHINode *NewPhi;
638 Value *IncrementPerRound = IncInstruction->getOperand(
639 (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
640
641 // Get the value that is added to/multiplied with the phi
642 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
643
644 if (IncrementPerRound->getType() != OffsSecondOperand->getType())
645 // Something has gone wrong, abort
646 return false;
647
648 // Only proceed if the increment per round is a constant or an instruction
649 // which does not originate from within the loop
650 if (!isa<Constant>(IncrementPerRound) &&
651 !(isa<Instruction>(IncrementPerRound) &&
652 !L->contains(cast<Instruction>(IncrementPerRound))))
653 return false;
654
655 if (Phi->getNumUses() == 2) {
656 // No other users -> reuse existing phi (One user is the instruction
657 // we're looking at, the other is the phi increment)
658 if (IncInstruction->getNumUses() != 1) {
659 // If the incrementing instruction does have more users than
660 // our phi, we need to copy it
661 IncInstruction = BinaryOperator::Create(
662 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
663 IncrementPerRound, "LoopIncrement", IncInstruction);
664 Phi->setIncomingValue(IncrementingBlock, IncInstruction);
665 }
666 NewPhi = Phi;
667 } else {
668 // There are other users -> create a new phi
669 NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
670 std::vector<Value *> Increases;
671 // Copy the incoming values of the old phi
672 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
673 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
674 IncInstruction = BinaryOperator::Create(
675 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
676 IncrementPerRound, "LoopIncrement", IncInstruction);
677 NewPhi->addIncoming(IncInstruction,
678 Phi->getIncomingBlock(IncrementingBlock));
679 IncrementingBlock = 1;
680 }
681
682 IRBuilder<> Builder(BB->getContext());
683 Builder.SetInsertPoint(Phi);
684 Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
685
686 switch (Offs->getOpcode()) {
687 case Instruction::Add:
688 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
689 break;
690 case Instruction::Mul:
691 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
692 Builder);
693 break;
694 default:
695 return false;
696 }
697 LLVM_DEBUG(
698 dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
699
700 // The instruction has now been "absorbed" into the phi value
701 Offs->replaceAllUsesWith(NewPhi);
702 if (Offs->hasNUses(0))
703 Offs->eraseFromParent();
704 // Clean up the old increment in case it's unused because we built a new
705 // one
706 if (IncInstruction->hasNUses(0))
707 IncInstruction->eraseFromParent();
708
709 return true;
710}
711
Anna Welker346f6b52020-01-08 13:08:27 +0000712bool MVEGatherScatterLowering::runOnFunction(Function &F) {
713 if (!EnableMaskedGatherScatters)
714 return false;
715 auto &TPC = getAnalysis<TargetPassConfig>();
716 auto &TM = TPC.getTM<TargetMachine>();
717 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
718 if (!ST->hasMVEIntegerOps())
719 return false;
720 SmallVector<IntrinsicInst *, 4> Gathers;
Anna Welkerff9877c2020-01-21 09:44:31 +0000721 SmallVector<IntrinsicInst *, 4> Scatters;
Anna Welker89e12482020-04-08 11:43:55 +0100722 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
723
Anna Welker346f6b52020-01-08 13:08:27 +0000724 for (BasicBlock &BB : F) {
725 for (Instruction &I : BB) {
726 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
727 if (II && II->getIntrinsicID() == Intrinsic::masked_gather)
728 Gathers.push_back(II);
Anna Welkerff9877c2020-01-21 09:44:31 +0000729 else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter)
730 Scatters.push_back(II);
Anna Welker346f6b52020-01-08 13:08:27 +0000731 }
732 }
733
Anna Welkerff9877c2020-01-21 09:44:31 +0000734 bool Changed = false;
Anna Welker89e12482020-04-08 11:43:55 +0100735 for (unsigned i = 0; i < Gathers.size(); i++) {
736 IntrinsicInst *I = Gathers[i];
737 if (isa<GetElementPtrInst>(I->getArgOperand(0)))
738 optimiseOffsets(cast<Instruction>(I->getArgOperand(0))->getOperand(1),
739 I->getParent(), &LI);
740 Value *L = lowerGather(I);
741 if (L == nullptr)
742 continue;
743 // Get rid of any now dead instructions
744 SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
745 Changed = true;
746 }
Anna Welker346f6b52020-01-08 13:08:27 +0000747
Anna Welker89e12482020-04-08 11:43:55 +0100748 for (unsigned i = 0; i < Scatters.size(); i++) {
749 IntrinsicInst *I = Scatters[i];
750 if (isa<GetElementPtrInst>(I->getArgOperand(1)))
751 optimiseOffsets(cast<Instruction>(I->getArgOperand(1))->getOperand(1),
752 I->getParent(), &LI);
753 Value *S = lowerScatter(I);
754 if (S == nullptr)
755 continue;
756 // Get rid of any now dead instructions
757 SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
758 Changed = true;
759 }
Anna Welkerff9877c2020-01-21 09:44:31 +0000760 return Changed;
Anna Welker346f6b52020-01-08 13:08:27 +0000761}