blob: 8ec042b3c4c0e20a4652044e856175bd8ecb5044 [file] [log] [blame]
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -07001//===- mlir-opt.cpp - MLIR Optimizer Driver -------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This is a command line utility that parses an MLIR file, runs an optimization
19// pass, then prints the result back out. It is designed to support unit
20// testing.
21//
22//===----------------------------------------------------------------------===//
23
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070024#include "mlir/IR/MLFunction.h"
Chris Lattnerf7e22732018-06-22 22:03:48 -070025#include "mlir/IR/MLIRContext.h"
Chris Lattnere2259872018-06-21 15:22:42 -070026#include "mlir/IR/Module.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070027#include "mlir/IR/Pass.h"
Chris Lattnere79379a2018-06-22 10:39:19 -070028#include "mlir/Parser.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070029#include "mlir/TensorFlow/ControlFlowOps.h"
30#include "mlir/TensorFlow/Passes.h"
31#include "mlir/Transforms/Passes.h"
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070032#include "llvm/Support/CommandLine.h"
Chris Lattnere2259872018-06-21 15:22:42 -070033#include "llvm/Support/FileUtilities.h"
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070034#include "llvm/Support/InitLLVM.h"
Jacques Pienaarca4c4a02018-06-25 08:10:46 -070035#include "llvm/Support/Regex.h"
Jacques Pienaar39ffa102018-07-07 19:12:22 -070036#include "llvm/Support/SourceMgr.h"
Chris Lattnere2259872018-06-21 15:22:42 -070037#include "llvm/Support/ToolOutputFile.h"
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070038using namespace mlir;
Chris Lattnere2259872018-06-21 15:22:42 -070039using namespace llvm;
40
41static cl::opt<std::string>
42inputFilename(cl::Positional, cl::desc("<input file>"), cl::init("-"));
43
44static cl::opt<std::string>
45outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"),
46 cl::init("-"));
47
Jacques Pienaarbae40512018-06-24 09:10:36 -070048static cl::opt<bool>
49checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
50 cl::init(false));
Chris Lattnere2259872018-06-21 15:22:42 -070051
Chris Lattneree0c2ae2018-07-29 12:37:35 -070052// TODO(clattner): replace these bool options with an enum list option.
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070053static cl::opt<bool> convertToCFGOpt(
54 "convert-to-cfg",
55 cl::desc("Convert all ML functions in the module to CFG ones"));
56
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070057static cl::opt<bool> unrollInnermostLoops("unroll-innermost-loops",
58 cl::desc("Unroll innermost loops"),
59 cl::init(false));
60
Chris Lattneree0c2ae2018-07-29 12:37:35 -070061static cl::opt<bool> raiseTFControlFlow(
62 "tf-raise-control-flow",
63 cl::desc("Raise TensorFlow Switch/Match nodes to a CFG"));
64
Jacques Pienaar7b829702018-07-03 13:24:09 -070065enum OptResult { OptSuccess, OptFailure };
66
Chris Lattnere2259872018-06-21 15:22:42 -070067/// Open the specified output file and return it, exiting if there is any I/O or
68/// other errors.
69static std::unique_ptr<ToolOutputFile> getOutputStream() {
70 std::error_code error;
71 auto result = make_unique<ToolOutputFile>(outputFilename, error,
72 sys::fs::F_None);
73 if (error) {
74 llvm::errs() << error.message() << '\n';
75 exit(1);
76 }
77
78 return result;
79}
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070080
Chris Lattneree0c2ae2018-07-29 12:37:35 -070081static void initializeMLIRContext(MLIRContext &ctx) {
82 TFControlFlow::registerOperations(ctx);
83}
84
Jacques Pienaarbae40512018-06-24 09:10:36 -070085/// Parses the memory buffer and, if successfully parsed, prints the parsed
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070086/// output. Optionally, convert ML functions into CFG functions.
87/// TODO: pull parsing and printing into separate functions.
Jacques Pienaar7b829702018-07-03 13:24:09 -070088OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
Jacques Pienaarbae40512018-06-24 09:10:36 -070089 // Tell sourceMgr about this buffer, which is what the parser will pick up.
90 SourceMgr sourceMgr;
91 sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
92
Jacques Pienaar9c411be2018-06-24 19:17:35 -070093 // Parse the input file.
Jacques Pienaarbae40512018-06-24 09:10:36 -070094 MLIRContext context;
Chris Lattneree0c2ae2018-07-29 12:37:35 -070095 initializeMLIRContext(context);
Jacques Pienaar7b829702018-07-03 13:24:09 -070096 std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
97 if (!module)
98 return OptFailure;
Jacques Pienaarbae40512018-06-24 09:10:36 -070099
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700100 // Convert ML functions into CFG functions
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700101 if (convertToCFGOpt) {
102 auto *pass = createConvertToCFGPass();
103 pass->runOnModule(module.get());
104 delete pass;
105 module->verify();
106 }
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700107
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700108 if (unrollInnermostLoops) {
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700109 auto *pass = createLoopUnrollPass();
110 pass->runOnModule(module.get());
111 delete pass;
112 module->verify();
113 }
114
115 if (raiseTFControlFlow) {
116 auto *pass = createRaiseTFControlFlowPass();
117 pass->runOnModule(module.get());
118 delete pass;
119 module->verify();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700120 }
121
Jacques Pienaarbae40512018-06-24 09:10:36 -0700122 // Print the output.
123 auto output = getOutputStream();
124 module->print(output->os());
125 output->keep();
126
Jacques Pienaar7b829702018-07-03 13:24:09 -0700127 return OptSuccess;
Jacques Pienaarbae40512018-06-24 09:10:36 -0700128}
129
130/// Split the memory buffer into multiple buffers using the marker -----.
Jacques Pienaar7b829702018-07-03 13:24:09 -0700131OptResult
132splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
Jacques Pienaarbae40512018-06-24 09:10:36 -0700133 const char marker[] = "-----";
134 SmallVector<StringRef, 2> sourceBuffers;
135 buffer->getBuffer().split(sourceBuffers, marker);
Jacques Pienaarbae40512018-06-24 09:10:36 -0700136
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700137 // Error reporter that verifies error reports matches expected error
138 // substring.
139 // TODO: Only checking for error cases below. Could be expanded to other kinds
140 // of diagnostics.
141 // TODO: Enable specifying errors on different lines (@-1).
142 // TODO: Currently only checking if substring matches, enable regex checking.
Jacques Pienaar7b829702018-07-03 13:24:09 -0700143 OptResult opt_result = OptSuccess;
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700144 SourceMgr fileSourceMgr;
145 fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
146
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700147 // Record the expected errors's position, substring and whether it was seen.
148 struct ExpectedError {
149 int lineNo;
150 StringRef substring;
151 SMLoc fileLoc;
152 bool matched;
153 };
154
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700155 // Tracks offset of subbuffer into original buffer.
156 const char *fileOffset =
157 fileSourceMgr.getMemoryBuffer(fileSourceMgr.getMainFileID())
158 ->getBufferStart();
159
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700160 for (auto &subbuffer : sourceBuffers) {
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700161 SourceMgr sourceMgr;
162 // Tell sourceMgr about this buffer, which is what the parser will pick up.
163 sourceMgr.AddNewSourceBuffer(MemoryBuffer::getMemBufferCopy(subbuffer),
164 SMLoc());
165
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700166 // Extracing the expected errors.
James Molloy61a656c2018-07-22 15:45:24 -0700167 llvm::Regex expected("expected-error(@[+-][0-9]+)? *{{(.*)}}");
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700168 SmallVector<ExpectedError, 2> expectedErrors;
169 SmallVector<StringRef, 100> lines;
170 subbuffer.split(lines, '\n');
171 size_t bufOffset = 0;
172 for (int lineNo = 0; lineNo < lines.size(); ++lineNo) {
173 SmallVector<StringRef, 3> matches;
174 if (expected.match(lines[lineNo], &matches)) {
175 // Point to the start of expected-error.
176 SMLoc errorStart =
177 SMLoc::getFromPointer(fileOffset + bufOffset +
178 lines[lineNo].size() - matches[2].size() - 2);
179 ExpectedError expErr{lineNo + 1, matches[2], errorStart, false};
180 int offset;
181 if (!matches[1].empty() &&
182 !matches[1].drop_front().getAsInteger(0, offset)) {
183 expErr.lineNo += offset;
184 }
185 expectedErrors.push_back(expErr);
186 }
187 bufOffset += lines[lineNo].size() + 1;
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700188 }
189
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700190 // Error checker that verifies reported error was expected.
191 auto checker = [&](const SMDiagnostic &err) {
192 for (auto &e : expectedErrors) {
193 if (err.getLineNo() == e.lineNo &&
194 err.getMessage().contains(e.substring)) {
195 e.matched = true;
196 return;
197 }
198 }
199 // Report error if no match found.
200 const auto &sourceMgr = *err.getSourceMgr();
201 const char *bufferStart =
202 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID())
203 ->getBufferStart();
204
205 size_t offset = err.getLoc().getPointer() - bufferStart;
206 SMLoc loc = SMLoc::getFromPointer(fileOffset + offset);
207 fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error,
208 "unexpected error: " + err.getMessage());
209 opt_result = OptFailure;
210 };
211
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700212 // Parse the input file.
213 MLIRContext context;
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700214 initializeMLIRContext(context);
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700215 std::unique_ptr<Module> module(
216 parseSourceFile(sourceMgr, &context, checker));
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700217
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700218 // Verify that all expected errors were seen.
219 for (auto err : expectedErrors) {
220 if (!err.matched) {
221 SMRange range(err.fileLoc,
222 SMLoc::getFromPointer(err.fileLoc.getPointer() +
223 err.substring.size()));
224 fileSourceMgr.PrintMessage(
225 err.fileLoc, SourceMgr::DK_Error,
226 "expected error \"" + err.substring + "\" was not produced", range);
227 opt_result = OptFailure;
228 }
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700229 }
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700230
231 fileOffset += subbuffer.size() + strlen(marker);
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700232 }
233
Jacques Pienaar7b829702018-07-03 13:24:09 -0700234 return opt_result;
Jacques Pienaarbae40512018-06-24 09:10:36 -0700235}
236
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700237int main(int argc, char **argv) {
Chris Lattnere2259872018-06-21 15:22:42 -0700238 InitLLVM x(argc, argv);
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700239
Chris Lattnere2259872018-06-21 15:22:42 -0700240 cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700241
Chris Lattnere79379a2018-06-22 10:39:19 -0700242 // Set up the input file.
243 auto fileOrErr = MemoryBuffer::getFileOrSTDIN(inputFilename);
244 if (std::error_code error = fileOrErr.getError()) {
245 llvm::errs() << argv[0] << ": could not open input file '" << inputFilename
246 << "': " << error.message() << "\n";
247 return 1;
248 }
249
Jacques Pienaarbae40512018-06-24 09:10:36 -0700250 if (checkParserErrors)
Jacques Pienaar7b829702018-07-03 13:24:09 -0700251 return splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700252
Jacques Pienaar7b829702018-07-03 13:24:09 -0700253 return parseAndPrintMemoryBuffer(std::move(*fileOrErr));
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700254}