Clean up and extend MLFuncBuilder to allow creating statements in the middle of a statement block. Rename Statement::getFunction() and StmtBlock()::getFunction() to findFunction() to make it clear that this is not a constant time getter.
Fix b/112039912 - we were recording 'i' instead of '%i' for loop induction variables causing "use of undefined SSA value" error.

PiperOrigin-RevId: 206884644
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index b84da83..c6a09f1 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -195,9 +195,11 @@
 /// statement block.
 class MLFuncBuilder : public Builder {
 public:
-  MLFuncBuilder(MLFunction *function) : Builder(function->getContext()) {}
-
-  MLFuncBuilder(StmtBlock *block) : MLFuncBuilder(block->getFunction()) {
+  /// Create ML function builder and set insertion point to the given
+  /// statement block, that is, given ML function, for statement or if statement
+  /// clause.
+  MLFuncBuilder(StmtBlock *block)
+      : Builder(block->findFunction()->getContext()) {
     setInsertionPoint(block);
   }
 
@@ -209,6 +211,20 @@
     insertPoint = StmtBlock::iterator();
   }
 
+  /// Set the insertion point to the specified location.
+  /// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
+  /// point to a different function.
+  void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
+    // TODO: check that insertPoint is in this rather than some other block.
+    this->block = block;
+    this->insertPoint = insertPoint;
+  }
+
+  /// Set the insertion point to the specified operation.
+  void setInsertionPoint(OperationStmt *stmt) {
+    setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
+  }
+
   /// Set the insertion point to the end of the specified block.
   void setInsertionPoint(StmtBlock *block) {
     this->block = block;
@@ -230,8 +246,8 @@
                      AffineConstantExpr *step = nullptr);
 
   IfStmt *createIf() {
-    auto stmt = new IfStmt();
-    block->getStatements().push_back(stmt);
+    auto *stmt = new IfStmt();
+    block->getStatements().insert(insertPoint, stmt);
     return stmt;
   }
 
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index d2bfdb2..f43f0b6 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -51,8 +51,13 @@
   /// Returns the statement block that contains this statement.
   StmtBlock *getBlock() const { return block; }
 
+  /// Returns the closest surrounding statement that contains this statement
+  /// or nullptr if this is a top-level statement.
+  Statement *getParentStmt() const;
+
   /// Returns the function that this statement is part of.
-  MLFunction *getFunction() const;
+  /// The function is determined by traversing the chain of parent statements.
+  MLFunction *findFunction() const;
 
   /// Returns true if there are no more loops nested under this stmt.
   bool isInnermost() const;
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index 8f4cc20..8b03bf1 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -45,7 +45,8 @@
   Statement *getParentStmt() const;
 
   /// Returns the function that this statement block is part of.
-  MLFunction *getFunction() const;
+  /// The function is determined by traversing the chain of parent statements.
+  MLFunction *findFunction() const;
 
   //===--------------------------------------------------------------------===//
   // Statement list management
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 99fb1df..275f8ba 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -1065,9 +1065,10 @@
 void BasicBlock::dump() const { print(llvm::errs()); }
 
 void Statement::print(raw_ostream &os) const {
-  ModuleState state(getFunction()->getContext());
+  MLFunction *function = findFunction();
+  ModuleState state(function->getContext());
   ModulePrinter modulePrinter(os, state);
-  MLFunctionPrinter(getFunction(), modulePrinter).print(this);
+  MLFunctionPrinter(function, modulePrinter).print(this);
 }
 
 void Statement::dump() const { print(llvm::errs()); }
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 1a094d9..9ba75e4 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -162,6 +162,6 @@
   if (!step)
     step = getConstantExpr(1);
   auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
-  block->getStatements().push_back(stmt);
+  block->getStatements().insert(insertPoint, stmt);
   return stmt;
 }
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 3ac481f..b55b597 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -57,8 +57,10 @@
   }
 }
 
