MLStmt cloning and IV replacement for loop unrolling, add constant pool to
MLFunctions.
- MLStmt cloning and IV replacement
- While at this, fix the innermostLoopGatherer to actually gather all the
innermost loops (it was stopping its walk at the first innermost loop it
found)
- Improve comments for MLFunction statement classes, fix inheritance order.
- Fixed StmtBlock destructor.
PiperOrigin-RevId: 207049173
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 160a463..fe110d2 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
@@ -54,10 +55,13 @@
typedef llvm::iplist<Statement> StmtListType;
bool walkPostOrder(StmtListType::iterator Start,
StmtListType::iterator End) {
+ bool hasInnerLoops = false;
+ // We need to walk all elements since all innermost loops need to be
+ // gathered as opposed to determining whether this list has any inner
+ // loops or not.
while (Start != End)
- if (walkPostOrder(&(*Start++)))
- return true;
- return false;
+ hasInnerLoops |= walkPostOrder(&(*Start++));
+ return hasInnerLoops;
}
// FIXME: can't use base class method for this because that in turn would
@@ -73,12 +77,11 @@
}
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
- if (walkPostOrder(ifStmt->getThenClause()->begin(),
- ifStmt->getThenClause()->end()) ||
- walkPostOrder(ifStmt->getElseClause()->begin(),
- ifStmt->getElseClause()->end()))
- return true;
- return false;
+ bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
+ ifStmt->getThenClause()->end());
+ hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
+ ifStmt->getElseClause()->end());
+ return hasInnerLoops;
}
bool walkOpStmt(OperationStmt *opStmt) { return false; }
@@ -93,17 +96,45 @@
runOnForStmt(forStmt);
}
-/// Unrolls this loop completely. Returns true if the unrolling happens.
+/// Replace an IV with a constant value.
+static void replaceIterator(Statement *stmt, const ForStmt &iv,
+ MLValue *constVal) {
+ struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
+ // IV to be replaced.
+ const ForStmt *iv;
+ // Constant to be replaced with.
+ MLValue *constVal;
+
+ ReplaceIterator(const ForStmt &iv, MLValue *constVal)
+ : iv(&iv), constVal(constVal){};
+
+ void visitOperationStmt(OperationStmt *os) {
+ for (auto &operand : os->getStmtOperands()) {
+ if (operand.get() == static_cast<const MLValue *>(iv)) {
+ operand.set(constVal);
+ }
+ }
+ }
+ };
+
+ ReplaceIterator ri(iv, constVal);
+ ri.walk(stmt);
+}
+
+/// Unrolls this loop completely.
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
auto trip_count = (ub - lb + 1) / step;
- auto *block = forStmt->getBlock();
- MLFuncBuilder builder(block);
+ auto *mlFunc = forStmt->Statement::findFunction();
+ MLFuncBuilder funcTopBuilder(mlFunc);
+ funcTopBuilder.setInsertionPointAtStart(mlFunc);
+ MLFuncBuilder builder(forStmt->getBlock());
for (int i = 0; i < trip_count; i++) {
+ auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
for (auto &stmt : forStmt->getStatements()) {
switch (stmt.getKind()) {
case Statement::Kind::For:
@@ -113,16 +144,13 @@
llvm_unreachable("unrolling loops that have only operations");
break;
case Statement::Kind::Operation:
- auto *op = cast<OperationStmt>(&stmt);
- // TODO: clone operands and result types.
- builder.createOperation(op->getName(), /*operands*/ {},
- /*resultTypes*/ {}, op->getAttrs());
- // TODO: loop iterator parsing not yet implemented; replace loop
- // iterator uses in unrolled body appropriately.
+ auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
+ // TODO(bondhugula): only generate constants when the IV actually
+ // appears in the body.
+ replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
break;
}
}
}
-
forStmt->eraseFromBlock();
}