Implement a proper function list in module, which auto-maintain the parent
pointer, and ensure that functions are deleted when the module is destroyed.

This exposed the fact that MLFunction had no dtor, and that the dtor in
CFGFunction was broken with cyclic references.  Fix both of these problems.

PiperOrigin-RevId: 206051666
diff --git a/include/mlir/IR/BasicBlock.h b/include/mlir/IR/BasicBlock.h
index d59a83b..c26203c 100644
--- a/include/mlir/IR/BasicBlock.h
+++ b/include/mlir/IR/BasicBlock.h
@@ -146,7 +146,7 @@
 } // end namespace mlir
 
 //===----------------------------------------------------------------------===//
-// ilist_traits for OperationInst
+// ilist_traits for BasicBlock
 //===----------------------------------------------------------------------===//
 
 namespace llvm {
diff --git a/include/mlir/IR/CFGFunction.h b/include/mlir/IR/CFGFunction.h
index 977217d..e01987d 100644
--- a/include/mlir/IR/CFGFunction.h
+++ b/include/mlir/IR/CFGFunction.h
@@ -28,6 +28,7 @@
 class CFGFunction : public Function {
 public:
   CFGFunction(StringRef name, FunctionType *type);
+  ~CFGFunction();
 
   //===--------------------------------------------------------------------===//
   // BasicBlock list management
@@ -67,15 +68,15 @@
     return const_cast<CFGFunction*>(this)->front();
   }
 
+  //===--------------------------------------------------------------------===//
+  // Other
+  //===--------------------------------------------------------------------===//
+
   /// getSublistAccess() - Returns pointer to member of block list
   static BasicBlockListType CFGFunction::*getSublistAccess(BasicBlock*) {
     return &CFGFunction::blocks;
   }
 
-  //===--------------------------------------------------------------------===//
-  // Other
-  //===--------------------------------------------------------------------===//
-
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Function *func) {
     return func->getKind() == Kind::CFGFunc;
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index e413172..7e70c37 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -25,13 +25,15 @@
 #define MLIR_IR_FUNCTION_H
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/ilist.h"
 
 namespace mlir {
 class FunctionType;
 class MLIRContext;
+class Module;
 
 /// This is the base class for all of the MLIR function types.
-class Function {
+class Function : public llvm::ilist_node_with_parent<Function, Module> {
 public:
   enum class Kind { ExtFunc, CFGFunc, MLFunc };
 
@@ -44,6 +46,14 @@
   FunctionType *getType() const { return type; }
 
   MLIRContext *getContext() const;
+  Module *getModule() { return module; }
+  const Module *getModule() const { return module; }
+
+  /// Unlink this instruction from its module and delete it.
+  void eraseFromModule();
+
+  /// Delete this object.
+  void destroy();
 
   /// Perform (potentially expensive) checks of invariants, used to detect
   /// compiler bugs.  On error, this fills in the string and return true,
@@ -59,10 +69,12 @@
 
 private:
   Kind kind;
+  Module *module = nullptr;
   std::string name;
   FunctionType *const type;
 
   void operator=(const Function &) = delete;
+  friend struct llvm::ilist_traits<Function>;
 };
 
 /// An extfunc declaration is a declaration of a function signature that is
@@ -77,7 +89,30 @@
   }
 };
 
-
 } // end namespace mlir
 
+//===----------------------------------------------------------------------===//
+// ilist_traits for Function
+//===----------------------------------------------------------------------===//
+
+namespace llvm {
+
+template <>
+struct ilist_traits<::mlir::Function>
+    : public ilist_alloc_traits<::mlir::Function> {
+  using Function = ::mlir::Function;
+  using function_iterator = simple_ilist<Function>::iterator;
+
+  static void deleteNode(Function *inst) { inst->destroy(); }
+
+  void addNodeToList(Function *function);
+  void removeNodeFromList(Function *function);
+  void transferNodesFromList(ilist_traits<Function> &otherList,
+                             function_iterator first, function_iterator last);
+
+private:
+  mlir::Module *getContainingModule();
+};
+} // end namespace llvm
+
 #endif  // MLIR_IR_FUNCTION_H
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index cbd9578..a625997 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -27,7 +27,6 @@
 #include "mlir/IR/Operation.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ilist.h"
-#include "llvm/ADT/ilist_node.h"
 #include "llvm/Support/TrailingObjects.h"
 
 namespace mlir {
@@ -108,6 +107,11 @@
     return getInstOperands()[idx];
   }
 
+  /// This drops all operand uses from this instruction, which is an essential
+  /// step in breaking cyclic dependences between references when they are to
+  /// be deleted.
+  void dropAllReferences();
+
 protected:
   Instruction(Kind kind) : kind(kind) {}
 
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index b0b9c96..8e45f8d 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -23,6 +23,7 @@
 #define MLIR_IR_MODULE_H
 
 #include "mlir/IR/Function.h"
