Implement a proper function list in module, which auto-maintain the parent
pointer, and ensure that functions are deleted when the module is destroyed.
This exposed the fact that MLFunction had no dtor, and that the dtor in
CFGFunction was broken with cyclic references. Fix both of these problems.
PiperOrigin-RevId: 206051666
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index e23964d..4fd7e61 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -157,8 +157,8 @@
// Initializes module state, populating affine map state.
void ModuleState::initialize(const Module *module) {
- for (auto fn : module->functionList) {
- visitFunction(fn);
+ for (auto &fn : *module) {
+ visitFunction(&fn);
}
}
@@ -236,8 +236,8 @@
map->print(os);
os << '\n';
}
- for (auto *fn : module->functionList)
- print(fn);
+ for (auto const &fn : *module)
+ print(&fn);
}
void ModulePrinter::printAttribute(const Attribute *attr) {
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 72ec443..8476b06 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Module.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/StringRef.h"
using namespace mlir;
@@ -27,6 +28,64 @@
MLIRContext *Function::getContext() const { return getType()->getContext(); }
+/// Delete this object.
+void Function::destroy() {
+ switch (getKind()) {
+ case Kind::ExtFunc:
+ delete cast<ExtFunction>(this);
+ break;
+ case Kind::MLFunc:
+ delete cast<MLFunction>(this);
+ break;
+ case Kind::CFGFunc:
+ delete cast<CFGFunction>(this);
+ break;
+ }
+}
+
+Module *llvm::ilist_traits<Function>::getContainingModule() {
+ size_t Offset(
+ size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
+ iplist<Function> *Anchor(static_cast<iplist<Function> *>(this));
+ return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
+}
+
+/// This is a trait method invoked when a Function is added to a Module. We
+/// keep the module pointer up to date.
+void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
+ assert(!function->getModule() && "already in a module!");
+ function->module = getContainingModule();
+}
+
+/// This is a trait method invoked when a Function is removed from a Module.
+/// We keep the module pointer up to date.
+void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
+ assert(function->module && "not already in a module!");
+ function->module = nullptr;
+}
+
+/// This is a trait method invoked when an instruction is moved from one block
+/// to another. We keep the block pointer up to date.
+void llvm::ilist_traits<Function>::transferNodesFromList(
+ ilist_traits<Function> &otherList, function_iterator first,
+ function_iterator last) {
+ // If we are transferring functions within the same module, the Module
+ // pointer doesn't need to be updated.
+ Module *curParent = getContainingModule();
+ if (curParent == otherList.getContainingModule())
+ return;
+
+ // Update the 'module' member of each function.
+ for (; first != last; ++first)
+ first->module = curParent;
+}
+
+/// Unlink this function from its Module and delete it.
+void Function::eraseFromModule() {
+ assert(getModule() && "Function has no parent");
+ getModule()->getFunctions().erase(this);
+}
+
//===----------------------------------------------------------------------===//
// ExtFunction implementation.
//===----------------------------------------------------------------------===//
@@ -43,9 +102,23 @@
: Function(name, type, Kind::CFGFunc) {
}
+CFGFunction::~CFGFunction() {
+ // Instructions may have cyclic references, which need to be dropped before we
+ // can start deleting them.
+ for (auto &bb : *this) {
+ for (auto &inst : bb)
+ inst.dropAllReferences();
+ }
+}
+
//===----------------------------------------------------------------------===//
// MLFunction implementation.
//===----------------------------------------------------------------------===//
MLFunction::MLFunction(StringRef name, FunctionType *type)
: Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
+
+MLFunction::~MLFunction() {
+ // TODO: When move SSA stuff is supported.
+ // dropAllReferences();
+}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 8cbdabc..a1cdcb3 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -99,6 +99,14 @@
}
}
+/// This drops all operand uses from this instruction, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void Instruction::dropAllReferences() {
+ for (auto &op : getInstOperands())
+ op.drop();
+}
+
//===----------------------------------------------------------------------===//
// OperationInst
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 96f7f73..9e52556 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -330,8 +330,8 @@
bool Module::verify(std::string *errorResult) const {
/// Check that each function is correct.
- for (auto fn : functionList) {
- if (fn->verify(errorResult))
+ for (auto &fn : *this) {
+ if (fn.verify(errorResult))
return true;
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 55fd260..362f868 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -917,11 +917,10 @@
StringRef sRef = getTokenSpelling();
for (auto entry : dimsAndSymbols) {
- if (entry.first != sRef)
- continue;
-
- consumeToken(Token::bare_identifier);
- return entry.second;
+ if (entry.first == sRef) {
+ consumeToken(Token::bare_identifier);
+ return entry.second;
+ }
}
return (emitError("use of undeclared identifier"), nullptr);
@@ -1861,7 +1860,7 @@
"'");
}
- getModule()->functionList.push_back(function);
+ getModule()->getFunctions().push_back(function);
return finalizeFunction(function, braceLoc);
}
@@ -2053,7 +2052,7 @@
parseToken(Token::r_brace, "expected '}' to end mlfunc"))
return ParseFailure;
- getModule()->functionList.push_back(function);
+ getModule()->getFunctions().push_back(function);
return finalizeFunction(function, braceLoc);
}
@@ -2356,7 +2355,7 @@
return ParseFailure;
// Okay, the external function definition was parsed correctly.
- getModule()->functionList.push_back(new ExtFunction(name, type));
+ getModule()->getFunctions().push_back(new ExtFunction(name, type));
return ParseSuccess;
}
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 0a98e2b..9487cf9 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -96,8 +96,8 @@
}
void ModuleConverter::convertMLFunctions() {
- for (Function *fn : module->functionList) {
- if (auto mlFunc = dyn_cast<MLFunction>(fn))
+ for (Function &fn : *module) {
+ if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
generatedFuncs[mlFunc] = convert(mlFunc);
}
}
@@ -105,22 +105,22 @@
// Creates CFG function equivalent to the given ML function.
CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
// TODO: ensure that CFG function name is unique.
- CFGFunction *cfgFunc =
+ auto *cfgFunc =
new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
- module->functionList.push_back(cfgFunc);
+ module->getFunctions().push_back(cfgFunc);
// Generates the body of the CFG function.
return FunctionConverter(cfgFunc).convert(mlFunc);
}
void ModuleConverter::replaceReferences() {
- for (Function *fn : module->functionList) {
- switch (fn->getKind()) {
+ for (Function &fn : *module) {
+ switch (fn.getKind()) {
case Function::Kind::CFGFunc:
- replaceReferences(cast<CFGFunction>(fn));
+ replaceReferences(&cast<CFGFunction>(fn));
break;
case Function::Kind::MLFunc:
- replaceReferences(cast<MLFunction>(fn));
+ replaceReferences(&cast<MLFunction>(fn));
break;
case Function::Kind::ExtFunc:
// nothing to do for external functions
@@ -139,20 +139,14 @@
// Removes all ML functions from the module.
void ModuleConverter::removeMLFunctions() {
- std::vector<Function *> &fnList = module->functionList;
-
- // Delete ML functions and its data.
- for (auto &fn : fnList) {
- if (auto mlFunc = dyn_cast<MLFunction>(fn)) {
- delete mlFunc;
- fn = nullptr;
- }
+ // Delete ML functions from the module.
+ for (auto it = module->begin(), e = module->end(); it != e;) {
+ // Manipulate iterator carefully to avoid deleting a function we're pointing
+ // at.
+ Function &fn = *it++;
+ if (auto mlFunc = dyn_cast<MLFunction>(&fn))
+ mlFunc->eraseFromModule();
}
-
- // Remove ML functions from the function list.
- fnList.erase(std::remove_if(fnList.begin(), fnList.end(),
- [](Function *fn) { return !fn; }),
- fnList.end());
}
//===----------------------------------------------------------------------===//
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index af0c7ca..337f558 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -44,8 +44,8 @@
/// Unrolls all the innermost loops of this Module.
bool MLFunctionPass::runOnModule(Module *m) {
bool changed = false;
- for (auto fn : m->functionList) {
- if (auto *mlFunc = dyn_cast<MLFunction>(fn))
+ for (auto &fn : *m) {
+ if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
changed |= runOnMLFunction(mlFunc);
}
return changed;