Loop unrolling pass update
- fix/complete forStmt cloning for unrolling to work for outer loops
- create IV const's only when needed
- test outer loop unrolling by creating a short trip count unroll pass for
loops with trip counts <= <parameter>
- add unrolling test cases for multiple op results, outer loop unrolling
- fix/clean up StmtWalker class while on this
- switch unroll loop iterator values from i32 to affineint
PiperOrigin-RevId: 207645967
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index b6cc934..59aca69 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -265,39 +265,16 @@
return op;
}
- OperationStmt *clone(const OperationStmt &srcOpStmt) {
- auto *op = srcOpStmt.clone();
- block->getStatements().insert(insertPoint, op);
- return op;
- }
-
// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Args... args) {
return OpTy::build(this, args...);
}
- ForStmt *clone(const ForStmt &srcForStmt) {
- auto *forStmt = srcForStmt.clone();
- block->getStatements().insert(insertPoint, forStmt);
- return forStmt;
- }
-
- IfStmt *clone(const IfStmt &srcIfStmt) {
- auto *ifStmt = srcIfStmt.clone();
- block->getStatements().insert(insertPoint, ifStmt);
- return ifStmt;
- }
-
Statement *clone(const Statement &stmt) {
- switch (stmt.getKind()) {
- case Statement::Kind::Operation:
- return clone(cast<const OperationStmt>(stmt));
- case Statement::Kind::If:
- return clone(cast<const IfStmt>(stmt));
- case Statement::Kind::For:
- return clone(cast<const ForStmt>(stmt));
- }
+ Statement *cloneStmt = stmt.clone();
+ block->getStatements().insert(insertPoint, cloneStmt);
+ return cloneStmt;
}
// Creates for statement. When step is not specified, it is set to 1.
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index 8cfe432..f8701bf 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -158,12 +158,13 @@
};
/// This is a refinement of the "constant" op for the case where it is
-/// returning an integer value (either an IntegerType or AffineInt).
+/// returning an integer value of IntegerType.
///
/// %1 = "constant"(){value: 42}
///
class ConstantIntOp : public ConstantOp {
public:
+ /// Build a constant int op producing an integer of the specified width.
template <class Builder>
static OpPointer<ConstantIntOp> build(Builder *builder, int64_t value,
unsigned width) {
@@ -186,6 +187,36 @@
explicit ConstantIntOp(const Operation *state) : ConstantOp(state) {}
};
+/// This is a refinement of the "constant" op for the case where it is
+/// returning an integer value of AffineInt type.
+///
+/// %1 = "constant"(){value: 99} : () -> affineint
+///
+class ConstantAffineIntOp : public ConstantOp {
+public:
+ /// Build a constant int op producing an affineint.
+ template <class Builder>
+ static OpPointer<ConstantAffineIntOp> build(Builder *builder, int64_t value) {
+ std::pair<Identifier, Attribute *> namedAttr(
+ builder->getIdentifier("value"), builder->getIntegerAttr(value));
+ auto *type = builder->getAffineIntType();
+
+ return OpPointer<ConstantAffineIntOp>(
+ ConstantAffineIntOp(builder->createOperation(
+ builder->getIdentifier("constant"), {}, type, {namedAttr})));
+ }
+
+ int64_t getValue() const {
+ return getAttrOfType<IntegerAttr>("value")->getValue();
+ }
+
+ static bool isClassFor(const Operation *op);
+
+private:
+ friend class Operation;
+ explicit ConstantAffineIntOp(const Operation *state) : ConstantOp(state) {}
+};
+
/// The "dim" operation takes a memref or tensor operand and returns an
/// "affineint". It requires a single integer attribute named "index". It
/// returns the size of the specified dimension. For example:
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index 2326d50..43bbd2c 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -31,6 +31,7 @@
class StmtBlock;
class ForStmt;
class MLIRContext;
+class MLValue;
/// Statement is a basic unit of execution within an ML function.
/// Statements can be nested within for and if statements effectively
@@ -72,6 +73,9 @@
void print(raw_ostream &os) const;
void dump() const;
+ /// Replace all uses of 'oldVal' with 'newVal' in 'stmt'.
+ void replaceUses(MLValue *oldVal, MLValue *newVal);
+
protected:
Statement(Kind kind) : kind(kind) {}
// Statements are deleted through the destroy() member because this class
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index a8a1f20..50463f0 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -24,11 +24,11 @@
#include "mlir/IR/Statement.h"
#include "mlir/Support/LLVM.h"
-#include "llvm/Support/raw_ostream.h"
namespace mlir {
class MLFunction;
class IfStmt;
+class MLValue;
/// Statement block represents an ordered list of statements, with the order
/// being the contiguous lexical order in which the statements appear as
diff --git a/include/mlir/IR/StmtVisitor.h b/include/mlir/IR/StmtVisitor.h
index 73f17e3..3c10743 100644
--- a/include/mlir/IR/StmtVisitor.h
+++ b/include/mlir/IR/StmtVisitor.h
@@ -123,19 +123,24 @@
walk(&(*Start++));
}
}
+ template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
+ while (Start != End) {
+ walkPostOrder(&(*Start++));
+ }
+ }
// Define walkers for MLFunction and all MLFunction statement kinds.
void walk(MLFunction *f) {
static_cast<SubClass *>(this)->visitMLFunction(f);
- static_cast<SubClass *>(this)->walk(f->begin(), f->end());
+ walk(f->begin(), f->end());
}
void walkPostOrder(MLFunction *f) {
- walk(f->begin(), f->end());
- return static_cast<SubClass *>(this)->visitMLFunction(f);
+ walkPostOrder(f->begin(), f->end());
+ static_cast<SubClass *>(this)->visitMLFunction(f);
}
- void walkOpStmt(OperationStmt *opStmt) {
+ RetTy walkOpStmt(OperationStmt *opStmt) {
return static_cast<SubClass *>(this)->visitOperationStmt(opStmt);
}
@@ -204,7 +209,8 @@
// When visiting a specific stmt directly during a walk, these methods get
// called. These are typically O(1) complexity and shouldn't be recursively
- // processing their descendants in some way.
+ // processing their descendants in some way. When using RetTy, all of these
+ // need to be overridden.
void visitMLFunction(MLFunction *f) {}
void visitForStmt(ForStmt *forStmt) {}
void visitIfStmt(IfStmt *ifStmt) {}
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index 7352ad0..3944c13 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -28,8 +28,9 @@
class MLFunctionPass;
class ModulePass;
-/// A loop unrolling pass.
+/// Loop unrolling passes.
MLFunctionPass *createLoopUnrollPass();
+MLFunctionPass *createLoopUnrollPass(unsigned);
/// Replaces all ML functions in the module with equivalent CFG functions.
/// Function references are appropriately patched to refer to the newly
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 0a76587..068e498 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -85,11 +85,35 @@
llvm_unreachable("cloning for if's not implemented yet");
return cast<IfStmt>(this)->clone();
case Kind::For:
- llvm_unreachable("cloning for loops not implemented yet");
return cast<ForStmt>(this)->clone();
}
}
+/// Replaces all uses of oldVal with newVal.
+// TODO(bondhugula,clattner): do this more efficiently by walking those uses of
+// oldVal that fall within this statement.
+void Statement::replaceUses(MLValue *oldVal, MLValue *newVal) {
+ struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
+ // Value to be replaced.
+ MLValue *oldVal;
+ // Value to be replaced with.
+ MLValue *newVal;
+
+ ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
+ : oldVal(oldVal), newVal(newVal){};
+
+ void visitOperationStmt(OperationStmt *os) {
+ for (auto &operand : os->getStmtOperands()) {
+ if (operand.get() == oldVal)
+ operand.set(newVal);
+ }
+ }
+ };
+
+ ReplaceUseWalker ri(oldVal, newVal);
+ ri.walk(this);
+}
+
//===----------------------------------------------------------------------===//
// ilist_traits for Statement
//===----------------------------------------------------------------------===//
@@ -233,12 +257,34 @@
upperBound(upperBound), step(step) {}
ForStmt *ForStmt::clone() const {
- auto *stmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
- Statement::findFunction()->getContext());
+ auto *forStmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
+ Statement::findFunction()->getContext());
+
+ // Pairs of <old op stmt result whose uses need to be replaced,
+ // new result generated by the corresponding cloned op stmt>.
+ SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
for (auto &s : getStatements()) {
- stmt->getStatements().push_back(s.clone());
+ auto *cloneStmt = s.clone();
+ forStmt->getStatements().push_back(cloneStmt);
+ if (auto *opStmt = dyn_cast<OperationStmt>(&s)) {
+ auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
+ for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
+ oldNewResultPairs.push_back(
+ std::make_pair(const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
+ &cloneOpStmt->getStmtResult(i)));
+ }
+ }
}
- return stmt;
+ // Replace uses of old op results' with the newly created ones.
+ for (unsigned i = 0, e = oldNewResultPairs.size(); i < e; i++) {
+ for (auto &stmt : *forStmt) {
+ stmt.replaceUses(oldNewResultPairs[i].first, oldNewResultPairs[i].second);
+ }
+ }
+
+ // Replace uses of old loop IV with the new one.
+ forStmt->Statement::replaceUses(const_cast<ForStmt *>(this), forStmt);
+ return forStmt;
}
//===----------------------------------------------------------------------===//
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 27bb43f..eea3bf7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -39,14 +39,22 @@
void runOnMLFunction(MLFunction *f) override;
void runOnForStmt(ForStmt *forStmt);
};
+struct ShortLoopUnroll : public LoopUnroll {
+ const unsigned minTripCount;
+ void runOnMLFunction(MLFunction *f) override;
+ ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
+};
} // end anonymous namespace
MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
+MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
+ return new ShortLoopUnroll(minTripCount);
+}
+
/// Unrolls all the innermost loops of this MLFunction.
void LoopUnroll::runOnMLFunction(MLFunction *f) {
// Gathers all innermost loops through a post order pruned walk.
- // TODO: figure out the right reusable template here to better refactor code.
class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
public:
// Store innermost loops as we walk.
@@ -65,11 +73,6 @@
return hasInnerLoops;
}
- // FIXME: can't use base class method for this because that in turn would
- // need to use the derived class method above. CRTP doesn't allow it, and
- // the compiler error resulting from it is also very misleading!
- void walkPostOrder(MLFunction *f) { walkPostOrder(f->begin(), f->end()); }
-
bool walkForStmtPostOrder(ForStmt *forStmt) {
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
if (!hasInnerLoops)
@@ -85,8 +88,11 @@
return hasInnerLoops;
}
- bool walkOpStmt(OperationStmt *opStmt) { return false; }
+ bool visitOperationStmt(OperationStmt *opStmt) { return false; }
+ // FIXME: can't use base class method for this because that in turn would
+ // need to use the derived class method above. CRTP doesn't allow it, and
+ // the compiler error resulting from it is also misleading.
using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
};
@@ -97,28 +103,96 @@
runOnForStmt(forStmt);
}
-/// Replace all uses of 'oldVal' with 'newVal' in 'stmt'
-static void replaceAllStmtUses(Statement *stmt, MLValue *oldVal,
- MLValue *newVal) {
- struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
- // Value to be replaced.
- MLValue *oldVal;
- // Value to be replaced with.
- MLValue *newVal;
+/// Unrolls all loops with trip count <= minTripCount.
+void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
+ // Gathers all loops with trip count <= minTripCount.
+ class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
+ public:
+ // Store short loops as we walk.
+ std::vector<ForStmt *> loops;
+ const unsigned minTripCount;
+ ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
- ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
- : oldVal(oldVal), newVal(newVal){};
+ void visitForStmt(ForStmt *forStmt) {
+ auto lb = forStmt->getLowerBound()->getValue();
+ auto ub = forStmt->getUpperBound()->getValue();
+ auto step = forStmt->getStep()->getValue();
- void visitOperationStmt(OperationStmt *os) {
- for (auto &operand : os->getStmtOperands()) {
- if (operand.get() == oldVal)
- operand.set(newVal);
- }
+ if ((ub - lb) / step + 1 <= minTripCount)
+ loops.push_back(forStmt);
}
};
- ReplaceUseWalker ri(oldVal, newVal);
- ri.walk(stmt);
+ ShortLoopGatherer slg(minTripCount);
+ slg.walk(f);
+ auto &loops = slg.loops;
+ for (auto *forStmt : loops)
+ runOnForStmt(forStmt);
+}
+
+/// Replace all uses of oldVal with newVal from begin to end.
+static void replaceUses(StmtBlock::iterator begin, StmtBlock::iterator end,
+ MLValue *oldVal, MLValue *newVal) {
+ // TODO(bondhugula,clattner): do this more efficiently by walking those uses
+ // of oldVal that fall within this list of statements (instead of iterating
+ // through all statements / through all operands of operations found).
+ for (auto it = begin; it != end; it++) {
+ it->replaceUses(oldVal, newVal);
+ }
+}
+
+/// Replace all uses of oldVal with newVal.
+void replaceUses(StmtBlock *block, MLValue *oldVal, MLValue *newVal) {
+ // TODO(bondhugula,clattner): do this more efficiently by walking those uses
+ // of oldVal that fall within this StmtBlock (instead of iterating through
+ // all statements / through all operands of operations found).
+ for (auto it = block->begin(); it != block->end(); it++) {
+ it->replaceUses(oldVal, newVal);
+ }
+}
+
+/// Clone the list of stmt's from 'block' and insert into the current
+/// position of the builder.
+// TODO(bondhugula,clattner): replace this with a parameterizable clone.
+void cloneStmtListFromBlock(MLFuncBuilder *builder, const StmtBlock &block) {
+ // Pairs of <old op stmt result whose uses need to be replaced,
+ // new result generated by the corresponding cloned op stmt>.
+ SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+
+ // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+ StmtBlock::iterator beforeUnrolledBody = --builder->getInsertionPoint();
+
+ for (auto &stmt : block.getStatements()) {
+ auto *cloneStmt = builder->clone(stmt);
+ // Whenever we have an op stmt, we'll have a new ML Value defined: replace
+ // uses of the old result with this one.
+ if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ if (opStmt->getNumResults()) {
+ auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
+ for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
+ // Store old/new result pairs.
+ // TODO(bondhugula) *only* if needed later: storing of old/new
+ // results can be avoided by cloning the statement list in the
+ // reverse direction (and running the IR builder in the reverse
+ // (iplist.insertAfter()). That way, a newly created result can be
+ // immediately propagated to all its uses.
+ oldNewResultPairs.push_back(std::make_pair(
+ const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
+ &cloneOpStmt->getStmtResult(i)));
+ }
+ }
+ }
+ }
+
+ // Replace uses of old op results' with the new results.
+ StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
+ StmtBlock::iterator endOfUnrolledBody = builder->getInsertionPoint();
+
+ // Replace uses of old op results' with the newly created ones.
+ for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
+ replaceUses(startOfUnrolledBody, endOfUnrolledBody,
+ oldNewResultPairs[i].first, oldNewResultPairs[i].second);
+ }
}
/// Unroll this 'for stmt' / loop completely.
@@ -139,52 +213,25 @@
++StmtBlock::iterator(forStmt));
// Unroll the contents of 'forStmt'.
- for (int i = lb; i <= ub; i += step) {
- // TODO(bondhugula): generate constants only when IV actually appears.
- auto constOp = funcTopBuilder.create<ConstantIntOp>(i, 32);
- auto *ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
-
- // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+ for (int64_t i = lb; i <= ub; i += step) {
+ MLValue *ivConst = nullptr;
+ if (!forStmt->use_empty()) {
+ auto constOp = funcTopBuilder.create<ConstantAffineIntOp>(i);
+ ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
+ }
StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
- // Pairs of <old op stmt result whose uses need to be replaced,
- // new result generated by the corresponding cloned op stmt>.
- SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+ // Clone the loop body and insert it right after the loop - the latter will
+ // be erased after all unrolling has been done.
+ cloneStmtListFromBlock(&builder, *forStmt);
- for (auto &loopBodyStmt : forStmt->getStatements()) {
- auto *cloneStmt = builder.clone(loopBodyStmt);
- // Replace all uses of the IV in the clone with constant iteration value.
- replaceAllStmtUses(cloneStmt, forStmt, ivConst);
-
- // Whenever we have an op stmt, we'll have a new ML Value defined: replace
- // uses of the old result with this one.
- if (auto *opStmt = dyn_cast<OperationStmt>(&loopBodyStmt)) {
- if (opStmt->getNumResults()) {
- auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
- for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
- // Store old/new result pairs.
- // TODO *only* if needed later: storing of old/new results can be
- // avoided, by cloning the statement list in the reverse direction
- // (and running the IR builder in the reverse
- // (iplist.insertAfter()). That way, a newly created result can be
- // immediately propagated to all its uses, which would already been
- // cloned/inserted.
- oldNewResultPairs.push_back(std::make_pair(
- &opStmt->getStmtResult(i), &cloneOpStmt->getStmtResult(i)));
- }
- }
- }
- }
- // Replace uses of old op results' with the results in the just
- // unrolled body.
- StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
- for (auto it = ++beforeUnrolledBody; it != endOfUnrolledBody; it++) {
- for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
- replaceAllStmtUses(&(*it), oldNewResultPairs[i].first,
- oldNewResultPairs[i].second);
- }
+ // Replace unrolled loop IV with the unrolled constant.
+ if (ivConst) {
+ StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
+ StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
+ replaceUses(startOfUnrolledBody, endOfUnrolledBody, forStmt, ivConst);
}
}
- // Erase the original for stmt from the block.
+ // Erase the original 'for' stmt from the block.
forStmt->eraseFromBlock();
}
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index f091ba1..42f1c2a 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -1,17 +1,14 @@
// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
+// RUN: %S/../../mlir-opt %s -o - -unroll-short-loops | FileCheck %s --check-prefix SHORT
// CHECK-LABEL: mlfunc @loop_nest_simplest() {
mlfunc @loop_nest_simplest() {
- // CHECK: %c1_i32 = constant 1 : i32
- // CHECK-NEXT: %c2_i32 = constant 2 : i32
- // CHECK-NEXT: %c3_i32 = constant 3 : i32
- // CHECK-NEXT: %c4_i32 = constant 4 : i32
- // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
+ // CHECK: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
- // CHECK: %c1_i32_0 = constant 1 : i32
+ // CHECK: %c1_i32 = constant 1 : i32
+ // CHECK-NEXT: %c1_i32_0 = constant 1 : i32
// CHECK-NEXT: %c1_i32_1 = constant 1 : i32
// CHECK-NEXT: %c1_i32_2 = constant 1 : i32
- // CHECK-NEXT: %c1_i32_3 = constant 1 : i32
for %j = 1 to 4 {
%x = constant 1 : i32
}
@@ -21,16 +18,16 @@
// CHECK-LABEL: mlfunc @loop_nest_simple_iv_use() {
mlfunc @loop_nest_simple_iv_use() {
- // CHECK: %c1_i32 = constant 1 : i32
- // CHECK-NEXT: %c2_i32 = constant 2 : i32
- // CHECK-NEXT: %c3_i32 = constant 3 : i32
- // CHECK-NEXT: %c4_i32 = constant 4 : i32
+ // CHECK: %c1 = constant 1 : affineint
+ // CHECK-NEXT: %c2 = constant 2 : affineint
+ // CHECK-NEXT: %c3 = constant 3 : affineint
+ // CHECK-NEXT: %c4 = constant 4 : affineint
// CHECK-NEXT: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
- // CHECK: %0 = "addi32"(%c1_i32, %c1_i32) : (i32, i32) -> i32
- // CHECK-NEXT: %1 = "addi32"(%c2_i32, %c2_i32) : (i32, i32) -> i32
- // CHECK-NEXT: %2 = "addi32"(%c3_i32, %c3_i32) : (i32, i32) -> i32
- // CHECK-NEXT: %3 = "addi32"(%c4_i32, %c4_i32) : (i32, i32) -> i32
+ // CHECK: %0 = "addi32"(%c1, %c1) : (affineint, affineint) -> i32
+ // CHECK-NEXT: %1 = "addi32"(%c2, %c2) : (affineint, affineint) -> i32
+ // CHECK-NEXT: %2 = "addi32"(%c3, %c3) : (affineint, affineint) -> i32
+ // CHECK-NEXT: %3 = "addi32"(%c4, %c4) : (affineint, affineint) -> i32
for %j = 1 to 4 {
%x = "addi32"(%j, %j) : (affineint, affineint) -> i32
}
@@ -38,29 +35,57 @@
return // CHECK: return
} // CHECK }
+// Operations in the loop body have results that are used therein.
+// CHECK-LABEL: mlfunc @loop_nest_body_def_use() {
+mlfunc @loop_nest_body_def_use() {
+ // CHECK: %c0 = constant 0 : affineint
+ // CHECK-NEXT: %c1 = constant 1 : affineint
+ // CHECK-NEXT: %c2 = constant 2 : affineint
+ // CHECK-NEXT: %c3 = constant 3 : affineint
+ // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
+ for %i = 1 to 100 step 2 {
+ // CHECK: %c0_0 = constant 0 : affineint
+ %c0 = constant 0 : affineint
+ // CHECK: %0 = affine_apply (d0) -> (d0 + 1)(%c0)
+ // CHECK-NEXT: %1 = "addi32"(%0, %c0_0) : (affineint, affineint) -> affineint
+ // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c1)
+ // CHECK-NEXT: %3 = "addi32"(%2, %c0_0) : (affineint, affineint) -> affineint
+ // CHECK-NEXT: %4 = affine_apply (d0) -> (d0 + 1)(%c2)
+ // CHECK-NEXT: %5 = "addi32"(%4, %c0_0) : (affineint, affineint) -> affineint
+ // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c3)
+ // CHECK-NEXT: %7 = "addi32"(%6, %c0_0) : (affineint, affineint) -> affineint
+ for %j = 0 to 3 {
+ %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
+ (affineint) -> (affineint)
+ %y = "addi32"(%x, %c0) : (affineint, affineint) -> affineint
+ }
+ } // CHECK: }
+ return // CHECK: return
+} // CHECK }
+
// CHECK-LABEL: mlfunc @loop_nest_strided() {
mlfunc @loop_nest_strided() {
- // CHECK: %c3_i32 = constant 3 : i32
- // CHECK-NEXT: %c5_i32 = constant 5 : i32
- // CHECK-NEXT: %c7_i32 = constant 7 : i32
- // CHECK-NEXT: %c3_i32_0 = constant 3 : i32
- // CHECK-NEXT: %c5_i32_1 = constant 5 : i32
+ // CHECK: %c3 = constant 3 : affineint
+ // CHECK-NEXT: %c5 = constant 5 : affineint
+ // CHECK-NEXT: %c7 = constant 7 : affineint
+ // CHECK-NEXT: %c3_0 = constant 3 : affineint
+ // CHECK-NEXT: %c5_1 = constant 5 : affineint
// CHECK-NEXT: for %i0 = 1 to 100 {
for %i = 1 to 100 {
- // CHECK: %0 = affine_apply (d0) -> (d0 + 1)(%c3_i32_0)
+ // CHECK: %0 = affine_apply (d0) -> (d0 + 1)(%c3_0)
// CHECK-NEXT: %1 = "addi32"(%0, %0) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c5_i32_1)
+ // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c5_1)
// CHECK-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> affineint
for %j = 3 to 6 step 2 {
%x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
(affineint) -> (affineint)
%y = "addi32"(%x, %x) : (affineint, affineint) -> affineint
}
- // CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
+ // CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c3)
// CHECK-NEXT: %5 = "addi32"(%4, %4) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c5_i32)
+ // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c5)
// CHECK-NEXT: %7 = "addi32"(%6, %6) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c7_i32)
+ // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c7)
// CHECK-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> affineint
for %k = 3 to 7 step 2 {
%z = "affine_apply" (%k) { map: (d0) -> (d0 + 1) } :
@@ -71,29 +96,26 @@
return // CHECK: return
} // CHECK }
-// Operations in the loop body have results that are used therein.
-// CHECK-LABEL: mlfunc @loop_nest_body_def_use() {
-mlfunc @loop_nest_body_def_use() {
- // CHECK: %c0_i32 = constant 0 : i32
- // CHECK-NEXT: %c1_i32 = constant 1 : i32
- // CHECK-NEXT: %c2_i32 = constant 2 : i32
- // CHECK-NEXT: %c3_i32 = constant 3 : i32
- // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
- for %i = 1 to 100 step 2 {
- // CHECK: %c0 = constant 0 : affineint
- %c0 = constant 0 : affineint
- // CHECK: %0 = affine_apply (d0) -> (d0 + 1)(%c0_i32)
- // CHECK-NEXT: %1 = "addi32"(%0, %c0) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
- // CHECK-NEXT: %3 = "addi32"(%2, %c0) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %4 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
- // CHECK-NEXT: %5 = "addi32"(%4, %c0) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
- // CHECK-NEXT: %7 = "addi32"(%6, %c0) : (affineint, affineint) -> affineint
- for %j = 0 to 3 {
- %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
- (affineint) -> (affineint)
- %y = "addi32"(%x, %c0) : (affineint, affineint) -> affineint
+// CHECK-LABEL: mlfunc @loop_nest_multiple_results() {
+mlfunc @loop_nest_multiple_results() {
+ // CHECK: %c0 = constant 0 : affineint
+ // CHECK-NEXT: %c1 = constant 1 : affineint
+ for %i = 1 to 100 {
+ // CHECK: %0 = affine_apply (d0, d1) -> (d0 + 1, d1 + 2)(%i0, %c0)
+ // CHECK-NEXT: %1 = "addi32"(%0#0, %0#1) : (affineint, affineint) -> affineint
+ // CHECK-NEXT: %2 = affine_apply (d0, d1) -> (d0 + 3, d1 + 4)(%i0, %c0)
+ // CHECK-NEXT: %3 = "fma"(%2#0, %2#1, %0#0) : (affineint, affineint, affineint) -> (affineint, affineint)
+ // CHECK-NEXT: %4 = affine_apply (d0, d1) -> (d0 + 1, d1 + 2)(%i0, %c1)
+ // CHECK-NEXT: %5 = "addi32"(%4#0, %4#1) : (affineint, affineint) -> affineint
+ // CHECK-NEXT: %6 = affine_apply (d0, d1) -> (d0 + 3, d1 + 4)(%i0, %c1)
+ // CHECK-NEXT: %7 = "fma"(%6#0, %6#1, %4#0) : (affineint, affineint, affineint) -> (affineint, affineint)
+ for %j = 0 to 1 step 1 {
+ %x = "affine_apply" (%i, %j) { map: (d0, d1) -> (d0 + 1, d1 + 2) } :
+ (affineint, affineint) -> (affineint, affineint)
+ %y = "addi32"(%x#0, %x#1) : (affineint, affineint) -> affineint
+ %z = "affine_apply" (%i, %j) { map: (d0, d1) -> (d0 + 3, d1 + 4) } :
+ (affineint, affineint) -> (affineint, affineint)
+ %w = "fma"(%z#0, %z#1, %x#0) : (affineint, affineint, affineint) -> (affineint, affineint)
}
} // CHECK: }
return // CHECK: return
@@ -103,27 +125,27 @@
// Imperfect loop nest. Unrolling innermost here yields a perfect nest.
// CHECK-LABEL: mlfunc @loop_nest_seq_imperfect(%arg0 : memref<128x128xf32>) {
mlfunc @loop_nest_seq_imperfect(%a : memref<128x128xf32>) {
- // CHECK: %c1_i32 = constant 1 : i32
- // CHECK-NEXT: %c2_i32 = constant 2 : i32
- // CHECK-NEXT: %c3_i32 = constant 3 : i32
- // CHECK-NEXT: %c4_i32 = constant 4 : i32
+ // CHECK: %c1 = constant 1 : affineint
+ // CHECK-NEXT: %c2 = constant 2 : affineint
+ // CHECK-NEXT: %c3 = constant 3 : affineint
+ // CHECK-NEXT: %c4 = constant 4 : affineint
// CHECK-NEXT: %c128 = constant 128 : affineint
%c128 = constant 128 : affineint
// CHECK: for %i0 = 1 to 100 {
for %i = 1 to 100 {
// CHECK: %0 = "vld"(%i0) : (affineint) -> i32
%ld = "vld"(%i) : (affineint) -> i32
- // CHECK: %1 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
- // CHECK-NEXT: %2 = "vmulf"(%c1_i32, %1) : (i32, affineint) -> affineint
+ // CHECK: %1 = affine_apply (d0) -> (d0 + 1)(%c1)
+ // CHECK-NEXT: %2 = "vmulf"(%c1, %1) : (affineint, affineint) -> affineint
// CHECK-NEXT: %3 = "vaddf"(%2, %2) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %4 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
- // CHECK-NEXT: %5 = "vmulf"(%c2_i32, %4) : (i32, affineint) -> affineint
+ // CHECK-NEXT: %4 = affine_apply (d0) -> (d0 + 1)(%c2)
+ // CHECK-NEXT: %5 = "vmulf"(%c2, %4) : (affineint, affineint) -> affineint
// CHECK-NEXT: %6 = "vaddf"(%5, %5) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %7 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
- // CHECK-NEXT: %8 = "vmulf"(%c3_i32, %7) : (i32, affineint) -> affineint
+ // CHECK-NEXT: %7 = affine_apply (d0) -> (d0 + 1)(%c3)
+ // CHECK-NEXT: %8 = "vmulf"(%c3, %7) : (affineint, affineint) -> affineint
// CHECK-NEXT: %9 = "vaddf"(%8, %8) : (affineint, affineint) -> affineint
- // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c4_i32)
- // CHECK-NEXT: %11 = "vmulf"(%c4_i32, %10) : (i32, affineint) -> affineint
+ // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c4)
+ // CHECK-NEXT: %11 = "vmulf"(%c4, %10) : (affineint, affineint) -> affineint
// CHECK-NEXT: %12 = "vaddf"(%11, %11) : (affineint, affineint) -> affineint
for %j = 1 to 4 {
%x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
@@ -141,21 +163,21 @@
// CHECK-LABEL: mlfunc @loop_nest_seq_multiple() {
mlfunc @loop_nest_seq_multiple() {
- // CHECK: %c1_i32 = constant 1 : i32
- // CHECK-NEXT: %c2_i32 = constant 2 : i32
- // CHECK-NEXT: %c3_i32 = constant 3 : i32
- // CHECK-NEXT: %c4_i32 = constant 4 : i32
- // CHECK-NEXT: %c0_i32 = constant 0 : i32
- // CHECK-NEXT: %c1_i32_0 = constant 1 : i32
- // CHECK-NEXT: %c2_i32_1 = constant 2 : i32
- // CHECK-NEXT: %c3_i32_2 = constant 3 : i32
- // CHECK-NEXT: %0 = affine_apply (d0) -> (d0 + 1)(%c0_i32)
+ // CHECK: %c1 = constant 1 : affineint
+ // CHECK-NEXT: %c2 = constant 2 : affineint
+ // CHECK-NEXT: %c3 = constant 3 : affineint
+ // CHECK-NEXT: %c4 = constant 4 : affineint
+ // CHECK-NEXT: %c0 = constant 0 : affineint
+ // CHECK-NEXT: %c1_0 = constant 1 : affineint
+ // CHECK-NEXT: %c2_1 = constant 2 : affineint
+ // CHECK-NEXT: %c3_2 = constant 3 : affineint
+ // CHECK-NEXT: %0 = affine_apply (d0) -> (d0 + 1)(%c0)
// CHECK-NEXT: "mul"(%0, %0) : (affineint, affineint) -> ()
- // CHECK-NEXT: %1 = affine_apply (d0) -> (d0 + 1)(%c1_i32_0)
+ // CHECK-NEXT: %1 = affine_apply (d0) -> (d0 + 1)(%c1_0)
// CHECK-NEXT: "mul"(%1, %1) : (affineint, affineint) -> ()
- // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c2_i32_1)
+ // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c2_1)
// CHECK-NEXT: "mul"(%2, %2) : (affineint, affineint) -> ()
- // CHECK-NEXT: %3 = affine_apply (d0) -> (d0 + 1)(%c3_i32_2)
+ // CHECK-NEXT: %3 = affine_apply (d0) -> (d0 + 1)(%c3_2)
// CHECK-NEXT: "mul"(%3, %3) : (affineint, affineint) -> ()
for %j = 0 to 3 {
%x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
@@ -167,14 +189,14 @@
%k = "constant"(){value: 99} : () -> affineint
// CHECK: for %i0 = 1 to 100 step 2 {
for %m = 1 to 100 step 2 {
- // CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
- // CHECK-NEXT: %5 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c1_i32)[%c99]
- // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
- // CHECK-NEXT: %7 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c2_i32)[%c99]
- // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
- // CHECK-NEXT: %9 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c3_i32)[%c99]
- // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c4_i32)
- // CHECK-NEXT: %11 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c4_i32)[%c99]
+ // CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c1)
+ // CHECK-NEXT: %5 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c1)[%c99]
+ // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c2)
+ // CHECK-NEXT: %7 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c2)[%c99]
+ // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c3)
+ // CHECK-NEXT: %9 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c3)[%c99]
+ // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c4)
+ // CHECK-NEXT: %11 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c4)[%c99]
for %n = 1 to 4 {
%y = "affine_apply" (%n) { map: (d0) -> (d0 + 1) } :
(affineint) -> (affineint)
@@ -184,3 +206,23 @@
} // CHECK }
return // CHECK: return
} // CHECK }
+
+// SHORT-LABEL: mlfunc @loop_nest_outer_unroll() {
+mlfunc @loop_nest_outer_unroll() {
+ // SHORT: for %i0 = 1 to 4 {
+ // SHORT-NEXT: %0 = affine_apply (d0) -> (d0 + 1)(%i0)
+ // SHORT-NEXT: %1 = "addi32"(%0, %0) : (affineint, affineint) -> affineint
+ // SHORT-NEXT: }
+ // SHORT-NEXT: for %i1 = 1 to 4 {
+ // SHORT-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%i1)
+ // SHORT-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> affineint
+ // SHORT-NEXT: }
+ for %i = 1 to 2 {
+ for %j = 1 to 4 {
+ %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
+ (affineint) -> (affineint)
+ %y = "addi32"(%x, %x) : (affineint, affineint) -> affineint
+ }
+ }
+ return // SHORT: return
+} // SHORT }
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index c95b838..26417e2 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -54,6 +54,7 @@
enum Passes {
ConvertToCFG,
UnrollInnermostLoops,
+ UnrollShortLoops,
TFRaiseControlFlow,
};
@@ -63,6 +64,8 @@
"Convert all ML functions in the module to CFG ones"),
clEnumValN(UnrollInnermostLoops, "unroll-innermost-loops",
"Unroll innermost loops"),
+ clEnumValN(UnrollShortLoops, "unroll-short-loops",
+ "Unroll loops of trip count <= 2"),
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
"Dynamic TensorFlow Switch/Match nodes to a CFG")));
@@ -111,6 +114,9 @@
case UnrollInnermostLoops:
pass = createLoopUnrollPass();
break;
+ case UnrollShortLoops:
+ pass = createLoopUnrollPass(2);
+ break;
case TFRaiseControlFlow:
pass = createRaiseTFControlFlowPass();
break;