| //===- mlir-opt.cpp - MLIR Optimizer Driver -------------------------------===// |
| // |
| // Copyright 2019 The MLIR Authors. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| // ============================================================================= |
| // |
| // This is a command line utility that parses an MLIR file, runs an optimization |
| // pass, then prints the result back out. It is designed to support unit |
| // testing. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/IR/MLFunction.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/Module.h" |
| #include "mlir/IR/Pass.h" |
| #include "mlir/Parser.h" |
| #include "mlir/TensorFlow/ControlFlowOps.h" |
| #include "mlir/TensorFlow/Passes.h" |
| #include "mlir/Transforms/Passes.h" |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FileUtilities.h" |
| #include "llvm/Support/InitLLVM.h" |
| #include "llvm/Support/Regex.h" |
| #include "llvm/Support/SourceMgr.h" |
| #include "llvm/Support/ToolOutputFile.h" |
| using namespace mlir; |
| using namespace llvm; |
| |
| static cl::opt<std::string> |
| inputFilename(cl::Positional, cl::desc("<input file>"), cl::init("-")); |
| |
| static cl::opt<std::string> |
| outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"), |
| cl::init("-")); |
| |
| static cl::opt<bool> |
| checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"), |
| cl::init(false)); |
| |
| // TODO(clattner): replace these bool options with an enum list option. |
| static cl::opt<bool> convertToCFGOpt( |
| "convert-to-cfg", |
| cl::desc("Convert all ML functions in the module to CFG ones")); |
| |
| static cl::opt<bool> unrollInnermostLoops("unroll-innermost-loops", |
| cl::desc("Unroll innermost loops"), |
| cl::init(false)); |
| |
| static cl::opt<bool> raiseTFControlFlow( |
| "tf-raise-control-flow", |
| cl::desc("Raise TensorFlow Switch/Match nodes to a CFG")); |
| |
| enum OptResult { OptSuccess, OptFailure }; |
| |
| /// Open the specified output file and return it, exiting if there is any I/O or |
| /// other errors. |
| static std::unique_ptr<ToolOutputFile> getOutputStream() { |
| std::error_code error; |
| auto result = make_unique<ToolOutputFile>(outputFilename, error, |
| sys::fs::F_None); |
| if (error) { |
| llvm::errs() << error.message() << '\n'; |
| exit(1); |
| } |
| |
| return result; |
| } |
| |
| static void initializeMLIRContext(MLIRContext &ctx) { |
| TFControlFlow::registerOperations(ctx); |
| } |
| |
| /// Parses the memory buffer and, if successfully parsed, prints the parsed |
| /// output. Optionally, convert ML functions into CFG functions. |
| /// TODO: pull parsing and printing into separate functions. |
| OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) { |
| // Tell sourceMgr about this buffer, which is what the parser will pick up. |
| SourceMgr sourceMgr; |
| sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc()); |
| |
| // Parse the input file. |
| MLIRContext context; |
| initializeMLIRContext(context); |
| std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context)); |
| if (!module) |
| return OptFailure; |
| |
| // Convert ML functions into CFG functions |
| if (convertToCFGOpt) { |
| auto *pass = createConvertToCFGPass(); |
| pass->runOnModule(module.get()); |
| delete pass; |
| module->verify(); |
| } |
| |
| if (unrollInnermostLoops) { |
| auto *pass = createLoopUnrollPass(); |
| pass->runOnModule(module.get()); |
| delete pass; |
| module->verify(); |
| } |
| |
| if (raiseTFControlFlow) { |
| auto *pass = createRaiseTFControlFlowPass(); |
| pass->runOnModule(module.get()); |
| delete pass; |
| module->verify(); |
| } |
| |
| // Print the output. |
| auto output = getOutputStream(); |
| module->print(output->os()); |
| output->keep(); |
| |
| return OptSuccess; |
| } |
| |
| /// Split the memory buffer into multiple buffers using the marker -----. |
| OptResult |
| splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) { |
| const char marker[] = "-----"; |
| SmallVector<StringRef, 2> sourceBuffers; |
| buffer->getBuffer().split(sourceBuffers, marker); |
| |
| // Error reporter that verifies error reports matches expected error |
| // substring. |
| // TODO: Only checking for error cases below. Could be expanded to other kinds |
| // of diagnostics. |
| // TODO: Enable specifying errors on different lines (@-1). |
| // TODO: Currently only checking if substring matches, enable regex checking. |
| OptResult opt_result = OptSuccess; |
| SourceMgr fileSourceMgr; |
| fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc()); |
| |
| // Record the expected errors's position, substring and whether it was seen. |
| struct ExpectedError { |
| int lineNo; |
| StringRef substring; |
| SMLoc fileLoc; |
| bool matched; |
| }; |
| |
| // Tracks offset of subbuffer into original buffer. |
| const char *fileOffset = |
| fileSourceMgr.getMemoryBuffer(fileSourceMgr.getMainFileID()) |
| ->getBufferStart(); |
| |
| for (auto &subbuffer : sourceBuffers) { |
| SourceMgr sourceMgr; |
| // Tell sourceMgr about this buffer, which is what the parser will pick up. |
| sourceMgr.AddNewSourceBuffer(MemoryBuffer::getMemBufferCopy(subbuffer), |
| SMLoc()); |
| |
| // Extracing the expected errors. |
| llvm::Regex expected("expected-error(@[+-][0-9]+)? *{{(.*)}}"); |
| SmallVector<ExpectedError, 2> expectedErrors; |
| SmallVector<StringRef, 100> lines; |
| subbuffer.split(lines, '\n'); |
| size_t bufOffset = 0; |
| for (int lineNo = 0; lineNo < lines.size(); ++lineNo) { |
| SmallVector<StringRef, 3> matches; |
| if (expected.match(lines[lineNo], &matches)) { |
| // Point to the start of expected-error. |
| SMLoc errorStart = |
| SMLoc::getFromPointer(fileOffset + bufOffset + |
| lines[lineNo].size() - matches[2].size() - 2); |
| ExpectedError expErr{lineNo + 1, matches[2], errorStart, false}; |
| int offset; |
| if (!matches[1].empty() && |
| !matches[1].drop_front().getAsInteger(0, offset)) { |
| expErr.lineNo += offset; |
| } |
| expectedErrors.push_back(expErr); |
| } |
| bufOffset += lines[lineNo].size() + 1; |
| } |
| |
| // Error checker that verifies reported error was expected. |
| auto checker = [&](const SMDiagnostic &err) { |
| for (auto &e : expectedErrors) { |
| if (err.getLineNo() == e.lineNo && |
| err.getMessage().contains(e.substring)) { |
| e.matched = true; |
| return; |
| } |
| } |
| // Report error if no match found. |
| const auto &sourceMgr = *err.getSourceMgr(); |
| const char *bufferStart = |
| sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()) |
| ->getBufferStart(); |
| |
| size_t offset = err.getLoc().getPointer() - bufferStart; |
| SMLoc loc = SMLoc::getFromPointer(fileOffset + offset); |
| fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error, |
| "unexpected error: " + err.getMessage()); |
| opt_result = OptFailure; |
| }; |
| |
| // Parse the input file. |
| MLIRContext context; |
| initializeMLIRContext(context); |
| std::unique_ptr<Module> module( |
| parseSourceFile(sourceMgr, &context, checker)); |
| |
| // Verify that all expected errors were seen. |
| for (auto err : expectedErrors) { |
| if (!err.matched) { |
| SMRange range(err.fileLoc, |
| SMLoc::getFromPointer(err.fileLoc.getPointer() + |
| err.substring.size())); |
| fileSourceMgr.PrintMessage( |
| err.fileLoc, SourceMgr::DK_Error, |
| "expected error \"" + err.substring + "\" was not produced", range); |
| opt_result = OptFailure; |
| } |
| } |
| |
| fileOffset += subbuffer.size() + strlen(marker); |
| } |
| |
| return opt_result; |
| } |
| |
| int main(int argc, char **argv) { |
| InitLLVM x(argc, argv); |
| |
| cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n"); |
| |
| // Set up the input file. |
| auto fileOrErr = MemoryBuffer::getFileOrSTDIN(inputFilename); |
| if (std::error_code error = fileOrErr.getError()) { |
| llvm::errs() << argv[0] << ": could not open input file '" << inputFilename |
| << "': " << error.message() << "\n"; |
| return 1; |
| } |
| |
| if (checkParserErrors) |
| return splitMemoryBufferForErrorChecking(std::move(*fileOrErr)); |
| |
| return parseAndPrintMemoryBuffer(std::move(*fileOrErr)); |
| } |