Continue revising diagnostic handling to simplify and generalize it, and improve related infra.
- Add a new -verify mode to the mlir-opt tool that allows writing test cases
for optimization and other passes that produce diagnostics.
- Refactor existing the -check-parser-errors flag to mlir-opt into a new
-split-input-file option which is orthogonal to -verify.
- Eliminate the special error hook the parser maintained and use the standard
MLIRContext's one instead.
- Enhance the default MLIRContext error reporter to print file/line/col of
errors when it is available.
- Add new createChecked() methods to the builder that create ops and invoke
the verify hook on them, use this to detected unhandled code in the
RaiseControlFlow pass.
- Teach mlir-opt about expected-error @+, it previously only worked with @-
PiperOrigin-RevId: 211305770
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index c677f38..cb4f4f5 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -191,7 +191,8 @@
/// Create an operation given the fields represented as an OperationState.
OperationInst *createOperation(const OperationState &state);
- /// Create operation of specific op type at the current insertion point.
+ /// Create operation of specific op type at the current insertion point
+ /// without verifying to see if it is valid.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Location *location, Args... args) {
OperationState state(getContext(), location, OpTy::getOperationName());
@@ -202,6 +203,27 @@
return result;
}
+ /// Create operation of specific op type at the current insertion point. If
+ /// the result is an invalid op (the verifier hook fails), emit a the
+ /// specified error message and return null.
+ template <typename OpTy, typename... Args>
+ OpPointer<OpTy> createChecked(const Twine &message, Location *location,
+ Args... args) {
+ OperationState state(getContext(), location, OpTy::getOperationName());
+ OpTy::build(this, &state, args...);
+ auto *inst = createOperation(state);
+ auto result = inst->template getAs<OpTy>();
+ assert(result && "Builder didn't return the right type");
+
+ // If the operation we produce is valid, return it.
+ if (!result->verify())
+ return result;
+ // Otherwise, emit the provided message and return null.
+ inst->emitError(message);
+ inst->eraseFromBlock();
+ return OpPointer<OpTy>();
+ }
+
OperationInst *cloneOperation(const OperationInst &srcOpInst) {
auto *op = srcOpInst.clone();
insert(op);
@@ -307,6 +329,27 @@
return result;
}
+ /// Create operation of specific op type at the current insertion point. If
+ /// the result is an invalid op (the verifier hook fails), emit an error and
+ /// return null.
+ template <typename OpTy, typename... Args>
+ OpPointer<OpTy> createChecked(const Twine &message, Location *location,
+ Args... args) {
+ OperationState state(getContext(), location, OpTy::getOperationName());
+ OpTy::build(this, &state, args...);
+ auto *stmt = createOperation(state);
+ auto result = stmt->template getAs<OpTy>();
+ assert(result && "Builder didn't return the right type");
+
+ // If the operation we produce is valid, return it.
+ if (!result->verify())
+ return result;
+ // Otherwise, emit the provided message and return null.
+ stmt->emitError(message);
+ stmt->eraseFromBlock();
+ return OpPointer<OpTy>();
+ }
+
/// Create a deep copy of the specified statement, remapping any operands that
/// use values outside of the statement using the map that is provided (
/// leaving them alone if no entry is present). Replaces references to cloned
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index 6410a9e..5f00933 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -41,6 +41,7 @@
template <typename OpType>
class OpPointer {
public:
+ explicit OpPointer() : value(Operation::getNull<OpType>().value) {}
explicit OpPointer(OpType value) : value(value) {}
OpType &operator*() { return value; }
@@ -49,7 +50,7 @@
operator bool() const { return value.getOperation(); }
-public:
+private:
OpType value;
};
@@ -58,6 +59,7 @@
template <typename OpType>
class ConstOpPointer {
public:
+ explicit ConstOpPointer() : value(Operation::getNull<OpType>().value) {}
explicit ConstOpPointer(OpType value) : value(value) {}
const OpType &operator*() const { return value; }
@@ -67,7 +69,7 @@
/// Return true if non-null.
operator bool() const { return value.getOperation(); }
-public:
+private:
const OpType value;
};
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index f5e3b17..038f32f 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -216,6 +216,11 @@
/// OperationSet, return it. Otherwise return null.
const AbstractOperation *getAbstractOperation() const;
+ // Return a null OpPointer for the specified type.
+ template <typename OpClass> static OpPointer<OpClass> getNull() {
+ return OpPointer<OpClass>(OpClass(nullptr));
+ }
+
/// The getAs methods perform a dynamic cast from an Operation (like
/// OperationInst and OperationStmt) to a typed Op like DimOp. This returns
/// a null OpPointer on failure.
diff --git a/include/mlir/Parser.h b/include/mlir/Parser.h
index 951b3a4..0c4a7b9 100644
--- a/include/mlir/Parser.h
+++ b/include/mlir/Parser.h
@@ -22,29 +22,18 @@
#ifndef MLIR_PARSER_H
#define MLIR_PARSER_H
-#include <functional>
-
namespace llvm {
class SourceMgr;
- class SMDiagnostic;
} // end namespace llvm
namespace mlir {
class Module;
class MLIRContext;
-using SMDiagnosticHandlerTy =
- std::function<void(const llvm::SMDiagnostic &error)>;
-
-/// Default error reproter that prints out the error using the SourceMgr of the
-/// error.
-void defaultErrorReporter(const llvm::SMDiagnostic &error);
-
/// This parses the file specified by the indicated SourceMgr and returns an
-/// MLIR module if it was valid. If not, the errorReporter is used to report
-/// the error diagnostics and this function returns null.
-Module *parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
- SMDiagnosticHandlerTy errorReporter = nullptr);
+/// MLIR module if it was valid. If not, the error message is emitted through
+/// the error handler registered in the context, and a null pointer is returned.
+Module *parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context);
} // end namespace mlir
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 46d0103..32e2f52 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -319,11 +319,17 @@
if (kind != DiagnosticKind::Error)
return;
- // TODO(clattner): can improve this now!
+ auto &os = llvm::errs();
+
+ if (auto fileLoc = dyn_cast<FileLineColLoc>(location))
+ os << fileLoc->getFilename() << ':' << fileLoc->getLine() << ':'
+ << fileLoc->getColumn() << ": ";
+
+ os << "error: ";
// The default behavior for errors is to emit them to stderr and exit.
- llvm::errs() << message.str() << "\n";
- llvm::errs().flush();
+ os << message.str() << '\n';
+ os.flush();
exit(1);
}
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 043acd7..b4f8e1d 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -20,8 +20,9 @@
//===----------------------------------------------------------------------===//
#include "Lexer.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
#include "llvm/Support/SourceMgr.h"
-#include <cctype>
using namespace mlir;
using llvm::SMLoc;
using llvm::SourceMgr;
@@ -32,17 +33,30 @@
return c == '$' || c == '.' || c == '_' || c == '-';
}
-Lexer::Lexer(llvm::SourceMgr &sourceMgr, SMDiagnosticHandlerTy errorReporter)
- : sourceMgr(sourceMgr), errorReporter(errorReporter) {
+Lexer::Lexer(llvm::SourceMgr &sourceMgr, MLIRContext *context)
+ : sourceMgr(sourceMgr), context(context) {
auto bufferID = sourceMgr.getMainFileID();
curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
curPtr = curBuffer.begin();
}
+/// Encode the specified source location information into an attribute for
+/// attachment to the IR.
+Location *Lexer::getEncodedSourceLocation(llvm::SMLoc loc) {
+ auto &sourceMgr = getSourceMgr();
+ unsigned mainFileID = sourceMgr.getMainFileID();
+ auto lineAndColumn = sourceMgr.getLineAndColumn(loc, mainFileID);
+ auto *buffer = sourceMgr.getMemoryBuffer(mainFileID);
+ auto filename = UniquedFilename::get(buffer->getBufferIdentifier(), context);
+
+ return FileLineColLoc::get(filename, lineAndColumn.first,
+ lineAndColumn.second, context);
+}
+
/// emitError - Emit an error message and return an Token::error token.
Token Lexer::emitError(const char *loc, const Twine &message) {
- errorReporter(sourceMgr.GetMessage(SMLoc::getFromPointer(loc),
- SourceMgr::DK_Error, message));
+ context->emitDiagnostic(getEncodedSourceLocation(SMLoc::getFromPointer(loc)),
+ message, MLIRContext::DiagnosticKind::Error);
return formToken(Token::error, loc);
}
diff --git a/lib/Parser/Lexer.h b/lib/Parser/Lexer.h
index 51962fa..cbd4d0d 100644
--- a/lib/Parser/Lexer.h
+++ b/lib/Parser/Lexer.h
@@ -26,11 +26,12 @@
#include "Token.h"
namespace mlir {
+class Location;
/// This class breaks up the current file into a token stream.
class Lexer {
llvm::SourceMgr &sourceMgr;
- const SMDiagnosticHandlerTy errorReporter;
+ MLIRContext *context;
StringRef curBuffer;
const char *curPtr;
@@ -38,16 +39,20 @@
Lexer(const Lexer&) = delete;
void operator=(const Lexer&) = delete;
public:
- explicit Lexer(llvm::SourceMgr &sourceMgr,
- SMDiagnosticHandlerTy errorReporter);
+ explicit Lexer(llvm::SourceMgr &sourceMgr, MLIRContext *context);
- llvm::SourceMgr &getSourceMgr() { return sourceMgr; }
+ llvm::SourceMgr &getSourceMgr() { return sourceMgr; }
- Token lexToken();
+ Token lexToken();
- /// Change the position of the lexer cursor. The next token we lex will start
- /// at the designated point in the input.
- void resetPointer(const char *newPointer) { curPtr = newPointer; }
+ /// Encode the specified source location information into a Location object
+ /// for attachment to the IR or error reporting.
+ Location *getEncodedSourceLocation(llvm::SMLoc loc);
+
+ /// Change the position of the lexer cursor. The next token we lex will start
+ /// at the designated point in the input.
+ void resetPointer(const char *newPointer) { curPtr = newPointer; }
+
private:
// Helpers.
Token formToken(Token::Kind kind, const char *tokStart) {
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 0935090..1de8a52 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -45,14 +45,6 @@
/// bool value. Failure is "true" in a boolean context.
enum ParseResult { ParseSuccess, ParseFailure };
-/// Return a uniqued filename for the main file the specified SourceMgr is
-/// looking at.
-static UniquedFilename getUniquedFilename(llvm::SourceMgr &sourceMgr,
- MLIRContext *context) {
- auto *buffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
- return UniquedFilename::get(buffer->getBufferIdentifier(), context);
-}
-
namespace {
class Parser;
@@ -61,13 +53,9 @@
/// methods to access this.
class ParserState {
public:
- ParserState(llvm::SourceMgr &sourceMgr, Module *module,
- SMDiagnosticHandlerTy errorReporter)
- : context(module->getContext()), module(module),
- filename(getUniquedFilename(sourceMgr, context)),
- lex(sourceMgr, errorReporter), curToken(lex.lexToken()),
- errorReporter(errorReporter), operationSet(OperationSet::get(context)) {
- }
+ ParserState(llvm::SourceMgr &sourceMgr, Module *module)
+ : context(module->getContext()), module(module), lex(sourceMgr, context),
+ curToken(lex.lexToken()), operationSet(OperationSet::get(context)) {}
// A map from affine map identifier to AffineMap.
llvm::StringMap<AffineMap *> affineMapDefinitions;
@@ -92,18 +80,12 @@
// This is the module we are parsing into.
Module *const module;
- /// The filename to use for location generation.
- UniquedFilename filename;
-
// The lexer for the source file we're parsing.
Lexer lex;
// This is the next token that hasn't been consumed yet.
Token curToken;
- // The diagnostic error reporter.
- SMDiagnosticHandlerTy const errorReporter;
-
// The active OperationSet we're parsing with.
OperationSet &operationSet;
};
@@ -136,7 +118,9 @@
/// Encode the specified source location information into an attribute for
/// attachment to the IR.
- Location *getEncodedSourceLocation(llvm::SMLoc loc);
+ Location *getEncodedSourceLocation(llvm::SMLoc loc) {
+ return state.lex.getEncodedSourceLocation(loc);
+ }
/// Emit an error and return failure.
ParseResult emitError(const Twine &message) {
@@ -221,25 +205,14 @@
// Helper methods.
//===----------------------------------------------------------------------===//
-/// Encode the specified source location information into an attribute for
-/// attachment to the IR.
-Location *Parser::getEncodedSourceLocation(llvm::SMLoc loc) {
- auto &sourceMgr = getSourceMgr();
- auto lineAndColumn =
- sourceMgr.getLineAndColumn(loc, sourceMgr.getMainFileID());
-
- return FileLineColLoc::get(state.filename, lineAndColumn.first,
- lineAndColumn.second, getContext());
-}
-
ParseResult Parser::emitError(SMLoc loc, const Twine &message) {
// If we hit a parse error in response to a lexer error, then the lexer
// already reported the error.
if (getToken().is(Token::error))
return ParseFailure;
- auto &sourceMgr = state.lex.getSourceMgr();
- state.errorReporter(sourceMgr.GetMessage(loc, SourceMgr::DK_Error, message));
+ getContext()->emitDiagnostic(getEncodedSourceLocation(loc), message,
+ MLIRContext::DiagnosticKind::Error);
return ParseFailure;
}
@@ -3026,78 +2999,34 @@
//===----------------------------------------------------------------------===//
-void mlir::defaultErrorReporter(const llvm::SMDiagnostic &error) {
- const auto &sourceMgr = *error.getSourceMgr();
- sourceMgr.PrintMessage(error.getLoc(), error.getKind(), error.getMessage());
-}
-
/// This parses the file specified by the indicated SourceMgr and returns an
/// MLIR module if it was valid. If not, it emits diagnostics and returns null.
-Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr, MLIRContext *context,
- SMDiagnosticHandlerTy errorReporter) {
- if (!errorReporter)
- errorReporter = defaultErrorReporter;
-
- // We are going to replace the context's handler and redirect it to use the
- // error reporter. Save the existing handler and reinstate it when we're
- // done.
- auto existingContextHandler = context->getDiagnosticHandler();
-
- // Install a new handler that uses the error reporter.
- context->registerDiagnosticHandler([&](Location *location, StringRef message,
- MLIRContext::DiagnosticKind kind) {
- SourceMgr::DiagKind diagKind;
- switch (kind) {
- case MLIRContext::DiagnosticKind::Error:
- diagKind = SourceMgr::DK_Error;
- break;
- case MLIRContext::DiagnosticKind::Warning:
- diagKind = SourceMgr::DK_Warning;
- break;
- case MLIRContext::DiagnosticKind::Note:
- diagKind = SourceMgr::DK_Note;
- break;
- }
-
- StringRef filename;
- unsigned line = 0, column = 0;
- if (auto fileLoc = dyn_cast<FileLineColLoc>(location)) {
- filename = fileLoc->getFilename();
- line = fileLoc->getLine();
- column = fileLoc->getColumn();
- }
-
- auto diag = llvm::SMDiagnostic(sourceMgr, SMLoc(), filename, line, column,
- diagKind, message, /*LineStr=*/StringRef(),
- /*Ranges=*/{}, /*FixIts=*/{});
-
- errorReporter(diag);
- });
+Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) {
// This is the result module we are parsing into.
std::unique_ptr<Module> module(new Module(context));
- ParserState state(sourceMgr, module.get(), errorReporter);
+ ParserState state(sourceMgr, module.get());
if (ModuleParser(state).parseModule()) {
- context->registerDiagnosticHandler(existingContextHandler);
return nullptr;
}
// Make sure the parse module has no other structural problems detected by the
// verifier.
+ //
+ // TODO(clattner): The verifier should always emit diagnostics when we have
+ // more location information available. We shouldn't need this hook.
std::string errorResult;
module->verify(&errorResult);
// We don't have location information for general verifier errors, so emit the
- // error on the first line.
+ // error with an unknown location.
if (!errorResult.empty()) {
- auto *mainBuffer = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
- errorReporter(sourceMgr.GetMessage(
- SMLoc::getFromPointer(mainBuffer->getBufferStart()),
- SourceMgr::DK_Error, errorResult));
+ context->emitDiagnostic(UnknownLoc::get(context), errorResult,
+ MLIRContext::DiagnosticKind::Error);
return nullptr;
}
- context->registerDiagnosticHandler(existingContextHandler);
return module.release();
}
diff --git a/test/IR/invalid-affinemap.mlir b/test/IR/invalid-affinemap.mlir
index 8c8fd1c..60654ee 100644
--- a/test/IR/invalid-affinemap.mlir
+++ b/test/IR/invalid-affinemap.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -o - -check-parser-errors
+// RUN: mlir-opt %s -o - -split-input-file -verify
// Check different error cases.
// -----
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
index bf9eae1..966ef61 100644
--- a/test/IR/invalid-ops.mlir
+++ b/test/IR/invalid-ops.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -o - -check-parser-errors
+// RUN: mlir-opt %s -o - -split-input-file -verify
cfgfunc @dim(tensor<1xf32>) {
bb(%0: tensor<1xf32>):
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index ab113e9..69f55b4 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -o - -check-parser-errors
+// RUN: mlir-opt %s -o - -split-input-file -verify
// Check different error cases.
// -----
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 5bbf6b7..97036e9 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -22,6 +22,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Location.h"
#include "mlir/IR/MLFunction.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
@@ -40,6 +41,7 @@
using namespace mlir;
using namespace llvm;
+using llvm::SMLoc;
static cl::opt<std::string>
inputFilename(cl::Positional, cl::desc("<input file>"), cl::init("-"));
@@ -49,8 +51,16 @@
cl::init("-"));
static cl::opt<bool>
-checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
- cl::init(false));
+ splitInputFile("split-input-file",
+ cl::desc("Split the input file into pieces and process each "
+ "chunk independently"),
+ cl::init(false));
+
+static cl::opt<bool>
+ verifyDiagnostics("verify",
+ cl::desc("Check that emitted diagnostics match "
+ "expected-* lines on the corresponding line"),
+ cl::init(false));
enum Passes {
ConvertToCFG,
@@ -93,18 +103,56 @@
// context initializations (e.g., op registrations).
extern void initializeMLIRContext(MLIRContext *ctx);
-/// Parses the memory buffer and, if successfully parsed, prints the parsed
-/// output. Optionally, convert ML functions into CFG functions.
-/// TODO: pull parsing and printing into separate functions.
-OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
- // Tell sourceMgr about this buffer, which is what the parser will pick up.
- SourceMgr sourceMgr;
- sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
+/// Given a MemoryBuffer along with a line and column within it, return the
+/// location being referenced.
+static SMLoc getLocFromLineAndCol(MemoryBuffer &membuf, unsigned lineNo,
+ unsigned columnNo) {
+ // TODO: This should really be upstreamed to be a method on llvm::SourceMgr.
+ // Doing so would allow it to use the offset cache that is already maintained
+ // by SrcBuffer, making this more efficient.
- // Parse the input file.
- MLIRContext context;
- initializeMLIRContext(&context);
- std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
+ // Scan for the correct line number.
+ const char *position = membuf.getBufferStart();
+ const char *end = membuf.getBufferEnd();
+
+ // We start counting line and column numbers from 1.
+ --lineNo;
+ --columnNo;
+
+ while (position < end && lineNo) {
+ auto curChar = *position++;
+
+ // Scan for newlines. If this isn't one, ignore it.
+ if (curChar != '\r' && curChar != '\n')
+ continue;
+
+ // We saw a line break, decrement our counter.
+ --lineNo;
+
+ // Check for \r\n and \n\r and treat it as a single escape. We know that
+ // looking past one character is safe because MemoryBuffer's are always nul
+ // terminated.
+ if (*position != curChar && (*position == '\r' || *position == '\n'))
+ ++position;
+ }
+
+ // If the line/column counter was invalid, return a pointer to the start of
+ // the buffer.
+ if (lineNo || position + columnNo > end)
+ return SMLoc::getFromPointer(membuf.getBufferStart());
+
+ // Otherwise return the right pointer.
+ return SMLoc::getFromPointer(position + columnNo);
+}
+
+/// Perform the actions on the input file indicated by the command line flags
+/// within the specified context.
+///
+/// This typically parses the main source file, runs zero or more optimization
+/// passes, then prints the output.
+///
+static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
+ std::unique_ptr<Module> module(parseSourceFile(sourceMgr, context));
if (!module)
return OptFailure;
@@ -138,115 +186,156 @@
auto output = getOutputStream();
module->print(output->os());
output->keep();
-
return OptSuccess;
}
-/// Split the memory buffer into multiple buffers using the marker -----.
-OptResult
-splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
- const char marker[] = "-----";
- SmallVector<StringRef, 2> sourceBuffers;
- buffer->getBuffer().split(sourceBuffers, marker);
+/// Parses the memory buffer. If successfully, run a series of passes against
+/// it and print the result.
+static OptResult processFile(std::unique_ptr<MemoryBuffer> ownedBuffer) {
+ // Tell sourceMgr about this buffer, which is what the parser will pick up.
+ SourceMgr sourceMgr;
+ auto &buffer = *ownedBuffer;
+ sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
- // Error reporter that verifies error reports matches expected error
- // substring.
- // TODO: Only checking for error cases below. Could be expanded to other kinds
- // of diagnostics.
- // TODO: Enable specifying errors on different lines (@-1).
- // TODO: Currently only checking if substring matches, enable regex checking.
- OptResult opt_result = OptSuccess;
- SourceMgr fileSourceMgr;
- fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
+ // Parse the input file.
+ MLIRContext context;
+ initializeMLIRContext(&context);
- // Record the expected errors's position, substring and whether it was seen.
- struct ExpectedError {
- int lineNo;
- StringRef substring;
- SMLoc fileLoc;
- bool matched;
- };
+ // If we are in verify mode then we have a lot of work to do, otherwise just
+ // perform the actions without worrying about it.
+ if (!verifyDiagnostics) {
- // Tracks offset of subbuffer into original buffer.
- const char *fileOffset =
- fileSourceMgr.getMemoryBuffer(fileSourceMgr.getMainFileID())
- ->getBufferStart();
-
- for (auto &subbuffer : sourceBuffers) {
- SourceMgr sourceMgr;
- // Tell sourceMgr about this buffer, which is what the parser will pick up.
- sourceMgr.AddNewSourceBuffer(MemoryBuffer::getMemBufferCopy(subbuffer),
- SMLoc());
-
- // Extract the expected errors.
- llvm::Regex expected("expected-error *(@[+-][0-9]+)? *{{(.*)}}");
- SmallVector<ExpectedError, 2> expectedErrors;
- SmallVector<StringRef, 100> lines;
- subbuffer.split(lines, '\n');
- size_t bufOffset = 0;
- for (int lineNo = 0; lineNo < lines.size(); ++lineNo) {
- SmallVector<StringRef, 3> matches;
- if (expected.match(lines[lineNo], &matches)) {
- // Point to the start of expected-error.
- SMLoc errorStart =
- SMLoc::getFromPointer(fileOffset + bufOffset +
- lines[lineNo].size() - matches[2].size() - 2);
- ExpectedError expErr{lineNo + 1, matches[2], errorStart, false};
- int offset;
- if (!matches[1].empty() &&
- !matches[1].drop_front().getAsInteger(0, offset)) {
- expErr.lineNo += offset;
- }
- expectedErrors.push_back(expErr);
+ // Register a simple diagnostic handler that prints out info with context.
+ context.registerDiagnosticHandler([&](Location *location, StringRef message,
+ MLIRContext::DiagnosticKind kind) {
+ unsigned line = 1, column = 1;
+ if (auto fileLoc = dyn_cast<FileLineColLoc>(location)) {
+ line = fileLoc->getLine();
+ column = fileLoc->getColumn();
}
- bufOffset += lines[lineNo].size() + 1;
- }
- // Error checker that verifies reported error was expected.
- auto checker = [&](const SMDiagnostic &err) {
- for (auto &e : expectedErrors) {
- if (err.getLineNo() == e.lineNo &&
- err.getMessage().contains(e.substring)) {
- e.matched = true;
- return;
- }
- }
- // Report error if no match found.
- const auto &sourceMgr = *err.getSourceMgr();
- const char *bufferStart =
- sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID())
- ->getBufferStart();
+ auto unexpectedLoc = getLocFromLineAndCol(buffer, line, column);
+ sourceMgr.PrintMessage(unexpectedLoc, SourceMgr::DK_Error, message);
+ });
- size_t offset = err.getLoc().getPointer() - bufferStart;
- SMLoc loc = SMLoc::getFromPointer(fileOffset + offset);
- fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error,
- "unexpected error: " + err.getMessage());
- opt_result = OptFailure;
- };
-
- // Parse the input file.
- MLIRContext context;
- initializeMLIRContext(&context);
- std::unique_ptr<Module> module(
- parseSourceFile(sourceMgr, &context, checker));
-
- // Verify that all expected errors were seen.
- for (auto err : expectedErrors) {
- if (!err.matched) {
- SMRange range(err.fileLoc,
- SMLoc::getFromPointer(err.fileLoc.getPointer() +
- err.substring.size()));
- fileSourceMgr.PrintMessage(
- err.fileLoc, SourceMgr::DK_Error,
- "expected error \"" + err.substring + "\" was not produced", range);
- opt_result = OptFailure;
- }
- }
-
- fileOffset += subbuffer.size() + strlen(marker);
+ // Run the test actions.
+ return performActions(sourceMgr, &context);
}
- return opt_result;
+ // Keep track of the result of this file processing. If there are no issues,
+ // then we succeed.
+ auto result = OptSuccess;
+
+ // Record the expected error's position, substring and whether it was seen.
+ struct ExpectedError {
+ unsigned lineNo;
+ StringRef substring;
+ SMLoc fileLoc;
+ bool matched = false;
+ };
+ SmallVector<ExpectedError, 2> expectedErrors;
+
+ // Error checker that verifies reported error was expected.
+ auto checker = [&](Location *location, StringRef message,
+ MLIRContext::DiagnosticKind kind) {
+ unsigned line = 1, column = 1;
+ if (auto *fileLoc = dyn_cast<FileLineColLoc>(location)) {
+ line = fileLoc->getLine();
+ column = fileLoc->getColumn();
+ }
+
+ // If this was an expected error, remember that we saw it and return.
+ for (auto &e : expectedErrors) {
+ if (line == e.lineNo && message.contains(e.substring)) {
+ e.matched = true;
+ return;
+ }
+ }
+
+ // If this error wasn't expected, produce an error out of mlir-opt saying
+ // so.
+ auto unexpectedLoc = getLocFromLineAndCol(buffer, line, column);
+ sourceMgr.PrintMessage(unexpectedLoc, SourceMgr::DK_Error,
+ "unexpected error: " + Twine(message));
+ result = OptFailure;
+ };
+
+ // Scan the file for expected-* designators and register a callback for the
+ // error handler.
+ // Extract the expected errors from the file.
+ llvm::Regex expected("expected-error *(@[+-][0-9]+)? *{{(.*)}}");
+ SmallVector<StringRef, 100> lines;
+ buffer.getBuffer().split(lines, '\n');
+ for (unsigned lineNo = 0, e = lines.size(); lineNo < e; ++lineNo) {
+ SmallVector<StringRef, 3> matches;
+ if (expected.match(lines[lineNo], &matches)) {
+ // Point to the start of expected-error.
+ SMLoc errorStart = SMLoc::getFromPointer(matches[0].data());
+ ExpectedError expErr{lineNo + 1, matches[2], errorStart, false};
+ auto offsetMatch = matches[1];
+ if (!offsetMatch.empty()) {
+ int offset;
+ // Get the integer value without the @ and +/- prefix.
+ if (!offsetMatch.drop_front(2).getAsInteger(0, offset)) {
+ if (offsetMatch[1] == '+')
+ expErr.lineNo += offset;
+ else
+ expErr.lineNo -= offset;
+ }
+ }
+ expectedErrors.push_back(expErr);
+ }
+ }
+
+ // Finally, register the error handler to capture them.
+ context.registerDiagnosticHandler(checker);
+
+ // Do any processing requested by command line flags. We don't care whether
+ // these actions succeed or fail, we only care what diagnostics they produce
+ // and whether they match our expectations.
+ performActions(sourceMgr, &context);
+
+ // Verify that all expected errors were seen.
+ for (auto &err : expectedErrors) {
+ if (!err.matched) {
+ SMRange range(err.fileLoc,
+ SMLoc::getFromPointer(err.fileLoc.getPointer() +
+ err.substring.size()));
+ sourceMgr.PrintMessage(
+ err.fileLoc, SourceMgr::DK_Error,
+ "expected error \"" + err.substring + "\" was not produced", range);
+ result = OptFailure;
+ }
+ }
+
+ return result;
+}
+
+/// Split the specified file on a marker and process each chunk independently
+/// according to the normal processFile logic. This is primarily used to
+/// allow a large number of small independent parser tests to be put into a
+/// single test, but could be used for other purposes as well.
+static OptResult
+splitAndProcessFile(std::unique_ptr<MemoryBuffer> originalBuffer) {
+ const char marker[] = "-----";
+ SmallVector<StringRef, 8> sourceBuffers;
+ originalBuffer->getBuffer().split(sourceBuffers, marker);
+
+ // Add the original buffer to the source manager.
+ SourceMgr fileSourceMgr;
+ fileSourceMgr.AddNewSourceBuffer(std::move(originalBuffer), SMLoc());
+
+ bool hadUnexpectedResult = false;
+
+ // Process each chunk in turn. If any fails, then return a failure of the
+ // tool.
+ for (auto &subBuffer : sourceBuffers) {
+ auto subMemBuffer = MemoryBuffer::getMemBufferCopy(subBuffer);
+ if (processFile(std::move(subMemBuffer)))
+ hadUnexpectedResult = true;
+ }
+
+ return hadUnexpectedResult ? OptFailure : OptSuccess;
}
int main(int argc, char **argv) {
@@ -263,8 +352,10 @@
return 1;
}
- if (checkParserErrors)
- return splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
+ // The split-input-file mode is a very specific mode that slices the file
+ // up into small pieces and checks each independently.
+ if (splitInputFile)
+ return splitAndProcessFile(std::move(*fileOrErr));
- return parseAndPrintMemoryBuffer(std::move(*fileOrErr));
+ return processFile(std::move(*fileOrErr));
}