Enhance MLIRContext and operations with the ability to register diagnostic
handlers and to feed them with errors and warnings produced by the compiler.
Enhance Operation to be able to get its own MLIRContext on demand, simplifying
some clients. Change the verifier to emit certain errors with the diagnostic
handler.
This is steps towards reworking the verifier and diagnostic propagation but is
itself not particularly useful. More to come.
PiperOrigin-RevId: 206948643
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index be8c99a..a18a97f 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -46,6 +46,9 @@
Kind getKind() const { return kind; }
+ /// Return the context this operation is associated with.
+ MLIRContext *getContext() const;
+
/// Return the BasicBlock containing this instruction.
BasicBlock *getBlock() const { return block; }
@@ -152,6 +155,9 @@
ArrayRef<NamedAttribute> attributes,
MLIRContext *context);
+ /// Return the context this operation is associated with.
+ MLIRContext *getContext() const { return Instruction::getContext(); }
+
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
diff --git a/include/mlir/IR/MLIRContext.h b/include/mlir/IR/MLIRContext.h
index 2bd019a..f17b25b 100644
--- a/include/mlir/IR/MLIRContext.h
+++ b/include/mlir/IR/MLIRContext.h
@@ -18,10 +18,13 @@
#ifndef MLIR_IR_MLIRCONTEXT_H
#define MLIR_IR_MLIRCONTEXT_H
+#include "mlir/Support/LLVM.h"
+#include <functional>
#include <memory>
namespace mlir {
- class MLIRContextImpl;
+class MLIRContextImpl;
+class Attribute;
/// MLIRContext is the top-level object for a collection of MLIR modules. It
/// holds immortal uniqued objects like types, and the tables used to unique
@@ -42,6 +45,29 @@
// This is effectively private given that only MLIRContext.cpp can see the
// MLIRContextImpl type.
MLIRContextImpl &getImpl() const { return *impl.get(); }
+
+ // Diagnostic handler registration and use. MLIR supports the ability for the
+ // IR to carry arbitrary metadata about operation location information. If an
+ // error or warning is detected in the compiler, the pass in question can
+ // invoke the emitError/emitWarning method on an operation and have it
+ // reported through this interface.
+ //
+ // Tools using MLIR are encouraged to register error handlers and define a
+ // schema for their location information. If they don't, then warnings will
+ // be dropped and errors will terminate the process with exit(1).
+
+ /// Register a diagnostic handler with this LLVM context. The handler is
+ /// passed location information if present (nullptr if not) along with a
+ /// message and a boolean that indicates whether this is an error or warning.
+ void registerDiagnosticHandler(
+ const std::function<void(Attribute *location, StringRef message,
+ bool isError)> &handler);
+
+ /// This emits an diagnostic using the registered issue handle if present, or
+ /// with the default behavior if not. The MLIR compiler should not generally
+ /// interact with this, it should use methods on Operation instead.
+ void emitDiagnostic(Attribute *location, const Twine &message,
+ bool isError) const;
};
} // end namespace mlir
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index 2222ab7..1e2d614 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -44,6 +44,9 @@
///
class Operation {
public:
+ /// Return the context this operation is associated with.
+ MLIRContext *getContext() const;
+
/// The name of an operation is the key identifier for it.
Identifier getName() const { return nameAndIsInstruction.getPointer(); }
@@ -137,11 +140,21 @@
/// value indicates whether the attribute was present or not.
RemoveResult removeAttr(Identifier name, MLIRContext *context);
+ /// Emit a warning about this operation, reporting up to any diagnostic
+ /// handlers that may be listening.
+ void emitWarning(const Twine &message) const;
+
+ /// Emit an error about fatal conditions with this operation, reporting up to
+ /// any diagnostic handlers that may be listening. NOTE: This may terminate
+ /// the containing application, only use when the IR is in an inconsistent
+ /// state.
+ void emitError(const Twine &message) const;
+
/// If this operation has a registered operation description in the
/// OperationSet, return it. Otherwise return null.
/// TODO: Shouldn't have to pass a Context here, Operation should eventually
/// be able to get to its own parent.
- const AbstractOperation *getAbstractOperation(MLIRContext *context) const;
+ const AbstractOperation *getAbstractOperation() const;
/// The getAs methods perform a dynamic cast from an Operation (like
/// OperationInst and OperationStmt) to a typed Op like DimOp. This returns
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index f43f0b6..65cad98 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -30,6 +30,7 @@
class MLFunction;
class StmtBlock;
class ForStmt;
+class MLIRContext;
/// Statement is a basic unit of execution within an ML function.
/// Statements can be nested within for and if statements effectively
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 39d191a..306702f 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -44,6 +44,9 @@
ArrayRef<NamedAttribute> attributes,
MLIRContext *context);
+ /// Return the context this operation is associated with.
+ MLIRContext *getContext() const;
+
//===--------------------------------------------------------------------===//
// Operands
//===--------------------------------------------------------------------===//
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index c3f0b1e..0076d01 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -16,7 +16,7 @@
// =============================================================================
#include "mlir/IR/Instructions.h"
-#include "mlir/IR/BasicBlock.h"
+#include "mlir/IR/CFGFunction.h"
using namespace mlir;
/// Replace all uses of 'this' value with the new value, updating anything in
@@ -69,6 +69,11 @@
free(this);
}
+/// Return the context this operation is associated with.
+MLIRContext *Instruction::getContext() const {
+ return getFunction()->getContext();
+}
+
CFGFunction *Instruction::getFunction() const {
return getBlock()->getFunction();
}
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 4963000..8d8b013 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -28,7 +28,9 @@
#include "third_party/llvm/llvm/include/llvm/ADT/STLExtras.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/StringMap.h"
+#include "llvm/ADT/Twine.h"
#include "llvm/Support/Allocator.h"
+#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using namespace llvm;
@@ -184,6 +186,10 @@
/// This is the set of all operations that are registered with the system.
OperationSet operationSet;
+ /// This is the handler to use to report issues, or null if not registered.
+ std::function<void(Attribute *location, StringRef message, bool isError)>
+ issueHandler;
+
/// These are identifiers uniqued into this MLIRContext.
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
@@ -266,6 +272,36 @@
MLIRContext::~MLIRContext() {}
+/// Register an issue handler with this LLVM context. The issue handler is
+/// passed location information if present (nullptr if not) along with a
+/// message and a boolean that indicates whether this is an error or warning.
+void MLIRContext::registerDiagnosticHandler(
+ const std::function<void(Attribute *location, StringRef message,
+ bool isError)> &handler) {
+ getImpl().issueHandler = handler;
+}
+
+/// This emits a diagnostic using the registered issue handle if present, or
+/// with the default behavior if not. The MLIR compiler should not generally
+/// interact with this, it should use methods on Operation instead.
+void MLIRContext::emitDiagnostic(Attribute *location,
+ const llvm::Twine &message,
+ bool isError) const {
+ // If we had a handler registered, emit the diagnostic using it.
+ auto handler = getImpl().issueHandler;
+ if (handler)
+ return handler(location, message.str(), isError);
+
+ // The default behavior for warnings is to ignore them.
+ if (!isError)
+ return;
+
+ // The default behavior for errors is to emit them to stderr and exit.
+ llvm::errs() << message.str() << "\n";
+ llvm::errs().flush();
+ exit(1);
+}
+
/// Return the operation set associated with the specified MLIRContext object.
OperationSet &OperationSet::get(MLIRContext *context) {
return context->getImpl().operationSet;
@@ -273,10 +309,8 @@
/// If this operation has a registered operation description in the
/// OperationSet, return it. Otherwise return null.
-/// TODO: Shouldn't have to pass a Context here.
-const AbstractOperation *
-Operation::getAbstractOperation(MLIRContext *context) const {
- return OperationSet::get(context).lookup(getName().str());
+const AbstractOperation *Operation::getAbstractOperation() const {
+ return OperationSet::get(getContext()).lookup(getName().str());
}
//===----------------------------------------------------------------------===//
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index f2f9eb4..f26181e 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -18,6 +18,7 @@
#include "mlir/IR/Operation.h"
#include "AttributeListStorage.h"
#include "mlir/IR/Instructions.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Statements.h"
using namespace mlir;
@@ -34,6 +35,13 @@
Operation::~Operation() {}
+/// Return the context this operation is associated with.
+MLIRContext *Operation::getContext() const {
+ if (auto *inst = dyn_cast<OperationInst>(this))
+ return inst->getContext();
+ return cast<OperationStmt>(this)->getContext();
+}
+
/// Return the number of operands this operation has.
unsigned Operation::getNumOperands() const {
if (auto *inst = dyn_cast<OperationInst>(this)) {
@@ -126,3 +134,35 @@
}
return RemoveResult::NotFound;
}
+
+/// Emit a warning about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+void Operation::emitWarning(const Twine &message) const {
+ // Get the location information for this operation.
+ auto *loc = getAttr("location");
+
+ // If that fails, fall back to the internal location which is used in
+ // testcases.
+ if (!loc)
+ loc = getAttr(":location");
+
+ auto *context = getContext();
+ context->emitDiagnostic(loc, message, /*isError=*/false);
+}
+
+/// Emit an error about fatal conditions with this operation, reporting up to
+/// any diagnostic handlers that may be listening. NOTE: This may terminate
+/// the containing application, only use when the IR is in an inconsistent
+/// state.
+void Operation::emitError(const Twine &message) const {
+ // Get the location information for this operation.
+ auto *loc = getAttr("location");
+
+ // If that fails, fall back to the internal location which is used in
+ // testcases.
+ if (!loc)
+ loc = getAttr(":location");
+
+ auto *context = getContext();
+ context->emitDiagnostic(loc, message, /*isError=*/true);
+}
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index b55b597..3bbcdd4 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -177,6 +177,20 @@
free(this);
}
+/// Return the context this operation is associated with.
+MLIRContext *OperationStmt::getContext() const {
+ // If we have a result or operand type, that is a constant time way to get
+ // to the context.
+ if (getNumResults())
+ return getResult(0)->getType()->getContext();
+ if (getNumOperands())
+ return getOperand(0)->getType()->getContext();
+
+ // In the very odd case where we have no operands or results, fall back to
+ // doing a find.
+ return findFunction()->getContext();
+}
+
/// This drops all operand uses from this statement, which is an essential
/// step in breaking cyclic dependences between references when they are to
/// be deleted.
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 9e52556..6eb7a68 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -73,6 +73,11 @@
return true;
}
+ bool opFailure(const Twine &message, const Operation &value) {
+ value.emitError(message);
+ return true;
+ }
+
protected:
explicit Verifier(std::string *errorResult) : errorResult(errorResult) {}
@@ -272,15 +277,15 @@
bool CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
if (inst.getFunction() != &fn)
- return failure("operation in the wrong function", inst);
+ return opFailure("operation in the wrong function", inst);
// TODO: Check that operands are structurally ok.
// See if we can get operation info for this.
- if (auto *opInfo = inst.getAbstractOperation(fn.getContext())) {
+ if (auto *opInfo = inst.getAbstractOperation()) {
if (auto errorMessage = opInfo->verifyInvariants(&inst))
- return failure(Twine("'") + inst.getName().str() + "' op " + errorMessage,
- inst);
+ return opFailure(
+ Twine("'") + inst.getName().str() + "' op " + errorMessage, inst);
}
return false;
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 0f4a3a5..a468c9e 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1548,7 +1548,7 @@
// We just parsed an operation. If it is a recognized one, verify that it
// is structurally as we expect. If not, produce an error with a reasonable
// source location.
- if (auto *opInfo = op->getAbstractOperation(builder.getContext())) {
+ if (auto *opInfo = op->getAbstractOperation()) {
if (auto error = opInfo->verifyInvariants(op))
return emitError(loc, Twine("'") + op->getName().str() + "' op " + error);
}