Implement a module-level symbol table for functions, enforcing uniqueness of
names across the module and auto-renaming conflicts.  Have the parser reject
malformed modules that have redefinitions.

PiperOrigin-RevId: 209227560
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index 7e70c37..15e08e3 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -24,6 +24,7 @@
 #ifndef MLIR_IR_FUNCTION_H
 #define MLIR_IR_FUNCTION_H
 
+#include "mlir/IR/Identifier.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ilist.h"
 
@@ -40,7 +41,7 @@
   Kind getKind() const { return kind; }
 
   /// Return the name of this function, without the @.
-  const std::string &getName() const { return name; }
+  Identifier getName() const { return name; }
 
   /// Return the type of this function.
   FunctionType *getType() const { return type; }
@@ -70,7 +71,7 @@
 private:
   Kind kind;
   Module *module = nullptr;
-  std::string name;
+  Identifier name;
   FunctionType *const type;
 
   void operator=(const Function &) = delete;
diff --git a/include/mlir/IR/Identifier.h b/include/mlir/IR/Identifier.h
index 6b9a782..70528e3 100644
--- a/include/mlir/IR/Identifier.h
+++ b/include/mlir/IR/Identifier.h
@@ -36,11 +36,17 @@
 public:
   /// Return an identifier for the specified string.
   static Identifier get(StringRef str, const MLIRContext *context);
+  Identifier(const Identifier &) = default;
+  Identifier &operator=(const Identifier &other) = default;
 
   /// Return a StringRef for the string.
-  StringRef str() const {
-    return StringRef(pointer, size());
-  }
+  StringRef ref() const { return StringRef(pointer, size()); }
+
+  /// Return an std::string.
+  std::string str() const { return ref().str(); }
+
+  /// Return a null terminated C string.
+  const char *c_str() const { return pointer; }
 
   /// Return a pointer to the start of the string data.
   const char *data() const {
@@ -53,12 +59,10 @@
   }
 
   /// Return true if this identifier is the specified string.
-  bool is(StringRef string) const {
-    return str().equals(string);
-  }
+  bool is(StringRef string) const { return ref().equals(string); }
 
-  Identifier(const Identifier&) = default;
-  Identifier &operator=(const Identifier &other) = default;
+  const char *begin() const { return pointer; }
+  const char *end() const { return pointer + size(); }
 
   void print(raw_ostream &os) const;
   void dump() const;
@@ -96,7 +100,7 @@
 
 // Make identifiers hashable.
 inline llvm::hash_code hash_value(Identifier arg) {
-  return llvm::hash_value(arg.str());
+  return llvm::hash_value(arg.ref());
 }
 
 } // end namespace mlir
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index ed12afc..f6dedb6 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -39,7 +39,7 @@
       public StmtBlock,
       private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
 public:
-  /// Creates a new MLFunction with the specific fields.
+  /// Creates a new MLFunction with the specific type.
   static MLFunction *create(StringRef name, FunctionType *type);
 
   /// Destroys this statement and its subclass data.
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index 8e45f8d..f6c543b 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -23,6 +23,7 @@
 #define MLIR_IR_MODULE_H
 
 #include "mlir/IR/Function.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/ilist.h"
 #include <vector>
 
@@ -36,8 +37,6 @@
 
   MLIRContext *getContext() const { return context; }
 
-  // TODO: We should have a symbol table for function names.
-
   /// This is the list of functions in the module.
   typedef llvm::iplist<Function> FunctionListType;
   FunctionListType &getFunctions() { return functions; }
@@ -58,6 +57,23 @@
   const_reverse_iterator rbegin() const { return functions.rbegin(); }
   const_reverse_iterator rend() const { return functions.rend(); }
 
+  // Interfaces for working with the symbol table.
+
+  /// Look up a function with the specified name, returning null if no such
+  /// name exists.
+  Function *getNamedFunction(StringRef name);
+  const Function *getNamedFunction(StringRef name) const {
+    return const_cast<Module *>(this)->getNamedFunction(name);
+  }
+
+  /// Look up a function with the specified name, returning null if no such
+  /// name exists.
+  Function *getNamedFunction(Identifier name);
+
+  const Function *getNamedFunction(Identifier name) const {
+    return const_cast<Module *>(this)->getNamedFunction(name);
+  }
+
   /// Perform (potentially expensive) checks of invariants, used to detect
   /// compiler bugs.  On error, this fills in the string and return true,
   /// or aborts if the string was not provided.
@@ -66,13 +82,23 @@
   void print(raw_ostream &os) const;
   void dump() const;
 
+private:
+  friend struct llvm::ilist_traits<Function>;
+
   /// getSublistAccess() - Returns pointer to member of function list
   static FunctionListType Module::*getSublistAccess(Function *) {
     return &Module::functions;
   }
 
