Disable the Loop Vectorizer in case of GEMM

Currently, in case of GEMM and the pattern matching based optimizations, we
use only the SLP Vectorizer out of two LLVM vectorizers. Since the Loop
Vectorizer can get in the way of optimal code generation, we disable the Loop
Vectorizer for the innermost loop using mark nodes and emitting the
corresponding metadata.

Reviewed-by: Tobias Grosser <tobias@grosser.es>

Differential Revision: https://reviews.llvm.org/D36928

llvm-svn: 311473
diff --git a/polly/lib/CodeGen/IRBuilder.cpp b/polly/lib/CodeGen/IRBuilder.cpp
index 94cdda2..1ad01a4 100644
--- a/polly/lib/CodeGen/IRBuilder.cpp
+++ b/polly/lib/CodeGen/IRBuilder.cpp
@@ -114,15 +114,27 @@
   ParallelLoops.pop_back();
 }
 
-void ScopAnnotator::annotateLoopLatch(BranchInst *B, Loop *L,
-                                      bool IsParallel) const {
-  if (!IsParallel)
-    return;
+void ScopAnnotator::annotateLoopLatch(BranchInst *B, Loop *L, bool IsParallel,
+                                      bool IsLoopVectorizerDisabled) const {
+  MDNode *MData = nullptr;
 
-  assert(!ParallelLoops.empty() && "Expected a parallel loop to annotate");
-  MDNode *Ids = ParallelLoops.back();
-  MDNode *Id = cast<MDNode>(Ids->getOperand(Ids->getNumOperands() - 1));
-  B->setMetadata("llvm.loop", Id);
+  if (IsLoopVectorizerDisabled) {
+    SmallVector<Metadata *, 3> Args;
+    LLVMContext &Ctx = SE->getContext();
+    Args.push_back(MDString::get(Ctx, "llvm.loop.vectorize.enable"));
+    auto *FalseValue = ConstantInt::get(Type::getInt1Ty(Ctx), 0);
+    Args.push_back(ValueAsMetadata::get(FalseValue));
+    MData = MDNode::concatenate(MData, getID(Ctx, MDNode::get(Ctx, Args)));
+  }
+
+  if (IsParallel) {
+    assert(!ParallelLoops.empty() && "Expected a parallel loop to annotate");
+    MDNode *Ids = ParallelLoops.back();
+    MDNode *Id = cast<MDNode>(Ids->getOperand(Ids->getNumOperands() - 1));
+    MData = MDNode::concatenate(MData, Id);
+  }
+
+  B->setMetadata("llvm.loop", MData);
 }
 
 /// Get the pointer operand
diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp
index 98c8e30..8ed572e 100644
--- a/polly/lib/CodeGen/IslNodeBuilder.cpp
+++ b/polly/lib/CodeGen/IslNodeBuilder.cpp
@@ -482,6 +482,27 @@
   isl_ast_expr_free(Iterator);
 }
 
+/// Restore the initial ordering of dimensions of the band node
+///
+/// In case the band node represents all the dimensions of the iteration
+/// domain, recreate the band node to restore the initial ordering of the
+/// dimensions.
+///
+/// @param Node The band node to be modified.
+/// @return The modified schedule node.
+namespace {
+bool IsLoopVectorizerDisabled(isl::ast_node Node) {
+  assert(isl_ast_node_get_type(Node.keep()) == isl_ast_node_for);
+  auto Body = Node.for_get_body();
+  if (isl_ast_node_get_type(Body.keep()) != isl_ast_node_mark)
+    return false;
+  auto Id = Body.mark_get_id();
+  if (!strcmp(Id.get_name().c_str(), "Loop Vectorizer Disabled"))
+    return true;
+  return false;
+}
+} // namespace
+
 void IslNodeBuilder::createForSequential(__isl_take isl_ast_node *For,
                                          bool KnownParallel) {
   isl_ast_node *Body;
@@ -497,6 +518,9 @@
   Parallel = KnownParallel || (IslAstInfo::isParallel(For) &&
                                !IslAstInfo::isReductionParallel(For));
 
+  bool LoopVectorizerDisabled =
+      IsLoopVectorizerDisabled(isl::manage(isl_ast_node_copy(For)));
+
   Body = isl_ast_node_for_get_body(For);
 
   // isl_ast_node_for_is_degenerate(For)
@@ -532,7 +556,8 @@
   bool UseGuardBB =
       !SE.isKnownPredicate(Predicate, SE.getSCEV(ValueLB), SE.getSCEV(ValueUB));
   IV = createLoop(ValueLB, ValueUB, ValueInc, Builder, LI, DT, ExitBlock,
-                  Predicate, &Annotator, Parallel, UseGuardBB);
+                  Predicate, &Annotator, Parallel, UseGuardBB,
+                  LoopVectorizerDisabled);
   IDToValue[IteratorID] = IV;
 
   create(Body);
