blob: e2ee9f28f3b5592491b1b2c459304803cb12eede [file] [log] [blame]
Eugene Zelenkofa57bd02017-09-27 23:26:01 +00001//===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2// instrinsics
Ayman Musac5490e52017-05-15 11:30:54 +00003//
Chandler Carruth2946cd72019-01-19 08:50:56 +00004// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5// See https://llvm.org/LICENSE.txt for license information.
6// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Ayman Musac5490e52017-05-15 11:30:54 +00007//
8//===----------------------------------------------------------------------===//
9//
10// This pass replaces masked memory intrinsics - when unsupported by the target
11// - with a chain of basic blocks, that deal with the elements one-by-one if the
12// appropriate mask bit is set.
13//
14//===----------------------------------------------------------------------===//
15
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000016#include "llvm/ADT/Twine.h"
Ayman Musac5490e52017-05-15 11:30:54 +000017#include "llvm/Analysis/TargetTransformInfo.h"
David Blaikieb3bde2e2017-11-17 01:07:10 +000018#include "llvm/CodeGen/TargetSubtargetInfo.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000019#include "llvm/IR/BasicBlock.h"
20#include "llvm/IR/Constant.h"
21#include "llvm/IR/Constants.h"
22#include "llvm/IR/DerivedTypes.h"
23#include "llvm/IR/Function.h"
Ayman Musac5490e52017-05-15 11:30:54 +000024#include "llvm/IR/IRBuilder.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000025#include "llvm/IR/InstrTypes.h"
26#include "llvm/IR/Instruction.h"
27#include "llvm/IR/Instructions.h"
Reid Kleckner0e8c4bb2017-09-07 23:27:44 +000028#include "llvm/IR/IntrinsicInst.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000029#include "llvm/IR/Intrinsics.h"
30#include "llvm/IR/Type.h"
31#include "llvm/IR/Value.h"
32#include "llvm/Pass.h"
33#include "llvm/Support/Casting.h"
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000034#include <algorithm>
35#include <cassert>
Ayman Musac5490e52017-05-15 11:30:54 +000036
37using namespace llvm;
38
39#define DEBUG_TYPE "scalarize-masked-mem-intrin"
40
41namespace {
42
43class ScalarizeMaskedMemIntrin : public FunctionPass {
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000044 const TargetTransformInfo *TTI = nullptr;
Ayman Musac5490e52017-05-15 11:30:54 +000045
46public:
47 static char ID; // Pass identification, replacement for typeid
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000048
49 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
Ayman Musac5490e52017-05-15 11:30:54 +000050 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
51 }
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000052
Ayman Musac5490e52017-05-15 11:30:54 +000053 bool runOnFunction(Function &F) override;
54
55 StringRef getPassName() const override {
56 return "Scalarize Masked Memory Intrinsics";
57 }
58
59 void getAnalysisUsage(AnalysisUsage &AU) const override {
60 AU.addRequired<TargetTransformInfoWrapperPass>();
61 }
62
63private:
64 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
65 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
66};
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000067
68} // end anonymous namespace
Ayman Musac5490e52017-05-15 11:30:54 +000069
70char ScalarizeMaskedMemIntrin::ID = 0;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +000071
Matthias Braun1527baa2017-05-25 21:26:32 +000072INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
73 "Scalarize unsupported masked memory intrinsics", false, false)
Ayman Musac5490e52017-05-15 11:30:54 +000074
75FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
76 return new ScalarizeMaskedMemIntrin();
77}
78
Craig Topper8b4f0e12018-09-27 22:31:42 +000079static bool isConstantIntVector(Value *Mask) {
80 Constant *C = dyn_cast<Constant>(Mask);
81 if (!C)
82 return false;
83
84 unsigned NumElts = Mask->getType()->getVectorNumElements();
85 for (unsigned i = 0; i != NumElts; ++i) {
86 Constant *CElt = C->getAggregateElement(i);
87 if (!CElt || !isa<ConstantInt>(CElt))
88 return false;
89 }
90
91 return true;
92}
93
Ayman Musac5490e52017-05-15 11:30:54 +000094// Translate a masked load intrinsic like
95// <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
96// <16 x i1> %mask, <16 x i32> %passthru)
97// to a chain of basic blocks, with loading element one-by-one if
98// the appropriate mask bit is set
99//
100// %1 = bitcast i8* %addr to i32*
101// %2 = extractelement <16 x i1> %mask, i32 0
Craig Topper49dad8b2018-09-27 21:28:39 +0000102// br i1 %2, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000103//
104// cond.load: ; preds = %0
Craig Topper49dad8b2018-09-27 21:28:39 +0000105// %3 = getelementptr i32* %1, i32 0
106// %4 = load i32* %3
Craig Topper7d234d62018-09-27 21:28:52 +0000107// %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
Ayman Musac5490e52017-05-15 11:30:54 +0000108// br label %else
109//
110// else: ; preds = %0, %cond.load
Craig Topper49dad8b2018-09-27 21:28:39 +0000111// %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
112// %6 = extractelement <16 x i1> %mask, i32 1
113// br i1 %6, label %cond.load1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000114//
115// cond.load1: ; preds = %else
Craig Topper49dad8b2018-09-27 21:28:39 +0000116// %7 = getelementptr i32* %1, i32 1
117// %8 = load i32* %7
118// %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
Ayman Musac5490e52017-05-15 11:30:54 +0000119// br label %else2
120//
121// else2: ; preds = %else, %cond.load1
Craig Topper49dad8b2018-09-27 21:28:39 +0000122// %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
123// %10 = extractelement <16 x i1> %mask, i32 2
124// br i1 %10, label %cond.load4, label %else5
Ayman Musac5490e52017-05-15 11:30:54 +0000125//
Craig Topperd84f6052019-03-08 23:03:43 +0000126static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000127 Value *Ptr = CI->getArgOperand(0);
128 Value *Alignment = CI->getArgOperand(1);
129 Value *Mask = CI->getArgOperand(2);
130 Value *Src0 = CI->getArgOperand(3);
131
132 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
Craig Topper10ec0212018-09-27 22:31:40 +0000133 VectorType *VecType = cast<VectorType>(CI->getType());
Ayman Musac5490e52017-05-15 11:30:54 +0000134
Craig Topper10ec0212018-09-27 22:31:40 +0000135 Type *EltTy = VecType->getElementType();
Ayman Musac5490e52017-05-15 11:30:54 +0000136
137 IRBuilder<> Builder(CI->getContext());
138 Instruction *InsertPt = CI;
139 BasicBlock *IfBlock = CI->getParent();
Ayman Musac5490e52017-05-15 11:30:54 +0000140
141 Builder.SetInsertPoint(InsertPt);
142 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
143
144 // Short-cut if the mask is all-true.
Craig Topperdfe460d2018-09-27 21:28:41 +0000145 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
James Y Knight14359ef2019-02-01 20:44:24 +0000146 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
Ayman Musac5490e52017-05-15 11:30:54 +0000147 CI->replaceAllUsesWith(NewI);
148 CI->eraseFromParent();
149 return;
150 }
151
152 // Adjust alignment for the scalar instruction.
Craig Topperbb50c382018-09-28 03:35:37 +0000153 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
Craig Topper69f8c162019-03-09 02:08:41 +0000154 // Bitcast %addr from i8* to EltTy*
Ayman Musac5490e52017-05-15 11:30:54 +0000155 Type *NewPtrType =
Craig Topper69f8c162019-03-09 02:08:41 +0000156 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
Ayman Musac5490e52017-05-15 11:30:54 +0000157 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
158 unsigned VectorWidth = VecType->getNumElements();
159
Ayman Musac5490e52017-05-15 11:30:54 +0000160 // The result vector
Craig Topper7d234d62018-09-27 21:28:52 +0000161 Value *VResult = Src0;
Ayman Musac5490e52017-05-15 11:30:54 +0000162
Craig Topper8b4f0e12018-09-27 22:31:42 +0000163 if (isConstantIntVector(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000164 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000165 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000166 continue;
Craig Topper69f8c162019-03-09 02:08:41 +0000167 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
James Y Knight14359ef2019-02-01 20:44:24 +0000168 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
Craig Topper69f8c162019-03-09 02:08:41 +0000169 VResult = Builder.CreateInsertElement(VResult, Load, Idx);
Ayman Musac5490e52017-05-15 11:30:54 +0000170 }
Craig Topper7d234d62018-09-27 21:28:52 +0000171 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000172 CI->eraseFromParent();
173 return;
174 }
175
Ayman Musac5490e52017-05-15 11:30:54 +0000176 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000177 // Fill the "else" block, created in the previous iteration
178 //
179 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
180 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
Craig Topper04236812018-09-27 18:01:48 +0000181 // br i1 %mask_1, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000182 //
Ayman Musac5490e52017-05-15 11:30:54 +0000183
Craig Topper69f8c162019-03-09 02:08:41 +0000184 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
Ayman Musac5490e52017-05-15 11:30:54 +0000185
186 // Create "cond" block
187 //
188 // %EltAddr = getelementptr i32* %1, i32 0
189 // %Elt = load i32* %EltAddr
190 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
191 //
Craig Topper4104c002018-10-30 20:33:58 +0000192 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
193 "cond.load");
Ayman Musac5490e52017-05-15 11:30:54 +0000194 Builder.SetInsertPoint(InsertPt);
195
Craig Topper69f8c162019-03-09 02:08:41 +0000196 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
James Y Knight14359ef2019-02-01 20:44:24 +0000197 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AlignVal);
Craig Topper69f8c162019-03-09 02:08:41 +0000198 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
Ayman Musac5490e52017-05-15 11:30:54 +0000199
200 // Create "else" block, fill it in the next iteration
201 BasicBlock *NewIfBlock =
202 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
203 Builder.SetInsertPoint(InsertPt);
204 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000205 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000206 OldBr->eraseFromParent();
Craig Topper4104c002018-10-30 20:33:58 +0000207 BasicBlock *PrevIfBlock = IfBlock;
Ayman Musac5490e52017-05-15 11:30:54 +0000208 IfBlock = NewIfBlock;
Craig Topper7d234d62018-09-27 21:28:52 +0000209
210 // Create the phi to join the new and previous value.
211 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
212 Phi->addIncoming(NewVResult, CondBlock);
213 Phi->addIncoming(VResult, PrevIfBlock);
214 VResult = Phi;
Ayman Musac5490e52017-05-15 11:30:54 +0000215 }
216
Craig Topper7d234d62018-09-27 21:28:52 +0000217 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000218 CI->eraseFromParent();
Craig Topperd84f6052019-03-08 23:03:43 +0000219
220 ModifiedDT = true;
Ayman Musac5490e52017-05-15 11:30:54 +0000221}
222
223// Translate a masked store intrinsic, like
224// void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
225// <16 x i1> %mask)
226// to a chain of basic blocks, that stores element one-by-one if
227// the appropriate mask bit is set
228//
229// %1 = bitcast i8* %addr to i32*
230// %2 = extractelement <16 x i1> %mask, i32 0
Craig Topper49dad8b2018-09-27 21:28:39 +0000231// br i1 %2, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000232//
233// cond.store: ; preds = %0
Craig Topper49dad8b2018-09-27 21:28:39 +0000234// %3 = extractelement <16 x i32> %val, i32 0
235// %4 = getelementptr i32* %1, i32 0
236// store i32 %3, i32* %4
Ayman Musac5490e52017-05-15 11:30:54 +0000237// br label %else
238//
239// else: ; preds = %0, %cond.store
Craig Topper49dad8b2018-09-27 21:28:39 +0000240// %5 = extractelement <16 x i1> %mask, i32 1
241// br i1 %5, label %cond.store1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000242//
243// cond.store1: ; preds = %else
Craig Topper49dad8b2018-09-27 21:28:39 +0000244// %6 = extractelement <16 x i32> %val, i32 1
245// %7 = getelementptr i32* %1, i32 1
246// store i32 %6, i32* %7
Ayman Musac5490e52017-05-15 11:30:54 +0000247// br label %else2
248// . . .
Craig Topperd84f6052019-03-08 23:03:43 +0000249static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000250 Value *Src = CI->getArgOperand(0);
251 Value *Ptr = CI->getArgOperand(1);
252 Value *Alignment = CI->getArgOperand(2);
253 Value *Mask = CI->getArgOperand(3);
254
255 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
Craig Topper10ec0212018-09-27 22:31:40 +0000256 VectorType *VecType = cast<VectorType>(Src->getType());
Ayman Musac5490e52017-05-15 11:30:54 +0000257
258 Type *EltTy = VecType->getElementType();
259
260 IRBuilder<> Builder(CI->getContext());
261 Instruction *InsertPt = CI;
262 BasicBlock *IfBlock = CI->getParent();
263 Builder.SetInsertPoint(InsertPt);
264 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
265
266 // Short-cut if the mask is all-true.
Craig Topperdfe460d2018-09-27 21:28:41 +0000267 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
Ayman Musac5490e52017-05-15 11:30:54 +0000268 Builder.CreateAlignedStore(Src, Ptr, AlignVal);
269 CI->eraseFromParent();
270 return;
271 }
272
273 // Adjust alignment for the scalar instruction.
Craig Topperbb50c382018-09-28 03:35:37 +0000274 AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
Craig Topper69f8c162019-03-09 02:08:41 +0000275 // Bitcast %addr from i8* to EltTy*
Ayman Musac5490e52017-05-15 11:30:54 +0000276 Type *NewPtrType =
Craig Topper69f8c162019-03-09 02:08:41 +0000277 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
Ayman Musac5490e52017-05-15 11:30:54 +0000278 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
279 unsigned VectorWidth = VecType->getNumElements();
280
Craig Topper8b4f0e12018-09-27 22:31:42 +0000281 if (isConstantIntVector(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000282 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000283 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000284 continue;
Craig Topper69f8c162019-03-09 02:08:41 +0000285 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
286 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
Ayman Musac5490e52017-05-15 11:30:54 +0000287 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
288 }
289 CI->eraseFromParent();
290 return;
291 }
292
293 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000294 // Fill the "else" block, created in the previous iteration
295 //
296 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
Craig Topper04236812018-09-27 18:01:48 +0000297 // br i1 %mask_1, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000298 //
Craig Topper69f8c162019-03-09 02:08:41 +0000299 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
Ayman Musac5490e52017-05-15 11:30:54 +0000300
301 // Create "cond" block
302 //
303 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
304 // %EltAddr = getelementptr i32* %1, i32 0
305 // %store i32 %OneElt, i32* %EltAddr
306 //
307 BasicBlock *CondBlock =
308 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
309 Builder.SetInsertPoint(InsertPt);
310
Craig Topper69f8c162019-03-09 02:08:41 +0000311 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
312 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
Ayman Musac5490e52017-05-15 11:30:54 +0000313 Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
314
315 // Create "else" block, fill it in the next iteration
316 BasicBlock *NewIfBlock =
317 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
318 Builder.SetInsertPoint(InsertPt);
319 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000320 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000321 OldBr->eraseFromParent();
322 IfBlock = NewIfBlock;
323 }
324 CI->eraseFromParent();
Craig Topperd84f6052019-03-08 23:03:43 +0000325
326 ModifiedDT = true;
Ayman Musac5490e52017-05-15 11:30:54 +0000327}
328
329// Translate a masked gather intrinsic like
330// <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
331// <16 x i1> %Mask, <16 x i32> %Src)
332// to a chain of basic blocks, with loading element one-by-one if
333// the appropriate mask bit is set
334//
Craig Topper49dad8b2018-09-27 21:28:39 +0000335// %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
336// %Mask0 = extractelement <16 x i1> %Mask, i32 0
337// br i1 %Mask0, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000338//
339// cond.load:
Craig Topper49dad8b2018-09-27 21:28:39 +0000340// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
341// %Load0 = load i32, i32* %Ptr0, align 4
342// %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
Ayman Musac5490e52017-05-15 11:30:54 +0000343// br label %else
344//
345// else:
Craig Topper49dad8b2018-09-27 21:28:39 +0000346// %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
347// %Mask1 = extractelement <16 x i1> %Mask, i32 1
348// br i1 %Mask1, label %cond.load1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000349//
350// cond.load1:
Craig Topper49dad8b2018-09-27 21:28:39 +0000351// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
352// %Load1 = load i32, i32* %Ptr1, align 4
353// %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
Ayman Musac5490e52017-05-15 11:30:54 +0000354// br label %else2
355// . . .
Craig Topper49dad8b2018-09-27 21:28:39 +0000356// %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
Ayman Musac5490e52017-05-15 11:30:54 +0000357// ret <16 x i32> %Result
Craig Topperd84f6052019-03-08 23:03:43 +0000358static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000359 Value *Ptrs = CI->getArgOperand(0);
360 Value *Alignment = CI->getArgOperand(1);
361 Value *Mask = CI->getArgOperand(2);
362 Value *Src0 = CI->getArgOperand(3);
363
Craig Topper10ec0212018-09-27 22:31:40 +0000364 VectorType *VecType = cast<VectorType>(CI->getType());
James Y Knight14359ef2019-02-01 20:44:24 +0000365 Type *EltTy = VecType->getElementType();
Ayman Musac5490e52017-05-15 11:30:54 +0000366
367 IRBuilder<> Builder(CI->getContext());
368 Instruction *InsertPt = CI;
369 BasicBlock *IfBlock = CI->getParent();
Ayman Musac5490e52017-05-15 11:30:54 +0000370 Builder.SetInsertPoint(InsertPt);
371 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
372
373 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
374
Ayman Musac5490e52017-05-15 11:30:54 +0000375 // The result vector
Craig Topper6911bfe2018-09-27 21:28:59 +0000376 Value *VResult = Src0;
Ayman Musac5490e52017-05-15 11:30:54 +0000377 unsigned VectorWidth = VecType->getNumElements();
378
379 // Shorten the way if the mask is a vector of constants.
Craig Topper8b4f0e12018-09-27 22:31:42 +0000380 if (isConstantIntVector(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000381 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000382 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000383 continue;
Craig Topper69f8c162019-03-09 02:08:41 +0000384 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000385 LoadInst *Load =
James Y Knight14359ef2019-02-01 20:44:24 +0000386 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
Craig Topper69f8c162019-03-09 02:08:41 +0000387 VResult =
388 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000389 }
Craig Topper6911bfe2018-09-27 21:28:59 +0000390 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000391 CI->eraseFromParent();
392 return;
393 }
394
Ayman Musac5490e52017-05-15 11:30:54 +0000395 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Ayman Musac5490e52017-05-15 11:30:54 +0000396 // Fill the "else" block, created in the previous iteration
397 //
398 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
Craig Topper04236812018-09-27 18:01:48 +0000399 // br i1 %Mask1, label %cond.load, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000400 //
Ayman Musac5490e52017-05-15 11:30:54 +0000401
Craig Topper69f8c162019-03-09 02:08:41 +0000402 Value *Predicate =
403 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000404
405 // Create "cond" block
406 //
407 // %EltAddr = getelementptr i32* %1, i32 0
408 // %Elt = load i32* %EltAddr
409 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
410 //
Craig Topper4104c002018-10-30 20:33:58 +0000411 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
Ayman Musac5490e52017-05-15 11:30:54 +0000412 Builder.SetInsertPoint(InsertPt);
413
Craig Topper69f8c162019-03-09 02:08:41 +0000414 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000415 LoadInst *Load =
James Y Knight14359ef2019-02-01 20:44:24 +0000416 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
Craig Topper69f8c162019-03-09 02:08:41 +0000417 Value *NewVResult =
418 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000419
420 // Create "else" block, fill it in the next iteration
421 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
422 Builder.SetInsertPoint(InsertPt);
423 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000424 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000425 OldBr->eraseFromParent();
Craig Topper4104c002018-10-30 20:33:58 +0000426 BasicBlock *PrevIfBlock = IfBlock;
Ayman Musac5490e52017-05-15 11:30:54 +0000427 IfBlock = NewIfBlock;
Craig Topper6911bfe2018-09-27 21:28:59 +0000428
429 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
430 Phi->addIncoming(NewVResult, CondBlock);
431 Phi->addIncoming(VResult, PrevIfBlock);
432 VResult = Phi;
Ayman Musac5490e52017-05-15 11:30:54 +0000433 }
434
Craig Topper6911bfe2018-09-27 21:28:59 +0000435 CI->replaceAllUsesWith(VResult);
Ayman Musac5490e52017-05-15 11:30:54 +0000436 CI->eraseFromParent();
Craig Topperd84f6052019-03-08 23:03:43 +0000437
438 ModifiedDT = true;
Ayman Musac5490e52017-05-15 11:30:54 +0000439}
440
441// Translate a masked scatter intrinsic, like
442// void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
443// <16 x i1> %Mask)
444// to a chain of basic blocks, that stores element one-by-one if
445// the appropriate mask bit is set.
446//
Craig Topper49dad8b2018-09-27 21:28:39 +0000447// %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
448// %Mask0 = extractelement <16 x i1> %Mask, i32 0
449// br i1 %Mask0, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000450//
451// cond.store:
Craig Topper49dad8b2018-09-27 21:28:39 +0000452// %Elt0 = extractelement <16 x i32> %Src, i32 0
453// %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
454// store i32 %Elt0, i32* %Ptr0, align 4
Ayman Musac5490e52017-05-15 11:30:54 +0000455// br label %else
456//
457// else:
Craig Topper49dad8b2018-09-27 21:28:39 +0000458// %Mask1 = extractelement <16 x i1> %Mask, i32 1
459// br i1 %Mask1, label %cond.store1, label %else2
Ayman Musac5490e52017-05-15 11:30:54 +0000460//
461// cond.store1:
Craig Topper49dad8b2018-09-27 21:28:39 +0000462// %Elt1 = extractelement <16 x i32> %Src, i32 1
463// %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
464// store i32 %Elt1, i32* %Ptr1, align 4
Ayman Musac5490e52017-05-15 11:30:54 +0000465// br label %else2
466// . . .
Craig Topperd84f6052019-03-08 23:03:43 +0000467static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000468 Value *Src = CI->getArgOperand(0);
469 Value *Ptrs = CI->getArgOperand(1);
470 Value *Alignment = CI->getArgOperand(2);
471 Value *Mask = CI->getArgOperand(3);
472
473 assert(isa<VectorType>(Src->getType()) &&
474 "Unexpected data type in masked scatter intrinsic");
475 assert(isa<VectorType>(Ptrs->getType()) &&
476 isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
477 "Vector of pointers is expected in masked scatter intrinsic");
478
479 IRBuilder<> Builder(CI->getContext());
480 Instruction *InsertPt = CI;
481 BasicBlock *IfBlock = CI->getParent();
482 Builder.SetInsertPoint(InsertPt);
483 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
484
485 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
486 unsigned VectorWidth = Src->getType()->getVectorNumElements();
487
488 // Shorten the way if the mask is a vector of constants.
Craig Topper8b4f0e12018-09-27 22:31:42 +0000489 if (isConstantIntVector(Mask)) {
Ayman Musac5490e52017-05-15 11:30:54 +0000490 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
Craig Topperdfc0f282018-09-27 21:28:46 +0000491 if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
Ayman Musac5490e52017-05-15 11:30:54 +0000492 continue;
Craig Topper69f8c162019-03-09 02:08:41 +0000493 Value *OneElt =
494 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
495 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000496 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
497 }
498 CI->eraseFromParent();
499 return;
500 }
Craig Topperdfe460d2018-09-27 21:28:41 +0000501
Ayman Musac5490e52017-05-15 11:30:54 +0000502 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
503 // Fill the "else" block, created in the previous iteration
504 //
Craig Topper04236812018-09-27 18:01:48 +0000505 // %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
506 // br i1 %Mask1, label %cond.store, label %else
Ayman Musac5490e52017-05-15 11:30:54 +0000507 //
Craig Topper69f8c162019-03-09 02:08:41 +0000508 Value *Predicate =
509 Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000510
511 // Create "cond" block
512 //
Craig Topper49dad8b2018-09-27 21:28:39 +0000513 // %Elt1 = extractelement <16 x i32> %Src, i32 1
514 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
515 // %store i32 %Elt1, i32* %Ptr1
Ayman Musac5490e52017-05-15 11:30:54 +0000516 //
517 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
518 Builder.SetInsertPoint(InsertPt);
519
Craig Topper69f8c162019-03-09 02:08:41 +0000520 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
521 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
Ayman Musac5490e52017-05-15 11:30:54 +0000522 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
523
524 // Create "else" block, fill it in the next iteration
525 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
526 Builder.SetInsertPoint(InsertPt);
527 Instruction *OldBr = IfBlock->getTerminator();
Craig Topper04236812018-09-27 18:01:48 +0000528 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
Ayman Musac5490e52017-05-15 11:30:54 +0000529 OldBr->eraseFromParent();
530 IfBlock = NewIfBlock;
531 }
532 CI->eraseFromParent();
Craig Topperd84f6052019-03-08 23:03:43 +0000533
534 ModifiedDT = true;
Ayman Musac5490e52017-05-15 11:30:54 +0000535}
536
Craig Topper9f0b17a2019-03-21 17:38:52 +0000537static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
538 Value *Ptr = CI->getArgOperand(0);
539 Value *Mask = CI->getArgOperand(1);
540 Value *PassThru = CI->getArgOperand(2);
541
542 VectorType *VecType = cast<VectorType>(CI->getType());
543
544 Type *EltTy = VecType->getElementType();
545
546 IRBuilder<> Builder(CI->getContext());
547 Instruction *InsertPt = CI;
548 BasicBlock *IfBlock = CI->getParent();
549
550 Builder.SetInsertPoint(InsertPt);
551 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
552
553 unsigned VectorWidth = VecType->getNumElements();
554
555 // The result vector
556 Value *VResult = PassThru;
557
558 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
559 // Fill the "else" block, created in the previous iteration
560 //
561 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
562 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
563 // br i1 %mask_1, label %cond.load, label %else
564 //
565
566 Value *Predicate =
567 Builder.CreateExtractElement(Mask, Idx);
568
569 // Create "cond" block
570 //
571 // %EltAddr = getelementptr i32* %1, i32 0
572 // %Elt = load i32* %EltAddr
573 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
574 //
575 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
576 "cond.load");
577 Builder.SetInsertPoint(InsertPt);
578
579 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
580 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
581
582 // Move the pointer if there are more blocks to come.
583 Value *NewPtr;
584 if ((Idx + 1) != VectorWidth)
585 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
586
587 // Create "else" block, fill it in the next iteration
588 BasicBlock *NewIfBlock =
589 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
590 Builder.SetInsertPoint(InsertPt);
591 Instruction *OldBr = IfBlock->getTerminator();
592 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
593 OldBr->eraseFromParent();
594 BasicBlock *PrevIfBlock = IfBlock;
595 IfBlock = NewIfBlock;
596
597 // Create the phi to join the new and previous value.
598 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
599 ResultPhi->addIncoming(NewVResult, CondBlock);
600 ResultPhi->addIncoming(VResult, PrevIfBlock);
601 VResult = ResultPhi;
602
603 // Add a PHI for the pointer if this isn't the last iteration.
604 if ((Idx + 1) != VectorWidth) {
605 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
606 PtrPhi->addIncoming(NewPtr, CondBlock);
607 PtrPhi->addIncoming(Ptr, PrevIfBlock);
608 Ptr = PtrPhi;
609 }
610 }
611
612 CI->replaceAllUsesWith(VResult);
613 CI->eraseFromParent();
614
615 ModifiedDT = true;
616}
617
618static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
619 Value *Src = CI->getArgOperand(0);
620 Value *Ptr = CI->getArgOperand(1);
621 Value *Mask = CI->getArgOperand(2);
622
623 VectorType *VecType = cast<VectorType>(Src->getType());
624
625 IRBuilder<> Builder(CI->getContext());
626 Instruction *InsertPt = CI;
627 BasicBlock *IfBlock = CI->getParent();
628
629 Builder.SetInsertPoint(InsertPt);
630 Builder.SetCurrentDebugLocation(CI->getDebugLoc());
631
632 Type *EltTy = VecType->getVectorElementType();
633
634 unsigned VectorWidth = VecType->getNumElements();
635
636 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
637 // Fill the "else" block, created in the previous iteration
638 //
639 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
640 // br i1 %mask_1, label %cond.store, label %else
641 //
642 Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
643
644 // Create "cond" block
645 //
646 // %OneElt = extractelement <16 x i32> %Src, i32 Idx
647 // %EltAddr = getelementptr i32* %1, i32 0
648 // %store i32 %OneElt, i32* %EltAddr
649 //
650 BasicBlock *CondBlock =
651 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
652 Builder.SetInsertPoint(InsertPt);
653
654 Value *OneElt = Builder.CreateExtractElement(Src, Idx);
655 Builder.CreateAlignedStore(OneElt, Ptr, 1);
656
657 // Move the pointer if there are more blocks to come.
658 Value *NewPtr;
659 if ((Idx + 1) != VectorWidth)
660 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
661
662 // Create "else" block, fill it in the next iteration
663 BasicBlock *NewIfBlock =
664 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
665 Builder.SetInsertPoint(InsertPt);
666 Instruction *OldBr = IfBlock->getTerminator();
667 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
668 OldBr->eraseFromParent();
669 BasicBlock *PrevIfBlock = IfBlock;
670 IfBlock = NewIfBlock;
671
672 // Add a PHI for the pointer if this isn't the last iteration.
673 if ((Idx + 1) != VectorWidth) {
674 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
675 PtrPhi->addIncoming(NewPtr, CondBlock);
676 PtrPhi->addIncoming(Ptr, PrevIfBlock);
677 Ptr = PtrPhi;
678 }
679 }
680 CI->eraseFromParent();
681
682 ModifiedDT = true;
683}
684
Ayman Musac5490e52017-05-15 11:30:54 +0000685bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
Ayman Musac5490e52017-05-15 11:30:54 +0000686 bool EverMadeChange = false;
687
688 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
689
690 bool MadeChange = true;
691 while (MadeChange) {
692 MadeChange = false;
693 for (Function::iterator I = F.begin(); I != F.end();) {
694 BasicBlock *BB = &*I++;
695 bool ModifiedDTOnIteration = false;
696 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
697
698 // Restart BB iteration if the dominator tree of the Function was changed
699 if (ModifiedDTOnIteration)
700 break;
701 }
702
703 EverMadeChange |= MadeChange;
704 }
705
706 return EverMadeChange;
707}
708
709bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
710 bool MadeChange = false;
711
712 BasicBlock::iterator CurInstIterator = BB.begin();
713 while (CurInstIterator != BB.end()) {
714 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
715 MadeChange |= optimizeCallInst(CI, ModifiedDT);
716 if (ModifiedDT)
717 return true;
718 }
719
720 return MadeChange;
721}
722
723bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
724 bool &ModifiedDT) {
Ayman Musac5490e52017-05-15 11:30:54 +0000725 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
726 if (II) {
727 switch (II->getIntrinsicID()) {
728 default:
729 break;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000730 case Intrinsic::masked_load:
Ayman Musac5490e52017-05-15 11:30:54 +0000731 // Scalarize unsupported vector masked load
Craig Topper8de7bc02019-03-21 05:54:37 +0000732 if (TTI->isLegalMaskedLoad(CI->getType()))
733 return false;
734 scalarizeMaskedLoad(CI, ModifiedDT);
735 return true;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000736 case Intrinsic::masked_store:
Craig Topper8de7bc02019-03-21 05:54:37 +0000737 if (TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType()))
738 return false;
739 scalarizeMaskedStore(CI, ModifiedDT);
740 return true;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000741 case Intrinsic::masked_gather:
Craig Topper8de7bc02019-03-21 05:54:37 +0000742 if (TTI->isLegalMaskedGather(CI->getType()))
743 return false;
744 scalarizeMaskedGather(CI, ModifiedDT);
745 return true;
Eugene Zelenkofa57bd02017-09-27 23:26:01 +0000746 case Intrinsic::masked_scatter:
Craig Topper8de7bc02019-03-21 05:54:37 +0000747 if (TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType()))
748 return false;
749 scalarizeMaskedScatter(CI, ModifiedDT);
750 return true;
Craig Topper9f0b17a2019-03-21 17:38:52 +0000751 case Intrinsic::masked_expandload:
752 if (TTI->isLegalMaskedExpandLoad(CI->getType()))
753 return false;
754 scalarizeMaskedExpandLoad(CI, ModifiedDT);
755 return true;
756 case Intrinsic::masked_compressstore:
757 if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
758 return false;
759 scalarizeMaskedCompressStore(CI, ModifiedDT);
760 return true;
Ayman Musac5490e52017-05-15 11:30:54 +0000761 }
Ayman Musac5490e52017-05-15 11:30:54 +0000762 }
763
764 return false;
765}