blob: b039cdd01d476efeb4861bcdb14cd7e0d9dfcadb [file] [log] [blame]
Eugene Zelenkofa57bd02017-09-27 23:26:01 +00001//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// instrinsics
Ayman Musac5490e52017-05-15 11:30:54 +00003//
4// The LLVM Compiler Infrastructure
5//
6// This file is distributed under the University of Illinois Open Source
7// License. See LICENSE.TXT for details.
8//
9//===----------------------------------------------------------------------===//
10//
11// This pass replaces masked memory intrinsics - when unsupported by the target
12// - with a chain of basic blocks, that deal with the elements one-by-one if the
13// appropriate mask bit is set.
14//
15//===----------------------------------------------------------------------===//
16
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000017#include "llvm/ADT/Twine.h"
Ayman Musac5490e52017-05-15 11:30:54 +000018#include "llvm/Analysis/TargetTransformInfo.h"
David Blaikieb3bde2e2017-11-17 01:07:10 +000019#include "llvm/CodeGen/TargetSubtargetInfo.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000020#include "llvm/IR/BasicBlock.h"
21#include "llvm/IR/Constant.h"
22#include "llvm/IR/Constants.h"
23#include "llvm/IR/DerivedTypes.h"
24#include "llvm/IR/Function.h"
Ayman Musac5490e52017-05-15 11:30:54 +000025#include "llvm/IR/IRBuilder.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000026#include "llvm/IR/InstrTypes.h"
27#include "llvm/IR/Instruction.h"
28#include "llvm/IR/Instructions.h"
Reid Kleckner0e8c4bb2017-09-07 23:27:44 +000029#include "llvm/IR/IntrinsicInst.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000030#include "llvm/IR/Intrinsics.h"
31#include "llvm/IR/Type.h"
32#include "llvm/IR/Value.h"
33#include "llvm/Pass.h"
34#include "llvm/Support/Casting.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000035#include <algorithm>
36#include <cassert>
Ayman Musac5490e52017-05-15 11:30:54 +000037
38using namespace llvm;
39
40#define DEBUG_TYPE "scalarize-masked-mem-intrin"
41
42namespace {
43
44class ScalarizeMaskedMemIntrin : public FunctionPass {
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000045 const TargetTransformInfo *TTI = nullptr;
Ayman Musac5490e52017-05-15 11:30:54 +000046
47public:
48 static char ID; // Pass identification, replacement for typeid
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000049
50 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
Ayman Musac5490e52017-05-15 11:30:54 +000051 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
52 }
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000053
Ayman Musac5490e52017-05-15 11:30:54 +000054 bool runOnFunction(Function &F) override;
55
56 StringRef getPassName() const override {
57 return "Scalarize Masked Memory Intrinsics";
58 }
59
60 void getAnalysisUsage(AnalysisUsage &AU) const override {
61 AU.addRequired<TargetTransformInfoWrapperPass>();
62 }
63
64private:
65 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
66 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
67};
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000068
69} // end anonymous namespace
Ayman Musac5490e52017-05-15 11:30:54 +000070
71char ScalarizeMaskedMemIntrin::ID = 0;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000072
Matthias Braun1527baa2017-05-25 21:26:32 +000073INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74 "Scalarize unsupported masked memory intrinsics", false, false)
Ayman Musac5490e52017-05-15 11:30:54 +000075
76FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
77 return new ScalarizeMaskedMemIntrin();
78}
79
80// Translate a masked load intrinsic like
81// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
82// <16 x i1> %mask, <16 x i32> %passthru)
83// to a chain of basic blocks, with loading element one-by-one if
84// the appropriate mask bit is set
85//
86// %1 = bitcast i8* %addr to i32*
87// %2 = extractelement <16 x i1> %mask, i32 0
Craig Topper49dad8b2018-09-27 21:28:39 +000088// br i1 %2, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +000089//
90// cond.load: ; preds = %0
Craig Topper49dad8b2018-09-27 21:28:39 +000091// %3 = getelementptr i32* %1, i32 0
92// %4 = load i32* %3
Craig Topper7d234d62018-09-27 21:28:52 +000093// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
Ayman Musac5490e52017-05-15 11:30:54 +000094// br label %else
95//
96// else: ; preds = %0, %cond.load
Craig Topper49dad8b2018-09-27 21:28:39 +000097// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
98// %6 = extractelement <16 x i1> %mask, i32 1
99// br i1 %6, label %cond.load1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000100//
101// cond.load1: ; preds = %else
Craig Topper49dad8b2018-09-27 21:28:39 +0000102// %7 = getelementptr i32* %1, i32 1
103// %8 = load i32* %7
104// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
Ayman Musac5490e52017-05-15 11:30:54 +0000105// br label %else2
106//
107// else2: ; preds = %else, %cond.load1
Craig Topper49dad8b2018-09-27 21:28:39 +0000108// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
109// %10 = extractelement <16 x i1> %mask, i32 2
110// br i1 %10, label %cond.load4, label %else5
Ayman Musac5490e52017-05-15 11:30:54 +0000111//
112static void scalarizeMaskedLoad(CallInst *CI) {
113 Value *Ptr = CI->getArgOperand(0);
114 Value *Alignment = CI->getArgOperand(1);
115 Value *Mask = CI->getArgOperand(2);
116 Value *Src0 = CI->getArgOperand(3);
117
118 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
119 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
120 assert(VecType && "Unexpected return type of masked load intrinsic");
121
122 Type *EltTy = CI->getType()->getVectorElementType();
123
124 IRBuilder<> Builder(CI->getContext());
125 Instruction *InsertPt = CI;
126 BasicBlock *IfBlock = CI->getParent();
127 BasicBlock *CondBlock = nullptr;
128 BasicBlock *PrevIfBlock = CI->getParent();
129
130 Builder.SetInsertPoint(InsertPt);
131 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
132
133 // Short-cut if the mask is all-true.
Craig Topperdfe460d2018-09-27 21:28:41 +0000134 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Ayman Musac5490e52017-05-15 11:30:54 +0000135 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
136 CI->replaceAllUsesWith(NewI);
137 CI->eraseFromParent();
138 return;
139 }
140
141 // Adjust alignment for the scalar instruction.
142 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
143 // Bitcast %addr fron i8* to EltTy*
144 Type *NewPtrType =
145 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
146 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
147 unsigned VectorWidth = VecType->getNumElements();
148
Ayman Musac5490e52017-05-15 11:30:54 +0000149 // The result vector
Craig Topper7d234d62018-09-27 21:28:52 +0000150 Value *VResult = Src0;
Ayman Musac5490e52017-05-15 11:30:54 +0000151
Craig Topperdfc0f282018-09-27 21:28:46 +0000152 if (isa<Constant>(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000153 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000154 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000155 continue;
156 Value *Gep =
157 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
158 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
159 VResult =
160 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
161 }
Craig Topper7d234d62018-09-27 21:28:52 +0000162 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000163 CI->eraseFromParent();
164 return;
165 }
166
Ayman Musac5490e52017-05-15 11:30:54 +0000167 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000168 // Fill the "else" block, created in the previous iteration
169 //
170 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
171 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
Craig Topper04236812018-09-27 18:01:48 +0000172 // br i1 %mask_1, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000173 //
Ayman Musac5490e52017-05-15 11:30:54 +0000174
175 Value *Predicate =
176 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000177
178 // Create "cond" block
179 //
180 // %EltAddr = getelementptr i32* %1, i32 0
181 // %Elt = load i32* %EltAddr
182 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
183 //
184 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
185 Builder.SetInsertPoint(InsertPt);
186
187 Value *Gep =
188 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
189 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
Craig Topper7d234d62018-09-27 21:28:52 +0000190 Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
191 Builder.getInt32(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000192
193 // Create "else" block, fill it in the next iteration
194 BasicBlock *NewIfBlock =
195 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
196 Builder.SetInsertPoint(InsertPt);
197 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000198 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000199 OldBr->eraseFromParent();
200 PrevIfBlock = IfBlock;
201 IfBlock = NewIfBlock;
Craig Topper7d234d62018-09-27 21:28:52 +0000202
203 // Create the phi to join the new and previous value.
204 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
205 Phi->addIncoming(NewVResult, CondBlock);
206 Phi->addIncoming(VResult, PrevIfBlock);
207 VResult = Phi;
Ayman Musac5490e52017-05-15 11:30:54 +0000208 }
209
Craig Topper7d234d62018-09-27 21:28:52 +0000210 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000211 CI->eraseFromParent();
212}
213
214// Translate a masked store intrinsic, like
215// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
216// <16 x i1> %mask)
217// to a chain of basic blocks, that stores element one-by-one if
218// the appropriate mask bit is set
219//
220// %1 = bitcast i8* %addr to i32*
221// %2 = extractelement <16 x i1> %mask, i32 0
Craig Topper49dad8b2018-09-27 21:28:39 +0000222// br i1 %2, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000223//
224// cond.store: ; preds = %0
Craig Topper49dad8b2018-09-27 21:28:39 +0000225// %3 = extractelement <16 x i32> %val, i32 0
226// %4 = getelementptr i32* %1, i32 0
227// store i32 %3, i32* %4
Ayman Musac5490e52017-05-15 11:30:54 +0000228// br label %else
229//
230// else: ; preds = %0, %cond.store
Craig Topper49dad8b2018-09-27 21:28:39 +0000231// %5 = extractelement <16 x i1> %mask, i32 1
232// br i1 %5, label %cond.store1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000233//
234// cond.store1: ; preds = %else
Craig Topper49dad8b2018-09-27 21:28:39 +0000235// %6 = extractelement <16 x i32> %val, i32 1
236// %7 = getelementptr i32* %1, i32 1
237// store i32 %6, i32* %7
Ayman Musac5490e52017-05-15 11:30:54 +0000238// br label %else2
239// . . .
240static void scalarizeMaskedStore(CallInst *CI) {
241 Value *Src = CI->getArgOperand(0);
242 Value *Ptr = CI->getArgOperand(1);
243 Value *Alignment = CI->getArgOperand(2);
244 Value *Mask = CI->getArgOperand(3);
245
246 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
247 VectorType *VecType = dyn_cast<VectorType>(Src->getType());
248 assert(VecType && "Unexpected data type in masked store intrinsic");
249
250 Type *EltTy = VecType->getElementType();
251
252 IRBuilder<> Builder(CI->getContext());
253 Instruction *InsertPt = CI;
254 BasicBlock *IfBlock = CI->getParent();
255 Builder.SetInsertPoint(InsertPt);
256 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
257
258 // Short-cut if the mask is all-true.
Craig Topperdfe460d2018-09-27 21:28:41 +0000259 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Ayman Musac5490e52017-05-15 11:30:54 +0000260 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
261 CI->eraseFromParent();
262 return;
263 }
264
265 // Adjust alignment for the scalar instruction.
266 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
267 // Bitcast %addr fron i8* to EltTy*
268 Type *NewPtrType =
269 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
270 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
271 unsigned VectorWidth = VecType->getNumElements();
272
Craig Topperdfc0f282018-09-27 21:28:46 +0000273 if (isa<Constant>(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000274 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000275 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000276 continue;
277 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
278 Value *Gep =
279 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
280 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
281 }
282 CI->eraseFromParent();
283 return;
284 }
285
286 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000287 // Fill the "else" block, created in the previous iteration
288 //
289 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
Craig Topper04236812018-09-27 18:01:48 +0000290 // br i1 %mask_1, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000291 //
292 Value *Predicate =
293 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000294
295 // Create "cond" block
296 //
297 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
298 // %EltAddr = getelementptr i32* %1, i32 0
299 // %store i32 %OneElt, i32* %EltAddr
300 //
301 BasicBlock *CondBlock =
302 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
303 Builder.SetInsertPoint(InsertPt);
304
305 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
306 Value *Gep =
307 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
308 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
309
310 // Create "else" block, fill it in the next iteration
311 BasicBlock *NewIfBlock =
312 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
313 Builder.SetInsertPoint(InsertPt);
314 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000315 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000316 OldBr->eraseFromParent();
317 IfBlock = NewIfBlock;
318 }
319 CI->eraseFromParent();
320}
321
322// Translate a masked gather intrinsic like
323// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
324// <16 x i1> %Mask, <16 x i32> %Src)
325// to a chain of basic blocks, with loading element one-by-one if
326// the appropriate mask bit is set
327//
Craig Topper49dad8b2018-09-27 21:28:39 +0000328// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
329// %Mask0 = extractelement <16 x i1> %Mask, i32 0
330// br i1 %Mask0, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000331//
332// cond.load:
Craig Topper49dad8b2018-09-27 21:28:39 +0000333// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
334// %Load0 = load i32, i32* %Ptr0, align 4
335// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
Ayman Musac5490e52017-05-15 11:30:54 +0000336// br label %else
337//
338// else:
Craig Topper49dad8b2018-09-27 21:28:39 +0000339// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
340// %Mask1 = extractelement <16 x i1> %Mask, i32 1
341// br i1 %Mask1, label %cond.load1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000342//
343// cond.load1:
Craig Topper49dad8b2018-09-27 21:28:39 +0000344// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
345// %Load1 = load i32, i32* %Ptr1, align 4
346// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
Ayman Musac5490e52017-05-15 11:30:54 +0000347// br label %else2
348// . . .
Craig Topper49dad8b2018-09-27 21:28:39 +0000349// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
Ayman Musac5490e52017-05-15 11:30:54 +0000350// ret <16 x i32> %Result
351static void scalarizeMaskedGather(CallInst *CI) {
352 Value *Ptrs = CI->getArgOperand(0);
353 Value *Alignment = CI->getArgOperand(1);
354 Value *Mask = CI->getArgOperand(2);
355 Value *Src0 = CI->getArgOperand(3);
356
357 VectorType *VecType = dyn_cast<VectorType>(CI->getType());
358
359 assert(VecType && "Unexpected return type of masked load intrinsic");
360
361 IRBuilder<> Builder(CI->getContext());
362 Instruction *InsertPt = CI;
363 BasicBlock *IfBlock = CI->getParent();
364 BasicBlock *CondBlock = nullptr;
365 BasicBlock *PrevIfBlock = CI->getParent();
366 Builder.SetInsertPoint(InsertPt);
367 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
368
369 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
370
371 Value *UndefVal = UndefValue::get(VecType);
372
373 // The result vector
374 Value *VResult = UndefVal;
375 unsigned VectorWidth = VecType->getNumElements();
376
377 // Shorten the way if the mask is a vector of constants.
Craig Topperdfc0f282018-09-27 21:28:46 +0000378 if (isa<Constant>(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000379 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000380 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000381 continue;
382 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
383 "Ptr" + Twine(Idx));
384 LoadInst *Load =
385 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
386 VResult = Builder.CreateInsertElement(
387 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
388 }
389 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
390 CI->replaceAllUsesWith(NewI);
391 CI->eraseFromParent();
392 return;
393 }
394
395 PHINode *Phi = nullptr;
396 Value *PrevPhi = UndefVal;
397
398 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000399 // Fill the "else" block, created in the previous iteration
400 //
401 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
Craig Topper04236812018-09-27 18:01:48 +0000402 // br i1 %Mask1, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000403 //
404 if (Idx > 0) {
405 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
406 Phi->addIncoming(VResult, CondBlock);
407 Phi->addIncoming(PrevPhi, PrevIfBlock);
408 PrevPhi = Phi;
409 VResult = Phi;
410 }
411
412 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
413 "Mask" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000414
415 // Create "cond" block
416 //
417 // %EltAddr = getelementptr i32* %1, i32 0
418 // %Elt = load i32* %EltAddr
419 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
420 //
421 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
422 Builder.SetInsertPoint(InsertPt);
423
424 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
425 "Ptr" + Twine(Idx));
426 LoadInst *Load =
427 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
428 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
429 "Res" + Twine(Idx));
430
431 // Create "else" block, fill it in the next iteration
432 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
433 Builder.SetInsertPoint(InsertPt);
434 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000435 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000436 OldBr->eraseFromParent();
437 PrevIfBlock = IfBlock;
438 IfBlock = NewIfBlock;
439 }
440
441 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
442 Phi->addIncoming(VResult, CondBlock);
443 Phi->addIncoming(PrevPhi, PrevIfBlock);
444 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
445 CI->replaceAllUsesWith(NewI);
446 CI->eraseFromParent();
447}
448
449// Translate a masked scatter intrinsic, like
450// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
451// <16 x i1> %Mask)
452// to a chain of basic blocks, that stores element one-by-one if
453// the appropriate mask bit is set.
454//
Craig Topper49dad8b2018-09-27 21:28:39 +0000455// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
456// %Mask0 = extractelement <16 x i1> %Mask, i32 0
457// br i1 %Mask0, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000458//
459// cond.store:
Craig Topper49dad8b2018-09-27 21:28:39 +0000460// %Elt0 = extractelement <16 x i32> %Src, i32 0
461// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
462// store i32 %Elt0, i32* %Ptr0, align 4
Ayman Musac5490e52017-05-15 11:30:54 +0000463// br label %else
464//
465// else:
Craig Topper49dad8b2018-09-27 21:28:39 +0000466// %Mask1 = extractelement <16 x i1> %Mask, i32 1
467// br i1 %Mask1, label %cond.store1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000468//
469// cond.store1:
Craig Topper49dad8b2018-09-27 21:28:39 +0000470// %Elt1 = extractelement <16 x i32> %Src, i32 1
471// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
472// store i32 %Elt1, i32* %Ptr1, align 4
Ayman Musac5490e52017-05-15 11:30:54 +0000473// br label %else2
474// . . .
475static void scalarizeMaskedScatter(CallInst *CI) {
476 Value *Src = CI->getArgOperand(0);
477 Value *Ptrs = CI->getArgOperand(1);
478 Value *Alignment = CI->getArgOperand(2);
479 Value *Mask = CI->getArgOperand(3);
480
481 assert(isa<VectorType>(Src->getType()) &&
482 "Unexpected data type in masked scatter intrinsic");
483 assert(isa<VectorType>(Ptrs->getType()) &&
484 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
485 "Vector of pointers is expected in masked scatter intrinsic");
486
487 IRBuilder<> Builder(CI->getContext());
488 Instruction *InsertPt = CI;
489 BasicBlock *IfBlock = CI->getParent();
490 Builder.SetInsertPoint(InsertPt);
491 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
492
493 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
494 unsigned VectorWidth = Src->getType()->getVectorNumElements();
495
496 // Shorten the way if the mask is a vector of constants.
Craig Topperdfc0f282018-09-27 21:28:46 +0000497 if (isa<Constant>(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000498 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000499 if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000500 continue;
501 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
502 "Elt" + Twine(Idx));
503 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
504 "Ptr" + Twine(Idx));
505 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
506 }
507 CI->eraseFromParent();
508 return;
509 }
Craig Topperdfe460d2018-09-27 21:28:41 +0000510
Ayman Musac5490e52017-05-15 11:30:54 +0000511 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
512 // Fill the "else" block, created in the previous iteration
513 //
Craig Topper04236812018-09-27 18:01:48 +0000514 // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
515 // br i1 %Mask1, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000516 //
517 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
518 "Mask" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000519
520 // Create "cond" block
521 //
Craig Topper49dad8b2018-09-27 21:28:39 +0000522 // %Elt1 = extractelement <16 x i32> %Src, i32 1
523 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
524 // %store i32 %Elt1, i32* %Ptr1
Ayman Musac5490e52017-05-15 11:30:54 +0000525 //
526 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
527 Builder.SetInsertPoint(InsertPt);
528
529 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
530 "Elt" + Twine(Idx));
531 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
532 "Ptr" + Twine(Idx));
533 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
534
535 // Create "else" block, fill it in the next iteration
536 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
537 Builder.SetInsertPoint(InsertPt);
538 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000539 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000540 OldBr->eraseFromParent();
541 IfBlock = NewIfBlock;
542 }
543 CI->eraseFromParent();
544}
545
546bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
Ayman Musac5490e52017-05-15 11:30:54 +0000547 bool EverMadeChange = false;
548
549 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
550
551 bool MadeChange = true;
552 while (MadeChange) {
553 MadeChange = false;
554 for (Function::iterator I = F.begin(); I != F.end();) {
555 BasicBlock *BB = &*I++;
556 bool ModifiedDTOnIteration = false;
557 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
558
559 // Restart BB iteration if the dominator tree of the Function was changed
560 if (ModifiedDTOnIteration)
561 break;
562 }
563
564 EverMadeChange |= MadeChange;
565 }
566
567 return EverMadeChange;
568}
569
570bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
571 bool MadeChange = false;
572
573 BasicBlock::iterator CurInstIterator = BB.begin();
574 while (CurInstIterator != BB.end()) {
575 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
576 MadeChange |= optimizeCallInst(CI, ModifiedDT);
577 if (ModifiedDT)
578 return true;
579 }
580
581 return MadeChange;
582}
583
584bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
585 bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000586 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
587 if (II) {
588 switch (II->getIntrinsicID()) {
589 default:
590 break;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000591 case Intrinsic::masked_load:
Ayman Musac5490e52017-05-15 11:30:54 +0000592 // Scalarize unsupported vector masked load
593 if (!TTI->isLegalMaskedLoad(CI->getType())) {
594 scalarizeMaskedLoad(CI);
595 ModifiedDT = true;
596 return true;
597 }
598 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000599 case Intrinsic::masked_store:
Ayman Musac5490e52017-05-15 11:30:54 +0000600 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
601 scalarizeMaskedStore(CI);
602 ModifiedDT = true;
603 return true;
604 }
605 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000606 case Intrinsic::masked_gather:
Ayman Musac5490e52017-05-15 11:30:54 +0000607 if (!TTI->isLegalMaskedGather(CI->getType())) {
608 scalarizeMaskedGather(CI);
609 ModifiedDT = true;
610 return true;
611 }
612 return false;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000613 case Intrinsic::masked_scatter:
Ayman Musac5490e52017-05-15 11:30:54 +0000614 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
615 scalarizeMaskedScatter(CI);
616 ModifiedDT = true;
617 return true;
618 }
619 return false;
620 }
Ayman Musac5490e52017-05-15 11:30:54 +0000621 }
622
623 return false;
624}