Finish support for function attributes, and improve lots of things:
- Have the parser rewrite forward references to their resolved values at the
end of parsing.
- Implement verifier support for detecting malformed function attrs.
- Add efficient query for (in general, recursive) attributes to tell if they
contain a function.
As part of this, improve other general infrastructure:
- Implement support for verifying OperationStmt's in ml functions, refactoring
and generalizing support for operations in the verifier.
- Refactor location handling code in mlir-opt to have the non-error expecting
form of mlir-opt invocations to report error locations precisely.
- Fix parser to detect verifier failures and report them through errorReporter
instead of printing the error and crashing.
This regresses the location info for verifier errors in the parser that were
previously ascribed to the function. This will get resolved in future patches
by adding support for function attributes, which we can use to manage location
information.
PiperOrigin-RevId: 209600980
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 5310daa..174d6ca 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -26,10 +26,11 @@
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSet.h"
-#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/SourceMgr.h"
@@ -997,16 +998,10 @@
///
/// affine-expr ::= integer-literal
AffineExpr *AffineParser::parseIntegerExpr() {
- // No need to handle negative numbers separately here. They are naturally
- // handled via the unary negation operator, although (FIXME) MININT_64 still
- // not correctly handled.
- if (getToken().isNot(Token::integer))
- return (emitError("expected integer"), nullptr);
-
auto val = getToken().getUInt64IntegerValue();
- if (!val.hasValue() || (int64_t)val.getValue() < 0) {
+ if (!val.hasValue() || (int64_t)val.getValue() < 0)
return (emitError("constant too large for affineint"), nullptr);
- }
+
consumeToken(Token::integer);
return builder.getConstantExpr((int64_t)val.getValue());
}
@@ -1454,11 +1449,6 @@
return ParseFailure;
}
- // Run the verifier on this function. If an error is detected, report it.
- std::string errorString;
- if (func->verify(&errorString))
- return emitError(loc, errorString);
-
return ParseSuccess;
}
@@ -2220,9 +2210,6 @@
// Reset insertion point to the current block.
builder.setInsertionPointToEnd(forStmt->getBlock());
-
- // TODO: remove definition of the induction variable.
-
return ParseSuccess;
}
@@ -2700,11 +2687,65 @@
return parser.parseFunctionBody();
}
+/// Given an attribute that could refer to a function attribute in the remapping
+/// table, walk it and rewrite it to use the mapped function. If it doesn't
+/// refer to anything in the table, then it is returned unmodified.
+static Attribute *
+remapFunctionAttrs(Attribute *input,
+ DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable,
+ MLIRContext *context) {
+ // Most attributes are trivially unrelated to function attributes, skip them
+ // rapidly.
+ if (!input->isOrContainsFunction())
+ return input;
+
+ // If we have a function attribute, remap it.
+ if (auto *fnAttr = dyn_cast<FunctionAttr>(input)) {
+ auto it = remappingTable.find(fnAttr);
+ return it != remappingTable.end() ? it->second : input;
+ }
+
+ // Otherwise, we must have an array attribute, remap the elements.
+ auto *arrayAttr = cast<ArrayAttr>(input);
+ SmallVector<Attribute *, 8> remappedElts;
+ bool anyChange = false;
+ for (auto *elt : arrayAttr->getValue()) {
+ auto *newElt = remapFunctionAttrs(elt, remappingTable, context);
+ remappedElts.push_back(newElt);
+ anyChange |= (elt != newElt);
+ }
+
+ if (!anyChange)
+ return input;
+
+ return ArrayAttr::get(remappedElts, context);
+}
+
+/// Remap function attributes to resolve forward references to their actual
+/// definition.
+static void remapFunctionAttrsInOperation(
+ Operation *op, DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable) {
+ for (auto attr : op->getAttrs()) {
+ // Do the remapping, if we got the same thing back, then it must contain
+ // functions that aren't getting remapped.
+ auto *newVal =
+ remapFunctionAttrs(attr.second, remappingTable, op->getContext());
+ if (newVal == attr.second)
+ continue;
+
+ // Otherwise, replace the existing attribute with the new one. It is safe
+ // to mutate the attribute list while we walk it because underlying
+ // attribute lists are uniqued and immortal.
+ op->setAttr(attr.first, newVal);
+ }
+}
+
/// Finish the end of module parsing - when the result is valid, do final
/// checking.
ParseResult ModuleParser::finalizeModule() {
- // Resolve all forward references.
+ // Resolve all forward references, building a remapping table of attributes.
+ DenseMap<FunctionAttr *, FunctionAttr *> remappingTable;
for (auto forwardRef : getState().functionForwardRefs) {
auto name = forwardRef.first;
@@ -2714,10 +2755,47 @@
return emitError(forwardRef.second.second,
"reference to undefined function '" + name.str() + "'");
- // TODO(clattner): actually go through and update references in the module
- // to the new function.
+ remappingTable[builder.getFunctionAttr(forwardRef.second.first)] =
+ builder.getFunctionAttr(resolvedFunction);
}
+ // If there was nothing to remap, then we're done.
+ if (remappingTable.empty())
+ return ParseSuccess;
+
+ // Otherwise, walk the entire module replacing uses of one attribute set with
+ // the correct ones.
+ for (auto &fn : *getModule()) {
+ if (auto *cfgFn = dyn_cast<CFGFunction>(&fn)) {
+ for (auto &bb : *cfgFn) {
+ for (auto &inst : bb) {
+ remapFunctionAttrsInOperation(&inst, remappingTable);
+ }
+ }
+ }
+
+ // Otherwise, look at MLFunctions. We ignore ExtFunctions.
+ auto *mlFn = dyn_cast<MLFunction>(&fn);
+ if (!mlFn)
+ continue;
+
+ struct MLFnWalker : public StmtWalker<MLFnWalker> {
+ MLFnWalker(DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable)
+ : remappingTable(remappingTable) {}
+ void visitOperationStmt(OperationStmt *opStmt) {
+ remapFunctionAttrsInOperation(opStmt, remappingTable);
+ }
+
+ DenseMap<FunctionAttr *, FunctionAttr *> &remappingTable;
+ };
+
+ MLFnWalker(remappingTable).walk(mlFn);
+ }
+
+ // Now that all references to the forward definition placeholders are
+ // resolved, we can deallocate the placeholders.
+ for (auto forwardRef : getState().functionForwardRefs)
+ forwardRef.second.first->destroy();
return ParseSuccess;
}
@@ -2778,16 +2856,60 @@
/// MLIR module if it was valid. If not, it emits diagnostics and returns null.
Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
SMDiagnosticHandlerTy errorReporter) {
+ if (!errorReporter)
+ errorReporter = defaultErrorReporter;
+
+ // We are going to replace the context's handler and redirect it to use the
+ // error reporter. Save the existing handler and reinstate it when we're
+ // done.
+ auto existingContextHandler = context->getDiagnosticHandler();
+
+ // Install a new handler that uses the error reporter.
+ context->registerDiagnosticHandler([&](Attribute *location, StringRef message,
+ MLIRContext::DiagnosticKind kind) {
+ auto offset = cast<IntegerAttr>(location)->getValue();
+ auto *mainBuffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
+ auto ptr = mainBuffer->getBufferStart() + offset;
+ SourceMgr::DiagKind diagKind;
+ switch (kind) {
+ case MLIRContext::DiagnosticKind::Error:
+ diagKind = SourceMgr::DK_Error;
+ break;
+ case MLIRContext::DiagnosticKind::Warning:
+ diagKind = SourceMgr::DK_Warning;
+ break;
+ case MLIRContext::DiagnosticKind::Note:
+ diagKind = SourceMgr::DK_Note;
+ break;
+ }
+ errorReporter(
+ sourceMgr.GetMessage(SMLoc::getFromPointer(ptr), diagKind, message));
+ });
+
// This is the result module we are parsing into.
std::unique_ptr<Module> module(new Module(context));
- ParserState state(sourceMgr, module.get(),
- errorReporter ? errorReporter : defaultErrorReporter);
- if (ModuleParser(state).parseModule())
+ ParserState state(sourceMgr, module.get(), errorReporter);
+ if (ModuleParser(state).parseModule()) {
+ context->registerDiagnosticHandler(existingContextHandler);
return nullptr;
+ }
// Make sure the parse module has no other structural problems detected by the
// verifier.
- module->verify();
+ std::string errorResult;
+ module->verify(&errorResult);
+
+ // We don't have location information for general verifier errors, so emit the
+ // error on the first line.
+ if (!errorResult.empty()) {
+ auto *mainBuffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
+ errorReporter(sourceMgr.GetMessage(
+ SMLoc::getFromPointer(mainBuffer->getBufferStart()),
+ SourceMgr::DK_Error, errorResult));
+ return nullptr;
+ }
+
+ context->registerDiagnosticHandler(existingContextHandler);
return module.release();
}