Sketch out loop unrolling transformation.
- Implement a full loop unroll for innermost loops.
- Use it to implement a pass that unroll all the innermost loops of all
mlfunction's in a module. ForStmt's parsed currently have constant trip
counts (and constant loop bounds).
- Implement StmtVisitor based (Visitor pattern)
Loop IVs aren't currently parsed and represented as SSA values. Replacing uses
of loop IVs in unrolled bodies is thus a TODO. Class comments are sparse at some places - will add them after one round of comments.
A cmd-line flag triggers this for now.
Original:
mlfunc @loops() {
for x = 1 to 100 step 2 {
for x = 1 to 4 {
"Const"(){value: 1} : () -> ()
}
}
return
}
After unrolling:
mlfunc @loops() {
for x = 1 to 100 step 2 {
"Const"(){value: 1} : () -> ()
"Const"(){value: 1} : () -> ()
"Const"(){value: 1} : () -> ()
"Const"(){value: 1} : () -> ()
}
return
}
PiperOrigin-RevId: 205933235
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index b469482..6b272e0 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -48,6 +49,18 @@
return this->getBlock()->getFunction();
}
+unsigned Statement::getNumNestedLoops() const {
+ struct NestedLoopCounter : public StmtVisitor<NestedLoopCounter> {
+ unsigned numNestedLoops;
+ NestedLoopCounter() : numNestedLoops(0) {}
+ void visitForStmt(const ForStmt *fs) { numNestedLoops++; }
+ };
+
+ NestedLoopCounter nlc;
+ nlc.visit(const_cast<Statement *>(this));
+ return nlc.numNestedLoops;
+}
+
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
@@ -91,7 +104,9 @@
first->block = curParent;
}
-/// Remove this statement from its StmtBlock and delete it.
+/// Remove this statement (and its descendants) from its StmtBlock and delete
+/// all of them.
+/// TODO: erase all descendents for ForStmt/IfStmt.
void Statement::eraseFromBlock() {
assert(getBlock() && "Statement has no block");
getBlock()->getStatements().erase(this);