[MLIR] Sketch AffineExpr value type
This CL sketches what it takes for AffineExpr to fully have by-value semantics
and not be a not-so-smart pointer anymore.
This essentially makes the underyling class a simple storage struct and
implements the operations on the value type directly. Since there is no
forwarding of operations anymore, we can full isolate the storage class and
make a hard visibility barrier by moving detail::AffineExpr into
AffineExprDetail.h.
AffineExprDetail.h is only included where storage-related information is
needed.
PiperOrigin-RevId: 216385459
diff --git a/lib/Analysis/AffineAnalysis.cpp b/lib/Analysis/AffineAnalysis.cpp
index f9a159e..19b6638 100644
--- a/lib/Analysis/AffineAnalysis.cpp
+++ b/lib/Analysis/AffineAnalysis.cpp
@@ -30,7 +30,7 @@
/// Constructs an affine expression from a flat ArrayRef. If there are local
/// identifiers (neither dimensional nor symbolic) that appear in the sum of
-/// products expression, 'localExprs' is expected to have the AffineExprClass
+/// products expression, 'localExprs' is expected to have the AffineExpr
/// for it, and is substituted into. The ArrayRef 'eq' is expected to be in the
/// format [dims, symbols, locals, constant term].
// TODO(bondhugula): refactor getAddMulPureAffineExpr to reuse it from here.
@@ -124,7 +124,7 @@
// Number of newly introduced identifiers to flatten mod/floordiv/ceildiv
// expressions that could not be simplified.
unsigned numLocals;
- // AffineExprClass's corresponding to the floordiv/ceildiv/mod expressions for
+ // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
// which new identifiers were introduced; if the latter do not get canceled
// out, these expressions are needed to reconstruct the AffineExpr / tree
// form. Note that these expressions themselves would have been simplified
@@ -144,7 +144,7 @@
void visitMulExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(expr->getRHS().isa<AffineConstantExpr>());
+ assert(expr.getRHS().isa<AffineConstantExpr>());
// Get the RHS constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
@@ -171,7 +171,7 @@
void visitModExpr(AffineBinaryOpExpr expr) {
assert(operandExprStack.size() >= 2);
// This is a pure affine expr; the RHS will be a constant.
- assert(expr->getRHS().isa<AffineConstantExpr>());
+ assert(expr.getRHS().isa<AffineConstantExpr>());
auto rhsConst = operandExprStack.back()[getConstantIndex()];
operandExprStack.pop_back();
auto &lhs = operandExprStack.back();
@@ -206,23 +206,23 @@
void visitDimExpr(AffineDimExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
- eq[getDimStartIndex() + expr->getPosition()] = 1;
+ eq[getDimStartIndex() + expr.getPosition()] = 1;
}
void visitSymbolExpr(AffineSymbolExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
- eq[getSymbolStartIndex() + expr->getPosition()] = 1;
+ eq[getSymbolStartIndex() + expr.getPosition()] = 1;
}
void visitConstantExpr(AffineConstantExpr expr) {
operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
auto &eq = operandExprStack.back();
- eq[getConstantIndex()] = expr->getValue();
+ eq[getConstantIndex()] = expr.getValue();
}
private:
void visitDivExpr(AffineBinaryOpExpr expr, bool isCeil) {
assert(operandExprStack.size() >= 2);
- assert(expr->getRHS().isa<AffineConstantExpr>());
+ assert(expr.getRHS().isa<AffineConstantExpr>());
// This is a pure affine expr; the RHS is a positive constant.
auto rhsConst = operandExprStack.back()[getConstantIndex()];
// TODO(bondhugula): handle division by zero at the same time the issue is
@@ -285,14 +285,14 @@
unsigned numSymbols) {
// TODO(bondhugula): only pure affine for now. The simplification here can be
// extended to semi-affine maps in the future.
- if (!expr->isPureAffine())
+ if (!expr.isPureAffine())
return nullptr;
- AffineExprFlattener flattener(numDims, numSymbols, expr->getContext());
+ AffineExprFlattener flattener(numDims, numSymbols, expr.getContext());
flattener.walkPostOrder(expr);
ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
auto simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
- flattener.localExprs, expr->getContext());
+ flattener.localExprs, expr.getContext());
flattener.operandExprStack.pop_back();
assert(flattener.operandExprStack.empty());
return simplifiedExpr;
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index 62ad477..9320326 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -81,7 +81,7 @@
: mapUpdate(mapUpdate), walkingInputMap(false) {}
AffineExpr walk(AffineExpr expr) {
- switch (expr->getKind()) {
+ switch (expr.getKind()) {
case AffineExprKind::Add:
return walkBinExpr(
expr, [](AffineExpr lhs, AffineExpr rhs) { return lhs + rhs; });
@@ -102,16 +102,16 @@
case AffineExprKind::Constant:
return expr;
case AffineExprKind::DimId: {
- unsigned dimPosition = expr.cast<AffineDimExpr>()->getPosition();
+ unsigned dimPosition = expr.cast<AffineDimExpr>().getPosition();
if (walkingInputMap) {
return getAffineDimExpr(mapUpdate.inputDimMap.lookup(dimPosition),
- expr->getContext());
+ expr.getContext());
}
// Check if we are just mapping this dim to another position.
if (mapUpdate.currDimMap.count(dimPosition) > 0) {
assert(mapUpdate.currDimToInputResultMap.count(dimPosition) == 0);
return getAffineDimExpr(mapUpdate.currDimMap.lookup(dimPosition),
- expr->getContext());
+ expr.getContext());
}
// We are substituting an input map result at 'dimPositon'
// Forward substitute currDimToInputResultMap[dimPosition] into this
@@ -123,14 +123,13 @@
return composer.walk(mapUpdate.inputResults[inputResultIndex]);
}
case AffineExprKind::SymbolId:
- unsigned symbolPosition = expr.cast<AffineSymbolExpr>()->getPosition();
+ unsigned symbolPosition = expr.cast<AffineSymbolExpr>().getPosition();
if (walkingInputMap) {
return getAffineSymbolExpr(
- mapUpdate.inputSymbolMap.lookup(symbolPosition),
- expr->getContext());
+ mapUpdate.inputSymbolMap.lookup(symbolPosition), expr.getContext());
}
return getAffineSymbolExpr(mapUpdate.currSymbolMap.lookup(symbolPosition),
- expr->getContext());
+ expr.getContext());
}
}
@@ -142,7 +141,7 @@
AffineExpr walkBinExpr(AffineExpr expr,
std::function<AffineExpr(AffineExpr, AffineExpr)> op) {
auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
- return op(walk(binOpExpr->getLHS()), walk(binOpExpr->getRHS()));
+ return op(walk(binOpExpr.getLHS()), walk(binOpExpr.getRHS()));
}
// Map update specifies to dim and symbol postion maps, as well as the input
@@ -177,7 +176,7 @@
}
bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
- if (results[idx]->isMultipleOf(factor))
+ if (results[idx].isMultipleOf(factor))
return true;
// TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
@@ -295,10 +294,10 @@
AffineExprPositionGatherer(unsigned numDims, DenseSet<unsigned> *positions)
: numDims(numDims), positions(positions) {}
void visitDimExpr(AffineDimExpr expr) {
- positions->insert(expr->getPosition());
+ positions->insert(expr.getPosition());
}
void visitSymbolExpr(AffineSymbolExpr expr) {
- positions->insert(numDims + expr->getPosition());
+ positions->insert(numDims + expr.getPosition());
}
};
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index 7fc5b29..bd1361b 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -40,10 +40,10 @@
for (it = ubs.begin(), e = ubs.end(); it != e; it++, j++) {
if (auto cExpr = it->dyn_cast<AffineConstantExpr>()) {
if (val == None) {
- val = cExpr->getValue();
+ val = cExpr.getValue();
*idx = j;
- } else if (cmp(cExpr->getValue(), val.getValue())) {
- val = cExpr->getValue();
+ } else if (cmp(cExpr.getValue(), val.getValue())) {
+ val = cExpr.getValue();
*idx = j;
}
}
@@ -52,7 +52,7 @@
return val;
}
-// Merge the two lists of AffineExprClass's into a single one, avoiding
+// Merge the two lists of AffineExpr's into a single one, avoiding
// duplicates. lb specifies whether the bound lists are for a lower bound or an
// upper bound.
// TODO(bondhugula): clean this code up.
@@ -79,7 +79,7 @@
lhsList.push_back(expr);
continue;
}
- if (cExpr->getValue() < cb)
+ if (cExpr.getValue() < cb)
lhsList[idx] = expr;
// A constant value >= the existing bound constant.
continue;
@@ -93,7 +93,7 @@
lhsList.push_back(expr);
continue;
}
- if (cExpr->getValue() > cb)
+ if (cExpr.getValue() > cb)
lhsList[idx] = expr;
continue;
}
@@ -116,7 +116,7 @@
for (auto boundList : lbs) {
AffineBoundExprList lb;
for (auto expr : boundList) {
- assert(expr->isSymbolicOrConstant() &&
+ assert(expr.isSymbolicOrConstant() &&
"bound expression should be symbolic or constant");
lb.push_back(expr);
}
@@ -127,7 +127,7 @@
for (auto boundList : ubs) {
AffineBoundExprList ub;
for (auto expr : boundList) {
- assert(expr->isSymbolicOrConstant() &&
+ assert(expr.isSymbolicOrConstant() &&
"bound expression should be symbolic or constant");
ub.push_back(expr);
}
@@ -163,7 +163,7 @@
for (auto &lb : lowerBounds) {
os << "Dim " << d++ << "\n";
for (auto expr : lb) {
- expr->print(os);
+ expr.print(os);
}
}
d = 0;
@@ -171,7 +171,7 @@
for (auto &lb : upperBounds) {
os << "Dim " << d++ << "\n";
for (auto expr : lb) {
- expr->print(os);
+ expr.print(os);
}
}
}
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index b3e3afe..cb63de3 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -64,7 +64,7 @@
auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
if (!cExpr)
return loopSpanExpr.ceilDiv(step);
- loopSpan = cExpr->getValue();
+ loopSpan = cExpr.getValue();
}
// 0 iteration loops.
@@ -85,7 +85,7 @@
return None;
if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>())
- return constExpr->getValue();
+ return constExpr.getValue();
return None;
}
@@ -100,7 +100,7 @@
return 1;
if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>()) {
- uint64_t tripCount = constExpr->getValue();
+ uint64_t tripCount = constExpr.getValue();
// 0 iteration loops (greatest divisor is 2^64 - 1).
if (tripCount == 0)
@@ -111,5 +111,5 @@
}
// Trip count is not a known constant; return its largest known divisor.
- return tripCountExpr->getLargestKnownDivisor();
+ return tripCountExpr.getLargestKnownDivisor();
}