Perform copying to created arrays according to the packing transformation
This is the fourth patch to apply the BLIS matmul optimization pattern on matmul
kernels (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
BLIS implements gemm as three nested loops around a macro-kernel, plus two
packing routines. The macro-kernel is implemented in terms of two additional
loops around a micro-kernel. The micro-kernel is a loop around a rank-1
(i.e., outer product) update. In this change we perform copying to created
arrays, which is the last step to implement the packing transformation.
Reviewed-by: Tobias Grosser <tobias@grosser.es>
Differential Revision: https://reviews.llvm.org/D23260
llvm-svn: 281441
diff --git a/polly/lib/Transform/ScheduleOptimizer.cpp b/polly/lib/Transform/ScheduleOptimizer.cpp
index b89374d..050ad67 100644
--- a/polly/lib/Transform/ScheduleOptimizer.cpp
+++ b/polly/lib/Transform/ScheduleOptimizer.cpp
@@ -660,6 +660,76 @@
return IdentifiedAccess;
}
+/// Add constrains to @Dim dimension of @p ExtMap.
+///
+/// If @ExtMap has the following form [O0, O1, O2]->[I1, I2, I3],
+/// the following constraint will be added
+/// Bound * OM <= IM <= Bound * (OM + 1) - 1,
+/// where M is @p Dim and Bound is @p Bound.
+///
+/// @param ExtMap The isl map to be modified.
+/// @param Dim The output dimension to be modfied.
+/// @param Bound The value that is used to specify the constraint.
+/// @return The modified isl map
+__isl_give isl_map *
+addExtensionMapMatMulDimConstraint(__isl_take isl_map *ExtMap, unsigned Dim,
+ unsigned Bound) {
+ assert(Bound != 0);
+ auto *ExtMapSpace = isl_map_get_space(ExtMap);
+ auto *ConstrSpace = isl_local_space_from_space(ExtMapSpace);
+ auto *Constr =
+ isl_constraint_alloc_inequality(isl_local_space_copy(ConstrSpace));
+ Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, 1);
+ Constr =
+ isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound * (-1));
+ ExtMap = isl_map_add_constraint(ExtMap, Constr);
+ Constr = isl_constraint_alloc_inequality(ConstrSpace);
+ Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_out, Dim, -1);
+ Constr = isl_constraint_set_coefficient_si(Constr, isl_dim_in, Dim, Bound);
+ Constr = isl_constraint_set_constant_si(Constr, Bound - 1);
+ return isl_map_add_constraint(ExtMap, Constr);
+}
+
+/// Create an access relation that is specific for matrix multiplication
+/// pattern.
+///
+/// Create an access relation of the following form:
+/// { [O0, O1, O2]->[I1, I2, I3] :
+/// FirstOutputDimBound * O0 <= I1 <= FirstOutputDimBound * (O0 + 1) - 1
+/// and SecondOutputDimBound * O1 <= I2 <= SecondOutputDimBound * (O1 + 1) - 1
+/// and ThirdOutputDimBound * O2 <= I3 <= ThirdOutputDimBound * (O2 + 1) - 1}
+/// where FirstOutputDimBound is @p FirstOutputDimBound,
+/// SecondOutputDimBound is @p SecondOutputDimBound,
+/// ThirdOutputDimBound is @p ThirdOutputDimBound
+///
+/// @param Ctx The isl context.
+/// @param FirstOutputDimBound,
+/// SecondOutputDimBound,
+/// ThirdOutputDimBound The parameters of the access relation.
+/// @return The specified access relation.
+__isl_give isl_map *getMatMulExt(isl_ctx *Ctx, unsigned FirstOutputDimBound,
+ unsigned SecondOutputDimBound,
+ unsigned ThirdOutputDimBound) {
+ auto *NewRelSpace = isl_space_alloc(Ctx, 0, 3, 3);
+ auto *extensionMap = isl_map_universe(NewRelSpace);
+ if (!FirstOutputDimBound)
+ extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 0, 0);
+ else
+ extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 0,
+ FirstOutputDimBound);
+ if (!SecondOutputDimBound)
+ extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 1, 0);
+ else
+ extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 1,
+ SecondOutputDimBound);
+ if (!ThirdOutputDimBound)
+ extensionMap = isl_map_fix_si(extensionMap, isl_dim_out, 2, 0);
+ else
+ extensionMap = addExtensionMapMatMulDimConstraint(extensionMap, 2,
+ ThirdOutputDimBound);
+ return extensionMap;
+}
+
/// Create an access relation that is specific to the matrix
/// multiplication pattern.
///
@@ -758,6 +828,14 @@
return isl_map_apply_range(MapOldIndVar, AccessRel);
}
+__isl_give isl_schedule_node *
+createExtensionNode(__isl_take isl_schedule_node *Node,
+ __isl_take isl_map *ExtensionMap) {
+ auto *Extension = isl_union_map_from_map(ExtensionMap);
+ auto *NewNode = isl_schedule_node_from_extension(Extension);
+ return isl_schedule_node_graft_before(Node, NewNode);
+}
+
/// Apply the packing transformation.
///
/// The packing transformation can be described as a data-layout
@@ -772,9 +850,9 @@
/// @param MicroParams, MacroParams Parameters of the BLIS kernel
/// to be taken into account.
/// @return The optimized schedule node.
-static void optimizeDataLayoutMatrMulPattern(__isl_take isl_map *MapOldIndVar,
- MicroKernelParamsTy MicroParams,
- MacroKernelParamsTy MacroParams) {
+static __isl_give isl_schedule_node *optimizeDataLayoutMatrMulPattern(
+ __isl_take isl_schedule_node *Node, __isl_take isl_map *MapOldIndVar,
+ MicroKernelParamsTy MicroParams, MacroKernelParamsTy MacroParams) {
auto InputDimsId = isl_map_get_tuple_id(MapOldIndVar, isl_dim_in);
auto *Stmt = static_cast<ScopStmt *>(isl_id_get_user(InputDimsId));
isl_id_free(InputDimsId);
@@ -782,8 +860,12 @@
MemoryAccess *MemAccessB = identifyAccessB(Stmt);
if (!MemAccessA || !MemAccessB) {
isl_map_free(MapOldIndVar);
- return;
+ return Node;
}
+ Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
+ Node = isl_schedule_node_parent(isl_schedule_node_parent(Node));
+ Node = isl_schedule_node_parent(Node);
+ Node = isl_schedule_node_child(isl_schedule_node_band_split(Node, 2), 0);
auto *AccRel =
getMatMulAccRel(isl_map_copy(MapOldIndVar), MacroParams.Kc, 3, 6);
unsigned FirstDimSize = MacroParams.Mc * MacroParams.Kc / MicroParams.Mr;
@@ -791,14 +873,34 @@
auto *SAI = Stmt->getParent()->createScopArrayInfo(
MemAccessA->getElementType(), "Packed_A", {FirstDimSize, SecondDimSize});
AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
+ auto *OldAcc = MemAccessA->getAccessRelation();
MemAccessA->setNewAccessRelation(AccRel);
+ auto *ExtMap =
+ getMatMulExt(Stmt->getIslCtx(), MacroParams.Mc, 0, MacroParams.Kc);
+ ExtMap = isl_map_project_out(ExtMap, isl_dim_in, 1, 1);
+ auto *Domain = Stmt->getDomain();
+ auto *NewStmt = Stmt->getParent()->addScopStmt(
+ OldAcc, MemAccessA->getAccessRelation(), isl_set_copy(Domain));
+ ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
+ Node = createExtensionNode(Node, ExtMap);
+ Node = isl_schedule_node_child(Node, 0);
AccRel = getMatMulAccRel(MapOldIndVar, MacroParams.Kc, 4, 7);
FirstDimSize = MacroParams.Nc * MacroParams.Kc / MicroParams.Nr;
SecondDimSize = MicroParams.Nr;
SAI = Stmt->getParent()->createScopArrayInfo(
MemAccessB->getElementType(), "Packed_B", {FirstDimSize, SecondDimSize});
AccRel = isl_map_set_tuple_id(AccRel, isl_dim_out, SAI->getBasePtrId());
+ OldAcc = MemAccessB->getAccessRelation();
MemAccessB->setNewAccessRelation(AccRel);
+ ExtMap = getMatMulExt(Stmt->getIslCtx(), 0, MacroParams.Nc, MacroParams.Kc);
+ isl_map_move_dims(ExtMap, isl_dim_out, 0, isl_dim_in, 1, 1);
+ isl_map_move_dims(ExtMap, isl_dim_in, 2, isl_dim_out, 0, 1);
+ NewStmt = Stmt->getParent()->addScopStmt(
+ OldAcc, MemAccessB->getAccessRelation(), Domain);
+ ExtMap = isl_map_set_tuple_id(ExtMap, isl_dim_out, NewStmt->getDomainId());
+ Node = createExtensionNode(Node, ExtMap);
+ Node = isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
+ return isl_schedule_node_child(isl_schedule_node_child(Node, 0), 0);
}
/// Get a relation mapping induction variables produced by schedule
@@ -842,9 +944,8 @@
Node, MicroKernelParams, MacroKernelParams);
if (!MapOldIndVar)
return Node;
- optimizeDataLayoutMatrMulPattern(MapOldIndVar, MicroKernelParams,
- MacroKernelParams);
- return Node;
+ return optimizeDataLayoutMatrMulPattern(Node, MapOldIndVar, MicroKernelParams,
+ MacroKernelParams);
}
bool ScheduleTreeOptimizer::isMatrMultPattern(
@@ -901,7 +1002,7 @@
}
bool ScheduleTreeOptimizer::isProfitableSchedule(
- Scop &S, __isl_keep isl_union_map *NewSchedule) {
+ Scop &S, __isl_keep isl_schedule *NewSchedule) {
// To understand if the schedule has been optimized we check if the schedule
// has changed at all.
// TODO: We can improve this by tracking if any necessarily beneficial
@@ -911,9 +1012,15 @@
// optimizations, by comparing (yet to be defined) performance metrics
// before/after the scheduling optimizer
// (e.g., #stride-one accesses)
+ if (S.containsExtensionNode(NewSchedule))
+ return true;
+ auto *NewScheduleMap = isl_schedule_get_map(NewSchedule);
isl_union_map *OldSchedule = S.getSchedule();
- bool changed = !isl_union_map_is_equal(OldSchedule, NewSchedule);
+ assert(OldSchedule && "Only IslScheduleOptimizer can insert extension nodes "
+ "that make Scop::getSchedule() return nullptr.");
+ bool changed = !isl_union_map_is_equal(OldSchedule, NewScheduleMap);
isl_union_map_free(OldSchedule);
+ isl_union_map_free(NewScheduleMap);
return changed;
}
@@ -1090,10 +1197,8 @@
auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
isl_schedule *NewSchedule =
ScheduleTreeOptimizer::optimizeSchedule(Schedule, TTI);
- isl_union_map *NewScheduleMap = isl_schedule_get_map(NewSchedule);
- if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewScheduleMap)) {
- isl_union_map_free(NewScheduleMap);
+ if (!ScheduleTreeOptimizer::isProfitableSchedule(S, NewSchedule)) {
isl_schedule_free(NewSchedule);
return false;
}
@@ -1104,7 +1209,6 @@
if (OptimizedScops)
S.dump();
- isl_union_map_free(NewScheduleMap);
return false;
}