Implement a proper function list in module, which auto-maintain the parent
pointer, and ensure that functions are deleted when the module is destroyed.
This exposed the fact that MLFunction had no dtor, and that the dtor in
CFGFunction was broken with cyclic references. Fix both of these problems.
PiperOrigin-RevId: 206051666
diff --git a/include/mlir/IR/BasicBlock.h b/include/mlir/IR/BasicBlock.h
index d59a83b..c26203c 100644
--- a/include/mlir/IR/BasicBlock.h
+++ b/include/mlir/IR/BasicBlock.h
@@ -146,7 +146,7 @@
} // end namespace mlir
//===----------------------------------------------------------------------===//
-// ilist_traits for OperationInst
+// ilist_traits for BasicBlock
//===----------------------------------------------------------------------===//
namespace llvm {
diff --git a/include/mlir/IR/CFGFunction.h b/include/mlir/IR/CFGFunction.h
index 977217d..e01987d 100644
--- a/include/mlir/IR/CFGFunction.h
+++ b/include/mlir/IR/CFGFunction.h
@@ -28,6 +28,7 @@
class CFGFunction : public Function {
public:
CFGFunction(StringRef name, FunctionType *type);
+ ~CFGFunction();
//===--------------------------------------------------------------------===//
// BasicBlock list management
@@ -67,15 +68,15 @@
return const_cast<CFGFunction*>(this)->front();
}
+ //===--------------------------------------------------------------------===//
+ // Other
+ //===--------------------------------------------------------------------===//
+
/// getSublistAccess() - Returns pointer to member of block list
static BasicBlockListType CFGFunction::*getSublistAccess(BasicBlock*) {
return &CFGFunction::blocks;
}
- //===--------------------------------------------------------------------===//
- // Other
- //===--------------------------------------------------------------------===//
-
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Function *func) {
return func->getKind() == Kind::CFGFunc;
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index e413172..7e70c37 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -25,13 +25,15 @@
#define MLIR_IR_FUNCTION_H
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ilist.h"
namespace mlir {
class FunctionType;
class MLIRContext;
+class Module;
/// This is the base class for all of the MLIR function types.
-class Function {
+class Function : public llvm::ilist_node_with_parent<Function, Module> {
public:
enum class Kind { ExtFunc, CFGFunc, MLFunc };
@@ -44,6 +46,14 @@
FunctionType *getType() const { return type; }
MLIRContext *getContext() const;
+ Module *getModule() { return module; }
+ const Module *getModule() const { return module; }
+
+ /// Unlink this instruction from its module and delete it.
+ void eraseFromModule();
+
+ /// Delete this object.
+ void destroy();
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs. On error, this fills in the string and return true,
@@ -59,10 +69,12 @@
private:
Kind kind;
+ Module *module = nullptr;
std::string name;
FunctionType *const type;
void operator=(const Function &) = delete;
+ friend struct llvm::ilist_traits<Function>;
};
/// An extfunc declaration is a declaration of a function signature that is
@@ -77,7 +89,30 @@
}
};
-
} // end namespace mlir
+//===----------------------------------------------------------------------===//
+// ilist_traits for Function
+//===----------------------------------------------------------------------===//
+
+namespace llvm {
+
+template <>
+struct ilist_traits<::mlir::Function>
+ : public ilist_alloc_traits<::mlir::Function> {
+ using Function = ::mlir::Function;
+ using function_iterator = simple_ilist<Function>::iterator;
+
+ static void deleteNode(Function *inst) { inst->destroy(); }
+
+ void addNodeToList(Function *function);
+ void removeNodeFromList(Function *function);
+ void transferNodesFromList(ilist_traits<Function> &otherList,
+ function_iterator first, function_iterator last);
+
+private:
+ mlir::Module *getContainingModule();
+};
+} // end namespace llvm
+
#endif // MLIR_IR_FUNCTION_H
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index cbd9578..a625997 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -27,7 +27,6 @@
#include "mlir/IR/Operation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h"
-#include "llvm/ADT/ilist_node.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
@@ -108,6 +107,11 @@
return getInstOperands()[idx];
}
+ /// This drops all operand uses from this instruction, which is an essential
+ /// step in breaking cyclic dependences between references when they are to
+ /// be deleted.
+ void dropAllReferences();
+
protected:
Instruction(Kind kind) : kind(kind) {}
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index b0b9c96..8e45f8d 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -23,6 +23,7 @@
#define MLIR_IR_MODULE_H
#include "mlir/IR/Function.h"
+#include "llvm/ADT/ilist.h"
#include <vector>
namespace mlir {
@@ -35,11 +36,27 @@
MLIRContext *getContext() const { return context; }
- // FIXME: wrong representation and API.
- // TODO(someone): This should switch to llvm::iplist<Function>.
- // TODO(someone): we also need a symbol table for function names +
- // autorenaming like LLVM does.
- std::vector<Function*> functionList;
+ // TODO: We should have a symbol table for function names.
+
+ /// This is the list of functions in the module.
+ typedef llvm::iplist<Function> FunctionListType;
+ FunctionListType &getFunctions() { return functions; }
+ const FunctionListType &getFunctions() const { return functions; }
+
+ // Iteration over the functions in the module.
+ using iterator = FunctionListType::iterator;
+ using const_iterator = FunctionListType::const_iterator;
+ using reverse_iterator = FunctionListType::reverse_iterator;
+ using const_reverse_iterator = FunctionListType::const_reverse_iterator;
+
+ iterator begin() { return functions.begin(); }
+ iterator end() { return functions.end(); }
+ const_iterator begin() const { return functions.begin(); }
+ const_iterator end() const { return functions.end(); }
+ reverse_iterator rbegin() { return functions.rbegin(); }
+ reverse_iterator rend() { return functions.rend(); }
+ const_reverse_iterator rbegin() const { return functions.rbegin(); }
+ const_reverse_iterator rend() const { return functions.rend(); }
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs. On error, this fills in the string and return true,
@@ -49,8 +66,14 @@
void print(raw_ostream &os) const;
void dump() const;
+ /// getSublistAccess() - Returns pointer to member of function list
+ static FunctionListType Module::*getSublistAccess(Function *) {
+ return &Module::functions;
+ }
+
private:
MLIRContext *context;
+ FunctionListType functions;
};
} // end namespace mlir