-MLFunction *Statement::getFunction() const {
-  return this->getBlock()->getFunction();
+Statement *Statement::getParentStmt() const { return block->getParentStmt(); }
+
+MLFunction *Statement::findFunction() const {
+  return this->getBlock()->findFunction();
 }
 
 bool Statement::isInnermost() const {
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
index 16ddb37..21b870f 100644
--- a/lib/IR/StmtBlock.cpp
+++ b/lib/IR/StmtBlock.cpp
@@ -35,7 +35,7 @@
   }
 }
 
-MLFunction *StmtBlock::getFunction() const {
+MLFunction *StmtBlock::findFunction() const {
   StmtBlock *block = const_cast<StmtBlock *>(this);
 
   while (block->getParentStmt() != nullptr)
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index aa08cb6..0f4a3a5 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1335,14 +1335,14 @@
     return (emitError(useInfo.loc, "reference to invalid result number"),
             nullptr);
 
+  // Otherwise, this is a forward reference.  If we are in ML function return
+  // an error. In CFG function, create a placeholder and remember
+  // that we did so.
   if (getKind() == Kind::MLFunc)
     return (
         emitError(useInfo.loc, "use of undefined SSA value " + useInfo.name),
         nullptr);
 
-  // Otherwise, this is a forward reference.  If we are in ML function return
-  // an error. In CFG function, create a placeholder and remember
-  // that we did so.
   auto *result = createForwardReferencePlaceholder(useInfo.loc, type);
   entries[useInfo.number].first = result;
   entries[useInfo.number].second = useInfo.loc;
@@ -2102,7 +2102,7 @@
     return emitError("expected SSA identifier for the loop variable");
 
   auto loc = getToken().getLoc();
-  StringRef inductionVariableName = getTokenSpelling().drop_front();
+  StringRef inductionVariableName = getTokenSpelling();
   consumeToken(Token::percent_identifier);
 
   if (parseToken(Token::equal, "expected ="))
@@ -2143,6 +2143,8 @@
   // Reset insertion point to the current block.
   builder.setInsertionPoint(forStmt->getBlock());
 
+  // TODO: remove definition of the induction variable.
+
   return ParseSuccess;
 }
 
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 489a98f..160a463 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -101,9 +101,7 @@
   auto trip_count = (ub - lb + 1) / step;
 
   auto *block = forStmt->getBlock();
-
-  MLFuncBuilder builder(forStmt->Statement::getFunction());
-  builder.setInsertionPoint(block);
+  MLFuncBuilder builder(block);
 
   for (int i = 0; i < trip_count; i++) {
     for (auto &stmt : forStmt->getStatements()) {
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 6d1f15d..bc88db9 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -285,7 +285,7 @@
 
 mlfunc @duplicate_induction_var() {
   for %i = 1 to 10 {   // expected-error {{previously defined here}}
-    for %i = 1 to 10 { // expected-error {{redefinition of SSA value 'i'}}
+    for %i = 1 to 10 { // expected-error {{redefinition of SSA value '%i'}}
     }
   }
   return
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index ce101e5..0ed0938 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -145,7 +145,8 @@
 mlfunc @complex_loops() {
   for %i1 = 1 to 100 {      // CHECK:   for %i0 = 1 to 100 {
     for %j1 = 1 to 100 {    // CHECK:     for %i1 = 1 to 100 {
-       "foo"() : () -> ()   // CHECK:       "foo"() : () -> ()
+       // CHECK: "foo"(%i0, %i1) : (affineint, affineint) -> ()
+       "foo"(%i1, %j1) : (affineint,affineint) -> ()
     }                       // CHECK:     }
     "boo"() : () -> ()      // CHECK:     "boo"() : () -> ()
     for %j2 = 1 to 10 {     // CHECK:     for %i2 = 1 to 10 {
@@ -157,6 +158,7 @@
   return                    // CHECK:   return
 }                           // CHECK: }
 
+
 // CHECK-LABEL: mlfunc @ifstmt() {
 mlfunc @ifstmt() {
   for %i = 1 to 10 {    // CHECK   for %i0 = 1 to 10 {