Clean up and extend MLFuncBuilder to allow creating statements in the middle of a statement block. Rename Statement::getFunction() and StmtBlock()::getFunction() to findFunction() to make it clear that this is not a constant time getter.
Fix b/112039912 - we were recording 'i' instead of '%i' for loop induction variables causing "use of undefined SSA value" error.
PiperOrigin-RevId: 206884644
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index b84da83..c6a09f1 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -195,9 +195,11 @@
/// statement block.
class MLFuncBuilder : public Builder {
public:
- MLFuncBuilder(MLFunction *function) : Builder(function->getContext()) {}
-
- MLFuncBuilder(StmtBlock *block) : MLFuncBuilder(block->getFunction()) {
+ /// Create ML function builder and set insertion point to the given
+ /// statement block, that is, given ML function, for statement or if statement
+ /// clause.
+ MLFuncBuilder(StmtBlock *block)
+ : Builder(block->findFunction()->getContext()) {
setInsertionPoint(block);
}
@@ -209,6 +211,20 @@
insertPoint = StmtBlock::iterator();
}
+ /// Set the insertion point to the specified location.
+ /// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
+ /// point to a different function.
+ void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
+ // TODO: check that insertPoint is in this rather than some other block.
+ this->block = block;
+ this->insertPoint = insertPoint;
+ }
+
+ /// Set the insertion point to the specified operation.
+ void setInsertionPoint(OperationStmt *stmt) {
+ setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
+ }
+
/// Set the insertion point to the end of the specified block.
void setInsertionPoint(StmtBlock *block) {
this->block = block;
@@ -230,8 +246,8 @@
AffineConstantExpr *step = nullptr);
IfStmt *createIf() {
- auto stmt = new IfStmt();
- block->getStatements().push_back(stmt);
+ auto *stmt = new IfStmt();
+ block->getStatements().insert(insertPoint, stmt);
return stmt;
}
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index d2bfdb2..f43f0b6 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -51,8 +51,13 @@
/// Returns the statement block that contains this statement.
StmtBlock *getBlock() const { return block; }
+ /// Returns the closest surrounding statement that contains this statement
+ /// or nullptr if this is a top-level statement.
+ Statement *getParentStmt() const;
+
/// Returns the function that this statement is part of.
- MLFunction *getFunction() const;
+ /// The function is determined by traversing the chain of parent statements.
+ MLFunction *findFunction() const;
/// Returns true if there are no more loops nested under this stmt.
bool isInnermost() const;
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index 8f4cc20..8b03bf1 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -45,7 +45,8 @@
Statement *getParentStmt() const;
/// Returns the function that this statement block is part of.
- MLFunction *getFunction() const;
+ /// The function is determined by traversing the chain of parent statements.
+ MLFunction *findFunction() const;
//===--------------------------------------------------------------------===//
// Statement list management