Add default error reporter for parser.
Add a default error reporter for the parser that uses the SourceManager to print the error. Also and OptResult enum (mirroring ParseResult) to make the behavior self-documenting.
PiperOrigin-RevId: 203173647
diff --git a/include/mlir/Parser.h b/include/mlir/Parser.h
index 276c081..951b3a4 100644
--- a/include/mlir/Parser.h
+++ b/include/mlir/Parser.h
@@ -33,13 +33,18 @@
class Module;
class MLIRContext;
-using SMDiagnosticHandlerTy = std::function<void(llvm::SMDiagnostic error)>;
+using SMDiagnosticHandlerTy =
+ std::function<void(const llvm::SMDiagnostic &error)>;
+
+/// Default error reproter that prints out the error using the SourceMgr of the
+/// error.
+void defaultErrorReporter(const llvm::SMDiagnostic &error);
/// This parses the file specified by the indicated SourceMgr and returns an
/// MLIR module if it was valid. If not, the errorReporter is used to report
/// the error diagnostics and this function returns null.
Module *parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
- const SMDiagnosticHandlerTy &errorReporter);
+ SMDiagnosticHandlerTy errorReporter = nullptr);
} // end namespace mlir
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 91f80e2..08157e5 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -48,11 +48,9 @@
class Parser {
public:
Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
- const SMDiagnosticHandlerTy &errorReporter)
- : context(context),
- lex(sourceMgr, errorReporter),
- curToken(lex.lexToken()),
- errorReporter(errorReporter) {
+ SMDiagnosticHandlerTy errorReporter)
+ : context(context), lex(sourceMgr, errorReporter),
+ curToken(lex.lexToken()), errorReporter(std::move(errorReporter)) {
module.reset(new Module());
}
@@ -68,7 +66,7 @@
Token curToken;
// The diagnostic error reporter.
- const SMDiagnosticHandlerTy &errorReporter;
+ SMDiagnosticHandlerTy errorReporter;
// This is the result module we are parsing into.
std::unique_ptr<Module> module;
@@ -1049,10 +1047,16 @@
//===----------------------------------------------------------------------===//
+void mlir::defaultErrorReporter(const llvm::SMDiagnostic &error) {
+ const auto &sourceMgr = *error.getSourceMgr();
+ sourceMgr.PrintMessage(error.getLoc(), error.getKind(), error.getMessage());
+}
+
/// This parses the file specified by the indicated SourceMgr and returns an
/// MLIR module if it was valid. If not, it emits diagnostics and returns null.
Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
- const SMDiagnosticHandlerTy &errorReporter) {
- return Parser(sourceMgr, context, errorReporter).parseModule();
+ SMDiagnosticHandlerTy errorReporter) {
+ return Parser(sourceMgr, context,
+ errorReporter ? std::move(errorReporter) : defaultErrorReporter)
+ .parseModule();
}
-
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 4238bb8..67a090c 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -44,6 +44,8 @@
checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
cl::init(false));
+enum OptResult { OptSuccess, OptFailure };
+
/// Open the specified output file and return it, exiting if there is any I/O or
/// other errors.
static std::unique_ptr<ToolOutputFile> getOutputStream() {
@@ -59,30 +61,29 @@
}
/// Parses the memory buffer and, if successfully parsed, prints the parsed
-/// output. Returns whether parsing succeeded.
-bool parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer,
- const SMDiagnosticHandlerTy& errorReporter) {
+/// output.
+OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
// Parse the input file.
MLIRContext context;
- std::unique_ptr<Module> module(
- parseSourceFile(sourceMgr, &context, errorReporter));
- if (!module) return false;
+ std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
+ if (!module)
+ return OptFailure;
// Print the output.
auto output = getOutputStream();
module->print(output->os());
output->keep();
- // Success.
- return true;
+ return OptSuccess;
}
/// Split the memory buffer into multiple buffers using the marker -----.
-bool splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
+OptResult
+splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
const char marker[] = "-----";
SmallVector<StringRef, 2> sourceBuffers;
buffer->getBuffer().split(sourceBuffers, marker);
@@ -93,7 +94,7 @@
// of diagnostics.
// TODO: Enable specifying errors on different lines (@-1).
// TODO: Currently only checking if substring matches, enable regex checking.
- bool failed = false;
+ OptResult opt_result = OptSuccess;
SourceMgr fileSourceMgr;
fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
@@ -104,7 +105,7 @@
// Create error checker that uses the helper function to relate the reported
// error to the file being parsed.
- SMDiagnosticHandlerTy checker = [&](SMDiagnostic err) {
+ SMDiagnosticHandlerTy checker = [&](const SMDiagnostic &err) {
const auto &sourceMgr = *err.getSourceMgr();
const char *bufferStart =
sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID())->getBufferStart();
@@ -122,7 +123,7 @@
fileSourceMgr.PrintMessage(
loc, SourceMgr::DK_Error,
"unexpected error: " + err.getMessage());
- failed = true;
+ opt_result = OptFailure;
return;
}
@@ -131,11 +132,11 @@
const char checkPrefix[] = "expected-error {{";
loc = SMLoc::getFromPointer(fileOffset + offset + line.find(checkPrefix) -
err.getColumnNo() + strlen(checkPrefix));
- fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error,
- "\"" + err.getMessage() +
- "\" did not contain expected substring \"" +
- matches[1] + "\"");
- failed = true;
+ fileSourceMgr.PrintMessage(
+ loc, SourceMgr::DK_Error,
+ "\"" + err.getMessage() + "\" did not contain expected substring \"" +
+ matches[1] + "\"");
+ opt_result = OptFailure;
return;
}
};
@@ -155,7 +156,7 @@
"too many errors expected: unable to verify "
"more than one error per group");
fileOffset += subbuffer.size() + strlen(marker);
- failed = true;
+ opt_result = OptFailure;
continue;
}
@@ -177,13 +178,13 @@
fileSourceMgr.PrintMessage(
loc, SourceMgr::DK_Error,
"expected error \"" + matches[1] + "\" was not produced", range);
- failed = true;
+ opt_result = OptFailure;
}
fileOffset += subbuffer.size() + strlen(marker);
}
- return !failed;
+ return opt_result;
}
int main(int argc, char **argv) {
@@ -200,12 +201,7 @@
}
if (checkParserErrors)
- return !splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
+ return splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
- // Error reporter that simply prints the errors reported.
- SMDiagnosticHandlerTy errorReporter = [](llvm::SMDiagnostic err) {
- const auto& sourceMgr = *err.getSourceMgr();
- sourceMgr.PrintMessage(err.getLoc(), err.getKind(), err.getMessage());
- };
- return !parseAndPrintMemoryBuffer(std::move(*fileOrErr), errorReporter);
+ return parseAndPrintMemoryBuffer(std::move(*fileOrErr));
}