Perform copying to created arrays according to the packing transformation

This is the fourth patch to apply the BLIS matmul optimization pattern on matmul
kernels (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
BLIS implements gemm as three nested loops around a macro-kernel, plus two
packing routines. The macro-kernel is implemented in terms of two additional
loops around a micro-kernel. The micro-kernel is a loop around a rank-1
(i.e., outer product) update. In this change we perform copying to created
arrays, which is the last step to implement the packing transformation.

Reviewed-by: Tobias Grosser <tobias@grosser.es>

Differential Revision: https://reviews.llvm.org/D23260

llvm-svn: 281441
diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp
index b89374d..050ad67 100644
--- a/polly/lib/Transform/ScheduleOptimizer.cpp
+++ b/polly/lib/Transform/ScheduleOptimizer.cpp
@@ -660,6 +660,76 @@
   return IdentifiedAccess;
 }
 
+/// Add constrains to @Dim dimension of @p ExtMap.
+///
+/// If @ExtMap has the following form [O0, O1, O2]->[I1, I2, I3],
+/// the following constraint will be added
+/// Bound * OM <= IM <= Bound * (OM + 1) - 1,
+/// where M is @p Dim and Bound is @p Bound.
+///
+/// @param ExtMap The isl map to be modified.
+/// @param Dim The output dimension to be modfied.
+/// @param Bound The value that is used to specify the constraint.
+/// @return The modified isl map
+__isl_give isl_map *
+addExtensionMapMatMulDimConstraint(__isl_take isl_map *ExtMap, unsigned Dim,
+                                   unsigned Bound) {
+  assert(Bound != 0);
+  auto *ExtMapSpace = isl_map_get_space(ExtMap);
+  auto *ConstrSpace = isl_local_space_from_space(ExtMapSpace);
+  auto *Constr =
+      isl_constraint_alloc_inequality(isl_local_space_copy(ConstrSpace));
+  Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, 1);
+  Constr =
+      isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound * (-1));
+  ExtMap = isl_map_add_constraint(ExtMap, Constr);
+  Constr = isl_constraint_alloc_inequality(ConstrSpace);
+  Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, -1);
+  Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound);
+  Constr = isl_constraint_set_constant_si(Constr, Bound - 1);
+  return isl_map_add_constraint(ExtMap, Constr);
+}
+
+/// Create an access relation that is specific for matrix multiplication
+/// pattern.
+///
+/// Create an access relation of the following form:
+/// { [O0, O1, O2]->[I1, I2, I3] :
+///   FirstOutputDimBound * O0 <= I1 <= FirstOutputDimBound * (O0 + 1) - 1
+///   and SecondOutputDimBound * O1 <= I2 <= SecondOutputDimBound * (O1 + 1) - 1
+///   and ThirdOutputDimBound * O2 <= I3 <= ThirdOutputDimBound * (O2 + 1) - 1}
+///   where FirstOutputDimBound is @p FirstOutputDimBound,
+///   SecondOutputDimBound is @p SecondOutputDimBound,
+///   ThirdOutputDimBound is @p ThirdOutputDimBound
+///
+/// @param Ctx The isl context.
+/// @param FirstOutputDimBound,
+///        SecondOutputDimBound,
+///        ThirdOutputDimBound The parameters of the access relation.
+/// @return The specified access relation.
+__isl_give isl_map *getMatMulExt(isl_ctx *Ctx, unsigned FirstOutputDimBound,
+                                 unsigned SecondOutputDimBound,
+                                 unsigned ThirdOutputDimBound) {
+  auto *NewRelSpace = isl_space_alloc(Ctx, 0, 3, 3);
+  auto *extensionMap = isl_map_universe(NewRelSpace);
+  if (!FirstOutputDimBound)
+    extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 0, 0);
+  else
+    extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 0,
+                                                      FirstOutputDimBound);
+  if (!SecondOutputDimBound)
+    extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 1, 0);
+  else
+    extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 1,
+                                                      SecondOutputDimBound);
+  if (!ThirdOutputDimBound)
+    extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 2, 0);
+  else
+    extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 2,
+                                                      ThirdOutputDimBound);
+  return extensionMap;
+}
+
 /// Create an access relation that is specific to the matrix
 ///        multiplication pattern.
 ///
@@ -758,6 +828,14 @@
   return isl_map_apply_range(MapOldIndVar, AccessRel);
 }
 
+__isl_give isl_schedule_node *
+createExtensionNode(__isl_take isl_schedule_node *Node,
+                    __isl_take isl_map *ExtensionMap) {
+  auto *Extension = isl_union_map_from_map(ExtensionMap);
+  auto *NewNode = isl_schedule_node_from_extension(Extension);
+  return isl_schedule_node_graft_before(Node, NewNode);
+}
+
 /// Apply the packing transformation.
 ///
 /// The packing transformation can be described as a data-layout
