blob: 1fb116e9b4805f24a5fdb27353593ed9959bb191 [file] [log] [blame]
Eugene Zelenkofa57bd02017-09-27 23:26:01 +00001//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// instrinsics
Ayman Musac5490e52017-05-15 11:30:54 +00003//
4// The LLVM Compiler Infrastructure
5//
6// This file is distributed under the University of Illinois Open Source
7// License. See LICENSE.TXT for details.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass replaces masked memory intrinsics - when unsupported by the target
12// - with a chain of basic blocks, that deal with the elements one-by-one if the
13// appropriate mask bit is set.
14//
15//===----------------------------------------------------------------------===//
16
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000017#include "llvm/ADT/Twine.h"
Ayman Musac5490e52017-05-15 11:30:54 +000018#include "llvm/Analysis/TargetTransformInfo.h"
David Blaikieb3bde2e2017-11-17 01:07:10 +000019#include "llvm/CodeGen/TargetSubtargetInfo.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000020#include "llvm/IR/BasicBlock.h"
21#include "llvm/IR/Constant.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
24#include "llvm/IR/Function.h"
Ayman Musac5490e52017-05-15 11:30:54 +000025#include "llvm/IR/IRBuilder.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000026#include "llvm/IR/InstrTypes.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
Reid Kleckner0e8c4bb2017-09-07 23:27:44 +000029#include "llvm/IR/IntrinsicInst.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000030#include "llvm/IR/Intrinsics.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/Value.h"
33#include "llvm/Pass.h"
34#include "llvm/Support/Casting.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000035#include <algorithm>
36#include <cassert>
Ayman Musac5490e52017-05-15 11:30:54 +000037
38using namespace llvm;
39
40#define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42namespace {
43
44class ScalarizeMaskedMemIntrin : public FunctionPass {
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000045 const TargetTransformInfo *TTI = nullptr;
Ayman Musac5490e52017-05-15 11:30:54 +000046
47public:
48 static char ID; // Pass identification, replacement for typeid
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000049
50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
Ayman Musac5490e52017-05-15 11:30:54 +000051 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52 }
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000053
Ayman Musac5490e52017-05-15 11:30:54 +000054 bool runOnFunction(Function &F) override;
55
56 StringRef getPassName() const override {
57 return "Scalarize Masked Memory Intrinsics";
58 }
59
60 void getAnalysisUsage(AnalysisUsage &AU) const override {
61 AU.addRequired<TargetTransformInfoWrapperPass>();
62 }
63
64private:
65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67};
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000068
69} // end anonymous namespace
Ayman Musac5490e52017-05-15 11:30:54 +000070
71char ScalarizeMaskedMemIntrin::ID = 0;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000072
Matthias Braun1527baa2017-05-25 21:26:32 +000073INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74 "Scalarize unsupported masked memory intrinsics", false, false)
Ayman Musac5490e52017-05-15 11:30:54 +000075
76FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77 return new ScalarizeMaskedMemIntrin();
78}
79
Craig Topper8b4f0e12018-09-27 22:31:42 +000080static bool isConstantIntVector(Value *Mask) {
81 Constant *C = dyn_cast<Constant>(Mask);
82 if (!C)
83 return false;
84
85 unsigned NumElts = Mask->getType()->getVectorNumElements();
86 for (unsigned i = 0; i != NumElts; ++i) {
87 Constant *CElt = C->getAggregateElement(i);
88 if (!CElt || !isa<ConstantInt>(CElt))
89 return false;
90 }
91
92 return true;
93}
94
Ayman Musac5490e52017-05-15 11:30:54 +000095// Translate a masked load intrinsic like
96// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
97// <16 x i1> %mask, <16 x i32> %passthru)
98// to a chain of basic blocks, with loading element one-by-one if
99// the appropriate mask bit is set
100//
101// %1 = bitcast i8* %addr to i32*
102// %2 = extractelement <16 x i1> %mask, i32 0
Craig Topper49dad8b2018-09-27 21:28:39 +0000103// br i1 %2, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000104//
105// cond.load: ; preds = %0
Craig Topper49dad8b2018-09-27 21:28:39 +0000106// %3 = getelementptr i32* %1, i32 0
107// %4 = load i32* %3
Craig Topper7d234d62018-09-27 21:28:52 +0000108// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
Ayman Musac5490e52017-05-15 11:30:54 +0000109// br label %else
110//
111// else: ; preds = %0, %cond.load
Craig Topper49dad8b2018-09-27 21:28:39 +0000112// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
113// %6 = extractelement <16 x i1> %mask, i32 1
114// br i1 %6, label %cond.load1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000115//
116// cond.load1: ; preds = %else
Craig Topper49dad8b2018-09-27 21:28:39 +0000117// %7 = getelementptr i32* %1, i32 1
118// %8 = load i32* %7
119// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
Ayman Musac5490e52017-05-15 11:30:54 +0000120// br label %else2
121//
122// else2: ; preds = %else, %cond.load1
Craig Topper49dad8b2018-09-27 21:28:39 +0000123// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
124// %10 = extractelement <16 x i1> %mask, i32 2
125// br i1 %10, label %cond.load4, label %else5
Ayman Musac5490e52017-05-15 11:30:54 +0000126//
127static void scalarizeMaskedLoad(CallInst *CI) {
128 Value *Ptr = CI->getArgOperand(0);
129 Value *Alignment = CI->getArgOperand(1);
130 Value *Mask = CI->getArgOperand(2);
131 Value *Src0 = CI->getArgOperand(3);
132
133 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
Craig Topper10ec0212018-09-27 22:31:40 +0000134 VectorType *VecType = cast<VectorType>(CI->getType());
Ayman Musac5490e52017-05-15 11:30:54 +0000135
Craig Topper10ec0212018-09-27 22:31:40 +0000136 Type *EltTy = VecType->getElementType();
Ayman Musac5490e52017-05-15 11:30:54 +0000137
138 IRBuilder<> Builder(CI->getContext());
139 Instruction *InsertPt = CI;
140 BasicBlock *IfBlock = CI->getParent();
141 BasicBlock *CondBlock = nullptr;
142 BasicBlock *PrevIfBlock = CI->getParent();
143
144 Builder.SetInsertPoint(InsertPt);
145 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
146
147 // Short-cut if the mask is all-true.
Craig Topperdfe460d2018-09-27 21:28:41 +0000148 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Ayman Musac5490e52017-05-15 11:30:54 +0000149 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
150 CI->replaceAllUsesWith(NewI);
151 CI->eraseFromParent();
152 return;
153 }
154
155 // Adjust alignment for the scalar instruction.
Craig Topperbb50c382018-09-28 03:35:37 +0000156 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
Ayman Musac5490e52017-05-15 11:30:54 +0000157 // Bitcast %addr fron i8* to EltTy*
158 Type *NewPtrType =
159 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
160 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
161 unsigned VectorWidth = VecType->getNumElements();
162
Ayman Musac5490e52017-05-15 11:30:54 +0000163 // The result vector
Craig Topper7d234d62018-09-27 21:28:52 +0000164 Value *VResult = Src0;
Ayman Musac5490e52017-05-15 11:30:54 +0000165
Craig Topper8b4f0e12018-09-27 22:31:42 +0000166 if (isConstantIntVector(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000167 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000168 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000169 continue;
170 Value *Gep =
171 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
172 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
173 VResult =
174 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
175 }
Craig Topper7d234d62018-09-27 21:28:52 +0000176 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000177 CI->eraseFromParent();
178 return;
179 }
180
Ayman Musac5490e52017-05-15 11:30:54 +0000181 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000182 // Fill the "else" block, created in the previous iteration
183 //
184 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
185 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
Craig Topper04236812018-09-27 18:01:48 +0000186 // br i1 %mask_1, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000187 //
Ayman Musac5490e52017-05-15 11:30:54 +0000188
189 Value *Predicate =
190 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000191
192 // Create "cond" block
193 //
194 // %EltAddr = getelementptr i32* %1, i32 0
195 // %Elt = load i32* %EltAddr
196 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
197 //
198 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
199 Builder.SetInsertPoint(InsertPt);
200
201 Value *Gep =
202 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
203 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
Craig Topper7d234d62018-09-27 21:28:52 +0000204 Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
205 Builder.getInt32(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000206
207 // Create "else" block, fill it in the next iteration
208 BasicBlock *NewIfBlock =
209 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
210 Builder.SetInsertPoint(InsertPt);
211 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000212 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000213 OldBr->eraseFromParent();
214 PrevIfBlock = IfBlock;
215 IfBlock = NewIfBlock;
Craig Topper7d234d62018-09-27 21:28:52 +0000216
217 // Create the phi to join the new and previous value.
218 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
219 Phi->addIncoming(NewVResult, CondBlock);
220 Phi->addIncoming(VResult, PrevIfBlock);
221 VResult = Phi;
Ayman Musac5490e52017-05-15 11:30:54 +0000222 }
223
Craig Topper7d234d62018-09-27 21:28:52 +0000224 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000225 CI->eraseFromParent();
226}
227
228// Translate a masked store intrinsic, like
229// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
230// <16 x i1> %mask)
231// to a chain of basic blocks, that stores element one-by-one if
232// the appropriate mask bit is set
233//
234// %1 = bitcast i8* %addr to i32*
235// %2 = extractelement <16 x i1> %mask, i32 0
Craig Topper49dad8b2018-09-27 21:28:39 +0000236// br i1 %2, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000237//
238// cond.store: ; preds = %0
Craig Topper49dad8b2018-09-27 21:28:39 +0000239// %3 = extractelement <16 x i32> %val, i32 0
240// %4 = getelementptr i32* %1, i32 0
241// store i32 %3, i32* %4
Ayman Musac5490e52017-05-15 11:30:54 +0000242// br label %else
243//
244// else: ; preds = %0, %cond.store
Craig Topper49dad8b2018-09-27 21:28:39 +0000245// %5 = extractelement <16 x i1> %mask, i32 1
246// br i1 %5, label %cond.store1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000247//
248// cond.store1: ; preds = %else
Craig Topper49dad8b2018-09-27 21:28:39 +0000249// %6 = extractelement <16 x i32> %val, i32 1
250// %7 = getelementptr i32* %1, i32 1
251// store i32 %6, i32* %7
Ayman Musac5490e52017-05-15 11:30:54 +0000252// br label %else2
253// . . .
254static void scalarizeMaskedStore(CallInst *CI) {
255 Value *Src = CI->getArgOperand(0);
256 Value *Ptr = CI->getArgOperand(1);
257 Value *Alignment = CI->getArgOperand(2);
258 Value *Mask = CI->getArgOperand(3);
259
260 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
Craig Topper10ec0212018-09-27 22:31:40 +0000261 VectorType *VecType = cast<VectorType>(Src->getType());
Ayman Musac5490e52017-05-15 11:30:54 +0000262
263 Type *EltTy = VecType->getElementType();
264
265 IRBuilder<> Builder(CI->getContext());
266 Instruction *InsertPt = CI;
267 BasicBlock *IfBlock = CI->getParent();
268 Builder.SetInsertPoint(InsertPt);
269 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
270
271 // Short-cut if the mask is all-true.
Craig Topperdfe460d2018-09-27 21:28:41 +0000272 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Ayman Musac5490e52017-05-15 11:30:54 +0000273 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
274 CI->eraseFromParent();
275 return;
276 }
277
278 // Adjust alignment for the scalar instruction.
Craig Topperbb50c382018-09-28 03:35:37 +0000279 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
Ayman Musac5490e52017-05-15 11:30:54 +0000280 // Bitcast %addr fron i8* to EltTy*
281 Type *NewPtrType =
282 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
283 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
284 unsigned VectorWidth = VecType->getNumElements();
285
Craig Topper8b4f0e12018-09-27 22:31:42 +0000286 if (isConstantIntVector(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000287 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000288 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000289 continue;
290 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
291 Value *Gep =
292 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
293 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
294 }
295 CI->eraseFromParent();
296 return;
297 }
298
299 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000300 // Fill the "else" block, created in the previous iteration
301 //
302 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
Craig Topper04236812018-09-27 18:01:48 +0000303 // br i1 %mask_1, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000304 //
305 Value *Predicate =
306 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000307
308 // Create "cond" block
309 //
310 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
311 // %EltAddr = getelementptr i32* %1, i32 0
312 // %store i32 %OneElt, i32* %EltAddr
313 //
314 BasicBlock *CondBlock =
315 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
316 Builder.SetInsertPoint(InsertPt);
317
318 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
319 Value *Gep =
320 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
321 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
322
323 // Create "else" block, fill it in the next iteration
324 BasicBlock *NewIfBlock =
325 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
326 Builder.SetInsertPoint(InsertPt);
327 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000328 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000329 OldBr->eraseFromParent();
330 IfBlock = NewIfBlock;
331 }
332 CI->eraseFromParent();
333}
334
335// Translate a masked gather intrinsic like
336// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
337// <16 x i1> %Mask, <16 x i32> %Src)
338// to a chain of basic blocks, with loading element one-by-one if
339// the appropriate mask bit is set
340//
Craig Topper49dad8b2018-09-27 21:28:39 +0000341// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
342// %Mask0 = extractelement <16 x i1> %Mask, i32 0
343// br i1 %Mask0, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000344//
345// cond.load:
Craig Topper49dad8b2018-09-27 21:28:39 +0000346// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
347// %Load0 = load i32, i32* %Ptr0, align 4
348// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
Ayman Musac5490e52017-05-15 11:30:54 +0000349// br label %else
350//
351// else:
Craig Topper49dad8b2018-09-27 21:28:39 +0000352// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
353// %Mask1 = extractelement <16 x i1> %Mask, i32 1
354// br i1 %Mask1, label %cond.load1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000355//
356// cond.load1:
Craig Topper49dad8b2018-09-27 21:28:39 +0000357// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
358// %Load1 = load i32, i32* %Ptr1, align 4
359// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
Ayman Musac5490e52017-05-15 11:30:54 +0000360// br label %else2
361// . . .
Craig Topper49dad8b2018-09-27 21:28:39 +0000362// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
Ayman Musac5490e52017-05-15 11:30:54 +0000363// ret <16 x i32> %Result
364static void scalarizeMaskedGather(CallInst *CI) {
365 Value *Ptrs = CI->getArgOperand(0);
366 Value *Alignment = CI->getArgOperand(1);
367 Value *Mask = CI->getArgOperand(2);
368 Value *Src0 = CI->getArgOperand(3);
369
Craig Topper10ec0212018-09-27 22:31:40 +0000370 VectorType *VecType = cast<VectorType>(CI->getType());
Ayman Musac5490e52017-05-15 11:30:54 +0000371
372 IRBuilder<> Builder(CI->getContext());
373 Instruction *InsertPt = CI;
374 BasicBlock *IfBlock = CI->getParent();
375 BasicBlock *CondBlock = nullptr;
376 BasicBlock *PrevIfBlock = CI->getParent();
377 Builder.SetInsertPoint(InsertPt);
378 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
379
380 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
381
Ayman Musac5490e52017-05-15 11:30:54 +0000382 // The result vector
Craig Topper6911bfe2018-09-27 21:28:59 +0000383 Value *VResult = Src0;
Ayman Musac5490e52017-05-15 11:30:54 +0000384 unsigned VectorWidth = VecType->getNumElements();
385
386 // Shorten the way if the mask is a vector of constants.
Craig Topper8b4f0e12018-09-27 22:31:42 +0000387 if (isConstantIntVector(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000388 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000389 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000390 continue;
391 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
392 "Ptr" + Twine(Idx));
393 LoadInst *Load =
394 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
395 VResult = Builder.CreateInsertElement(
396 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
397 }
Craig Topper6911bfe2018-09-27 21:28:59 +0000398 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000399 CI->eraseFromParent();
400 return;
401 }
402
Ayman Musac5490e52017-05-15 11:30:54 +0000403 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000404 // Fill the "else" block, created in the previous iteration
405 //
406 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
Craig Topper04236812018-09-27 18:01:48 +0000407 // br i1 %Mask1, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000408 //
Ayman Musac5490e52017-05-15 11:30:54 +0000409
410 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
411 "Mask" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000412
413 // Create "cond" block
414 //
415 // %EltAddr = getelementptr i32* %1, i32 0
416 // %Elt = load i32* %EltAddr
417 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
418 //
419 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
420 Builder.SetInsertPoint(InsertPt);
421
422 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
423 "Ptr" + Twine(Idx));
424 LoadInst *Load =
425 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
Craig Topper6911bfe2018-09-27 21:28:59 +0000426 Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
427 Builder.getInt32(Idx),
428 "Res" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000429
430 // Create "else" block, fill it in the next iteration
431 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
432 Builder.SetInsertPoint(InsertPt);
433 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000434 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000435 OldBr->eraseFromParent();
436 PrevIfBlock = IfBlock;
437 IfBlock = NewIfBlock;
Craig Topper6911bfe2018-09-27 21:28:59 +0000438
439 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
440 Phi->addIncoming(NewVResult, CondBlock);
441 Phi->addIncoming(VResult, PrevIfBlock);
442 VResult = Phi;
Ayman Musac5490e52017-05-15 11:30:54 +0000443 }
444
Craig Topper6911bfe2018-09-27 21:28:59 +0000445 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000446 CI->eraseFromParent();
447}
448
449// Translate a masked scatter intrinsic, like
450// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
451// <16 x i1> %Mask)
452// to a chain of basic blocks, that stores element one-by-one if
453// the appropriate mask bit is set.
454//
Craig Topper49dad8b2018-09-27 21:28:39 +0000455// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
456// %Mask0 = extractelement <16 x i1> %Mask, i32 0
457// br i1 %Mask0, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000458//
459// cond.store:
Craig Topper49dad8b2018-09-27 21:28:39 +0000460// %Elt0 = extractelement <16 x i32> %Src, i32 0
461// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
462// store i32 %Elt0, i32* %Ptr0, align 4
Ayman Musac5490e52017-05-15 11:30:54 +0000463// br label %else
464//
465// else:
Craig Topper49dad8b2018-09-27 21:28:39 +0000466// %Mask1 = extractelement <16 x i1> %Mask, i32 1
467// br i1 %Mask1, label %cond.store1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000468//
469// cond.store1:
Craig Topper49dad8b2018-09-27 21:28:39 +0000470// %Elt1 = extractelement <16 x i32> %Src, i32 1
471// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
472// store i32 %Elt1, i32* %Ptr1, align 4
Ayman Musac5490e52017-05-15 11:30:54 +0000473// br label %else2
474// . . .
475static void scalarizeMaskedScatter(CallInst *CI) {
476 Value *Src = CI->getArgOperand(0);
477 Value *Ptrs = CI->getArgOperand(1);
478 Value *Alignment = CI->getArgOperand(2);
479 Value *Mask = CI->getArgOperand(3);
480
481 assert(isa<VectorType>(Src->getType()) &&
482 "Unexpected data type in masked scatter intrinsic");
483 assert(isa<VectorType>(Ptrs->getType()) &&
484 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
485 "Vector of pointers is expected in masked scatter intrinsic");
486
487 IRBuilder<> Builder(CI->getContext());
488 Instruction *InsertPt = CI;
489 BasicBlock *IfBlock = CI->getParent();
490 Builder.SetInsertPoint(InsertPt);
491 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
492
493 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
494 unsigned VectorWidth = Src->getType()->getVectorNumElements();
495
496 // Shorten the way if the mask is a vector of constants.
Craig Topper8b4f0e12018-09-27 22:31:42 +0000497 if (isConstantIntVector(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000498 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000499 if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000500 continue;
501 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
502 "Elt" + Twine(Idx));
503 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
504 "Ptr" + Twine(Idx));
505 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
506 }
507 CI->eraseFromParent();
508 return;
509 }
Craig Topperdfe460d2018-09-27 21:28:41 +0000510
Ayman Musac5490e52017-05-15 11:30:54 +0000511 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
512 // Fill the "else" block, created in the previous iteration
513 //
Craig Topper04236812018-09-27 18:01:48 +0000514 // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
515 // br i1 %Mask1, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000516 //
517 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
518 "Mask" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000519
520 // Create "cond" block
521 //
Craig Topper49dad8b2018-09-27 21:28:39 +0000522 // %Elt1 = extractelement <16 x i32> %Src, i32 1
523 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
524 // %store i32 %Elt1, i32* %Ptr1
Ayman Musac5490e52017-05-15 11:30:54 +0000525 //
526 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
527 Builder.SetInsertPoint(InsertPt);
528
529 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
530 "Elt" + Twine(Idx));
531 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
532 "Ptr" + Twine(Idx));
533 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
534
535 // Create "else" block, fill it in the next iteration
536 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
537 Builder.SetInsertPoint(InsertPt);
538 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000539 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000540 OldBr->eraseFromParent();
541 IfBlock = NewIfBlock;
542 }
543 CI->eraseFromParent();
544}
545
546bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
Ayman Musac5490e52017-05-15 11:30:54 +0000547 bool EverMadeChange = false;
548
549 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
550
551 bool MadeChange = true;
552 while (MadeChange) {
553 MadeChange = false;
554 for (Function::iterator I = F.begin(); I != F.end();) {
555 BasicBlock *BB = &*I++;
556 bool ModifiedDTOnIteration = false;
557 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
558
559 // Restart BB iteration if the dominator tree of the Function was changed
560 if (ModifiedDTOnIteration)
561 break;
562 }
563
564 EverMadeChange |= MadeChange;
565 }
566
567 return EverMadeChange;
568}
569
570bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
571 bool MadeChange = false;
572
573 BasicBlock::iterator CurInstIterator = BB.begin();
574 while (CurInstIterator != BB.end()) {
575 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
576 MadeChange |= optimizeCallInst(CI, ModifiedDT);
577 if (ModifiedDT)
578 return true;
579 }
580
581 return MadeChange;
582}
583
584bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
585 bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000586 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
587 if (II) {
588 switch (II->getIntrinsicID()) {
589 default:
590 break;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000591 case Intrinsic::masked_load:
Ayman Musac5490e52017-05-15 11:30:54 +0000592 // Scalarize unsupported vector masked load
593 if (!TTI->isLegalMaskedLoad(CI->getType())) {
594 scalarizeMaskedLoad(CI);
595 ModifiedDT = true;
596 return true;
597 }
598 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000599 case Intrinsic::masked_store:
Ayman Musac5490e52017-05-15 11:30:54 +0000600 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
601 scalarizeMaskedStore(CI);
602 ModifiedDT = true;
603 return true;
604 }
605 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000606 case Intrinsic::masked_gather:
Ayman Musac5490e52017-05-15 11:30:54 +0000607 if (!TTI->isLegalMaskedGather(CI->getType())) {
608 scalarizeMaskedGather(CI);
609 ModifiedDT = true;
610 return true;
611 }
612 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000613 case Intrinsic::masked_scatter:
Ayman Musac5490e52017-05-15 11:30:54 +0000614 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
615 scalarizeMaskedScatter(CI);
616 ModifiedDT = true;
617 return true;
618 }
619 return false;
620 }
Ayman Musac5490e52017-05-15 11:30:54 +0000621 }
622
623 return false;
624}