ShortLoopUnroll - bug fix.
Collect loops through a post order walk instead of a pre-order so that loops
are collected from inner loops are collected before outer surrounding ones.
Add a complex test case.
PiperOrigin-RevId: 209041057
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index cacb207..0f6d428 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -36,10 +36,14 @@
using namespace mlir;
namespace {
+/// Loop unrolling pass. For now, this unrolls all the innermost loops of this
+/// MLFunction.
struct LoopUnroll : public MLFunctionPass {
void runOnMLFunction(MLFunction *f) override;
void runOnForStmt(ForStmt *forStmt);
};
+
+/// Unrolls all loops with trip count <= minTripCount.
struct ShortLoopUnroll : public LoopUnroll {
const unsigned minTripCount;
void runOnMLFunction(MLFunction *f) override;
@@ -53,7 +57,6 @@
return new ShortLoopUnroll(minTripCount);
}
-/// Unrolls all the innermost loops of this MLFunction.
void LoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all innermost loops through a post order pruned walk.
class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
@@ -105,7 +108,6 @@
runOnForStmt(forStmt);
}
-/// Unrolls all loops with trip count <= minTripCount.
void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all loops with trip count <= minTripCount.
class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
@@ -126,7 +128,9 @@
};
ShortLoopGatherer slg(minTripCount);
- slg.walk(f);
+ // Do a post order walk so that loops are gathered from innermost to
+ // outermost (or else unrolling an outer one may delete gathered inner ones).
+ slg.walkPostOrder(f);
auto &loops = slg.loops;
for (auto *forStmt : loops)
runOnForStmt(forStmt);