Determination of statements that contain matrix multiplication

Add determination of statements that contain, in particular,
matrix multiplications and can be optimized with [1] to try to
get close-to-peak performance. It can be enabled
via polly-pm-based-opts, which is false by default.

Refs:
[1] - http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf

Contributed-by: Roman Gareev <gareevroman@gmail.com>
Reviewed-by: Tobias Grosser <tobias@grosser.es>

Differential Revision: http://reviews.llvm.org/D20575

llvm-svn: 271128
diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp
index b517879..7859194 100644
--- a/polly/lib/Transform/ScheduleOptimizer.cpp
+++ b/polly/lib/Transform/ScheduleOptimizer.cpp
@@ -166,6 +166,11 @@
                       cl::Hidden, cl::ZeroOrMore, cl::CommaSeparated,
                       cl::cat(PollyCategory));
 
+static cl::opt<bool>
+    PMBasedOpts("polly-pattern-matching-based-opts",
+                cl::desc("Perform optimizations based on pattern matching"),
+                cl::init(false), cl::ZeroOrMore, cl::cat(PollyCategory));
+
 /// @brief Create an isl_union_set, which describes the isolate option based
 ///        on IsoalteDomain.
 ///
@@ -359,11 +364,8 @@
 }
 
 __isl_give isl_schedule_node *
-ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
-                                    void *User) {
-  if (!isTileableBandNode(Node))
-    return Node;
-
+ScheduleTreeOptimizer::standardBandOpts(__isl_take isl_schedule_node *Node,
+                                        void *User) {
   if (FirstLevelTiling)
     Node = tileNode(Node, "1st level tiling", FirstLevelTileSizes,
                     FirstLevelDefaultTileSize);
@@ -396,6 +398,110 @@
   return Node;
 }
 
+/// @brief Check whether output dimensions of the map rely on the specified
+///        input dimension.
+///
+/// @param IslMap The isl map to be considered.
+/// @param DimNum The number of an input dimension to be checked.
+static bool isInputDimUsed(__isl_take isl_map *IslMap, unsigned DimNum) {
+  auto *CheckedAccessRelation =
+      isl_map_project_out(isl_map_copy(IslMap), isl_dim_in, DimNum, 1);
+  CheckedAccessRelation =
+      isl_map_insert_dims(CheckedAccessRelation, isl_dim_in, DimNum, 1);
+  auto *InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
+  CheckedAccessRelation =
+      isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_in, InputDimsId);
+  InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_out);
+  CheckedAccessRelation =
+      isl_map_set_tuple_id(CheckedAccessRelation, isl_dim_out, InputDimsId);
+  auto res = !isl_map_is_equal(CheckedAccessRelation, IslMap);
+  isl_map_free(CheckedAccessRelation);
+  isl_map_free(IslMap);
+  return res;
+}
+
+/// @brief Check if the SCoP statement could probably be optimized with
+///        analytical modeling.
+///
+/// containsMatrMult tries to determine whether the following conditions
+/// are true:
+/// 1. all memory accesses of the statement will have stride 0 or 1,
+///    if we interchange loops (switch the variable used in the inner
+///    loop to the outer loop).
+/// 2. all memory accesses of the statement except from the last one, are
+///    read memory access and the last one is write memory access.
+/// 3. all subscripts of the last memory access of the statement don’t contain
+///    the variable used in the inner loop.
+///
+/// @param PartialSchedule The PartialSchedule that contains a SCoP statement
+///        to check.
+static bool containsMatrMult(__isl_keep isl_map *PartialSchedule) {
+  auto InputDimsId = isl_map_get_tuple_id(PartialSchedule, isl_dim_in);
+  auto *ScpStmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
+  isl_id_free(InputDimsId);
+  if (ScpStmt->size() <= 1)
+    return false;
+  auto MemA = ScpStmt->begin();
+  for (unsigned i = 0; i < ScpStmt->size() - 2 && MemA != ScpStmt->end();
+       i++, MemA++)
+    if (!(*MemA)->isRead() or
+        ((*MemA)->isArrayKind() and
+         !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
+           (*MemA)->isStrideZero(isl_map_copy(PartialSchedule)))))
+      return false;
+  MemA++;
+  if (!(*MemA)->isWrite() or !(*MemA)->isArrayKind() or
+      !((*MemA)->isStrideOne(isl_map_copy(PartialSchedule)) or
+        (*MemA)->isStrideZero(isl_map_copy(PartialSchedule))))
+    return false;
+  auto DimNum = isl_map_dim(PartialSchedule, isl_dim_in);
+  return !isInputDimUsed((*MemA)->getAccessRelation(), DimNum - 1);
+}
+
+/// @brief Circular shift of output dimensions of the integer map.
+///
+/// @param IslMap The isl map to be modified.
+static __isl_give isl_map *circularShiftOutputDims(__isl_take isl_map *IslMap) {
+  auto InputDimsId = isl_map_get_tuple_id(IslMap, isl_dim_in);
+  auto DimNum = isl_map_dim(IslMap, isl_dim_out);
+  IslMap = isl_map_move_dims(IslMap, isl_dim_in, 0, isl_dim_out, DimNum - 1, 1);
+  IslMap = isl_map_move_dims(IslMap, isl_dim_out, 0, isl_dim_in, 0, 1);
+  return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
+}
+
+bool ScheduleTreeOptimizer::isMatrMultPattern(
+    __isl_keep isl_schedule_node *Node) {
+  auto *PartialSchedule =
+      isl_schedule_node_band_get_partial_schedule_union_map(Node);
+  if (isl_union_map_n_map(PartialSchedule) != 1)
+    return false;
+  auto *NewPartialSchedule = isl_map_from_union_map(PartialSchedule);
+  auto DimNum = isl_map_dim(NewPartialSchedule, isl_dim_in);
+  if (DimNum != 3) {
+    isl_map_free(NewPartialSchedule);
+    return false;
+  }
+  NewPartialSchedule = circularShiftOutputDims(NewPartialSchedule);
+  if (containsMatrMult(NewPartialSchedule)) {
+    isl_map_free(NewPartialSchedule);
+    return true;
+  }
+  isl_map_free(NewPartialSchedule);
+  return false;
+}
+
+__isl_give isl_schedule_node *
+ScheduleTreeOptimizer::optimizeBand(__isl_take isl_schedule_node *Node,
+                                    void *User) {
+  if (!isTileableBandNode(Node))
+    return Node;
+
+  if (PMBasedOpts && isMatrMultPattern(Node))
+    DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
+
+  return standardBandOpts(Node, User);
+}
+
 __isl_give isl_schedule *
 ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
   isl_schedule_node *Root = isl_schedule_get_root(Schedule);