Add location specifier to MLIR Functions, and:
- Compress the identifier/kind of a Function into a single word.
- Eliminate otherFailure from verifier now that we always have a location
- Eliminate the error string from the verifier now that we always have
locations.
- Simplify the parser's handling of fn forward references, using the location
tracked by the function.
PiperOrigin-RevId: 211985101
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index bfcfd6f..29494de 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
@@ -25,8 +26,10 @@
#include "llvm/ADT/StringRef.h"
using namespace mlir;
-Function::Function(StringRef name, FunctionType *type, Kind kind)
- : kind(kind), name(Identifier::get(name, type->getContext())), type(type) {}
+Function::Function(Kind kind, Location *location, StringRef name,
+ FunctionType *type)
+ : nameAndKind(Identifier::get(name, type->getContext()), kind),
+ location(location), type(type) {}
Function::~Function() {
// Clean up function attributes referring to this function.
@@ -66,7 +69,7 @@
// Add this function to the symbol table of the module, uniquing the name if
// a conflict is detected.
- if (!module->symbolTable.insert({function->name, function}).second) {
+ if (!module->symbolTable.insert({function->getName(), function}).second) {
// If a conflict was detected, then the function will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
SmallString<128> nameBuffer(function->getName().begin(),
@@ -79,8 +82,10 @@
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(module->uniquingCounter++);
- function->name = Identifier::get(nameBuffer, module->getContext());
- } while (!module->symbolTable.insert({function->name, function}).second);
+ function->nameAndKind.setPointer(
+ Identifier::get(nameBuffer, module->getContext()));
+ } while (
+ !module->symbolTable.insert({function->getName(), function}).second);
}
}
@@ -118,21 +123,41 @@
getModule()->getFunctions().erase(this);
}
+/// Emit a note about this instruction, reporting up to any diagnostic
+/// handlers that may be listening.
+void Function::emitNote(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Note);
+}
+
+/// Emit a warning about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+void Function::emitWarning(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Warning);
+}
+
+/// Emit an error about fatal conditions with this instruction, reporting up to
+/// any diagnostic handlers that may be listening. NOTE: This may terminate
+/// the containing application, only use when the IR is in an inconsistent
+/// state.
+void Function::emitError(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Error);
+}
//===----------------------------------------------------------------------===//
// ExtFunction implementation.
//===----------------------------------------------------------------------===//
-ExtFunction::ExtFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::ExtFunc) {
-}
+ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type)
+ : Function(Kind::ExtFunc, location, name, type) {}
//===----------------------------------------------------------------------===//
// CFGFunction implementation.
//===----------------------------------------------------------------------===//
-CFGFunction::CFGFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::CFGFunc) {
-}
+CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type)
+ : Function(Kind::CFGFunc, location, name, type) {}
CFGFunction::~CFGFunction() {
// Instructions may have cyclic references, which need to be dropped before we
@@ -150,13 +175,14 @@
//===----------------------------------------------------------------------===//
/// Create a new MLFunction with the specific fields.
-MLFunction *MLFunction::create(StringRef name, FunctionType *type) {
+MLFunction *MLFunction::create(Location *location, StringRef name,
+ FunctionType *type) {
const auto &argTypes = type->getInputs();
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
void *rawMem = malloc(byteSize);
// Initialize the MLFunction part of the function object.
- auto function = ::new (rawMem) MLFunction(name, type);
+ auto function = ::new (rawMem) MLFunction(location, name, type);
// Initialize the arguments.
auto arguments = function->getArgumentsInternal();
@@ -165,8 +191,9 @@
return function;
}
-MLFunction::MLFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
+MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type)
+ : Function(Kind::MLFunc, location, name, type),
+ StmtBlock(StmtBlockKind::MLFunc) {}
MLFunction::~MLFunction() {
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index a624979..33b0adf 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -53,33 +53,13 @@
///
class Verifier {
public:
- template <typename T>
- static void failure(const Twine &message, const T &value, raw_ostream &os) {
- // Print the error message and flush the stream in case printing the value
- // causes a crash.
- os << "MLIR verification failure: " + message + "\n";
- os.flush();
- value.print(os);
- }
-
- template <typename T>
- bool otherFailure(const Twine &message, const T &value) {
- // If the caller isn't trying to collect failure information, just print
- // the result and abort.
- if (!errorResult) {
- failure(message, value, llvm::errs());
- abort();
- }
-
- // Otherwise, emit the error into the string and return true.
- llvm::raw_string_ostream os(*errorResult);
- failure(message, value, os);
- os.flush();
+ bool failure(const Twine &message, const Operation &value) {
+ value.emitError(message);
return true;
}
- bool failure(const Twine &message, const Operation &value) {
- value.emitError(message);
+ bool failure(const Twine &message, const Function &fn) {
+ fn.emitError(message);
return true;
}
@@ -88,27 +68,28 @@
return true;
}
- bool failure(const Twine &message, const Function &fn) {
- return otherFailure(message, fn);
- }
-
bool failure(const Twine &message, const BasicBlock &bb) {
- return otherFailure(message, bb);
+ // Take the location information for the first instruction in the block.
+ if (!bb.empty())
+ return failure(message, static_cast<const Instruction &>(bb.front()));
+
+ // If the code is properly formed, there will be a terminator. Use its
+ // location.
+ if (auto *termInst = bb.getTerminator())
+ return failure(message, *termInst);
+
+ // Worst case, fall back to using the function's location.
+ return failure(message, fn);
}
bool verifyOperation(const Operation &op);
bool verifyAttribute(Attribute *attr, const Operation &op);
protected:
- explicit Verifier(std::string *errorResult, const Function &fn)
- : errorResult(errorResult), fn(fn),
- operationSet(OperationSet::get(fn.getContext())) {}
+ explicit Verifier(const Function &fn)
+ : fn(fn), operationSet(OperationSet::get(fn.getContext())) {}
private:
- /// If the verifier is returning errors back to a client, this is the error to
- /// fill in.
- std::string *errorResult;
-
/// The function being checked.
const Function &fn;
@@ -185,8 +166,7 @@
struct CFGFuncVerifier : public Verifier {
const CFGFunction &fn;
- CFGFuncVerifier(const CFGFunction &fn, std::string *errorResult)
- : Verifier(errorResult, fn), fn(fn) {}
+ CFGFuncVerifier(const CFGFunction &fn) : Verifier(fn), fn(fn) {}
bool verify();
bool verifyBlock(const BasicBlock &block);
@@ -354,8 +334,7 @@
const MLFunction &fn;
bool hadError = false;
- MLFuncVerifier(const MLFunction &fn, std::string *errorResult)
- : Verifier(errorResult, fn), fn(fn) {}
+ MLFuncVerifier(const MLFunction &fn) : Verifier(fn), fn(fn) {}
void visitOperationStmt(OperationStmt *opStmt) {
hadError |= verifyOperation(*opStmt);
@@ -487,33 +466,30 @@
//===----------------------------------------------------------------------===//
/// Perform (potentially expensive) checks of invariants, used to detect
-/// compiler bugs. On error, this fills in the string and return true,
-/// or aborts if the string was not provided.
-bool Function::verify(std::string *errorResult) const {
+/// compiler bugs. On error, this reports the error through the MLIRContext and
+/// returns true.
+bool Function::verify() const {
switch (getKind()) {
case Kind::ExtFunc:
// No body, nothing can be wrong here.
return false;
case Kind::CFGFunc:
- return CFGFuncVerifier(*cast<CFGFunction>(this), errorResult).verify();
+ return CFGFuncVerifier(*cast<CFGFunction>(this)).verify();
case Kind::MLFunc:
- return MLFuncVerifier(*cast<MLFunction>(this), errorResult).verify();
+ return MLFuncVerifier(*cast<MLFunction>(this)).verify();
}
}
/// Perform (potentially expensive) checks of invariants, used to detect
-/// compiler bugs. On error, this fills in the string and return true,
-/// or aborts if the string was not provided.
-bool Module::verify(std::string *errorResult) const {
+/// compiler bugs. On error, this reports the error through the MLIRContext and
+/// returns true.
+bool Module::verify() const {
/// Check that each function is correct.
for (auto &fn : *this) {
- if (fn.verify(errorResult))
+ if (fn.verify())
return true;
}
- // Make sure the error string is empty on success.
- if (errorResult)
- errorResult->clear();
return false;
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 154a24c..bc6585c 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -68,9 +68,8 @@
llvm::StringMap<IntegerSet *> integerSetDefinitions;
// This keeps track of all forward references to functions along with the
- // temporary function used to represent them and the location of the first
- // reference.
- llvm::DenseMap<Identifier, std::pair<Function *, SMLoc>> functionForwardRefs;
+ // temporary function used to represent them.
+ llvm::DenseMap<Identifier, Function *> functionForwardRefs;
private:
ParserState(const ParserState &) = delete;
@@ -605,11 +604,9 @@
// If not, get or create a forward reference to one.
if (!function) {
auto &entry = state.functionForwardRefs[name];
- if (!entry.first) {
- entry.first = new ExtFunction(name, type);
- entry.second = nameLoc;
- }
- function = entry.first;
+ if (!entry)
+ entry = new ExtFunction(getEncodedSourceLocation(nameLoc), name, type);
+ function = entry;
}
if (function->getType() != type)
@@ -2771,7 +2768,7 @@
return ParseFailure;
// Okay, the external function definition was parsed correctly.
- auto *function = new ExtFunction(name, type);
+ auto *function = new ExtFunction(getEncodedSourceLocation(loc), name, type);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
@@ -2796,7 +2793,7 @@
return ParseFailure;
// Okay, the CFG function signature was parsed correctly, create the function.
- auto *function = new CFGFunction(name, type);
+ auto *function = new CFGFunction(getEncodedSourceLocation(loc), name, type);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
@@ -2823,7 +2820,8 @@
return ParseFailure;
// Okay, the ML function signature was parsed correctly, create the function.
- auto *function = MLFunction::create(name, type);
+ auto *function =
+ MLFunction::create(getEncodedSourceLocation(loc), name, type);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
@@ -2907,11 +2905,13 @@
// Resolve the reference.
auto *resolvedFunction = getModule()->getNamedFunction(name);
- if (!resolvedFunction)
- return emitError(forwardRef.second.second,
- "reference to undefined function '" + name.str() + "'");
+ if (!resolvedFunction) {
+ forwardRef.second->emitError("reference to undefined function '" +
+ name.str() + "'");
+ return ParseFailure;
+ }
- remappingTable[builder.getFunctionAttr(forwardRef.second.first)] =
+ remappingTable[builder.getFunctionAttr(forwardRef.second)] =
builder.getFunctionAttr(resolvedFunction);
}
@@ -2951,7 +2951,7 @@
// Now that all references to the forward definition placeholders are
// resolved, we can deallocate the placeholders.
for (auto forwardRef : getState().functionForwardRefs)
- forwardRef.second.first->destroy();
+ forwardRef.second->destroy();
return ParseSuccess;
}
@@ -3018,19 +3018,8 @@
// Make sure the parse module has no other structural problems detected by the
// verifier.
- //
- // TODO(clattner): The verifier should always emit diagnostics when we have
- // more location information available. We shouldn't need this hook.
- std::string errorResult;
- module->verify(&errorResult);
-
- // We don't have location information for general verifier errors, so emit the
- // error with an unknown location.
- if (!errorResult.empty()) {
- context->emitDiagnostic(UnknownLoc::get(context), errorResult,
- MLIRContext::DiagnosticKind::Error);
+ if (module->verify())
return nullptr;
- }
return module.release();
}
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 2f50713..97529ec 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -109,8 +109,8 @@
// Creates CFG function equivalent to the given ML function.
CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
// TODO: ensure that CFG function name is unique.
- auto *cfgFunc =
- new CFGFunction(mlFunc->getName().str() + "_cfg", mlFunc->getType());
+ auto *cfgFunc = new CFGFunction(
+ mlFunc->getLoc(), mlFunc->getName().str() + "_cfg", mlFunc->getType());
module->getFunctions().push_back(cfgFunc);
// Generates the body of the CFG function.