[MLIR] Templated AffineExprBaseRef
This CL implements AffineExprBaseRef as a templated type to allow LLVM-style
casts to work properly. This also allows making AffineExprBaseRef::expr
private.
To achieve this, it is necessary to use llvm::simplify_type and make
AffineConstExpr derive from both AffineExpr and llvm::simplify<AffineExprRef>.
Note that llvm::simplify_type is just an interface to enable the proper
template resolution of isa/cast/dyn_cast but it otherwise holds no value.
Lastly note that certain dyn_cast operations wanted the const AffineExpr* form
of AffineExprBaseRef so I made the implicit constructor take that by default
and documented the immutable behavior. I think this is consistent with the
decision to make unique'd type immutable by convention and never use const on
them.
PiperOrigin-RevId: 215642247
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index 2a09d3c..ebcc50a 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -32,10 +32,10 @@
/// and is substituted into. The ArrayRef 'eq' is expected to be in the format
/// [dims, symbols, locals, constant term].
// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
-static AffineExprWrap toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
- unsigned numSymbols,
- ArrayRef<AffineExprWrap> localExprs,
- MLIRContext *context) {
+static AffineExprRef toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+ unsigned numSymbols,
+ ArrayRef<AffineExprRef> localExprs,
+ MLIRContext *context) {
// Assert expected numLocals = eq.size() - numDims - numSymbols - 1
assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
"unexpected number of local expressions");
@@ -69,7 +69,7 @@
namespace {
-// This class is used to flatten a pure affine expression (AffineExprWrap,
+// This class is used to flatten a pure affine expression (AffineExprRef,
// which is in a tree form) into a sum of products (w.r.t constants) when
// possible, and in that process simplifying the expression. The simplification
// performed includes the accumulation of contributions for each dimensional and
@@ -124,12 +124,12 @@
unsigned numLocals;
// AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
// which new identifiers were introduced; if the latter do not get canceled
- // out, these expressions are needed to reconstruct the AffineExprWrap / tree
+ // out, these expressions are needed to reconstruct the AffineExprRef / tree
// form. Note that these expressions themselves would have been simplified
// (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be
// simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2
// would be the local expression stored for q.
- SmallVector<AffineExprWrap, 4> localExprs;
+ SmallVector<AffineExprRef, 4> localExprs;
MLIRContext *context;
AffineExprFlattener(unsigned numDims, unsigned numSymbols,
@@ -142,7 +142,7 @@
void visitMulExpr(AffineBinaryOpExpr *expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(isa<AffineConstantExpr>(expr->getRHS().expr));
+ assert(isa<AffineConstantExpr>(expr->getRHS()));
// Get the RHS constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
@@ -169,7 +169,7 @@
void visitModExpr(AffineBinaryOpExpr *expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(isa<AffineConstantExpr>(expr->getRHS().expr));
+ assert(isa<AffineConstantExpr>(expr->getRHS()));
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
auto &lhs = operandExprStack.back();
@@ -220,7 +220,7 @@
private:
void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
assert(operandExprStack.size() >= 2);
- assert(isa<AffineConstantExpr>(expr->getRHS().expr));
+ assert(isa<AffineConstantExpr>(expr->getRHS()));
// This is a pure affine expr; the RHS is a positive constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
// TODO(bondhugula): handle division by zero at the same time the issue is
@@ -259,9 +259,9 @@
}
// Add an existential quantifier (used to flatten a mod, floordiv, ceildiv
- // expr). localExpr is the simplified tree expression (AffineExprWrap )
+ // expr). localExpr is the simplified tree expression (AffineExprRef)
// corresponding to the quantifier.
- void addLocalId(AffineExprWrap localExpr) {
+ void addLocalId(AffineExprRef localExpr) {
for (auto &subExpr : operandExprStack) {
subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
}
@@ -277,8 +277,8 @@
} // end anonymous namespace
-AffineExprWrap mlir::simplifyAffineExpr(AffineExprWrap expr, unsigned numDims,
- unsigned numSymbols) {
+AffineExprRef mlir::simplifyAffineExpr(AffineExprRef expr, unsigned numDims,
+ unsigned numSymbols) {
// TODO(bondhugula): only pure affine for now. The simplification here can be
// extended to semi-affine maps in the future.
if (!expr->isPureAffine())
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index c745207..3a6183b 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -38,7 +38,7 @@
unsigned j = 0;
AffineBoundExprList::const_iterator it, e;
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
- if (auto *cExpr = dyn_cast<AffineConstantExpr>(it->expr)) {
+ if (auto *cExpr = dyn_cast<AffineConstantExpr>(*it)) {
if (val == None) {
val = cExpr->getValue();
*idx = j;
@@ -68,7 +68,7 @@
}
if (it == lhsList.end()) {
// There can only be one constant affine expr in this bound list.
- if (auto *cExpr = dyn_cast<AffineConstantExpr>(expr.expr)) {
+ if (auto cExpr = dyn_cast<AffineConstantExpr>(expr)) {
unsigned idx;
if (lb) {
auto cb = getReducedConstBound(
@@ -105,8 +105,8 @@
}
HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
- ArrayRef<ArrayRef<AffineExprWrap>> lbs,
- ArrayRef<ArrayRef<AffineExprWrap>> ubs,
+ ArrayRef<ArrayRef<AffineExprRef>> lbs,
+ ArrayRef<ArrayRef<AffineExprRef>> ubs,
MLIRContext *context,
IntegerSet *symbolContext)
: context(symbolContext ? MutableIntegerSet(symbolContext, context)
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 5c64b2f..babe95f 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -32,7 +32,7 @@
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExprWrap mlir::getTripCountExpr(const ForStmt &forStmt) {
+AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
// upper_bound - lower_bound + 1
int64_t loopSpan;
@@ -56,12 +56,12 @@
return nullptr;
// ub_expr - lb_expr + 1
- AffineExprWrap lbExpr(lbMap->getResult(0));
- AffineExprWrap ubExpr(ubMap->getResult(0));
+ AffineExprRef lbExpr(lbMap->getResult(0));
+ AffineExprRef ubExpr(ubMap->getResult(0));
auto loopSpanExpr = simplifyAffineExpr(
ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
- auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr.expr);
+ auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
if (!cExpr)
return AffineBinaryOpExpr::getCeilDiv(loopSpanExpr, step, context);
loopSpan = cExpr->getValue();
@@ -81,8 +81,7 @@
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
auto tripCountExpr = getTripCountExpr(forStmt);
- if (auto *constExpr =
- dyn_cast_or_null<AffineConstantExpr>(tripCountExpr.expr))
+ if (auto constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
return constExpr->getValue();
return None;
@@ -97,7 +96,7 @@
if (!tripCountExpr)
return 1;
- if (auto *constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr.expr)) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
uint64_t tripCount = constExpr->getValue();
// 0 iteration loops (greatest divisor is 2^64 - 1).
@@ -109,5 +108,5 @@
}
// Trip count is not a known constant; return its largest known divisor.
- return tripCountExpr.expr->getLargestKnownDivisor();
+ return tripCountExpr->getLargestKnownDivisor();
}