-private:
   MLIRContext *context;
+
+  /// This is a mapping from a name to the function with that name.
+  llvm::DenseMap<Identifier, Function *> symbolTable;
+
+  /// This is used when name conflicts are detected.
+  unsigned uniquingCounter = 0;
+
+  /// This is the actual list of functions the module contains.
   FunctionListType functions;
 };
 } // end namespace mlir
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9e4dd65..896b76e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -909,7 +909,7 @@
   // Filter out any attributes that shouldn't be included.
   SmallVector<NamedAttribute, 8> filteredAttrs;
   for (auto attr : attrs) {
-    auto attrName = attr.first.str();
+    auto attrName = attr.first.ref();
     // Never print attributes that start with a colon.  These are internal
     // attributes that represent location or other internal metadata.
     if (attrName.startswith(":"))
@@ -946,7 +946,7 @@
 
   // Check to see if this is a known operation.  If so, use the registered
   // custom printer hook.
-  if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
+  if (auto *opInfo = state.operationSet->lookup(op->getName().ref())) {
     opInfo->printAssembly(op, this);
     return;
   }
@@ -957,7 +957,7 @@
 
 void FunctionPrinter::printDefaultOp(const Operation *op) {
   os << '"';
-  printEscapedString(op->getName().str(), os);
+  printEscapedString(op->getName().ref(), os);
   os << "\"(";
 
   interleaveComma(op->getOperands(),
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 1c89de8..df3d7c8 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -20,12 +20,12 @@
 #include "mlir/IR/Module.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringRef.h"
 using namespace mlir;
 
 Function::Function(StringRef name, FunctionType *type, Kind kind)
-  : kind(kind), name(name.str()), type(type) {
-}
+    : kind(kind), name(Identifier::get(name, type->getContext())), type(type) {}
 
 MLIRContext *Function::getContext() const { return getType()->getContext(); }
 
@@ -52,16 +52,39 @@
 }
 
 /// This is a trait method invoked when a Function is added to a Module.  We
-/// keep the module pointer up to date.
+/// keep the module pointer and module symbol table up to date.
 void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
   assert(!function->getModule() && "already in a module!");
-  function->module = getContainingModule();
+  auto *module = getContainingModule();
+  function->module = module;
+
+  // Add this function to the symbol table of the module, uniquing the name if
+  // a conflict is detected.
+  if (!module->symbolTable.insert({function->name, function}).second) {
+    // If a conflict was detected, then the function will not have been added to
+    // the symbol table.  Try suffixes until we get to a unique name that works.
+    SmallString<128> nameBuffer(function->getName().begin(),
+                                function->getName().end());
+    unsigned originalLength = nameBuffer.size();
+
+    // Iteratively try suffixes until we find one that isn't used.  We use a
+    // module level uniquing counter to avoid N^2 behavior.
+    do {
+      nameBuffer.resize(originalLength);
+      nameBuffer += '_';
+      nameBuffer += std::to_string(module->uniquingCounter++);
+      function->name = Identifier::get(nameBuffer, module->getContext());
+    } while (!module->symbolTable.insert({function->name, function}).second);
+  }
 }
 
 /// This is a trait method invoked when a Function is removed from a Module.
 /// We keep the module pointer up to date.
 void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
   assert(function->module && "not already in a module!");
+
+  // Remove the symbol table entry.
+  function->module->symbolTable.erase(function->getName());
   function->module = nullptr;
 }
 
@@ -76,9 +99,11 @@
   if (curParent == otherList.getContainingModule())
     return;
 
-  // Update the 'module' member of each function.
-  for (; first != last; ++first)
-    first->module = curParent;
+  // Update the 'module' member and symbol table records for each function.
+  for (; first != last; ++first) {
+    removeNodeFromList(&*first);
+    addNodeToList(&*first);
+  }
 }
 
 /// Unlink this function from its Module and delete it.
diff --git a/lib/IR/Module.cpp b/lib/IR/Module.cpp
index 99e5e32..5a33d26 100644
--- a/lib/IR/Module.cpp
+++ b/lib/IR/Module.cpp
@@ -19,3 +19,16 @@
 using namespace mlir;
 
 Module::Module(MLIRContext *context) : context(context) {}
+
+/// Look up a function with the specified name, returning null if no such
+/// name exists.
+Function *Module::getNamedFunction(StringRef name) {
+  return getNamedFunction(Identifier::get(name, context));
+}
+
+/// Look up a function with the specified name, returning null if no such
+/// name exists.
+Function *Module::getNamedFunction(Identifier name) {
+  auto it = symbolTable.find(name);
+  return it != symbolTable.end() ? it->second : nullptr;
+}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index f263804..ae87ad2 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1945,8 +1945,6 @@
                            "'");
   }
 
