[MLIR] Value types for AffineXXXExpr
This CL makes AffineExprRef into a value type.
Notably:
1. drops llvm isa, cast, dyn_cast on pointer type and uses member functions on
the value type. It may be possible to still use classof (in a followup CL)
2. AffineBaseExprRef aggressively casts constness away: if we mean the type is
immutable then let's jump in with both feet;
3. Drop implicit casts to the underlying pointer type because that always
results in surprising behavior and is not needed in practice once enough
cleanup has been applied.
The remaining negative I see is that we still need to mix operator. and
operator->. There is an ugly solution that forwards the methods but that ends
up duplicating the class hierarchy which I tried to avoid as much as
possible. But maybe it's not that bad anymore since AffineExpr.h would still
contain a single class hierarchy (the duplication would be impl detail in.cpp)
PiperOrigin-RevId: 216188003
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index ebcc50a..2f58500 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -139,10 +139,10 @@
operandExprStack.reserve(8);
}
- void visitMulExpr(AffineBinaryOpExpr *expr) {
+ void visitMulExpr(AffineBinaryOpExprRef expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(isa<AffineConstantExpr>(expr->getRHS()));
+ assert(expr->getRHS().isa<AffineConstantExprRef>());
// Get the RHS constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
@@ -153,7 +153,7 @@
}
}
- void visitAddExpr(AffineBinaryOpExpr *expr) {
+ void visitAddExpr(AffineBinaryOpExprRef expr) {
assert(operandExprStack.size() >= 2);
const auto &rhs = operandExprStack.back();
auto &lhs = operandExprStack[operandExprStack.size() - 2];
@@ -166,10 +166,10 @@
operandExprStack.pop_back();
}
- void visitModExpr(AffineBinaryOpExpr *expr) {
+ void visitModExpr(AffineBinaryOpExprRef expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(isa<AffineConstantExpr>(expr->getRHS()));
+ assert(expr->getRHS().isa<AffineConstantExprRef>());
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
auto &lhs = operandExprStack.back();
@@ -195,32 +195,32 @@
AffineConstantExpr::get(rhsConst, context), context));
lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
}
- void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
+ void visitCeilDivExpr(AffineBinaryOpExprRef expr) {
visitDivExpr(expr, /*isCeil=*/true);
}
- void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
+ void visitFloorDivExpr(AffineBinaryOpExprRef expr) {
visitDivExpr(expr, /*isCeil=*/false);
}
- void visitDimExpr(AffineDimExpr *expr) {
+ void visitDimExpr(AffineDimExprRef expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
eq[getDimStartIndex() + expr->getPosition()] = 1;
}
- void visitSymbolExpr(AffineSymbolExpr *expr) {
+ void visitSymbolExpr(AffineSymbolExprRef expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
eq[getSymbolStartIndex() + expr->getPosition()] = 1;
}
- void visitConstantExpr(AffineConstantExpr *expr) {
+ void visitConstantExpr(AffineConstantExprRef expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
eq[getConstantIndex()] = expr->getValue();
}
private:
- void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
+ void visitDivExpr(AffineBinaryOpExprRef expr, bool isCeil) {
assert(operandExprStack.size() >= 2);
- assert(isa<AffineConstantExpr>(expr->getRHS()));
+ assert(expr->getRHS().isa<AffineConstantExprRef>());
// This is a pure affine expr; the RHS is a positive constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
// TODO(bondhugula): handle division by zero at the same time the issue is
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index 772ec85..4d72808 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -38,8 +38,7 @@
unsigned j = 0;
AffineBoundExprList::const_iterator it, e;
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
- if (auto *cExpr = const_cast<AffineConstantExpr *>(
- dyn_cast<AffineConstantExpr>(*it))) {
+ if (auto cExpr = it->dyn_cast<AffineConstantExprRef>()) {
if (val == None) {
val = cExpr->getValue();
*idx = j;
@@ -69,7 +68,7 @@
}
if (it == lhsList.end()) {
// There can only be one constant affine expr in this bound list.
- if (auto cExpr = dyn_cast<AffineConstantExpr>(expr)) {
+ if (auto cExpr = expr.dyn_cast<AffineConstantExprRef>()) {
unsigned idx;
if (lb) {
auto cb = getReducedConstBound(
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index babe95f..0b50494 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -61,7 +61,7 @@
auto loopSpanExpr = simplifyAffineExpr(
ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
- auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
+ auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExprRef>();
if (!cExpr)
return AffineBinaryOpExpr::getCeilDiv(loopSpanExpr, step, context);
loopSpan = cExpr->getValue();
@@ -81,7 +81,10 @@
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
auto tripCountExpr = getTripCountExpr(forStmt);
- if (auto constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
+ if (!tripCountExpr)
+ return None;
+
+ if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>())
return constExpr->getValue();
return None;
@@ -96,7 +99,7 @@
if (!tripCountExpr)
return 1;
- if (auto constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
+ if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>()) {
uint64_t tripCount = constExpr->getValue();
// 0 iteration loops (greatest divisor is 2^64 - 1).