[MLIR] Remove uses of AffineExpr* outside of IR
This CL uniformizes the uses of AffineExprWrap outside of IR.
The public API of AffineExpr builder is modified to only use AffineExprWrap.
A few places access AffineExprWrap.expr, this is only while the API is in
transition to easily keep track (i.e. make expr private and let the compiler
track the errors).
Parser.cpp exhibits patterns that are dependent on nullptr values so
converting it is left for another CL.
PiperOrigin-RevId: 215642005
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index 74b607a..2a09d3c 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -32,55 +32,52 @@
/// and is substituted into. The ArrayRef 'eq' is expected to be in the format
/// [dims, symbols, locals, constant term].
// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
-static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
- unsigned numSymbols,
- ArrayRef<AffineExpr *> localExprs,
- MLIRContext *context) {
+static AffineExprWrap toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+ unsigned numSymbols,
+ ArrayRef<AffineExprWrap> localExprs,
+ MLIRContext *context) {
// Assert expected numLocals = eq.size() - numDims - numSymbols - 1
assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
"unexpected number of local expressions");
- AffineExpr *expr = AffineConstantExpr::get(0, context);
+ auto expr = AffineConstantExpr::get(0, context);
// Dimensions and symbols.
for (unsigned j = 0; j < numDims + numSymbols; j++) {
- if (eq[j] != 0) {
- AffineExpr *id =
- j < numDims
- ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
- : AffineSymbolExpr::get(j - numDims, context);
- auto *term = AffineBinaryOpExpr::getMul(
- AffineConstantExpr::get(eq[j], context), id, context);
- expr = AffineBinaryOpExpr::getAdd(expr, term, context);
+ if (eq[j] == 0) {
+ continue;
}
+ auto id = j < numDims ? AffineDimExpr::get(j, context)
+ : AffineSymbolExpr::get(j - numDims, context);
+ expr = expr + id * eq[j];
}
// Local identifiers.
for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
- if (eq[j] != 0) {
- auto *term = AffineBinaryOpExpr::getMul(
- localExprs[j - numDims - numSymbols], eq[j], context);
- expr = AffineBinaryOpExpr::getAdd(expr, term, context);
+ if (eq[j] == 0) {
+ continue;
}
+ auto term = localExprs[j - numDims - numSymbols] * eq[j];
+ expr = expr + term;
}
// Constant term.
unsigned constTerm = eq[eq.size() - 1];
if (constTerm != 0)
- expr = AffineBinaryOpExpr::getAdd(expr, constTerm, context);
+ expr = expr + constTerm;
return expr;
}
namespace {
-// This class is used to flatten a pure affine expression (AffineExpr *, which
-// is in a tree form) into a sum of products (w.r.t constants) when possible,
-// and in that process simplifying the expression. The simplification performed
-// includes the accumulation of contributions for each dimensional and symbolic
-// identifier together, the simplification of floordiv/ceildiv/mod exprssions
-// and other simplifications that in turn happen as a result. A simplification
-// that this flattening naturally performs is of simplifying the numerator and
-// denominator of floordiv/ceildiv, and folding a modulo expression to a zero,
-// if possible. Three examples are below:
+// This class is used to flatten a pure affine expression (AffineExprWrap,
+// which is in a tree form) into a sum of products (w.r.t constants) when
+// possible, and in that process simplifying the expression. The simplification
+// performed includes the accumulation of contributions for each dimensional and
+// symbolic identifier together, the simplification of floordiv/ceildiv/mod
+// expressions and other simplifications that in turn happen as a result. A
+// simplification that this flattening naturally performs is of simplifying the
+// numerator and denominator of floordiv/ceildiv, and folding a modulo
+// expression to a zero, if possible. Three examples are below:
//
// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
// (d0 - d0 mod 4 + 4) mod 4 simplified to 0.
@@ -127,12 +124,12 @@
unsigned numLocals;
// AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
// which new identifiers were introduced; if the latter do not get canceled
- // out, these expressions are needed to reconstruct the AffineExpr * / tree
+ // out, these expressions are needed to reconstruct the AffineExprWrap / tree
// form. Note that these expressions themselves would have been simplified
// (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be
// simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2
// would be the local expression stored for q.
- SmallVector<AffineExpr *, 4> localExprs;
+ SmallVector<AffineExprWrap, 4> localExprs;
MLIRContext *context;
AffineExprFlattener(unsigned numDims, unsigned numSymbols,
@@ -145,7 +142,7 @@
void visitMulExpr(AffineBinaryOpExpr *expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(isa<AffineConstantExpr>(expr->getRHS()));
+ assert(isa<AffineConstantExpr>(expr->getRHS().expr));
// Get the RHS constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
@@ -172,7 +169,7 @@
void visitModExpr(AffineBinaryOpExpr *expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(isa<AffineConstantExpr>(expr->getRHS()));
+ assert(isa<AffineConstantExpr>(expr->getRHS().expr));
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
auto &lhs = operandExprStack.back();
@@ -223,7 +220,7 @@
private:
void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
assert(operandExprStack.size() >= 2);
- assert(isa<AffineConstantExpr>(expr->getRHS()));
+ assert(isa<AffineConstantExpr>(expr->getRHS().expr));
// This is a pure affine expr; the RHS is a positive constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
// TODO(bondhugula): handle division by zero at the same time the issue is
@@ -262,9 +259,9 @@
}
// Add an existential quantifier (used to flatten a mod, floordiv, ceildiv
- // expr). localExpr is the simplified tree expression (AffineExpr *)
+ // expr). localExpr is the simplified tree expression (AffineExprWrap )
// corresponding to the quantifier.
- void addLocalId(AffineExpr *localExpr) {
+ void addLocalId(AffineExprWrap localExpr) {
for (auto &subExpr : operandExprStack) {
subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
}
@@ -280,22 +277,19 @@
} // end anonymous namespace
-AffineExpr *mlir::simplifyAffineExpr(AffineExpr *expr, unsigned numDims,
- unsigned numSymbols,
- MLIRContext *context) {
+AffineExprWrap mlir::simplifyAffineExpr(AffineExprWrap expr, unsigned numDims,
+ unsigned numSymbols) {
// TODO(bondhugula): only pure affine for now. The simplification here can be
// extended to semi-affine maps in the future.
if (!expr->isPureAffine())
return nullptr;
- AffineExprFlattener flattener(numDims, numSymbols, context);
+ AffineExprFlattener flattener(numDims, numSymbols, expr->getContext());
flattener.walkPostOrder(expr);
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
- auto *simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
- flattener.localExprs, context);
+ auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
+ flattener.localExprs, expr->getContext());
flattener.operandExprStack.pop_back();
assert(flattener.operandExprStack.empty());
- if (simplifiedExpr == expr)
- return nullptr;
return simplifiedExpr;
}
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index 6965137..0e88123 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -49,11 +49,9 @@
// be pure for the simplification implemented.
void MutableAffineMap::simplify() {
// Simplify each of the results if possible.
+ // TODO(ntv): functional-style map
for (unsigned i = 0, e = getNumResults(); i < e; i++) {
- AffineExpr *sExpr =
- simplifyAffineExpr(getResult(i), numDims, numSymbols, context);
- if (sExpr)
- results[i] = sExpr;
+ results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
}
}
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index c385a22..c745207 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -38,7 +38,7 @@
unsigned j = 0;
AffineBoundExprList::const_iterator it, e;
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
- if (auto *cExpr = dyn_cast<AffineConstantExpr>(*it)) {
+ if (auto *cExpr = dyn_cast<AffineConstantExpr>(it->expr)) {
if (val == None) {
val = cExpr->getValue();
*idx = j;
@@ -60,7 +60,7 @@
const AffineBoundExprList &rhsList, bool lb) {
// The list of bounds is going to be small. Just a linear search
// should be enough to create a list without duplicates.
- for (auto *expr : rhsList) {
+ for (auto expr : rhsList) {
AffineBoundExprList::const_iterator it;
for (it = lhsList.begin(); it != lhsList.end(); it++) {
if (expr == *it)
@@ -68,7 +68,7 @@
}
if (it == lhsList.end()) {
// There can only be one constant affine expr in this bound list.
- if (auto *cExpr = dyn_cast<AffineConstantExpr>(expr)) {
+ if (auto *cExpr = dyn_cast<AffineConstantExpr>(expr.expr)) {
unsigned idx;
if (lb) {
auto cb = getReducedConstBound(
@@ -105,8 +105,8 @@
}
HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
- ArrayRef<ArrayRef<AffineExpr *>> lbs,
- ArrayRef<ArrayRef<AffineExpr *>> ubs,
+ ArrayRef<ArrayRef<AffineExprWrap>> lbs,
+ ArrayRef<ArrayRef<AffineExprWrap>> ubs,
MLIRContext *context,
IntegerSet *symbolContext)
: context(symbolContext ? MutableIntegerSet(symbolContext, context)
@@ -114,7 +114,7 @@
unsigned d = 0;
for (auto boundList : lbs) {
AffineBoundExprList lb;
- for (auto *expr : boundList) {
+ for (auto expr : boundList) {
assert(expr->isSymbolicOrConstant() &&
"bound expression should be symbolic or constant");
lb.push_back(expr);
@@ -125,7 +125,7 @@
d = 0;
for (auto boundList : ubs) {
AffineBoundExprList ub;
- for (auto *expr : boundList) {
+ for (auto expr : boundList) {
assert(expr->isSymbolicOrConstant() &&
"bound expression should be symbolic or constant");
ub.push_back(expr);
@@ -161,7 +161,7 @@
unsigned d = 0;
for (auto &lb : lowerBounds) {
os << "Dim " << d++ << "\n";
- for (auto *expr : lb) {
+ for (auto expr : lb) {
expr->print(os);
}
}
@@ -169,7 +169,7 @@
os << "Upper bounds\n";
for (auto &lb : upperBounds) {
os << "Dim " << d++ << "\n";
- for (auto *expr : lb) {
+ for (auto expr : lb) {
expr->print(os);
}
}
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 1549182..5c64b2f 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -27,12 +27,12 @@
#include "mlir/IR/Statements.h"
#include "mlir/Support/MathExtras.h"
-using mlir::AffineExpr;
+using namespace mlir;
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExpr *mlir::getTripCountExpr(const ForStmt &forStmt) {
+AffineExprWrap mlir::getTripCountExpr(const ForStmt &forStmt) {
// upper_bound - lower_bound + 1
int64_t loopSpan;
@@ -56,16 +56,12 @@
return nullptr;
// ub_expr - lb_expr + 1
- auto *lbExpr = lbMap->getResult(0);
- auto *ubExpr = ubMap->getResult(0);
- auto *loopSpanExpr = AffineBinaryOpExpr::getAdd(
- AffineBinaryOpExpr::getSub(ubExpr, lbExpr, context), 1, context);
-
- if (auto *expr = simplifyAffineExpr(loopSpanExpr, lbMap->getNumDims(),
- lbMap->getNumSymbols(), context))
- loopSpanExpr = expr;
-
- auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
+ AffineExprWrap lbExpr(lbMap->getResult(0));
+ AffineExprWrap ubExpr(ubMap->getResult(0));
+ auto loopSpanExpr = simplifyAffineExpr(
+ ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
+ std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
+ auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr.expr);
if (!cExpr)
return AffineBinaryOpExpr::getCeilDiv(loopSpanExpr, step, context);
loopSpan = cExpr->getValue();
@@ -83,9 +79,10 @@
/// method uses affine expression analysis (in turn using getTripCount) and is
/// able to determine constant trip count in non-trivial cases.
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
- AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
+ auto tripCountExpr = getTripCountExpr(forStmt);
- if (auto *constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
+ if (auto *constExpr =
+ dyn_cast_or_null<AffineConstantExpr>(tripCountExpr.expr))
return constExpr->getValue();
return None;
@@ -95,12 +92,12 @@
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
- AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
+ auto tripCountExpr = getTripCountExpr(forStmt);
if (!tripCountExpr)
return 1;
- if (auto *constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
+ if (auto *constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr.expr)) {
uint64_t tripCount = constExpr->getValue();
// 0 iteration loops (greatest divisor is 2^64 - 1).
@@ -112,5 +109,5 @@
}
// Trip count is not a known constant; return its largest known divisor.
- return tripCountExpr->getLargestKnownDivisor();
+ return tripCountExpr.expr->getLargestKnownDivisor();
}