Loop unrolling pass update
- fix/complete forStmt cloning for unrolling to work for outer loops
- create IV const's only when needed
- test outer loop unrolling by creating a short trip count unroll pass for
loops with trip counts <= <parameter>
- add unrolling test cases for multiple op results, outer loop unrolling
- fix/clean up StmtWalker class while on this
- switch unroll loop iterator values from i32 to affineint
PiperOrigin-RevId: 207645967
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 27bb43f..eea3bf7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -39,14 +39,22 @@
void runOnMLFunction(MLFunction *f) override;
void runOnForStmt(ForStmt *forStmt);
};
+struct ShortLoopUnroll : public LoopUnroll {
+ const unsigned minTripCount;
+ void runOnMLFunction(MLFunction *f) override;
+ ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
+};
} // end anonymous namespace
MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
+MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
+ return new ShortLoopUnroll(minTripCount);
+}
+
/// Unrolls all the innermost loops of this MLFunction.
void LoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all innermost loops through a post order pruned walk.
- // TODO: figure out the right reusable template here to better refactor code.
class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
public:
// Store innermost loops as we walk.
@@ -65,11 +73,6 @@
return hasInnerLoops;
}
- // FIXME: can't use base class method for this because that in turn would
- // need to use the derived class method above. CRTP doesn't allow it, and
- // the compiler error resulting from it is also very misleading!
- void walkPostOrder(MLFunction *f) { walkPostOrder(f->begin(), f->end()); }
-
bool walkForStmtPostOrder(ForStmt *forStmt) {
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
if (!hasInnerLoops)
@@ -85,8 +88,11 @@
return hasInnerLoops;
}
- bool walkOpStmt(OperationStmt *opStmt) { return false; }
+ bool visitOperationStmt(OperationStmt *opStmt) { return false; }
+ // FIXME: can't use base class method for this because that in turn would
+ // need to use the derived class method above. CRTP doesn't allow it, and
+ // the compiler error resulting from it is also misleading.
using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
@@ -97,28 +103,96 @@
runOnForStmt(forStmt);
}
-/// Replace all uses of 'oldVal' with 'newVal' in 'stmt'
-static void replaceAllStmtUses(Statement *stmt, MLValue *oldVal,
- MLValue *newVal) {
- struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
- // Value to be replaced.
- MLValue *oldVal;
- // Value to be replaced with.
- MLValue *newVal;
+/// Unrolls all loops with trip count <= minTripCount.
+void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
+ // Gathers all loops with trip count <= minTripCount.
+ class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
+ public:
+ // Store short loops as we walk.
+ std::vector<ForStmt *> loops;
+ const unsigned minTripCount;
+ ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
- ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
- : oldVal(oldVal), newVal(newVal){};
+ void visitForStmt(ForStmt *forStmt) {
+ auto lb = forStmt->getLowerBound()->getValue();
+ auto ub = forStmt->getUpperBound()->getValue();
+ auto step = forStmt->getStep()->getValue();
- void visitOperationStmt(OperationStmt *os) {
- for (auto &operand : os->getStmtOperands()) {
- if (operand.get() == oldVal)
- operand.set(newVal);
- }
+ if ((ub - lb) / step + 1 <= minTripCount)
+ loops.push_back(forStmt);
}
};
- ReplaceUseWalker ri(oldVal, newVal);
- ri.walk(stmt);
+ ShortLoopGatherer slg(minTripCount);
+ slg.walk(f);
+ auto &loops = slg.loops;
+ for (auto *forStmt : loops)
+ runOnForStmt(forStmt);
+}
+
+/// Replace all uses of oldVal with newVal from begin to end.
+static void replaceUses(StmtBlock::iterator begin, StmtBlock::iterator end,
+ MLValue *oldVal, MLValue *newVal) {
+ // TODO(bondhugula,clattner): do this more efficiently by walking those uses
+ // of oldVal that fall within this list of statements (instead of iterating
+ // through all statements / through all operands of operations found).
+ for (auto it = begin; it != end; it++) {
+ it->replaceUses(oldVal, newVal);
+ }
+}
+
+/// Replace all uses of oldVal with newVal.
+void replaceUses(StmtBlock *block, MLValue *oldVal, MLValue *newVal) {
+ // TODO(bondhugula,clattner): do this more efficiently by walking those uses
+ // of oldVal that fall within this StmtBlock (instead of iterating through
+ // all statements / through all operands of operations found).
+ for (auto it = block->begin(); it != block->end(); it++) {
+ it->replaceUses(oldVal, newVal);
+ }
+}
+
+/// Clone the list of stmt's from 'block' and insert into the current
+/// position of the builder.
+// TODO(bondhugula,clattner): replace this with a parameterizable clone.
+void cloneStmtListFromBlock(MLFuncBuilder *builder, const StmtBlock &block) {
+ // Pairs of <old op stmt result whose uses need to be replaced,
+ // new result generated by the corresponding cloned op stmt>.
+ SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+
+ // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+ StmtBlock::iterator beforeUnrolledBody = --builder->getInsertionPoint();
+
+ for (auto &stmt : block.getStatements()) {
+ auto *cloneStmt = builder->clone(stmt);
+ // Whenever we have an op stmt, we'll have a new ML Value defined: replace
+ // uses of the old result with this one.
+ if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ if (opStmt->getNumResults()) {
+ auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
+ for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
+ // Store old/new result pairs.
+ // TODO(bondhugula) *only* if needed later: storing of old/new
+ // results can be avoided by cloning the statement list in the
+ // reverse direction (and running the IR builder in the reverse
+ // (iplist.insertAfter()). That way, a newly created result can be
+ // immediately propagated to all its uses.
+ oldNewResultPairs.push_back(std::make_pair(
+ const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
+ &cloneOpStmt->getStmtResult(i)));
+ }
+ }
+ }
+ }
+
+ // Replace uses of old op results' with the new results.
+ StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
+ StmtBlock::iterator endOfUnrolledBody = builder->getInsertionPoint();
+
+ // Replace uses of old op results' with the newly created ones.
+ for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
+ replaceUses(startOfUnrolledBody, endOfUnrolledBody,
+ oldNewResultPairs[i].first, oldNewResultPairs[i].second);
+ }
}
/// Unroll this 'for stmt' / loop completely.
@@ -139,52 +213,25 @@
++StmtBlock::iterator(forStmt));
// Unroll the contents of 'forStmt'.
- for (int i = lb; i <= ub; i += step) {
- // TODO(bondhugula): generate constants only when IV actually appears.
- auto constOp = funcTopBuilder.create<ConstantIntOp>(i, 32);
- auto *ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
-
- // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+ for (int64_t i = lb; i <= ub; i += step) {
+ MLValue *ivConst = nullptr;
+ if (!forStmt->use_empty()) {
+ auto constOp = funcTopBuilder.create<ConstantAffineIntOp>(i);
+ ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
+ }
StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
- // Pairs of <old op stmt result whose uses need to be replaced,
- // new result generated by the corresponding cloned op stmt>.
- SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+ // Clone the loop body and insert it right after the loop - the latter will
+ // be erased after all unrolling has been done.
+ cloneStmtListFromBlock(&builder, *forStmt);
- for (auto &loopBodyStmt : forStmt->getStatements()) {
- auto *cloneStmt = builder.clone(loopBodyStmt);
- // Replace all uses of the IV in the clone with constant iteration value.
- replaceAllStmtUses(cloneStmt, forStmt, ivConst);
-
- // Whenever we have an op stmt, we'll have a new ML Value defined: replace
- // uses of the old result with this one.
- if (auto *opStmt = dyn_cast<OperationStmt>(&loopBodyStmt)) {
- if (opStmt->getNumResults()) {
- auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
- for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
- // Store old/new result pairs.
- // TODO *only* if needed later: storing of old/new results can be
- // avoided, by cloning the statement list in the reverse direction
- // (and running the IR builder in the reverse
- // (iplist.insertAfter()). That way, a newly created result can be
- // immediately propagated to all its uses, which would already been
- // cloned/inserted.
- oldNewResultPairs.push_back(std::make_pair(
- &opStmt->getStmtResult(i), &cloneOpStmt->getStmtResult(i)));
- }
- }
- }
- }
- // Replace uses of old op results' with the results in the just
- // unrolled body.
- StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
- for (auto it = ++beforeUnrolledBody; it != endOfUnrolledBody; it++) {
- for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
- replaceAllStmtUses(&(*it), oldNewResultPairs[i].first,
- oldNewResultPairs[i].second);
- }
+ // Replace unrolled loop IV with the unrolled constant.
+ if (ivConst) {
+ StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
+ StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
+ replaceUses(startOfUnrolledBody, endOfUnrolledBody, forStmt, ivConst);
}
}
- // Erase the original for stmt from the block.
+ // Erase the original 'for' stmt from the block.
forStmt->eraseFromBlock();
}