Add location specifier to MLIR Functions, and:
- Compress the identifier/kind of a Function into a single word.
- Eliminate otherFailure from verifier now that we always have a location
- Eliminate the error string from the verifier now that we always have
locations.
- Simplify the parser's handling of fn forward references, using the location
tracked by the function.
PiperOrigin-RevId: 211985101
diff --git a/include/mlir/IR/CFGFunction.h b/include/mlir/IR/CFGFunction.h
index 025e13a..16de0af 100644
--- a/include/mlir/IR/CFGFunction.h
+++ b/include/mlir/IR/CFGFunction.h
@@ -27,7 +27,7 @@
// blocks, each of which includes instructions.
class CFGFunction : public Function {
public:
- CFGFunction(StringRef name, FunctionType *type);
+ CFGFunction(Location *location, StringRef name, FunctionType *type);
~CFGFunction();
//===--------------------------------------------------------------------===//
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index f3b7aa0..6398cee 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -30,6 +30,7 @@
namespace mlir {
class FunctionType;
+class Location;
class MLIRContext;
class Module;
@@ -38,10 +39,13 @@
public:
enum class Kind { ExtFunc, CFGFunc, MLFunc };
- Kind getKind() const { return kind; }
+ Kind getKind() const { return (Kind)nameAndKind.getInt(); }
+
+ /// The source location the operation was defined or derived from.
+ Location *getLoc() const { return location; }
/// Return the name of this function, without the @.
- Identifier getName() const { return name; }
+ Identifier getName() const { return nameAndKind.getPointer(); }
/// Return the type of this function.
FunctionType *getType() const { return type; }
@@ -57,21 +61,42 @@
void destroy();
/// Perform (potentially expensive) checks of invariants, used to detect
- /// compiler bugs. On error, this fills in the string and return true,
- /// or aborts if the string was not provided.
- bool verify(std::string *errorResult = nullptr) const;
+ /// compiler bugs. On error, this reports the error through the MLIRContext
+ /// and returns true.
+ bool verify() const;
void print(raw_ostream &os) const;
void dump() const;
+ /// Emit an error about fatal conditions with this operation, reporting up to
+ /// any diagnostic handlers that may be listening. NOTE: This may terminate
+ /// the containing application, only use when the IR is in an inconsistent
+ /// state.
+ void emitError(const Twine &message) const;
+
+ /// Emit a warning about this operation, reporting up to any diagnostic
+ /// handlers that may be listening.
+ void emitWarning(const Twine &message) const;
+
+ /// Emit a note about this operation, reporting up to any diagnostic
+ /// handlers that may be listening.
+ void emitNote(const Twine &message) const;
+
protected:
- Function(StringRef name, FunctionType *type, Kind kind);
+ Function(Kind kind, Location *location, StringRef name, FunctionType *type);
~Function();
private:
- Kind kind;
+ /// The name of the function and the kind of function this is.
+ llvm::PointerIntPair<Identifier, 2, Kind> nameAndKind;
+
+ /// The module this function is embedded into.
Module *module = nullptr;
- Identifier name;
+
+ /// The source location the function was defined or derived from.
+ Location *location;
+
+ /// The type of the function.
FunctionType *const type;
void operator=(const Function &) = delete;
@@ -82,7 +107,7 @@
/// defined in some other module.
class ExtFunction : public Function {
public:
- ExtFunction(StringRef name, FunctionType *type);
+ ExtFunction(Location *location, StringRef name, FunctionType *type);
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Function *func) {
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index 03535ff..2d24c8e 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -151,8 +151,8 @@
Kind kind;
BasicBlock *block = nullptr;
- /// This holds information about the source location the operation was defined
- /// or derived from.
+ /// This holds information about the source location the instruction was
+ /// defined or derived from.
Location *location;
friend struct llvm::ilist_traits<OperationInst>;
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index f6dedb6..5f2497b 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -40,7 +40,8 @@
private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
public:
/// Creates a new MLFunction with the specific type.
- static MLFunction *create(StringRef name, FunctionType *type);
+ static MLFunction *create(Location *location, StringRef name,
+ FunctionType *type);
/// Destroys this statement and its subclass data.
void destroy();
@@ -93,7 +94,7 @@
}
private:
- MLFunction(StringRef name, FunctionType *type);
+ MLFunction(Location *location, StringRef name, FunctionType *type);
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index 38c4ef8..9ea6d33 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -75,9 +75,9 @@
}
/// Perform (potentially expensive) checks of invariants, used to detect
- /// compiler bugs. On error, this fills in the string and return true,
- /// or aborts if the string was not provided.
- bool verify(std::string *errorResult = nullptr) const;
+ /// compiler bugs. On error, this reports the error through the MLIRContext
+ /// and returns true.
+ bool verify() const;
void print(raw_ostream &os) const;
void dump() const;
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index bfcfd6f..29494de 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
@@ -25,8 +26,10 @@
#include "llvm/ADT/StringRef.h"
using namespace mlir;
-Function::Function(StringRef name, FunctionType *type, Kind kind)
- : kind(kind), name(Identifier::get(name, type->getContext())), type(type) {}
+Function::Function(Kind kind, Location *location, StringRef name,
+ FunctionType *type)
+ : nameAndKind(Identifier::get(name, type->getContext()), kind),
+ location(location), type(type) {}
Function::~Function() {
// Clean up function attributes referring to this function.
@@ -66,7 +69,7 @@
// Add this function to the symbol table of the module, uniquing the name if
// a conflict is detected.
- if (!module->symbolTable.insert({function->name, function}).second) {
+ if (!module->symbolTable.insert({function->getName(), function}).second) {
// If a conflict was detected, then the function will not have been added to
// the symbol table. Try suffixes until we get to a unique name that works.
SmallString<128> nameBuffer(function->getName().begin(),
@@ -79,8 +82,10 @@
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(module->uniquingCounter++);
- function->name = Identifier::get(nameBuffer, module->getContext());
- } while (!module->symbolTable.insert({function->name, function}).second);
+ function->nameAndKind.setPointer(
+ Identifier::get(nameBuffer, module->getContext()));
+ } while (
+ !module->symbolTable.insert({function->getName(), function}).second);
}
}
@@ -118,21 +123,41 @@
getModule()->getFunctions().erase(this);
}
+/// Emit a note about this instruction, reporting up to any diagnostic
+/// handlers that may be listening.
+void Function::emitNote(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Note);
+}
+
+/// Emit a warning about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+void Function::emitWarning(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Warning);
+}
+
+/// Emit an error about fatal conditions with this instruction, reporting up to
+/// any diagnostic handlers that may be listening. NOTE: This may terminate
+/// the containing application, only use when the IR is in an inconsistent
+/// state.
+void Function::emitError(const Twine &message) const {
+ getContext()->emitDiagnostic(getLoc(), message,
+ MLIRContext::DiagnosticKind::Error);
+}
//===----------------------------------------------------------------------===//
// ExtFunction implementation.
//===----------------------------------------------------------------------===//
-ExtFunction::ExtFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::ExtFunc) {
-}
+ExtFunction::ExtFunction(Location *location, StringRef name, FunctionType *type)
+ : Function(Kind::ExtFunc, location, name, type) {}
//===----------------------------------------------------------------------===//
// CFGFunction implementation.
//===----------------------------------------------------------------------===//
-CFGFunction::CFGFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::CFGFunc) {
-}
+CFGFunction::CFGFunction(Location *location, StringRef name, FunctionType *type)
+ : Function(Kind::CFGFunc, location, name, type) {}
CFGFunction::~CFGFunction() {
// Instructions may have cyclic references, which need to be dropped before we
@@ -150,13 +175,14 @@
//===----------------------------------------------------------------------===//
/// Create a new MLFunction with the specific fields.
-MLFunction *MLFunction::create(StringRef name, FunctionType *type) {
+MLFunction *MLFunction::create(Location *location, StringRef name,
+ FunctionType *type) {
const auto &argTypes = type->getInputs();
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
void *rawMem = malloc(byteSize);
// Initialize the MLFunction part of the function object.
- auto function = ::new (rawMem) MLFunction(name, type);
+ auto function = ::new (rawMem) MLFunction(location, name, type);
// Initialize the arguments.
auto arguments = function->getArgumentsInternal();
@@ -165,8 +191,9 @@
return function;
}
-MLFunction::MLFunction(StringRef name, FunctionType *type)
- : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
+MLFunction::MLFunction(Location *location, StringRef name, FunctionType *type)
+ : Function(Kind::MLFunc, location, name, type),
+ StmtBlock(StmtBlockKind::MLFunc) {}
MLFunction::~MLFunction() {
// Explicitly erase statements instead of relying of 'StmtBlock' destructor
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index a624979..33b0adf 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -53,33 +53,13 @@
///
class Verifier {
public:
- template <typename T>
- static void failure(const Twine &message, const T &value, raw_ostream &os) {
- // Print the error message and flush the stream in case printing the value
- // causes a crash.
- os << "MLIR verification failure: " + message + "\n";
- os.flush();
- value.print(os);
- }
-
- template <typename T>
- bool otherFailure(const Twine &message, const T &value) {
- // If the caller isn't trying to collect failure information, just print
- // the result and abort.
- if (!errorResult) {
- failure(message, value, llvm::errs());
- abort();
- }
-
- // Otherwise, emit the error into the string and return true.
- llvm::raw_string_ostream os(*errorResult);
- failure(message, value, os);
- os.flush();
+ bool failure(const Twine &message, const Operation &value) {
+ value.emitError(message);
return true;
}
- bool failure(const Twine &message, const Operation &value) {
- value.emitError(message);
+ bool failure(const Twine &message, const Function &fn) {
+ fn.emitError(message);
return true;
}
@@ -88,27 +68,28 @@
return true;
}
- bool failure(const Twine &message, const Function &fn) {
- return otherFailure(message, fn);
- }
-
bool failure(const Twine &message, const BasicBlock &bb) {
- return otherFailure(message, bb);
+ // Take the location information for the first instruction in the block.
+ if (!bb.empty())
+ return failure(message, static_cast<const Instruction &>(bb.front()));
+
+ // If the code is properly formed, there will be a terminator. Use its
+ // location.
+ if (auto *termInst = bb.getTerminator())
+ return failure(message, *termInst);
+
+ // Worst case, fall back to using the function's location.
+ return failure(message, fn);
}
bool verifyOperation(const Operation &op);
bool verifyAttribute(Attribute *attr, const Operation &op);
protected:
- explicit Verifier(std::string *errorResult, const Function &fn)
- : errorResult(errorResult), fn(fn),
- operationSet(OperationSet::get(fn.getContext())) {}
+ explicit Verifier(const Function &fn)
+ : fn(fn), operationSet(OperationSet::get(fn.getContext())) {}
private:
- /// If the verifier is returning errors back to a client, this is the error to
- /// fill in.
- std::string *errorResult;
-
/// The function being checked.
const Function &fn;
@@ -185,8 +166,7 @@
struct CFGFuncVerifier : public Verifier {
const CFGFunction &fn;
- CFGFuncVerifier(const CFGFunction &fn, std::string *errorResult)
- : Verifier(errorResult, fn), fn(fn) {}
+ CFGFuncVerifier(const CFGFunction &fn) : Verifier(fn), fn(fn) {}
bool verify();
bool verifyBlock(const BasicBlock &block);
@@ -354,8 +334,7 @@
const MLFunction &fn;
bool hadError = false;
- MLFuncVerifier(const MLFunction &fn, std::string *errorResult)
- : Verifier(errorResult, fn), fn(fn) {}
+ MLFuncVerifier(const MLFunction &fn) : Verifier(fn), fn(fn) {}
void visitOperationStmt(OperationStmt *opStmt) {
hadError |= verifyOperation(*opStmt);
@@ -487,33 +466,30 @@
//===----------------------------------------------------------------------===//
/// Perform (potentially expensive) checks of invariants, used to detect
-/// compiler bugs. On error, this fills in the string and return true,
-/// or aborts if the string was not provided.
-bool Function::verify(std::string *errorResult) const {
+/// compiler bugs. On error, this reports the error through the MLIRContext and
+/// returns true.
+bool Function::verify() const {
switch (getKind()) {
case Kind::ExtFunc:
// No body, nothing can be wrong here.
return false;
case Kind::CFGFunc:
- return CFGFuncVerifier(*cast<CFGFunction>(this), errorResult).verify();
+ return CFGFuncVerifier(*cast<CFGFunction>(this)).verify();
case Kind::MLFunc:
- return MLFuncVerifier(*cast<MLFunction>(this), errorResult).verify();
+ return MLFuncVerifier(*cast<MLFunction>(this)).verify();
}
}
/// Perform (potentially expensive) checks of invariants, used to detect
-/// compiler bugs. On error, this fills in the string and return true,
-/// or aborts if the string was not provided.
-bool Module::verify(std::string *errorResult) const {
+/// compiler bugs. On error, this reports the error through the MLIRContext and
+/// returns true.
+bool Module::verify() const {
/// Check that each function is correct.
for (auto &fn : *this) {
- if (fn.verify(errorResult))
+ if (fn.verify())
return true;
}
- // Make sure the error string is empty on success.
- if (errorResult)
- errorResult->clear();
return false;
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 154a24c..bc6585c 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -68,9 +68,8 @@
llvm::StringMap<IntegerSet *> integerSetDefinitions;
// This keeps track of all forward references to functions along with the
- // temporary function used to represent them and the location of the first
- // reference.
- llvm::DenseMap<Identifier, std::pair<Function *, SMLoc>> functionForwardRefs;
+ // temporary function used to represent them.
+ llvm::DenseMap<Identifier, Function *> functionForwardRefs;
private:
ParserState(const ParserState &) = delete;
@@ -605,11 +604,9 @@
// If not, get or create a forward reference to one.
if (!function) {
auto &entry = state.functionForwardRefs[name];
- if (!entry.first) {
- entry.first = new ExtFunction(name, type);
- entry.second = nameLoc;
- }
- function = entry.first;
+ if (!entry)
+ entry = new ExtFunction(getEncodedSourceLocation(nameLoc), name, type);
+ function = entry;
}
if (function->getType() != type)
@@ -2771,7 +2768,7 @@
return ParseFailure;
// Okay, the external function definition was parsed correctly.
- auto *function = new ExtFunction(name, type);
+ auto *function = new ExtFunction(getEncodedSourceLocation(loc), name, type);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
@@ -2796,7 +2793,7 @@
return ParseFailure;
// Okay, the CFG function signature was parsed correctly, create the function.
- auto *function = new CFGFunction(name, type);
+ auto *function = new CFGFunction(getEncodedSourceLocation(loc), name, type);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
@@ -2823,7 +2820,8 @@
return ParseFailure;
// Okay, the ML function signature was parsed correctly, create the function.
- auto *function = MLFunction::create(name, type);
+ auto *function =
+ MLFunction::create(getEncodedSourceLocation(loc), name, type);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.
@@ -2907,11 +2905,13 @@
// Resolve the reference.
auto *resolvedFunction = getModule()->getNamedFunction(name);
- if (!resolvedFunction)
- return emitError(forwardRef.second.second,
- "reference to undefined function '" + name.str() + "'");
+ if (!resolvedFunction) {
+ forwardRef.second->emitError("reference to undefined function '" +
+ name.str() + "'");
+ return ParseFailure;
+ }
- remappingTable[builder.getFunctionAttr(forwardRef.second.first)] =
+ remappingTable[builder.getFunctionAttr(forwardRef.second)] =
builder.getFunctionAttr(resolvedFunction);
}
@@ -2951,7 +2951,7 @@
// Now that all references to the forward definition placeholders are
// resolved, we can deallocate the placeholders.
for (auto forwardRef : getState().functionForwardRefs)
- forwardRef.second.first->destroy();
+ forwardRef.second->destroy();
return ParseSuccess;
}
@@ -3018,19 +3018,8 @@
// Make sure the parse module has no other structural problems detected by the
// verifier.
- //
- // TODO(clattner): The verifier should always emit diagnostics when we have
- // more location information available. We shouldn't need this hook.
- std::string errorResult;
- module->verify(&errorResult);
-
- // We don't have location information for general verifier errors, so emit the
- // error with an unknown location.
- if (!errorResult.empty()) {
- context->emitDiagnostic(UnknownLoc::get(context), errorResult,
- MLIRContext::DiagnosticKind::Error);
+ if (module->verify())
return nullptr;
- }
return module.release();
}
diff --git a/lib/Transforms/ConvertToCFG.cpp b/lib/Transforms/ConvertToCFG.cpp
index 2f50713..97529ec 100644
--- a/lib/Transforms/ConvertToCFG.cpp
+++ b/lib/Transforms/ConvertToCFG.cpp
@@ -109,8 +109,8 @@
// Creates CFG function equivalent to the given ML function.
CFGFunction *ModuleConverter::convert(const MLFunction *mlFunc) {
// TODO: ensure that CFG function name is unique.
- auto *cfgFunc =
- new CFGFunction(mlFunc->getName().str() + "_cfg", mlFunc->getType());
+ auto *cfgFunc = new CFGFunction(
+ mlFunc->getLoc(), mlFunc->getName().str() + "_cfg", mlFunc->getType());
module->getFunctions().push_back(cfgFunc);
// Generates the body of the CFG function.
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 484b47d..ae5d3b2 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -103,14 +103,12 @@
// -----
-mlfunc @empty() {
- //expected-error@-3 {{ML function must end with return statement}}
+mlfunc @empty() { // expected-error {{ML function must end with return statement}}
}
// -----
-mlfunc @no_return() {
- // expected-error@-3 {{ML function must end with return statement}}
+mlfunc @no_return() { // expected-error {{ML function must end with return statement}}
"foo"() : () -> ()
}
@@ -297,7 +295,7 @@
// -----
-cfgfunc @bbargMismatch(i32, f32) { // expected-error @-2 {{first block of cfgfunc must have 2 arguments to match function signature}}
+cfgfunc @bbargMismatch(i32, f32) { // expected-error {{first block of cfgfunc must have 2 arguments to match function signature}}
bb42(%0: f32):
return
}
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 5bb86b4..0c434b1 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -190,20 +190,7 @@
delete pass;
// Verify that the result of the pass is still valid.
- std::string errorResult;
- module->verify(&errorResult);
-
- // We don't have location information for general verifier errors, so emit
- // the error with an unknown location.
- if (!errorResult.empty()) {
- context->emitDiagnostic(UnknownLoc::get(context), errorResult,
- MLIRContext::DiagnosticKind::Error);
-
- auto output = getOutputStream();
- module->print(output->os());
- output->keep();
- return OptFailure;
- }
+ module->verify();
}
// Print the output.