Support for affine integer sets
- introduce affine integer sets into the IR
- parse and print affine integer sets (both inline or outlined) similar to
affine maps
- use integer set for IfStmt's conditional, and implement parsing of IfStmt's
conditional
- fixed an affine expr paren omission bug while one this.
TODO: parse/represent/print MLValue operands to affine integer set references.
PiperOrigin-RevId: 207779408
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 2f67c27..c2cd257 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -24,6 +24,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
@@ -72,6 +73,18 @@
ArrayRef<const AffineMap *> getAffineMapIds() const { return affineMapsById; }
+ int getIntegerSetId(const IntegerSet *integerSet) const {
+ auto it = integerSetIds.find(integerSet);
+ if (it == integerSetIds.end()) {
+ return -1;
+ }
+ return it->second;
+ }
+
+ ArrayRef<const IntegerSet *> getIntegerSetIds() const {
+ return integerSetsById;
+ }
+
private:
void recordAffineMapReference(const AffineMap *affineMap) {
if (affineMapIds.count(affineMap) == 0) {
@@ -80,17 +93,31 @@
}
}
+ void recordIntegerSetReference(const IntegerSet *integerSet) {
+ if (integerSetIds.count(integerSet) == 0) {
+ integerSetIds[integerSet] = integerSetsById.size();
+ integerSetsById.push_back(integerSet);
+ }
+ }
+
// Visit functions.
void visitFunction(const Function *fn);
void visitExtFunction(const ExtFunction *fn);
void visitCFGFunction(const CFGFunction *fn);
void visitMLFunction(const MLFunction *fn);
+ void visitStatement(const Statement *stmt);
+ void visitForStmt(const ForStmt *forStmt);
+ void visitIfStmt(const IfStmt *ifStmt);
+ void visitOperationStmt(const OperationStmt *opStmt);
void visitType(const Type *type);
void visitAttribute(const Attribute *attr);
void visitOperation(const Operation *op);
DenseMap<const AffineMap *, int> affineMapIds;
std::vector<const AffineMap *> affineMapsById;
+
+ DenseMap<const IntegerSet *, int> integerSetIds;
+ std::vector<const IntegerSet *> integerSetsById;
};
} // end anonymous namespace
@@ -113,8 +140,8 @@
void ModuleState::visitAttribute(const Attribute *attr) {
if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr)) {
recordAffineMapReference(mapAttr->getValue());
- } else if (auto *array = dyn_cast<ArrayAttr>(attr)) {
- for (auto elt : array->getValue()) {
+ } else if (auto *arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ for (auto elt : arrayAttr->getValue()) {
visitAttribute(elt);
}
}
@@ -145,9 +172,42 @@
}
}
+void ModuleState::visitIfStmt(const IfStmt *ifStmt) {
+ recordIntegerSetReference(ifStmt->getCondition());
+ for (auto &childStmt : *ifStmt->getThenClause())
+ visitStatement(&childStmt);
+ if (ifStmt->hasElseClause())
+ for (auto &childStmt : *ifStmt->getElseClause())
+ visitStatement(&childStmt);
+}
+
+void ModuleState::visitForStmt(const ForStmt *forStmt) {
+ for (auto &childStmt : *forStmt)
+ visitStatement(&childStmt);
+}
+
+void ModuleState::visitOperationStmt(const OperationStmt *opStmt) {
+ // TODO: visit any attributes if necessary.
+}
+
+void ModuleState::visitStatement(const Statement *stmt) {
+ switch (stmt->getKind()) {
+ case Statement::Kind::If:
+ return visitIfStmt(cast<IfStmt>(stmt));
+ case Statement::Kind::For:
+ return visitForStmt(cast<ForStmt>(stmt));
+ case Statement::Kind::Operation:
+ return visitOperationStmt(cast<OperationStmt>(stmt));
+ default:
+ return;
+ }
+}
+
void ModuleState::visitMLFunction(const MLFunction *fn) {
visitType(fn->getType());
- // TODO Visit function body statements (and attributes if required).
+ for (auto &stmt : *fn) {
+ ModuleState::visitStatement(&stmt);
+ }
}
void ModuleState::visitFunction(const Function *fn) {
@@ -161,7 +221,7 @@
}
}
-// Initializes module state, populating affine map state.
+// Initializes module state, populating affine map and integer set state.
void ModuleState::initialize(const Module *module) {
for (auto &fn : *module) {
visitFunction(&fn);
@@ -194,6 +254,8 @@
void printAffineMap(const AffineMap *map);
void printAffineExpr(const AffineExpr *expr);
+ void printAffineConstraint(const AffineExpr *expr, bool isEq);
+ void printIntegerSet(const IntegerSet *set);
protected:
raw_ostream &os;
@@ -203,6 +265,8 @@
void printFunctionResultType(const FunctionType *type);
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(const AffineMap *affineMap);
+ void printIntegerSetId(int integerSetId) const;
+ void printIntegerSetReference(const IntegerSet *integerSet);
/// This enum is used to represent the binding stength of the enclosing
/// context that an AffineExpr is being printed in, so we can intelligently
@@ -244,6 +308,22 @@
}
}
+// Prints integer set identifier.
+void ModulePrinter::printIntegerSetId(int integerSetId) const {
+ os << "@@set" << integerSetId;
+}
+
+void ModulePrinter::printIntegerSetReference(const IntegerSet *integerSet) {
+ int setId;
+ if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
+ // The set will be printed at top of module; so print reference to its id.
+ printIntegerSetId(setId);
+ } else {
+ // Set not in module state so print inline.
+ integerSet->print(os);
+ }
+}
+
void ModulePrinter::print(const Module *module) {
for (const auto &map : state.getAffineMapIds()) {
printAffineMapId(state.getAffineMapId(map));
@@ -251,6 +331,12 @@
map->print(os);
os << '\n';
}
+ for (const auto &set : state.getIntegerSetIds()) {
+ printIntegerSetId(state.getIntegerSetId(set));
+ os << " = ";
+ set->print(os);
+ os << '\n';
+ }
for (auto const &fn : *module)
print(&fn);
}
@@ -473,7 +559,9 @@
if (auto *rhs = dyn_cast<AffineConstantExpr>(binOp->getRHS())) {
if (rhs->getValue() < 0) {
printAffineExprInternal(binOp->getLHS(), BindingStrength::Weak);
- os << " - " << -rhs->getValue() << ')';
+ os << " - " << -rhs->getValue();
+ if (enclosingTightness == BindingStrength::Strong)
+ os << ')';
return;
}
}
@@ -486,6 +574,11 @@
os << ')';
}
+void ModulePrinter::printAffineConstraint(const AffineExpr *expr, bool isEq) {
+ printAffineExprInternal(expr, BindingStrength::Weak);
+ isEq ? os << " == 0" : os << " >= 0";
+}
+
void ModulePrinter::printAffineMap(const AffineMap *map) {
// Dimension identifiers.
os << '(';
@@ -524,6 +617,38 @@
os << ')';
}
+void ModulePrinter::printIntegerSet(const IntegerSet *set) {
+ // Dimension identifiers.
+ os << '(';
+ for (unsigned i = 1; i < set->getNumDims(); ++i)
+ os << 'd' << i - 1 << ", ";
+ if (set->getNumDims() >= 1)
+ os << 'd' << set->getNumDims() - 1;
+ os << ')';
+
+ // Symbolic identifiers.
+ if (set->getNumSymbols() != 0) {
+ os << '[';
+ for (unsigned i = 0; i < set->getNumSymbols() - 1; ++i)
+ os << 's' << i << ", ";
+ if (set->getNumSymbols() >= 1)
+ os << 's' << set->getNumSymbols() - 1;
+ os << ']';
+ }
+
+ // Print constraints.
+ os << " : (";
+ auto numConstraints = set->getNumConstraints();
+ for (int i = 1; i < numConstraints; ++i) {
+ printAffineConstraint(set->getConstraint(i - 1), set->isEq(i - 1));
+ os << ", ";
+ }
+ if (numConstraints >= 1)
+ printAffineConstraint(set->getConstraint(numConstraints - 1),
+ set->isEq(numConstraints - 1));
+ os << ')';
+}
+
//===----------------------------------------------------------------------===//
// Function printing
//===----------------------------------------------------------------------===//
@@ -582,6 +707,9 @@
void printAffineMap(const AffineMap *map) {
return ModulePrinter::printAffineMapReference(map);
}
+ void printIntegerSet(const IntegerSet *set) {
+ return ModulePrinter::printIntegerSetReference(set);
+ }
void printAffineExpr(const AffineExpr *expr) {
return ModulePrinter::printAffineExpr(expr);
}
@@ -1139,7 +1267,9 @@
}
void MLFunctionPrinter::print(const IfStmt *stmt) {
- os.indent(numSpaces) << "if () {\n";
+ os.indent(numSpaces) << "if (";
+ printIntegerSetReference(stmt->getCondition());
+ os << ") {\n";
print(stmt->getThenClause());
os.indent(numSpaces) << "}";
if (stmt->hasElseClause()) {
@@ -1181,6 +1311,11 @@
llvm::errs() << "\n";
}
+void IntegerSet::dump() const {
+ print(llvm::errs());
+ llvm::errs() << "\n";
+}
+
void AffineExpr::print(raw_ostream &os) const {
ModuleState state(/*no context is known*/ nullptr);
ModulePrinter(os, state).printAffineExpr(this);
@@ -1191,6 +1326,11 @@
ModulePrinter(os, state).printAffineMap(this);
}
+void IntegerSet::print(raw_ostream &os) const {
+ ModuleState state(/*no context is known*/ nullptr);
+ ModulePrinter(os, state).printIntegerSet(this);
+}
+
void SSAValue::print(raw_ostream &os) const {
switch (getKind()) {
case SSAValueKind::BBArgument: