blob: cef413f9d410bf92e4fd1bd05da5a420a42a71b4 [file] [log] [blame]
Eugene Zelenkofa57bd02017-09-27 23:26:01 +00001//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// instrinsics
Ayman Musac5490e52017-05-15 11:30:54 +00003//
4// The LLVM Compiler Infrastructure
5//
6// This file is distributed under the University of Illinois Open Source
7// License. See LICENSE.TXT for details.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass replaces masked memory intrinsics - when unsupported by the target
12// - with a chain of basic blocks, that deal with the elements one-by-one if the
13// appropriate mask bit is set.
14//
15//===----------------------------------------------------------------------===//
16
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000017#include "llvm/ADT/Twine.h"
Ayman Musac5490e52017-05-15 11:30:54 +000018#include "llvm/Analysis/TargetTransformInfo.h"
David Blaikieb3bde2e2017-11-17 01:07:10 +000019#include "llvm/CodeGen/TargetSubtargetInfo.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000020#include "llvm/IR/BasicBlock.h"
21#include "llvm/IR/Constant.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
24#include "llvm/IR/Function.h"
Ayman Musac5490e52017-05-15 11:30:54 +000025#include "llvm/IR/IRBuilder.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000026#include "llvm/IR/InstrTypes.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
Reid Kleckner0e8c4bb2017-09-07 23:27:44 +000029#include "llvm/IR/IntrinsicInst.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000030#include "llvm/IR/Intrinsics.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/Value.h"
33#include "llvm/Pass.h"
34#include "llvm/Support/Casting.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000035#include <algorithm>
36#include <cassert>
Ayman Musac5490e52017-05-15 11:30:54 +000037
38using namespace llvm;
39
40#define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42namespace {
43
44class ScalarizeMaskedMemIntrin : public FunctionPass {
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000045 const TargetTransformInfo *TTI = nullptr;
Ayman Musac5490e52017-05-15 11:30:54 +000046
47public:
48 static char ID; // Pass identification, replacement for typeid
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000049
50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
Ayman Musac5490e52017-05-15 11:30:54 +000051 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52 }
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000053
Ayman Musac5490e52017-05-15 11:30:54 +000054 bool runOnFunction(Function &F) override;
55
56 StringRef getPassName() const override {
57 return "Scalarize Masked Memory Intrinsics";
58 }
59
60 void getAnalysisUsage(AnalysisUsage &AU) const override {
61 AU.addRequired<TargetTransformInfoWrapperPass>();
62 }
63
64private:
65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67};
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000068
69} // end anonymous namespace
Ayman Musac5490e52017-05-15 11:30:54 +000070
71char ScalarizeMaskedMemIntrin::ID = 0;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000072
Matthias Braun1527baa2017-05-25 21:26:32 +000073INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74 "Scalarize unsupported masked memory intrinsics", false, false)
Ayman Musac5490e52017-05-15 11:30:54 +000075
76FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77 return new ScalarizeMaskedMemIntrin();
78}
79
80// Translate a masked load intrinsic like
81// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
82// <16 x i1> %mask, <16 x i32> %passthru)
83// to a chain of basic blocks, with loading element one-by-one if
84// the appropriate mask bit is set
85//
86// %1 = bitcast i8* %addr to i32*
87// %2 = extractelement <16 x i1> %mask, i32 0
88// %3 = icmp eq i1 %2, true
89// br i1 %3, label %cond.load, label %else
90//
91// cond.load: ; preds = %0
92// %4 = getelementptr i32* %1, i32 0
93// %5 = load i32* %4
94// %6 = insertelement <16 x i32> undef, i32 %5, i32 0
95// br label %else
96//
97// else: ; preds = %0, %cond.load
98// %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
99// %7 = extractelement <16 x i1> %mask, i32 1
100// %8 = icmp eq i1 %7, true
101// br i1 %8, label %cond.load1, label %else2
102//
103// cond.load1: ; preds = %else
104// %9 = getelementptr i32* %1, i32 1
105// %10 = load i32* %9
106// %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
107// br label %else2
108//
109// else2: ; preds = %else, %cond.load1
110// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
111// %12 = extractelement <16 x i1> %mask, i32 2
112// %13 = icmp eq i1 %12, true
113// br i1 %13, label %cond.load4, label %else5
114//
115static void scalarizeMaskedLoad(CallInst *CI) {
116 Value *Ptr = CI->getArgOperand(0);
117 Value *Alignment = CI->getArgOperand(1);
118 Value *Mask = CI->getArgOperand(2);
119 Value *Src0 = CI->getArgOperand(3);
120
121 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
122 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
123 assert(VecType && "Unexpected return type of masked load intrinsic");
124
125 Type *EltTy = CI->getType()->getVectorElementType();
126
127 IRBuilder<> Builder(CI->getContext());
128 Instruction *InsertPt = CI;
129 BasicBlock *IfBlock = CI->getParent();
130 BasicBlock *CondBlock = nullptr;
131 BasicBlock *PrevIfBlock = CI->getParent();
132
133 Builder.SetInsertPoint(InsertPt);
134 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
135
136 // Short-cut if the mask is all-true.
137 bool IsAllOnesMask =
138 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
139
140 if (IsAllOnesMask) {
141 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
142 CI->replaceAllUsesWith(NewI);
143 CI->eraseFromParent();
144 return;
145 }
146
147 // Adjust alignment for the scalar instruction.
148 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
149 // Bitcast %addr fron i8* to EltTy*
150 Type *NewPtrType =
151 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
152 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
153 unsigned VectorWidth = VecType->getNumElements();
154
155 Value *UndefVal = UndefValue::get(VecType);
156
157 // The result vector
158 Value *VResult = UndefVal;
159
160 if (isa<ConstantVector>(Mask)) {
161 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
162 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
163 continue;
164 Value *Gep =
165 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
166 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
167 VResult =
168 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
169 }
170 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
171 CI->replaceAllUsesWith(NewI);
172 CI->eraseFromParent();
173 return;
174 }
175
176 PHINode *Phi = nullptr;
177 Value *PrevPhi = UndefVal;
178
179 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000180 // Fill the "else" block, created in the previous iteration
181 //
182 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
183 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
184 // %to_load = icmp eq i1 %mask_1, true
185 // br i1 %to_load, label %cond.load, label %else
186 //
187 if (Idx > 0) {
188 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
189 Phi->addIncoming(VResult, CondBlock);
190 Phi->addIncoming(PrevPhi, PrevIfBlock);
191 PrevPhi = Phi;
192 VResult = Phi;
193 }
194
195 Value *Predicate =
196 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
197 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
198 ConstantInt::get(Predicate->getType(), 1));
199
200 // Create "cond" block
201 //
202 // %EltAddr = getelementptr i32* %1, i32 0
203 // %Elt = load i32* %EltAddr
204 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
205 //
206 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
207 Builder.SetInsertPoint(InsertPt);
208
209 Value *Gep =
210 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
211 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
212 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
213
214 // Create "else" block, fill it in the next iteration
215 BasicBlock *NewIfBlock =
216 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
217 Builder.SetInsertPoint(InsertPt);
218 Instruction *OldBr = IfBlock->getTerminator();
219 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
220 OldBr->eraseFromParent();
221 PrevIfBlock = IfBlock;
222 IfBlock = NewIfBlock;
223 }
224
225 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
226 Phi->addIncoming(VResult, CondBlock);
227 Phi->addIncoming(PrevPhi, PrevIfBlock);
228 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
229 CI->replaceAllUsesWith(NewI);
230 CI->eraseFromParent();
231}
232
233// Translate a masked store intrinsic, like
234// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
235// <16 x i1> %mask)
236// to a chain of basic blocks, that stores element one-by-one if
237// the appropriate mask bit is set
238//
239// %1 = bitcast i8* %addr to i32*
240// %2 = extractelement <16 x i1> %mask, i32 0
241// %3 = icmp eq i1 %2, true
242// br i1 %3, label %cond.store, label %else
243//
244// cond.store: ; preds = %0
245// %4 = extractelement <16 x i32> %val, i32 0
246// %5 = getelementptr i32* %1, i32 0
247// store i32 %4, i32* %5
248// br label %else
249//
250// else: ; preds = %0, %cond.store
251// %6 = extractelement <16 x i1> %mask, i32 1
252// %7 = icmp eq i1 %6, true
253// br i1 %7, label %cond.store1, label %else2
254//
255// cond.store1: ; preds = %else
256// %8 = extractelement <16 x i32> %val, i32 1
257// %9 = getelementptr i32* %1, i32 1
258// store i32 %8, i32* %9
259// br label %else2
260// . . .
261static void scalarizeMaskedStore(CallInst *CI) {
262 Value *Src = CI->getArgOperand(0);
263 Value *Ptr = CI->getArgOperand(1);
264 Value *Alignment = CI->getArgOperand(2);
265 Value *Mask = CI->getArgOperand(3);
266
267 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
268 VectorType *VecType = dyn_cast<VectorType>(Src->getType());
269 assert(VecType && "Unexpected data type in masked store intrinsic");
270
271 Type *EltTy = VecType->getElementType();
272
273 IRBuilder<> Builder(CI->getContext());
274 Instruction *InsertPt = CI;
275 BasicBlock *IfBlock = CI->getParent();
276 Builder.SetInsertPoint(InsertPt);
277 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
278
279 // Short-cut if the mask is all-true.
280 bool IsAllOnesMask =
281 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
282
283 if (IsAllOnesMask) {
284 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
285 CI->eraseFromParent();
286 return;
287 }
288
289 // Adjust alignment for the scalar instruction.
290 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
291 // Bitcast %addr fron i8* to EltTy*
292 Type *NewPtrType =
293 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
294 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
295 unsigned VectorWidth = VecType->getNumElements();
296
297 if (isa<ConstantVector>(Mask)) {
298 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
299 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
300 continue;
301 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
302 Value *Gep =
303 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
304 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
305 }
306 CI->eraseFromParent();
307 return;
308 }
309
310 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000311 // Fill the "else" block, created in the previous iteration
312 //
313 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
314 // %to_store = icmp eq i1 %mask_1, true
315 // br i1 %to_store, label %cond.store, label %else
316 //
317 Value *Predicate =
318 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
319 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
320 ConstantInt::get(Predicate->getType(), 1));
321
322 // Create "cond" block
323 //
324 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
325 // %EltAddr = getelementptr i32* %1, i32 0
326 // %store i32 %OneElt, i32* %EltAddr
327 //
328 BasicBlock *CondBlock =
329 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
330 Builder.SetInsertPoint(InsertPt);
331
332 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
333 Value *Gep =
334 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
335 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
336
337 // Create "else" block, fill it in the next iteration
338 BasicBlock *NewIfBlock =
339 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
340 Builder.SetInsertPoint(InsertPt);
341 Instruction *OldBr = IfBlock->getTerminator();
342 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
343 OldBr->eraseFromParent();
344 IfBlock = NewIfBlock;
345 }
346 CI->eraseFromParent();
347}
348
349// Translate a masked gather intrinsic like
350// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
351// <16 x i1> %Mask, <16 x i32> %Src)
352// to a chain of basic blocks, with loading element one-by-one if
353// the appropriate mask bit is set
354//
355// % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
356// % Mask0 = extractelement <16 x i1> %Mask, i32 0
357// % ToLoad0 = icmp eq i1 % Mask0, true
358// br i1 % ToLoad0, label %cond.load, label %else
359//
360// cond.load:
361// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
362// % Load0 = load i32, i32* % Ptr0, align 4
363// % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
364// br label %else
365//
366// else:
367// %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
368// % Mask1 = extractelement <16 x i1> %Mask, i32 1
369// % ToLoad1 = icmp eq i1 % Mask1, true
370// br i1 % ToLoad1, label %cond.load1, label %else2
371//
372// cond.load1:
373// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
374// % Load1 = load i32, i32* % Ptr1, align 4
375// % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
376// br label %else2
377// . . .
378// % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
379// ret <16 x i32> %Result
380static void scalarizeMaskedGather(CallInst *CI) {
381 Value *Ptrs = CI->getArgOperand(0);
382 Value *Alignment = CI->getArgOperand(1);
383 Value *Mask = CI->getArgOperand(2);
384 Value *Src0 = CI->getArgOperand(3);
385
386 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
387
388 assert(VecType && "Unexpected return type of masked load intrinsic");
389
390 IRBuilder<> Builder(CI->getContext());
391 Instruction *InsertPt = CI;
392 BasicBlock *IfBlock = CI->getParent();
393 BasicBlock *CondBlock = nullptr;
394 BasicBlock *PrevIfBlock = CI->getParent();
395 Builder.SetInsertPoint(InsertPt);
396 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
397
398 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
399
400 Value *UndefVal = UndefValue::get(VecType);
401
402 // The result vector
403 Value *VResult = UndefVal;
404 unsigned VectorWidth = VecType->getNumElements();
405
406 // Shorten the way if the mask is a vector of constants.
407 bool IsConstMask = isa<ConstantVector>(Mask);
408
409 if (IsConstMask) {
410 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
411 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
412 continue;
413 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
414 "Ptr" + Twine(Idx));
415 LoadInst *Load =
416 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
417 VResult = Builder.CreateInsertElement(
418 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
419 }
420 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
421 CI->replaceAllUsesWith(NewI);
422 CI->eraseFromParent();
423 return;
424 }
425
426 PHINode *Phi = nullptr;
427 Value *PrevPhi = UndefVal;
428
429 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000430 // Fill the "else" block, created in the previous iteration
431 //
432 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
433 // %ToLoad1 = icmp eq i1 %Mask1, true
434 // br i1 %ToLoad1, label %cond.load, label %else
435 //
436 if (Idx > 0) {
437 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
438 Phi->addIncoming(VResult, CondBlock);
439 Phi->addIncoming(PrevPhi, PrevIfBlock);
440 PrevPhi = Phi;
441 VResult = Phi;
442 }
443
444 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
445 "Mask" + Twine(Idx));
446 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
447 ConstantInt::get(Predicate->getType(), 1),
448 "ToLoad" + Twine(Idx));
449
450 // Create "cond" block
451 //
452 // %EltAddr = getelementptr i32* %1, i32 0
453 // %Elt = load i32* %EltAddr
454 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
455 //
456 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
457 Builder.SetInsertPoint(InsertPt);
458
459 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
460 "Ptr" + Twine(Idx));
461 LoadInst *Load =
462 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
463 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
464 "Res" + Twine(Idx));
465
466 // Create "else" block, fill it in the next iteration
467 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
468 Builder.SetInsertPoint(InsertPt);
469 Instruction *OldBr = IfBlock->getTerminator();
470 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
471 OldBr->eraseFromParent();
472 PrevIfBlock = IfBlock;
473 IfBlock = NewIfBlock;
474 }
475
476 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
477 Phi->addIncoming(VResult, CondBlock);
478 Phi->addIncoming(PrevPhi, PrevIfBlock);
479 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
480 CI->replaceAllUsesWith(NewI);
481 CI->eraseFromParent();
482}
483
484// Translate a masked scatter intrinsic, like
485// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
486// <16 x i1> %Mask)
487// to a chain of basic blocks, that stores element one-by-one if
488// the appropriate mask bit is set.
489//
490// % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
491// % Mask0 = extractelement <16 x i1> % Mask, i32 0
492// % ToStore0 = icmp eq i1 % Mask0, true
493// br i1 %ToStore0, label %cond.store, label %else
494//
495// cond.store:
496// % Elt0 = extractelement <16 x i32> %Src, i32 0
497// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
498// store i32 %Elt0, i32* % Ptr0, align 4
499// br label %else
500//
501// else:
502// % Mask1 = extractelement <16 x i1> % Mask, i32 1
503// % ToStore1 = icmp eq i1 % Mask1, true
504// br i1 % ToStore1, label %cond.store1, label %else2
505//
506// cond.store1:
507// % Elt1 = extractelement <16 x i32> %Src, i32 1
508// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
509// store i32 % Elt1, i32* % Ptr1, align 4
510// br label %else2
511// . . .
512static void scalarizeMaskedScatter(CallInst *CI) {
513 Value *Src = CI->getArgOperand(0);
514 Value *Ptrs = CI->getArgOperand(1);
515 Value *Alignment = CI->getArgOperand(2);
516 Value *Mask = CI->getArgOperand(3);
517
518 assert(isa<VectorType>(Src->getType()) &&
519 "Unexpected data type in masked scatter intrinsic");
520 assert(isa<VectorType>(Ptrs->getType()) &&
521 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
522 "Vector of pointers is expected in masked scatter intrinsic");
523
524 IRBuilder<> Builder(CI->getContext());
525 Instruction *InsertPt = CI;
526 BasicBlock *IfBlock = CI->getParent();
527 Builder.SetInsertPoint(InsertPt);
528 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
529
530 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
531 unsigned VectorWidth = Src->getType()->getVectorNumElements();
532
533 // Shorten the way if the mask is a vector of constants.
534 bool IsConstMask = isa<ConstantVector>(Mask);
535
536 if (IsConstMask) {
537 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
538 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
539 continue;
540 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
541 "Elt" + Twine(Idx));
542 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
543 "Ptr" + Twine(Idx));
544 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
545 }
546 CI->eraseFromParent();
547 return;
548 }
549 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
550 // Fill the "else" block, created in the previous iteration
551 //
552 // % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
553 // % ToStore = icmp eq i1 % Mask1, true
554 // br i1 % ToStore, label %cond.store, label %else
555 //
556 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
557 "Mask" + Twine(Idx));
558 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
559 ConstantInt::get(Predicate->getType(), 1),
560 "ToStore" + Twine(Idx));
561
562 // Create "cond" block
563 //
564 // % Elt1 = extractelement <16 x i32> %Src, i32 1
565 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
566 // %store i32 % Elt1, i32* % Ptr1
567 //
568 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
569 Builder.SetInsertPoint(InsertPt);
570
571 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
572 "Elt" + Twine(Idx));
573 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
574 "Ptr" + Twine(Idx));
575 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
576
577 // Create "else" block, fill it in the next iteration
578 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
579 Builder.SetInsertPoint(InsertPt);
580 Instruction *OldBr = IfBlock->getTerminator();
581 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
582 OldBr->eraseFromParent();
583 IfBlock = NewIfBlock;
584 }
585 CI->eraseFromParent();
586}
587
588bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
589 if (skipFunction(F))
590 return false;
591
592 bool EverMadeChange = false;
593
594 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
595
596 bool MadeChange = true;
597 while (MadeChange) {
598 MadeChange = false;
599 for (Function::iterator I = F.begin(); I != F.end();) {
600 BasicBlock *BB = &*I++;
601 bool ModifiedDTOnIteration = false;
602 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
603
604 // Restart BB iteration if the dominator tree of the Function was changed
605 if (ModifiedDTOnIteration)
606 break;
607 }
608
609 EverMadeChange |= MadeChange;
610 }
611
612 return EverMadeChange;
613}
614
615bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
616 bool MadeChange = false;
617
618 BasicBlock::iterator CurInstIterator = BB.begin();
619 while (CurInstIterator != BB.end()) {
620 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
621 MadeChange |= optimizeCallInst(CI, ModifiedDT);
622 if (ModifiedDT)
623 return true;
624 }
625
626 return MadeChange;
627}
628
629bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
630 bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000631 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
632 if (II) {
633 switch (II->getIntrinsicID()) {
634 default:
635 break;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000636 case Intrinsic::masked_load:
Ayman Musac5490e52017-05-15 11:30:54 +0000637 // Scalarize unsupported vector masked load
638 if (!TTI->isLegalMaskedLoad(CI->getType())) {
639 scalarizeMaskedLoad(CI);
640 ModifiedDT = true;
641 return true;
642 }
643 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000644 case Intrinsic::masked_store:
Ayman Musac5490e52017-05-15 11:30:54 +0000645 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
646 scalarizeMaskedStore(CI);
647 ModifiedDT = true;
648 return true;
649 }
650 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000651 case Intrinsic::masked_gather:
Ayman Musac5490e52017-05-15 11:30:54 +0000652 if (!TTI->isLegalMaskedGather(CI->getType())) {
653 scalarizeMaskedGather(CI);
654 ModifiedDT = true;
655 return true;
656 }
657 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000658 case Intrinsic::masked_scatter:
Ayman Musac5490e52017-05-15 11:30:54 +0000659 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
660 scalarizeMaskedScatter(CI);
661 ModifiedDT = true;
662 return true;
663 }
664 return false;
665 }
Ayman Musac5490e52017-05-15 11:30:54 +0000666 }
667
668 return false;
669}