Sketch out loop unrolling transformation.
- Implement a full loop unroll for innermost loops.
- Use it to implement a pass that unroll all the innermost loops of all
mlfunction's in a module. ForStmt's parsed currently have constant trip
counts (and constant loop bounds).
- Implement StmtVisitor based (Visitor pattern)
Loop IVs aren't currently parsed and represented as SSA values. Replacing uses
of loop IVs in unrolled bodies is thus a TODO. Class comments are sparse at some places - will add them after one round of comments.
A cmd-line flag triggers this for now.
Original:
mlfunc @loops() {
for x = 1 to 100 step 2 {
for x = 1 to 4 {
"Const"(){value: 1} : () -> ()
}
}
return
}
After unrolling:
mlfunc @loops() {
for x = 1 to 100 step 2 {
"Const"(){value: 1} : () -> ()
"Const"(){value: 1} : () -> ()
"Const"(){value: 1} : () -> ()
"Const"(){value: 1} : () -> ()
}
return
}
PiperOrigin-RevId: 205933235
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index f1ccddd..4f62b28 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -21,10 +21,13 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/Parser.h"
+#include "mlir/Pass.h"
#include "mlir/Transforms/ConvertToCFG.h"
+#include "mlir/Transforms/Loop.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/InitLLVM.h"
@@ -49,6 +52,10 @@
"convert-to-cfg",
cl::desc("Convert all ML functions in the module to CFG ones"));
+static cl::opt<bool> unrollInnermostLoops("unroll-innermost-loops",
+ cl::desc("Unroll innermost loops"),
+ cl::init(false));
+
enum OptResult { OptSuccess, OptFailure };
/// Open the specified output file and return it, exiting if there is any I/O or
@@ -83,6 +90,11 @@
if (convertToCFGOpt)
convertToCFG(module.get());
+ if (unrollInnermostLoops) {
+ MLFunctionPass *loopUnroll = createLoopUnrollPass();
+ loopUnroll->runOnModule(module.get());
+ }
+
// Print the output.
auto output = getOutputStream();
module->print(output->os());