Extend loop unrolling to unroll by a given factor; add builder for affine
apply op.
- add builder for AffineApplyOp (first one for an operation that has
non-zero operands)
- add support for loop unrolling by a given factor; uses the affine apply op
builder.
While on this, change 'step' of ForStmt to be 'unsigned' instead of
AffineConstantExpr *. Add setters for ForStmt lb, ub, step.
Sample Input:
// CHECK-LABEL: mlfunc @loop_nest_unroll_cleanup() {
mlfunc @loop_nest_unroll_cleanup() {
for %i = 1 to 100 {
for %j = 0 to 17 {
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
%y = "addi32"(%x, %x) : (i32, i32) -> i32
}
}
return
}
Output:
$ mlir-opt -loop-unroll -unroll-factor=4 /tmp/single2.mlir
#map0 = (d0) -> (d0 + 1)
#map1 = (d0) -> (d0 + 2)
#map2 = (d0) -> (d0 + 3)
mlfunc @loop_nest_unroll_cleanup() {
for %i0 = 1 to 100 {
for %i1 = 0 to 17 step 4 {
%0 = "addi32"(%i1, %i1) : (affineint, affineint) -> i32
%1 = "addi32"(%0, %0) : (i32, i32) -> i32
%2 = affine_apply #map0(%i1)
%3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
%4 = affine_apply #map1(%i1)
%5 = "addi32"(%4, %4) : (affineint, affineint) -> i32
%6 = affine_apply #map2(%i1)
%7 = "addi32"(%6, %6) : (affineint, affineint) -> i32
}
for %i2 = 16 to 17 {
%8 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
%9 = "addi32"(%8, %8) : (i32, i32) -> i32
}
}
return
}
PiperOrigin-RevId: 209676220
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 0f6d428..a4a11a7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -31,30 +31,49 @@
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
+using namespace llvm;
+
+// Loop unrolling factor.
+static llvm::cl::opt<unsigned>
+ clUnrollFactor("unroll-factor", cl::Hidden,
+ cl::desc("Use this unroll factor for all loops"));
+
+static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
+ cl::desc("Fully unroll loops"));
+
+static llvm::cl::opt<unsigned> clUnrollFullThreshold(
+ "unroll-full-threshold", cl::Hidden,
+ cl::desc("Unroll all loops with trip count less than or equal to this"));
namespace {
-/// Loop unrolling pass. For now, this unrolls all the innermost loops of this
-/// MLFunction.
+/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
+/// full unroll threshold was specified, in which case, fully unrolls all loops
+/// with trip count less than the specified threshold. The latter is for testing
+/// purposes, especially for testing outer loop unrolling.
struct LoopUnroll : public MLFunctionPass {
- void runOnMLFunction(MLFunction *f) override;
- void runOnForStmt(ForStmt *forStmt);
-};
+ Optional<unsigned> unrollFactor;
+ Optional<bool> unrollFull;
-/// Unrolls all loops with trip count <= minTripCount.
-struct ShortLoopUnroll : public LoopUnroll {
- const unsigned minTripCount;
+ explicit LoopUnroll(Optional<unsigned> unrollFactor,
+ Optional<bool> unrollFull)
+ : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
+
void runOnMLFunction(MLFunction *f) override;
- ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
+ /// Unroll this for stmt. Returns false if nothing was done.
+ bool runOnForStmt(ForStmt *forStmt);
+ bool loopUnrollFull(ForStmt *forStmt);
+ bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
};
} // end anonymous namespace
-MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
-
-MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
- return new ShortLoopUnroll(minTripCount);
+MLFunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
+ return new LoopUnroll(unrollFactor == -1 ? None
+ : Optional<unsigned>(unrollFactor),
+ unrollFull == -1 ? None : Optional<bool>(unrollFull));
}
void LoopUnroll::runOnMLFunction(MLFunction *f) {
@@ -81,7 +100,6 @@
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
if (!hasInnerLoops)
loops.push_back(forStmt);
-
return true;
}
@@ -101,14 +119,6 @@
using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
- InnermostLoopGatherer ilg;
- ilg.walkPostOrder(f);
- auto &loops = ilg.loops;
- for (auto *forStmt : loops)
- runOnForStmt(forStmt);
-}
-
-void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all loops with trip count <= minTripCount.
class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
public:
@@ -120,27 +130,55 @@
void visitForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
- auto step = forStmt->getStep()->getValue();
+ auto step = forStmt->getStep();
if ((ub - lb) / step + 1 <= minTripCount)
loops.push_back(forStmt);
}
};
- ShortLoopGatherer slg(minTripCount);
- // Do a post order walk so that loops are gathered from innermost to
- // outermost (or else unrolling an outer one may delete gathered inner ones).
- slg.walkPostOrder(f);
- auto &loops = slg.loops;
+ if (clUnrollFull.getNumOccurrences() > 0 &&
+ clUnrollFullThreshold.getNumOccurrences() > 0) {
+ ShortLoopGatherer slg(clUnrollFullThreshold);
+ // Do a post order walk so that loops are gathered from innermost to
+ // outermost (or else unrolling an outer one may delete gathered inner
+ // ones).
+ slg.walkPostOrder(f);
+ auto &loops = slg.loops;
+ for (auto *forStmt : loops)
+ loopUnrollFull(forStmt);
+ return;
+ }
+
+ InnermostLoopGatherer ilg;
+ ilg.walkPostOrder(f);
+ auto &loops = ilg.loops;
for (auto *forStmt : loops)
runOnForStmt(forStmt);
}
-/// Unroll this For loop completely.
-void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+/// Unroll a for stmt. Default unroll factor is 4.
+bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+ // Unroll completely if full loop unroll was specified.
+ if (clUnrollFull.getNumOccurrences() > 0 ||
+ (unrollFull.hasValue() && unrollFull.getValue()))
+ return loopUnrollFull(forStmt);
+
+ // Unroll by the specified factor if one was specified.
+ if (clUnrollFactor.getNumOccurrences() > 0)
+ return loopUnrollByFactor(forStmt, clUnrollFactor);
+ else if (unrollFactor.hasValue())
+ return loopUnrollByFactor(forStmt, unrollFactor.getValue());
+
+ // Unroll by four otherwise.
+ return loopUnrollByFactor(forStmt, 4);
+}
+
+// Unrolls this loop completely.
+bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
- auto step = forStmt->getStep()->getValue();
+ auto step = forStmt->getStep();
// Builder to add constants need for the unrolled iterator.
auto *mlFunc = forStmt->findFunction();
@@ -164,9 +202,75 @@
// Clone the body of the loop.
for (auto &childStmt : *forStmt) {
- (void)builder.clone(childStmt, operandMapping);
+ builder.clone(childStmt, operandMapping);
}
}
// Erase the original 'for' stmt from the block.
forStmt->eraseFromBlock();
+ return true;
+}
+
+/// Unrolls this loop by the specified unroll factor.
+bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
+ assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
+
+ if (unrollFactor == 1 || forStmt->getStatements().empty())
+ return false;
+
+ auto lb = forStmt->getLowerBound()->getValue();
+ auto ub = forStmt->getUpperBound()->getValue();
+ auto step = forStmt->getStep();
+
+ int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
+
+ // If the trip count is lower than the unroll factor, no unrolled body.
+ // TODO(bondhugula): option to specify cleanup loop unrolling.
+ if (tripCount < unrollFactor)
+ return true;
+
+ // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
+ if (tripCount % unrollFactor) {
+ DenseMap<const MLValue *, MLValue *> operandMap;
+ MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
+ auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
+ cleanupForStmt->setLowerBound(builder.getConstantExpr(
+ lb + (tripCount - tripCount % unrollFactor) * step));
+ }
+
+ // Builder to insert unrolled bodies right after the last statement in the
+ // body of 'forStmt'.
+ MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
+ forStmt->setStep(step * unrollFactor);
+ forStmt->setUpperBound(builder.getConstantExpr(
+ lb + (tripCount - tripCount % unrollFactor - 1) * step));
+
+ // Keep a pointer to the last statement in the original block so that we know
+ // what to clone (since we are doing this in-place).
+ StmtBlock::iterator srcBlockEnd = --forStmt->end();
+
+ // Unroll the contents of 'forStmt' (unrollFactor-1 additional copies
+ // appended).
+ for (unsigned i = 1; i < unrollFactor; i++) {
+ DenseMap<const MLValue *, MLValue *> operandMapping;
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forStmt->use_empty()) {
+ // iv' = iv + 1/2/3...unrollFactor-1;
+ auto *bumpExpr = builder.getAddExpr(builder.getDimExpr(0),
+ builder.getConstantExpr(i * step));
+ auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {});
+ auto *ivUnroll =
+ builder.create<AffineApplyOp>(bumpMap, forStmt)->getResult(0);
+ operandMapping[forStmt] = cast<MLValue>(ivUnroll);
+ }
+
+ // Clone the original body of the loop (this doesn't include the last stmt).
+ for (auto it = forStmt->begin(); it != srcBlockEnd; it++) {
+ builder.clone(*it, operandMapping);
+ }
+ // Clone the last statement in the original body.
+ builder.clone(*srcBlockEnd, operandMapping);
+ }
+ return true;
}