Loop unrolling update.
- deal with non-operation stmt's (if/for stmt's) in loops being unrolled
(unrolling of non-innermost loops works).
- update uses in unrolled bodies to use results of new operations that may be
introduced in the unrolled bodies.
Unrolling now works for all kinds of loop nests - perfect nests, imperfect
nests, loops at any depth, and with any kind of operation in the body. (IfStmt
support not done, hence untested there).
Added missing dump/print method for StmtBlock.
TODO: add test case for outer loop unrolling.
PiperOrigin-RevId: 207314286
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index fe110d2..27bb43f 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -26,6 +26,7 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/Pass.h"
+#include "mlir/IR/StandardOps.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Transforms/Passes.h"
@@ -96,61 +97,94 @@
runOnForStmt(forStmt);
}
-/// Replace an IV with a constant value.
-static void replaceIterator(Statement *stmt, const ForStmt &iv,
- MLValue *constVal) {
- struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
- // IV to be replaced.
- const ForStmt *iv;
- // Constant to be replaced with.
- MLValue *constVal;
+/// Replace all uses of 'oldVal' with 'newVal' in 'stmt'
+static void replaceAllStmtUses(Statement *stmt, MLValue *oldVal,
+ MLValue *newVal) {
+ struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
+ // Value to be replaced.
+ MLValue *oldVal;
+ // Value to be replaced with.
+ MLValue *newVal;
- ReplaceIterator(const ForStmt &iv, MLValue *constVal)
- : iv(&iv), constVal(constVal){};
+ ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
+ : oldVal(oldVal), newVal(newVal){};
void visitOperationStmt(OperationStmt *os) {
for (auto &operand : os->getStmtOperands()) {
- if (operand.get() == static_cast<const MLValue *>(iv)) {
- operand.set(constVal);
- }
+ if (operand.get() == oldVal)
+ operand.set(newVal);
}
}
};
- ReplaceIterator ri(iv, constVal);
+ ReplaceUseWalker ri(oldVal, newVal);
ri.walk(stmt);
}
-/// Unrolls this loop completely.
+/// Unroll this 'for stmt' / loop completely.
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
- auto trip_count = (ub - lb + 1) / step;
+ // Builder to add constants need for the unrolled iterator.
auto *mlFunc = forStmt->Statement::findFunction();
MLFuncBuilder funcTopBuilder(mlFunc);
funcTopBuilder.setInsertionPointAtStart(mlFunc);
+ // Builder to insert the unrolled bodies.
MLFuncBuilder builder(forStmt->getBlock());
- for (int i = 0; i < trip_count; i++) {
- auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
- for (auto &stmt : forStmt->getStatements()) {
- switch (stmt.getKind()) {
- case Statement::Kind::For:
- llvm_unreachable("unrolling loops that have only operations");
- break;
- case Statement::Kind::If:
- llvm_unreachable("unrolling loops that have only operations");
- break;
- case Statement::Kind::Operation:
- auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
- // TODO(bondhugula): only generate constants when the IV actually
- // appears in the body.
- replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
- break;
+ // Set insertion point to right after where the for stmt ends.
+ builder.setInsertionPoint(forStmt->getBlock(),
+ ++StmtBlock::iterator(forStmt));
+
+ // Unroll the contents of 'forStmt'.
+ for (int i = lb; i <= ub; i += step) {
+ // TODO(bondhugula): generate constants only when IV actually appears.
+ auto constOp = funcTopBuilder.create<ConstantIntOp>(i, 32);
+ auto *ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
+
+ // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+ StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
+
+ // Pairs of <old op stmt result whose uses need to be replaced,
+ // new result generated by the corresponding cloned op stmt>.
+ SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+
+ for (auto &loopBodyStmt : forStmt->getStatements()) {
+ auto *cloneStmt = builder.clone(loopBodyStmt);
+ // Replace all uses of the IV in the clone with constant iteration value.
+ replaceAllStmtUses(cloneStmt, forStmt, ivConst);
+
+ // Whenever we have an op stmt, we'll have a new ML Value defined: replace
+ // uses of the old result with this one.
+ if (auto *opStmt = dyn_cast<OperationStmt>(&loopBodyStmt)) {
+ if (opStmt->getNumResults()) {
+ auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
+ for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
+ // Store old/new result pairs.
+ // TODO *only* if needed later: storing of old/new results can be
+ // avoided, by cloning the statement list in the reverse direction
+ // (and running the IR builder in the reverse
+ // (iplist.insertAfter()). That way, a newly created result can be
+ // immediately propagated to all its uses, which would already been
+ // cloned/inserted.
+ oldNewResultPairs.push_back(std::make_pair(
+ &opStmt->getStmtResult(i), &cloneOpStmt->getStmtResult(i)));
+ }
+ }
+ }
+ }
+ // Replace uses of old op results' with the results in the just
+ // unrolled body.
+ StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
+ for (auto it = ++beforeUnrolledBody; it != endOfUnrolledBody; it++) {
+ for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
+ replaceAllStmtUses(&(*it), oldNewResultPairs[i].first,
+ oldNewResultPairs[i].second);
}
}
}
+ // Erase the original for stmt from the block.
forStmt->eraseFromBlock();
}