Implement a proper function list in module, which auto-maintain the parent
pointer, and ensure that functions are deleted when the module is destroyed.
This exposed the fact that MLFunction had no dtor, and that the dtor in
CFGFunction was broken with cyclic references. Fix both of these problems.
PiperOrigin-RevId: 206051666
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 0a98e2b..9487cf9 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -96,8 +96,8 @@
}
void ModuleConverter::convertMLFunctions() {
- for (Function *fn : module->functionList) {
- if (auto mlFunc = dyn_cast<MLFunction>(fn))
+ for (Function &fn : *module) {
+ if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
generatedFuncs[mlFunc] = convert(mlFunc);
}
}
@@ -105,22 +105,22 @@
// Creates CFG function equivalent to the given ML function.
CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
// TODO: ensure that CFG function name is unique.
- CFGFunction *cfgFunc =
+ auto *cfgFunc =
new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
- module->functionList.push_back(cfgFunc);
+ module->getFunctions().push_back(cfgFunc);
// Generates the body of the CFG function.
return FunctionConverter(cfgFunc).convert(mlFunc);
}
void ModuleConverter::replaceReferences() {
- for (Function *fn : module->functionList) {
- switch (fn->getKind()) {
+ for (Function &fn : *module) {
+ switch (fn.getKind()) {
case Function::Kind::CFGFunc:
- replaceReferences(cast<CFGFunction>(fn));
+ replaceReferences(&cast<CFGFunction>(fn));
break;
case Function::Kind::MLFunc:
- replaceReferences(cast<MLFunction>(fn));
+ replaceReferences(&cast<MLFunction>(fn));
break;
case Function::Kind::ExtFunc:
// nothing to do for external functions
@@ -139,20 +139,14 @@
// Removes all ML functions from the module.
void ModuleConverter::removeMLFunctions() {
- std::vector<Function *> &fnList = module->functionList;
-
- // Delete ML functions and its data.
- for (auto &fn : fnList) {
- if (auto mlFunc = dyn_cast<MLFunction>(fn)) {
- delete mlFunc;
- fn = nullptr;
- }
+ // Delete ML functions from the module.
+ for (auto it = module->begin(), e = module->end(); it != e;) {
+ // Manipulate iterator carefully to avoid deleting a function we're pointing
+ // at.
+ Function &fn = *it++;
+ if (auto mlFunc = dyn_cast<MLFunction>(&fn))
+ mlFunc->eraseFromModule();
}
-
- // Remove ML functions from the function list.
- fnList.erase(std::remove_if(fnList.begin(), fnList.end(),
- [](Function *fn) { return !fn; }),
- fnList.end());
}
//===----------------------------------------------------------------------===//
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index af0c7ca..337f558 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -44,8 +44,8 @@
/// Unrolls all the innermost loops of this Module.
bool MLFunctionPass::runOnModule(Module *m) {
bool changed = false;
- for (auto fn : m->functionList) {
- if (auto *mlFunc = dyn_cast<MLFunction>(fn))
+ for (auto &fn : *m) {
+ if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
changed |= runOnMLFunction(mlFunc);
}
return changed;