diff --git a/polly/lib/CodeGen/LoopGenerators.cpp b/polly/lib/CodeGen/LoopGenerators.cpp
index 483a994..42f7e67 100644
--- a/polly/lib/CodeGen/LoopGenerators.cpp
+++ b/polly/lib/CodeGen/LoopGenerators.cpp
@@ -56,8 +56,8 @@
                          PollyIRBuilder &Builder, LoopInfo &LI,
                          DominatorTree &DT, BasicBlock *&ExitBB,
                          ICmpInst::Predicate Predicate,
-                         ScopAnnotator *Annotator, bool Parallel,
-                         bool UseGuard) {
+                         ScopAnnotator *Annotator, bool Parallel, bool UseGuard,
+                         bool LoopVectDisabled) {
   Function *F = Builder.GetInsertBlock()->getParent();
   LLVMContext &Context = F->getContext();
 
@@ -132,7 +132,7 @@
   // Create the loop latch and annotate it as such.
   BranchInst *B = Builder.CreateCondBr(LoopCondition, HeaderBB, ExitBB);
   if (Annotator)
-    Annotator->annotateLoopLatch(B, NewLoop, Parallel);
+    Annotator->annotateLoopLatch(B, NewLoop, Parallel, LoopVectDisabled);
 
   IV->addIncoming(IncrementedIV, HeaderBB);
   if (GuardBB)
diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp
index 31a43a2..f7d560b 100644
--- a/polly/lib/Transform/ScheduleOptimizer.cpp
+++ b/polly/lib/Transform/ScheduleOptimizer.cpp
@@ -993,7 +993,7 @@
 
   // Create a copy statement that corresponds to the memory access to the
   // matrix B, the second operand of the matrix multiplication.
-  Node = Node.parent().parent().parent().parent().parent();
+  Node = Node.parent().parent().parent().parent().parent().parent();
   Node = isl::manage(isl_schedule_node_band_split(Node.release(), 2)).child(0);
   auto AccRel = getMatMulAccRel(isl::manage(MapOldIndVar.copy()), 3, 7);
   unsigned FirstDimSize = MacroParams.Nc / MicroParams.Nr;
@@ -1046,7 +1046,7 @@
   ExtMap = ExtMap.intersect_range(Domain);
   ExtMap = ExtMap.set_tuple_id(isl::dim::out, NewStmt->getDomainId());
   Node = createExtensionNode(Node, ExtMap);
-  return Node.child(0).child(0).child(0).child(0);
+  return Node.child(0).child(0).child(0).child(0).child(0);
 }
 
 /// Get a relation mapping induction variables produced by schedule
@@ -1106,11 +1106,11 @@
   isl::union_set Options = IsolateOption.unite(AtomicOption);
   Options = Options.unite(getUnrollIsolatedSetOptions(Ctx));
   Node = Node.band_set_ast_build_options(Options);
-  Node = Node.parent().parent();
+  Node = Node.parent().parent().parent();
   IsolateOption = getIsolateOptions(Prefix, 3);
   Options = IsolateOption.unite(AtomicOption);
   Node = Node.band_set_ast_build_options(Options);
-  Node = Node.child(0).child(0);
+  Node = Node.child(0).child(0).child(0);
   return Node;
 }
 
@@ -1129,6 +1129,15 @@
   return Node.insert_mark(Id).child(0);
 }
 
+/// Insert "Loop Vectorizer Disabled" mark node.
+///
+/// @param Node The child of the mark node to be inserted.
+/// @return The modified isl_schedule_node.
+static isl::schedule_node markLoopVectorizerDisabled(isl::schedule_node Node) {
+  auto Id = isl::id::alloc(Node.get_ctx(), "Loop Vectorizer Disabled", nullptr);
+  return Node.insert_mark(Id).child(0);
+}
+
 /// Restore the initial ordering of dimensions of the band node
 ///
 /// In case the band node represents all the dimensions of the iteration
@@ -1187,6 +1196,7 @@
                                                         MacroKernelParams);
   if (!MapOldIndVar)
     return Node;
+  Node = markLoopVectorizerDisabled(Node.parent()).child(0);
   Node = isolateAndUnrollMatMulInnerLoops(Node, MicroKernelParams);
   return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
                                           MacroKernelParams, MMI);