Implement a module-level symbol table for functions, enforcing uniqueness of
names across the module and auto-renaming conflicts. Have the parser reject
malformed modules that have redefinitions.
PiperOrigin-RevId: 209227560
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9e4dd65..896b76e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -909,7 +909,7 @@
// Filter out any attributes that shouldn't be included.
SmallVector<NamedAttribute, 8> filteredAttrs;
for (auto attr : attrs) {
- auto attrName = attr.first.str();
+ auto attrName = attr.first.ref();
// Never print attributes that start with a colon. These are internal
// attributes that represent location or other internal metadata.
if (attrName.startswith(":"))
@@ -946,7 +946,7 @@
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
- if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
+ if (auto *opInfo = state.operationSet->lookup(op->getName().ref())) {
opInfo->printAssembly(op, this);
return;
}
@@ -957,7 +957,7 @@
void FunctionPrinter::printDefaultOp(const Operation *op) {
os << '"';
- printEscapedString(op->getName().str(), os);
+ printEscapedString(op->getName().ref(), os);
os << "\"(";
interleaveComma(op->getOperands(),
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 1c89de8..df3d7c8 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -20,12 +20,12 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
using namespace mlir;
Function::Function(StringRef name, FunctionType *type, Kind kind)
- : kind(kind), name(name.str()), type(type) {
-}
+ : kind(kind), name(Identifier::get(name, type->getContext())), type(type) {}
MLIRContext *Function::getContext() const { return getType()->getContext(); }
@@ -52,16 +52,39 @@
}
/// This is a trait method invoked when a Function is added to a Module. We
-/// keep the module pointer up to date.
+/// keep the module pointer and module symbol table up to date.
void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
assert(!function->getModule() && "already in a module!");
- function->module = getContainingModule();
+ auto *module = getContainingModule();
+ function->module = module;
+
+ // Add this function to the symbol table of the module, uniquing the name if
+ // a conflict is detected.
+ if (!module->symbolTable.insert({function->name, function}).second) {
+ // If a conflict was detected, then the function will not have been added to
+ // the symbol table. Try suffixes until we get to a unique name that works.
+ SmallString<128> nameBuffer(function->getName().begin(),
+ function->getName().end());
+ unsigned originalLength = nameBuffer.size();
+
+ // Iteratively try suffixes until we find one that isn't used. We use a
+ // module level uniquing counter to avoid N^2 behavior.
+ do {
+ nameBuffer.resize(originalLength);
+ nameBuffer += '_';
+ nameBuffer += std::to_string(module->uniquingCounter++);
+ function->name = Identifier::get(nameBuffer, module->getContext());
+ } while (!module->symbolTable.insert({function->name, function}).second);
+ }
}
/// This is a trait method invoked when a Function is removed from a Module.
/// We keep the module pointer up to date.
void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
assert(function->module && "not already in a module!");
+
+ // Remove the symbol table entry.
+ function->module->symbolTable.erase(function->getName());
function->module = nullptr;
}
@@ -76,9 +99,11 @@
if (curParent == otherList.getContainingModule())
return;
- // Update the 'module' member of each function.
- for (; first != last; ++first)
- first->module = curParent;
+ // Update the 'module' member and symbol table records for each function.
+ for (; first != last; ++first) {
+ removeNodeFromList(&*first);
+ addNodeToList(&*first);
+ }
}
/// Unlink this function from its Module and delete it.
diff --git a/lib/IR/Module.cpp b/lib/IR/Module.cpp
index 99e5e32..5a33d26 100644
--- a/lib/IR/Module.cpp
+++ b/lib/IR/Module.cpp
@@ -19,3 +19,16 @@
using namespace mlir;
Module::Module(MLIRContext *context) : context(context) {}
+
+/// Look up a function with the specified name, returning null if no such
+/// name exists.
+Function *Module::getNamedFunction(StringRef name) {
+ return getNamedFunction(Identifier::get(name, context));
+}
+
+/// Look up a function with the specified name, returning null if no such
+/// name exists.
+Function *Module::getNamedFunction(Identifier name) {
+ auto it = symbolTable.find(name);
+ return it != symbolTable.end() ? it->second : nullptr;
+}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index f263804..ae87ad2 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1945,8 +1945,6 @@
"'");
}
- getModule()->getFunctions().push_back(function);
-
return finalizeFunction(function, braceLoc);
}
@@ -2122,8 +2120,6 @@
if (parseStmtBlock(function))
return ParseFailure;
- getModule()->getFunctions().push_back(function);
-
return finalizeFunction(function, braceLoc);
}
@@ -2579,6 +2575,7 @@
///
ParseResult ModuleParser::parseExtFunc() {
consumeToken(Token::kw_extfunc);
+ auto loc = getToken().getLoc();
StringRef name;
FunctionType *type = nullptr;
@@ -2586,7 +2583,14 @@
return ParseFailure;
// Okay, the external function definition was parsed correctly.
- getModule()->getFunctions().push_back(new ExtFunction(name, type));
+ auto *function = new ExtFunction(name, type);
+ getModule()->getFunctions().push_back(function);
+
+ // Verify no name collision / redefinition.
+ if (function->getName().ref() != name)
+ return emitError(loc,
+ "redefinition of function named '" + name.str() + "'");
+
return ParseSuccess;
}
@@ -2596,6 +2600,7 @@
///
ParseResult ModuleParser::parseCFGFunc() {
consumeToken(Token::kw_cfgfunc);
+ auto loc = getToken().getLoc();
StringRef name;
FunctionType *type = nullptr;
@@ -2603,7 +2608,13 @@
return ParseFailure;
// Okay, the CFG function signature was parsed correctly, create the function.
- auto function = new CFGFunction(name, type);
+ auto *function = new CFGFunction(name, type);
+ getModule()->getFunctions().push_back(function);
+
+ // Verify no name collision / redefinition.
+ if (function->getName().ref() != name)
+ return emitError(loc,
+ "redefinition of function named '" + name.str() + "'");
return CFGFunctionParser(getState(), function).parseFunctionBody();
}
@@ -2624,7 +2635,13 @@
return ParseFailure;
// Okay, the ML function signature was parsed correctly, create the function.
- auto function = MLFunction::create(name, type);
+ auto *function = MLFunction::create(name, type);
+ getModule()->getFunctions().push_back(function);
+
+ // Verify no name collision / redefinition.
+ if (function->getName().ref() != name)
+ return emitError(loc,
+ "redefinition of function named '" + name.str() + "'");
// Create the parser.
auto parser = MLFunctionParser(getState(), function);
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 3863ea0..1bac43f 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -110,7 +110,7 @@
CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
// TODO: ensure that CFG function name is unique.
auto *cfgFunc =
- new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
+ new CFGFunction(mlFunc->getName().str() + "_cfg", mlFunc->getType());
module->getFunctions().push_back(cfgFunc);
// Generates the body of the CFG function.