Implement a module-level symbol table for functions, enforcing uniqueness of
names across the module and auto-renaming conflicts.  Have the parser reject
malformed modules that have redefinitions.

PiperOrigin-RevId: 209227560
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9e4dd65..896b76e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -909,7 +909,7 @@
   // Filter out any attributes that shouldn't be included.
   SmallVector<NamedAttribute, 8> filteredAttrs;
   for (auto attr : attrs) {
-    auto attrName = attr.first.str();
+    auto attrName = attr.first.ref();
     // Never print attributes that start with a colon.  These are internal
     // attributes that represent location or other internal metadata.
     if (attrName.startswith(":"))
@@ -946,7 +946,7 @@
 
   // Check to see if this is a known operation.  If so, use the registered
   // custom printer hook.
-  if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
+  if (auto *opInfo = state.operationSet->lookup(op->getName().ref())) {
     opInfo->printAssembly(op, this);
     return;
   }
@@ -957,7 +957,7 @@
 
 void FunctionPrinter::printDefaultOp(const Operation *op) {
   os << '"';
-  printEscapedString(op->getName().str(), os);
+  printEscapedString(op->getName().ref(), os);
   os << "\"(";
 
   interleaveComma(op->getOperands(),
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 1c89de8..df3d7c8 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -20,12 +20,12 @@
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringRef.h"
 using namespace mlir;
 
 Function::Function(StringRef name, FunctionType *type, Kind kind)
-  : kind(kind), name(name.str()), type(type) {
-}
+    : kind(kind), name(Identifier::get(name, type->getContext())), type(type) {}
 
 MLIRContext *Function::getContext() const { return getType()->getContext(); }
 
@@ -52,16 +52,39 @@
 }
 
 /// This is a trait method invoked when a Function is added to a Module.  We
-/// keep the module pointer up to date.
+/// keep the module pointer and module symbol table up to date.
 void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
   assert(!function->getModule() && "already in a module!");
-  function->module = getContainingModule();
+  auto *module = getContainingModule();
+  function->module = module;
+
+  // Add this function to the symbol table of the module, uniquing the name if
+  // a conflict is detected.
+  if (!module->symbolTable.insert({function->name, function}).second) {
+    // If a conflict was detected, then the function will not have been added to
+    // the symbol table.  Try suffixes until we get to a unique name that works.
+    SmallString<128> nameBuffer(function->getName().begin(),
+                                function->getName().end());
+    unsigned originalLength = nameBuffer.size();
+
+    // Iteratively try suffixes until we find one that isn't used.  We use a
+    // module level uniquing counter to avoid N^2 behavior.
+    do {
+      nameBuffer.resize(originalLength);
+      nameBuffer += '_';
+      nameBuffer += std::to_string(module->uniquingCounter++);
+      function->name = Identifier::get(nameBuffer, module->getContext());
+    } while (!module->symbolTable.insert({function->name, function}).second);
+  }
 }
 
 /// This is a trait method invoked when a Function is removed from a Module.
 /// We keep the module pointer up to date.
 void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
   assert(function->module && "not already in a module!");
+
+  // Remove the symbol table entry.
+  function->module->symbolTable.erase(function->getName());
   function->module = nullptr;
 }
 
@@ -76,9 +99,11 @@
   if (curParent == otherList.getContainingModule())
     return;
 
-  // Update the 'module' member of each function.
-  for (; first != last; ++first)
-    first->module = curParent;
+  // Update the 'module' member and symbol table records for each function.
+  for (; first != last; ++first) {
+    removeNodeFromList(&*first);
+    addNodeToList(&*first);
+  }
 }
 
 /// Unlink this function from its Module and delete it.
diff --git a/lib/IR/Module.cpp b/lib/IR/Module.cpp
index 99e5e32..5a33d26 100644
--- a/lib/IR/Module.cpp
+++ b/lib/IR/Module.cpp
@@ -19,3 +19,16 @@
 using namespace mlir;
 
 Module::Module(MLIRContext *context) : context(context) {}
+
+/// Look up a function with the specified name, returning null if no such
+/// name exists.
+Function *Module::getNamedFunction(StringRef name) {
+  return getNamedFunction(Identifier::get(name, context));
+}
+
+/// Look up a function with the specified name, returning null if no such
+/// name exists.
+Function *Module::getNamedFunction(Identifier name) {
+  auto it = symbolTable.find(name);
+  return it != symbolTable.end() ? it->second : nullptr;
+}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index f263804..ae87ad2 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1945,8 +1945,6 @@
                            "'");
   }
 
-  getModule()->getFunctions().push_back(function);
-
   return finalizeFunction(function, braceLoc);
 }
 
@@ -2122,8 +2120,6 @@
   if (parseStmtBlock(function))
     return ParseFailure;
 
-  getModule()->getFunctions().push_back(function);
-
   return finalizeFunction(function, braceLoc);
 }
 
@@ -2579,6 +2575,7 @@
 ///
 ParseResult ModuleParser::parseExtFunc() {
   consumeToken(Token::kw_extfunc);
+  auto loc = getToken().getLoc();
 
   StringRef name;
   FunctionType *type = nullptr;
@@ -2586,7 +2583,14 @@
     return ParseFailure;
 
   // Okay, the external function definition was parsed correctly.
-  getModule()->getFunctions().push_back(new ExtFunction(name, type));
+  auto *function = new ExtFunction(name, type);
+  getModule()->getFunctions().push_back(function);
+
+  // Verify no name collision / redefinition.
+  if (function->getName().ref() != name)
+    return emitError(loc,
+                     "redefinition of function named '" + name.str() + "'");
+
   return ParseSuccess;
 }
 
@@ -2596,6 +2600,7 @@
 ///
 ParseResult ModuleParser::parseCFGFunc() {
   consumeToken(Token::kw_cfgfunc);
+  auto loc = getToken().getLoc();
 
   StringRef name;
   FunctionType *type = nullptr;
@@ -2603,7 +2608,13 @@
     return ParseFailure;
 
   // Okay, the CFG function signature was parsed correctly, create the function.
-  auto function = new CFGFunction(name, type);
+  auto *function = new CFGFunction(name, type);
+  getModule()->getFunctions().push_back(function);
+
+  // Verify no name collision / redefinition.
+  if (function->getName().ref() != name)
+    return emitError(loc,
+                     "redefinition of function named '" + name.str() + "'");
 
   return CFGFunctionParser(getState(), function).parseFunctionBody();
 }
@@ -2624,7 +2635,13 @@
     return ParseFailure;
 
   // Okay, the ML function signature was parsed correctly, create the function.
-  auto function = MLFunction::create(name, type);
+  auto *function = MLFunction::create(name, type);
+  getModule()->getFunctions().push_back(function);
+
+  // Verify no name collision / redefinition.
+  if (function->getName().ref() != name)
+    return emitError(loc,
+                     "redefinition of function named '" + name.str() + "'");
 
   // Create the parser.
   auto parser = MLFunctionParser(getState(), function);
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 3863ea0..1bac43f 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -110,7 +110,7 @@
 CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
   // TODO: ensure that CFG function name is unique.
   auto *cfgFunc =
-      new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
+      new CFGFunction(mlFunc->getName().str() + "_cfg", mlFunc->getType());
   module->getFunctions().push_back(cfgFunc);
 
   // Generates the body of the CFG function.