blob: 07b43a82ca994c03abcfb4ae810ab36af1a6593e [file] [log] [blame]
Ayman Musac5490e52017-05-15 11:30:54 +00001//=== ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ===//
2//=== instrinsics ===//
3//
4// The LLVM Compiler Infrastructure
5//
6// This file is distributed under the University of Illinois Open Source
7// License. See LICENSE.TXT for details.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass replaces masked memory intrinsics - when unsupported by the target
12// - with a chain of basic blocks, that deal with the elements one-by-one if the
13// appropriate mask bit is set.
14//
15//===----------------------------------------------------------------------===//
16
17#include "llvm/Analysis/TargetTransformInfo.h"
18#include "llvm/IR/IRBuilder.h"
19#include "llvm/Target/TargetSubtargetInfo.h"
20
21using namespace llvm;
22
23#define DEBUG_TYPE "scalarize-masked-mem-intrin"
24
25namespace {
26
27class ScalarizeMaskedMemIntrin : public FunctionPass {
28 const TargetTransformInfo *TTI;
29
30public:
31 static char ID; // Pass identification, replacement for typeid
32 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID), TTI(nullptr) {
33 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
34 }
35 bool runOnFunction(Function &F) override;
36
37 StringRef getPassName() const override {
38 return "Scalarize Masked Memory Intrinsics";
39 }
40
41 void getAnalysisUsage(AnalysisUsage &AU) const override {
42 AU.addRequired<TargetTransformInfoWrapperPass>();
43 }
44
45private:
46 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
47 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
48};
49} // namespace
50
51char ScalarizeMaskedMemIntrin::ID = 0;
Matthias Braun1527baa2017-05-25 21:26:32 +000052INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
53 "Scalarize unsupported masked memory intrinsics", false, false)
Ayman Musac5490e52017-05-15 11:30:54 +000054
55FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
56 return new ScalarizeMaskedMemIntrin();
57}
58
59// Translate a masked load intrinsic like
60// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
61// <16 x i1> %mask, <16 x i32> %passthru)
62// to a chain of basic blocks, with loading element one-by-one if
63// the appropriate mask bit is set
64//
65// %1 = bitcast i8* %addr to i32*
66// %2 = extractelement <16 x i1> %mask, i32 0
67// %3 = icmp eq i1 %2, true
68// br i1 %3, label %cond.load, label %else
69//
70// cond.load: ; preds = %0
71// %4 = getelementptr i32* %1, i32 0
72// %5 = load i32* %4
73// %6 = insertelement <16 x i32> undef, i32 %5, i32 0
74// br label %else
75//
76// else: ; preds = %0, %cond.load
77// %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
78// %7 = extractelement <16 x i1> %mask, i32 1
79// %8 = icmp eq i1 %7, true
80// br i1 %8, label %cond.load1, label %else2
81//
82// cond.load1: ; preds = %else
83// %9 = getelementptr i32* %1, i32 1
84// %10 = load i32* %9
85// %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
86// br label %else2
87//
88// else2: ; preds = %else, %cond.load1
89// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
90// %12 = extractelement <16 x i1> %mask, i32 2
91// %13 = icmp eq i1 %12, true
92// br i1 %13, label %cond.load4, label %else5
93//
94static void scalarizeMaskedLoad(CallInst *CI) {
95 Value *Ptr = CI->getArgOperand(0);
96 Value *Alignment = CI->getArgOperand(1);
97 Value *Mask = CI->getArgOperand(2);
98 Value *Src0 = CI->getArgOperand(3);
99
100 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
101 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
102 assert(VecType && "Unexpected return type of masked load intrinsic");
103
104 Type *EltTy = CI->getType()->getVectorElementType();
105
106 IRBuilder<> Builder(CI->getContext());
107 Instruction *InsertPt = CI;
108 BasicBlock *IfBlock = CI->getParent();
109 BasicBlock *CondBlock = nullptr;
110 BasicBlock *PrevIfBlock = CI->getParent();
111
112 Builder.SetInsertPoint(InsertPt);
113 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
114
115 // Short-cut if the mask is all-true.
116 bool IsAllOnesMask =
117 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
118
119 if (IsAllOnesMask) {
120 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
121 CI->replaceAllUsesWith(NewI);
122 CI->eraseFromParent();
123 return;
124 }
125
126 // Adjust alignment for the scalar instruction.
127 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
128 // Bitcast %addr fron i8* to EltTy*
129 Type *NewPtrType =
130 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
131 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
132 unsigned VectorWidth = VecType->getNumElements();
133
134 Value *UndefVal = UndefValue::get(VecType);
135
136 // The result vector
137 Value *VResult = UndefVal;
138
139 if (isa<ConstantVector>(Mask)) {
140 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
141 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
142 continue;
143 Value *Gep =
144 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
145 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
146 VResult =
147 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
148 }
149 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
150 CI->replaceAllUsesWith(NewI);
151 CI->eraseFromParent();
152 return;
153 }
154
155 PHINode *Phi = nullptr;
156 Value *PrevPhi = UndefVal;
157
158 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
159
160 // Fill the "else" block, created in the previous iteration
161 //
162 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
163 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
164 // %to_load = icmp eq i1 %mask_1, true
165 // br i1 %to_load, label %cond.load, label %else
166 //
167 if (Idx > 0) {
168 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
169 Phi->addIncoming(VResult, CondBlock);
170 Phi->addIncoming(PrevPhi, PrevIfBlock);
171 PrevPhi = Phi;
172 VResult = Phi;
173 }
174
175 Value *Predicate =
176 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
177 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
178 ConstantInt::get(Predicate->getType(), 1));
179
180 // Create "cond" block
181 //
182 // %EltAddr = getelementptr i32* %1, i32 0
183 // %Elt = load i32* %EltAddr
184 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
185 //
186 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
187 Builder.SetInsertPoint(InsertPt);
188
189 Value *Gep =
190 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
191 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
192 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
193
194 // Create "else" block, fill it in the next iteration
195 BasicBlock *NewIfBlock =
196 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
197 Builder.SetInsertPoint(InsertPt);
198 Instruction *OldBr = IfBlock->getTerminator();
199 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
200 OldBr->eraseFromParent();
201 PrevIfBlock = IfBlock;
202 IfBlock = NewIfBlock;
203 }
204
205 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
206 Phi->addIncoming(VResult, CondBlock);
207 Phi->addIncoming(PrevPhi, PrevIfBlock);
208 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
209 CI->replaceAllUsesWith(NewI);
210 CI->eraseFromParent();
211}
212
213// Translate a masked store intrinsic, like
214// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
215// <16 x i1> %mask)
216// to a chain of basic blocks, that stores element one-by-one if
217// the appropriate mask bit is set
218//
219// %1 = bitcast i8* %addr to i32*
220// %2 = extractelement <16 x i1> %mask, i32 0
221// %3 = icmp eq i1 %2, true
222// br i1 %3, label %cond.store, label %else
223//
224// cond.store: ; preds = %0
225// %4 = extractelement <16 x i32> %val, i32 0
226// %5 = getelementptr i32* %1, i32 0
227// store i32 %4, i32* %5
228// br label %else
229//
230// else: ; preds = %0, %cond.store
231// %6 = extractelement <16 x i1> %mask, i32 1
232// %7 = icmp eq i1 %6, true
233// br i1 %7, label %cond.store1, label %else2
234//
235// cond.store1: ; preds = %else
236// %8 = extractelement <16 x i32> %val, i32 1
237// %9 = getelementptr i32* %1, i32 1
238// store i32 %8, i32* %9
239// br label %else2
240// . . .
241static void scalarizeMaskedStore(CallInst *CI) {
242 Value *Src = CI->getArgOperand(0);
243 Value *Ptr = CI->getArgOperand(1);
244 Value *Alignment = CI->getArgOperand(2);
245 Value *Mask = CI->getArgOperand(3);
246
247 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
248 VectorType *VecType = dyn_cast<VectorType>(Src->getType());
249 assert(VecType && "Unexpected data type in masked store intrinsic");
250
251 Type *EltTy = VecType->getElementType();
252
253 IRBuilder<> Builder(CI->getContext());
254 Instruction *InsertPt = CI;
255 BasicBlock *IfBlock = CI->getParent();
256 Builder.SetInsertPoint(InsertPt);
257 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
258
259 // Short-cut if the mask is all-true.
260 bool IsAllOnesMask =
261 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
262
263 if (IsAllOnesMask) {
264 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
265 CI->eraseFromParent();
266 return;
267 }
268
269 // Adjust alignment for the scalar instruction.
270 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
271 // Bitcast %addr fron i8* to EltTy*
272 Type *NewPtrType =
273 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
274 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
275 unsigned VectorWidth = VecType->getNumElements();
276
277 if (isa<ConstantVector>(Mask)) {
278 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
279 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
280 continue;
281 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
282 Value *Gep =
283 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
284 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
285 }
286 CI->eraseFromParent();
287 return;
288 }
289
290 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
291
292 // Fill the "else" block, created in the previous iteration
293 //
294 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
295 // %to_store = icmp eq i1 %mask_1, true
296 // br i1 %to_store, label %cond.store, label %else
297 //
298 Value *Predicate =
299 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
300 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
301 ConstantInt::get(Predicate->getType(), 1));
302
303 // Create "cond" block
304 //
305 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
306 // %EltAddr = getelementptr i32* %1, i32 0
307 // %store i32 %OneElt, i32* %EltAddr
308 //
309 BasicBlock *CondBlock =
310 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
311 Builder.SetInsertPoint(InsertPt);
312
313 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
314 Value *Gep =
315 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
316 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
317
318 // Create "else" block, fill it in the next iteration
319 BasicBlock *NewIfBlock =
320 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
321 Builder.SetInsertPoint(InsertPt);
322 Instruction *OldBr = IfBlock->getTerminator();
323 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
324 OldBr->eraseFromParent();
325 IfBlock = NewIfBlock;
326 }
327 CI->eraseFromParent();
328}
329
330// Translate a masked gather intrinsic like
331// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
332// <16 x i1> %Mask, <16 x i32> %Src)
333// to a chain of basic blocks, with loading element one-by-one if
334// the appropriate mask bit is set
335//
336// % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
337// % Mask0 = extractelement <16 x i1> %Mask, i32 0
338// % ToLoad0 = icmp eq i1 % Mask0, true
339// br i1 % ToLoad0, label %cond.load, label %else
340//
341// cond.load:
342// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
343// % Load0 = load i32, i32* % Ptr0, align 4
344// % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
345// br label %else
346//
347// else:
348// %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
349// % Mask1 = extractelement <16 x i1> %Mask, i32 1
350// % ToLoad1 = icmp eq i1 % Mask1, true
351// br i1 % ToLoad1, label %cond.load1, label %else2
352//
353// cond.load1:
354// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
355// % Load1 = load i32, i32* % Ptr1, align 4
356// % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
357// br label %else2
358// . . .
359// % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
360// ret <16 x i32> %Result
361static void scalarizeMaskedGather(CallInst *CI) {
362 Value *Ptrs = CI->getArgOperand(0);
363 Value *Alignment = CI->getArgOperand(1);
364 Value *Mask = CI->getArgOperand(2);
365 Value *Src0 = CI->getArgOperand(3);
366
367 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
368
369 assert(VecType && "Unexpected return type of masked load intrinsic");
370
371 IRBuilder<> Builder(CI->getContext());
372 Instruction *InsertPt = CI;
373 BasicBlock *IfBlock = CI->getParent();
374 BasicBlock *CondBlock = nullptr;
375 BasicBlock *PrevIfBlock = CI->getParent();
376 Builder.SetInsertPoint(InsertPt);
377 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
378
379 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
380
381 Value *UndefVal = UndefValue::get(VecType);
382
383 // The result vector
384 Value *VResult = UndefVal;
385 unsigned VectorWidth = VecType->getNumElements();
386
387 // Shorten the way if the mask is a vector of constants.
388 bool IsConstMask = isa<ConstantVector>(Mask);
389
390 if (IsConstMask) {
391 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
392 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
393 continue;
394 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
395 "Ptr" + Twine(Idx));
396 LoadInst *Load =
397 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
398 VResult = Builder.CreateInsertElement(
399 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
400 }
401 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
402 CI->replaceAllUsesWith(NewI);
403 CI->eraseFromParent();
404 return;
405 }
406
407 PHINode *Phi = nullptr;
408 Value *PrevPhi = UndefVal;
409
410 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
411
412 // Fill the "else" block, created in the previous iteration
413 //
414 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
415 // %ToLoad1 = icmp eq i1 %Mask1, true
416 // br i1 %ToLoad1, label %cond.load, label %else
417 //
418 if (Idx > 0) {
419 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
420 Phi->addIncoming(VResult, CondBlock);
421 Phi->addIncoming(PrevPhi, PrevIfBlock);
422 PrevPhi = Phi;
423 VResult = Phi;
424 }
425
426 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
427 "Mask" + Twine(Idx));
428 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
429 ConstantInt::get(Predicate->getType(), 1),
430 "ToLoad" + Twine(Idx));
431
432 // Create "cond" block
433 //
434 // %EltAddr = getelementptr i32* %1, i32 0
435 // %Elt = load i32* %EltAddr
436 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
437 //
438 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
439 Builder.SetInsertPoint(InsertPt);
440
441 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
442 "Ptr" + Twine(Idx));
443 LoadInst *Load =
444 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
445 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
446 "Res" + Twine(Idx));
447
448 // Create "else" block, fill it in the next iteration
449 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
450 Builder.SetInsertPoint(InsertPt);
451 Instruction *OldBr = IfBlock->getTerminator();
452 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
453 OldBr->eraseFromParent();
454 PrevIfBlock = IfBlock;
455 IfBlock = NewIfBlock;
456 }
457
458 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
459 Phi->addIncoming(VResult, CondBlock);
460 Phi->addIncoming(PrevPhi, PrevIfBlock);
461 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
462 CI->replaceAllUsesWith(NewI);
463 CI->eraseFromParent();
464}
465
466// Translate a masked scatter intrinsic, like
467// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
468// <16 x i1> %Mask)
469// to a chain of basic blocks, that stores element one-by-one if
470// the appropriate mask bit is set.
471//
472// % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
473// % Mask0 = extractelement <16 x i1> % Mask, i32 0
474// % ToStore0 = icmp eq i1 % Mask0, true
475// br i1 %ToStore0, label %cond.store, label %else
476//
477// cond.store:
478// % Elt0 = extractelement <16 x i32> %Src, i32 0
479// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
480// store i32 %Elt0, i32* % Ptr0, align 4
481// br label %else
482//
483// else:
484// % Mask1 = extractelement <16 x i1> % Mask, i32 1
485// % ToStore1 = icmp eq i1 % Mask1, true
486// br i1 % ToStore1, label %cond.store1, label %else2
487//
488// cond.store1:
489// % Elt1 = extractelement <16 x i32> %Src, i32 1
490// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
491// store i32 % Elt1, i32* % Ptr1, align 4
492// br label %else2
493// . . .
494static void scalarizeMaskedScatter(CallInst *CI) {
495 Value *Src = CI->getArgOperand(0);
496 Value *Ptrs = CI->getArgOperand(1);
497 Value *Alignment = CI->getArgOperand(2);
498 Value *Mask = CI->getArgOperand(3);
499
500 assert(isa<VectorType>(Src->getType()) &&
501 "Unexpected data type in masked scatter intrinsic");
502 assert(isa<VectorType>(Ptrs->getType()) &&
503 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
504 "Vector of pointers is expected in masked scatter intrinsic");
505
506 IRBuilder<> Builder(CI->getContext());
507 Instruction *InsertPt = CI;
508 BasicBlock *IfBlock = CI->getParent();
509 Builder.SetInsertPoint(InsertPt);
510 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
511
512 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
513 unsigned VectorWidth = Src->getType()->getVectorNumElements();
514
515 // Shorten the way if the mask is a vector of constants.
516 bool IsConstMask = isa<ConstantVector>(Mask);
517
518 if (IsConstMask) {
519 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
520 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
521 continue;
522 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
523 "Elt" + Twine(Idx));
524 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
525 "Ptr" + Twine(Idx));
526 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
527 }
528 CI->eraseFromParent();
529 return;
530 }
531 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
532 // Fill the "else" block, created in the previous iteration
533 //
534 // % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
535 // % ToStore = icmp eq i1 % Mask1, true
536 // br i1 % ToStore, label %cond.store, label %else
537 //
538 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
539 "Mask" + Twine(Idx));
540 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
541 ConstantInt::get(Predicate->getType(), 1),
542 "ToStore" + Twine(Idx));
543
544 // Create "cond" block
545 //
546 // % Elt1 = extractelement <16 x i32> %Src, i32 1
547 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
548 // %store i32 % Elt1, i32* % Ptr1
549 //
550 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
551 Builder.SetInsertPoint(InsertPt);
552
553 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
554 "Elt" + Twine(Idx));
555 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
556 "Ptr" + Twine(Idx));
557 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
558
559 // Create "else" block, fill it in the next iteration
560 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
561 Builder.SetInsertPoint(InsertPt);
562 Instruction *OldBr = IfBlock->getTerminator();
563 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
564 OldBr->eraseFromParent();
565 IfBlock = NewIfBlock;
566 }
567 CI->eraseFromParent();
568}
569
570bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
571 if (skipFunction(F))
572 return false;
573
574 bool EverMadeChange = false;
575
576 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
577
578 bool MadeChange = true;
579 while (MadeChange) {
580 MadeChange = false;
581 for (Function::iterator I = F.begin(); I != F.end();) {
582 BasicBlock *BB = &*I++;
583 bool ModifiedDTOnIteration = false;
584 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
585
586 // Restart BB iteration if the dominator tree of the Function was changed
587 if (ModifiedDTOnIteration)
588 break;
589 }
590
591 EverMadeChange |= MadeChange;
592 }
593
594 return EverMadeChange;
595}
596
597bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
598 bool MadeChange = false;
599
600 BasicBlock::iterator CurInstIterator = BB.begin();
601 while (CurInstIterator != BB.end()) {
602 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
603 MadeChange |= optimizeCallInst(CI, ModifiedDT);
604 if (ModifiedDT)
605 return true;
606 }
607
608 return MadeChange;
609}
610
611bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
612 bool &ModifiedDT) {
613
614 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
615 if (II) {
616 switch (II->getIntrinsicID()) {
617 default:
618 break;
619 case Intrinsic::masked_load: {
620 // Scalarize unsupported vector masked load
621 if (!TTI->isLegalMaskedLoad(CI->getType())) {
622 scalarizeMaskedLoad(CI);
623 ModifiedDT = true;
624 return true;
625 }
626 return false;
627 }
628 case Intrinsic::masked_store: {
629 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
630 scalarizeMaskedStore(CI);
631 ModifiedDT = true;
632 return true;
633 }
634 return false;
635 }
636 case Intrinsic::masked_gather: {
637 if (!TTI->isLegalMaskedGather(CI->getType())) {
638 scalarizeMaskedGather(CI);
639 ModifiedDT = true;
640 return true;
641 }
642 return false;
643 }
644 case Intrinsic::masked_scatter: {
645 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
646 scalarizeMaskedScatter(CI);
647 ModifiedDT = true;
648 return true;
649 }
650 return false;
651 }
652 }
653 }
654
655 return false;
656}