Extend loop unrolling to unroll by a given factor; add builder for affine
apply op.
- add builder for AffineApplyOp (first one for an operation that has
non-zero operands)
- add support for loop unrolling by a given factor; uses the affine apply op
builder.
While on this, change 'step' of ForStmt to be 'unsigned' instead of
AffineConstantExpr *. Add setters for ForStmt lb, ub, step.
Sample Input:
// CHECK-LABEL: mlfunc @loop_nest_unroll_cleanup() {
mlfunc @loop_nest_unroll_cleanup() {
for %i = 1 to 100 {
for %j = 0 to 17 {
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
%y = "addi32"(%x, %x) : (i32, i32) -> i32
}
}
return
}
Output:
$ mlir-opt -loop-unroll -unroll-factor=4 /tmp/single2.mlir
#map0 = (d0) -> (d0 + 1)
#map1 = (d0) -> (d0 + 2)
#map2 = (d0) -> (d0 + 3)
mlfunc @loop_nest_unroll_cleanup() {
for %i0 = 1 to 100 {
for %i1 = 0 to 17 step 4 {
%0 = "addi32"(%i1, %i1) : (affineint, affineint) -> i32
%1 = "addi32"(%0, %0) : (i32, i32) -> i32
%2 = affine_apply #map0(%i1)
%3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
%4 = affine_apply #map1(%i1)
%5 = "addi32"(%4, %4) : (affineint, affineint) -> i32
%6 = affine_apply #map2(%i1)
%7 = "addi32"(%6, %6) : (affineint, affineint) -> i32
}
for %i2 = 16 to 17 {
%8 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
%9 = "addi32"(%8, %8) : (i32, i32) -> i32
}
}
return
}
PiperOrigin-RevId: 209676220
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 62c0dcb..e8d3894 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -252,7 +252,8 @@
this->insertPoint = insertPoint;
}
- /// Set the insertion point to the specified operation.
+ /// Set the insertion point to the specified operation, which will cause
+ /// subsequent insertions to go right before it.
void setInsertionPoint(Statement *stmt) {
setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
}
@@ -298,8 +299,7 @@
// Creates for statement. When step is not specified, it is set to 1.
ForStmt *createFor(AffineConstantExpr *lowerBound,
- AffineConstantExpr *upperBound,
- AffineConstantExpr *step = nullptr);
+ AffineConstantExpr *upperBound, int64_t step = 1);
IfStmt *createIf(IntegerSet *condition) {
auto *stmt = new IfStmt(condition);
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index fed0d4e..af01a1d 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -47,6 +47,7 @@
struct OperationState {
Identifier name;
SmallVector<SSAValue *, 4> operands;
+ /// Types of the results of this operation.
SmallVector<Type *, 4> types;
SmallVector<NamedAttribute, 4> attributes;
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index 3b615cc..63921ca 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -77,6 +77,10 @@
class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
+ /// Builds an affine apply op with the specified map and operands.
+ static OperationState build(Builder *builder, AffineMap *map,
+ ArrayRef<SSAValue *> operands);
+
// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
return getAttrOfType<AffineMapAttr>("map")->getValue();
@@ -163,6 +167,7 @@
///
class ConstantFloatOp : public ConstantOp {
public:
+ /// Builds a constant float op producing a float of the specified type.
static OperationState build(Builder *builder, double value, FloatType *type);
double getValue() const {
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index f851d9f..1a68aab 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -199,7 +199,7 @@
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
explicit ForStmt(AffineConstantExpr *lowerBound,
- AffineConstantExpr *upperBound, AffineConstantExpr *step,
+ AffineConstantExpr *upperBound, int64_t step,
MLIRContext *context);
~ForStmt() {
@@ -216,7 +216,11 @@
AffineConstantExpr *getLowerBound() const { return lowerBound; }
AffineConstantExpr *getUpperBound() const { return upperBound; }
- AffineConstantExpr *getStep() const { return step; }
+ int64_t getStep() const { return step; }
+
+ void setLowerBound(AffineConstantExpr *lb) { lowerBound = lb; }
+ void setUpperBound(AffineConstantExpr *ub) { upperBound = ub; }
+ void setStep(unsigned s) { step = s; }
using Statement::dump;
using Statement::print;
@@ -242,7 +246,7 @@
// an affinemap and its operands as AffineBound.
AffineConstantExpr *lowerBound;
AffineConstantExpr *upperBound;
- AffineConstantExpr *step;
+ int64_t step;
};
/// An if clause represents statements contained within a then or an else clause
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index 3944c13..3f983d8 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -28,9 +28,9 @@
class MLFunctionPass;
class ModulePass;
-/// Loop unrolling passes.
-MLFunctionPass *createLoopUnrollPass();
-MLFunctionPass *createLoopUnrollPass(unsigned);
+// Loop unrolling passes.
+/// Creates a loop unrolling pass.
+MLFunctionPass *createLoopUnrollPass(int unrollFactor, int unrollFull);
/// Replaces all ML functions in the module with equivalent CFG functions.
/// Function references are appropriately patched to refer to the newly
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 8303f57..6bcaad4 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -1316,8 +1316,8 @@
printOperand(stmt);
os << " = " << *stmt->getLowerBound();
os << " to " << *stmt->getUpperBound();
- if (stmt->getStep()->getValue() != 1)
- os << " step " << *stmt->getStep();
+ if (stmt->getStep() != 1)
+ os << " step " << stmt->getStep();
os << " {\n";
print(static_cast<const StmtBlock *>(stmt));
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 5fce94c..b1dce25 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -103,8 +103,8 @@
return ArrayAttr::get(value, context);
}
-AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
- return AffineMapAttr::get(value, context);
+AffineMapAttr *Builder::getAffineMapAttr(AffineMap *map) {
+ return AffineMapAttr::get(map, context);
}
TypeAttr *Builder::getTypeAttr(Type *type) {
@@ -207,9 +207,7 @@
ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
- AffineConstantExpr *step) {
- if (!step)
- step = getConstantExpr(1);
+ int64_t step) {
auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
block->getStatements().insert(insertPoint, stmt);
return stmt;
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 6d8c366..cec23e0 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -305,6 +305,22 @@
}
//===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
+OperationState AffineApplyOp::build(Builder *builder, AffineMap *map,
+ ArrayRef<SSAValue *> operands) {
+ SmallVector<Type *, 4> resultTypes(map->getNumResults(),
+ builder->getAffineIntType());
+
+ OperationState result(
+ builder->getIdentifier("affine_apply"), operands, resultTypes,
+ {{builder->getIdentifier("map"), builder->getAffineMapAttr(map)}});
+
+ return result;
+}
+
+//===----------------------------------------------------------------------===//
// DeallocOp
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index d20b45f..7da08c2 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -198,7 +198,7 @@
//===----------------------------------------------------------------------===//
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
- AffineConstantExpr *step, MLIRContext *context)
+ int64_t step, MLIRContext *context)
: Statement(Kind::For),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 174d6ca..0debe2d 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -2164,7 +2164,7 @@
ParseResult MLFunctionParser::parseForStmt() {
consumeToken(Token::kw_for);
- // Parse induction variable
+ // Parse induction variable.
if (getToken().isNot(Token::percent_identifier))
return emitError("expected SSA identifier for the loop variable");
@@ -2175,7 +2175,7 @@
if (parseToken(Token::equal, "expected '='"))
return ParseFailure;
- // Parse loop bounds
+ // Parse loop bounds.
AffineConstantExpr *lowerBound = parseIntConstant();
if (!lowerBound)
return ParseFailure;
@@ -2187,12 +2187,13 @@
if (!upperBound)
return ParseFailure;
- // Parse step
- AffineConstantExpr *step = nullptr;
+ // Parse step.
+ int64_t step = 1;
if (consumeIf(Token::kw_step)) {
- step = parseIntConstant();
- if (!step)
+ AffineConstantExpr *stepExpr = parseIntConstant();
+ if (!stepExpr)
return ParseFailure;
+ step = stepExpr->getValue();
}
// Create for statement.
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 0f6d428..a4a11a7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -31,30 +31,49 @@
#include "mlir/Transforms/Pass.h"
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
+using namespace llvm;
+
+// Loop unrolling factor.
+static llvm::cl::opt<unsigned>
+ clUnrollFactor("unroll-factor", cl::Hidden,
+ cl::desc("Use this unroll factor for all loops"));
+
+static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
+ cl::desc("Fully unroll loops"));
+
+static llvm::cl::opt<unsigned> clUnrollFullThreshold(
+ "unroll-full-threshold", cl::Hidden,
+ cl::desc("Unroll all loops with trip count less than or equal to this"));
namespace {
-/// Loop unrolling pass. For now, this unrolls all the innermost loops of this
-/// MLFunction.
+/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
+/// full unroll threshold was specified, in which case, fully unrolls all loops
+/// with trip count less than the specified threshold. The latter is for testing
+/// purposes, especially for testing outer loop unrolling.
struct LoopUnroll : public MLFunctionPass {
- void runOnMLFunction(MLFunction *f) override;
- void runOnForStmt(ForStmt *forStmt);
-};
+ Optional<unsigned> unrollFactor;
+ Optional<bool> unrollFull;
-/// Unrolls all loops with trip count <= minTripCount.
-struct ShortLoopUnroll : public LoopUnroll {
- const unsigned minTripCount;
+ explicit LoopUnroll(Optional<unsigned> unrollFactor,
+ Optional<bool> unrollFull)
+ : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
+
void runOnMLFunction(MLFunction *f) override;
- ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
+ /// Unroll this for stmt. Returns false if nothing was done.
+ bool runOnForStmt(ForStmt *forStmt);
+ bool loopUnrollFull(ForStmt *forStmt);
+ bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
};
} // end anonymous namespace
-MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
-
-MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
- return new ShortLoopUnroll(minTripCount);
+MLFunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
+ return new LoopUnroll(unrollFactor == -1 ? None
+ : Optional<unsigned>(unrollFactor),
+ unrollFull == -1 ? None : Optional<bool>(unrollFull));
}
void LoopUnroll::runOnMLFunction(MLFunction *f) {
@@ -81,7 +100,6 @@
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
if (!hasInnerLoops)
loops.push_back(forStmt);
-
return true;
}
@@ -101,14 +119,6 @@
using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
- InnermostLoopGatherer ilg;
- ilg.walkPostOrder(f);
- auto &loops = ilg.loops;
- for (auto *forStmt : loops)
- runOnForStmt(forStmt);
-}
-
-void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all loops with trip count <= minTripCount.
class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
public:
@@ -120,27 +130,55 @@
void visitForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
- auto step = forStmt->getStep()->getValue();
+ auto step = forStmt->getStep();
if ((ub - lb) / step + 1 <= minTripCount)
loops.push_back(forStmt);
}
};
- ShortLoopGatherer slg(minTripCount);
- // Do a post order walk so that loops are gathered from innermost to
- // outermost (or else unrolling an outer one may delete gathered inner ones).
- slg.walkPostOrder(f);
- auto &loops = slg.loops;
+ if (clUnrollFull.getNumOccurrences() > 0 &&
+ clUnrollFullThreshold.getNumOccurrences() > 0) {
+ ShortLoopGatherer slg(clUnrollFullThreshold);
+ // Do a post order walk so that loops are gathered from innermost to
+ // outermost (or else unrolling an outer one may delete gathered inner
+ // ones).
+ slg.walkPostOrder(f);
+ auto &loops = slg.loops;
+ for (auto *forStmt : loops)
+ loopUnrollFull(forStmt);
+ return;
+ }
+
+ InnermostLoopGatherer ilg;
+ ilg.walkPostOrder(f);
+ auto &loops = ilg.loops;
for (auto *forStmt : loops)
runOnForStmt(forStmt);
}
-/// Unroll this For loop completely.
-void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+/// Unroll a for stmt. Default unroll factor is 4.
+bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+ // Unroll completely if full loop unroll was specified.
+ if (clUnrollFull.getNumOccurrences() > 0 ||
+ (unrollFull.hasValue() && unrollFull.getValue()))
+ return loopUnrollFull(forStmt);
+
+ // Unroll by the specified factor if one was specified.
+ if (clUnrollFactor.getNumOccurrences() > 0)
+ return loopUnrollByFactor(forStmt, clUnrollFactor);
+ else if (unrollFactor.hasValue())
+ return loopUnrollByFactor(forStmt, unrollFactor.getValue());
+
+ // Unroll by four otherwise.
+ return loopUnrollByFactor(forStmt, 4);
+}
+
+// Unrolls this loop completely.
+bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
- auto step = forStmt->getStep()->getValue();
+ auto step = forStmt->getStep();
// Builder to add constants need for the unrolled iterator.
auto *mlFunc = forStmt->findFunction();
@@ -164,9 +202,75 @@
// Clone the body of the loop.
for (auto &childStmt : *forStmt) {
- (void)builder.clone(childStmt, operandMapping);
+ builder.clone(childStmt, operandMapping);
}
}
// Erase the original 'for' stmt from the block.
forStmt->eraseFromBlock();
+ return true;
+}
+
+/// Unrolls this loop by the specified unroll factor.
+bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
+ assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
+
+ if (unrollFactor == 1 || forStmt->getStatements().empty())
+ return false;
+
+ auto lb = forStmt->getLowerBound()->getValue();
+ auto ub = forStmt->getUpperBound()->getValue();
+ auto step = forStmt->getStep();
+
+ int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
+
+ // If the trip count is lower than the unroll factor, no unrolled body.
+ // TODO(bondhugula): option to specify cleanup loop unrolling.
+ if (tripCount < unrollFactor)
+ return true;
+
+ // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
+ if (tripCount % unrollFactor) {
+ DenseMap<const MLValue *, MLValue *> operandMap;
+ MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
+ auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
+ cleanupForStmt->setLowerBound(builder.getConstantExpr(
+ lb + (tripCount - tripCount % unrollFactor) * step));
+ }
+
+ // Builder to insert unrolled bodies right after the last statement in the
+ // body of 'forStmt'.
+ MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
+ forStmt->setStep(step * unrollFactor);
+ forStmt->setUpperBound(builder.getConstantExpr(
+ lb + (tripCount - tripCount % unrollFactor - 1) * step));
+
+ // Keep a pointer to the last statement in the original block so that we know
+ // what to clone (since we are doing this in-place).
+ StmtBlock::iterator srcBlockEnd = --forStmt->end();
+
+ // Unroll the contents of 'forStmt' (unrollFactor-1 additional copies
+ // appended).
+ for (unsigned i = 1; i < unrollFactor; i++) {
+ DenseMap<const MLValue *, MLValue *> operandMapping;
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forStmt->use_empty()) {
+ // iv' = iv + 1/2/3...unrollFactor-1;
+ auto *bumpExpr = builder.getAddExpr(builder.getDimExpr(0),
+ builder.getConstantExpr(i * step));
+ auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {});
+ auto *ivUnroll =
+ builder.create<AffineApplyOp>(bumpMap, forStmt)->getResult(0);
+ operandMapping[forStmt] = cast<MLValue>(ivUnroll);
+ }
+
+ // Clone the original body of the loop (this doesn't include the last stmt).
+ for (auto it = forStmt->begin(); it != srcBlockEnd; it++) {
+ builder.clone(*it, operandMapping);
+ }
+ // Clone the last statement in the original body.
+ builder.clone(*srcBlockEnd, operandMapping);
+ }
+ return true;
}
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index 0f9f208..71536c1 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -1,5 +1,7 @@
-// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
-// RUN: %S/../../mlir-opt %s -o - -unroll-short-loops | FileCheck %s --check-prefix SHORT
+// RUN: mlir-opt %s -o - -loop-unroll -unroll-full | FileCheck %s
+// RUN: mlir-opt %s -o - -loop-unroll -unroll-full -unroll-full-threshold=2 | FileCheck %s --check-prefix SHORT
+// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=4 | FileCheck %s --check-prefix UNROLL-BY-4
+// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=3 | FileCheck %s --check-prefix UNROLL-BY-3
// CHECK: #map0 = (d0) -> (d0 + 1)
@@ -279,3 +281,87 @@
%ret = load %C[%zero_idx, %zero_idx] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2>
return %ret : i32
}
+
+// UNROLL-BY-4-LABEL: mlfunc @unroll_unit_stride_no_cleanup() {
+mlfunc @unroll_unit_stride_no_cleanup() {
+ // UNROLL-BY-4: for %i0 = 1 to 100 {
+ for %i = 1 to 100 {
+ // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 1 to 8 step 4 {
+ // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
+ // UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
+ // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
+ // UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: }
+ for %j = 1 to 8 {
+ %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
+ %y = "addi32"(%x, %x) : (i32, i32) -> i32
+ }
+ // empty loop
+ // UNROLL-BY-4: for %i2 = 1 to 8 {
+ for %k = 1 to 8 {
+ }
+ }
+ return
+}
+
+// UNROLL-BY-4-LABEL: mlfunc @unroll_unit_stride_cleanup() {
+mlfunc @unroll_unit_stride_cleanup() {
+ // UNROLL-BY-4: for %i0 = 1 to 100 {
+ for %i = 1 to 100 {
+ // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 1 to 8 step 4 {
+ // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
+ // UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
+ // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
+ // UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: }
+ // UNROLL-BY-4-NEXT: for [[L2:%i[0-9]+]] = 9 to 10 {
+ // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
+ // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32
+ // UNROLL-BY-4-NEXT: }
+ for %j = 1 to 10 {
+ %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
+ %y = "addi32"(%x, %x) : (i32, i32) -> i32
+ }
+ }
+ return
+}
+
+// UNROLL-BY-3-LABEL: mlfunc @unroll_non_unit_stride_cleanup() {
+mlfunc @unroll_non_unit_stride_cleanup() {
+ // UNROLL-BY-3: for %i0 = 1 to 100 {
+ for %i = 1 to 100 {
+ // UNROLL-BY-3: for [[L1:%i[0-9]+]] = 2 to 12 step 15 {
+ // UNROLL-BY-3-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
+ // UNROLL-BY-3-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
+ // UNROLL-BY-3-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
+ // UNROLL-BY-3-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
+ // UNROLL-BY-3-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
+ // UNROLL-BY-3-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
+ // UNROLL-BY-3-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
+ // UNROLL-BY-3-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
+ // UNROLL-BY-3-NEXT: }
+ // UNROLL-BY-3-NEXT: for [[L2:%i[0-9]+]] = 17 to 20 step 5 {
+ // UNROLL-BY-3-NEXT: %8 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
+ // UNROLL-BY-3-NEXT: %9 = "addi32"(%8, %8) : (i32, i32) -> i32
+ // UNROLL-BY-3-NEXT: }
+ for %j = 2 to 20 step 5 {
+ %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
+ %y = "addi32"(%x, %x) : (i32, i32) -> i32
+ }
+ }
+ return
+}
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 2c6b76b..252e1f2 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -53,8 +53,7 @@
enum Passes {
ConvertToCFG,
- UnrollInnermostLoops,
- UnrollShortLoops,
+ LoopUnroll,
TFRaiseControlFlow,
};
@@ -62,10 +61,7 @@
"", cl::desc("Compiler passes to run"),
cl::values(clEnumValN(ConvertToCFG, "convert-to-cfg",
"Convert all ML functions in the module to CFG ones"),
- clEnumValN(UnrollInnermostLoops, "unroll-innermost-loops",
- "Unroll innermost loops"),
- clEnumValN(UnrollShortLoops, "unroll-short-loops",
- "Unroll loops of trip count <= 2"),
+ clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
"Dynamic TensorFlow Switch/Match nodes to a CFG")));
@@ -112,11 +108,8 @@
case ConvertToCFG:
pass = createConvertToCFGPass();
break;
- case UnrollInnermostLoops:
- pass = createLoopUnrollPass();
- break;
- case UnrollShortLoops:
- pass = createLoopUnrollPass(2);
+ case LoopUnroll:
+ pass = createLoopUnrollPass(-1, -1);
break;
case TFRaiseControlFlow:
pass = createRaiseTFControlFlowPass();