+#include "llvm/ADT/ilist.h"
 #include <vector>
 
 namespace mlir {
@@ -35,11 +36,27 @@
 
   MLIRContext *getContext() const { return context; }
 
-  // FIXME: wrong representation and API.
-  // TODO(someone): This should switch to llvm::iplist<Function>.
-  // TODO(someone): we also need a symbol table for function names +
-  // autorenaming like LLVM does.
-  std::vector<Function*> functionList;
+  // TODO: We should have a symbol table for function names.
+
+  /// This is the list of functions in the module.
+  typedef llvm::iplist<Function> FunctionListType;
+  FunctionListType &getFunctions() { return functions; }
+  const FunctionListType &getFunctions() const { return functions; }
+
+  // Iteration over the functions in the module.
+  using iterator = FunctionListType::iterator;
+  using const_iterator = FunctionListType::const_iterator;
+  using reverse_iterator = FunctionListType::reverse_iterator;
+  using const_reverse_iterator = FunctionListType::const_reverse_iterator;
+
+  iterator begin() { return functions.begin(); }
+  iterator end() { return functions.end(); }
+  const_iterator begin() const { return functions.begin(); }
+  const_iterator end() const { return functions.end(); }
+  reverse_iterator rbegin() { return functions.rbegin(); }
+  reverse_iterator rend() { return functions.rend(); }
+  const_reverse_iterator rbegin() const { return functions.rbegin(); }
+  const_reverse_iterator rend() const { return functions.rend(); }
 
   /// Perform (potentially expensive) checks of invariants, used to detect
   /// compiler bugs.  On error, this fills in the string and return true,
@@ -49,8 +66,14 @@
   void print(raw_ostream &os) const;
   void dump() const;
 
+  /// getSublistAccess() - Returns pointer to member of function list
+  static FunctionListType Module::*getSublistAccess(Function *) {
+    return &Module::functions;
+  }
+
 private:
   MLIRContext *context;
+  FunctionListType functions;
 };
 } // end namespace mlir
 
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index e23964d..4fd7e61 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -157,8 +157,8 @@
 
 // Initializes module state, populating affine map state.
 void ModuleState::initialize(const Module *module) {
-  for (auto fn : module->functionList) {
-    visitFunction(fn);
+  for (auto &fn : *module) {
+    visitFunction(&fn);
   }
 }
 
@@ -236,8 +236,8 @@
     map->print(os);
     os << '\n';
   }
-  for (auto *fn : module->functionList)
-    print(fn);
+  for (auto const &fn : *module)
+    print(&fn);
 }
 
 void ModulePrinter::printAttribute(const Attribute *attr) {
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 72ec443..8476b06 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -17,6 +17,7 @@
 
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Module.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/StringRef.h"
 using namespace mlir;
@@ -27,6 +28,64 @@
 
 MLIRContext *Function::getContext() const { return getType()->getContext(); }
 
+/// Delete this object.
+void Function::destroy() {
+  switch (getKind()) {
+  case Kind::ExtFunc:
+    delete cast<ExtFunction>(this);
+    break;
+  case Kind::MLFunc:
+    delete cast<MLFunction>(this);
+    break;
+  case Kind::CFGFunc:
+    delete cast<CFGFunction>(this);
+    break;
+  }
+}
+
+Module *llvm::ilist_traits<Function>::getContainingModule() {
+  size_t Offset(
+      size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
+  iplist<Function> *Anchor(static_cast<iplist<Function> *>(this));
+  return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
+}
+
+/// This is a trait method invoked when a Function is added to a Module.  We
+/// keep the module pointer up to date.
+void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
+  assert(!function->getModule() && "already in a module!");
+  function->module = getContainingModule();
+}
+
+/// This is a trait method invoked when a Function is removed from a Module.
+/// We keep the module pointer up to date.
+void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
+  assert(function->module && "not already in a module!");
+  function->module = nullptr;
+}
+
+/// This is a trait method invoked when an instruction is moved from one block
+/// to another.  We keep the block pointer up to date.
+void llvm::ilist_traits<Function>::transferNodesFromList(
+    ilist_traits<Function> &otherList, function_iterator first,
+    function_iterator last) {
+  // If we are transferring functions within the same module, the Module
+  // pointer doesn't need to be updated.
+  Module *curParent = getContainingModule();
+  if (curParent == otherList.getContainingModule())
+    return;
+
+  // Update the 'module' member of each function.
+  for (; first != last; ++first)
+    first->module = curParent;
+}
+
+/// Unlink this function from its Module and delete it.
+void Function::eraseFromModule() {
+  assert(getModule() && "Function has no parent");
+  getModule()->getFunctions().erase(this);
+}
+
 //===----------------------------------------------------------------------===//
 // ExtFunction implementation.
 //===----------------------------------------------------------------------===//
@@ -43,9 +102,23 @@
   : Function(name, type, Kind::CFGFunc) {
 }
 
+CFGFunction::~CFGFunction() {
+  // Instructions may have cyclic references, which need to be dropped before we
+  // can start deleting them.
+  for (auto &bb : *this) {
+    for (auto &inst : bb)
+      inst.dropAllReferences();
+  }
+}
+
 //===----------------------------------------------------------------------===//
 // MLFunction implementation.
 //===----------------------------------------------------------------------===//
 
 MLFunction::MLFunction(StringRef name, FunctionType *type)
     : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
