blob: 30659a4df47bbd900aceb164ff74119a3b7fbcb9 [file] [log] [blame]
Eugene Zelenkofa57bd02017-09-27 23:26:01 +00001//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// instrinsics
Ayman Musac5490e52017-05-15 11:30:54 +00003//
4// The LLVM Compiler Infrastructure
5//
6// This file is distributed under the University of Illinois Open Source
7// License. See LICENSE.TXT for details.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass replaces masked memory intrinsics - when unsupported by the target
12// - with a chain of basic blocks, that deal with the elements one-by-one if the
13// appropriate mask bit is set.
14//
15//===----------------------------------------------------------------------===//
16
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000017#include "llvm/ADT/Twine.h"
Ayman Musac5490e52017-05-15 11:30:54 +000018#include "llvm/Analysis/TargetTransformInfo.h"
David Blaikieb3bde2e2017-11-17 01:07:10 +000019#include "llvm/CodeGen/TargetSubtargetInfo.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000020#include "llvm/IR/BasicBlock.h"
21#include "llvm/IR/Constant.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
24#include "llvm/IR/Function.h"
Ayman Musac5490e52017-05-15 11:30:54 +000025#include "llvm/IR/IRBuilder.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000026#include "llvm/IR/InstrTypes.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
Reid Kleckner0e8c4bb2017-09-07 23:27:44 +000029#include "llvm/IR/IntrinsicInst.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000030#include "llvm/IR/Intrinsics.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/Value.h"
33#include "llvm/Pass.h"
34#include "llvm/Support/Casting.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000035#include <algorithm>
36#include <cassert>
Ayman Musac5490e52017-05-15 11:30:54 +000037
38using namespace llvm;
39
40#define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42namespace {
43
44class ScalarizeMaskedMemIntrin : public FunctionPass {
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000045 const TargetTransformInfo *TTI = nullptr;
Ayman Musac5490e52017-05-15 11:30:54 +000046
47public:
48 static char ID; // Pass identification, replacement for typeid
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000049
50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
Ayman Musac5490e52017-05-15 11:30:54 +000051 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52 }
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000053
Ayman Musac5490e52017-05-15 11:30:54 +000054 bool runOnFunction(Function &F) override;
55
56 StringRef getPassName() const override {
57 return "Scalarize Masked Memory Intrinsics";
58 }
59
60 void getAnalysisUsage(AnalysisUsage &AU) const override {
61 AU.addRequired<TargetTransformInfoWrapperPass>();
62 }
63
64private:
65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67};
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000068
69} // end anonymous namespace
Ayman Musac5490e52017-05-15 11:30:54 +000070
71char ScalarizeMaskedMemIntrin::ID = 0;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000072
Matthias Braun1527baa2017-05-25 21:26:32 +000073INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74 "Scalarize unsupported masked memory intrinsics", false, false)
Ayman Musac5490e52017-05-15 11:30:54 +000075
76FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77 return new ScalarizeMaskedMemIntrin();
78}
79
80// Translate a masked load intrinsic like
81// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
82// <16 x i1> %mask, <16 x i32> %passthru)
83// to a chain of basic blocks, with loading element one-by-one if
84// the appropriate mask bit is set
85//
86// %1 = bitcast i8* %addr to i32*
87// %2 = extractelement <16 x i1> %mask, i32 0
88// %3 = icmp eq i1 %2, true
89// br i1 %3, label %cond.load, label %else
90//
91// cond.load: ; preds = %0
92// %4 = getelementptr i32* %1, i32 0
93// %5 = load i32* %4
94// %6 = insertelement <16 x i32> undef, i32 %5, i32 0
95// br label %else
96//
97// else: ; preds = %0, %cond.load
98// %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
99// %7 = extractelement <16 x i1> %mask, i32 1
100// %8 = icmp eq i1 %7, true
101// br i1 %8, label %cond.load1, label %else2
102//
103// cond.load1: ; preds = %else
104// %9 = getelementptr i32* %1, i32 1
105// %10 = load i32* %9
106// %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
107// br label %else2
108//
109// else2: ; preds = %else, %cond.load1
110// %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
111// %12 = extractelement <16 x i1> %mask, i32 2
112// %13 = icmp eq i1 %12, true
113// br i1 %13, label %cond.load4, label %else5
114//
115static void scalarizeMaskedLoad(CallInst *CI) {
116 Value *Ptr = CI->getArgOperand(0);
117 Value *Alignment = CI->getArgOperand(1);
118 Value *Mask = CI->getArgOperand(2);
119 Value *Src0 = CI->getArgOperand(3);
120
121 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
122 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
123 assert(VecType && "Unexpected return type of masked load intrinsic");
124
125 Type *EltTy = CI->getType()->getVectorElementType();
126
127 IRBuilder<> Builder(CI->getContext());
128 Instruction *InsertPt = CI;
129 BasicBlock *IfBlock = CI->getParent();
130 BasicBlock *CondBlock = nullptr;
131 BasicBlock *PrevIfBlock = CI->getParent();
132
133 Builder.SetInsertPoint(InsertPt);
134 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
135
136 // Short-cut if the mask is all-true.
137 bool IsAllOnesMask =
138 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
139
140 if (IsAllOnesMask) {
141 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
142 CI->replaceAllUsesWith(NewI);
143 CI->eraseFromParent();
144 return;
145 }
146
147 // Adjust alignment for the scalar instruction.
148 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
149 // Bitcast %addr fron i8* to EltTy*
150 Type *NewPtrType =
151 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
152 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
153 unsigned VectorWidth = VecType->getNumElements();
154
155 Value *UndefVal = UndefValue::get(VecType);
156
157 // The result vector
158 Value *VResult = UndefVal;
159
160 if (isa<ConstantVector>(Mask)) {
161 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
162 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
163 continue;
164 Value *Gep =
165 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
166 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
167 VResult =
168 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
169 }
170 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
171 CI->replaceAllUsesWith(NewI);
172 CI->eraseFromParent();
173 return;
174 }
175
176 PHINode *Phi = nullptr;
177 Value *PrevPhi = UndefVal;
178
179 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000180 // Fill the "else" block, created in the previous iteration
181 //
182 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
183 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
Craig Topper04236812018-09-27 18:01:48 +0000184 // br i1 %mask_1, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000185 //
186 if (Idx > 0) {
187 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
188 Phi->addIncoming(VResult, CondBlock);
189 Phi->addIncoming(PrevPhi, PrevIfBlock);
190 PrevPhi = Phi;
191 VResult = Phi;
192 }
193
194 Value *Predicate =
195 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000196
197 // Create "cond" block
198 //
199 // %EltAddr = getelementptr i32* %1, i32 0
200 // %Elt = load i32* %EltAddr
201 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
202 //
203 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
204 Builder.SetInsertPoint(InsertPt);
205
206 Value *Gep =
207 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
208 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
209 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
210
211 // Create "else" block, fill it in the next iteration
212 BasicBlock *NewIfBlock =
213 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
214 Builder.SetInsertPoint(InsertPt);
215 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000216 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000217 OldBr->eraseFromParent();
218 PrevIfBlock = IfBlock;
219 IfBlock = NewIfBlock;
220 }
221
222 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
223 Phi->addIncoming(VResult, CondBlock);
224 Phi->addIncoming(PrevPhi, PrevIfBlock);
225 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
226 CI->replaceAllUsesWith(NewI);
227 CI->eraseFromParent();
228}
229
230// Translate a masked store intrinsic, like
231// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
232// <16 x i1> %mask)
233// to a chain of basic blocks, that stores element one-by-one if
234// the appropriate mask bit is set
235//
236// %1 = bitcast i8* %addr to i32*
237// %2 = extractelement <16 x i1> %mask, i32 0
238// %3 = icmp eq i1 %2, true
239// br i1 %3, label %cond.store, label %else
240//
241// cond.store: ; preds = %0
242// %4 = extractelement <16 x i32> %val, i32 0
243// %5 = getelementptr i32* %1, i32 0
244// store i32 %4, i32* %5
245// br label %else
246//
247// else: ; preds = %0, %cond.store
248// %6 = extractelement <16 x i1> %mask, i32 1
249// %7 = icmp eq i1 %6, true
250// br i1 %7, label %cond.store1, label %else2
251//
252// cond.store1: ; preds = %else
253// %8 = extractelement <16 x i32> %val, i32 1
254// %9 = getelementptr i32* %1, i32 1
255// store i32 %8, i32* %9
256// br label %else2
257// . . .
258static void scalarizeMaskedStore(CallInst *CI) {
259 Value *Src = CI->getArgOperand(0);
260 Value *Ptr = CI->getArgOperand(1);
261 Value *Alignment = CI->getArgOperand(2);
262 Value *Mask = CI->getArgOperand(3);
263
264 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
265 VectorType *VecType = dyn_cast<VectorType>(Src->getType());
266 assert(VecType && "Unexpected data type in masked store intrinsic");
267
268 Type *EltTy = VecType->getElementType();
269
270 IRBuilder<> Builder(CI->getContext());
271 Instruction *InsertPt = CI;
272 BasicBlock *IfBlock = CI->getParent();
273 Builder.SetInsertPoint(InsertPt);
274 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
275
276 // Short-cut if the mask is all-true.
277 bool IsAllOnesMask =
278 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
279
280 if (IsAllOnesMask) {
281 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
282 CI->eraseFromParent();
283 return;
284 }
285
286 // Adjust alignment for the scalar instruction.
287 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
288 // Bitcast %addr fron i8* to EltTy*
289 Type *NewPtrType =
290 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
291 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
292 unsigned VectorWidth = VecType->getNumElements();
293
294 if (isa<ConstantVector>(Mask)) {
295 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
296 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
297 continue;
298 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
299 Value *Gep =
300 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
301 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
302 }
303 CI->eraseFromParent();
304 return;
305 }
306
307 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000308 // Fill the "else" block, created in the previous iteration
309 //
310 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
Craig Topper04236812018-09-27 18:01:48 +0000311 // br i1 %mask_1, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000312 //
313 Value *Predicate =
314 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000315
316 // Create "cond" block
317 //
318 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
319 // %EltAddr = getelementptr i32* %1, i32 0
320 // %store i32 %OneElt, i32* %EltAddr
321 //
322 BasicBlock *CondBlock =
323 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
324 Builder.SetInsertPoint(InsertPt);
325
326 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
327 Value *Gep =
328 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
329 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
330
331 // Create "else" block, fill it in the next iteration
332 BasicBlock *NewIfBlock =
333 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
334 Builder.SetInsertPoint(InsertPt);
335 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000336 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000337 OldBr->eraseFromParent();
338 IfBlock = NewIfBlock;
339 }
340 CI->eraseFromParent();
341}
342
343// Translate a masked gather intrinsic like
344// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
345// <16 x i1> %Mask, <16 x i32> %Src)
346// to a chain of basic blocks, with loading element one-by-one if
347// the appropriate mask bit is set
348//
349// % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
350// % Mask0 = extractelement <16 x i1> %Mask, i32 0
351// % ToLoad0 = icmp eq i1 % Mask0, true
352// br i1 % ToLoad0, label %cond.load, label %else
353//
354// cond.load:
355// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
356// % Load0 = load i32, i32* % Ptr0, align 4
357// % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
358// br label %else
359//
360// else:
361// %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
362// % Mask1 = extractelement <16 x i1> %Mask, i32 1
363// % ToLoad1 = icmp eq i1 % Mask1, true
364// br i1 % ToLoad1, label %cond.load1, label %else2
365//
366// cond.load1:
367// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
368// % Load1 = load i32, i32* % Ptr1, align 4
369// % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
370// br label %else2
371// . . .
372// % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
373// ret <16 x i32> %Result
374static void scalarizeMaskedGather(CallInst *CI) {
375 Value *Ptrs = CI->getArgOperand(0);
376 Value *Alignment = CI->getArgOperand(1);
377 Value *Mask = CI->getArgOperand(2);
378 Value *Src0 = CI->getArgOperand(3);
379
380 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
381
382 assert(VecType && "Unexpected return type of masked load intrinsic");
383
384 IRBuilder<> Builder(CI->getContext());
385 Instruction *InsertPt = CI;
386 BasicBlock *IfBlock = CI->getParent();
387 BasicBlock *CondBlock = nullptr;
388 BasicBlock *PrevIfBlock = CI->getParent();
389 Builder.SetInsertPoint(InsertPt);
390 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
391
392 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
393
394 Value *UndefVal = UndefValue::get(VecType);
395
396 // The result vector
397 Value *VResult = UndefVal;
398 unsigned VectorWidth = VecType->getNumElements();
399
400 // Shorten the way if the mask is a vector of constants.
401 bool IsConstMask = isa<ConstantVector>(Mask);
402
403 if (IsConstMask) {
404 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
405 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
406 continue;
407 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
408 "Ptr" + Twine(Idx));
409 LoadInst *Load =
410 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
411 VResult = Builder.CreateInsertElement(
412 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
413 }
414 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
415 CI->replaceAllUsesWith(NewI);
416 CI->eraseFromParent();
417 return;
418 }
419
420 PHINode *Phi = nullptr;
421 Value *PrevPhi = UndefVal;
422
423 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000424 // Fill the "else" block, created in the previous iteration
425 //
426 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
Craig Topper04236812018-09-27 18:01:48 +0000427 // br i1 %Mask1, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000428 //
429 if (Idx > 0) {
430 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
431 Phi->addIncoming(VResult, CondBlock);
432 Phi->addIncoming(PrevPhi, PrevIfBlock);
433 PrevPhi = Phi;
434 VResult = Phi;
435 }
436
437 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
438 "Mask" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000439
440 // Create "cond" block
441 //
442 // %EltAddr = getelementptr i32* %1, i32 0
443 // %Elt = load i32* %EltAddr
444 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
445 //
446 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
447 Builder.SetInsertPoint(InsertPt);
448
449 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
450 "Ptr" + Twine(Idx));
451 LoadInst *Load =
452 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
453 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
454 "Res" + Twine(Idx));
455
456 // Create "else" block, fill it in the next iteration
457 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
458 Builder.SetInsertPoint(InsertPt);
459 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000460 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000461 OldBr->eraseFromParent();
462 PrevIfBlock = IfBlock;
463 IfBlock = NewIfBlock;
464 }
465
466 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
467 Phi->addIncoming(VResult, CondBlock);
468 Phi->addIncoming(PrevPhi, PrevIfBlock);
469 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
470 CI->replaceAllUsesWith(NewI);
471 CI->eraseFromParent();
472}
473
474// Translate a masked scatter intrinsic, like
475// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
476// <16 x i1> %Mask)
477// to a chain of basic blocks, that stores element one-by-one if
478// the appropriate mask bit is set.
479//
480// % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
481// % Mask0 = extractelement <16 x i1> % Mask, i32 0
482// % ToStore0 = icmp eq i1 % Mask0, true
483// br i1 %ToStore0, label %cond.store, label %else
484//
485// cond.store:
486// % Elt0 = extractelement <16 x i32> %Src, i32 0
487// % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
488// store i32 %Elt0, i32* % Ptr0, align 4
489// br label %else
490//
491// else:
492// % Mask1 = extractelement <16 x i1> % Mask, i32 1
493// % ToStore1 = icmp eq i1 % Mask1, true
494// br i1 % ToStore1, label %cond.store1, label %else2
495//
496// cond.store1:
497// % Elt1 = extractelement <16 x i32> %Src, i32 1
498// % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
499// store i32 % Elt1, i32* % Ptr1, align 4
500// br label %else2
501// . . .
502static void scalarizeMaskedScatter(CallInst *CI) {
503 Value *Src = CI->getArgOperand(0);
504 Value *Ptrs = CI->getArgOperand(1);
505 Value *Alignment = CI->getArgOperand(2);
506 Value *Mask = CI->getArgOperand(3);
507
508 assert(isa<VectorType>(Src->getType()) &&
509 "Unexpected data type in masked scatter intrinsic");
510 assert(isa<VectorType>(Ptrs->getType()) &&
511 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
512 "Vector of pointers is expected in masked scatter intrinsic");
513
514 IRBuilder<> Builder(CI->getContext());
515 Instruction *InsertPt = CI;
516 BasicBlock *IfBlock = CI->getParent();
517 Builder.SetInsertPoint(InsertPt);
518 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
519
520 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
521 unsigned VectorWidth = Src->getType()->getVectorNumElements();
522
523 // Shorten the way if the mask is a vector of constants.
524 bool IsConstMask = isa<ConstantVector>(Mask);
525
526 if (IsConstMask) {
527 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
528 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
529 continue;
530 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
531 "Elt" + Twine(Idx));
532 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
533 "Ptr" + Twine(Idx));
534 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
535 }
536 CI->eraseFromParent();
537 return;
538 }
539 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
540 // Fill the "else" block, created in the previous iteration
541 //
Craig Topper04236812018-09-27 18:01:48 +0000542 // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
543 // br i1 %Mask1, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000544 //
545 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
546 "Mask" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000547
548 // Create "cond" block
549 //
550 // % Elt1 = extractelement <16 x i32> %Src, i32 1
551 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
552 // %store i32 % Elt1, i32* % Ptr1
553 //
554 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
555 Builder.SetInsertPoint(InsertPt);
556
557 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
558 "Elt" + Twine(Idx));
559 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
560 "Ptr" + Twine(Idx));
561 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
562
563 // Create "else" block, fill it in the next iteration
564 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
565 Builder.SetInsertPoint(InsertPt);
566 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000567 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000568 OldBr->eraseFromParent();
569 IfBlock = NewIfBlock;
570 }
571 CI->eraseFromParent();
572}
573
574bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
Ayman Musac5490e52017-05-15 11:30:54 +0000575 bool EverMadeChange = false;
576
577 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
578
579 bool MadeChange = true;
580 while (MadeChange) {
581 MadeChange = false;
582 for (Function::iterator I = F.begin(); I != F.end();) {
583 BasicBlock *BB = &*I++;
584 bool ModifiedDTOnIteration = false;
585 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
586
587 // Restart BB iteration if the dominator tree of the Function was changed
588 if (ModifiedDTOnIteration)
589 break;
590 }
591
592 EverMadeChange |= MadeChange;
593 }
594
595 return EverMadeChange;
596}
597
598bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
599 bool MadeChange = false;
600
601 BasicBlock::iterator CurInstIterator = BB.begin();
602 while (CurInstIterator != BB.end()) {
603 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
604 MadeChange |= optimizeCallInst(CI, ModifiedDT);
605 if (ModifiedDT)
606 return true;
607 }
608
609 return MadeChange;
610}
611
612bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
613 bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000614 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
615 if (II) {
616 switch (II->getIntrinsicID()) {
617 default:
618 break;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000619 case Intrinsic::masked_load:
Ayman Musac5490e52017-05-15 11:30:54 +0000620 // Scalarize unsupported vector masked load
621 if (!TTI->isLegalMaskedLoad(CI->getType())) {
622 scalarizeMaskedLoad(CI);
623 ModifiedDT = true;
624 return true;
625 }
626 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000627 case Intrinsic::masked_store:
Ayman Musac5490e52017-05-15 11:30:54 +0000628 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
629 scalarizeMaskedStore(CI);
630 ModifiedDT = true;
631 return true;
632 }
633 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000634 case Intrinsic::masked_gather:
Ayman Musac5490e52017-05-15 11:30:54 +0000635 if (!TTI->isLegalMaskedGather(CI->getType())) {
636 scalarizeMaskedGather(CI);
637 ModifiedDT = true;
638 return true;
639 }
640 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000641 case Intrinsic::masked_scatter:
Ayman Musac5490e52017-05-15 11:30:54 +0000642 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
643 scalarizeMaskedScatter(CI);
644 ModifiedDT = true;
645 return true;
646 }
647 return false;
648 }
Ayman Musac5490e52017-05-15 11:30:54 +0000649 }
650
651 return false;
652}