blob: da3e4542ad9ccc21256bbca7fdf7150b1aa27ad7 [file] [log] [blame]
Ayman Musac5490e52017-05-15 11:30:54 +00001//=== ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ===//
2//=== instrinsics ===//
3//
4// The LLVM Compiler Infrastructure
5//
6// This file is distributed under the University of Illinois Open Source
7// License. See LICENSE.TXT for details.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass replaces masked memory intrinsics - when unsupported by the target
12// - with a chain of basic blocks, that deal with the elements one-by-one if the
13// appropriate mask bit is set.
14//
15//===----------------------------------------------------------------------===//
16
17#include "llvm/Analysis/TargetTransformInfo.h"
18#include "llvm/IR/IRBuilder.h"
Reid Kleckner0e8c4bb2017-09-07 23:27:44 +000019#include "llvm/IR/IntrinsicInst.h"
Ayman Musac5490e52017-05-15 11:30:54 +000020#include "llvm/Target/TargetSubtargetInfo.h"
21
22using namespace llvm;
23
24#define DEBUG_TYPE "scalarize-masked-mem-intrin"
25
26namespace {
27
28class ScalarizeMaskedMemIntrin : public FunctionPass {
29 const TargetTransformInfo *TTI;
30
31public:
32 static char ID; // Pass identification, replacement for typeid
33 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID), TTI(nullptr) {
34 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
35 }
36 bool runOnFunction(Function &F) override;
37
38 StringRef getPassName() const override {
39 return "Scalarize Masked Memory Intrinsics";
40 }
41
42 void getAnalysisUsage(AnalysisUsage &AU) const override {
43 AU.addRequired<TargetTransformInfoWrapperPass>();
44 }
45
46private:
47 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
48 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
49};
50} // namespace
51
52char ScalarizeMaskedMemIntrin::ID = 0;
Matthias Braun1527baa2017-05-25 21:26:32 +000053INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
54 "Scalarize unsupported masked memory intrinsics", false, false)
Ayman Musac5490e52017-05-15 11:30:54 +000055
56FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
57 return new ScalarizeMaskedMemIntrin();
58}
59
60// Translate a masked load intrinsic like
61// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
62// <16 x i1> %mask, <16 x i32> %passthru)
63// to a chain of basic blocks, with loading element one-by-one if
64// the appropriate mask bit is set
65//
66// %1 = bitcast i8* %addr to i32*
67// %2 = extractelement <16 x i1> %mask, i32 0
68// %3 = icmp eq i1 %2, true
69// br i1 %3, label %cond.load, label %else
70//
71// cond.load: ; preds = %0
72// %4 = getelementptr i32* %1, i32 0
73// %5 = load i32* %4
74// %6 = insertelement <16 x i32> undef, i32 %5, i32 0
75// br label %else
76//
77// else: ; preds = %0, %cond.load
78// %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
79// %7 = extractelement <16 x i1> %mask, i32 1
80// %8 = icmp eq i1 %7, true
81// br i1 %8, label %cond.load1, label %else2
82//
83// cond.load1: ; preds = %else
84// %9 = getelementptr i32* %1, i32 1
85// %10 = load i32* %9
86// %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
87// br label %else2
88//
89// else2: ; preds = %else, %cond.load1
90// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
91// %12 = extractelement <16 x i1> %mask, i32 2
92// %13 = icmp eq i1 %12, true
93// br i1 %13, label %cond.load4, label %else5
94//
95static void scalarizeMaskedLoad(CallInst *CI) {
96 Value *Ptr = CI->getArgOperand(0);
97 Value *Alignment = CI->getArgOperand(1);
98 Value *Mask = CI->getArgOperand(2);
99 Value *Src0 = CI->getArgOperand(3);
100
101 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
102 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
103 assert(VecType && "Unexpected return type of masked load intrinsic");
104
105 Type *EltTy = CI->getType()->getVectorElementType();
106
107 IRBuilder<> Builder(CI->getContext());
108 Instruction *InsertPt = CI;
109 BasicBlock *IfBlock = CI->getParent();
110 BasicBlock *CondBlock = nullptr;
111 BasicBlock *PrevIfBlock = CI->getParent();
112
113 Builder.SetInsertPoint(InsertPt);
114 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
115
116 // Short-cut if the mask is all-true.
117 bool IsAllOnesMask =
118 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
119
120 if (IsAllOnesMask) {
121 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
122 CI->replaceAllUsesWith(NewI);
123 CI->eraseFromParent();
124 return;
125 }
126
127 // Adjust alignment for the scalar instruction.
128 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
129 // Bitcast %addr fron i8* to EltTy*
130 Type *NewPtrType =
131 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
132 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
133 unsigned VectorWidth = VecType->getNumElements();
134
135 Value *UndefVal = UndefValue::get(VecType);
136
137 // The result vector
138 Value *VResult = UndefVal;
139
140 if (isa<ConstantVector>(Mask)) {
141 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
142 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
143 continue;
144 Value *Gep =
145 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
146 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
147 VResult =
148 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
149 }
150 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
151 CI->replaceAllUsesWith(NewI);
152 CI->eraseFromParent();
153 return;
154 }
155
156 PHINode *Phi = nullptr;
157 Value *PrevPhi = UndefVal;
158
159 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
160
161 // Fill the "else" block, created in the previous iteration
162 //
163 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
164 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
165 // %to_load = icmp eq i1 %mask_1, true
166 // br i1 %to_load, label %cond.load, label %else
167 //
168 if (Idx > 0) {
169 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
170 Phi->addIncoming(VResult, CondBlock);
171 Phi->addIncoming(PrevPhi, PrevIfBlock);
172 PrevPhi = Phi;
173 VResult = Phi;
174 }
175
176 Value *Predicate =
177 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
178 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
179 ConstantInt::get(Predicate->getType(), 1));
180
181 // Create "cond" block
182 //
183 // %EltAddr = getelementptr i32* %1, i32 0
184 // %Elt = load i32* %EltAddr
185 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
186 //
187 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
188 Builder.SetInsertPoint(InsertPt);
189
190 Value *Gep =
191 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
192 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
193 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
194
195 // Create "else" block, fill it in the next iteration
196 BasicBlock *NewIfBlock =
197 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
198 Builder.SetInsertPoint(InsertPt);
199 Instruction *OldBr = IfBlock->getTerminator();
200 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
201 OldBr->eraseFromParent();
202 PrevIfBlock = IfBlock;
203 IfBlock = NewIfBlock;
204 }
205
206 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
207 Phi->addIncoming(VResult, CondBlock);
208 Phi->addIncoming(PrevPhi, PrevIfBlock);
209 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
210 CI->replaceAllUsesWith(NewI);
211 CI->eraseFromParent();
212}
213
214// Translate a masked store intrinsic, like
215// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
216// <16 x i1> %mask)
217// to a chain of basic blocks, that stores element one-by-one if
218// the appropriate mask bit is set
219//
220// %1 = bitcast i8* %addr to i32*
221// %2 = extractelement <16 x i1> %mask, i32 0
222// %3 = icmp eq i1 %2, true
223// br i1 %3, label %cond.store, label %else
224//
225// cond.store: ; preds = %0
226// %4 = extractelement <16 x i32> %val, i32 0
227// %5 = getelementptr i32* %1, i32 0
228// store i32 %4, i32* %5
229// br label %else
230//
231// else: ; preds = %0, %cond.store
232// %6 = extractelement <16 x i1> %mask, i32 1
233// %7 = icmp eq i1 %6, true
234// br i1 %7, label %cond.store1, label %else2
235//
236// cond.store1: ; preds = %else
237// %8 = extractelement <16 x i32> %val, i32 1
238// %9 = getelementptr i32* %1, i32 1
239// store i32 %8, i32* %9
240// br label %else2
241// . . .
242static void scalarizeMaskedStore(CallInst *CI) {
243 Value *Src = CI->getArgOperand(0);
244 Value *Ptr = CI->getArgOperand(1);
245 Value *Alignment = CI->getArgOperand(2);
246 Value *Mask = CI->getArgOperand(3);
247
248 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
249 VectorType *VecType = dyn_cast<VectorType>(Src->getType());
250 assert(VecType && "Unexpected data type in masked store intrinsic");
251
252 Type *EltTy = VecType->getElementType();
253
254 IRBuilder<> Builder(CI->getContext());
255 Instruction *InsertPt = CI;
256 BasicBlock *IfBlock = CI->getParent();
257 Builder.SetInsertPoint(InsertPt);
258 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
259
260 // Short-cut if the mask is all-true.
261 bool IsAllOnesMask =
262 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
263
264 if (IsAllOnesMask) {
265 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
266 CI->eraseFromParent();
267 return;
268 }
269
270 // Adjust alignment for the scalar instruction.
271 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
272 // Bitcast %addr fron i8* to EltTy*
273 Type *NewPtrType =
274 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
275 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
276 unsigned VectorWidth = VecType->getNumElements();
277
278 if (isa<ConstantVector>(Mask)) {
279 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
280 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
281 continue;
282 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
283 Value *Gep =
284 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
285 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
286 }
287 CI->eraseFromParent();
288 return;
289 }
290
291 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
292
293 // Fill the "else" block, created in the previous iteration
294 //
295 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
296 // %to_store = icmp eq i1 %mask_1, true
297 // br i1 %to_store, label %cond.store, label %else
298 //
299 Value *Predicate =
300 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
301 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
302 ConstantInt::get(Predicate->getType(), 1));
303
304 // Create "cond" block
305 //
306 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
307 // %EltAddr = getelementptr i32* %1, i32 0
308 // %store i32 %OneElt, i32* %EltAddr
309 //
310 BasicBlock *CondBlock =
311 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
312 Builder.SetInsertPoint(InsertPt);
313
314 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
315 Value *Gep =
316 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
317 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
318
319 // Create "else" block, fill it in the next iteration
320 BasicBlock *NewIfBlock =
321 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
322 Builder.SetInsertPoint(InsertPt);
323 Instruction *OldBr = IfBlock->getTerminator();
324 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
325 OldBr->eraseFromParent();
326 IfBlock = NewIfBlock;
327 }
328 CI->eraseFromParent();
329}
330
331// Translate a masked gather intrinsic like
332// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
333// <16 x i1> %Mask, <16 x i32> %Src)
334// to a chain of basic blocks, with loading element one-by-one if
335// the appropriate mask bit is set
336//
337// % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
338// % Mask0 = extractelement <16 x i1> %Mask, i32 0
339// % ToLoad0 = icmp eq i1 % Mask0, true
340// br i1 % ToLoad0, label %cond.load, label %else
341//
342// cond.load:
343// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
344// % Load0 = load i32, i32* % Ptr0, align 4
345// % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
346// br label %else
347//
348// else:
349// %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
350// % Mask1 = extractelement <16 x i1> %Mask, i32 1
351// % ToLoad1 = icmp eq i1 % Mask1, true
352// br i1 % ToLoad1, label %cond.load1, label %else2
353//
354// cond.load1:
355// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
356// % Load1 = load i32, i32* % Ptr1, align 4
357// % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
358// br label %else2
359// . . .
360// % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
361// ret <16 x i32> %Result
362static void scalarizeMaskedGather(CallInst *CI) {
363 Value *Ptrs = CI->getArgOperand(0);
364 Value *Alignment = CI->getArgOperand(1);
365 Value *Mask = CI->getArgOperand(2);
366 Value *Src0 = CI->getArgOperand(3);
367
368 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
369
370 assert(VecType && "Unexpected return type of masked load intrinsic");
371
372 IRBuilder<> Builder(CI->getContext());
373 Instruction *InsertPt = CI;
374 BasicBlock *IfBlock = CI->getParent();
375 BasicBlock *CondBlock = nullptr;
376 BasicBlock *PrevIfBlock = CI->getParent();
377 Builder.SetInsertPoint(InsertPt);
378 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
379
380 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
381
382 Value *UndefVal = UndefValue::get(VecType);
383
384 // The result vector
385 Value *VResult = UndefVal;
386 unsigned VectorWidth = VecType->getNumElements();
387
388 // Shorten the way if the mask is a vector of constants.
389 bool IsConstMask = isa<ConstantVector>(Mask);
390
391 if (IsConstMask) {
392 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
393 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
394 continue;
395 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
396 "Ptr" + Twine(Idx));
397 LoadInst *Load =
398 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
399 VResult = Builder.CreateInsertElement(
400 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
401 }
402 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
403 CI->replaceAllUsesWith(NewI);
404 CI->eraseFromParent();
405 return;
406 }
407
408 PHINode *Phi = nullptr;
409 Value *PrevPhi = UndefVal;
410
411 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
412
413 // Fill the "else" block, created in the previous iteration
414 //
415 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
416 // %ToLoad1 = icmp eq i1 %Mask1, true
417 // br i1 %ToLoad1, label %cond.load, label %else
418 //
419 if (Idx > 0) {
420 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
421 Phi->addIncoming(VResult, CondBlock);
422 Phi->addIncoming(PrevPhi, PrevIfBlock);
423 PrevPhi = Phi;
424 VResult = Phi;
425 }
426
427 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
428 "Mask" + Twine(Idx));
429 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
430 ConstantInt::get(Predicate->getType(), 1),
431 "ToLoad" + Twine(Idx));
432
433 // Create "cond" block
434 //
435 // %EltAddr = getelementptr i32* %1, i32 0
436 // %Elt = load i32* %EltAddr
437 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
438 //
439 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
440 Builder.SetInsertPoint(InsertPt);
441
442 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
443 "Ptr" + Twine(Idx));
444 LoadInst *Load =
445 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
446 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
447 "Res" + Twine(Idx));
448
449 // Create "else" block, fill it in the next iteration
450 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
451 Builder.SetInsertPoint(InsertPt);
452 Instruction *OldBr = IfBlock->getTerminator();
453 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
454 OldBr->eraseFromParent();
455 PrevIfBlock = IfBlock;
456 IfBlock = NewIfBlock;
457 }
458
459 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
460 Phi->addIncoming(VResult, CondBlock);
461 Phi->addIncoming(PrevPhi, PrevIfBlock);
462 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
463 CI->replaceAllUsesWith(NewI);
464 CI->eraseFromParent();
465}
466
467// Translate a masked scatter intrinsic, like
468// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
469// <16 x i1> %Mask)
470// to a chain of basic blocks, that stores element one-by-one if
471// the appropriate mask bit is set.
472//
473// % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
474// % Mask0 = extractelement <16 x i1> % Mask, i32 0
475// % ToStore0 = icmp eq i1 % Mask0, true
476// br i1 %ToStore0, label %cond.store, label %else
477//
478// cond.store:
479// % Elt0 = extractelement <16 x i32> %Src, i32 0
480// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
481// store i32 %Elt0, i32* % Ptr0, align 4
482// br label %else
483//
484// else:
485// % Mask1 = extractelement <16 x i1> % Mask, i32 1
486// % ToStore1 = icmp eq i1 % Mask1, true
487// br i1 % ToStore1, label %cond.store1, label %else2
488//
489// cond.store1:
490// % Elt1 = extractelement <16 x i32> %Src, i32 1
491// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
492// store i32 % Elt1, i32* % Ptr1, align 4
493// br label %else2
494// . . .
495static void scalarizeMaskedScatter(CallInst *CI) {
496 Value *Src = CI->getArgOperand(0);
497 Value *Ptrs = CI->getArgOperand(1);
498 Value *Alignment = CI->getArgOperand(2);
499 Value *Mask = CI->getArgOperand(3);
500
501 assert(isa<VectorType>(Src->getType()) &&
502 "Unexpected data type in masked scatter intrinsic");
503 assert(isa<VectorType>(Ptrs->getType()) &&
504 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
505 "Vector of pointers is expected in masked scatter intrinsic");
506
507 IRBuilder<> Builder(CI->getContext());
508 Instruction *InsertPt = CI;
509 BasicBlock *IfBlock = CI->getParent();
510 Builder.SetInsertPoint(InsertPt);
511 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
512
513 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
514 unsigned VectorWidth = Src->getType()->getVectorNumElements();
515
516 // Shorten the way if the mask is a vector of constants.
517 bool IsConstMask = isa<ConstantVector>(Mask);
518
519 if (IsConstMask) {
520 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
521 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
522 continue;
523 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
524 "Elt" + Twine(Idx));
525 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
526 "Ptr" + Twine(Idx));
527 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
528 }
529 CI->eraseFromParent();
530 return;
531 }
532 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
533 // Fill the "else" block, created in the previous iteration
534 //
535 // % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
536 // % ToStore = icmp eq i1 % Mask1, true
537 // br i1 % ToStore, label %cond.store, label %else
538 //
539 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
540 "Mask" + Twine(Idx));
541 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
542 ConstantInt::get(Predicate->getType(), 1),
543 "ToStore" + Twine(Idx));
544
545 // Create "cond" block
546 //
547 // % Elt1 = extractelement <16 x i32> %Src, i32 1
548 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
549 // %store i32 % Elt1, i32* % Ptr1
550 //
551 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
552 Builder.SetInsertPoint(InsertPt);
553
554 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
555 "Elt" + Twine(Idx));
556 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
557 "Ptr" + Twine(Idx));
558 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
559
560 // Create "else" block, fill it in the next iteration
561 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
562 Builder.SetInsertPoint(InsertPt);
563 Instruction *OldBr = IfBlock->getTerminator();
564 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
565 OldBr->eraseFromParent();
566 IfBlock = NewIfBlock;
567 }
568 CI->eraseFromParent();
569}
570
571bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
572 if (skipFunction(F))
573 return false;
574
575 bool EverMadeChange = false;
576
577 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
578
579 bool MadeChange = true;
580 while (MadeChange) {
581 MadeChange = false;
582 for (Function::iterator I = F.begin(); I != F.end();) {
583 BasicBlock *BB = &*I++;
584 bool ModifiedDTOnIteration = false;
585 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
586
587 // Restart BB iteration if the dominator tree of the Function was changed
588 if (ModifiedDTOnIteration)
589 break;
590 }
591
592 EverMadeChange |= MadeChange;
593 }
594
595 return EverMadeChange;
596}
597
598bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
599 bool MadeChange = false;
600
601 BasicBlock::iterator CurInstIterator = BB.begin();
602 while (CurInstIterator != BB.end()) {
603 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
604 MadeChange |= optimizeCallInst(CI, ModifiedDT);
605 if (ModifiedDT)
606 return true;
607 }
608
609 return MadeChange;
610}
611
612bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
613 bool &ModifiedDT) {
614
615 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
616 if (II) {
617 switch (II->getIntrinsicID()) {
618 default:
619 break;
620 case Intrinsic::masked_load: {
621 // Scalarize unsupported vector masked load
622 if (!TTI->isLegalMaskedLoad(CI->getType())) {
623 scalarizeMaskedLoad(CI);
624 ModifiedDT = true;
625 return true;
626 }
627 return false;
628 }
629 case Intrinsic::masked_store: {
630 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
631 scalarizeMaskedStore(CI);
632 ModifiedDT = true;
633 return true;
634 }
635 return false;
636 }
637 case Intrinsic::masked_gather: {
638 if (!TTI->isLegalMaskedGather(CI->getType())) {
639 scalarizeMaskedGather(CI);
640 ModifiedDT = true;
641 return true;
642 }
643 return false;
644 }
645 case Intrinsic::masked_scatter: {
646 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
647 scalarizeMaskedScatter(CI);
648 ModifiedDT = true;
649 return true;
650 }
651 return false;
652 }
653 }
654 }
655
656 return false;
657}