blob: dab5b91f50addd172dc2fda6dc75835aedb532c4 [file] [log] [blame]
Ayman Musac5490e52017-05-15 11:30:54 +00001//=== ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ===//
2//=== instrinsics ===//
3//
4// The LLVM Compiler Infrastructure
5//
6// This file is distributed under the University of Illinois Open Source
7// License. See LICENSE.TXT for details.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass replaces masked memory intrinsics - when unsupported by the target
12// - with a chain of basic blocks, that deal with the elements one-by-one if the
13// appropriate mask bit is set.
14//
15//===----------------------------------------------------------------------===//
16
17#include "llvm/Analysis/TargetTransformInfo.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/Target/TargetSubtargetInfo.h"
20
21using namespace llvm;
22
23#define DEBUG_TYPE "scalarize-masked-mem-intrin"
24
25namespace {
26
27class ScalarizeMaskedMemIntrin : public FunctionPass {
28 const TargetTransformInfo *TTI;
29
30public:
31 static char ID; // Pass identification, replacement for typeid
32 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID), TTI(nullptr) {
33 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
34 }
35 bool runOnFunction(Function &F) override;
36
37 StringRef getPassName() const override {
38 return "Scalarize Masked Memory Intrinsics";
39 }
40
41 void getAnalysisUsage(AnalysisUsage &AU) const override {
42 AU.addRequired<TargetTransformInfoWrapperPass>();
43 }
44
45private:
46 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
47 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
48};
49} // namespace
50
51char ScalarizeMaskedMemIntrin::ID = 0;
52INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrin, "scalarize-masked-mem-intrin",
53 "Scalarize unsupported masked memory intrinsics", false,
54 false)
55INITIALIZE_PASS_END(ScalarizeMaskedMemIntrin, "scalarize-masked-mem-intrin",
56 "Scalarize unsupported masked memory intrinsics", false,
57 false)
58
59FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
60 return new ScalarizeMaskedMemIntrin();
61}
62
63// Translate a masked load intrinsic like
64// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
65// <16 x i1> %mask, <16 x i32> %passthru)
66// to a chain of basic blocks, with loading element one-by-one if
67// the appropriate mask bit is set
68//
69// %1 = bitcast i8* %addr to i32*
70// %2 = extractelement <16 x i1> %mask, i32 0
71// %3 = icmp eq i1 %2, true
72// br i1 %3, label %cond.load, label %else
73//
74// cond.load: ; preds = %0
75// %4 = getelementptr i32* %1, i32 0
76// %5 = load i32* %4
77// %6 = insertelement <16 x i32> undef, i32 %5, i32 0
78// br label %else
79//
80// else: ; preds = %0, %cond.load
81// %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
82// %7 = extractelement <16 x i1> %mask, i32 1
83// %8 = icmp eq i1 %7, true
84// br i1 %8, label %cond.load1, label %else2
85//
86// cond.load1: ; preds = %else
87// %9 = getelementptr i32* %1, i32 1
88// %10 = load i32* %9
89// %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
90// br label %else2
91//
92// else2: ; preds = %else, %cond.load1
93// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
94// %12 = extractelement <16 x i1> %mask, i32 2
95// %13 = icmp eq i1 %12, true
96// br i1 %13, label %cond.load4, label %else5
97//
98static void scalarizeMaskedLoad(CallInst *CI) {
99 Value *Ptr = CI->getArgOperand(0);
100 Value *Alignment = CI->getArgOperand(1);
101 Value *Mask = CI->getArgOperand(2);
102 Value *Src0 = CI->getArgOperand(3);
103
104 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
105 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
106 assert(VecType && "Unexpected return type of masked load intrinsic");
107
108 Type *EltTy = CI->getType()->getVectorElementType();
109
110 IRBuilder<> Builder(CI->getContext());
111 Instruction *InsertPt = CI;
112 BasicBlock *IfBlock = CI->getParent();
113 BasicBlock *CondBlock = nullptr;
114 BasicBlock *PrevIfBlock = CI->getParent();
115
116 Builder.SetInsertPoint(InsertPt);
117 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
118
119 // Short-cut if the mask is all-true.
120 bool IsAllOnesMask =
121 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
122
123 if (IsAllOnesMask) {
124 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
125 CI->replaceAllUsesWith(NewI);
126 CI->eraseFromParent();
127 return;
128 }
129
130 // Adjust alignment for the scalar instruction.
131 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
132 // Bitcast %addr fron i8* to EltTy*
133 Type *NewPtrType =
134 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
135 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
136 unsigned VectorWidth = VecType->getNumElements();
137
138 Value *UndefVal = UndefValue::get(VecType);
139
140 // The result vector
141 Value *VResult = UndefVal;
142
143 if (isa<ConstantVector>(Mask)) {
144 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
145 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
146 continue;
147 Value *Gep =
148 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
149 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
150 VResult =
151 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
152 }
153 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
154 CI->replaceAllUsesWith(NewI);
155 CI->eraseFromParent();
156 return;
157 }
158
159 PHINode *Phi = nullptr;
160 Value *PrevPhi = UndefVal;
161
162 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
163
164 // Fill the "else" block, created in the previous iteration
165 //
166 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
167 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
168 // %to_load = icmp eq i1 %mask_1, true
169 // br i1 %to_load, label %cond.load, label %else
170 //
171 if (Idx > 0) {
172 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
173 Phi->addIncoming(VResult, CondBlock);
174 Phi->addIncoming(PrevPhi, PrevIfBlock);
175 PrevPhi = Phi;
176 VResult = Phi;
177 }
178
179 Value *Predicate =
180 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
181 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
182 ConstantInt::get(Predicate->getType(), 1));
183
184 // Create "cond" block
185 //
186 // %EltAddr = getelementptr i32* %1, i32 0
187 // %Elt = load i32* %EltAddr
188 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
189 //
190 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
191 Builder.SetInsertPoint(InsertPt);
192
193 Value *Gep =
194 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
195 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
196 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
197
198 // Create "else" block, fill it in the next iteration
199 BasicBlock *NewIfBlock =
200 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
201 Builder.SetInsertPoint(InsertPt);
202 Instruction *OldBr = IfBlock->getTerminator();
203 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
204 OldBr->eraseFromParent();
205 PrevIfBlock = IfBlock;
206 IfBlock = NewIfBlock;
207 }
208
209 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
210 Phi->addIncoming(VResult, CondBlock);
211 Phi->addIncoming(PrevPhi, PrevIfBlock);
212 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
213 CI->replaceAllUsesWith(NewI);
214 CI->eraseFromParent();
215}
216
217// Translate a masked store intrinsic, like
218// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
219// <16 x i1> %mask)
220// to a chain of basic blocks, that stores element one-by-one if
221// the appropriate mask bit is set
222//
223// %1 = bitcast i8* %addr to i32*
224// %2 = extractelement <16 x i1> %mask, i32 0
225// %3 = icmp eq i1 %2, true
226// br i1 %3, label %cond.store, label %else
227//
228// cond.store: ; preds = %0
229// %4 = extractelement <16 x i32> %val, i32 0
230// %5 = getelementptr i32* %1, i32 0
231// store i32 %4, i32* %5
232// br label %else
233//
234// else: ; preds = %0, %cond.store
235// %6 = extractelement <16 x i1> %mask, i32 1
236// %7 = icmp eq i1 %6, true
237// br i1 %7, label %cond.store1, label %else2
238//
239// cond.store1: ; preds = %else
240// %8 = extractelement <16 x i32> %val, i32 1
241// %9 = getelementptr i32* %1, i32 1
242// store i32 %8, i32* %9
243// br label %else2
244// . . .
245static void scalarizeMaskedStore(CallInst *CI) {
246 Value *Src = CI->getArgOperand(0);
247 Value *Ptr = CI->getArgOperand(1);
248 Value *Alignment = CI->getArgOperand(2);
249 Value *Mask = CI->getArgOperand(3);
250
251 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
252 VectorType *VecType = dyn_cast<VectorType>(Src->getType());
253 assert(VecType && "Unexpected data type in masked store intrinsic");
254
255 Type *EltTy = VecType->getElementType();
256
257 IRBuilder<> Builder(CI->getContext());
258 Instruction *InsertPt = CI;
259 BasicBlock *IfBlock = CI->getParent();
260 Builder.SetInsertPoint(InsertPt);
261 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
262
263 // Short-cut if the mask is all-true.
264 bool IsAllOnesMask =
265 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
266
267 if (IsAllOnesMask) {
268 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
269 CI->eraseFromParent();
270 return;
271 }
272
273 // Adjust alignment for the scalar instruction.
274 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
275 // Bitcast %addr fron i8* to EltTy*
276 Type *NewPtrType =
277 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
278 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
279 unsigned VectorWidth = VecType->getNumElements();
280
281 if (isa<ConstantVector>(Mask)) {
282 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
283 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
284 continue;
285 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
286 Value *Gep =
287 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
288 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
289 }
290 CI->eraseFromParent();
291 return;
292 }
293
294 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
295
296 // Fill the "else" block, created in the previous iteration
297 //
298 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
299 // %to_store = icmp eq i1 %mask_1, true
300 // br i1 %to_store, label %cond.store, label %else
301 //
302 Value *Predicate =
303 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
304 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
305 ConstantInt::get(Predicate->getType(), 1));
306
307 // Create "cond" block
308 //
309 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
310 // %EltAddr = getelementptr i32* %1, i32 0
311 // %store i32 %OneElt, i32* %EltAddr
312 //
313 BasicBlock *CondBlock =
314 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
315 Builder.SetInsertPoint(InsertPt);
316
317 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
318 Value *Gep =
319 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
320 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
321
322 // Create "else" block, fill it in the next iteration
323 BasicBlock *NewIfBlock =
324 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
325 Builder.SetInsertPoint(InsertPt);
326 Instruction *OldBr = IfBlock->getTerminator();
327 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
328 OldBr->eraseFromParent();
329 IfBlock = NewIfBlock;
330 }
331 CI->eraseFromParent();
332}
333
334// Translate a masked gather intrinsic like
335// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
336// <16 x i1> %Mask, <16 x i32> %Src)
337// to a chain of basic blocks, with loading element one-by-one if
338// the appropriate mask bit is set
339//
340// % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
341// % Mask0 = extractelement <16 x i1> %Mask, i32 0
342// % ToLoad0 = icmp eq i1 % Mask0, true
343// br i1 % ToLoad0, label %cond.load, label %else
344//
345// cond.load:
346// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
347// % Load0 = load i32, i32* % Ptr0, align 4
348// % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
349// br label %else
350//
351// else:
352// %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
353// % Mask1 = extractelement <16 x i1> %Mask, i32 1
354// % ToLoad1 = icmp eq i1 % Mask1, true
355// br i1 % ToLoad1, label %cond.load1, label %else2
356//
357// cond.load1:
358// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
359// % Load1 = load i32, i32* % Ptr1, align 4
360// % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
361// br label %else2
362// . . .
363// % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
364// ret <16 x i32> %Result
365static void scalarizeMaskedGather(CallInst *CI) {
366 Value *Ptrs = CI->getArgOperand(0);
367 Value *Alignment = CI->getArgOperand(1);
368 Value *Mask = CI->getArgOperand(2);
369 Value *Src0 = CI->getArgOperand(3);
370
371 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
372
373 assert(VecType && "Unexpected return type of masked load intrinsic");
374
375 IRBuilder<> Builder(CI->getContext());
376 Instruction *InsertPt = CI;
377 BasicBlock *IfBlock = CI->getParent();
378 BasicBlock *CondBlock = nullptr;
379 BasicBlock *PrevIfBlock = CI->getParent();
380 Builder.SetInsertPoint(InsertPt);
381 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
382
383 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
384
385 Value *UndefVal = UndefValue::get(VecType);
386
387 // The result vector
388 Value *VResult = UndefVal;
389 unsigned VectorWidth = VecType->getNumElements();
390
391 // Shorten the way if the mask is a vector of constants.
392 bool IsConstMask = isa<ConstantVector>(Mask);
393
394 if (IsConstMask) {
395 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
396 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
397 continue;
398 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
399 "Ptr" + Twine(Idx));
400 LoadInst *Load =
401 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
402 VResult = Builder.CreateInsertElement(
403 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
404 }
405 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
406 CI->replaceAllUsesWith(NewI);
407 CI->eraseFromParent();
408 return;
409 }
410
411 PHINode *Phi = nullptr;
412 Value *PrevPhi = UndefVal;
413
414 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
415
416 // Fill the "else" block, created in the previous iteration
417 //
418 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
419 // %ToLoad1 = icmp eq i1 %Mask1, true
420 // br i1 %ToLoad1, label %cond.load, label %else
421 //
422 if (Idx > 0) {
423 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
424 Phi->addIncoming(VResult, CondBlock);
425 Phi->addIncoming(PrevPhi, PrevIfBlock);
426 PrevPhi = Phi;
427 VResult = Phi;
428 }
429
430 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
431 "Mask" + Twine(Idx));
432 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
433 ConstantInt::get(Predicate->getType(), 1),
434 "ToLoad" + Twine(Idx));
435
436 // Create "cond" block
437 //
438 // %EltAddr = getelementptr i32* %1, i32 0
439 // %Elt = load i32* %EltAddr
440 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
441 //
442 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
443 Builder.SetInsertPoint(InsertPt);
444
445 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
446 "Ptr" + Twine(Idx));
447 LoadInst *Load =
448 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
449 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
450 "Res" + Twine(Idx));
451
452 // Create "else" block, fill it in the next iteration
453 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
454 Builder.SetInsertPoint(InsertPt);
455 Instruction *OldBr = IfBlock->getTerminator();
456 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
457 OldBr->eraseFromParent();
458 PrevIfBlock = IfBlock;
459 IfBlock = NewIfBlock;
460 }
461
462 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
463 Phi->addIncoming(VResult, CondBlock);
464 Phi->addIncoming(PrevPhi, PrevIfBlock);
465 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
466 CI->replaceAllUsesWith(NewI);
467 CI->eraseFromParent();
468}
469
470// Translate a masked scatter intrinsic, like
471// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
472// <16 x i1> %Mask)
473// to a chain of basic blocks, that stores element one-by-one if
474// the appropriate mask bit is set.
475//
476// % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
477// % Mask0 = extractelement <16 x i1> % Mask, i32 0
478// % ToStore0 = icmp eq i1 % Mask0, true
479// br i1 %ToStore0, label %cond.store, label %else
480//
481// cond.store:
482// % Elt0 = extractelement <16 x i32> %Src, i32 0
483// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
484// store i32 %Elt0, i32* % Ptr0, align 4
485// br label %else
486//
487// else:
488// % Mask1 = extractelement <16 x i1> % Mask, i32 1
489// % ToStore1 = icmp eq i1 % Mask1, true
490// br i1 % ToStore1, label %cond.store1, label %else2
491//
492// cond.store1:
493// % Elt1 = extractelement <16 x i32> %Src, i32 1
494// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
495// store i32 % Elt1, i32* % Ptr1, align 4
496// br label %else2
497// . . .
498static void scalarizeMaskedScatter(CallInst *CI) {
499 Value *Src = CI->getArgOperand(0);
500 Value *Ptrs = CI->getArgOperand(1);
501 Value *Alignment = CI->getArgOperand(2);
502 Value *Mask = CI->getArgOperand(3);
503
504 assert(isa<VectorType>(Src->getType()) &&
505 "Unexpected data type in masked scatter intrinsic");
506 assert(isa<VectorType>(Ptrs->getType()) &&
507 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
508 "Vector of pointers is expected in masked scatter intrinsic");
509
510 IRBuilder<> Builder(CI->getContext());
511 Instruction *InsertPt = CI;
512 BasicBlock *IfBlock = CI->getParent();
513 Builder.SetInsertPoint(InsertPt);
514 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
515
516 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
517 unsigned VectorWidth = Src->getType()->getVectorNumElements();
518
519 // Shorten the way if the mask is a vector of constants.
520 bool IsConstMask = isa<ConstantVector>(Mask);
521
522 if (IsConstMask) {
523 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
524 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
525 continue;
526 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
527 "Elt" + Twine(Idx));
528 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
529 "Ptr" + Twine(Idx));
530 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
531 }
532 CI->eraseFromParent();
533 return;
534 }
535 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
536 // Fill the "else" block, created in the previous iteration
537 //
538 // % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
539 // % ToStore = icmp eq i1 % Mask1, true
540 // br i1 % ToStore, label %cond.store, label %else
541 //
542 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
543 "Mask" + Twine(Idx));
544 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
545 ConstantInt::get(Predicate->getType(), 1),
546 "ToStore" + Twine(Idx));
547
548 // Create "cond" block
549 //
550 // % Elt1 = extractelement <16 x i32> %Src, i32 1
551 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
552 // %store i32 % Elt1, i32* % Ptr1
553 //
554 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
555 Builder.SetInsertPoint(InsertPt);
556
557 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
558 "Elt" + Twine(Idx));
559 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
560 "Ptr" + Twine(Idx));
561 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
562
563 // Create "else" block, fill it in the next iteration
564 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
565 Builder.SetInsertPoint(InsertPt);
566 Instruction *OldBr = IfBlock->getTerminator();
567 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
568 OldBr->eraseFromParent();
569 IfBlock = NewIfBlock;
570 }
571 CI->eraseFromParent();
572}
573
574bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
575 if (skipFunction(F))
576 return false;
577
578 bool EverMadeChange = false;
579
580 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
581
582 bool MadeChange = true;
583 while (MadeChange) {
584 MadeChange = false;
585 for (Function::iterator I = F.begin(); I != F.end();) {
586 BasicBlock *BB = &*I++;
587 bool ModifiedDTOnIteration = false;
588 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
589
590 // Restart BB iteration if the dominator tree of the Function was changed
591 if (ModifiedDTOnIteration)
592 break;
593 }
594
595 EverMadeChange |= MadeChange;
596 }
597
598 return EverMadeChange;
599}
600
601bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
602 bool MadeChange = false;
603
604 BasicBlock::iterator CurInstIterator = BB.begin();
605 while (CurInstIterator != BB.end()) {
606 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
607 MadeChange |= optimizeCallInst(CI, ModifiedDT);
608 if (ModifiedDT)
609 return true;
610 }
611
612 return MadeChange;
613}
614
615bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
616 bool &ModifiedDT) {
617
618 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
619 if (II) {
620 switch (II->getIntrinsicID()) {
621 default:
622 break;
623 case Intrinsic::masked_load: {
624 // Scalarize unsupported vector masked load
625 if (!TTI->isLegalMaskedLoad(CI->getType())) {
626 scalarizeMaskedLoad(CI);
627 ModifiedDT = true;
628 return true;
629 }
630 return false;
631 }
632 case Intrinsic::masked_store: {
633 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
634 scalarizeMaskedStore(CI);
635 ModifiedDT = true;
636 return true;
637 }
638 return false;
639 }
640 case Intrinsic::masked_gather: {
641 if (!TTI->isLegalMaskedGather(CI->getType())) {
642 scalarizeMaskedGather(CI);
643 ModifiedDT = true;
644 return true;
645 }
646 return false;
647 }
648 case Intrinsic::masked_scatter: {
649 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
650 scalarizeMaskedScatter(CI);
651 ModifiedDT = true;
652 return true;
653 }
654 return false;
655 }
656 }
657 }
658
659 return false;
660}