Clean up and extend MLFuncBuilder to allow creating statements in the middle of a statement block. Rename Statement::getFunction() and StmtBlock()::getFunction() to findFunction() to make it clear that this is not a constant time getter.
Fix b/112039912 - we were recording 'i' instead of '%i' for loop induction variables causing "use of undefined SSA value" error.

PiperOrigin-RevId: 206884644
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index b84da83..c6a09f1 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -195,9 +195,11 @@
 /// statement block.
 class MLFuncBuilder : public Builder {
 public:
-  MLFuncBuilder(MLFunction *function) : Builder(function->getContext()) {}
-
-  MLFuncBuilder(StmtBlock *block) : MLFuncBuilder(block->getFunction()) {
+  /// Create ML function builder and set insertion point to the given
+  /// statement block, that is, given ML function, for statement or if statement
+  /// clause.
+  MLFuncBuilder(StmtBlock *block)
+      : Builder(block->findFunction()->getContext()) {
     setInsertionPoint(block);
   }
 
@@ -209,6 +211,20 @@
     insertPoint = StmtBlock::iterator();
   }
 
+  /// Set the insertion point to the specified location.
+  /// Unlike CFGFuncBuilder, MLFuncBuilder allows to set insertion
+  /// point to a different function.
+  void setInsertionPoint(StmtBlock *block, StmtBlock::iterator insertPoint) {
+    // TODO: check that insertPoint is in this rather than some other block.
+    this->block = block;
+    this->insertPoint = insertPoint;
+  }
+
+  /// Set the insertion point to the specified operation.
+  void setInsertionPoint(OperationStmt *stmt) {
+    setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
+  }
+
   /// Set the insertion point to the end of the specified block.
   void setInsertionPoint(StmtBlock *block) {
     this->block = block;
@@ -230,8 +246,8 @@
                      AffineConstantExpr *step = nullptr);
 
   IfStmt *createIf() {
-    auto stmt = new IfStmt();
-    block->getStatements().push_back(stmt);
+    auto *stmt = new IfStmt();
+    block->getStatements().insert(insertPoint, stmt);
     return stmt;
   }
 
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index d2bfdb2..f43f0b6 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -51,8 +51,13 @@
   /// Returns the statement block that contains this statement.
   StmtBlock *getBlock() const { return block; }
 
+  /// Returns the closest surrounding statement that contains this statement
+  /// or nullptr if this is a top-level statement.
+  Statement *getParentStmt() const;
+
   /// Returns the function that this statement is part of.
-  MLFunction *getFunction() const;
+  /// The function is determined by traversing the chain of parent statements.
+  MLFunction *findFunction() const;
 
   /// Returns true if there are no more loops nested under this stmt.
   bool isInnermost() const;
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index 8f4cc20..8b03bf1 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -45,7 +45,8 @@
   Statement *getParentStmt() const;
 
   /// Returns the function that this statement block is part of.
-  MLFunction *getFunction() const;
+  /// The function is determined by traversing the chain of parent statements.
+  MLFunction *findFunction() const;
 
   //===--------------------------------------------------------------------===//
   // Statement list management