blob: 6dde8c062d80a5919a3034b92b825a5eee74ecf6 [file] [log] [blame]
Chris Lattnere79379a2018-06-22 10:39:19 -07001//===- Parser.cpp - MLIR Parser Implementation ----------------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements the parser for the MLIR textual form.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Parser.h"
23#include "Lexer.h"
24#include "mlir/IR/Module.h"
25#include "llvm/Support/SourceMgr.h"
26using namespace mlir;
27using llvm::SourceMgr;
28
29namespace {
30/// Simple enum to make code read better. Failure is "true" in a boolean
31/// context.
32enum ParseResult {
33 ParseSuccess,
34 ParseFailure
35};
36
37/// Main parser implementation.
38class Parser {
39 public:
40 Parser(llvm::SourceMgr &sourceMgr) : lex(sourceMgr), curToken(lex.lexToken()){
41 module.reset(new Module());
42 }
43
44 Module *parseModule();
45private:
46 // State.
47 Lexer lex;
48
49 // This is the next token that hasn't been consumed yet.
50 Token curToken;
51
52 // This is the result module we are parsing into.
53 std::unique_ptr<Module> module;
54
55private:
56 // Helper methods.
57
58 /// Emit an error and return failure.
59 ParseResult emitError(const Twine &message);
60
61 /// Advance the current lexer onto the next token.
62 void consumeToken() {
63 assert(curToken.isNot(Token::eof, Token::error) &&
64 "shouldn't advance past EOF or errors");
65 curToken = lex.lexToken();
66 }
67
68 /// Advance the current lexer onto the next token, asserting what the expected
69 /// current token is. This is preferred to the above method because it leads
70 /// to more self-documenting code with better checking.
71 void consumeToken(Token::TokenKind kind) {
72 assert(curToken.is(kind) && "consumed an unexpected token");
73 consumeToken();
74 }
75
Chris Lattnerbb8fafc2018-06-22 15:52:02 -070076 /// If the current token has the specified kind, consume it and return true.
77 /// If not, return false.
78 bool consumeIf(Token::TokenKind kind) {
79 if (curToken.isNot(kind))
80 return false;
81 consumeToken(kind);
82 return true;
83 }
84
85 ParseResult parseCommaSeparatedList(Token::TokenKind rightToken,
86 const std::function<ParseResult()> &parseElement,
87 bool allowEmptyList = true);
88
Chris Lattnere79379a2018-06-22 10:39:19 -070089 // Type parsing.
Chris Lattnerbb8fafc2018-06-22 15:52:02 -070090 ParseResult parsePrimitiveType();
91 ParseResult parseElementType();
92 ParseResult parseVectorType();
93 ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
94 ParseResult parseTensorType();
95 ParseResult parseMemRefType();
96 ParseResult parseFunctionType();
97 ParseResult parseType();
98 ParseResult parseTypeList();
Chris Lattnere79379a2018-06-22 10:39:19 -070099
100 // Top level entity parsing.
101 ParseResult parseFunctionSignature(StringRef &name);
102 ParseResult parseExtFunc();
103};
104} // end anonymous namespace
105
106//===----------------------------------------------------------------------===//
107// Helper methods.
108//===----------------------------------------------------------------------===//
109
110ParseResult Parser::emitError(const Twine &message) {
Chris Lattnerbb8fafc2018-06-22 15:52:02 -0700111 // If we hit a parse error in response to a lexer error, then the lexer
112 // already emitted an error.
113 if (curToken.is(Token::error))
114 return ParseFailure;
115
Chris Lattnere79379a2018-06-22 10:39:19 -0700116 // TODO(clattner): If/when we want to implement a -verify mode, this will need
117 // to package up errors into SMDiagnostic and report them.
118 lex.getSourceMgr().PrintMessage(curToken.getLoc(), SourceMgr::DK_Error,
119 message);
120 return ParseFailure;
121}
122
Chris Lattnerbb8fafc2018-06-22 15:52:02 -0700123/// Parse a comma-separated list of elements, terminated with an arbitrary
124/// token. This allows empty lists if allowEmptyList is true.
125///
126/// abstract-list ::= rightToken // if allowEmptyList == true
127/// abstract-list ::= element (',' element)* rightToken
128///
129ParseResult Parser::
130parseCommaSeparatedList(Token::TokenKind rightToken,
131 const std::function<ParseResult()> &parseElement,
132 bool allowEmptyList) {
133 // Handle the empty case.
134 if (curToken.is(rightToken)) {
135 if (!allowEmptyList)
136 return emitError("expected list element");
137 consumeToken(rightToken);
138 return ParseSuccess;
139 }
140
141 // Non-empty case starts with an element.
142 if (parseElement())
143 return ParseFailure;
144
145 // Otherwise we have a list of comma separated elements.
146 while (consumeIf(Token::comma)) {
147 if (parseElement())
148 return ParseFailure;
149 }
150
151 // Consume the end character.
152 if (!consumeIf(rightToken))
153 return emitError("expected ',' or ')'");
154
155 return ParseSuccess;
156}
Chris Lattnere79379a2018-06-22 10:39:19 -0700157
158//===----------------------------------------------------------------------===//
159// Type Parsing
160//===----------------------------------------------------------------------===//
161
Chris Lattnerbb8fafc2018-06-22 15:52:02 -0700162/// Parse the low-level fixed dtypes in the system.
163///
164/// primitive-type
165/// ::= `f16` | `bf16` | `f32` | `f64` // Floating point
166/// | `i1` | `i8` | `i16` | `i32` | `i64` // Sized integers
167/// | `int`
168///
169ParseResult Parser::parsePrimitiveType() {
170 // TODO: Build IR objects.
171 switch (curToken.getKind()) {
172 default: return emitError("expected type");
173 case Token::kw_bf16:
174 consumeToken(Token::kw_bf16);
175 return ParseSuccess;
176 case Token::kw_f16:
177 consumeToken(Token::kw_f16);
178 return ParseSuccess;
179 case Token::kw_f32:
180 consumeToken(Token::kw_f32);
181 return ParseSuccess;
182 case Token::kw_f64:
183 consumeToken(Token::kw_f64);
184 return ParseSuccess;
185 case Token::kw_i1:
186 consumeToken(Token::kw_i1);
187 return ParseSuccess;
188 case Token::kw_i16:
189 consumeToken(Token::kw_i16);
190 return ParseSuccess;
191 case Token::kw_i32:
192 consumeToken(Token::kw_i32);
193 return ParseSuccess;
194 case Token::kw_i64:
195 consumeToken(Token::kw_i64);
196 return ParseSuccess;
197 case Token::kw_i8:
198 consumeToken(Token::kw_i8);
199 return ParseSuccess;
200 case Token::kw_int:
201 consumeToken(Token::kw_int);
202 return ParseSuccess;
203 }
204}
205
206/// Parse the element type of a tensor or memref type.
207///
208/// element-type ::= primitive-type | vector-type
209///
210ParseResult Parser::parseElementType() {
211 if (curToken.is(Token::kw_vector))
212 return parseVectorType();
213
214 return parsePrimitiveType();
215}
216
217/// Parse a vector type.
218///
219/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
220/// const-dimension-list ::= (integer-literal `x`)+
221///
222ParseResult Parser::parseVectorType() {
223 consumeToken(Token::kw_vector);
224
225 if (!consumeIf(Token::less))
226 return emitError("expected '<' in vector type");
227
228 if (curToken.isNot(Token::integer))
229 return emitError("expected dimension size in vector type");
230
231 SmallVector<unsigned, 4> dimensions;
232 while (curToken.is(Token::integer)) {
233 // Make sure this integer value is in bound and valid.
234 auto dimension = curToken.getUnsignedIntegerValue();
235 if (!dimension.hasValue())
236 return emitError("invalid dimension in vector type");
237 dimensions.push_back(dimension.getValue());
238
239 consumeToken(Token::integer);
240
241 // Make sure we have an 'x' or something like 'xbf32'.
242 if (curToken.isNot(Token::bare_identifier) ||
243 curToken.getSpelling()[0] != 'x')
244 return emitError("expected 'x' in vector dimension list");
245
246 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
247 if (curToken.getSpelling().size() != 1)
248 lex.resetPointer(curToken.getSpelling().data()+1);
249
250 // Consume the 'x'.
251 consumeToken(Token::bare_identifier);
252 }
253
254 // Parse the element type.
255 if (parsePrimitiveType())
256 return ParseFailure;
257
258 if (!consumeIf(Token::greater))
259 return emitError("expected '>' in vector type");
260
261 // TODO: Form IR object.
262
263 return ParseSuccess;
264}
265
266/// Parse a dimension list of a tensor or memref type. This populates the
267/// dimension list, returning -1 for the '?' dimensions.
268///
269/// dimension-list-ranked ::= (dimension `x`)*
270/// dimension ::= `?` | integer-literal
271///
272ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
273 while (curToken.isAny(Token::integer, Token::question)) {
274 if (consumeIf(Token::question)) {
275 dimensions.push_back(-1);
276 } else {
277 // Make sure this integer value is in bound and valid.
278 auto dimension = curToken.getUnsignedIntegerValue();
279 if (!dimension.hasValue() || (int)dimension.getValue() < 0)
280 return emitError("invalid dimension");
281 dimensions.push_back((int)dimension.getValue());
282 consumeToken(Token::integer);
283 }
284
285 // Make sure we have an 'x' or something like 'xbf32'.
286 if (curToken.isNot(Token::bare_identifier) ||
287 curToken.getSpelling()[0] != 'x')
288 return emitError("expected 'x' in dimension list");
289
290 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
291 if (curToken.getSpelling().size() != 1)
292 lex.resetPointer(curToken.getSpelling().data()+1);
293
294 // Consume the 'x'.
295 consumeToken(Token::bare_identifier);
296 }
297
298 return ParseSuccess;
299}
300
301/// Parse a tensor type.
302///
303/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
304/// dimension-list ::= dimension-list-ranked | `??`
305///
306ParseResult Parser::parseTensorType() {
307 consumeToken(Token::kw_tensor);
308
309 if (!consumeIf(Token::less))
310 return emitError("expected '<' in tensor type");
311
312 bool isUnranked;
313 SmallVector<int, 4> dimensions;
314
315 if (consumeIf(Token::questionquestion)) {
316 isUnranked = true;
317 } else {
318 isUnranked = false;
319 if (parseDimensionListRanked(dimensions))
320 return ParseFailure;
321 }
322
323 // Parse the element type.
324 if (parseElementType())
325 return ParseFailure;
326
327 if (!consumeIf(Token::greater))
328 return emitError("expected '>' in tensor type");
329
330 // TODO: Form IR object.
331
332 return ParseSuccess;
333}
334
335/// Parse a memref type.
336///
337/// memref-type ::= `memref` `<` dimension-list-ranked element-type
338/// (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
339///
340/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
341/// memory-space ::= integer-literal /* | TODO: address-space-id */
342///
343ParseResult Parser::parseMemRefType() {
344 consumeToken(Token::kw_memref);
345
346 if (!consumeIf(Token::less))
347 return emitError("expected '<' in memref type");
348
349 SmallVector<int, 4> dimensions;
350 if (parseDimensionListRanked(dimensions))
351 return ParseFailure;
352
353 // Parse the element type.
354 if (parseElementType())
355 return ParseFailure;
356
357 // TODO: Parse semi-affine-map-composition.
358 // TODO: Parse memory-space.
359
360 if (!consumeIf(Token::greater))
361 return emitError("expected '>' in memref type");
362
363 // TODO: Form IR object.
364
365 return ParseSuccess;
366}
367
368
369
370/// Parse a function type.
371///
372/// function-type ::= type-list-parens `->` type-list
373///
374ParseResult Parser::parseFunctionType() {
375 assert(curToken.is(Token::l_paren));
376
377 if (parseTypeList())
378 return ParseFailure;
379
380 if (!consumeIf(Token::arrow))
381 return emitError("expected '->' in function type");
382
383 if (parseTypeList())
384 return ParseFailure;
385
386 // TODO: Build IR object.
387 return ParseSuccess;
388}
389
390
391/// Parse an arbitrary type.
392///
393/// type ::= primitive-type
394/// | vector-type
395/// | tensor-type
396/// | memref-type
397/// | function-type
398/// element-type ::= primitive-type | vector-type
399///
400ParseResult Parser::parseType() {
401 switch (curToken.getKind()) {
402 case Token::kw_memref: return parseMemRefType();
403 case Token::kw_tensor: return parseTensorType();
404 case Token::kw_vector: return parseVectorType();
405 case Token::l_paren: return parseFunctionType();
406 default:
407 return parsePrimitiveType();
408 }
409}
410
411/// Parse a "type list", which is a singular type, or a parenthesized list of
412/// types.
413///
414/// type-list ::= type-list-parens | type
415/// type-list-parens ::= `(` `)`
416/// | `(` type (`,` type)* `)`
417///
418ParseResult Parser::parseTypeList() {
419 // If there is no parens, then it must be a singular type.
420 if (!consumeIf(Token::l_paren))
421 return parseType();
422
423 if (parseCommaSeparatedList(Token::r_paren,
424 [&]() -> ParseResult {
425 // TODO: Add to list of IR values we're parsing.
426 return parseType();
427 })) {
428 return ParseFailure;
429 }
430
431 // TODO: Build IR objects.
432 return ParseSuccess;
433}
434
Chris Lattnere79379a2018-06-22 10:39:19 -0700435
436//===----------------------------------------------------------------------===//
437// Top-level entity parsing.
438//===----------------------------------------------------------------------===//
439
440/// Parse a function signature, starting with a name and including the parameter
441/// list.
442///
443/// argument-list ::= type (`,` type)* | /*empty*/
444/// function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
445///
446ParseResult Parser::parseFunctionSignature(StringRef &name) {
447 if (curToken.isNot(Token::at_identifier))
448 return emitError("expected a function identifier like '@foo'");
449
450 name = curToken.getSpelling().drop_front();
451 consumeToken(Token::at_identifier);
452
453 if (curToken.isNot(Token::l_paren))
454 return emitError("expected '(' in function signature");
Chris Lattnere79379a2018-06-22 10:39:19 -0700455
Chris Lattnerbb8fafc2018-06-22 15:52:02 -0700456 if (parseTypeList())
457 return ParseFailure;
Chris Lattnere79379a2018-06-22 10:39:19 -0700458
Chris Lattnerbb8fafc2018-06-22 15:52:02 -0700459 // Parse the return type if present.
460 if (consumeIf(Token::arrow)) {
461 if (parseTypeList())
462 return ParseFailure;
463
464 // TODO: Build IR object.
465 }
Chris Lattnere79379a2018-06-22 10:39:19 -0700466
467 return ParseSuccess;
468}
469
470
471/// External function declarations.
472///
473/// ext-func ::= `extfunc` function-signature
474///
475ParseResult Parser::parseExtFunc() {
476 consumeToken(Token::kw_extfunc);
477
478 StringRef name;
479 if (parseFunctionSignature(name))
480 return ParseFailure;
481
482
483 // Okay, the external function definition was parsed correctly.
484 module->functionList.push_back(new Function(name));
485 return ParseSuccess;
486}
487
488
489/// This is the top-level module parser.
490Module *Parser::parseModule() {
491 while (1) {
492 switch (curToken.getKind()) {
493 default:
494 emitError("expected a top level entity");
495 return nullptr;
496
497 // If we got to the end of the file, then we're done.
498 case Token::eof:
499 return module.release();
500
501 // If we got an error token, then the lexer already emitted an error, just
502 // stop. Someday we could introduce error recovery if there was demand for
503 // it.
504 case Token::error:
505 return nullptr;
506
507 case Token::kw_extfunc:
508 if (parseExtFunc())
509 return nullptr;
510 break;
511
512 // TODO: cfgfunc, mlfunc, affine entity declarations, etc.
513 }
514 }
515}
516
517//===----------------------------------------------------------------------===//
518
519/// This parses the file specified by the indicated SourceMgr and returns an
520/// MLIR module if it was valid. If not, it emits diagnostics and returns null.
521Module *mlir::parseSourceFile(llvm::SourceMgr &sourceMgr) {
522 return Parser(sourceMgr).parseModule();
523}