Implement a module-level symbol table for functions, enforcing uniqueness of
names across the module and auto-renaming conflicts. Have the parser reject
malformed modules that have redefinitions.
PiperOrigin-RevId: 209227560
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index 7e70c37..15e08e3 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -24,6 +24,7 @@
#ifndef MLIR_IR_FUNCTION_H
#define MLIR_IR_FUNCTION_H
+#include "mlir/IR/Identifier.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h"
@@ -40,7 +41,7 @@
Kind getKind() const { return kind; }
/// Return the name of this function, without the @.
- const std::string &getName() const { return name; }
+ Identifier getName() const { return name; }
/// Return the type of this function.
FunctionType *getType() const { return type; }
@@ -70,7 +71,7 @@
private:
Kind kind;
Module *module = nullptr;
- std::string name;
+ Identifier name;
FunctionType *const type;
void operator=(const Function &) = delete;
diff --git a/include/mlir/IR/Identifier.h b/include/mlir/IR/Identifier.h
index 6b9a782..70528e3 100644
--- a/include/mlir/IR/Identifier.h
+++ b/include/mlir/IR/Identifier.h
@@ -36,11 +36,17 @@
public:
/// Return an identifier for the specified string.
static Identifier get(StringRef str, const MLIRContext *context);
+ Identifier(const Identifier &) = default;
+ Identifier &operator=(const Identifier &other) = default;
/// Return a StringRef for the string.
- StringRef str() const {
- return StringRef(pointer, size());
- }
+ StringRef ref() const { return StringRef(pointer, size()); }
+
+ /// Return an std::string.
+ std::string str() const { return ref().str(); }
+
+ /// Return a null terminated C string.
+ const char *c_str() const { return pointer; }
/// Return a pointer to the start of the string data.
const char *data() const {
@@ -53,12 +59,10 @@
}
/// Return true if this identifier is the specified string.
- bool is(StringRef string) const {
- return str().equals(string);
- }
+ bool is(StringRef string) const { return ref().equals(string); }
- Identifier(const Identifier&) = default;
- Identifier &operator=(const Identifier &other) = default;
+ const char *begin() const { return pointer; }
+ const char *end() const { return pointer + size(); }
void print(raw_ostream &os) const;
void dump() const;
@@ -96,7 +100,7 @@
// Make identifiers hashable.
inline llvm::hash_code hash_value(Identifier arg) {
- return llvm::hash_value(arg.str());
+ return llvm::hash_value(arg.ref());
}
} // end namespace mlir
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index ed12afc..f6dedb6 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -39,7 +39,7 @@
public StmtBlock,
private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
public:
- /// Creates a new MLFunction with the specific fields.
+ /// Creates a new MLFunction with the specific type.
static MLFunction *create(StringRef name, FunctionType *type);
/// Destroys this statement and its subclass data.
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index 8e45f8d..f6c543b 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -23,6 +23,7 @@
#define MLIR_IR_MODULE_H
#include "mlir/IR/Function.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ilist.h"
#include <vector>
@@ -36,8 +37,6 @@
MLIRContext *getContext() const { return context; }
- // TODO: We should have a symbol table for function names.
-
/// This is the list of functions in the module.
typedef llvm::iplist<Function> FunctionListType;
FunctionListType &getFunctions() { return functions; }
@@ -58,6 +57,23 @@
const_reverse_iterator rbegin() const { return functions.rbegin(); }
const_reverse_iterator rend() const { return functions.rend(); }
+ // Interfaces for working with the symbol table.
+
+ /// Look up a function with the specified name, returning null if no such
+ /// name exists.
+ Function *getNamedFunction(StringRef name);
+ const Function *getNamedFunction(StringRef name) const {
+ return const_cast<Module *>(this)->getNamedFunction(name);
+ }
+
+ /// Look up a function with the specified name, returning null if no such
+ /// name exists.
+ Function *getNamedFunction(Identifier name);
+
+ const Function *getNamedFunction(Identifier name) const {
+ return const_cast<Module *>(this)->getNamedFunction(name);
+ }
+
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs. On error, this fills in the string and return true,
/// or aborts if the string was not provided.
@@ -66,13 +82,23 @@
void print(raw_ostream &os) const;
void dump() const;
+private:
+ friend struct llvm::ilist_traits<Function>;
+
/// getSublistAccess() - Returns pointer to member of function list
static FunctionListType Module::*getSublistAccess(Function *) {
return &Module::functions;
}
-private:
MLIRContext *context;
+
+ /// This is a mapping from a name to the function with that name.
+ llvm::DenseMap<Identifier, Function *> symbolTable;
+
+ /// This is used when name conflicts are detected.
+ unsigned uniquingCounter = 0;
+
+ /// This is the actual list of functions the module contains.
FunctionListType functions;
};
} // end namespace mlir
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9e4dd65..896b76e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -909,7 +909,7 @@
// Filter out any attributes that shouldn't be included.
SmallVector<NamedAttribute, 8> filteredAttrs;
for (auto attr : attrs) {
- auto attrName = attr.first.str();
+ auto attrName = attr.first.ref();
// Never print attributes that start with a colon. These are internal
// attributes that represent location or other internal metadata.
if (attrName.startswith(":"))
@@ -946,7 +946,7 @@
// Check to see if this is a known operation. If so, use the registered
// custom printer hook.
- if (auto opInfo = state.operationSet->lookup(op->getName().str())) {
+ if (auto *opInfo = state.operationSet->lookup(op->getName().ref())) {
opInfo->printAssembly(op, this);
return;
}
@@ -957,7 +957,7 @@
void FunctionPrinter::printDefaultOp(const Operation *op) {
os << '"';
- printEscapedString(op->getName().str(), os);
+ printEscapedString(op->getName().ref(), os);
os << "\"(";
interleaveComma(op->getOperands(),
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 1c89de8..df3d7c8 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -20,12 +20,12 @@
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
+#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
using namespace mlir;
Function::Function(StringRef name, FunctionType *type, Kind kind)
- : kind(kind), name(name.str()), type(type) {
-}
+ : kind(kind), name(Identifier::get(name, type->getContext())), type(type) {}
MLIRContext *Function::getContext() const { return getType()->getContext(); }
@@ -52,16 +52,39 @@
}
/// This is a trait method invoked when a Function is added to a Module. We
-/// keep the module pointer up to date.
+/// keep the module pointer and module symbol table up to date.
void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
assert(!function->getModule() && "already in a module!");
- function->module = getContainingModule();
+ auto *module = getContainingModule();
+ function->module = module;
+
+ // Add this function to the symbol table of the module, uniquing the name if
+ // a conflict is detected.
+ if (!module->symbolTable.insert({function->name, function}).second) {
+ // If a conflict was detected, then the function will not have been added to
+ // the symbol table. Try suffixes until we get to a unique name that works.
+ SmallString<128> nameBuffer(function->getName().begin(),
+ function->getName().end());
+ unsigned originalLength = nameBuffer.size();
+
+ // Iteratively try suffixes until we find one that isn't used. We use a
+ // module level uniquing counter to avoid N^2 behavior.
+ do {
+ nameBuffer.resize(originalLength);
+ nameBuffer += '_';
+ nameBuffer += std::to_string(module->uniquingCounter++);
+ function->name = Identifier::get(nameBuffer, module->getContext());
+ } while (!module->symbolTable.insert({function->name, function}).second);
+ }
}
/// This is a trait method invoked when a Function is removed from a Module.
/// We keep the module pointer up to date.
void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
assert(function->module && "not already in a module!");
+
+ // Remove the symbol table entry.
+ function->module->symbolTable.erase(function->getName());
function->module = nullptr;
}
@@ -76,9 +99,11 @@
if (curParent == otherList.getContainingModule())
return;
- // Update the 'module' member of each function.
- for (; first != last; ++first)
- first->module = curParent;
+ // Update the 'module' member and symbol table records for each function.
+ for (; first != last; ++first) {
+ removeNodeFromList(&*first);
+ addNodeToList(&*first);
+ }
}
/// Unlink this function from its Module and delete it.
diff --git a/lib/IR/Module.cpp b/lib/IR/Module.cpp
index 99e5e32..5a33d26 100644
--- a/lib/IR/Module.cpp
+++ b/lib/IR/Module.cpp
@@ -19,3 +19,16 @@
using namespace mlir;
Module::Module(MLIRContext *context) : context(context) {}
+
+/// Look up a function with the specified name, returning null if no such
+/// name exists.
+Function *Module::getNamedFunction(StringRef name) {
+ return getNamedFunction(Identifier::get(name, context));
+}
+
+/// Look up a function with the specified name, returning null if no such
+/// name exists.
+Function *Module::getNamedFunction(Identifier name) {
+ auto it = symbolTable.find(name);
+ return it != symbolTable.end() ? it->second : nullptr;
+}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index f263804..ae87ad2 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1945,8 +1945,6 @@
"'");
}
- getModule()->getFunctions().push_back(function);
-
return finalizeFunction(function, braceLoc);
}
@@ -2122,8 +2120,6 @@
if (parseStmtBlock(function))
return ParseFailure;
- getModule()->getFunctions().push_back(function);
-
return finalizeFunction(function, braceLoc);
}
@@ -2579,6 +2575,7 @@
///
ParseResult ModuleParser::parseExtFunc() {
consumeToken(Token::kw_extfunc);
+ auto loc = getToken().getLoc();
StringRef name;
FunctionType *type = nullptr;
@@ -2586,7 +2583,14 @@
return ParseFailure;
// Okay, the external function definition was parsed correctly.
- getModule()->getFunctions().push_back(new ExtFunction(name, type));
+ auto *function = new ExtFunction(name, type);
+ getModule()->getFunctions().push_back(function);
+
+ // Verify no name collision / redefinition.
+ if (function->getName().ref() != name)
+ return emitError(loc,
+ "redefinition of function named '" + name.str() + "'");
+
return ParseSuccess;
}
@@ -2596,6 +2600,7 @@
///
ParseResult ModuleParser::parseCFGFunc() {
consumeToken(Token::kw_cfgfunc);
+ auto loc = getToken().getLoc();
StringRef name;
FunctionType *type = nullptr;
@@ -2603,7 +2608,13 @@
return ParseFailure;
// Okay, the CFG function signature was parsed correctly, create the function.
- auto function = new CFGFunction(name, type);
+ auto *function = new CFGFunction(name, type);
+ getModule()->getFunctions().push_back(function);
+
+ // Verify no name collision / redefinition.
+ if (function->getName().ref() != name)
+ return emitError(loc,
+ "redefinition of function named '" + name.str() + "'");
return CFGFunctionParser(getState(), function).parseFunctionBody();
}
@@ -2624,7 +2635,13 @@
return ParseFailure;
// Okay, the ML function signature was parsed correctly, create the function.
- auto function = MLFunction::create(name, type);
+ auto *function = MLFunction::create(name, type);
+ getModule()->getFunctions().push_back(function);
+
+ // Verify no name collision / redefinition.
+ if (function->getName().ref() != name)
+ return emitError(loc,
+ "redefinition of function named '" + name.str() + "'");
// Create the parser.
auto parser = MLFunctionParser(getState(), function);
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 3863ea0..1bac43f 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -110,7 +110,7 @@
CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
// TODO: ensure that CFG function name is unique.
auto *cfgFunc =
- new CFGFunction(mlFunc->getName() + "_cfg", mlFunc->getType());
+ new CFGFunction(mlFunc->getName().str() + "_cfg", mlFunc->getType());
module->getFunctions().push_back(cfgFunc);
// Generates the body of the CFG function.
diff --git a/test/IR/affine-map.mlir b/test/IR/affine-map.mlir
index 4a8ca6b..42c9608 100644
--- a/test/IR/affine-map.mlir
+++ b/test/IR/affine-map.mlir
@@ -300,5 +300,5 @@
// CHECK: extfunc @f45(memref<100x100x100xi8, #map{{[0-9]+}}>)
extfunc @f45(memref<100x100x100xi8, #map45>)
-// CHECK: extfunc @f45(memref<100x100x100xi8, #map{{[0-9]+}}>)
-extfunc @f45(memref<100x100x100xi8, #map46>)
+// CHECK: extfunc @f46(memref<100x100x100xi8, #map{{[0-9]+}}>)
+extfunc @f46(memref<100x100x100xi8, #map46>)
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 0a544c0..9edc152 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -397,3 +397,8 @@
// expected-error@-1 {{'return' op must be the last statement in the ML function}}
}
}
+
+// -----
+
+extfunc @redef()
+extfunc @redef() // expected-error {{redefinition of function named 'redef'}}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 5a36299..a5eb9e6 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -62,11 +62,11 @@
// Test memref inline affine map compositions.
-// CHECK: extfunc @memrefs2(memref<2x4x8xi8, #map{{[0-9]+}}>)
-extfunc @memrefs2(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2)>)
+// CHECK: extfunc @memrefs3(memref<2x4x8xi8, #map{{[0-9]+}}>)
+extfunc @memrefs3(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2)>)
-// CHECK: extfunc @memrefs23(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 1>)
-extfunc @memrefs23(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
+// CHECK: extfunc @memrefs33(memref<2x4x8xi8, #map{{[0-9]+}}, #map{{[0-9]+}}, 1>)
+extfunc @memrefs33(memref<2x4x8xi8, (d0, d1, d2) -> (d0, d1, d2), (d0, d1, d2) -> (d1, d0, d2), 1>)
// CHECK: extfunc @functions((memref<1x?x4x?x?xi32, #map0>, memref<i8, #map1>) -> (), () -> ())
extfunc @functions((memref<1x?x4x?x?xi32, #map0, 0>, memref<i8, #map1, 0>) -> (), ()->())