Clean up and extend MLFuncBuilder to allow creating statements in the middle of a statement block. Rename Statement::getFunction() and StmtBlock()::getFunction() to findFunction() to make it clear that this is not a constant time getter.
Fix b/112039912 - we were recording 'i' instead of '%i' for loop induction variables causing "use of undefined SSA value" error.
PiperOrigin-RevId: 206884644
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 99fb1df..275f8ba 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -1065,9 +1065,10 @@
void BasicBlock::dump() const { print(llvm::errs()); }
void Statement::print(raw_ostream &os) const {
- ModuleState state(getFunction()->getContext());
+ MLFunction *function = findFunction();
+ ModuleState state(function->getContext());
ModulePrinter modulePrinter(os, state);
- MLFunctionPrinter(getFunction(), modulePrinter).print(this);
+ MLFunctionPrinter(function, modulePrinter).print(this);
}
void Statement::dump() const { print(llvm::errs()); }
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 1a094d9..9ba75e4 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -162,6 +162,6 @@
if (!step)
step = getConstantExpr(1);
auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
- block->getStatements().push_back(stmt);
+ block->getStatements().insert(insertPoint, stmt);
return stmt;
}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 3ac481f..b55b597 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -57,8 +57,10 @@
}
}
-MLFunction *Statement::getFunction() const {
- return this->getBlock()->getFunction();
+Statement *Statement::getParentStmt() const { return block->getParentStmt(); }
+
+MLFunction *Statement::findFunction() const {
+ return this->getBlock()->findFunction();
}
bool Statement::isInnermost() const {
diff --git a/lib/IR/StmtBlock.cpp b/lib/IR/StmtBlock.cpp
index 16ddb37..21b870f 100644
--- a/lib/IR/StmtBlock.cpp
+++ b/lib/IR/StmtBlock.cpp
@@ -35,7 +35,7 @@
}
}
-MLFunction *StmtBlock::getFunction() const {
+MLFunction *StmtBlock::findFunction() const {
StmtBlock *block = const_cast<StmtBlock *>(this);
while (block->getParentStmt() != nullptr)