+
+MLFunction::~MLFunction() {
+  // TODO: When move SSA stuff is supported.
+  // dropAllReferences();
+}
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index 8cbdabc..a1cdcb3 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -99,6 +99,14 @@
   }
 }
 
+/// This drops all operand uses from this instruction, which is an essential
+/// step in breaking cyclic dependences between references when they are to
+/// be deleted.
+void Instruction::dropAllReferences() {
+  for (auto &op : getInstOperands())
+    op.drop();
+}
+
 //===----------------------------------------------------------------------===//
 // OperationInst
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 96f7f73..9e52556 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -330,8 +330,8 @@
 bool Module::verify(std::string *errorResult) const {
 
   /// Check that each function is correct.
-  for (auto fn : functionList) {
-    if (fn->verify(errorResult))
+  for (auto &fn : *this) {
+    if (fn.verify(errorResult))
       return true;
   }
 
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 55fd260..362f868 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -917,11 +917,10 @@
 
   StringRef sRef = getTokenSpelling();
   for (auto entry : dimsAndSymbols) {
-    if (entry.first != sRef)
-      continue;
-
-    consumeToken(Token::bare_identifier);
-    return entry.second;
+    if (entry.first == sRef) {
+      consumeToken(Token::bare_identifier);
+      return entry.second;
+    }
   }
 
   return (emitError("use of undeclared identifier"), nullptr);
@@ -1861,7 +1860,7 @@
                            "'");
   }
 
-  getModule()->functionList.push_back(function);
+  getModule()->getFunctions().push_back(function);
 
   return finalizeFunction(function, braceLoc);
 }
@@ -2053,7 +2052,7 @@
       parseToken(Token::r_brace, "expected '}' to end mlfunc"))
     return ParseFailure;
 
-  getModule()->functionList.push_back(function);
+  getModule()->getFunctions().push_back(function);
 
   return finalizeFunction(function, braceLoc);
 }
@@ -2356,7 +2355,7 @@
     return ParseFailure;
 
   // Okay, the external function definition was parsed correctly.
-  getModule()->functionList.push_back(new ExtFunction(name, type));
+  getModule()->getFunctions().push_back(new ExtFunction(name, type));
   return ParseSuccess;
 }
 
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 0a98e2b..9487cf9 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -96,8 +96,8 @@
 }
 
 void ModuleConverter::convertMLFunctions() {
-  for (Function *fn : module->functionList) {
-    if (auto mlFunc = dyn_cast<MLFunction>(fn))
+  for (Function &fn : *module) {
+    if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
       generatedFuncs[mlFunc] = convert(mlFunc);
   }
 }
@@ -105,22 +105,22 @@
 // Creates CFG function equivalent to the given ML function.
 CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
   // TODO: ensure that CFG function name is unique.
-  CFGFunction *cfgFunc =
+  auto *cfgFunc =
       new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
-  module->functionList.push_back(cfgFunc);
+  module->getFunctions().push_back(cfgFunc);
 
   // Generates the body of the CFG function.
   return FunctionConverter(cfgFunc).convert(mlFunc);
 }
 
 void ModuleConverter::replaceReferences() {
-  for (Function *fn : module->functionList) {
-    switch (fn->getKind()) {
+  for (Function &fn : *module) {
+    switch (fn.getKind()) {
     case Function::Kind::CFGFunc:
-      replaceReferences(cast<CFGFunction>(fn));
+      replaceReferences(&cast<CFGFunction>(fn));
       break;
     case Function::Kind::MLFunc:
-      replaceReferences(cast<MLFunction>(fn));
+      replaceReferences(&cast<MLFunction>(fn));
       break;
     case Function::Kind::ExtFunc:
       // nothing to do for external functions
@@ -139,20 +139,14 @@
 
 // Removes all ML functions from the module.
 void ModuleConverter::removeMLFunctions() {
-  std::vector<Function *> &fnList = module->functionList;
-
-  // Delete ML functions and its data.
-  for (auto &fn : fnList) {
-    if (auto mlFunc = dyn_cast<MLFunction>(fn)) {
-      delete mlFunc;
-      fn = nullptr;
-    }
+  // Delete ML functions from the module.
+  for (auto it = module->begin(), e = module->end(); it != e;) {
+    // Manipulate iterator carefully to avoid deleting a function we're pointing
+    // at.
+    Function &fn = *it++;
+    if (auto mlFunc = dyn_cast<MLFunction>(&fn))
+      mlFunc->eraseFromModule();
   }
-
-  // Remove ML functions from the function list.
-  fnList.erase(std::remove_if(fnList.begin(), fnList.end(),
-                              [](Function *fn) { return !fn; }),
-               fnList.end());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index af0c7ca..337f558 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -44,8 +44,8 @@
 /// Unrolls all the innermost loops of this Module.
 bool MLFunctionPass::runOnModule(Module *m) {
   bool changed = false;
-  for (auto fn : m->functionList) {
-    if (auto *mlFunc = dyn_cast<MLFunction>(fn))
+  for (auto &fn : *m) {
+    if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
       changed |= runOnMLFunction(mlFunc);
   }
   return changed;