Add default error reporter for parser.

Add a default error reporter for the parser that uses the SourceManager to print the error. Also and OptResult enum (mirroring ParseResult) to make the behavior self-documenting.

PiperOrigin-RevId: 203173647
diff --git a/include/mlir/Parser.h b/include/mlir/Parser.h
index 276c081..951b3a4 100644
--- a/include/mlir/Parser.h
+++ b/include/mlir/Parser.h
@@ -33,13 +33,18 @@
 class Module;
 class MLIRContext;
 
-using SMDiagnosticHandlerTy = std::function<void(llvm::SMDiagnostic error)>;
+using SMDiagnosticHandlerTy =
+    std::function<void(const llvm::SMDiagnostic &error)>;
+
+/// Default error reproter that prints out the error using the SourceMgr of the
+/// error.
+void defaultErrorReporter(const llvm::SMDiagnostic &error);
 
 /// This parses the file specified by the indicated SourceMgr and returns an
 /// MLIR module if it was valid.  If not, the errorReporter is used to report
 /// the error diagnostics and this function returns null.
 Module *parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
-                        const SMDiagnosticHandlerTy &errorReporter);
+                        SMDiagnosticHandlerTy errorReporter = nullptr);
 
 } // end namespace mlir
 
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 91f80e2..08157e5 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -48,11 +48,9 @@
 class Parser {
 public:
   Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
-         const SMDiagnosticHandlerTy &errorReporter)
-      : context(context),
-        lex(sourceMgr, errorReporter),
-        curToken(lex.lexToken()),
-        errorReporter(errorReporter) {
+         SMDiagnosticHandlerTy errorReporter)
+      : context(context), lex(sourceMgr, errorReporter),
+        curToken(lex.lexToken()), errorReporter(std::move(errorReporter)) {
     module.reset(new Module());
   }
 
@@ -68,7 +66,7 @@
   Token curToken;
 
   // The diagnostic error reporter.
-  const SMDiagnosticHandlerTy &errorReporter;
+  SMDiagnosticHandlerTy errorReporter;
 
   // This is the result module we are parsing into.
   std::unique_ptr<Module> module;
@@ -1049,10 +1047,16 @@
 
 //===----------------------------------------------------------------------===//
 
+void mlir::defaultErrorReporter(const llvm::SMDiagnostic &error) {
+  const auto &sourceMgr = *error.getSourceMgr();
+  sourceMgr.PrintMessage(error.getLoc(), error.getKind(), error.getMessage());
+}
+
 /// This parses the file specified by the indicated SourceMgr and returns an
 /// MLIR module if it was valid.  If not, it emits diagnostics and returns null.
 Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
-                              const SMDiagnosticHandlerTy &errorReporter) {
-  return Parser(sourceMgr, context, errorReporter).parseModule();
+                              SMDiagnosticHandlerTy errorReporter) {
+  return Parser(sourceMgr, context,
+                errorReporter ? std::move(errorReporter) : defaultErrorReporter)
+      .parseModule();
 }
-
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 4238bb8..67a090c 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -44,6 +44,8 @@
 checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
                   cl::init(false));
 
+enum OptResult { OptSuccess, OptFailure };
+
 /// Open the specified output file and return it, exiting if there is any I/O or
 /// other errors.
 static std::unique_ptr<ToolOutputFile> getOutputStream() {
@@ -59,30 +61,29 @@
 }
 
 /// Parses the memory buffer and, if successfully parsed, prints the parsed
-/// output. Returns whether parsing succeeded.
-bool parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer,
-                               const SMDiagnosticHandlerTy& errorReporter) {
+/// output.
+OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
   // Tell sourceMgr about this buffer, which is what the parser will pick up.
   SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
 
   // Parse the input file.
   MLIRContext context;
-  std::unique_ptr<Module> module(
-      parseSourceFile(sourceMgr, &context, errorReporter));
-  if (!module) return false;
+  std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
+  if (!module)
+    return OptFailure;
 
   // Print the output.
   auto output = getOutputStream();
   module->print(output->os());
   output->keep();
 
-  // Success.
-  return true;
+  return OptSuccess;
 }
 
 /// Split the memory buffer into multiple buffers using the marker -----.
-bool splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
+OptResult
+splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
   const char marker[] = "-----";
   SmallVector<StringRef, 2> sourceBuffers;
   buffer->getBuffer().split(sourceBuffers, marker);
@@ -93,7 +94,7 @@
   // of diagnostics.
   // TODO: Enable specifying errors on different lines (@-1).
   // TODO: Currently only checking if substring matches, enable regex checking.
-  bool failed = false;
+  OptResult opt_result = OptSuccess;
   SourceMgr fileSourceMgr;
   fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
 
@@ -104,7 +105,7 @@
 
   // Create error checker that uses the helper function to relate the reported
   // error to the file being parsed.
-  SMDiagnosticHandlerTy checker = [&](SMDiagnostic err) {
+  SMDiagnosticHandlerTy checker = [&](const SMDiagnostic &err) {
     const auto &sourceMgr = *err.getSourceMgr();
     const char *bufferStart =
         sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID())->getBufferStart();
@@ -122,7 +123,7 @@
       fileSourceMgr.PrintMessage(
           loc, SourceMgr::DK_Error,
           "unexpected error: " + err.getMessage());
-      failed = true;
+      opt_result = OptFailure;
       return;
     }
 
@@ -131,11 +132,11 @@
       const char checkPrefix[] = "expected-error {{";
       loc = SMLoc::getFromPointer(fileOffset + offset + line.find(checkPrefix) -
                                   err.getColumnNo() + strlen(checkPrefix));
-      fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error,
-                                 "\"" + err.getMessage() +
-                                     "\" did not contain expected substring \"" +
-                                     matches[1] + "\"");
-      failed = true;
+      fileSourceMgr.PrintMessage(
+          loc, SourceMgr::DK_Error,
+          "\"" + err.getMessage() + "\" did not contain expected substring \"" +
+              matches[1] + "\"");
+      opt_result = OptFailure;
       return;
     }
   };
@@ -155,7 +156,7 @@
                                  "too many errors expected: unable to verify "
                                  "more than one error per group");
       fileOffset += subbuffer.size() + strlen(marker);
-      failed = true;
+      opt_result = OptFailure;
       continue;
     }
 
@@ -177,13 +178,13 @@
       fileSourceMgr.PrintMessage(
           loc, SourceMgr::DK_Error,
           "expected error \"" + matches[1] + "\" was not produced", range);
-      failed = true;
+      opt_result = OptFailure;
     }
 
     fileOffset += subbuffer.size() + strlen(marker);
   }
 
-  return !failed;
+  return opt_result;
 }
 
 int main(int argc, char **argv) {
@@ -200,12 +201,7 @@
   }
 
   if (checkParserErrors)
-    return !splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
+    return splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
 
-  // Error reporter that simply prints the errors reported.
-  SMDiagnosticHandlerTy errorReporter = [](llvm::SMDiagnostic err) {
-    const auto& sourceMgr = *err.getSourceMgr();
-    sourceMgr.PrintMessage(err.getLoc(), err.getKind(), err.getMessage());
-  };
-  return !parseAndPrintMemoryBuffer(std::move(*fileOrErr), errorReporter);
+  return parseAndPrintMemoryBuffer(std::move(*fileOrErr));
 }