Restrict ranges of extension maps

To prevent copy statements from accessing arrays out of bounds, ranges of their
extension maps are restricted, according to the constraints of domains.

Reviewed-by: Michael Kruse <llvm@meinersbur.de>

Differential Revision: https://reviews.llvm.org/D25655

llvm-svn: 289815
diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp
index df3a08f..7f68573 100644
--- a/polly/lib/Transform/ScheduleOptimizer.cpp
+++ b/polly/lib/Transform/ScheduleOptimizer.cpp
@@ -851,6 +851,8 @@
 static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
     __isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar,
     MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) {
+  // Check whether memory accesses of the SCoP statement correspond to
+  // the matrix multiplication pattern and if this is true, obtain them.
   auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in);
   auto *Stmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
   isl_id_free(InputDimsId);
@@ -860,6 +862,9 @@
     isl_map_free(MapOldIndVar);
     return Node;
   }
+
+  // Create a copy statement that corresponds to the memory access to the
+  // matrix B, the second operand of the matrix multiplication.
   Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
   Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
   Node = isl_schedule_node_parent(Node);
@@ -879,10 +884,19 @@
   isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
   ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 2, 1);
   auto *Domain = Stmt->getDomain();
+
+  // Restrict the domains of the copy statements to only execute when also its
+  // originating statement is executed.
+  auto *DomainId = isl_set_get_tuple_id(Domain);
   auto *NewStmt = Stmt->getParent()->addScopStmt(
       OldAcc, MemAccessB->getAccessRelation(), isl_set_copy(Domain));
+  ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, isl_id_copy(DomainId));
+  ExtMap = isl_map_intersect_range(ExtMap, isl_set_copy(Domain));
   ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
   Node = createExtensionNode(Node, ExtMap);
+
+  // Create a copy statement that corresponds to the memory access
+  // to the matrix A, the first operand of the matrix multiplication.
   Node = isl_schedule_node_child(Node, 0);
   AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 6);
   FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr;
@@ -896,7 +910,12 @@
   isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 0, 1);
   isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
   NewStmt = Stmt->getParent()->addScopStmt(
-      OldAcc, MemAccessA->getAccessRelation(), Domain);
+      OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain));
+
+  // Restrict the domains of the copy statements to only execute when also its
+  // originating statement is executed.
+  ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, DomainId);
+  ExtMap = isl_map_intersect_range(ExtMap, Domain);
   ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
   Node = createExtensionNode(Node, ExtMap);
   Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);