Finish support for function attributes, and improve lots of things:
- Have the parser rewrite forward references to their resolved values at the
end of parsing.
- Implement verifier support for detecting malformed function attrs.
- Add efficient query for (in general, recursive) attributes to tell if they
contain a function.
As part of this, improve other general infrastructure:
- Implement support for verifying OperationStmt's in ml functions, refactoring
and generalizing support for operations in the verifier.
- Refactor location handling code in mlir-opt to have the non-error expecting
form of mlir-opt invocations to report error locations precisely.
- Fix parser to detect verifier failures and report them through errorReporter
instead of printing the error and crashing.
This regresses the location info for verifier errors in the parser that were
previously ascribed to the function. This will get resolved in future patches
by adding support for function attributes, which we can use to manage location
information.
PiperOrigin-RevId: 209600980
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 272395c..b74e113 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -33,11 +33,13 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/CFGFunction.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OperationSet.h"
#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
#include "llvm/ADT/ScopedHashTable.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/PrettyStackTrace.h"
@@ -81,31 +83,93 @@
return true;
}
+ bool verifyOperation(const Operation &op);
+ bool verifyAttribute(Attribute *attr, const Operation &op);
+
protected:
- explicit Verifier(std::string *errorResult) : errorResult(errorResult) {}
+ explicit Verifier(std::string *errorResult, const Function &fn)
+ : errorResult(errorResult), fn(fn),
+ operationSet(OperationSet::get(fn.getContext())) {}
private:
+ /// If the verifier is returning errors back to a client, this is the error to
+ /// fill in.
std::string *errorResult;
+
+ /// The function being checked.
+ const Function &fn;
+
+ /// The operation set installed in the current MLIR context.
+ OperationSet &operationSet;
};
} // end anonymous namespace
+// Check that function attributes are all well formed.
+bool Verifier::verifyAttribute(Attribute *attr, const Operation &op) {
+ if (!attr->isOrContainsFunction())
+ return false;
+
+ // If we have a function attribute, check that it is non-null and in the
+ // same module as the operation that refers to it.
+ if (auto *fnAttr = dyn_cast<FunctionAttr>(attr)) {
+ if (!fnAttr->getValue())
+ return opFailure("attribute refers to deallocated function!", op);
+
+ if (fnAttr->getValue()->getModule() != fn.getModule())
+ return opFailure("attribute refers to function '" +
+ Twine(fnAttr->getValue()->getName()) +
+ "' defined in another module!",
+ op);
+ return false;
+ }
+
+ // Otherwise, we must have an array attribute, remap the elements.
+ for (auto *elt : cast<ArrayAttr>(attr)->getValue()) {
+ if (verifyAttribute(elt, op))
+ return true;
+ }
+
+ return false;
+}
+
+/// Check the invariants of the specified operation instruction or statement.
+bool Verifier::verifyOperation(const Operation &op) {
+ if (op.getOperationFunction() != &fn)
+ return opFailure("operation in the wrong function", op);
+
+ // TODO: Check that operands are non-nil and structurally ok.
+
+ // Verify all attributes are ok. We need to check Function attributes, since
+ // they are actually mutable (the function they refer to can be deleted), and
+ // we have to check array attributes that can refer to them.
+ for (auto attr : op.getAttrs()) {
+ if (verifyAttribute(attr.second, op))
+ return true;
+ }
+
+ // If we can get operation info for this, check the custom hook.
+ if (auto *opInfo = op.getAbstractOperation()) {
+ if (auto *errorMessage = opInfo->verifyInvariants(&op))
+ return opFailure(Twine("'") + op.getName().str() + "' op " + errorMessage,
+ op);
+ }
+
+ return false;
+}
+
//===----------------------------------------------------------------------===//
// CFG Functions
//===----------------------------------------------------------------------===//
namespace {
-class CFGFuncVerifier : public Verifier {
-public:
+struct CFGFuncVerifier : public Verifier {
const CFGFunction &fn;
- OperationSet &operationSet;
CFGFuncVerifier(const CFGFunction &fn, std::string *errorResult)
- : Verifier(errorResult), fn(fn),
- operationSet(OperationSet::get(fn.getContext())) {}
+ : Verifier(errorResult, fn), fn(fn) {}
bool verify();
bool verifyBlock(const BasicBlock &block);
- bool verifyOperation(const OperationInst &inst);
bool verifyTerminator(const TerminatorInst &term);
bool verifyReturn(const ReturnInst &inst);
bool verifyBranch(const BranchInst &inst);
@@ -281,39 +345,31 @@
return false;
}
-bool CFGFuncVerifier::verifyOperation(const OperationInst &inst) {
- if (inst.getFunction() != &fn)
- return opFailure("operation in the wrong function", inst);
-
- // TODO: Check that operands are structurally ok.
-
- // See if we can get operation info for this.
- if (auto *opInfo = inst.getAbstractOperation()) {
- if (auto errorMessage = opInfo->verifyInvariants(&inst))
- return opFailure(
- Twine("'") + inst.getName().str() + "' op " + errorMessage, inst);
- }
-
- return false;
-}
-
//===----------------------------------------------------------------------===//
// ML Functions
//===----------------------------------------------------------------------===//
namespace {
-class MLFuncVerifier : public Verifier {
-public:
+struct MLFuncVerifier : public Verifier, public StmtWalker<MLFuncVerifier> {
const MLFunction &fn;
+ bool hadError = false;
MLFuncVerifier(const MLFunction &fn, std::string *errorResult)
- : Verifier(errorResult), fn(fn) {}
+ : Verifier(errorResult, fn), fn(fn) {}
+
+ void visitOperationStmt(OperationStmt *opStmt) {
+ hadError |= verifyOperation(*opStmt);
+ }
bool verify() {
llvm::PrettyStackTraceFormat fmt("MLIR Verifier: mlfunc @%s",
fn.getName().c_str());
- // TODO: check basic structural properties
+ // Check basic structural properties.
+ walk(const_cast<MLFunction *>(&fn));
+ if (hadError)
+ return true;
+
// TODO: check that operation is not a return statement unless it's
// the last one in the function.
if (verifyReturn())