Disable the Loop Vectorizer in case of GEMM
Currently, in case of GEMM and the pattern matching based optimizations, we
use only the SLP Vectorizer out of two LLVM vectorizers. Since the Loop
Vectorizer can get in the way of optimal code generation, we disable the Loop
Vectorizer for the innermost loop using mark nodes and emitting the
corresponding metadata.
Reviewed-by: Tobias Grosser <tobias@grosser.es>
Differential Revision: https://reviews.llvm.org/D36928
llvm-svn: 311473
diff --git a/polly/lib/CodeGen/IRBuilder.cpp b/polly/lib/CodeGen/IRBuilder.cpp
index 94cdda2..1ad01a4 100644
--- a/polly/lib/CodeGen/IRBuilder.cpp
+++ b/polly/lib/CodeGen/IRBuilder.cpp
@@ -114,15 +114,27 @@
ParallelLoops.pop_back();
}
-void ScopAnnotator::annotateLoopLatch(BranchInst *B, Loop *L,
- bool IsParallel) const {
- if (!IsParallel)
- return;
+void ScopAnnotator::annotateLoopLatch(BranchInst *B, Loop *L, bool IsParallel,
+ bool IsLoopVectorizerDisabled) const {
+ MDNode *MData = nullptr;
- assert(!ParallelLoops.empty() && "Expected a parallel loop to annotate");
- MDNode *Ids = ParallelLoops.back();
- MDNode *Id = cast<MDNode>(Ids->getOperand(Ids->getNumOperands() - 1));
- B->setMetadata("llvm.loop", Id);
+ if (IsLoopVectorizerDisabled) {
+ SmallVector<Metadata *, 3> Args;
+ LLVMContext &Ctx = SE->getContext();
+ Args.push_back(MDString::get(Ctx, "llvm.loop.vectorize.enable"));
+ auto *FalseValue = ConstantInt::get(Type::getInt1Ty(Ctx), 0);
+ Args.push_back(ValueAsMetadata::get(FalseValue));
+ MData = MDNode::concatenate(MData, getID(Ctx, MDNode::get(Ctx, Args)));
+ }
+
+ if (IsParallel) {
+ assert(!ParallelLoops.empty() && "Expected a parallel loop to annotate");
+ MDNode *Ids = ParallelLoops.back();
+ MDNode *Id = cast<MDNode>(Ids->getOperand(Ids->getNumOperands() - 1));
+ MData = MDNode::concatenate(MData, Id);
+ }
+
+ B->setMetadata("llvm.loop", MData);
}
/// Get the pointer operand
diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp
index 98c8e30..8ed572e 100644
--- a/polly/lib/CodeGen/IslNodeBuilder.cpp
+++ b/polly/lib/CodeGen/IslNodeBuilder.cpp
@@ -482,6 +482,27 @@
isl_ast_expr_free(Iterator);
}
+/// Restore the initial ordering of dimensions of the band node
+///
+/// In case the band node represents all the dimensions of the iteration
+/// domain, recreate the band node to restore the initial ordering of the
+/// dimensions.
+///
+/// @param Node The band node to be modified.
+/// @return The modified schedule node.
+namespace {
+bool IsLoopVectorizerDisabled(isl::ast_node Node) {
+ assert(isl_ast_node_get_type(Node.keep()) == isl_ast_node_for);
+ auto Body = Node.for_get_body();
+ if (isl_ast_node_get_type(Body.keep()) != isl_ast_node_mark)
+ return false;
+ auto Id = Body.mark_get_id();
+ if (!strcmp(Id.get_name().c_str(), "Loop Vectorizer Disabled"))
+ return true;
+ return false;
+}
+} // namespace
+
void IslNodeBuilder::createForSequential(__isl_take isl_ast_node *For,
bool KnownParallel) {
isl_ast_node *Body;
@@ -497,6 +518,9 @@
Parallel = KnownParallel || (IslAstInfo::isParallel(For) &&
!IslAstInfo::isReductionParallel(For));
+ bool LoopVectorizerDisabled =
+ IsLoopVectorizerDisabled(isl::manage(isl_ast_node_copy(For)));
+
Body = isl_ast_node_for_get_body(For);
// isl_ast_node_for_is_degenerate(For)
@@ -532,7 +556,8 @@
bool UseGuardBB =
!SE.isKnownPredicate(Predicate, SE.getSCEV(ValueLB), SE.getSCEV(ValueUB));
IV = createLoop(ValueLB, ValueUB, ValueInc, Builder, LI, DT, ExitBlock,
- Predicate, &Annotator, Parallel, UseGuardBB);
+ Predicate, &Annotator, Parallel, UseGuardBB,
+ LoopVectorizerDisabled);
IDToValue[IteratorID] = IV;
create(Body);
diff --git a/polly/lib/CodeGen/LoopGenerators.cpp b/polly/lib/CodeGen/LoopGenerators.cpp
index 483a994..42f7e67 100644
--- a/polly/lib/CodeGen/LoopGenerators.cpp
+++ b/polly/lib/CodeGen/LoopGenerators.cpp
@@ -56,8 +56,8 @@
PollyIRBuilder &Builder, LoopInfo &LI,
DominatorTree &DT, BasicBlock *&ExitBB,
ICmpInst::Predicate Predicate,
- ScopAnnotator *Annotator, bool Parallel,
- bool UseGuard) {
+ ScopAnnotator *Annotator, bool Parallel, bool UseGuard,
+ bool LoopVectDisabled) {
Function *F = Builder.GetInsertBlock()->getParent();
LLVMContext &Context = F->getContext();
@@ -132,7 +132,7 @@
// Create the loop latch and annotate it as such.
BranchInst *B = Builder.CreateCondBr(LoopCondition, HeaderBB, ExitBB);
if (Annotator)
- Annotator->annotateLoopLatch(B, NewLoop, Parallel);
+ Annotator->annotateLoopLatch(B, NewLoop, Parallel, LoopVectDisabled);
IV->addIncoming(IncrementedIV, HeaderBB);
if (GuardBB)