Continue wiring up diagnostic reporting infrastructure, still WIP.
- Implement a diagnostic hook in one of the paths in mlir-opt which
captures and reports the diagnostics nicely.
- Have the parser capture simple location information from the parser
indicating where each op came from in the source .mlir file.
- Add a verifyDominance() method to MLFuncVerifier to demo this, resolving b/112086163
- Add some PrettyStackTrace handlers to make crashes in the testsuite easier
to track down.
PiperOrigin-RevId: 207488548
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 3edb115..b5f8980 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -186,9 +186,9 @@
/// This is the set of all operations that are registered with the system.
OperationSet operationSet;
- /// This is the handler to use to report issues, or null if not registered.
- std::function<void(Attribute *location, StringRef message, bool isError)>
- issueHandler;
+ /// This is the handler to use to report diagnostics, or null if not
+ /// registered.
+ MLIRContext::DiagnosticHandlerTy diagnosticHandler;
/// These are identifiers uniqued into this MLIRContext.
llvm::StringMap<char, llvm::BumpPtrAllocator &> identifiers;
@@ -278,9 +278,8 @@
/// passed location information if present (nullptr if not) along with a
/// message and a boolean that indicates whether this is an error or warning.
void MLIRContext::registerDiagnosticHandler(
- const std::function<void(Attribute *location, StringRef message,
- bool isError)> &handler) {
- getImpl().issueHandler = handler;
+ const DiagnosticHandlerTy &handler) {
+ getImpl().diagnosticHandler = handler;
}
/// This emits a diagnostic using the registered issue handle if present, or
@@ -288,14 +287,14 @@
/// interact with this, it should use methods on Operation instead.
void MLIRContext::emitDiagnostic(Attribute *location,
const llvm::Twine &message,
- bool isError) const {
+ DiagnosticKind kind) const {
// If we had a handler registered, emit the diagnostic using it.
- auto handler = getImpl().issueHandler;
- if (handler)
- return handler(location, message.str(), isError);
+ auto handler = getImpl().diagnosticHandler;
+ if (handler && location)
+ return handler(location, message.str(), kind);
- // The default behavior for warnings is to ignore them.
- if (!isError)
+ // The default behavior for notes and warnings is to ignore them.
+ if (kind != DiagnosticKind::Error)
return;
// The default behavior for errors is to emit them to stderr and exit.
diff --git a/lib/IR/Operation.cpp b/lib/IR/Operation.cpp
index f26181e..af937bc 100644
--- a/lib/IR/Operation.cpp
+++ b/lib/IR/Operation.cpp
@@ -97,12 +97,12 @@
/// If an attribute exists with the specified name, change it to the new
/// value. Otherwise, add a new attribute with the specified name/value.
-void Operation::setAttr(Identifier name, Attribute *value,
- MLIRContext *context) {
+void Operation::setAttr(Identifier name, Attribute *value) {
assert(value && "attributes may never be null");
auto origAttrs = getAttrs();
SmallVector<NamedAttribute, 8> newAttrs(origAttrs.begin(), origAttrs.end());
+ auto *context = getContext();
// If we already have this attribute, replace it.
for (auto &elt : newAttrs)
@@ -119,8 +119,7 @@
/// Remove the attribute with the specified name if it exists. The return
/// value indicates whether the attribute was present or not.
-auto Operation::removeAttr(Identifier name, MLIRContext *context)
- -> RemoveResult {
+auto Operation::removeAttr(Identifier name) -> RemoveResult {
auto origAttrs = getAttrs();
for (unsigned i = 0, e = origAttrs.size(); i != e; ++i) {
if (origAttrs[i].first == name) {
@@ -128,26 +127,25 @@
newAttrs.reserve(origAttrs.size() - 1);
newAttrs.append(origAttrs.begin(), origAttrs.begin() + i);
newAttrs.append(origAttrs.begin() + i + 1, origAttrs.end());
- attrs = AttributeListStorage::get(newAttrs, context);
+ attrs = AttributeListStorage::get(newAttrs, getContext());
return RemoveResult::Removed;
}
}
return RemoveResult::NotFound;
}
+/// Emit a note about this operation, reporting up to any diagnostic
+/// handlers that may be listening.
+void Operation::emitNote(const Twine &message) const {
+ getContext()->emitDiagnostic(getAttr(":location"), message,
+ MLIRContext::DiagnosticKind::Note);
+}
+
/// Emit a warning about this operation, reporting up to any diagnostic
/// handlers that may be listening.
void Operation::emitWarning(const Twine &message) const {
- // Get the location information for this operation.
- auto *loc = getAttr("location");
-
- // If that fails, fall back to the internal location which is used in
- // testcases.
- if (!loc)
- loc = getAttr(":location");
-
- auto *context = getContext();
- context->emitDiagnostic(loc, message, /*isError=*/false);
+ getContext()->emitDiagnostic(getAttr(":location"), message,
+ MLIRContext::DiagnosticKind::Warning);
}
/// Emit an error about fatal conditions with this operation, reporting up to
@@ -155,14 +153,6 @@
/// the containing application, only use when the IR is in an inconsistent
/// state.
void Operation::emitError(const Twine &message) const {
- // Get the location information for this operation.
- auto *loc = getAttr("location");
-
- // If that fails, fall back to the internal location which is used in
- // testcases.
- if (!loc)
- loc = getAttr(":location");
-
- auto *context = getContext();
- context->emitDiagnostic(loc, message, /*isError=*/true);
+ getContext()->emitDiagnostic(getAttr(":location"), message,
+ MLIRContext::DiagnosticKind::Error);
}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 6eb7a68..45df86a 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -37,7 +37,10 @@
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
+#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/Twine.h"
+#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -121,6 +124,9 @@
} // end anonymous namespace
bool CFGFuncVerifier::verify() {
+ llvm::PrettyStackTraceFormat fmt("MLIR Verifier: cfgfunc @%s",
+ fn.getName().c_str());
+
// TODO: Lots to be done here, including verifying dominance information when
// we have uses and defs.
// TODO: Verify the first block has no predecessors.
@@ -304,12 +310,85 @@
: Verifier(errorResult), fn(fn) {}
bool verify() {
- // TODO.
- return false;
+ llvm::PrettyStackTraceFormat fmt("MLIR Verifier: mlfunc @%s",
+ fn.getName().c_str());
+
+ // TODO: check basic structural properties.
+
+ return verifyDominance();
}
+
+ /// Walk all of the code in this MLFunc and verify that the operands of any
+ /// operations are properly dominated by their definitions.
+ bool verifyDominance();
};
} // end anonymous namespace
+/// Walk all of the code in this MLFunc and verify that the operands of any
+/// operations are properly dominated by their definitions.
+bool MLFuncVerifier::verifyDominance() {
+ using HashTable = llvm::ScopedHashTable<const SSAValue *, bool>;
+ HashTable liveValues;
+ HashTable::ScopeTy topScope(liveValues);
+
+ // All of the arguments to the function are live for the whole function.
+ // TODO: Add arguments when they are supported.
+
+ // This recursive function walks the statement list pushing scopes onto the
+ // stack as it goes, and popping them to remove them from the table.
+ std::function<bool(const StmtBlock &block)> walkBlock;
+ walkBlock = [&](const StmtBlock &block) -> bool {
+ HashTable::ScopeTy blockScope(liveValues);
+
+ // The induction variable of a for statement is live within its body.
+ if (auto *forStmt = dyn_cast<ForStmt>(&block))
+ liveValues.insert(forStmt, true);
+
+ for (auto &stmt : block) {
+ // TODO: For and If will eventually have operands, we need to check them.
+ // When this happens, Statement should have a general getOperands() method
+ // we can use here first.
+ if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+ // Verify that each of the operands are live.
+ unsigned operandNo = 0;
+ for (auto *opValue : opStmt->getOperands()) {
+ if (!liveValues.count(opValue)) {
+ opStmt->emitError("operand #" + Twine(operandNo) +
+ " does not dominate this use");
+ if (auto *useStmt = opValue->getDefiningStmt())
+ useStmt->emitNote("operand defined here");
+ return true;
+ }
+ ++operandNo;
+ }
+
+ // Operations define values, add them to the hash table.
+ for (auto *result : opStmt->getResults())
+ liveValues.insert(result, true);
+ continue;
+ }
+
+ // If this is an if or for, recursively walk the block they contain.
+ if (auto *ifStmt = dyn_cast<IfStmt>(&stmt)) {
+ if (walkBlock(*ifStmt->getThenClause()))
+ return true;
+
+ if (auto *elseClause = ifStmt->getElseClause())
+ if (walkBlock(*elseClause))
+ return true;
+ }
+ if (auto *forStmt = dyn_cast<ForStmt>(&stmt))
+ if (walkBlock(*forStmt))
+ return true;
+ }
+
+ return false;
+ };
+
+ // Check the whole function out.
+ return walkBlock(fn);
+}
+
//===----------------------------------------------------------------------===//
// Entrypoints
//===----------------------------------------------------------------------===//
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index c77f0f4..ff3811f 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -105,6 +105,7 @@
MLIRContext *getContext() const { return state.context; }
Module *getModule() { return state.module; }
OperationSet &getOperationSet() const { return state.operationSet; }
+ llvm::SourceMgr &getSourceMgr() { return state.lex.getSourceMgr(); }
/// Return the current token the parser is inspecting.
const Token &getToken() const { return state.curToken; }
@@ -1553,6 +1554,19 @@
if (!op)
return ParseFailure;
+ // Apply location information to the instruction.
+ // TODO(clattner): make this more principled. We shouldn't overwrite existing
+ // location info, we should use a better serialized form, and we shouldn't
+ // be using the :location attribute. This is also pretty inefficient.
+ {
+ auto &sourceMgr = getSourceMgr();
+ auto fileID = sourceMgr.FindBufferContainingLoc(loc);
+ auto *srcBuffer = sourceMgr.getMemoryBuffer(fileID);
+ unsigned locationEncoding = loc.getPointer() - srcBuffer->getBufferStart();
+ op->setAttr(builder.getIdentifier(":location"),
+ builder.getIntegerAttr(locationEncoding));
+ }
+
// We just parsed an operation. If it is a recognized one, verify that it
// is structurally as we expect. If not, produce an error with a reasonable
// source location.