blob: 5bbf6b7bda5381c94d7a353dfcb202961c92a234 [file] [log] [blame]
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -07001//===- mlir-opt.cpp - MLIR Optimizer Driver -------------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This is a command line utility that parses an MLIR file, runs an optimization
19// pass, then prints the result back out. It is designed to support unit
20// testing.
21//
22//===----------------------------------------------------------------------===//
23
Chris Lattnerf7bdf952018-08-05 21:12:29 -070024#include "mlir/IR/Attributes.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070025#include "mlir/IR/MLFunction.h"
Chris Lattnerf7e22732018-06-22 22:03:48 -070026#include "mlir/IR/MLIRContext.h"
Chris Lattnere2259872018-06-21 15:22:42 -070027#include "mlir/IR/Module.h"
Chris Lattnere79379a2018-06-22 10:39:19 -070028#include "mlir/Parser.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070029#include "mlir/TensorFlow/ControlFlowOps.h"
30#include "mlir/TensorFlow/Passes.h"
Uday Bondhugula6c1f6602018-08-13 17:25:13 -070031#include "mlir/Transforms/Pass.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070032#include "mlir/Transforms/Passes.h"
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070033#include "llvm/Support/CommandLine.h"
Chris Lattnere2259872018-06-21 15:22:42 -070034#include "llvm/Support/FileUtilities.h"
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070035#include "llvm/Support/InitLLVM.h"
Chris Lattnerf7bdf952018-08-05 21:12:29 -070036#include "llvm/Support/PrettyStackTrace.h"
Jacques Pienaarca4c4a02018-06-25 08:10:46 -070037#include "llvm/Support/Regex.h"
Jacques Pienaar39ffa102018-07-07 19:12:22 -070038#include "llvm/Support/SourceMgr.h"
Chris Lattnere2259872018-06-21 15:22:42 -070039#include "llvm/Support/ToolOutputFile.h"
Uday Bondhugula83a41c92018-08-30 17:35:15 -070040
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070041using namespace mlir;
Chris Lattnere2259872018-06-21 15:22:42 -070042using namespace llvm;
43
44static cl::opt<std::string>
45inputFilename(cl::Positional, cl::desc("<input file>"), cl::init("-"));
46
47static cl::opt<std::string>
48outputFilename("o", cl::desc("Output filename"), cl::value_desc("filename"),
49 cl::init("-"));
50
Jacques Pienaarbae40512018-06-24 09:10:36 -070051static cl::opt<bool>
52checkParserErrors("check-parser-errors", cl::desc("Check for parser errors"),
53 cl::init(false));
Chris Lattnere2259872018-06-21 15:22:42 -070054
Chris Lattnerdc3ba382018-07-29 14:13:03 -070055enum Passes {
56 ConvertToCFG,
Uday Bondhugula67701712018-08-21 16:01:23 -070057 LoopUnroll,
Uday Bondhugula6cd35022018-08-28 18:24:27 -070058 LoopUnrollAndJam,
Uday Bondhugula83a41c92018-08-30 17:35:15 -070059 SimplifyAffineExpr,
Chris Lattnerdc3ba382018-07-29 14:13:03 -070060 TFRaiseControlFlow,
61};
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070062
Chris Lattnerdc3ba382018-07-29 14:13:03 -070063static cl::list<Passes> passList(
64 "", cl::desc("Compiler passes to run"),
65 cl::values(clEnumValN(ConvertToCFG, "convert-to-cfg",
66 "Convert all ML functions in the module to CFG ones"),
Uday Bondhugula67701712018-08-21 16:01:23 -070067 clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
Uday Bondhugula6cd35022018-08-28 18:24:27 -070068 clEnumValN(LoopUnrollAndJam, "loop-unroll-jam",
69 "Unroll and jam loops"),
Uday Bondhugula83a41c92018-08-30 17:35:15 -070070 clEnumValN(SimplifyAffineExpr, "simplify-affine-expr",
71 "Simplify affine expressions"),
Chris Lattnerdc3ba382018-07-29 14:13:03 -070072 clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
73 "Dynamic TensorFlow Switch/Match nodes to a CFG")));
Chris Lattneree0c2ae2018-07-29 12:37:35 -070074
Jacques Pienaar7b829702018-07-03 13:24:09 -070075enum OptResult { OptSuccess, OptFailure };
76
Chris Lattnere2259872018-06-21 15:22:42 -070077/// Open the specified output file and return it, exiting if there is any I/O or
78/// other errors.
79static std::unique_ptr<ToolOutputFile> getOutputStream() {
80 std::error_code error;
MLIR Team61eadaa2018-07-30 15:00:47 -070081 auto result =
82 llvm::make_unique<ToolOutputFile>(outputFilename, error, sys::fs::F_None);
Chris Lattnere2259872018-06-21 15:22:42 -070083 if (error) {
84 llvm::errs() << error.message() << '\n';
85 exit(1);
86 }
87
88 return result;
89}
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -070090
Jacques Pienaarad804b12018-08-16 17:26:08 -070091// The function to initialize the MLIRContext for different ops is defined in
92// another compilation unit to allow different tests to link in different
93// context initializations (e.g., op registrations).
94extern void initializeMLIRContext(MLIRContext *ctx);
Chris Lattneree0c2ae2018-07-29 12:37:35 -070095
Jacques Pienaarbae40512018-06-24 09:10:36 -070096/// Parses the memory buffer and, if successfully parsed, prints the parsed
Tatiana Shpeisman6708b452018-07-24 10:15:13 -070097/// output. Optionally, convert ML functions into CFG functions.
98/// TODO: pull parsing and printing into separate functions.
Jacques Pienaar7b829702018-07-03 13:24:09 -070099OptResult parseAndPrintMemoryBuffer(std::unique_ptr<MemoryBuffer> buffer) {
Jacques Pienaarbae40512018-06-24 09:10:36 -0700100 // Tell sourceMgr about this buffer, which is what the parser will pick up.
101 SourceMgr sourceMgr;
102 sourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
103
Jacques Pienaar9c411be2018-06-24 19:17:35 -0700104 // Parse the input file.
Jacques Pienaarbae40512018-06-24 09:10:36 -0700105 MLIRContext context;
Jacques Pienaarad804b12018-08-16 17:26:08 -0700106 initializeMLIRContext(&context);
Jacques Pienaar7b829702018-07-03 13:24:09 -0700107 std::unique_ptr<Module> module(parseSourceFile(sourceMgr, &context));
108 if (!module)
109 return OptFailure;
Jacques Pienaarbae40512018-06-24 09:10:36 -0700110
Chris Lattnerdc3ba382018-07-29 14:13:03 -0700111 // Run each of the passes that were selected.
112 for (auto passKind : passList) {
113 Pass *pass = nullptr;
114 switch (passKind) {
115 case ConvertToCFG:
116 pass = createConvertToCFGPass();
117 break;
Uday Bondhugula67701712018-08-21 16:01:23 -0700118 case LoopUnroll:
Uday Bondhugula6cd35022018-08-28 18:24:27 -0700119 pass = createLoopUnrollPass();
120 break;
121 case LoopUnrollAndJam:
122 pass = createLoopUnrollAndJamPass();
Uday Bondhugula134154e2018-08-06 18:40:34 -0700123 break;
Uday Bondhugula83a41c92018-08-30 17:35:15 -0700124 case SimplifyAffineExpr:
125 pass = createSimplifyAffineExprPass();
126 break;
Chris Lattnerdc3ba382018-07-29 14:13:03 -0700127 case TFRaiseControlFlow:
128 pass = createRaiseTFControlFlowPass();
129 break;
130 }
Tatiana Shpeisman6708b452018-07-24 10:15:13 -0700131
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700132 pass->runOnModule(module.get());
133 delete pass;
134 module->verify();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700135 }
136
Jacques Pienaarbae40512018-06-24 09:10:36 -0700137 // Print the output.
138 auto output = getOutputStream();
139 module->print(output->os());
140 output->keep();
141
Jacques Pienaar7b829702018-07-03 13:24:09 -0700142 return OptSuccess;
Jacques Pienaarbae40512018-06-24 09:10:36 -0700143}
144
145/// Split the memory buffer into multiple buffers using the marker -----.
Jacques Pienaar7b829702018-07-03 13:24:09 -0700146OptResult
147splitMemoryBufferForErrorChecking(std::unique_ptr<MemoryBuffer> buffer) {
Jacques Pienaarbae40512018-06-24 09:10:36 -0700148 const char marker[] = "-----";
149 SmallVector<StringRef, 2> sourceBuffers;
150 buffer->getBuffer().split(sourceBuffers, marker);
Jacques Pienaarbae40512018-06-24 09:10:36 -0700151
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700152 // Error reporter that verifies error reports matches expected error
153 // substring.
154 // TODO: Only checking for error cases below. Could be expanded to other kinds
155 // of diagnostics.
156 // TODO: Enable specifying errors on different lines (@-1).
157 // TODO: Currently only checking if substring matches, enable regex checking.
Jacques Pienaar7b829702018-07-03 13:24:09 -0700158 OptResult opt_result = OptSuccess;
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700159 SourceMgr fileSourceMgr;
160 fileSourceMgr.AddNewSourceBuffer(std::move(buffer), SMLoc());
161
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700162 // Record the expected errors's position, substring and whether it was seen.
163 struct ExpectedError {
164 int lineNo;
165 StringRef substring;
166 SMLoc fileLoc;
167 bool matched;
168 };
169
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700170 // Tracks offset of subbuffer into original buffer.
171 const char *fileOffset =
172 fileSourceMgr.getMemoryBuffer(fileSourceMgr.getMainFileID())
173 ->getBufferStart();
174
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700175 for (auto &subbuffer : sourceBuffers) {
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700176 SourceMgr sourceMgr;
177 // Tell sourceMgr about this buffer, which is what the parser will pick up.
Chris Lattnerea5c3dc2018-08-21 08:42:19 -0700178 sourceMgr.AddNewSourceBuffer(MemoryBuffer::getMemBufferCopy(subbuffer),
179 SMLoc());
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700180
Chris Lattnerf7bdf952018-08-05 21:12:29 -0700181 // Extract the expected errors.
Chris Lattnerea5c3dc2018-08-21 08:42:19 -0700182 llvm::Regex expected("expected-error *(@[+-][0-9]+)? *{{(.*)}}");
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700183 SmallVector<ExpectedError, 2> expectedErrors;
184 SmallVector<StringRef, 100> lines;
185 subbuffer.split(lines, '\n');
186 size_t bufOffset = 0;
187 for (int lineNo = 0; lineNo < lines.size(); ++lineNo) {
188 SmallVector<StringRef, 3> matches;
189 if (expected.match(lines[lineNo], &matches)) {
190 // Point to the start of expected-error.
191 SMLoc errorStart =
192 SMLoc::getFromPointer(fileOffset + bufOffset +
193 lines[lineNo].size() - matches[2].size() - 2);
194 ExpectedError expErr{lineNo + 1, matches[2], errorStart, false};
195 int offset;
196 if (!matches[1].empty() &&
197 !matches[1].drop_front().getAsInteger(0, offset)) {
198 expErr.lineNo += offset;
199 }
200 expectedErrors.push_back(expErr);
201 }
202 bufOffset += lines[lineNo].size() + 1;
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700203 }
204
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700205 // Error checker that verifies reported error was expected.
206 auto checker = [&](const SMDiagnostic &err) {
207 for (auto &e : expectedErrors) {
208 if (err.getLineNo() == e.lineNo &&
209 err.getMessage().contains(e.substring)) {
210 e.matched = true;
211 return;
212 }
213 }
214 // Report error if no match found.
215 const auto &sourceMgr = *err.getSourceMgr();
216 const char *bufferStart =
217 sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID())
218 ->getBufferStart();
219
220 size_t offset = err.getLoc().getPointer() - bufferStart;
221 SMLoc loc = SMLoc::getFromPointer(fileOffset + offset);
222 fileSourceMgr.PrintMessage(loc, SourceMgr::DK_Error,
223 "unexpected error: " + err.getMessage());
224 opt_result = OptFailure;
225 };
226
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700227 // Parse the input file.
228 MLIRContext context;
Jacques Pienaarad804b12018-08-16 17:26:08 -0700229 initializeMLIRContext(&context);
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700230 std::unique_ptr<Module> module(
231 parseSourceFile(sourceMgr, &context, checker));
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700232
Jacques Pienaar39ffa102018-07-07 19:12:22 -0700233 // Verify that all expected errors were seen.
234 for (auto err : expectedErrors) {
235 if (!err.matched) {
236 SMRange range(err.fileLoc,
237 SMLoc::getFromPointer(err.fileLoc.getPointer() +
238 err.substring.size()));
239 fileSourceMgr.PrintMessage(
240 err.fileLoc, SourceMgr::DK_Error,
241 "expected error \"" + err.substring + "\" was not produced", range);
242 opt_result = OptFailure;
243 }
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700244 }
Jacques Pienaarb2ddbb62018-06-26 08:56:55 -0700245
246 fileOffset += subbuffer.size() + strlen(marker);
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700247 }
248
Jacques Pienaar7b829702018-07-03 13:24:09 -0700249 return opt_result;
Jacques Pienaarbae40512018-06-24 09:10:36 -0700250}
251
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700252int main(int argc, char **argv) {
Chris Lattnerf7bdf952018-08-05 21:12:29 -0700253 llvm::PrettyStackTraceProgram x(argc, argv);
254 InitLLVM y(argc, argv);
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700255
Chris Lattnere2259872018-06-21 15:22:42 -0700256 cl::ParseCommandLineOptions(argc, argv, "MLIR modular optimizer driver\n");
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700257
Chris Lattnere79379a2018-06-22 10:39:19 -0700258 // Set up the input file.
259 auto fileOrErr = MemoryBuffer::getFileOrSTDIN(inputFilename);
260 if (std::error_code error = fileOrErr.getError()) {
261 llvm::errs() << argv[0] << ": could not open input file '" << inputFilename
262 << "': " << error.message() << "\n";
263 return 1;
264 }
265
Jacques Pienaarbae40512018-06-24 09:10:36 -0700266 if (checkParserErrors)
Jacques Pienaar7b829702018-07-03 13:24:09 -0700267 return splitMemoryBufferForErrorChecking(std::move(*fileOrErr));
Jacques Pienaarca4c4a02018-06-25 08:10:46 -0700268
Jacques Pienaar7b829702018-07-03 13:24:09 -0700269 return parseAndPrintMemoryBuffer(std::move(*fileOrErr));
Chris Lattnerc0c5e0f2018-06-21 09:49:33 -0700270}