MLStmt cloning and IV replacement for loop unrolling, add constant pool to
MLFunctions.
- MLStmt cloning and IV replacement
- While at this, fix the innermostLoopGatherer to actually gather all the
innermost loops (it was stopping its walk at the first innermost loop it
found)
- Improve comments for MLFunction statement classes, fix inheritance order.
- Fixed StmtBlock destructor.
PiperOrigin-RevId: 207049173
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 67cc7b8..4f7161b 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -18,6 +18,7 @@
#ifndef MLIR_IR_BUILDERS_H
#define MLIR_IR_BUILDERS_H
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
@@ -162,6 +163,12 @@
return op;
}
+ OperationInst *cloneOperation(const OperationInst &srcOpInst) {
+ auto *op = srcOpInst.clone();
+ block->getOperations().insert(insertPoint, op);
+ return op;
+ }
+
// Terminators.
ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
@@ -232,6 +239,12 @@
insertPoint = block->end();
}
+ /// Set the insertion point at the beginning of the specified block.
+ void setInsertionPointAtStart(StmtBlock *block) {
+ this->block = block;
+ insertPoint = block->begin();
+ }
+
OperationStmt *createOperation(Identifier name, ArrayRef<MLValue *> operands,
ArrayRef<Type *> resultTypes,
ArrayRef<NamedAttribute> attributes) {
@@ -241,6 +254,12 @@
return op;
}
+ OperationStmt *cloneOperation(const OperationStmt &srcOpStmt) {
+ auto *op = srcOpStmt.clone();
+ block->getStatements().insert(insertPoint, op);
+ return op;
+ }
+
// Creates for statement. When step is not specified, it is set to 1.
ForStmt *createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
@@ -252,6 +271,15 @@
return stmt;
}
+ // TODO: subsume with a generate create<ConstantInt>() method.
+ OperationStmt *createConstInt32Op(int value) {
+ std::pair<Identifier, Attribute *> namedAttr(
+ Identifier::get("value", context), getIntegerAttr(value));
+ auto *mlconst = createOperation(Identifier::get("constant", context), {},
+ {getIntegerType(32)}, {namedAttr});
+ return mlconst;
+ }
+
private:
StmtBlock *block = nullptr;
StmtBlock::iterator insertPoint;
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index a18a97f..1eac36c 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -158,6 +158,8 @@
/// Return the context this operation is associated with.
MLIRContext *getContext() const { return Instruction::getContext(); }
+ OperationInst *clone() const;
+
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index 65cad98..b5d2371 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -34,8 +34,8 @@
/// Statement is a basic unit of execution within an ML function.
/// Statements can be nested within for and if statements effectively
-/// forming a tree. Statements are organized into statement blocks
-/// represented by StmtBlock class.
+/// forming a tree. Child statements are organized into statement blocks
+/// represented by a 'StmtBlock' class.
class Statement : public llvm::ilist_node_with_parent<Statement, StmtBlock> {
public:
enum class Kind {
@@ -77,6 +77,7 @@
private:
Kind kind;
+ /// The statement block that containts this statement.
StmtBlock *block = nullptr;
// allow ilist_traits access to 'block' field.
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 306702f..ebecf58 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -47,6 +47,8 @@
/// Return the context this operation is associated with.
MLIRContext *getContext() const;
+ OperationStmt *clone() const;
+
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
@@ -190,7 +192,7 @@
};
/// For statement represents an affine loop nest.
-class ForStmt : public Statement, public StmtBlock, public MLValue {
+class ForStmt : public Statement, public MLValue, public StmtBlock {
public:
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
@@ -199,6 +201,10 @@
MLIRContext *context);
// Loop bounds and step are immortal objects and don't need to be deleted.
+ // With this dtor, ForStmt needs to inherit from MLValue before it does from
+ // StmtBlock since an MLValue can't be destroyed before the StmtBlock is ---
+ // the latter has uses for the induction variables, which is actually the
+ // MLValue here. FIXME: this dtor.
~ForStmt() {}
AffineConstantExpr *getLowerBound() const { return lowerBound; }
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index 8b03bf1..77f490a 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -26,10 +26,12 @@
#include "mlir/IR/Statement.h"
namespace mlir {
- class MLFunction;
- class IfStmt;
+class MLFunction;
+class IfStmt;
-/// Statement block represents an ordered list of statements.
+/// Statement block represents an ordered list of statements, with the order
+/// being the contiguous lexical order in which the statements appear as
+/// children of a parent statement in the ML Function.
class StmtBlock {
public:
enum class StmtBlockKind {
@@ -54,7 +56,7 @@
/// This is the list of statements in the block.
typedef llvm::iplist<Statement> StmtListType;
- StmtListType &getStatements() { return statements; }
+ StmtListType &getStatements() { return statements; }
const StmtListType &getStatements() const { return statements; }
// Iteration over the statements in the block.
@@ -82,14 +84,14 @@
}
Statement &front() { return statements.front(); }
const Statement &front() const {
- return const_cast<StmtBlock*>(this)->front();
+ return const_cast<StmtBlock *>(this)->front();
}
void print(raw_ostream &os) const;
void dump() const;
/// getSublistAccess() - Returns pointer to member of statement list
- static StmtListType StmtBlock::*getSublistAccess(Statement*) {
+ static StmtListType StmtBlock::*getSublistAccess(Statement *) {
return &StmtBlock::statements;
}
@@ -101,9 +103,8 @@
/// This is the list of statements in the block.
StmtListType statements;
- StmtBlock(const StmtBlock&) = delete;
- void operator=(const StmtBlock&) = delete;
-
+ StmtBlock(const StmtBlock &) = delete;
+ void operator=(const StmtBlock &) = delete;
};
} //end namespace mlir
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index a10cb3a..847907b 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -145,6 +145,21 @@
return inst;
}
+OperationInst *OperationInst::clone() const {
+ SmallVector<CFGValue *, 8> operands;
+ SmallVector<Type *, 8> resultTypes;
+
+ // TODO(clattner): switch to iterator logic.
+ // Put together the operands and results.
+ for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
+ operands.push_back(getInstOperand(i).get());
+
+ for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+ resultTypes.push_back(getInstResult(i).getType());
+
+ return create(getName(), operands, resultTypes, getAttrs(), getContext());
+}
+
OperationInst::OperationInst(Identifier name, unsigned numOperands,
unsigned numResults,
ArrayRef<NamedAttribute> attributes,
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 5bc99e0..4e2a7b0 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -120,7 +120,6 @@
/// Remove this statement (and its descendants) from its StmtBlock and delete
/// all of them.
-/// TODO: erase all descendents for ForStmt/IfStmt.
void Statement::eraseFromBlock() {
assert(getBlock() && "Statement has no block");
getBlock()->getStatements().erase(this);
@@ -155,6 +154,22 @@
return stmt;
}
+/// Clone an existing OperationStmt.
+OperationStmt *OperationStmt::clone() const {
+ SmallVector<MLValue *, 8> operands;
+ SmallVector<Type *, 8> resultTypes;
+
+ // TODO(clattner): switch this to iterator logic.
+ // Put together operands and results.
+ for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
+ operands.push_back(getStmtOperand(i).get());
+
+ for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+ resultTypes.push_back(getStmtResult(i).getType());
+
+ return create(getName(), operands, resultTypes, getAttrs(), getContext());
+}
+
OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
unsigned numResults,
ArrayRef<NamedAttribute> attributes,
@@ -205,9 +220,10 @@
ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
AffineConstantExpr *step, MLIRContext *context)
- : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
+ : Statement(Kind::For),
MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
- lowerBound(lowerBound), upperBound(upperBound), step(step) {}
+ StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
+ upperBound(upperBound), step(step) {}
//===----------------------------------------------------------------------===//
// IfStmt
@@ -215,6 +231,6 @@
IfStmt::~IfStmt() {
delete thenClause;
- if (elseClause != nullptr)
+ if (elseClause)
delete elseClause;
}
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 160a463..fe110d2 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -19,6 +19,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
@@ -54,10 +55,13 @@
typedef llvm::iplist<Statement> StmtListType;
bool walkPostOrder(StmtListType::iterator Start,
StmtListType::iterator End) {
+ bool hasInnerLoops = false;
+ // We need to walk all elements since all innermost loops need to be
+ // gathered as opposed to determining whether this list has any inner
+ // loops or not.
while (Start != End)
- if (walkPostOrder(&(*Start++)))
- return true;
- return false;
+ hasInnerLoops |= walkPostOrder(&(*Start++));
+ return hasInnerLoops;
}
// FIXME: can't use base class method for this because that in turn would
@@ -73,12 +77,11 @@
}
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
- if (walkPostOrder(ifStmt->getThenClause()->begin(),
- ifStmt->getThenClause()->end()) ||
- walkPostOrder(ifStmt->getElseClause()->begin(),
- ifStmt->getElseClause()->end()))
- return true;
- return false;
+ bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
+ ifStmt->getThenClause()->end());
+ hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
+ ifStmt->getElseClause()->end());
+ return hasInnerLoops;
}
bool walkOpStmt(OperationStmt *opStmt) { return false; }
@@ -93,17 +96,45 @@
runOnForStmt(forStmt);
}
-/// Unrolls this loop completely. Returns true if the unrolling happens.
+/// Replace an IV with a constant value.
+static void replaceIterator(Statement *stmt, const ForStmt &iv,
+ MLValue *constVal) {
+ struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
+ // IV to be replaced.
+ const ForStmt *iv;
+ // Constant to be replaced with.
+ MLValue *constVal;
+
+ ReplaceIterator(const ForStmt &iv, MLValue *constVal)
+ : iv(&iv), constVal(constVal){};
+
+ void visitOperationStmt(OperationStmt *os) {
+ for (auto &operand : os->getStmtOperands()) {
+ if (operand.get() == static_cast<const MLValue *>(iv)) {
+ operand.set(constVal);
+ }
+ }
+ }
+ };
+
+ ReplaceIterator ri(iv, constVal);
+ ri.walk(stmt);
+}
+
+/// Unrolls this loop completely.
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
auto trip_count = (ub - lb + 1) / step;
- auto *block = forStmt->getBlock();
- MLFuncBuilder builder(block);
+ auto *mlFunc = forStmt->Statement::findFunction();
+ MLFuncBuilder funcTopBuilder(mlFunc);
+ funcTopBuilder.setInsertionPointAtStart(mlFunc);
+ MLFuncBuilder builder(forStmt->getBlock());
for (int i = 0; i < trip_count; i++) {
+ auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
for (auto &stmt : forStmt->getStatements()) {
switch (stmt.getKind()) {
case Statement::Kind::For:
@@ -113,16 +144,13 @@
llvm_unreachable("unrolling loops that have only operations");
break;
case Statement::Kind::Operation:
- auto *op = cast<OperationStmt>(&stmt);
- // TODO: clone operands and result types.
- builder.createOperation(op->getName(), /*operands*/ {},
- /*resultTypes*/ {}, op->getAttrs());
- // TODO: loop iterator parsing not yet implemented; replace loop
- // iterator uses in unrolled body appropriately.
+ auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
+ // TODO(bondhugula): only generate constants when the IV actually
+ // appears in the body.
+ replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
break;
}
}
}
-
forStmt->eraseFromBlock();
}
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index 2d42e4a..6a7d9cf 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -1,16 +1,64 @@
// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
-// CHECK-LABEL: mlfunc @loops() {
-mlfunc @loops() {
- // CHECK: for %i0 = 1 to 100 step 2 {
+// CHECK-LABEL: mlfunc @loops1() {
+mlfunc @loops1() {
+ // CHECK: %c0_i32 = constant 0 : i32
+ // CHECK-NEXT: %c1_i32 = constant 1 : i32
+ // CHECK-NEXT: %c2_i32 = constant 2 : i32
+ // CHECK-NEXT: %c3_i32 = constant 3 : i32
+ // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
- // CHECK: "custom"(){value: 1} : () -> ()
- // CHECK-NEXT: "custom"(){value: 1} : () -> ()
- // CHECK-NEXT: "custom"(){value: 1} : () -> ()
- // CHECK-NEXT: "custom"(){value: 1} : () -> ()
+ // CHECK: %c1_i32_0 = constant 1 : i32
+ // CHECK-NEXT: %c1_i32_1 = constant 1 : i32
+ // CHECK-NEXT: %c1_i32_2 = constant 1 : i32
+ // CHECK-NEXT: %c1_i32_3 = constant 1 : i32
for %j = 1 to 4 {
- "custom"(){value: 1} : () -> f32
+ %x = constant 1 : i32
}
} // CHECK: }
return // CHECK: return
} // CHECK }
+
+// CHECK-LABEL: mlfunc @loops2() {
+mlfunc @loops2() {
+ // CHECK: %c0_i32 = constant 0 : i32
+ // CHECK-NEXT: %c1_i32 = constant 1 : i32
+ // CHECK-NEXT: %c2_i32 = constant 2 : i32
+ // CHECK-NEXT: %c3_i32 = constant 3 : i32
+ // CHECK-NEXT: %c0_i32_0 = constant 0 : i32
+ // CHECK-NEXT: %c1_i32_1 = constant 1 : i32
+ // CHECK-NEXT: %c2_i32_2 = constant 2 : i32
+ // CHECK-NEXT: %c3_i32_3 = constant 3 : i32
+ // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
+ for %i = 1 to 100 step 2 {
+ // CHECK: %0 = affine_apply (d0) -> (d0 + 1)(%c0_i32_0)
+ // CHECK-NEXT: %1 = affine_apply (d0) -> (d0 + 1)(%c1_i32_1)
+ // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c2_i32_2)
+ // CHECK-NEXT: %3 = affine_apply (d0) -> (d0 + 1)(%c3_i32_3)
+ for %j = 1 to 4 {
+ %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
+ (affineint) -> (affineint)
+ }
+ } // CHECK: }
+
+ // CHECK: %c99 = constant 99 : affineint
+ %k = "constant"(){value: 99} : () -> affineint
+ // CHECK: for %i1 = 1 to 100 step 2 {
+ for %m = 1 to 100 step 2 {
+ // CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c0_i32)
+ // CHECK-NEXT: %5 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c0_i32)[%c99]
+ // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
+ // CHECK-NEXT: %7 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c1_i32)[%c99]
+ // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
+ // CHECK-NEXT: %9 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c2_i32)[%c99]
+ // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
+ // CHECK-NEXT: %11 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c3_i32)[%c99]
+ for %n = 1 to 4 {
+ %y = "affine_apply" (%n) { map: (d0) -> (d0 + 1) } :
+ (affineint) -> (affineint)
+ %z = "affine_apply" (%n, %k) { map: (d0) [s0] -> (d0 + s0 + 1) } :
+ (affineint, affineint) -> (affineint)
+ } // CHECK }
+ } // CHECK }
+ return // CHECK: return
+} // CHECK }