Apply all necessary tilings and unrollings to get a micro-kernel

This is the first patch to apply the BLIS matmul optimization pattern
on matmul kernels
(http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
BLIS implements gemm as three nested loops around a macro-kernel,
plus two packing routines. The macro-kernel is implemented in terms
of two additional loops around a micro-kernel. The micro-kernel
is a loop around a rank-1 (i.e., outer product) update.
In this change we create the BLIS micro-kernel by applying
a combination of tiling and unrolling. In subsequent changes
we will add the extraction of the BLIS macro-kernel
and implement the packing transformation.

Contributed-by: Roman Gareev <gareevroman@gmail.com>
Reviewed-by: Tobias Grosser <tobias@grosser.es>

Differential Revision: http://reviews.llvm.org/D21140

llvm-svn: 273397
diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp
index cc8118b..6d0450f 100644
--- a/polly/lib/Transform/ScheduleOptimizer.cpp
+++ b/polly/lib/Transform/ScheduleOptimizer.cpp
@@ -53,6 +53,7 @@
 #include "polly/Options.h"
 #include "polly/ScopInfo.h"
 #include "polly/Support/GICHelper.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Support/Debug.h"
 #include "isl/aff.h"
 #include "isl/band.h"
@@ -119,6 +120,20 @@
                                       cl::init(true), cl::ZeroOrMore,
                                       cl::cat(PollyCategory));
 
+static cl::opt<int> LatencyVectorFma(
+    "polly-target-latency-vector-fma",
+    cl::desc("The minimal number of cycles between issuing two "
+             "dependent consecutive vector fused multiply-add "
+             "instructions."),
+    cl::Hidden, cl::init(8), cl::ZeroOrMore, cl::cat(PollyCategory));
+
+static cl::opt<int> ThrougputVectorFma(
+    "polly-target-througput-vector-fma",
+    cl::desc("A throughput of the processor floating-point arithmetic units "
+             "expressed in the number of vector fused multiply-add "
+             "instructions per clock cycle."),
+    cl::Hidden, cl::init(1), cl::ZeroOrMore, cl::cat(PollyCategory));
+
 static cl::opt<int> FirstLevelDefaultTileSize(
     "polly-default-tile-size",
     cl::desc("The default tile size (if not enough were provided by"
@@ -478,6 +493,23 @@
   return isl_map_set_tuple_id(IslMap, isl_dim_in, InputDimsId);
 }
 
+__isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeMatMulPattern(
+    __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) {
+  assert(TTI && "The target transform info should be provided.");
+  // Get a micro-kernel.
+  // Nvec - Number of double-precision floating-point numbers that can be hold
+  // by a vector register. Use 2 by default.
+  auto Nvec = TTI->getRegisterBitWidth(true) / 64;
+  if (Nvec == 0)
+    Nvec = 2;
+  int Nr =
+      ceil(sqrt(Nvec * LatencyVectorFma * ThrougputVectorFma) / Nvec) * Nvec;
+  int Mr = ceil(Nvec * LatencyVectorFma * ThrougputVectorFma / Nr);
+  std::vector<int> MicroKernelParams{Mr, Nr};
+  Node = applyRegisterTiling(Node, MicroKernelParams, 1);
+  return Node;
+}
+
 bool ScheduleTreeOptimizer::isMatrMultPattern(
     __isl_keep isl_schedule_node *Node) {
   auto *PartialSchedule =
@@ -508,16 +540,21 @@
   if (!isTileableBandNode(Node))
     return Node;
 
-  if (PMBasedOpts && isMatrMultPattern(Node))
+  if (PMBasedOpts && User && isMatrMultPattern(Node)) {
     DEBUG(dbgs() << "The matrix multiplication pattern was detected\n");
+    const llvm::TargetTransformInfo *TTI;
+    TTI = static_cast<const llvm::TargetTransformInfo *>(User);
+    Node = optimizeMatMulPattern(Node, TTI);
+  }
 
   return standardBandOpts(Node, User);
 }
 
 __isl_give isl_schedule *
-ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule) {
+ScheduleTreeOptimizer::optimizeSchedule(__isl_take isl_schedule *Schedule,
+                                        const llvm::TargetTransformInfo *TTI) {
   isl_schedule_node *Root = isl_schedule_get_root(Schedule);
-  Root = optimizeScheduleNode(Root);
+  Root = optimizeScheduleNode(Root, TTI);
   isl_schedule_free(Schedule);
   auto S = isl_schedule_node_get_schedule(Root);
   isl_schedule_node_free(Root);
@@ -525,8 +562,9 @@
 }
 
 __isl_give isl_schedule_node *ScheduleTreeOptimizer::optimizeScheduleNode(
-    __isl_take isl_schedule_node *Node) {
-  Node = isl_schedule_node_map_descendant_bottom_up(Node, optimizeBand, NULL);
+    __isl_take isl_schedule_node *Node, const llvm::TargetTransformInfo *TTI) {
+  Node = isl_schedule_node_map_descendant_bottom_up(
+      Node, optimizeBand, const_cast<void *>(static_cast<const void *>(TTI)));
   return Node;
 }
 
@@ -714,7 +752,10 @@
     isl_printer_free(P);
   });
 
-  isl_schedule *NewSchedule = ScheduleTreeOptimizer::optimizeSchedule(Schedule);
+  Function &F = S.getFunction();
+  auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
+  isl_schedule *NewSchedule =
+      ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI);
   isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule);
 
   if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) {
@@ -752,6 +793,7 @@
 void IslScheduleOptimizer::getAnalysisUsage(AnalysisUsage &AU) const {
   ScopPass::getAnalysisUsage(AU);
   AU.addRequired<DependenceInfo>();
+  AU.addRequired<TargetTransformInfoWrapperPass>();
 }
 
 Pass *polly::createIslScheduleOptimizerPass() {
@@ -762,5 +804,6 @@
                       "Polly - Optimize schedule of SCoP", false, false);
 INITIALIZE_PASS_DEPENDENCY(DependenceInfo);
 INITIALIZE_PASS_DEPENDENCY(ScopInfoRegionPass);
+INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass);
 INITIALIZE_PASS_END(IslScheduleOptimizer, "polly-opt-isl",
                     "Polly - Optimize schedule of SCoP", false, false)