Perform copying to created arrays according to the packing transformation
This is the fourth patch to apply the BLIS matmul optimization pattern on matmul
kernels (http://www.cs.utexas.edu/users/flame/pubs/TOMS-BLIS-Analytical.pdf).
BLIS implements gemm as three nested loops around a macro-kernel, plus two
packing routines. The macro-kernel is implemented in terms of two additional
loops around a micro-kernel. The micro-kernel is a loop around a rank-1
(i.e., outer product) update. In this change we perform copying to created
arrays, which is the last step to implement the packing transformation.
Reviewed-by: Tobias Grosser <tobias@grosser.es>
Differential Revision: https://reviews.llvm.org/D23260
llvm-svn: 281441
diff --git a/polly/lib/CodeGen/BlockGenerators.cpp b/polly/lib/CodeGen/BlockGenerators.cpp
index 6480a3a..70c5d98 100644
--- a/polly/lib/CodeGen/BlockGenerators.cpp
+++ b/polly/lib/CodeGen/BlockGenerators.cpp
@@ -681,7 +681,9 @@
void BlockGenerator::invalidateScalarEvolution(Scop &S) {
for (auto &Stmt : S)
- if (Stmt.isBlockStmt())
+ if (Stmt.isCopyStmt())
+ continue;
+ else if (Stmt.isBlockStmt())
for (auto &Inst : *Stmt.getBasicBlock())
SE.forgetValue(&Inst);
else if (Stmt.isRegionStmt())
diff --git a/polly/lib/CodeGen/IRBuilder.cpp b/polly/lib/CodeGen/IRBuilder.cpp
index cedbe29..92fce35 100644
--- a/polly/lib/CodeGen/IRBuilder.cpp
+++ b/polly/lib/CodeGen/IRBuilder.cpp
@@ -61,7 +61,8 @@
SetVector<Value *> BasePtrs;
for (ScopStmt &Stmt : S)
for (MemoryAccess *MA : Stmt)
- BasePtrs.insert(MA->getBaseAddr());
+ if (!Stmt.isCopyStmt())
+ BasePtrs.insert(MA->getBaseAddr());
std::string AliasScopeStr = "polly.alias.scope.";
for (Value *BasePtr : BasePtrs)
diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp
index 4b7cde0..2f00b69 100644
--- a/polly/lib/CodeGen/IslAst.cpp
+++ b/polly/lib/CodeGen/IslAst.cpp
@@ -593,8 +593,7 @@
P = isl_ast_node_print(RootNode, P, Options);
AstStr = isl_printer_get_str(P);
- isl_union_map *Schedule =
- isl_union_map_intersect_domain(S.getSchedule(), S.getDomains());
+ auto *Schedule = S.getScheduleTree();
DEBUG({
dbgs() << S.getContextStr() << "\n";
@@ -609,7 +608,7 @@
free(AstStr);
isl_ast_expr_free(RunCondition);
- isl_union_map_free(Schedule);
+ isl_schedule_free(Schedule);
isl_ast_node_free(RootNode);
isl_printer_free(P);
}
diff --git a/polly/lib/CodeGen/IslNodeBuilder.cpp b/polly/lib/CodeGen/IslNodeBuilder.cpp
index 0134323..cea97b7 100644
--- a/polly/lib/CodeGen/IslNodeBuilder.cpp
+++ b/polly/lib/CodeGen/IslNodeBuilder.cpp
@@ -767,6 +767,23 @@
isl_ast_expr_free(Expr);
}
+void IslNodeBuilder::generateCopyStmt(
+ ScopStmt *Stmt, __isl_keep isl_id_to_ast_expr *NewAccesses) {
+ assert(Stmt->size() == 2);
+ auto ReadAccess = Stmt->begin();
+ auto WriteAccess = ReadAccess++;
+ assert((*ReadAccess)->isRead() && (*WriteAccess)->isMustWrite());
+ assert((*ReadAccess)->getElementType() == (*WriteAccess)->getElementType() &&
+ "Accesses use the same data type");
+ assert((*ReadAccess)->isArrayKind() && (*WriteAccess)->isArrayKind());
+ auto *AccessExpr =
+ isl_id_to_ast_expr_get(NewAccesses, (*ReadAccess)->getId());
+ auto *LoadValue = ExprBuilder.create(AccessExpr);
+ AccessExpr = isl_id_to_ast_expr_get(NewAccesses, (*WriteAccess)->getId());
+ auto *StoreAddr = ExprBuilder.createAccessAddress(AccessExpr);
+ Builder.CreateStore(LoadValue, StoreAddr);
+}
+
void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
LoopToScevMapT LTS;
isl_id *Id;
@@ -781,12 +798,17 @@
Stmt = (ScopStmt *)isl_id_get_user(Id);
auto *NewAccesses = createNewAccesses(Stmt, User);
- createSubstitutions(Expr, Stmt, LTS);
+ if (Stmt->isCopyStmt()) {
+ generateCopyStmt(Stmt, NewAccesses);
+ isl_ast_expr_free(Expr);
+ } else {
+ createSubstitutions(Expr, Stmt, LTS);
- if (Stmt->isBlockStmt())
- BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
- else
- RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
+ if (Stmt->isBlockStmt())
+ BlockGen.copyStmt(*Stmt, LTS, NewAccesses);
+ else
+ RegionGen.copyStmt(*Stmt, LTS, NewAccesses);
+ }
isl_id_to_ast_expr_free(NewAccesses);
isl_ast_node_free(User);