[MLIR] Templated AffineExprBaseRef
This CL implements AffineExprBaseRef as a templated type to allow LLVM-style
casts to work properly. This also allows making AffineExprBaseRef::expr
private.
To achieve this, it is necessary to use llvm::simplify_type and make
AffineConstExpr derive from both AffineExpr and llvm::simplify<AffineExprRef>.
Note that llvm::simplify_type is just an interface to enable the proper
template resolution of isa/cast/dyn_cast but it otherwise holds no value.
Lastly note that certain dyn_cast operations wanted the const AffineExpr* form
of AffineExprBaseRef so I made the implicit constructor take that by default
and documented the immutable behavior. I think this is consistent with the
decision to make unique'd type immutable by convention and never use const on
them.
PiperOrigin-RevId: 215642247
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 5c64b2f..babe95f 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -32,7 +32,7 @@
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExprWrap mlir::getTripCountExpr(const ForStmt &forStmt) {
+AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
// upper_bound - lower_bound + 1
int64_t loopSpan;
@@ -56,12 +56,12 @@
return nullptr;
// ub_expr - lb_expr + 1
- AffineExprWrap lbExpr(lbMap->getResult(0));
- AffineExprWrap ubExpr(ubMap->getResult(0));
+ AffineExprRef lbExpr(lbMap->getResult(0));
+ AffineExprRef ubExpr(ubMap->getResult(0));
auto loopSpanExpr = simplifyAffineExpr(
ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
- auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr.expr);
+ auto *cExpr = dyn_cast<AffineConstantExpr>(loopSpanExpr);
if (!cExpr)
return AffineBinaryOpExpr::getCeilDiv(loopSpanExpr, step, context);
loopSpan = cExpr->getValue();
@@ -81,8 +81,7 @@
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
auto tripCountExpr = getTripCountExpr(forStmt);
- if (auto *constExpr =
- dyn_cast_or_null<AffineConstantExpr>(tripCountExpr.expr))
+ if (auto constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
return constExpr->getValue();
return None;
@@ -97,7 +96,7 @@
if (!tripCountExpr)
return 1;
- if (auto *constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr.expr)) {
+ if (auto constExpr = dyn_cast<AffineConstantExpr>(tripCountExpr)) {
uint64_t tripCount = constExpr->getValue();
// 0 iteration loops (greatest divisor is 2^64 - 1).
@@ -109,5 +108,5 @@
}
// Trip count is not a known constant; return its largest known divisor.
- return tripCountExpr.expr->getLargestKnownDivisor();
+ return tripCountExpr->getLargestKnownDivisor();
}