Implement parser and lexer support for most of the type grammar.
Semi-affine maps and address spaces are not yet supported (someone want to take
this on?). We also don't generate IR objects for types yet, which I plan to
tackle next.
PiperOrigin-RevId: 201754283
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index abad611..6dde8c0 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -73,7 +73,29 @@
consumeToken();
}
+ /// If the current token has the specified kind, consume it and return true.
+ /// If not, return false.
+ bool consumeIf(Token::TokenKind kind) {
+ if (curToken.isNot(kind))
+ return false;
+ consumeToken(kind);
+ return true;
+ }
+
+ ParseResult parseCommaSeparatedList(Token::TokenKind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList = true);
+
// Type parsing.
+ ParseResult parsePrimitiveType();
+ ParseResult parseElementType();
+ ParseResult parseVectorType();
+ ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
+ ParseResult parseTensorType();
+ ParseResult parseMemRefType();
+ ParseResult parseFunctionType();
+ ParseResult parseType();
+ ParseResult parseTypeList();
// Top level entity parsing.
ParseResult parseFunctionSignature(StringRef &name);
@@ -86,6 +108,11 @@
//===----------------------------------------------------------------------===//
ParseResult Parser::emitError(const Twine &message) {
+ // If we hit a parse error in response to a lexer error, then the lexer
+ // already emitted an error.
+ if (curToken.is(Token::error))
+ return ParseFailure;
+
// TODO(clattner): If/when we want to implement a -verify mode, this will need
// to package up errors into SMDiagnostic and report them.
lex.getSourceMgr().PrintMessage(curToken.getLoc(), SourceMgr::DK_Error,
@@ -93,12 +120,318 @@
return ParseFailure;
}
+/// Parse a comma-separated list of elements, terminated with an arbitrary
+/// token. This allows empty lists if allowEmptyList is true.
+///
+/// abstract-list ::= rightToken // if allowEmptyList == true
+/// abstract-list ::= element (',' element)* rightToken
+///
+ParseResult Parser::
+parseCommaSeparatedList(Token::TokenKind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList) {
+ // Handle the empty case.
+ if (curToken.is(rightToken)) {
+ if (!allowEmptyList)
+ return emitError("expected list element");
+ consumeToken(rightToken);
+ return ParseSuccess;
+ }
+
+ // Non-empty case starts with an element.
+ if (parseElement())
+ return ParseFailure;
+
+ // Otherwise we have a list of comma separated elements.
+ while (consumeIf(Token::comma)) {
+ if (parseElement())
+ return ParseFailure;
+ }
+
+ // Consume the end character.
+ if (!consumeIf(rightToken))
+ return emitError("expected ',' or ')'");
+
+ return ParseSuccess;
+}
//===----------------------------------------------------------------------===//
// Type Parsing
//===----------------------------------------------------------------------===//
-// ... TODO
+/// Parse the low-level fixed dtypes in the system.
+///
+/// primitive-type
+/// ::= `f16` | `bf16` | `f32` | `f64` // Floating point
+/// | `i1` | `i8` | `i16` | `i32` | `i64` // Sized integers
+/// | `int`
+///
+ParseResult Parser::parsePrimitiveType() {
+ // TODO: Build IR objects.
+ switch (curToken.getKind()) {
+ default: return emitError("expected type");
+ case Token::kw_bf16:
+ consumeToken(Token::kw_bf16);
+ return ParseSuccess;
+ case Token::kw_f16:
+ consumeToken(Token::kw_f16);
+ return ParseSuccess;
+ case Token::kw_f32:
+ consumeToken(Token::kw_f32);
+ return ParseSuccess;
+ case Token::kw_f64:
+ consumeToken(Token::kw_f64);
+ return ParseSuccess;
+ case Token::kw_i1:
+ consumeToken(Token::kw_i1);
+ return ParseSuccess;
+ case Token::kw_i16:
+ consumeToken(Token::kw_i16);
+ return ParseSuccess;
+ case Token::kw_i32:
+ consumeToken(Token::kw_i32);
+ return ParseSuccess;
+ case Token::kw_i64:
+ consumeToken(Token::kw_i64);
+ return ParseSuccess;
+ case Token::kw_i8:
+ consumeToken(Token::kw_i8);
+ return ParseSuccess;
+ case Token::kw_int:
+ consumeToken(Token::kw_int);
+ return ParseSuccess;
+ }
+}
+
+/// Parse the element type of a tensor or memref type.
+///
+/// element-type ::= primitive-type | vector-type
+///
+ParseResult Parser::parseElementType() {
+ if (curToken.is(Token::kw_vector))
+ return parseVectorType();
+
+ return parsePrimitiveType();
+}
+
+/// Parse a vector type.
+///
+/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
+/// const-dimension-list ::= (integer-literal `x`)+
+///
+ParseResult Parser::parseVectorType() {
+ consumeToken(Token::kw_vector);
+
+ if (!consumeIf(Token::less))
+ return emitError("expected '<' in vector type");
+
+ if (curToken.isNot(Token::integer))
+ return emitError("expected dimension size in vector type");
+
+ SmallVector<unsigned, 4> dimensions;
+ while (curToken.is(Token::integer)) {
+ // Make sure this integer value is in bound and valid.
+ auto dimension = curToken.getUnsignedIntegerValue();
+ if (!dimension.hasValue())
+ return emitError("invalid dimension in vector type");
+ dimensions.push_back(dimension.getValue());
+
+ consumeToken(Token::integer);
+
+ // Make sure we have an 'x' or something like 'xbf32'.
+ if (curToken.isNot(Token::bare_identifier) ||
+ curToken.getSpelling()[0] != 'x')
+ return emitError("expected 'x' in vector dimension list");
+
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+ if (curToken.getSpelling().size() != 1)
+ lex.resetPointer(curToken.getSpelling().data()+1);
+
+ // Consume the 'x'.
+ consumeToken(Token::bare_identifier);
+ }
+
+ // Parse the element type.
+ if (parsePrimitiveType())
+ return ParseFailure;
+
+ if (!consumeIf(Token::greater))
+ return emitError("expected '>' in vector type");
+
+ // TODO: Form IR object.
+
+ return ParseSuccess;
+}
+
+/// Parse a dimension list of a tensor or memref type. This populates the
+/// dimension list, returning -1 for the '?' dimensions.
+///
+/// dimension-list-ranked ::= (dimension `x`)*
+/// dimension ::= `?` | integer-literal
+///
+ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
+ while (curToken.isAny(Token::integer, Token::question)) {
+ if (consumeIf(Token::question)) {
+ dimensions.push_back(-1);
+ } else {
+ // Make sure this integer value is in bound and valid.
+ auto dimension = curToken.getUnsignedIntegerValue();
+ if (!dimension.hasValue() || (int)dimension.getValue() < 0)
+ return emitError("invalid dimension");
+ dimensions.push_back((int)dimension.getValue());
+ consumeToken(Token::integer);
+ }
+
+ // Make sure we have an 'x' or something like 'xbf32'.
+ if (curToken.isNot(Token::bare_identifier) ||
+ curToken.getSpelling()[0] != 'x')
+ return emitError("expected 'x' in dimension list");
+
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+ if (curToken.getSpelling().size() != 1)
+ lex.resetPointer(curToken.getSpelling().data()+1);
+
+ // Consume the 'x'.
+ consumeToken(Token::bare_identifier);
+ }
+
+ return ParseSuccess;
+}
+
+/// Parse a tensor type.
+///
+/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
+/// dimension-list ::= dimension-list-ranked | `??`
+///
+ParseResult Parser::parseTensorType() {
+ consumeToken(Token::kw_tensor);
+
+ if (!consumeIf(Token::less))
+ return emitError("expected '<' in tensor type");
+
+ bool isUnranked;
+ SmallVector<int, 4> dimensions;
+
+ if (consumeIf(Token::questionquestion)) {
+ isUnranked = true;
+ } else {
+ isUnranked = false;
+ if (parseDimensionListRanked(dimensions))
+ return ParseFailure;
+ }
+
+ // Parse the element type.
+ if (parseElementType())
+ return ParseFailure;
+
+ if (!consumeIf(Token::greater))
+ return emitError("expected '>' in tensor type");
+
+ // TODO: Form IR object.
+
+ return ParseSuccess;
+}
+
+/// Parse a memref type.
+///
+/// memref-type ::= `memref` `<` dimension-list-ranked element-type
+/// (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
+///
+/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
+/// memory-space ::= integer-literal /* | TODO: address-space-id */
+///
+ParseResult Parser::parseMemRefType() {
+ consumeToken(Token::kw_memref);
+
+ if (!consumeIf(Token::less))
+ return emitError("expected '<' in memref type");
+
+ SmallVector<int, 4> dimensions;
+ if (parseDimensionListRanked(dimensions))
+ return ParseFailure;
+
+ // Parse the element type.
+ if (parseElementType())
+ return ParseFailure;
+
+ // TODO: Parse semi-affine-map-composition.
+ // TODO: Parse memory-space.
+
+ if (!consumeIf(Token::greater))
+ return emitError("expected '>' in memref type");
+
+ // TODO: Form IR object.
+
+ return ParseSuccess;
+}
+
+
+
+/// Parse a function type.
+///
+/// function-type ::= type-list-parens `->` type-list
+///
+ParseResult Parser::parseFunctionType() {
+ assert(curToken.is(Token::l_paren));
+
+ if (parseTypeList())
+ return ParseFailure;
+
+ if (!consumeIf(Token::arrow))
+ return emitError("expected '->' in function type");
+
+ if (parseTypeList())
+ return ParseFailure;
+
+ // TODO: Build IR object.
+ return ParseSuccess;
+}
+
+
+/// Parse an arbitrary type.
+///
+/// type ::= primitive-type
+/// | vector-type
+/// | tensor-type
+/// | memref-type
+/// | function-type
+/// element-type ::= primitive-type | vector-type
+///
+ParseResult Parser::parseType() {
+ switch (curToken.getKind()) {
+ case Token::kw_memref: return parseMemRefType();
+ case Token::kw_tensor: return parseTensorType();
+ case Token::kw_vector: return parseVectorType();
+ case Token::l_paren: return parseFunctionType();
+ default:
+ return parsePrimitiveType();
+ }
+}
+
+/// Parse a "type list", which is a singular type, or a parenthesized list of
+/// types.
+///
+/// type-list ::= type-list-parens | type
+/// type-list-parens ::= `(` `)`
+/// | `(` type (`,` type)* `)`
+///
+ParseResult Parser::parseTypeList() {
+ // If there is no parens, then it must be a singular type.
+ if (!consumeIf(Token::l_paren))
+ return parseType();
+
+ if (parseCommaSeparatedList(Token::r_paren,
+ [&]() -> ParseResult {
+ // TODO: Add to list of IR values we're parsing.
+ return parseType();
+ })) {
+ return ParseFailure;
+ }
+
+ // TODO: Build IR objects.
+ return ParseSuccess;
+}
+
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
@@ -119,13 +452,17 @@
if (curToken.isNot(Token::l_paren))
return emitError("expected '(' in function signature");
- consumeToken(Token::l_paren);
- // TODO: This should actually parse the full grammar here.
+ if (parseTypeList())
+ return ParseFailure;
- if (curToken.isNot(Token::r_paren))
- return emitError("expected ')' in function signature");
- consumeToken(Token::r_paren);
+ // Parse the return type if present.
+ if (consumeIf(Token::arrow)) {
+ if (parseTypeList())
+ return ParseFailure;
+
+ // TODO: Build IR object.
+ }
return ParseSuccess;
}