Extend loop unroll/unroll-and-jam to affine bounds + refactor related code.
- extend loop unroll-jam similar to loop unroll for affine bounds
- extend both loop unroll/unroll-jam to deal with cleanup loop for non multiple
of unroll factor.
- extend promotion of single iteration loops to work with affine bounds
- fix typo bugs in loop unroll
- refactor common code b/w loop unroll and loop unroll-jam
- move prototypes of non-pass transforms to LoopUtils.h
- add additional builder methods.
- introduce loopUnrollUpTo(factor) to unroll by either factor or trip count,
whichever is less.
- remove Statement::isInnermost (not used for now - will come back at the right
place/in right form later)
PiperOrigin-RevId: 213471227
diff --git a/lib/Analysis/LoopAnalysis.cpp b/lib/Analysis/LoopAnalysis.cpp
index dc56d62..fa283ae 100644
--- a/lib/Analysis/LoopAnalysis.cpp
+++ b/lib/Analysis/LoopAnalysis.cpp
@@ -31,7 +31,7 @@
/// Returns the trip count of the loop as an affine expression if the latter is
/// expressible as an affine expression, and nullptr otherwise. The trip count
/// expression is simplified before returning.
-AffineExpr *mlir::getTripCount(const ForStmt &forStmt) {
+AffineExpr *mlir::getTripCountExpr(const ForStmt &forStmt) {
// upper_bound - lower_bound + 1
int64_t loopSpan;
@@ -43,32 +43,22 @@
int64_t ub = forStmt.getConstantUpperBound();
loopSpan = ub - lb + 1;
} else {
- const AffineBound lb = forStmt.getLowerBound();
- const AffineBound ub = forStmt.getUpperBound();
- auto lbMap = lb.getMap();
- auto ubMap = ub.getMap();
+ auto *lbMap = forStmt.getLowerBoundMap();
+ auto *ubMap = forStmt.getUpperBoundMap();
// TODO(bondhugula): handle max/min of multiple expressions.
- if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1 ||
- lbMap->getNumDims() != ubMap->getNumDims() ||
- lbMap->getNumSymbols() != ubMap->getNumSymbols()) {
+ if (lbMap->getNumResults() != 1 || ubMap->getNumResults() != 1)
return nullptr;
- }
// TODO(bondhugula): handle bounds with different operands.
- unsigned i, e = lb.getNumOperands();
- for (i = 0; i < e; i++) {
- if (lb.getStmtOperand(i).get() != ub.getStmtOperand(i).get())
- break;
- }
// Bounds have different operands, unhandled for now.
- if (i != e)
+ if (!forStmt.matchingBoundOperandList())
return nullptr;
// ub_expr - lb_expr + 1
+ auto *lbExpr = lbMap->getResult(0);
+ auto *ubExpr = ubMap->getResult(0);
auto *loopSpanExpr = AffineBinaryOpExpr::getAdd(
- AffineBinaryOpExpr::getSub(ubMap->getResult(0), lbMap->getResult(0),
- context),
- 1, context);
+ AffineBinaryOpExpr::getSub(ubExpr, lbExpr, context), 1, context);
if (auto *expr = simplifyAffineExpr(loopSpanExpr, lbMap->getNumDims(),
lbMap->getNumSymbols(), context))
@@ -95,7 +85,7 @@
/// method uses affine expression analysis (in turn using getTripCount) and is
/// able to determine constant trip count in non-trivial cases.
llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
- AffineExpr *tripCountExpr = getTripCount(forStmt);
+ AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
if (auto *constExpr = dyn_cast_or_null<AffineConstantExpr>(tripCountExpr))
return constExpr->getValue();
@@ -107,7 +97,7 @@
/// expression analysis is used (indirectly through getTripCount), and
/// this method is thus able to determine non-trivial divisors.
uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
- AffineExpr *tripCountExpr = getTripCount(forStmt);
+ AffineExpr *tripCountExpr = getTripCountExpr(forStmt);
if (!tripCountExpr)
return 1;