blob: 26417e20cb1d9ebb65c6150d3ac757745577cd42 [file] [log] [blame]
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -07001//===- mlir-opt.cpp - MLIR Optimizer Driver -------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This is a command line utility that parses an MLIR file, runs an optimization
19// pass, then prints the result back out. It is designed to support unit
20// testing.
21//
22//===----------------------------------------------------------------------===//
23
Chris Lattnerf7bdf952018-08-05 21:12:29 -070024#include "mlir/IR/Attributes.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070025#include "mlir/IR/MLFunction.h"
Chris Lattnerf7e22732018-06-22 22:03:48 -070026#include "mlir/IR/MLIRContext.h"
Chris Lattnere2259872018-06-21 15:22:42 -070027#include "mlir/IR/Module.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070028#include "mlir/IR/Pass.h"
Chris Lattnere79379a2018-06-22 10:39:19 -070029#include "mlir/Parser.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070030#include "mlir/TensorFlow/ControlFlowOps.h"
31#include "mlir/TensorFlow/Passes.h"
32#include "mlir/Transforms/Passes.h"
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070033#include "llvm/Support/CommandLine.h"
Chris Lattnere2259872018-06-21 15:22:42 -070034#include "llvm/Support/FileUtilities.h"
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070035#include "llvm/Support/InitLLVM.h"
Chris Lattnerf7bdf952018-08-05 21:12:29 -070036#include "llvm/Support/PrettyStackTrace.h"
Jacques Pienaarca4c4a02018-06-25 08:10:46 -070037#include "llvm/Support/Regex.h"
Jacques Pienaar39ffa102018-07-07 19:12:22 -070038#include "llvm/Support/SourceMgr.h"
Chris Lattnere2259872018-06-21 15:22:42 -070039#include "llvm/Support/ToolOutputFile.h"
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070040using namespace mlir;
Chris Lattnere2259872018-06-21 15:22:42 -070041using namespace llvm;
42
43static cl::opt<std::string>
44inputFilename(cl::Positional, cl::desc("<input file>"), cl::init("-"));
45
46static cl::opt<std::string>
47outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"),
48 cl::init("-"));
49
Jacques Pienaarbae40512018-06-24 09:10:36 -070050static cl::opt<bool>
51checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
52 cl::init(false));
Chris Lattnere2259872018-06-21 15:22:42 -070053
Chris Lattnerdc3ba382018-07-29 14:13:03 -070054enum Passes {
55 ConvertToCFG,
56 UnrollInnermostLoops,
Uday Bondhugula134154e2018-08-06 18:40:34 -070057 UnrollShortLoops,
Chris Lattnerdc3ba382018-07-29 14:13:03 -070058 TFRaiseControlFlow,
59};
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070060
Chris Lattnerdc3ba382018-07-29 14:13:03 -070061static cl::list<Passes> passList(
62 "", cl::desc("Compiler passes to run"),
63 cl::values(clEnumValN(ConvertToCFG, "convert-to-cfg",
64 "Convert all ML functions in the module to CFG ones"),
65 clEnumValN(UnrollInnermostLoops, "unroll-innermost-loops",
66 "Unroll innermost loops"),
Uday Bondhugula134154e2018-08-06 18:40:34 -070067 clEnumValN(UnrollShortLoops, "unroll-short-loops",
68 "Unroll loops of trip count <= 2"),
Chris Lattnerdc3ba382018-07-29 14:13:03 -070069 clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
70 "Dynamic TensorFlow Switch/Match nodes to a CFG")));
Chris Lattneree0c2ae2018-07-29 12:37:35 -070071
Jacques Pienaar7b829702018-07-03 13:24:09 -070072enum OptResult { OptSuccess, OptFailure };
73
Chris Lattnere2259872018-06-21 15:22:42 -070074/// Open the specified output file and return it, exiting if there is any I/O or
75/// other errors.
76static std::unique_ptr<ToolOutputFile> getOutputStream() {
77 std::error_code error;
MLIR Team61eadaa2018-07-30 15:00:47 -070078 auto result =
79 llvm::make_unique<ToolOutputFile>(outputFilename, error, sys::fs::F_None);
Chris Lattnere2259872018-06-21 15:22:42 -070080 if (error) {
81 llvm::errs() << error.message() << '\n';
82 exit(1);
83 }
84
85 return result;
86}
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070087
Chris Lattneree0c2ae2018-07-29 12:37:35 -070088static void initializeMLIRContext(MLIRContext &ctx) {
89 TFControlFlow::registerOperations(ctx);
90}
91
Jacques Pienaarbae40512018-06-24 09:10:36 -070092/// Parses the memory buffer and, if successfully parsed, prints the parsed
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070093/// output. Optionally, convert ML functions into CFG functions.
94/// TODO: pull parsing and printing into separate functions.
Jacques Pienaar7b829702018-07-03 13:24:09 -070095OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
Jacques Pienaarbae40512018-06-24 09:10:36 -070096 // Tell sourceMgr about this buffer, which is what the parser will pick up.
97 SourceMgr sourceMgr;
98 sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
99
Jacques Pienaar9c411be2018-06-24 19:17:35 -0700100 // Parse the input file.
Jacques Pienaarbae40512018-06-24 09:10:36 -0700101 MLIRContext context;
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700102 initializeMLIRContext(context);
Jacques Pienaar7b829702018-07-03 13:24:09 -0700103 std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
104 if (!module)
105 return OptFailure;
Jacques Pienaarbae40512018-06-24 09:10:36 -0700106
Chris Lattnerdc3ba382018-07-29 14:13:03 -0700107 // Run each of the passes that were selected.
108 for (auto passKind : passList) {
109 Pass *pass = nullptr;
110 switch (passKind) {
111 case ConvertToCFG:
112 pass = createConvertToCFGPass();
113 break;
114 case UnrollInnermostLoops:
115 pass = createLoopUnrollPass();
116 break;
Uday Bondhugula134154e2018-08-06 18:40:34 -0700117 case UnrollShortLoops:
118 pass = createLoopUnrollPass(2);
119 break;
Chris Lattnerdc3ba382018-07-29 14:13:03 -0700120 case TFRaiseControlFlow:
121 pass = createRaiseTFControlFlowPass();
122 break;
123 }
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700124
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700125 pass->runOnModule(module.get());
126 delete pass;
127 module->verify();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700128 }
129
Jacques Pienaarbae40512018-06-24 09:10:36 -0700130 // Print the output.
131 auto output = getOutputStream();
132 module->print(output->os());
133 output->keep();
134
Jacques Pienaar7b829702018-07-03 13:24:09 -0700135 return OptSuccess;
Jacques Pienaarbae40512018-06-24 09:10:36 -0700136}
137
138/// Split the memory buffer into multiple buffers using the marker -----.
Jacques Pienaar7b829702018-07-03 13:24:09 -0700139OptResult
140splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
Jacques Pienaarbae40512018-06-24 09:10:36 -0700141 const char marker[] = "-----";
142 SmallVector<StringRef, 2> sourceBuffers;
143 buffer->getBuffer().split(sourceBuffers, marker);
Jacques Pienaarbae40512018-06-24 09:10:36 -0700144
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700145 // Error reporter that verifies error reports matches expected error
146 // substring.
147 // TODO: Only checking for error cases below. Could be expanded to other kinds
148 // of diagnostics.
149 // TODO: Enable specifying errors on different lines (@-1).
150 // TODO: Currently only checking if substring matches, enable regex checking.
Jacques Pienaar7b829702018-07-03 13:24:09 -0700151 OptResult opt_result = OptSuccess;
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700152 SourceMgr fileSourceMgr;
153 fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
154
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700155 // Record the expected errors's position, substring and whether it was seen.
156 struct ExpectedError {
157 int lineNo;
158 StringRef substring;
159 SMLoc fileLoc;
160 bool matched;
161 };
162
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700163 // Tracks offset of subbuffer into original buffer.
164 const char *fileOffset =
165 fileSourceMgr.getMemoryBuffer(fileSourceMgr.getMainFileID())
166 ->getBufferStart();
167
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700168 for (auto &subbuffer : sourceBuffers) {
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700169 SourceMgr sourceMgr;
170 // Tell sourceMgr about this buffer, which is what the parser will pick up.
Chris Lattnerf7bdf952018-08-05 21:12:29 -0700171 auto bufferId = sourceMgr.AddNewSourceBuffer(
172 MemoryBuffer::getMemBufferCopy(subbuffer), SMLoc());
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700173
Chris Lattnerf7bdf952018-08-05 21:12:29 -0700174 // Extract the expected errors.
James Molloy61a656c2018-07-22 15:45:24 -0700175 llvm::Regex expected("expected-error(@[+-][0-9]+)? *{{(.*)}}");
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700176 SmallVector<ExpectedError, 2> expectedErrors;
177 SmallVector<StringRef, 100> lines;
178 subbuffer.split(lines, '\n');
179 size_t bufOffset = 0;
180 for (int lineNo = 0; lineNo < lines.size(); ++lineNo) {
181 SmallVector<StringRef, 3> matches;
182 if (expected.match(lines[lineNo], &matches)) {
183 // Point to the start of expected-error.
184 SMLoc errorStart =
185 SMLoc::getFromPointer(fileOffset + bufOffset +
186 lines[lineNo].size() - matches[2].size() - 2);
187 ExpectedError expErr{lineNo + 1, matches[2], errorStart, false};
188 int offset;
189 if (!matches[1].empty() &&
190 !matches[1].drop_front().getAsInteger(0, offset)) {
191 expErr.lineNo += offset;
192 }
193 expectedErrors.push_back(expErr);
194 }
195 bufOffset += lines[lineNo].size() + 1;
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700196 }
197
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700198 // Error checker that verifies reported error was expected.
199 auto checker = [&](const SMDiagnostic &err) {
200 for (auto &e : expectedErrors) {
201 if (err.getLineNo() == e.lineNo &&
202 err.getMessage().contains(e.substring)) {
203 e.matched = true;
204 return;
205 }
206 }
207 // Report error if no match found.
208 const auto &sourceMgr = *err.getSourceMgr();
209 const char *bufferStart =
210 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID())
211 ->getBufferStart();
212
213 size_t offset = err.getLoc().getPointer() - bufferStart;
214 SMLoc loc = SMLoc::getFromPointer(fileOffset + offset);
215 fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error,
216 "unexpected error: " + err.getMessage());
217 opt_result = OptFailure;
218 };
219
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700220 // Parse the input file.
221 MLIRContext context;
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700222 initializeMLIRContext(context);
Chris Lattnerf7bdf952018-08-05 21:12:29 -0700223
224 // TODO: refactor into initializeMLIRContext so the normal parser pass
225 // gets to use this.
226 context.registerDiagnosticHandler([&](Attribute *location,
227 StringRef message,
228 MLIRContext::DiagnosticKind kind) {
229 auto offset = cast<IntegerAttr>(location)->getValue();
230 auto ptr = sourceMgr.getMemoryBuffer(bufferId)->getBufferStart() + offset;
231 SourceMgr::DiagKind diagKind;
232 switch (kind) {
233 case MLIRContext::DiagnosticKind::Error:
234 diagKind = SourceMgr::DK_Error;
235 break;
236 case MLIRContext::DiagnosticKind::Warning:
237 diagKind = SourceMgr::DK_Warning;
238 break;
239 case MLIRContext::DiagnosticKind::Note:
240 diagKind = SourceMgr::DK_Note;
241 break;
242 }
243 checker(
244 sourceMgr.GetMessage(SMLoc::getFromPointer(ptr), diagKind, message));
245 });
246
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700247 std::unique_ptr<Module> module(
248 parseSourceFile(sourceMgr, &context, checker));
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700249
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700250 // Verify that all expected errors were seen.
251 for (auto err : expectedErrors) {
252 if (!err.matched) {
253 SMRange range(err.fileLoc,
254 SMLoc::getFromPointer(err.fileLoc.getPointer() +
255 err.substring.size()));
256 fileSourceMgr.PrintMessage(
257 err.fileLoc, SourceMgr::DK_Error,
258 "expected error \"" + err.substring + "\" was not produced", range);
259 opt_result = OptFailure;
260 }
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700261 }
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700262
263 fileOffset += subbuffer.size() + strlen(marker);
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700264 }
265
Jacques Pienaar7b829702018-07-03 13:24:09 -0700266 return opt_result;
Jacques Pienaarbae40512018-06-24 09:10:36 -0700267}
268
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700269int main(int argc, char **argv) {
Chris Lattnerf7bdf952018-08-05 21:12:29 -0700270 llvm::PrettyStackTraceProgram x(argc, argv);
271 InitLLVM y(argc, argv);
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700272
Chris Lattnere2259872018-06-21 15:22:42 -0700273 cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700274
Chris Lattnere79379a2018-06-22 10:39:19 -0700275 // Set up the input file.
276 auto fileOrErr = MemoryBuffer::getFileOrSTDIN(inputFilename);
277 if (std::error_code error = fileOrErr.getError()) {
278 llvm::errs() << argv[0] << ": could not open input file '" << inputFilename
279 << "': " << error.message() << "\n";
280 return 1;
281 }
282
Jacques Pienaarbae40512018-06-24 09:10:36 -0700283 if (checkParserErrors)
Jacques Pienaar7b829702018-07-03 13:24:09 -0700284 return splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700285
Jacques Pienaar7b829702018-07-03 13:24:09 -0700286 return parseAndPrintMemoryBuffer(std::move(*fileOrErr));
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700287}