Implement induction variables. Pretty print induction variable operands as %i<ssa value number>. Add support for future pretty printing of ML function arguments as %arg<ssa value number>.
Induction variables are implemented by inheriting ForStmt from MLValue. ForStmt provides APIs that make this design decision invisible to the ForStmt users.
This CL in combination with cl/206253643 resolves http://b/111769060.
PiperOrigin-RevId: 206655937
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index cc31c12..f77f1d7 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -187,19 +187,17 @@
};
/// For statement represents an affine loop nest.
-class ForStmt : public Statement, public StmtBlock {
+class ForStmt : public Statement, public StmtBlock, private MLValue {
public:
// TODO: lower and upper bounds should be affine maps with
// dimension and symbol use lists.
explicit ForStmt(AffineConstantExpr *lowerBound,
- AffineConstantExpr *upperBound, AffineConstantExpr *step)
- : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
- lowerBound(lowerBound), upperBound(upperBound), step(step) {}
+ AffineConstantExpr *upperBound, AffineConstantExpr *step,
+ MLIRContext *context);
// Loop bounds and step are immortal objects and don't need to be deleted.
~ForStmt() {}
- // TODO: represent induction variable
AffineConstantExpr *getLowerBound() const { return lowerBound; }
AffineConstantExpr *getUpperBound() const { return upperBound; }
AffineConstantExpr *getStep() const { return step; }
@@ -213,6 +211,16 @@
return block->getStmtBlockKind() == StmtBlockKind::For;
}
+ // For statement represents induction variable by inheriting
+ // from MLValue. This design is hidden behind interfaces.
+ static bool classof(const SSAValue *value) {
+ return value->getKind() == SSAValueKind::InductionVar;
+ }
+
+ /// MLValue methods
+ MLValue *getInductionVar() { return this; }
+ const MLValue *getInductionVar() const { return this; }
+
private:
AffineConstantExpr *lowerBound;
AffineConstantExpr *upperBound;
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index c928b34..2e09850 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -542,22 +542,40 @@
protected:
void numberValueID(const SSAValue *value) {
assert(!valueIDs.count(value) && "Value numbered multiple times");
- valueIDs[value] = nextValueID++;
+ unsigned id;
+ switch (value->getKind()) {
+ case SSAValueKind::BBArgument:
+ case SSAValueKind::InstResult:
+ case SSAValueKind::StmtResult:
+ id = nextValueID++;
+ break;
+ case SSAValueKind::FnArgument:
+ id = nextFnArgumentID++;
+ break;
+ case SSAValueKind::InductionVar:
+ id = nextInductionVarID++;
+ break;
+ }
+ valueIDs[value] = id;
}
- void printValueID(const SSAValue *value,
- bool dontPrintResultNo = false) const {
+ void printValueID(const SSAValue *value, bool printResultNo = true) const {
int resultNo = -1;
auto lookupValue = value;
- // If this is a reference to the result of a multi-result instruction, print
- // out the # identifier and make sure to map our lookup to the first result
- // of the instruction.
+ // If this is a reference to the result of a multi-result instruction or
+ // statement, print out the # identifier and make sure to map our lookup
+ // to the first result of the instruction.
if (auto *result = dyn_cast<InstResult>(value)) {
if (result->getOwner()->getNumResults() != 1) {
resultNo = result->getResultNumber();
lookupValue = result->getOwner()->getResult(0);
}
+ } else if (auto *result = dyn_cast<StmtResult>(value)) {
+ if (result->getOwner()->getNumResults() != 1) {
+ resultNo = result->getResultNumber();
+ lookupValue = result->getOwner()->getResult(0);
+ }
}
auto it = valueIDs.find(lookupValue);
@@ -566,8 +584,14 @@
return;
}
- os << '%' << it->getSecond();
- if (resultNo != -1 && !dontPrintResultNo)
+ os << '%';
+ if (isa<ForStmt>(value))
+
+ os << 'i';
+ else if (isa<FnArgument>(value))
+ os << "arg";
+ os << it->getSecond();
+ if (resultNo != -1 && printResultNo)
os << '#' << resultNo;
}
@@ -575,12 +599,14 @@
/// This is the value ID for each SSA value in the current function.
DenseMap<const SSAValue *, unsigned> valueIDs;
unsigned nextValueID = 0;
+ unsigned nextInductionVarID = 0;
+ unsigned nextFnArgumentID = 0;
};
} // end anonymous namespace
void FunctionPrinter::printOperation(const Operation *op) {
if (op->getNumResults()) {
- printValueID(op->getResult(0), /*dontPrintResultNo*/ true);
+ printValueID(op->getResult(0), /*printResultNo=*/false);
os << " = ";
}
@@ -874,6 +900,9 @@
if (stmt->getNumResults() != 0)
printer->numberValueID(stmt->getResult(0));
}
+ void visitForStmt(ForStmt *stmt) {
+ printer->numberValueID(stmt->getInductionVar());
+ }
MLFunctionPrinter *printer;
};
@@ -918,7 +947,9 @@
}
void MLFunctionPrinter::print(const ForStmt *stmt) {
- os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
+ os.indent(numSpaces) << "for ";
+ printOperand(stmt->getInductionVar());
+ os << " = " << *stmt->getLowerBound();
os << " to " << *stmt->getUpperBound();
if (stmt->getStep()->getValue() != 1)
os << " step " << *stmt->getStep();
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index e7cb9cb..1a094d9 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -161,7 +161,7 @@
AffineConstantExpr *step) {
if (!step)
step = getConstantExpr(1);
- auto *stmt = new ForStmt(lowerBound, upperBound, step);
+ auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
block->getStatements().push_back(stmt);
return stmt;
}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 0731e54..6ace8ff 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
+#include "mlir/IR/Types.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
@@ -191,6 +192,16 @@
}
//===----------------------------------------------------------------------===//
+// ForStmt
+//===----------------------------------------------------------------------===//
+
+ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
+ AffineConstantExpr *step, MLIRContext *context)
+ : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
+ MLValue(MLValueKind::InductionVar, Type::getAffineInt(context)),
+ lowerBound(lowerBound), upperBound(upperBound), step(step) {}
+
+//===----------------------------------------------------------------------===//
// IfStmt
//===----------------------------------------------------------------------===//
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 654db98..3bcec37 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -2086,10 +2086,8 @@
if (getToken().isNot(Token::percent_identifier))
return emitError("expected SSA identifier for the loop variable");
- // TODO: create SSA value definition from name
- StringRef name = getTokenSpelling().drop_front();
- (void)name;
-
+ auto loc = getToken().getLoc();
+ StringRef inductionVariableName = getTokenSpelling().drop_front();
consumeToken(Token::percent_identifier);
if (parseToken(Token::equal, "expected ="))
@@ -2116,14 +2114,20 @@
}
// Create for statement.
- ForStmt *stmt = builder.createFor(lowerBound, upperBound, step);
+ ForStmt *forStmt = builder.createFor(lowerBound, upperBound, step);
+
+ // Create SSA value definition for the induction variable.
+ addDefinition({inductionVariableName, 0, loc}, forStmt->getInductionVar());
// If parsing of the for statement body fails,
// MLIR contains for statement with those nested statements that have been
// successfully parsed.
- if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
+ if (parseStmtBlock(forStmt))
return ParseFailure;
+ // Reset insertion point to the current block.
+ builder.setInsertionPoint(forStmt->getBlock());
+
return ParseSuccess;
}
@@ -2174,6 +2178,9 @@
return ParseFailure;
}
+ // Reset insertion point to the current block.
+ builder.setInsertionPoint(ifStmt->getBlock());
+
return ParseSuccess;
}
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 38a79f8..6d1f15d 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -281,3 +281,12 @@
return
}
+// -----
+
+mlfunc @duplicate_induction_var() {
+ for %i = 1 to 10 { // expected-error {{previously defined here}}
+ for %i = 1 to 10 { // expected-error {{redefinition of SSA value 'i'}}
+ }
+ }
+ return
+}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 61604fb..fe82238 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -116,7 +116,7 @@
mlfunc @mlfunc_with_ops() {
// CHECK: %0 = "foo"() : () -> i64
%a = "foo"() : ()->i64
- // CHECK: for x = 1 to 10 {
+ // CHECK: for %i0 = 1 to 10 {
for %i = 1 to 10 {
// CHECK: %1 = "doo"() : () -> f32
%b = "doo"() : ()->f32
@@ -132,18 +132,34 @@
// CHECK-LABEL: mlfunc @loops() {
mlfunc @loops() {
- // CHECK: for x = 1 to 100 step 2 {
+ // CHECK: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
- // CHECK: for x = 1 to 200 {
+ // CHECK: for %i1 = 1 to 200 {
for %j = 1 to 200 {
} // CHECK: }
} // CHECK: }
return // CHECK: return
} // CHECK: }
+// CHECK-LABEL: mlfunc @complex_loops() {
+mlfunc @complex_loops() {
+ for %i1 = 1 to 100 { // CHECK: for %i0 = 1 to 100 {
+ for %j1 = 1 to 100 { // CHECK: for %i1 = 1 to 100 {
+ "foo"() : () -> () // CHECK: "foo"() : () -> ()
+ } // CHECK: }
+ "boo"() : () -> () // CHECK: "boo"() : () -> ()
+ for %j2 = 1 to 10 { // CHECK: for %i2 = 1 to 10 {
+ for %k2 = 1 to 10 { // CHECK: for %i3 = 1 to 10 {
+ "goo"() : () -> () // CHECK: "goo"() : () -> ()
+ } // CHECK: }
+ } // CHECK: }
+ } // CHECK: }
+ return // CHECK: return
+} // CHECK: }
+
// CHECK-LABEL: mlfunc @ifstmt() {
mlfunc @ifstmt() {
- for %i = 1 to 10 { // CHECK for x = 1 to 10 {
+ for %i = 1 to 10 { // CHECK for %i0 = 1 to 10 {
if () { // CHECK if () {
} else if () { // CHECK } else if () {
} else { // CHECK } else {
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index 3c42142..2d42e4a 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -2,7 +2,7 @@
// CHECK-LABEL: mlfunc @loops() {
mlfunc @loops() {
- // CHECK: for x = 1 to 100 step 2 {
+ // CHECK: for %i0 = 1 to 100 step 2 {
for %i = 1 to 100 step 2 {
// CHECK: "custom"(){value: 1} : () -> ()
// CHECK-NEXT: "custom"(){value: 1} : () -> ()