blob: e4f6d1200a8250762f4c5f9c227b2e03872aee07 [file] [log] [blame]
Anna Welker72ca86f2020-01-14 09:48:02 +00001//===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===//
Anna Welker346f6b52020-01-08 13:08:27 +00002//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9/// This pass custom lowers llvm.gather and llvm.scatter instructions to
10/// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to
11/// produce a better final result as we go.
12//
13//===----------------------------------------------------------------------===//
14
15#include "ARM.h"
16#include "ARMBaseInstrInfo.h"
17#include "ARMSubtarget.h"
18#include "llvm/Analysis/TargetTransformInfo.h"
19#include "llvm/CodeGen/TargetLowering.h"
20#include "llvm/CodeGen/TargetPassConfig.h"
21#include "llvm/CodeGen/TargetSubtargetInfo.h"
Anna Welker72ca86f2020-01-14 09:48:02 +000022#include "llvm/InitializePasses.h"
Anna Welker346f6b52020-01-08 13:08:27 +000023#include "llvm/IR/BasicBlock.h"
24#include "llvm/IR/Constant.h"
25#include "llvm/IR/Constants.h"
26#include "llvm/IR/DerivedTypes.h"
27#include "llvm/IR/Function.h"
Anna Welker346f6b52020-01-08 13:08:27 +000028#include "llvm/IR/InstrTypes.h"
29#include "llvm/IR/Instruction.h"
30#include "llvm/IR/Instructions.h"
31#include "llvm/IR/IntrinsicInst.h"
32#include "llvm/IR/Intrinsics.h"
33#include "llvm/IR/IntrinsicsARM.h"
Anna Welker72ca86f2020-01-14 09:48:02 +000034#include "llvm/IR/IRBuilder.h"
Anna Welker346f6b52020-01-08 13:08:27 +000035#include "llvm/IR/PatternMatch.h"
36#include "llvm/IR/Type.h"
37#include "llvm/IR/Value.h"
Anna Welker346f6b52020-01-08 13:08:27 +000038#include "llvm/Pass.h"
39#include "llvm/Support/Casting.h"
Anna Welker89e12482020-04-08 11:43:55 +010040#include "llvm/Transforms/Utils/Local.h"
Anna Welker346f6b52020-01-08 13:08:27 +000041#include <algorithm>
42#include <cassert>
43
44using namespace llvm;
45
46#define DEBUG_TYPE "mve-gather-scatter-lowering"
47
48cl::opt<bool> EnableMaskedGatherScatters(
49 "enable-arm-maskedgatscat", cl::Hidden, cl::init(false),
50 cl::desc("Enable the generation of masked gathers and scatters"));
51
52namespace {
53
54class MVEGatherScatterLowering : public FunctionPass {
55public:
56 static char ID; // Pass identification, replacement for typeid
57
58 explicit MVEGatherScatterLowering() : FunctionPass(ID) {
59 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry());
60 }
61
62 bool runOnFunction(Function &F) override;
63
64 StringRef getPassName() const override {
65 return "MVE gather/scatter lowering";
66 }
67
68 void getAnalysisUsage(AnalysisUsage &AU) const override {
69 AU.setPreservesCFG();
70 AU.addRequired<TargetPassConfig>();
Anna Welker89e12482020-04-08 11:43:55 +010071 AU.addRequired<LoopInfoWrapperPass>();
Anna Welker346f6b52020-01-08 13:08:27 +000072 FunctionPass::getAnalysisUsage(AU);
73 }
Anna Welker72ca86f2020-01-14 09:48:02 +000074
75private:
76 // Check this is a valid gather with correct alignment
77 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize,
78 unsigned Alignment);
79 // Check whether Ptr is hidden behind a bitcast and look through it
80 void lookThroughBitcast(Value *&Ptr);
81 // Check for a getelementptr and deduce base and offsets from it, on success
82 // returning the base directly and the offsets indirectly using the Offsets
83 // argument
Nikita Popov98ed6132020-02-17 17:15:06 +010084 Value *checkGEP(Value *&Offsets, Type *Ty, Value *Ptr, IRBuilder<> &Builder);
Anna Welkerff9877c2020-01-21 09:44:31 +000085 // Compute the scale of this gather/scatter instruction
86 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize);
Anna Welker72ca86f2020-01-14 09:48:02 +000087
Anna Welker89e12482020-04-08 11:43:55 +010088 Value *lowerGather(IntrinsicInst *I);
Anna Welker72ca86f2020-01-14 09:48:02 +000089 // Create a gather from a base + vector of offsets
90 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr,
Nikita Popov98ed6132020-02-17 17:15:06 +010091 Instruction *&Root, IRBuilder<> &Builder);
Anna Welker72ca86f2020-01-14 09:48:02 +000092 // Create a gather from a vector of pointers
93 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr,
Nikita Popov98ed6132020-02-17 17:15:06 +010094 IRBuilder<> &Builder);
Anna Welkerff9877c2020-01-21 09:44:31 +000095
Anna Welker89e12482020-04-08 11:43:55 +010096 Value *lowerScatter(IntrinsicInst *I);
Anna Welkerff9877c2020-01-21 09:44:31 +000097 // Create a scatter to a base + vector of offsets
Anna Welker89e12482020-04-08 11:43:55 +010098 Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets,
Nikita Popov98ed6132020-02-17 17:15:06 +010099 IRBuilder<> &Builder);
Anna Welkerff9877c2020-01-21 09:44:31 +0000100 // Create a scatter to a vector of pointers
101 Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr,
Nikita Popov98ed6132020-02-17 17:15:06 +0100102 IRBuilder<> &Builder);
Anna Welker89e12482020-04-08 11:43:55 +0100103
104 // Check whether these offsets could be moved out of the loop they're in
105 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI);
106 // Pushes the given add out of the loop
107 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex);
108 // Pushes the given mul out of the loop
109 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound,
110 Value *OffsSecondOperand, unsigned LoopIncrement,
111 IRBuilder<> &Builder);
Anna Welker346f6b52020-01-08 13:08:27 +0000112};
113
114} // end anonymous namespace
115
116char MVEGatherScatterLowering::ID = 0;
117
118INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE,
119 "MVE gather/scattering lowering pass", false, false)
120
121Pass *llvm::createMVEGatherScatterLoweringPass() {
122 return new MVEGatherScatterLowering();
123}
124
Anna Welker72ca86f2020-01-14 09:48:02 +0000125bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements,
126 unsigned ElemSize,
127 unsigned Alignment) {
Anna Welkerc24cf972020-01-16 13:48:18 +0000128 if (((NumElements == 4 &&
129 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) ||
130 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) ||
Anna Welker72ca86f2020-01-14 09:48:02 +0000131 (NumElements == 16 && ElemSize == 8)) &&
132 ElemSize / 8 <= Alignment)
133 return true;
Anna Welkerff9877c2020-01-21 09:44:31 +0000134 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have "
135 << "valid alignment or vector type \n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000136 return false;
Anna Welker346f6b52020-01-08 13:08:27 +0000137}
138
Anna Welker72ca86f2020-01-14 09:48:02 +0000139Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, Type *Ty, Value *Ptr,
Nikita Popov98ed6132020-02-17 17:15:06 +0100140 IRBuilder<> &Builder) {
Anna Welker72ca86f2020-01-14 09:48:02 +0000141 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr);
142 if (!GEP) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000143 LLVM_DEBUG(
144 dbgs() << "masked gathers/scatters: no getelementpointer found\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000145 return nullptr;
146 }
Anna Welkerff9877c2020-01-21 09:44:31 +0000147 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found."
148 << " Looking at intrinsic for base + vector of offsets\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000149 Value *GEPPtr = GEP->getPointerOperand();
150 if (GEPPtr->getType()->isVectorTy()) {
Anna Welker72ca86f2020-01-14 09:48:02 +0000151 return nullptr;
152 }
153 if (GEP->getNumOperands() != 2) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000154 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many"
Anna Welker72ca86f2020-01-14 09:48:02 +0000155 << " operands. Expanding.\n");
156 return nullptr;
157 }
158 Offsets = GEP->getOperand(1);
Anna Welker39497412020-03-02 09:14:37 +0000159 // Paranoid check whether the number of parallel lanes is the same
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700160 assert(cast<VectorType>(Ty)->getNumElements() ==
161 cast<VectorType>(Offsets->getType())->getNumElements());
Anna Welker39497412020-03-02 09:14:37 +0000162 // Only <N x i32> offsets can be integrated into an arm gather, any smaller
163 // type would have to be sign extended by the gep - and arm gathers can only
164 // zero extend. Additionally, the offsets do have to originate from a zext of
165 // a vector with element types smaller or equal the type of the gather we're
166 // looking at
167 if (Offsets->getType()->getScalarSizeInBits() != 32)
168 return nullptr;
Anna Welker72ca86f2020-01-14 09:48:02 +0000169 if (ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets))
170 Offsets = ZextOffs->getOperand(0);
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700171 else if (!(cast<VectorType>(Offsets->getType())->getNumElements() == 4 &&
Anna Welker39497412020-03-02 09:14:37 +0000172 Offsets->getType()->getScalarSizeInBits() == 32))
173 return nullptr;
174
175 if (Ty != Offsets->getType()) {
176 if ((Ty->getScalarSizeInBits() <
177 Offsets->getType()->getScalarSizeInBits())) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000178 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no correct offset type."
179 << " Can't create intrinsic.\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000180 return nullptr;
Anna Welker39497412020-03-02 09:14:37 +0000181 } else {
182 Offsets = Builder.CreateZExt(
183 Offsets, VectorType::getInteger(cast<VectorType>(Ty)));
Anna Welker72ca86f2020-01-14 09:48:02 +0000184 }
185 }
186 // If none of the checks failed, return the gep's base pointer
Anna Welker39497412020-03-02 09:14:37 +0000187 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000188 return GEPPtr;
189}
190
191void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) {
192 // Look through bitcast instruction if #elements is the same
193 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) {
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700194 auto *BCTy = cast<VectorType>(BitCast->getType());
195 auto *BCSrcTy = cast<VectorType>(BitCast->getOperand(0)->getType());
196 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000197 LLVM_DEBUG(
198 dbgs() << "masked gathers/scatters: looking through bitcast\n");
Anna Welker72ca86f2020-01-14 09:48:02 +0000199 Ptr = BitCast->getOperand(0);
200 }
201 }
202}
203
Anna Welkerff9877c2020-01-21 09:44:31 +0000204int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize,
205 unsigned MemoryElemSize) {
206 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2,
207 // or a 8bit, 16bit or 32bit load/store scaled by 1
208 if (GEPElemSize == 32 && MemoryElemSize == 32)
209 return 2;
210 else if (GEPElemSize == 16 && MemoryElemSize == 16)
211 return 1;
212 else if (GEPElemSize == 8)
213 return 0;
214 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't "
215 << "create intrinsic\n");
216 return -1;
217}
218
Anna Welker89e12482020-04-08 11:43:55 +0100219Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) {
Anna Welker346f6b52020-01-08 13:08:27 +0000220 using namespace PatternMatch;
221 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n");
222
223 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0)
224 // Attempt to turn the masked gather in I into a MVE intrinsic
225 // Potentially optimising the addressing modes as we do so.
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700226 auto *Ty = cast<VectorType>(I->getType());
Anna Welker346f6b52020-01-08 13:08:27 +0000227 Value *Ptr = I->getArgOperand(0);
228 unsigned Alignment = cast<ConstantInt>(I->getArgOperand(1))->getZExtValue();
229 Value *Mask = I->getArgOperand(2);
230 Value *PassThru = I->getArgOperand(3);
231
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700232 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
233 Alignment))
Anna Welker89e12482020-04-08 11:43:55 +0100234 return nullptr;
Anna Welker72ca86f2020-01-14 09:48:02 +0000235 lookThroughBitcast(Ptr);
236 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
Anna Welker346f6b52020-01-08 13:08:27 +0000237
238 IRBuilder<> Builder(I->getContext());
239 Builder.SetInsertPoint(I);
240 Builder.SetCurrentDebugLocation(I->getDebugLoc());
Anna Welkerc24cf972020-01-16 13:48:18 +0000241
242 Instruction *Root = I;
Anna Welkerc24cf972020-01-16 13:48:18 +0000243 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder);
Anna Welker72ca86f2020-01-14 09:48:02 +0000244 if (!Load)
245 Load = tryCreateMaskedGatherBase(I, Ptr, Builder);
246 if (!Load)
Anna Welker89e12482020-04-08 11:43:55 +0100247 return nullptr;
Anna Welker346f6b52020-01-08 13:08:27 +0000248
249 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) {
250 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - "
251 << "creating select\n");
252 Load = Builder.CreateSelect(Mask, Load, PassThru);
253 }
254
Anna Welkerc24cf972020-01-16 13:48:18 +0000255 Root->replaceAllUsesWith(Load);
256 Root->eraseFromParent();
257 if (Root != I)
258 // If this was an extending gather, we need to get rid of the sext/zext
259 // sext/zext as well as of the gather itself
260 I->eraseFromParent();
Anna Welker89e12482020-04-08 11:43:55 +0100261
Anna Welker346f6b52020-01-08 13:08:27 +0000262 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n");
Anna Welker89e12482020-04-08 11:43:55 +0100263 return Load;
Anna Welker346f6b52020-01-08 13:08:27 +0000264}
265
Anna Welker89e12482020-04-08 11:43:55 +0100266Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I,
267 Value *Ptr,
268 IRBuilder<> &Builder) {
Anna Welker72ca86f2020-01-14 09:48:02 +0000269 using namespace PatternMatch;
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700270 auto *Ty = cast<VectorType>(I->getType());
Anna Welkerc24cf972020-01-16 13:48:18 +0000271 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n");
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700272 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32)
Anna Welker72ca86f2020-01-14 09:48:02 +0000273 // Can't build an intrinsic for this
274 return nullptr;
275 Value *Mask = I->getArgOperand(2);
276 if (match(Mask, m_One()))
277 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base,
278 {Ty, Ptr->getType()},
279 {Ptr, Builder.getInt32(0)});
280 else
281 return Builder.CreateIntrinsic(
282 Intrinsic::arm_mve_vldr_gather_base_predicated,
283 {Ty, Ptr->getType(), Mask->getType()},
284 {Ptr, Builder.getInt32(0), Mask});
285}
286
287Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset(
Nikita Popov98ed6132020-02-17 17:15:06 +0100288 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) {
Anna Welker72ca86f2020-01-14 09:48:02 +0000289 using namespace PatternMatch;
Anna Welkerc24cf972020-01-16 13:48:18 +0000290
291 Type *OriginalTy = I->getType();
292 Type *ResultTy = OriginalTy;
293
294 unsigned Unsigned = 1;
295 // The size of the gather was already checked in isLegalTypeAndAlignment;
296 // if it was not a full vector width an appropriate extend should follow.
297 auto *Extend = Root;
298 if (OriginalTy->getPrimitiveSizeInBits() < 128) {
299 // Only transform gathers with exactly one use
300 if (!I->hasOneUse())
301 return nullptr;
302
Anna Welker89e12482020-04-08 11:43:55 +0100303 // The correct root to replace is not the CallInst itself, but the
Anna Welkerc24cf972020-01-16 13:48:18 +0000304 // instruction which extends it
305 Extend = cast<Instruction>(*I->users().begin());
306 if (isa<SExtInst>(Extend)) {
307 Unsigned = 0;
308 } else if (!isa<ZExtInst>(Extend)) {
309 LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. "
310 << "Expanding\n");
311 return nullptr;
312 }
313 LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n");
314 ResultTy = Extend->getType();
315 // The final size of the gather must be a full vector width
316 if (ResultTy->getPrimitiveSizeInBits() != 128) {
317 LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. "
318 << "Expanding\n");
319 return nullptr;
320 }
321 }
322
Anna Welker72ca86f2020-01-14 09:48:02 +0000323 Value *Offsets;
Anna Welkerc24cf972020-01-16 13:48:18 +0000324 Value *BasePtr = checkGEP(Offsets, ResultTy, Ptr, Builder);
Anna Welker72ca86f2020-01-14 09:48:02 +0000325 if (!BasePtr)
326 return nullptr;
327
Anna Welkerff9877c2020-01-21 09:44:31 +0000328 int Scale = computeScale(
329 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
330 OriginalTy->getScalarSizeInBits());
331 if (Scale == -1)
Anna Welker72ca86f2020-01-14 09:48:02 +0000332 return nullptr;
Anna Welkerc24cf972020-01-16 13:48:18 +0000333 Root = Extend;
Anna Welker72ca86f2020-01-14 09:48:02 +0000334
335 Value *Mask = I->getArgOperand(2);
336 if (!match(Mask, m_One()))
337 return Builder.CreateIntrinsic(
338 Intrinsic::arm_mve_vldr_gather_offset_predicated,
Anna Welkerc24cf972020-01-16 13:48:18 +0000339 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()},
340 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
341 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask});
Anna Welker72ca86f2020-01-14 09:48:02 +0000342 else
343 return Builder.CreateIntrinsic(
344 Intrinsic::arm_mve_vldr_gather_offset,
Anna Welkerc24cf972020-01-16 13:48:18 +0000345 {ResultTy, BasePtr->getType(), Offsets->getType()},
346 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()),
347 Builder.getInt32(Scale), Builder.getInt32(Unsigned)});
Anna Welker72ca86f2020-01-14 09:48:02 +0000348}
349
Anna Welker89e12482020-04-08 11:43:55 +0100350Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000351 using namespace PatternMatch;
352 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n");
353
354 // @llvm.masked.scatter.*(data, ptrs, alignment, mask)
355 // Attempt to turn the masked scatter in I into a MVE intrinsic
356 // Potentially optimising the addressing modes as we do so.
357 Value *Input = I->getArgOperand(0);
358 Value *Ptr = I->getArgOperand(1);
359 unsigned Alignment = cast<ConstantInt>(I->getArgOperand(2))->getZExtValue();
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700360 auto *Ty = cast<VectorType>(Input->getType());
Anna Welkerff9877c2020-01-21 09:44:31 +0000361
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700362 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(),
363 Alignment))
Anna Welker89e12482020-04-08 11:43:55 +0100364 return nullptr;
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700365
Anna Welkerff9877c2020-01-21 09:44:31 +0000366 lookThroughBitcast(Ptr);
367 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
368
369 IRBuilder<> Builder(I->getContext());
370 Builder.SetInsertPoint(I);
371 Builder.SetCurrentDebugLocation(I->getDebugLoc());
372
373 Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder);
374 if (!Store)
375 Store = tryCreateMaskedScatterBase(I, Ptr, Builder);
376 if (!Store)
Anna Welker89e12482020-04-08 11:43:55 +0100377 return nullptr;
Anna Welkerff9877c2020-01-21 09:44:31 +0000378
379 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n");
380 I->replaceAllUsesWith(Store);
381 I->eraseFromParent();
Anna Welker89e12482020-04-08 11:43:55 +0100382 return Store;
Anna Welkerff9877c2020-01-21 09:44:31 +0000383}
384
385Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase(
Nikita Popov98ed6132020-02-17 17:15:06 +0100386 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000387 using namespace PatternMatch;
388 Value *Input = I->getArgOperand(0);
389 Value *Mask = I->getArgOperand(3);
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700390 auto *Ty = cast<VectorType>(Input->getType());
Anna Welkerff9877c2020-01-21 09:44:31 +0000391 // Only QR variants allow truncating
Christopher Tetreaulte1e131e2020-04-09 12:43:18 -0700392 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000393 // Can't build an intrinsic for this
394 return nullptr;
395 }
396 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask)
397 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n");
398 if (match(Mask, m_One()))
399 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base,
400 {Ptr->getType(), Input->getType()},
401 {Ptr, Builder.getInt32(0), Input});
402 else
403 return Builder.CreateIntrinsic(
404 Intrinsic::arm_mve_vstr_scatter_base_predicated,
405 {Ptr->getType(), Input->getType(), Mask->getType()},
406 {Ptr, Builder.getInt32(0), Input, Mask});
407}
408
409Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset(
Nikita Popov98ed6132020-02-17 17:15:06 +0100410 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) {
Anna Welkerff9877c2020-01-21 09:44:31 +0000411 using namespace PatternMatch;
412 Value *Input = I->getArgOperand(0);
413 Value *Mask = I->getArgOperand(3);
414 Type *InputTy = Input->getType();
415 Type *MemoryTy = InputTy;
416 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing"
417 << " to base + vector of offsets\n");
418 // If the input has been truncated, try to integrate that trunc into the
419 // scatter instruction (we don't care about alignment here)
420 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) {
421 Value *PreTrunc = Trunc->getOperand(0);
422 Type *PreTruncTy = PreTrunc->getType();
423 if (PreTruncTy->getPrimitiveSizeInBits() == 128) {
424 Input = PreTrunc;
425 InputTy = PreTruncTy;
426 }
427 }
428 if (InputTy->getPrimitiveSizeInBits() != 128) {
429 LLVM_DEBUG(
430 dbgs() << "masked scatters: cannot create scatters for non-standard"
431 << " input types. Expanding.\n");
432 return nullptr;
433 }
434
435 Value *Offsets;
436 Value *BasePtr = checkGEP(Offsets, InputTy, Ptr, Builder);
437 if (!BasePtr)
438 return nullptr;
439 int Scale = computeScale(
440 BasePtr->getType()->getPointerElementType()->getPrimitiveSizeInBits(),
441 MemoryTy->getScalarSizeInBits());
442 if (Scale == -1)
443 return nullptr;
444
445 if (!match(Mask, m_One()))
446 return Builder.CreateIntrinsic(
447 Intrinsic::arm_mve_vstr_scatter_offset_predicated,
448 {BasePtr->getType(), Offsets->getType(), Input->getType(),
449 Mask->getType()},
450 {BasePtr, Offsets, Input,
451 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
452 Builder.getInt32(Scale), Mask});
453 else
454 return Builder.CreateIntrinsic(
455 Intrinsic::arm_mve_vstr_scatter_offset,
456 {BasePtr->getType(), Offsets->getType(), Input->getType()},
457 {BasePtr, Offsets, Input,
458 Builder.getInt32(MemoryTy->getScalarSizeInBits()),
459 Builder.getInt32(Scale)});
460}
461
Anna Welker89e12482020-04-08 11:43:55 +0100462void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi,
463 Value *OffsSecondOperand,
464 unsigned StartIndex) {
465 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n");
Anna Welkerd7365712020-04-16 18:09:24 +0100466 Instruction *InsertionPoint =
Anna Welker89e12482020-04-08 11:43:55 +0100467 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back());
468 // Initialize the phi with a vector that contains a sum of the constants
469 Instruction *NewIndex = BinaryOperator::Create(
470 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand,
471 "PushedOutAdd", InsertionPoint);
472 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0;
473
474 // Order such that start index comes first (this reduces mov's)
475 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex));
476 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex),
477 Phi->getIncomingBlock(IncrementIndex));
478 Phi->removeIncomingValue(IncrementIndex);
479 Phi->removeIncomingValue(StartIndex);
480}
481
482void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi,
483 Value *IncrementPerRound,
484 Value *OffsSecondOperand,
485 unsigned LoopIncrement,
486 IRBuilder<> &Builder) {
487 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n");
488
489 // Create a new scalar add outside of the loop and transform it to a splat
490 // by which loop variable can be incremented
Anna Welkerd7365712020-04-16 18:09:24 +0100491 Instruction *InsertionPoint = &cast<Instruction>(
Anna Welker89e12482020-04-08 11:43:55 +0100492 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back());
493
494 // Create a new index
495 Value *StartIndex = BinaryOperator::Create(
496 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1),
497 OffsSecondOperand, "PushedOutMul", InsertionPoint);
498
499 Instruction *Product =
500 BinaryOperator::Create(Instruction::Mul, IncrementPerRound,
501 OffsSecondOperand, "Product", InsertionPoint);
502 // Increment NewIndex by Product instead of the multiplication
503 Instruction *NewIncrement = BinaryOperator::Create(
504 Instruction::Add, Phi, Product, "IncrementPushedOutMul",
505 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back())
506 .getPrevNode());
507
508 Phi->addIncoming(StartIndex,
509 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1));
510 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement));
511 Phi->removeIncomingValue((unsigned)0);
512 Phi->removeIncomingValue((unsigned)0);
513 return;
514}
515
516// Return true if the given intrinsic is a gather or scatter
517bool isGatherScatter(IntrinsicInst *IntInst) {
518 if (IntInst == nullptr)
519 return false;
520 unsigned IntrinsicID = IntInst->getIntrinsicID();
521 return (IntrinsicID == Intrinsic::masked_gather ||
522 IntrinsicID == Intrinsic::arm_mve_vldr_gather_base ||
523 IntrinsicID == Intrinsic::arm_mve_vldr_gather_base_predicated ||
524 IntrinsicID == Intrinsic::arm_mve_vldr_gather_base_wb ||
525 IntrinsicID == Intrinsic::arm_mve_vldr_gather_base_wb_predicated ||
526 IntrinsicID == Intrinsic::arm_mve_vldr_gather_offset ||
527 IntrinsicID == Intrinsic::arm_mve_vldr_gather_offset_predicated ||
528 IntrinsicID == Intrinsic::masked_scatter ||
529 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_base ||
530 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_base_predicated ||
531 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_base_wb ||
532 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_base_wb_predicated ||
533 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_offset ||
534 IntrinsicID == Intrinsic::arm_mve_vstr_scatter_offset_predicated);
535}
536
537// Check whether all usages of this instruction are as offsets of
538// gathers/scatters or simple arithmetics only used by gathers/scatters
539bool hasAllGatScatUsers(Instruction *I) {
540 if (I->hasNUses(0)) {
541 return false;
542 }
543 bool Gatscat = true;
544 for (User *U : I->users()) {
545 if (!isa<Instruction>(U))
546 return false;
547 if (isa<GetElementPtrInst>(U) ||
548 isGatherScatter(dyn_cast<IntrinsicInst>(U))) {
549 return Gatscat;
550 } else {
551 unsigned OpCode = cast<Instruction>(U)->getOpcode();
552 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) &&
553 hasAllGatScatUsers(cast<Instruction>(U))) {
554 continue;
555 }
556 return false;
557 }
558 }
559 return Gatscat;
560}
561
562bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB,
563 LoopInfo *LI) {
564 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n");
565 // Optimise the addresses of gathers/scatters by moving invariant
566 // calculations out of the loop
567 if (!isa<Instruction>(Offsets))
568 return false;
569 Instruction *Offs = cast<Instruction>(Offsets);
570 if (Offs->getOpcode() != Instruction::Add &&
571 Offs->getOpcode() != Instruction::Mul)
572 return false;
573 Loop *L = LI->getLoopFor(BB);
574 if (L == nullptr)
575 return false;
576 if (!Offs->hasOneUse()) {
577 if (!hasAllGatScatUsers(Offs))
578 return false;
579 }
580
581 // Find out which, if any, operand of the instruction
582 // is a phi node
583 PHINode *Phi;
584 int OffsSecondOp;
585 if (isa<PHINode>(Offs->getOperand(0))) {
586 Phi = cast<PHINode>(Offs->getOperand(0));
587 OffsSecondOp = 1;
588 } else if (isa<PHINode>(Offs->getOperand(1))) {
589 Phi = cast<PHINode>(Offs->getOperand(1));
590 OffsSecondOp = 0;
591 } else {
592 bool Changed = true;
593 if (isa<Instruction>(Offs->getOperand(0)) &&
594 L->contains(cast<Instruction>(Offs->getOperand(0))))
595 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI);
596 if (isa<Instruction>(Offs->getOperand(1)) &&
597 L->contains(cast<Instruction>(Offs->getOperand(1))))
598 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI);
599 if (!Changed) {
600 return false;
601 } else {
602 if (isa<PHINode>(Offs->getOperand(0))) {
603 Phi = cast<PHINode>(Offs->getOperand(0));
604 OffsSecondOp = 1;
605 } else if (isa<PHINode>(Offs->getOperand(1))) {
606 Phi = cast<PHINode>(Offs->getOperand(1));
607 OffsSecondOp = 0;
608 } else {
609 return false;
610 }
611 }
612 }
613 // A phi node we want to perform this function on should be from the
614 // loop header, and shouldn't have more than 2 incoming values
615 if (Phi->getParent() != L->getHeader() ||
616 Phi->getNumIncomingValues() != 2)
617 return false;
618
619 // The phi must be an induction variable
620 Instruction *Op;
621 int IncrementingBlock = -1;
622
623 for (int i = 0; i < 2; i++)
624 if ((Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) != nullptr)
625 if (Op->getOpcode() == Instruction::Add &&
626 (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi))
627 IncrementingBlock = i;
628 if (IncrementingBlock == -1)
629 return false;
630
631 Instruction *IncInstruction =
632 cast<Instruction>(Phi->getIncomingValue(IncrementingBlock));
633
634 // If the phi is not used by anything else, we can just adapt it when
635 // replacing the instruction; if it is, we'll have to duplicate it
636 PHINode *NewPhi;
637 Value *IncrementPerRound = IncInstruction->getOperand(
638 (IncInstruction->getOperand(0) == Phi) ? 1 : 0);
639
640 // Get the value that is added to/multiplied with the phi
641 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp);
642
643 if (IncrementPerRound->getType() != OffsSecondOperand->getType())
644 // Something has gone wrong, abort
645 return false;
646
647 // Only proceed if the increment per round is a constant or an instruction
648 // which does not originate from within the loop
649 if (!isa<Constant>(IncrementPerRound) &&
650 !(isa<Instruction>(IncrementPerRound) &&
651 !L->contains(cast<Instruction>(IncrementPerRound))))
652 return false;
653
654 if (Phi->getNumUses() == 2) {
655 // No other users -> reuse existing phi (One user is the instruction
656 // we're looking at, the other is the phi increment)
657 if (IncInstruction->getNumUses() != 1) {
658 // If the incrementing instruction does have more users than
659 // our phi, we need to copy it
660 IncInstruction = BinaryOperator::Create(
661 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi,
662 IncrementPerRound, "LoopIncrement", IncInstruction);
663 Phi->setIncomingValue(IncrementingBlock, IncInstruction);
664 }
665 NewPhi = Phi;
666 } else {
667 // There are other users -> create a new phi
668 NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi);
669 std::vector<Value *> Increases;
670 // Copy the incoming values of the old phi
671 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1),
672 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1));
673 IncInstruction = BinaryOperator::Create(
674 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi,
675 IncrementPerRound, "LoopIncrement", IncInstruction);
676 NewPhi->addIncoming(IncInstruction,
677 Phi->getIncomingBlock(IncrementingBlock));
678 IncrementingBlock = 1;
679 }
680
681 IRBuilder<> Builder(BB->getContext());
682 Builder.SetInsertPoint(Phi);
683 Builder.SetCurrentDebugLocation(Offs->getDebugLoc());
684
685 switch (Offs->getOpcode()) {
686 case Instruction::Add:
687 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1);
688 break;
689 case Instruction::Mul:
690 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock,
691 Builder);
692 break;
693 default:
694 return false;
695 }
696 LLVM_DEBUG(
697 dbgs() << "masked gathers/scatters: simplified loop variable add/mul\n");
698
699 // The instruction has now been "absorbed" into the phi value
700 Offs->replaceAllUsesWith(NewPhi);
701 if (Offs->hasNUses(0))
702 Offs->eraseFromParent();
703 // Clean up the old increment in case it's unused because we built a new
704 // one
705 if (IncInstruction->hasNUses(0))
706 IncInstruction->eraseFromParent();
707
708 return true;
709}
710
Anna Welker346f6b52020-01-08 13:08:27 +0000711bool MVEGatherScatterLowering::runOnFunction(Function &F) {
712 if (!EnableMaskedGatherScatters)
713 return false;
714 auto &TPC = getAnalysis<TargetPassConfig>();
715 auto &TM = TPC.getTM<TargetMachine>();
716 auto *ST = &TM.getSubtarget<ARMSubtarget>(F);
717 if (!ST->hasMVEIntegerOps())
718 return false;
719 SmallVector<IntrinsicInst *, 4> Gathers;
Anna Welkerff9877c2020-01-21 09:44:31 +0000720 SmallVector<IntrinsicInst *, 4> Scatters;
Anna Welker89e12482020-04-08 11:43:55 +0100721 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
722
Anna Welker346f6b52020-01-08 13:08:27 +0000723 for (BasicBlock &BB : F) {
724 for (Instruction &I : BB) {
725 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
726 if (II && II->getIntrinsicID() == Intrinsic::masked_gather)
727 Gathers.push_back(II);
Anna Welkerff9877c2020-01-21 09:44:31 +0000728 else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter)
729 Scatters.push_back(II);
Anna Welker346f6b52020-01-08 13:08:27 +0000730 }
731 }
732
Anna Welkerff9877c2020-01-21 09:44:31 +0000733 bool Changed = false;
Anna Welker89e12482020-04-08 11:43:55 +0100734 for (unsigned i = 0; i < Gathers.size(); i++) {
735 IntrinsicInst *I = Gathers[i];
736 if (isa<GetElementPtrInst>(I->getArgOperand(0)))
737 optimiseOffsets(cast<Instruction>(I->getArgOperand(0))->getOperand(1),
738 I->getParent(), &LI);
739 Value *L = lowerGather(I);
740 if (L == nullptr)
741 continue;
742 // Get rid of any now dead instructions
743 SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent());
744 Changed = true;
745 }
Anna Welker346f6b52020-01-08 13:08:27 +0000746
Anna Welker89e12482020-04-08 11:43:55 +0100747 for (unsigned i = 0; i < Scatters.size(); i++) {
748 IntrinsicInst *I = Scatters[i];
749 if (isa<GetElementPtrInst>(I->getArgOperand(1)))
750 optimiseOffsets(cast<Instruction>(I->getArgOperand(1))->getOperand(1),
751 I->getParent(), &LI);
752 Value *S = lowerScatter(I);
753 if (S == nullptr)
754 continue;
755 // Get rid of any now dead instructions
756 SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent());
757 Changed = true;
758 }
Anna Welkerff9877c2020-01-21 09:44:31 +0000759 return Changed;
Anna Welker346f6b52020-01-08 13:08:27 +0000760}