isl: vector code generation based on ISL ast
Original patch by Tobias Grosser, slightly modified by Sebastian Pop.
llvm-svn: 170420
diff --git a/polly/lib/CodeGen/CodeGeneration.cpp b/polly/lib/CodeGen/CodeGeneration.cpp
index 60302cb..ad70375 100644
--- a/polly/lib/CodeGen/CodeGeneration.cpp
+++ b/polly/lib/CodeGen/CodeGeneration.cpp
@@ -768,45 +768,9 @@
return true;
}
-int ClastStmtCodeGen::getNumberOfIterations(const clast_for *f) {
- isl_set *loopDomain = isl_set_copy(isl_set_from_cloog_domain(f->domain));
- isl_set *tmp = isl_set_copy(loopDomain);
-
- // Calculate a map similar to the identity map, but with the last input
- // and output dimension not related.
- // [i0, i1, i2, i3] -> [i0, i1, i2, o0]
- isl_space *Space = isl_set_get_space(loopDomain);
- Space = isl_space_drop_outputs(Space,
- isl_set_dim(loopDomain, isl_dim_set) - 2, 1);
- Space = isl_space_map_from_set(Space);
- isl_map *identity = isl_map_identity(Space);
- identity = isl_map_add_dims(identity, isl_dim_in, 1);
- identity = isl_map_add_dims(identity, isl_dim_out, 1);
-
- isl_map *map = isl_map_from_domain_and_range(tmp, loopDomain);
- map = isl_map_intersect(map, identity);
-
- isl_map *lexmax = isl_map_lexmax(isl_map_copy(map));
- isl_map *lexmin = isl_map_lexmin(map);
- isl_map *sub = isl_map_sum(lexmax, isl_map_neg(lexmin));
-
- isl_set *elements = isl_map_range(sub);
-
- if (!isl_set_is_singleton(elements)) {
- isl_set_free(elements);
- return -1;
- }
-
- isl_point *p = isl_set_sample_point(elements);
-
- isl_int v;
- isl_int_init(v);
- isl_point_get_coordinate(p, isl_dim_set, isl_set_n_dim(loopDomain) - 1, &v);
- int numberIterations = isl_int_get_si(v);
- isl_int_clear(v);
- isl_point_free(p);
-
- return (numberIterations) / isl_int_get_si(f->stride) + 1;
+int ClastStmtCodeGen::getNumberOfIterations(const clast_for *For) {
+ isl_set *LoopDomain = isl_set_copy(isl_set_from_cloog_domain(For->domain));
+ return polly::getNumberOfIterations(LoopDomain) / isl_int_get_si(For->stride) + 1;
}
void ClastStmtCodeGen::codegenForVector(const clast_for *F) {
diff --git a/polly/lib/CodeGen/IslAst.cpp b/polly/lib/CodeGen/IslAst.cpp
index d3d65dc..20738a3 100644
--- a/polly/lib/CodeGen/IslAst.cpp
+++ b/polly/lib/CodeGen/IslAst.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "polly/CodeGen/CodeGeneration.h"
#include "polly/CodeGen/IslAst.h"
#include "polly/LinkAllPasses.h"
@@ -31,7 +32,6 @@
#include "isl/union_map.h"
#include "isl/list.h"
-#include "isl/ast.h"
#include "isl/ast_build.h"
#include "isl/set.h"
#include "isl/map.h"
@@ -68,24 +68,6 @@
};
} // End namespace polly.
-
-static void IslAstUserFree(void *User)
-{
- struct IslAstUser *UserStruct = (struct IslAstUser *) User;
- isl_ast_build_free(UserStruct->Context);
- isl_pw_multi_aff_free(UserStruct->PMA);
- free(UserStruct);
-}
-
-// Information about an ast node.
-struct AstNodeUserInfo {
- // The node is the outermost parallel loop.
- int IsOutermostParallel;
-
- // The node is the innermost parallel loop.
- int IsInnermostParallel;
-};
-
// Temporary information used when building the ast.
struct AstBuildUserInfo {
// The dependence information.
@@ -99,7 +81,7 @@
static __isl_give isl_printer *
printParallelFor(__isl_keep isl_ast_node *Node, __isl_take isl_printer *Printer,
__isl_take isl_ast_print_options *PrintOptions,
- AstNodeUserInfo *Info) {
+ IslAstUser *Info) {
if (Info) {
if (Info->IsInnermostParallel) {
Printer = isl_printer_start_line(Printer);
@@ -124,26 +106,29 @@
if (!Id)
return isl_ast_node_for_print(Node, Printer, PrintOptions);
- struct AstNodeUserInfo *Info = (struct AstNodeUserInfo *) isl_id_get_user(Id);
+ struct IslAstUser *Info = (struct IslAstUser *) isl_id_get_user(Id);
Printer = printParallelFor(Node, Printer, PrintOptions, Info);
isl_id_free(Id);
return Printer;
}
// Allocate an AstNodeInfo structure and initialize it with default values.
-static struct AstNodeUserInfo *allocateAstNodeUserInfo() {
- struct AstNodeUserInfo *NodeInfo;
- NodeInfo = (struct AstNodeUserInfo *) malloc(sizeof(struct AstNodeUserInfo));
+static struct IslAstUser *allocateIslAstUser() {
+ struct IslAstUser *NodeInfo;
+ NodeInfo = (struct IslAstUser *) malloc(sizeof(struct IslAstUser));
+ NodeInfo->PMA = 0;
+ NodeInfo->Context = 0;
NodeInfo->IsOutermostParallel = 0;
NodeInfo->IsInnermostParallel = 0;
return NodeInfo;
}
// Free the AstNodeInfo structure.
-static void freeAstNodeUserInfo(void *Ptr) {
- struct AstNodeUserInfo *Info;
- Info = (struct AstNodeUserInfo *) Ptr;
- free(Info);
+static void freeIslAstUser(void *Ptr) {
+ struct IslAstUser *UserStruct = (struct IslAstUser *) Ptr;
+ isl_ast_build_free(UserStruct->Context);
+ isl_pw_multi_aff_free(UserStruct->PMA);
+ free(UserStruct);
}
// Check if the current scheduling dimension is parallel.
@@ -200,7 +185,7 @@
// Mark a for node openmp parallel, if it is the outermost parallel for node.
static void markOpenmpParallel(__isl_keep isl_ast_build *Build,
struct AstBuildUserInfo *BuildInfo,
- struct AstNodeUserInfo *NodeInfo) {
+ struct IslAstUser *NodeInfo) {
if (BuildInfo->InParallelFor)
return;
@@ -219,14 +204,10 @@
//
static __isl_give isl_id *astBuildBeforeFor(__isl_keep isl_ast_build *Build,
void *User) {
- isl_id *Id;
- struct AstBuildUserInfo *BuildInfo;
- struct AstNodeUserInfo *NodeInfo;
-
- BuildInfo = (struct AstBuildUserInfo *) User;
- NodeInfo = allocateAstNodeUserInfo();
- Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo);
- Id = isl_id_set_free_user(Id, freeAstNodeUserInfo);
+ struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *) User;
+ struct IslAstUser *NodeInfo = allocateIslAstUser();
+ isl_id *Id = isl_id_alloc(isl_ast_build_get_ctx(Build), "", NodeInfo);
+ Id = isl_id_set_free_user(Id, freeIslAstUser);
markOpenmpParallel(Build, BuildInfo, NodeInfo);
@@ -286,7 +267,7 @@
isl_id *Id = isl_ast_node_get_annotation(Node);
if (!Id)
return Node;
- struct AstNodeUserInfo *Info = (struct AstNodeUserInfo *) isl_id_get_user(Id);
+ struct IslAstUser *Info = (struct IslAstUser *) isl_id_get_user(Id);
struct AstBuildUserInfo *BuildInfo = (struct AstBuildUserInfo *) User;
if (Info) {
if (Info->IsOutermostParallel)
@@ -296,28 +277,36 @@
Info->IsInnermostParallel = 1;
}
- isl_id_free(Id);
+ if (!Info->Context)
+ Info->Context = isl_ast_build_copy(Build);
+ isl_id_free(Id);
return Node;
}
static __isl_give isl_ast_node *
-AtEachDomain(__isl_keep isl_ast_node *Node,
+AtEachDomain(__isl_take isl_ast_node *Node,
__isl_keep isl_ast_build *Context, void *User)
{
- isl_map *Map;
- struct IslAstUser *UserStruct;
+ struct IslAstUser *Info = NULL;
+ isl_id *Id = isl_ast_node_get_annotation(Node);
- UserStruct = (struct IslAstUser *) malloc(sizeof(struct IslAstUser));
+ if (Id)
+ Info = (struct IslAstUser *) isl_id_get_user(Id);
- Map = isl_map_from_union_map(isl_ast_build_get_schedule(Context));
- UserStruct->PMA = isl_pw_multi_aff_from_map(isl_map_reverse(Map));
- UserStruct->Context = isl_ast_build_copy(Context);
+ if (!Info) {
+ // Allocate annotations once: parallel for detection might have already
+ // allocated the annotations for this node.
+ Info = allocateIslAstUser();
+ Id = isl_id_alloc(isl_ast_node_get_ctx(Node), NULL, Info);
+ Id = isl_id_set_free_user(Id, &freeIslAstUser);
+ }
- isl_id *Annotation = isl_id_alloc(isl_ast_node_get_ctx(Node), NULL,
- UserStruct);
- Annotation = isl_id_set_free_user(Annotation, &IslAstUserFree);
- return isl_ast_node_set_annotation(Node, Annotation);
+ isl_map *Map = isl_map_from_union_map(isl_ast_build_get_schedule(Context));
+ Info->PMA = isl_pw_multi_aff_from_map(isl_map_reverse(Map));
+ Info->Context = isl_ast_build_copy(Context);
+
+ return isl_ast_node_set_annotation(Node, Id);
}
IslAst::IslAst(Scop *Scop, Dependences &D) : S(Scop) {
@@ -343,7 +332,7 @@
isl_union_map_dump(Schedule);
);
- if (DetectParallel) {
+ if (DetectParallel || PollyVectorizerChoice != VECTORIZER_NONE) {
BuildInfo.Deps = &D;
BuildInfo.InParallelFor = 0;
diff --git a/polly/lib/CodeGen/IslCodeGeneration.cpp b/polly/lib/CodeGen/IslCodeGeneration.cpp
index 9bfdd1d..90ec0f9 100644
--- a/polly/lib/CodeGen/IslCodeGeneration.cpp
+++ b/polly/lib/CodeGen/IslCodeGeneration.cpp
@@ -26,6 +26,7 @@
#include "polly/TempScopInfo.h"
#include "polly/CodeGen/IslAst.h"
#include "polly/CodeGen/BlockGenerators.h"
+#include "polly/CodeGen/CodeGeneration.h"
#include "polly/CodeGen/LoopGenerators.h"
#include "polly/CodeGen/Utils.h"
#include "polly/Support/GICHelper.h"
@@ -579,8 +580,23 @@
__isl_give isl_ast_expr *getUpperBound(__isl_keep isl_ast_node *For,
CmpInst::Predicate &Predicate);
+ unsigned getNumberOfIterations(__isl_keep isl_ast_node *For);
+
void createFor(__isl_take isl_ast_node *For);
+ void createForVector(__isl_take isl_ast_node *For, int VectorWidth);
+ void createForSequential(__isl_take isl_ast_node *For);
+ void createSubstitutions(__isl_take isl_pw_multi_aff *PMA,
+ __isl_take isl_ast_build *Context,
+ ScopStmt *Stmt, ValueMapT &VMap);
+ void createSubstitutionsVector(__isl_take isl_pw_multi_aff *PMA,
+ __isl_take isl_ast_build *Context,
+ ScopStmt *Stmt, VectorValueMapT &VMap,
+ std::vector<Value*> &IVS,
+ __isl_take isl_id *IteratorID);
void createIf(__isl_take isl_ast_node *If);
+ void createUserVector(__isl_take isl_ast_node *User,
+ std::vector<Value*> &IVS, __isl_take isl_id *IteratorID,
+ __isl_take isl_union_map *Schedule);
void createUser(__isl_take isl_ast_node *User);
void createBlock(__isl_take isl_ast_node *Block);
};
@@ -635,7 +651,128 @@
return UB;
}
-void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) {
+unsigned IslNodeBuilder::getNumberOfIterations(__isl_keep isl_ast_node *For) {
+ isl_id *Annotation = isl_ast_node_get_annotation(For);
+ if (!Annotation)
+ return -1;
+
+ struct IslAstUser *Info = (struct IslAstUser *) isl_id_get_user(Annotation);
+ if (!Info) {
+ isl_id_free(Annotation);
+ return -1;
+ }
+
+ isl_union_map *Schedule = isl_ast_build_get_schedule(Info->Context);
+ isl_set *LoopDomain = isl_set_from_union_set(isl_union_map_range(Schedule));
+ isl_id_free(Annotation);
+ return polly::getNumberOfIterations(LoopDomain) + 1;
+}
+
+void IslNodeBuilder::createUserVector(__isl_take isl_ast_node *User,
+ std::vector<Value*> &IVS,
+ __isl_take isl_id *IteratorID,
+ __isl_take isl_union_map *Schedule) {
+ isl_id *Annotation = isl_ast_node_get_annotation(User);
+ assert(Annotation && "Vector user statement is not annotated");
+
+ struct IslAstUser *Info = (struct IslAstUser *) isl_id_get_user(Annotation);
+ assert(Info && "Vector user statement annotation does not contain info");
+
+ isl_id *Id = isl_pw_multi_aff_get_tuple_id(Info->PMA, isl_dim_out);
+ ScopStmt *Stmt = (ScopStmt *) isl_id_get_user(Id);
+ VectorValueMapT VectorMap(IVS.size());
+
+ isl_union_set *Domain = isl_union_set_from_set(Stmt->getDomain());
+ Schedule = isl_union_map_intersect_domain(Schedule, Domain);
+ isl_map *S = isl_map_from_union_map(Schedule);
+
+ createSubstitutionsVector(isl_pw_multi_aff_copy(Info->PMA),
+ isl_ast_build_copy(Info->Context),
+ Stmt, VectorMap, IVS, IteratorID);
+ VectorBlockGenerator::generate(Builder, *Stmt, VectorMap, S, P);
+
+
+ isl_map_free(S);
+ isl_id_free(Annotation);
+ isl_id_free(Id);
+ isl_ast_node_free(User);
+}
+
+void IslNodeBuilder::createForVector(__isl_take isl_ast_node *For,
+ int VectorWidth) {
+ isl_ast_node *Body = isl_ast_node_for_get_body(For);
+ isl_ast_expr *Init = isl_ast_node_for_get_init(For);
+ isl_ast_expr *Inc = isl_ast_node_for_get_inc(For);
+ isl_ast_expr *Iterator = isl_ast_node_for_get_iterator(For);
+ isl_id *IteratorID = isl_ast_expr_get_id(Iterator);
+ CmpInst::Predicate Predicate;
+ isl_ast_expr *UB = getUpperBound(For, Predicate);
+
+ Value *ValueLB = ExprBuilder.create(Init);
+ Value *ValueUB = ExprBuilder.create(UB);
+ Value *ValueInc = ExprBuilder.create(Inc);
+
+ Type *MaxType = ExprBuilder.getType(Iterator);
+ MaxType = ExprBuilder.getWidestType(MaxType, ValueLB->getType());
+ MaxType = ExprBuilder.getWidestType(MaxType, ValueUB->getType());
+ MaxType = ExprBuilder.getWidestType(MaxType, ValueInc->getType());
+
+ if (MaxType != ValueLB->getType())
+ ValueLB = Builder.CreateSExt(ValueLB, MaxType);
+ if (MaxType != ValueUB->getType())
+ ValueUB = Builder.CreateSExt(ValueUB, MaxType);
+ if (MaxType != ValueInc->getType())
+ ValueInc = Builder.CreateSExt(ValueInc, MaxType);
+
+ std::vector<Value*> IVS(VectorWidth);
+ IVS[0] = ValueLB;
+
+ for (int i = 1; i < VectorWidth; i++)
+ IVS[i] = Builder.CreateAdd(IVS[i-1], ValueInc, "p_vector_iv");
+
+ isl_id *Annotation = isl_ast_node_get_annotation(For);
+ assert(Annotation && "For statement is not annotated");
+
+ struct IslAstUser *Info = (struct IslAstUser *) isl_id_get_user(Annotation);
+ assert(Info && "For statement annotation does not contain info");
+
+ isl_union_map *Schedule = isl_ast_build_get_schedule(Info->Context);
+ assert(Schedule && "For statement annotation does not contain its schedule");
+
+ IDToValue[IteratorID] = ValueLB;
+
+ switch (isl_ast_node_get_type(Body)) {
+ case isl_ast_node_user:
+ createUserVector(Body, IVS, isl_id_copy(IteratorID),
+ isl_union_map_copy(Schedule));
+ break;
+ case isl_ast_node_block: {
+ isl_ast_node_list *List = isl_ast_node_block_get_children(Body);
+
+ for (int i = 0; i < isl_ast_node_list_n_ast_node(List); ++i)
+ createUserVector(isl_ast_node_list_get_ast_node(List, i), IVS,
+ isl_id_copy(IteratorID),
+ isl_union_map_copy(Schedule));
+
+ isl_ast_node_free(Body);
+ isl_ast_node_list_free(List);
+ break;
+ }
+ default:
+ isl_ast_node_dump(Body);
+ llvm_unreachable("Unhandled isl_ast_node in vectorizer");
+ }
+
+ IDToValue.erase(IteratorID);
+ isl_id_free(IteratorID);
+ isl_id_free(Annotation);
+ isl_union_map_free(Schedule);
+
+ isl_ast_node_free(For);
+ isl_ast_expr_free(Iterator);
+}
+
+void IslNodeBuilder::createForSequential(__isl_take isl_ast_node *For) {
isl_ast_node *Body;
isl_ast_expr *Init, *Inc, *Iterator, *UB;
isl_id *IteratorID;
@@ -696,6 +833,19 @@
isl_id_free(IteratorID);
}
+void IslNodeBuilder::createFor(__isl_take isl_ast_node *For) {
+ bool Vector = PollyVectorizerChoice != VECTORIZER_NONE;
+
+ if (Vector && isInnermostParallel(For)) {
+ int VectorWidth = getNumberOfIterations(For);
+ if (1 < VectorWidth && VectorWidth <= 16) {
+ createForVector(For, VectorWidth);
+ return;
+ }
+ }
+ createForSequential(For);
+}
+
void IslNodeBuilder::createIf(__isl_take isl_ast_node *If) {
isl_ast_expr *Cond = isl_ast_node_if_get_cond(If);
@@ -738,26 +888,18 @@
isl_ast_node_free(If);
}
-void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
- ValueMapT VMap;
- struct IslAstUser *UserInfo;
- isl_id *Annotation, *Id;
- ScopStmt *Stmt;
-
- Annotation = isl_ast_node_get_annotation(User);
- UserInfo = (struct IslAstUser *) isl_id_get_user(Annotation);
- Id = isl_pw_multi_aff_get_tuple_id(UserInfo->PMA, isl_dim_out);
- Stmt = (ScopStmt *) isl_id_get_user(Id);
-
- for (unsigned i = 0; i < isl_pw_multi_aff_dim(UserInfo->PMA, isl_dim_out);
+void IslNodeBuilder::createSubstitutions(__isl_take isl_pw_multi_aff *PMA,
+ __isl_take isl_ast_build *Context,
+ ScopStmt *Stmt, ValueMapT &VMap) {
+ for (unsigned i = 0; i < isl_pw_multi_aff_dim(PMA, isl_dim_out);
++i) {
isl_pw_aff *Aff;
isl_ast_expr *Expr;
const Value *OldIV;
Value *V;
- Aff = isl_pw_multi_aff_get_pw_aff(UserInfo->PMA, i);
- Expr = isl_ast_build_expr_from_pw_aff(UserInfo->Context, Aff);
+ Aff = isl_pw_multi_aff_get_pw_aff(PMA, i);
+ Expr = isl_ast_build_expr_from_pw_aff(Context, Aff);
OldIV = Stmt->getInductionVariableForDimension(i);
V = ExprBuilder.create(Expr);
@@ -768,6 +910,48 @@
VMap[OldIV] = V;
}
+ isl_pw_multi_aff_free(PMA);
+ isl_ast_build_free(Context);
+}
+
+void IslNodeBuilder::createSubstitutionsVector(__isl_take isl_pw_multi_aff *PMA,
+ __isl_take isl_ast_build *Context, ScopStmt *Stmt, VectorValueMapT &VMap,
+ std::vector<Value*> &IVS, __isl_take isl_id *IteratorID) {
+ int i = 0;
+
+ Value *OldValue = IDToValue[IteratorID];
+ for (std::vector<Value*>::iterator II = IVS.begin(), IE = IVS.end();
+ II != IE; ++II) {
+ IDToValue[IteratorID] = *II;
+ createSubstitutions(isl_pw_multi_aff_copy(PMA),
+ isl_ast_build_copy(Context), Stmt, VMap[i]);
+ i++;
+ }
+
+ IDToValue[IteratorID] = OldValue;
+ isl_id_free(IteratorID);
+ isl_pw_multi_aff_free(PMA);
+ isl_ast_build_free(Context);
+}
+
+void IslNodeBuilder::createUser(__isl_take isl_ast_node *User) {
+ ValueMapT VMap;
+ struct IslAstUser *Info;
+ isl_id *Annotation, *Id;
+ ScopStmt *Stmt;
+
+ Annotation = isl_ast_node_get_annotation(User);
+ assert(Annotation && "Scalar user statement is not annotated");
+
+ Info = (struct IslAstUser *) isl_id_get_user(Annotation);
+ assert(Info && "Scalar user statement annotation does not contain info");
+
+ Id = isl_pw_multi_aff_get_tuple_id(Info->PMA, isl_dim_out);
+ Stmt = (ScopStmt *) isl_id_get_user(Id);
+
+ createSubstitutions(isl_pw_multi_aff_copy(Info->PMA),
+ isl_ast_build_copy(Info->Context), Stmt, VMap);
+
BlockGenerator::generate(Builder, *Stmt, VMap, P);
isl_ast_node_free(User);