[ScalarizeMaskedMemIntrin] Add support for scalarizing expandload and compressstore intrinsics.
This adds support for scalarizing these intrinsics as well the X86TargetTransformInfo support to avoid scalarizing them in the cases X86 can handle.
I've omitted handling special cases for constant masks for this first pass. Though CodeGenPrepare can constant fold the branch conditions and remove some of the control flow anyway.
Fixes PR40994 and is covers most of PR3666. Might want to implement constant masks to close that.
Differential Revision: https://reviews.llvm.org/D59180
llvm-svn: 356687
diff --git a/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp b/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
index 3a27035..e2ee9f2 100644
--- a/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
+++ b/llvm/lib/CodeGen/ScalarizeMaskedMemIntrin.cpp
@@ -534,6 +534,154 @@
ModifiedDT = true;
}
+static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
+ Value *Ptr = CI->getArgOperand(0);
+ Value *Mask = CI->getArgOperand(1);
+ Value *PassThru = CI->getArgOperand(2);
+
+ VectorType *VecType = cast<VectorType>(CI->getType());
+
+ Type *EltTy = VecType->getElementType();
+
+ IRBuilder<> Builder(CI->getContext());
+ Instruction *InsertPt = CI;
+ BasicBlock *IfBlock = CI->getParent();
+
+ Builder.SetInsertPoint(InsertPt);
+ Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
+ unsigned VectorWidth = VecType->getNumElements();
+
+ // The result vector
+ Value *VResult = PassThru;
+
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ // Fill the "else" block, created in the previous iteration
+ //
+ // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
+ // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
+ // br i1 %mask_1, label %cond.load, label %else
+ //
+
+ Value *Predicate =
+ Builder.CreateExtractElement(Mask, Idx);
+
+ // Create "cond" block
+ //
+ // %EltAddr = getelementptr i32* %1, i32 0
+ // %Elt = load i32* %EltAddr
+ // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
+ //
+ BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
+ "cond.load");
+ Builder.SetInsertPoint(InsertPt);
+
+ LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, 1);
+ Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
+
+ // Move the pointer if there are more blocks to come.
+ Value *NewPtr;
+ if ((Idx + 1) != VectorWidth)
+ NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
+
+ // Create "else" block, fill it in the next iteration
+ BasicBlock *NewIfBlock =
+ CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
+ Builder.SetInsertPoint(InsertPt);
+ Instruction *OldBr = IfBlock->getTerminator();
+ BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
+ OldBr->eraseFromParent();
+ BasicBlock *PrevIfBlock = IfBlock;
+ IfBlock = NewIfBlock;
+
+ // Create the phi to join the new and previous value.
+ PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
+ ResultPhi->addIncoming(NewVResult, CondBlock);
+ ResultPhi->addIncoming(VResult, PrevIfBlock);
+ VResult = ResultPhi;
+
+ // Add a PHI for the pointer if this isn't the last iteration.
+ if ((Idx + 1) != VectorWidth) {
+ PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
+ PtrPhi->addIncoming(NewPtr, CondBlock);
+ PtrPhi->addIncoming(Ptr, PrevIfBlock);
+ Ptr = PtrPhi;
+ }
+ }
+
+ CI->replaceAllUsesWith(VResult);
+ CI->eraseFromParent();
+
+ ModifiedDT = true;
+}
+
+static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
+ Value *Src = CI->getArgOperand(0);
+ Value *Ptr = CI->getArgOperand(1);
+ Value *Mask = CI->getArgOperand(2);
+
+ VectorType *VecType = cast<VectorType>(Src->getType());
+
+ IRBuilder<> Builder(CI->getContext());
+ Instruction *InsertPt = CI;
+ BasicBlock *IfBlock = CI->getParent();
+
+ Builder.SetInsertPoint(InsertPt);
+ Builder.SetCurrentDebugLocation(CI->getDebugLoc());
+
+ Type *EltTy = VecType->getVectorElementType();
+
+ unsigned VectorWidth = VecType->getNumElements();
+
+ for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
+ // Fill the "else" block, created in the previous iteration
+ //
+ // %mask_1 = extractelement <16 x i1> %mask, i32 Idx
+ // br i1 %mask_1, label %cond.store, label %else
+ //
+ Value *Predicate = Builder.CreateExtractElement(Mask, Idx);
+
+ // Create "cond" block
+ //
+ // %OneElt = extractelement <16 x i32> %Src, i32 Idx
+ // %EltAddr = getelementptr i32* %1, i32 0
+ // %store i32 %OneElt, i32* %EltAddr
+ //
+ BasicBlock *CondBlock =
+ IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
+ Builder.SetInsertPoint(InsertPt);
+
+ Value *OneElt = Builder.CreateExtractElement(Src, Idx);
+ Builder.CreateAlignedStore(OneElt, Ptr, 1);
+
+ // Move the pointer if there are more blocks to come.
+ Value *NewPtr;
+ if ((Idx + 1) != VectorWidth)
+ NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
+
+ // Create "else" block, fill it in the next iteration
+ BasicBlock *NewIfBlock =
+ CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
+ Builder.SetInsertPoint(InsertPt);
+ Instruction *OldBr = IfBlock->getTerminator();
+ BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
+ OldBr->eraseFromParent();
+ BasicBlock *PrevIfBlock = IfBlock;
+ IfBlock = NewIfBlock;
+
+ // Add a PHI for the pointer if this isn't the last iteration.
+ if ((Idx + 1) != VectorWidth) {
+ PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
+ PtrPhi->addIncoming(NewPtr, CondBlock);
+ PtrPhi->addIncoming(Ptr, PrevIfBlock);
+ Ptr = PtrPhi;
+ }
+ }
+ CI->eraseFromParent();
+
+ ModifiedDT = true;
+}
+
bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
bool EverMadeChange = false;
@@ -600,6 +748,16 @@
return false;
scalarizeMaskedScatter(CI, ModifiedDT);
return true;
+ case Intrinsic::masked_expandload:
+ if (TTI->isLegalMaskedExpandLoad(CI->getType()))
+ return false;
+ scalarizeMaskedExpandLoad(CI, ModifiedDT);
+ return true;
+ case Intrinsic::masked_compressstore:
+ if (TTI->isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
+ return false;
+ scalarizeMaskedCompressStore(CI, ModifiedDT);
+ return true;
}
}