Implement a module-level symbol table for functions, enforcing uniqueness of
names across the module and auto-renaming conflicts. Have the parser reject
malformed modules that have redefinitions.
PiperOrigin-RevId: 209227560
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index 7e70c37..15e08e3 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -24,6 +24,7 @@
#ifndef MLIR_IR_FUNCTION_H
#define MLIR_IR_FUNCTION_H
+#include "mlir/IR/Identifier.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/ilist.h"
@@ -40,7 +41,7 @@
Kind getKind() const { return kind; }
/// Return the name of this function, without the @.
- const std::string &getName() const { return name; }
+ Identifier getName() const { return name; }
/// Return the type of this function.
FunctionType *getType() const { return type; }
@@ -70,7 +71,7 @@
private:
Kind kind;
Module *module = nullptr;
- std::string name;
+ Identifier name;
FunctionType *const type;
void operator=(const Function &) = delete;
diff --git a/include/mlir/IR/Identifier.h b/include/mlir/IR/Identifier.h
index 6b9a782..70528e3 100644
--- a/include/mlir/IR/Identifier.h
+++ b/include/mlir/IR/Identifier.h
@@ -36,11 +36,17 @@
public:
/// Return an identifier for the specified string.
static Identifier get(StringRef str, const MLIRContext *context);
+ Identifier(const Identifier &) = default;
+ Identifier &operator=(const Identifier &other) = default;
/// Return a StringRef for the string.
- StringRef str() const {
- return StringRef(pointer, size());
- }
+ StringRef ref() const { return StringRef(pointer, size()); }
+
+ /// Return an std::string.
+ std::string str() const { return ref().str(); }
+
+ /// Return a null terminated C string.
+ const char *c_str() const { return pointer; }
/// Return a pointer to the start of the string data.
const char *data() const {
@@ -53,12 +59,10 @@
}
/// Return true if this identifier is the specified string.
- bool is(StringRef string) const {
- return str().equals(string);
- }
+ bool is(StringRef string) const { return ref().equals(string); }
- Identifier(const Identifier&) = default;
- Identifier &operator=(const Identifier &other) = default;
+ const char *begin() const { return pointer; }
+ const char *end() const { return pointer + size(); }
void print(raw_ostream &os) const;
void dump() const;
@@ -96,7 +100,7 @@
// Make identifiers hashable.
inline llvm::hash_code hash_value(Identifier arg) {
- return llvm::hash_value(arg.str());
+ return llvm::hash_value(arg.ref());
}
} // end namespace mlir
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index ed12afc..f6dedb6 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -39,7 +39,7 @@
public StmtBlock,
private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
public:
- /// Creates a new MLFunction with the specific fields.
+ /// Creates a new MLFunction with the specific type.
static MLFunction *create(StringRef name, FunctionType *type);
/// Destroys this statement and its subclass data.
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index 8e45f8d..f6c543b 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -23,6 +23,7 @@
#define MLIR_IR_MODULE_H
#include "mlir/IR/Function.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ilist.h"
#include <vector>
@@ -36,8 +37,6 @@
MLIRContext *getContext() const { return context; }
- // TODO: We should have a symbol table for function names.
-
/// This is the list of functions in the module.
typedef llvm::iplist<Function> FunctionListType;
FunctionListType &getFunctions() { return functions; }
@@ -58,6 +57,23 @@
const_reverse_iterator rbegin() const { return functions.rbegin(); }
const_reverse_iterator rend() const { return functions.rend(); }
+ // Interfaces for working with the symbol table.
+
+ /// Look up a function with the specified name, returning null if no such
+ /// name exists.
+ Function *getNamedFunction(StringRef name);
+ const Function *getNamedFunction(StringRef name) const {
+ return const_cast<Module *>(this)->getNamedFunction(name);
+ }
+
+ /// Look up a function with the specified name, returning null if no such
+ /// name exists.
+ Function *getNamedFunction(Identifier name);
+
+ const Function *getNamedFunction(Identifier name) const {
+ return const_cast<Module *>(this)->getNamedFunction(name);
+ }
+
/// Perform (potentially expensive) checks of invariants, used to detect
/// compiler bugs. On error, this fills in the string and return true,
/// or aborts if the string was not provided.
@@ -66,13 +82,23 @@
void print(raw_ostream &os) const;
void dump() const;
+private:
+ friend struct llvm::ilist_traits<Function>;
+
/// getSublistAccess() - Returns pointer to member of function list
static FunctionListType Module::*getSublistAccess(Function *) {
return &Module::functions;
}
-private:
MLIRContext *context;
+
+ /// This is a mapping from a name to the function with that name.
+ llvm::DenseMap<Identifier, Function *> symbolTable;
+
+ /// This is used when name conflicts are detected.
+ unsigned uniquingCounter = 0;
+
+ /// This is the actual list of functions the module contains.
FunctionListType functions;
};
} // end namespace mlir