Change Lexer and Parser to take diagnostic reporter function.
Add diagnostic reporter function to lexer/parser and use that from mlir-opt to report errors instead of having the lexer/parser print the errors.
PiperOrigin-RevId: 201892004
diff --git a/include/mlir/Parser.h b/include/mlir/Parser.h
index 42f25bb..276c081 100644
--- a/include/mlir/Parser.h
+++ b/include/mlir/Parser.h
@@ -22,17 +22,24 @@
#ifndef MLIR_PARSER_H
#define MLIR_PARSER_H
+#include <functional>
+
namespace llvm {
class SourceMgr;
-}
+ class SMDiagnostic;
+} // end namespace llvm
namespace mlir {
class Module;
class MLIRContext;
+using SMDiagnosticHandlerTy = std::function<void(llvm::SMDiagnostic error)>;
+
/// This parses the file specified by the indicated SourceMgr and returns an
-/// MLIR module if it was valid. If not, it emits diagnostics and returns null.
-Module *parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context);
+/// MLIR module if it was valid. If not, the errorReporter is used to report
+/// the error diagnostics and this function returns null.
+Module *parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
+ const SMDiagnosticHandlerTy &errorReporter);
} // end namespace mlir
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index b192f71..7f53886 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -25,7 +25,9 @@
using llvm::SMLoc;
using llvm::SourceMgr;
-Lexer::Lexer(llvm::SourceMgr &sourceMgr) : sourceMgr(sourceMgr) {
+Lexer::Lexer(llvm::SourceMgr &sourceMgr,
+ const SMDiagnosticHandlerTy &errorReporter)
+ : sourceMgr(sourceMgr), errorReporter(errorReporter) {
auto bufferID = sourceMgr.getMainFileID();
curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
curPtr = curBuffer.begin();
@@ -33,10 +35,8 @@
/// emitError - Emit an error message and return an Token::error token.
Token Lexer::emitError(const char *loc, const Twine &message) {
- // TODO(clattner): If/when we want to implement a -verify mode, this will need
- // to package up errors into SMDiagnostic and report them.
- sourceMgr.PrintMessage(SMLoc::getFromPointer(loc), SourceMgr::DK_Error,
- message);
+ errorReporter(sourceMgr.GetMessage(SMLoc::getFromPointer(loc),
+ SourceMgr::DK_Error, message));
return formToken(Token::error, loc);
}
diff --git a/lib/Parser/Lexer.h b/lib/Parser/Lexer.h
index 4f364bc..139dfa6 100644
--- a/lib/Parser/Lexer.h
+++ b/lib/Parser/Lexer.h
@@ -22,17 +22,15 @@
#ifndef MLIR_LIB_PARSER_LEXER_H
#define MLIR_LIB_PARSER_LEXER_H
+#include "mlir/Parser.h"
#include "Token.h"
-namespace llvm {
- class SourceMgr;
-}
-
namespace mlir {
/// This class breaks up the current file into a token stream.
class Lexer {
llvm::SourceMgr &sourceMgr;
+ const SMDiagnosticHandlerTy &errorReporter;
StringRef curBuffer;
const char *curPtr;
@@ -40,7 +38,8 @@
Lexer(const Lexer&) = delete;
void operator=(const Lexer&) = delete;
public:
- explicit Lexer(llvm::SourceMgr &sourceMgr);
+ explicit Lexer(llvm::SourceMgr &sourceMgr,
+ const SMDiagnosticHandlerTy &errorReporter);
llvm::SourceMgr &getSourceMgr() { return sourceMgr; }
@@ -48,9 +47,7 @@
/// Change the position of the lexer cursor. The next token we lex will start
/// at the designated point in the input.
- void resetPointer(const char *newPointer) {
- curPtr = newPointer;
- }
+ void resetPointer(const char *newPointer) { curPtr = newPointer; }
private:
// Helpers.
Token formToken(Token::TokenKind kind, const char *tokStart) {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 828a8d5..5a79f1e 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -42,8 +42,12 @@
/// Main parser implementation.
class Parser {
public:
- Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context)
- : context(context), lex(sourceMgr), curToken(lex.lexToken()){
+ Parser(llvm::SourceMgr &sourceMgr, MLIRContext *context,
+ const SMDiagnosticHandlerTy &errorReporter)
+ : context(context),
+ lex(sourceMgr, errorReporter),
+ curToken(lex.lexToken()),
+ errorReporter(errorReporter) {
module.reset(new Module());
}
@@ -58,6 +62,9 @@
// This is the next token that hasn't been consumed yet.
Token curToken;
+ // The diagnostic error reporter.
+ const SMDiagnosticHandlerTy &errorReporter;
+
// This is the result module we are parsing into.
std::unique_ptr<Module> module;
@@ -131,13 +138,12 @@
ParseResult Parser::emitError(SMLoc loc, const Twine &message) {
// If we hit a parse error in response to a lexer error, then the lexer
- // already emitted an error.
+ // already reported the error.
if (curToken.is(Token::error))
return ParseFailure;
- // TODO(clattner): If/when we want to implement a -verify mode, this will need
- // to package up errors into SMDiagnostic and report them.
- lex.getSourceMgr().PrintMessage(loc, SourceMgr::DK_Error, message);
+ errorReporter(
+ lex.getSourceMgr().GetMessage(loc, SourceMgr::DK_Error, message));
return ParseFailure;
}
@@ -705,6 +711,7 @@
/// This parses the file specified by the indicated SourceMgr and returns an
/// MLIR module if it was valid. If not, it emits diagnostics and returns null.
-Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context){
- return Parser(sourceMgr, context).parseModule();
+Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
+ const SMDiagnosticHandlerTy &errorReporter) {
+ return Parser(sourceMgr, context, errorReporter).parseModule();
}
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 75ed4f5..b90dcd2 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -64,9 +64,14 @@
SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
- // Parse the input file and emit any errors.
+ // Parse the input file.
MLIRContext context;
- std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
+ // Error reporter that simply prints the errors reported.
+ SMDiagnosticHandlerTy errorReporter = [&sourceMgr](llvm::SMDiagnostic err) {
+ sourceMgr.PrintMessage(err.getLoc(), err.getKind(), err.getMessage());
+ };
+ std::unique_ptr<Module> module(
+ parseSourceFile(sourceMgr, &context, errorReporter));
if (!module) return false;
// Print the output.