-  getModule()->getFunctions().push_back(function);
-
   return finalizeFunction(function, braceLoc);
 }
 
@@ -2122,8 +2120,6 @@
   if (parseStmtBlock(function))
     return ParseFailure;
 
-  getModule()->getFunctions().push_back(function);
-
   return finalizeFunction(function, braceLoc);
 }
 
@@ -2579,6 +2575,7 @@
 ///
 ParseResult ModuleParser::parseExtFunc() {
   consumeToken(Token::kw_extfunc);
+  auto loc = getToken().getLoc();
 
   StringRef name;
   FunctionType *type = nullptr;
@@ -2586,7 +2583,14 @@
     return ParseFailure;
 
   // Okay, the external function definition was parsed correctly.
-  getModule()->getFunctions().push_back(new ExtFunction(name, type));
+  auto *function = new ExtFunction(name, type);
+  getModule()->getFunctions().push_back(function);
+
+  // Verify no name collision / redefinition.
+  if (function->getName().ref() != name)
+    return emitError(loc,
+                     "redefinition of function named '" + name.str() + "'");
+
   return ParseSuccess;
 }
 
@@ -2596,6 +2600,7 @@
 ///
 ParseResult ModuleParser::parseCFGFunc() {
   consumeToken(Token::kw_cfgfunc);
+  auto loc = getToken().getLoc();
 
   StringRef name;
   FunctionType *type = nullptr;
@@ -2603,7 +2608,13 @@
     return ParseFailure;
 
   // Okay, the CFG function signature was parsed correctly, create the function.
-  auto function = new CFGFunction(name, type);
+  auto *function = new CFGFunction(name, type);
+  getModule()->getFunctions().push_back(function);
+
+  // Verify no name collision / redefinition.
+  if (function->getName().ref() != name)
+    return emitError(loc,
+                     "redefinition of function named '" + name.str() + "'");
 
   return CFGFunctionParser(getState(), function).parseFunctionBody();
 }
@@ -2624,7 +2635,13 @@
     return ParseFailure;
 
   // Okay, the ML function signature was parsed correctly, create the function.
-  auto function = MLFunction::create(name, type);
+  auto *function = MLFunction::create(name, type);
+  getModule()->getFunctions().push_back(function);
+
+  // Verify no name collision / redefinition.
+  if (function->getName().ref() != name)
+    return emitError(loc,
+                     "redefinition of function named '" + name.str() + "'");
 
   // Create the parser.
   auto parser = MLFunctionParser(getState(), function);
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 3863ea0..1bac43f 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -110,7 +110,7 @@
 CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
   // TODO: ensure that CFG function name is unique.
   auto *cfgFunc =
-      new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
+      new CFGFunction(mlFunc->getName().str() + "_cfg", mlFunc->getType());
   module->getFunctions().push_back(cfgFunc);
 
   // Generates the body of the CFG function.
diff --git a/test/IR/affine-map.mlir b/test/IR/affine-map.mlir
index 4a8ca6b..42c9608 100644
--- a/test/IR/affine-map.mlir
+++ b/test/IR/affine-map.mlir
@@ -300,5 +300,5 @@
 // CHECK: extfunc @f45(memref<100x100x100xi8, #map{{[0-9]+}}>)
 extfunc @f45(memref<100x100x100xi8, #map45>)
 
-// CHECK: extfunc @f45(memref<100x100x100xi8, #map{{[0-9]+}}>)
-extfunc @f45(memref<100x100x100xi8, #map46>)
+// CHECK: extfunc @f46(memref<100x100x100xi8, #map{{[0-9]+}}>)
+extfunc @f46(memref<100x100x100xi8, #map46>)
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 0a544c0..9edc152 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -397,3 +397,8 @@
     // expected-error@-1 {{'return' op must be the last statement in the ML function}}
   }
 }
+
+// -----
+
+extfunc @redef()
+extfunc @redef()  // expected-error {{redefinition of function named 'redef'}}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 5a36299..a5eb9e6 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -62,11 +62,11 @@
 
 // Test memref inline affine map compositions.
 
-// CHECK: extfunc @memrefs2(memref<2x4x8xi8, #map{{[0-9]+}}>)
-extfunc @memrefs2(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2)>)
+// CHECK: extfunc @memrefs3(memref<2x4x8xi8, #map{{[0-9]+}}>)
+extfunc @memrefs3(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2)>)
 
-// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 1>)
-extfunc @memrefs23(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
+// CHECK: extfunc @memrefs33(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 1>)
+extfunc @memrefs33(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
 
 // CHECK: extfunc @functions((memref<1x?x4x?x?xi32, #map0>, memref<i8, #map1>) -> (), () -> ())
 extfunc @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<i8, #map1, 0>) -> (), ()->())