Perform copying to created arrays according to the packing transformation

This is the fourth patch to apply the BLIS matmul optimization pattern on matmul
kernels (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
BLIS implements gemm as three nested loops around a macro-kernel, plus two
packing routines. The macro-kernel is implemented in terms of two additional
loops around a micro-kernel. The micro-kernel is a loop around a rank-1
(i.e., outer product) update. In this change we perform copying to created
arrays, which is the last step to implement the packing transformation.

Reviewed-by: Tobias Grosser <tobias@grosser.es>

Differential Revision: https://reviews.llvm.org/D23260

llvm-svn: 281441
diff --git a/polly/lib/CodeGen/BlockGenerators.cpp b/polly/lib/CodeGen/BlockGenerators.cpp
index 6480a3a..70c5d98 100644
--- a/polly/lib/CodeGen/BlockGenerators.cpp
+++ b/polly/lib/CodeGen/BlockGenerators.cpp
@@ -681,7 +681,9 @@
 
 void BlockGenerator::invalidateScalarEvolution(Scop &S) {
   for (auto &Stmt : S)
-    if (Stmt.isBlockStmt())
+    if (Stmt.isCopyStmt())
+      continue;
+    else if (Stmt.isBlockStmt())
       for (auto &Inst : *Stmt.getBasicBlock())
         SE.forgetValue(&Inst);
     else if (Stmt.isRegionStmt())
diff --git a/polly/lib/CodeGen/IRBuilder.cpp b/polly/lib/CodeGen/IRBuilder.cpp
index cedbe29..92fce35 100644
--- a/polly/lib/CodeGen/IRBuilder.cpp
+++ b/polly/lib/CodeGen/IRBuilder.cpp
@@ -61,7 +61,8 @@
   SetVector<Value *> BasePtrs;
   for (ScopStmt &Stmt : S)
     for (MemoryAccess *MA : Stmt)
-      BasePtrs.insert(MA->getBaseAddr());
+      if (!Stmt.isCopyStmt())
+        BasePtrs.insert(MA->getBaseAddr());
 
   std::string AliasScopeStr = "polly.alias.scope.";
   for (Value *BasePtr : BasePtrs)
diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp
index 4b7cde0..2f00b69 100644
--- a/polly/lib/CodeGen/IslAst.cpp
+++ b/polly/lib/CodeGen/IslAst.cpp
@@ -593,8 +593,7 @@
   P = isl_ast_node_print(RootNode, P, Options);
   AstStr = isl_printer_get_str(P);
 
-  isl_union_map *Schedule =
-      isl_union_map_intersect_domain(S.getSchedule(), S.getDomains());
+  auto *Schedule = S.getScheduleTree();
 
   DEBUG({
     dbgs() << S.getContextStr() << "\n";
@@ -609,7 +608,7 @@
   free(AstStr);
 
   isl_ast_expr_free(RunCondition);
-  isl_union_map_free(Schedule);
+  isl_schedule_free(Schedule);
   isl_ast_node_free(RootNode);
   isl_printer_free(P);
 }
diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp
index 0134323..cea97b7 100644
--- a/polly/lib/CodeGen/IslNodeBuilder.cpp
+++ b/polly/lib/CodeGen/IslNodeBuilder.cpp
@@ -767,6 +767,23 @@
   isl_ast_expr_free(Expr);
 }
 
+void IslNodeBuilder::generateCopyStmt(
+    ScopStmt *Stmt, __isl_keep isl_id_to_ast_expr *NewAccesses) {
+  assert(Stmt->size() == 2);
+  auto ReadAccess = Stmt->begin();
+  auto WriteAccess = ReadAccess++;
+  assert((*ReadAccess)->isRead() && (*WriteAccess)->isMustWrite());
+  assert((*ReadAccess)->getElementType() == (*WriteAccess)->getElementType() &&
+         "Accesses use the same data type");
+  assert((*ReadAccess)->isArrayKind() && (*WriteAccess)->isArrayKind());
+  auto *AccessExpr =
+      isl_id_to_ast_expr_get(NewAccesses, (*ReadAccess)->getId());
+  auto *LoadValue = ExprBuilder.create(AccessExpr);
+  AccessExpr = isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId());
+  auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr);
+  Builder.CreateStore(LoadValue, StoreAddr);
+}
+
 void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
   LoopToScevMapT LTS;
   isl_id *Id;
@@ -781,12 +798,17 @@
 
   Stmt = (ScopStmt *)isl_id_get_user(Id);
   auto *NewAccesses = createNewAccesses(Stmt, User);
-  createSubstitutions(Expr, Stmt, LTS);
+  if (Stmt->isCopyStmt()) {
+    generateCopyStmt(Stmt, NewAccesses);
+    isl_ast_expr_free(Expr);
+  } else {
+    createSubstitutions(Expr, Stmt, LTS);
 
-  if (Stmt->isBlockStmt())
-    BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
-  else
-    RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
+    if (Stmt->isBlockStmt())
+      BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
+    else
+      RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
+  }
 
   isl_id_to_ast_expr_free(NewAccesses);
   isl_ast_node_free(User);