[MLIR] AffineExpr final cleanups
This CL:
1. performs the global codemod AffineXExpr->AffineXExprClass and
AffineXExprRef -> AffineXExpr;
2. simplifies function calls by removing the redundant MLIRContext parameter;
3. adds missing binary operator versions of scalar op AffineExpr where it
makes sense.
PiperOrigin-RevId: 216242674
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index 2d314e0..fa2541a 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -33,14 +33,14 @@
/// Constructs an affine expression from a flat ArrayRef. If there are local
/// identifiers (neither dimensional nor symbolic) that appear in the sum of
-/// products expression, 'localExprs' is expected to have the AffineExpr for it,
-/// and is substituted into. The ArrayRef 'eq' is expected to be in the format
-/// [dims, symbols, locals, constant term].
+/// products expression, 'localExprs' is expected to have the AffineExprClass
+/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
+/// format [dims, symbols, locals, constant term].
// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
-static AffineExprRef toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
- unsigned numSymbols,
- ArrayRef<AffineExprRef> localExprs,
- MLIRContext *context) {
+static AffineExpr toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+ unsigned numSymbols,
+ ArrayRef<AffineExpr> localExprs,
+ MLIRContext *context) {
// Assert expected numLocals = eq.size() - numDims - numSymbols - 1
assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
"unexpected number of local expressions");
@@ -74,7 +74,7 @@
namespace {
-// This class is used to flatten a pure affine expression (AffineExprRef,
+// This class is used to flatten a pure affine expression (AffineExpr,
// which is in a tree form) into a sum of products (w.r.t constants) when
// possible, and in that process simplifying the expression. The simplification
// performed includes the accumulation of contributions for each dimensional and
@@ -127,14 +127,14 @@
// Number of newly introduced identifiers to flatten mod/floordiv/ceildiv
// expressions that could not be simplified.
unsigned numLocals;
- // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
+ // AffineExprClass's corresponding to the floordiv/ceildiv/mod expressions for
// which new identifiers were introduced; if the latter do not get canceled
- // out, these expressions are needed to reconstruct the AffineExprRef / tree
+ // out, these expressions are needed to reconstruct the AffineExpr / tree
// form. Note that these expressions themselves would have been simplified
// (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be
// simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2
// would be the local expression stored for q.
- SmallVector<AffineExprRef, 4> localExprs;
+ SmallVector<AffineExpr, 4> localExprs;
MLIRContext *context;
AffineExprFlattener(unsigned numDims, unsigned numSymbols,
@@ -144,10 +144,10 @@
operandExprStack.reserve(8);
}
- void visitMulExpr(AffineBinaryOpExprRef expr) {
+ void visitMulExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(expr->getRHS().isa<AffineConstantExprRef>());
+ assert(expr->getRHS().isa<AffineConstantExpr>());
// Get the RHS constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
@@ -158,7 +158,7 @@
}
}
- void visitAddExpr(AffineBinaryOpExprRef expr) {
+ void visitAddExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
const auto &rhs = operandExprStack.back();
auto &lhs = operandExprStack[operandExprStack.size() - 2];
@@ -171,10 +171,10 @@
operandExprStack.pop_back();
}
- void visitModExpr(AffineBinaryOpExprRef expr) {
+ void visitModExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(expr->getRHS().isa<AffineConstantExprRef>());
+ assert(expr->getRHS().isa<AffineConstantExpr>());
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
auto &lhs = operandExprStack.back();
@@ -200,32 +200,32 @@
addLocalId(a.floorDiv(b));
lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
}
- void visitCeilDivExpr(AffineBinaryOpExprRef expr) {
+ void visitCeilDivExpr(AffineBinaryOpExpr expr) {
visitDivExpr(expr, /*isCeil=*/true);
}
- void visitFloorDivExpr(AffineBinaryOpExprRef expr) {
+ void visitFloorDivExpr(AffineBinaryOpExpr expr) {
visitDivExpr(expr, /*isCeil=*/false);
}
- void visitDimExpr(AffineDimExprRef expr) {
+ void visitDimExpr(AffineDimExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
eq[getDimStartIndex() + expr->getPosition()] = 1;
}
- void visitSymbolExpr(AffineSymbolExprRef expr) {
+ void visitSymbolExpr(AffineSymbolExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
eq[getSymbolStartIndex() + expr->getPosition()] = 1;
}
- void visitConstantExpr(AffineConstantExprRef expr) {
+ void visitConstantExpr(AffineConstantExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
eq[getConstantIndex()] = expr->getValue();
}
private:
- void visitDivExpr(AffineBinaryOpExprRef expr, bool isCeil) {
+ void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) {
assert(operandExprStack.size() >= 2);
- assert(expr->getRHS().isa<AffineConstantExprRef>());
+ assert(expr->getRHS().isa<AffineConstantExpr>());
// This is a pure affine expr; the RHS is a positive constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
// TODO(bondhugula): handle division by zero at the same time the issue is
@@ -266,9 +266,9 @@
}
// Add an existential quantifier (used to flatten a mod, floordiv, ceildiv
- // expr). localExpr is the simplified tree expression (AffineExprRef)
+ // expr). localExpr is the simplified tree expression (AffineExpr)
// corresponding to the quantifier.
- void addLocalId(AffineExprRef localExpr) {
+ void addLocalId(AffineExpr localExpr) {
for (auto &subExpr : operandExprStack) {
subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
}
@@ -284,8 +284,8 @@
} // end anonymous namespace
-AffineExprRef mlir::simplifyAffineExpr(AffineExprRef expr, unsigned numDims,
- unsigned numSymbols) {
+AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
+ unsigned numSymbols) {
// TODO(bondhugula): only pure affine for now. The simplification here can be
// extended to semi-affine maps in the future.
if (!expr->isPureAffine())
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index ac0ba3b..463a64b 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -47,7 +47,7 @@
struct AffineMapCompositionUpdate {
using PositionMap = DenseMap<unsigned, unsigned>;
- explicit AffineMapCompositionUpdate(ArrayRef<AffineExprRef> inputResults)
+ explicit AffineMapCompositionUpdate(ArrayRef<AffineExpr> inputResults)
: inputResults(inputResults), outputNumDims(0), outputNumSymbols(0) {}
// Map from 'curr' affine map dim position to 'output' affine map
@@ -65,7 +65,7 @@
// symbol position.
PositionMap inputSymbolMap;
// Results of 'input' affine map.
- ArrayRef<AffineExprRef> inputResults;
+ ArrayRef<AffineExpr> inputResults;
// Number of dimension operands for 'output' affine map.
unsigned outputNumDims;
// Number of symbol operands for 'output' affine map.
@@ -80,29 +80,29 @@
AffineExprComposer(const AffineMapCompositionUpdate &mapUpdate)
: mapUpdate(mapUpdate), walkingInputMap(false) {}
- AffineExprRef walk(AffineExprRef expr) {
+ AffineExpr walk(AffineExpr expr) {
switch (expr->getKind()) {
case AffineExprKind::Add:
return walkBinExpr(
- expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs + rhs; });
+ expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs + rhs; });
case AffineExprKind::Mul:
return walkBinExpr(
- expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs * rhs; });
+ expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs * rhs; });
case AffineExprKind::Mod:
return walkBinExpr(
- expr, [](AffineExprRef lhs, AffineExprRef rhs) { return lhs % rhs; });
+ expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs % rhs; });
case AffineExprKind::FloorDiv:
- return walkBinExpr(expr, [](AffineExprRef lhs, AffineExprRef rhs) {
+ return walkBinExpr(expr, [](AffineExpr lhs, AffineExpr rhs) {
return lhs.floorDiv(rhs);
});
case AffineExprKind::CeilDiv:
- return walkBinExpr(expr, [](AffineExprRef lhs, AffineExprRef rhs) {
+ return walkBinExpr(expr, [](AffineExpr lhs, AffineExpr rhs) {
return lhs.ceilDiv(rhs);
});
case AffineExprKind::Constant:
return expr;
case AffineExprKind::DimId: {
- unsigned dimPosition = expr.cast<AffineDimExprRef>()->getPosition();
+ unsigned dimPosition = expr.cast<AffineDimExpr>()->getPosition();
if (walkingInputMap) {
return getAffineDimExpr(mapUpdate.inputDimMap.lookup(dimPosition),
expr->getContext());
@@ -123,7 +123,7 @@
return composer.walk(mapUpdate.inputResults[inputResultIndex]);
}
case AffineExprKind::SymbolId:
- unsigned symbolPosition = expr.cast<AffineSymbolExprRef>()->getPosition();
+ unsigned symbolPosition = expr.cast<AffineSymbolExpr>()->getPosition();
if (walkingInputMap) {
return getAffineSymbolExpr(
mapUpdate.inputSymbolMap.lookup(symbolPosition),
@@ -139,10 +139,9 @@
bool walkingInputMap)
: mapUpdate(mapUpdate), walkingInputMap(walkingInputMap) {}
- AffineExprRef
- walkBinExpr(AffineExprRef expr,
- std::function<AffineExprRef(AffineExprRef, AffineExprRef)> op) {
- auto binOpExpr = expr.cast<AffineBinaryOpExprRef>();
+ AffineExpr walkBinExpr(AffineExpr expr,
+ std::function<AffineExpr(AffineExpr, AffineExpr)> op) {
+ auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
return op(walk(binOpExpr->getLHS()), walk(binOpExpr->getRHS()));
}
@@ -197,7 +196,7 @@
}
AffineMap *MutableAffineMap::getAffineMap() {
- return AffineMap::get(numDims, numSymbols, results, rangeSizes, context);
+ return AffineMap::get(numDims, numSymbols, results, rangeSizes);
}
MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context)
@@ -295,10 +294,10 @@
DenseSet<unsigned> *positions;
AffineExprPositionGatherer(unsigned numDims, DenseSet<unsigned> *positions)
: numDims(numDims), positions(positions) {}
- void visitDimExpr(AffineDimExprRef expr) {
+ void visitDimExpr(AffineDimExpr expr) {
positions->insert(expr->getPosition());
}
- void visitSymbolExpr(AffineSymbolExprRef expr) {
+ void visitSymbolExpr(AffineSymbolExpr expr) {
positions->insert(numDims + expr->getPosition());
}
};
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index 4d72808..7fc5b29 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -38,7 +38,7 @@
unsigned j = 0;
AffineBoundExprList::const_iterator it, e;
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
- if (auto cExpr = it->dyn_cast<AffineConstantExprRef>()) {
+ if (auto cExpr = it->dyn_cast<AffineConstantExpr>()) {
if (val == None) {
val = cExpr->getValue();
*idx = j;
@@ -52,8 +52,9 @@
return val;
}
-// Merge the two lists of AffineExpr's into a single one, avoiding duplicates.
-// lb specifies whether the bound lists are for a lower bound or an upper bound.
+// Merge the two lists of AffineExprClass's into a single one, avoiding
+// duplicates. lb specifies whether the bound lists are for a lower bound or an
+// upper bound.
// TODO(bondhugula): clean this code up.
static void mergeBounds(const HyperRectangularSet &set,
AffineBoundExprList &lhsList,
@@ -68,7 +69,7 @@
}
if (it == lhsList.end()) {
// There can only be one constant affine expr in this bound list.
- if (auto cExpr = expr.dyn_cast<AffineConstantExprRef>()) {
+ if (auto cExpr = expr.dyn_cast<AffineConstantExpr>()) {
unsigned idx;
if (lb) {
auto cb = getReducedConstBound(
@@ -105,8 +106,8 @@
}
HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
- ArrayRef<ArrayRef<AffineExprRef>> lbs,
- ArrayRef<ArrayRef<AffineExprRef>> ubs,
+ ArrayRef<ArrayRef<AffineExpr>> lbs,
+ ArrayRef<ArrayRef<AffineExpr>> ubs,
MLIRContext *context,
IntegerSet *symbolContext)
: context(symbolContext ? MutableIntegerSet(symbolContext, context)
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index 522720e..b3e3afe 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -32,7 +32,7 @@
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExprRef mlir::getTripCountExpr(const ForStmt &forStmt) {
+AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
// upper_bound - lower_bound + 1
int64_t loopSpan;
@@ -56,12 +56,12 @@
return nullptr;
// ub_expr - lb_expr + 1
- AffineExprRef lbExpr(lbMap->getResult(0));
- AffineExprRef ubExpr(ubMap->getResult(0));
+ AffineExpr lbExpr(lbMap->getResult(0));
+ AffineExpr ubExpr(ubMap->getResult(0));
auto loopSpanExpr = simplifyAffineExpr(
ubExpr - lbExpr + 1, std::max(lbMap->getNumDims(), ubMap->getNumDims()),
std::max(lbMap->getNumSymbols(), ubMap->getNumSymbols()));
- auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExprRef>();
+ auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
if (!cExpr)
return loopSpanExpr.ceilDiv(step);
loopSpan = cExpr->getValue();
@@ -84,7 +84,7 @@
if (!tripCountExpr)
return None;
- if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>())
+ if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>())
return constExpr->getValue();
return None;
@@ -99,7 +99,7 @@
if (!tripCountExpr)
return 1;
- if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExprRef>()) {
+ if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>()) {
uint64_t tripCount = constExpr->getValue();
// 0 iteration loops (greatest divisor is 2^64 - 1).