@@ -772,9 +850,9 @@
 /// @param MicroParams, MacroParams Parameters of the BLIS kernel
 ///                                 to be taken into account.
 /// @return The optimized schedule node.
-static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
-                                             MicroKernelParamsTy MicroParams,
-                                             MacroKernelParamsTy MacroParams) {
+static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
+    __isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar,
+    MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) {
   auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in);
   auto *Stmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
   isl_id_free(InputDimsId);
@@ -782,8 +860,12 @@
   MemoryAccess *MemAccessB = identifyAccessB(Stmt);
   if (!MemAccessA || !MemAccessB) {
     isl_map_free(MapOldIndVar);
-    return;
+    return Node;
   }
+  Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
+  Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
+  Node = isl_schedule_node_parent(Node);
+  Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0);
   auto *AccRel =
       getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 6);
   unsigned FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr;
@@ -791,14 +873,34 @@
   auto *SAI = Stmt->getParent()->createScopArrayInfo(
       MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize});
   AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
+  auto *OldAcc = MemAccessA->getAccessRelation();
   MemAccessA->setNewAccessRelation(AccRel);
+  auto *ExtMap =
+      getMatMulExt(Stmt->getIslCtx(), MacroParams.Mc, 0, MacroParams.Kc);
+  ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 1, 1);
+  auto *Domain = Stmt->getDomain();
+  auto *NewStmt = Stmt->getParent()->addScopStmt(
+      OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain));
+  ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
+  Node = createExtensionNode(Node, ExtMap);
+  Node = isl_schedule_node_child(Node, 0);
   AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 7);
   FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr;
   SecondDimSize = MicroParams.Nr;
   SAI = Stmt->getParent()->createScopArrayInfo(
       MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize});
   AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
+  OldAcc = MemAccessB->getAccessRelation();
   MemAccessB->setNewAccessRelation(AccRel);
+  ExtMap = getMatMulExt(Stmt->getIslCtx(), 0, MacroParams.Nc, MacroParams.Kc);
+  isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 1, 1);
+  isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
+  NewStmt = Stmt->getParent()->addScopStmt(
+      OldAcc, MemAccessB->getAccessRelation(), Domain);
+  ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
+  Node = createExtensionNode(Node, ExtMap);
+  Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
+  return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
 }
 
 /// Get a relation mapping induction variables produced by schedule
@@ -842,9 +944,8 @@
       Node, MicroKernelParams, MacroKernelParams);
   if (!MapOldIndVar)
     return Node;
-  optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams,
-                                   MacroKernelParams);
-  return Node;
+  return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
+                                          MacroKernelParams);
 }
 
 bool ScheduleTreeOptimizer::isMatrMultPattern(
@@ -901,7 +1002,7 @@
 }
 
 bool ScheduleTreeOptimizer::isProfitableSchedule(
-    Scop &S, __isl_keep isl_union_map *NewSchedule) {
+    Scop &S, __isl_keep isl_schedule *NewSchedule) {
   // To understand if the schedule has been optimized we check if the schedule
   // has changed at all.
   // TODO: We can improve this by tracking if any necessarily beneficial
@@ -911,9 +1012,15 @@
   // optimizations, by comparing (yet to be defined) performance metrics
   // before/after the scheduling optimizer
   // (e.g., #stride-one accesses)
+  if (S.containsExtensionNode(NewSchedule))
+    return true;
+  auto *NewScheduleMap = isl_schedule_get_map(NewSchedule);
   isl_union_map *OldSchedule = S.getSchedule();
-  bool changed = !isl_union_map_is_equal(OldSchedule, NewSchedule);
+  assert(OldSchedule && "Only IslScheduleOptimizer can insert extension nodes "
+                        "that make Scop::getSchedule() return nullptr.");
+  bool changed = !isl_union_map_is_equal(OldSchedule, NewScheduleMap);
   isl_union_map_free(OldSchedule);
+  isl_union_map_free(NewScheduleMap);
   return changed;
 }
 
@@ -1090,10 +1197,8 @@
   auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
   isl_schedule *NewSchedule =
       ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI);
-  isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule);
 
-  if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) {
-    isl_union_map_free(NewScheduleMap);
+  if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule)) {
     isl_schedule_free(NewSchedule);
     return false;
   }
@@ -1104,7 +1209,6 @@
   if (OptimizedScops)
     S.dump();
 
-  isl_union_map_free(NewScheduleMap);
   return false;
 }