Implement parser and lexer support for most of the type grammar.
Semi-affine maps and address spaces are not yet supported (someone want to take
this on?). We also don't generate IR objects for types yet, which I plan to
tackle next.
PiperOrigin-RevId: 201754283
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 5958658..2e79271 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -66,13 +66,33 @@
// Ignore whitespace.
return lexToken();
+ case ',': return formToken(Token::comma, tokStart);
case '(': return formToken(Token::l_paren, tokStart);
case ')': return formToken(Token::r_paren, tokStart);
case '<': return formToken(Token::less, tokStart);
case '>': return formToken(Token::greater, tokStart);
+ case '-':
+ if (*curPtr == '>') {
+ ++curPtr;
+ return formToken(Token::arrow, tokStart);
+ }
+ return emitError(tokStart, "unexpected character");
+
+ case '?':
+ if (*curPtr == '?') {
+ ++curPtr;
+ return formToken(Token::questionquestion, tokStart);
+ }
+
+ return formToken(Token::question, tokStart);
+
case ';': return lexComment();
case '@': return lexAtIdentifier(tokStart);
+
+ case '0': case '1': case '2': case '3': case '4':
+ case '5': case '6': case '7': case '8': case '9':
+ return lexNumber(tokStart);
}
}
@@ -114,9 +134,22 @@
StringRef spelling(tokStart, curPtr-tokStart);
Token::TokenKind kind = llvm::StringSwitch<Token::TokenKind>(spelling)
+ .Case("bf16", Token::kw_bf16)
.Case("cfgfunc", Token::kw_cfgfunc)
.Case("extfunc", Token::kw_extfunc)
+ .Case("f16", Token::kw_f16)
+ .Case("f32", Token::kw_f32)
+ .Case("f64", Token::kw_f64)
+ .Case("i1", Token::kw_i1)
+ .Case("i16", Token::kw_i16)
+ .Case("i32", Token::kw_i32)
+ .Case("i64", Token::kw_i64)
+ .Case("i8", Token::kw_i8)
+ .Case("int", Token::kw_int)
+ .Case("memref", Token::kw_memref)
.Case("mlfunc", Token::kw_mlfunc)
+ .Case("tensor", Token::kw_tensor)
+ .Case("vector", Token::kw_vector)
.Default(Token::bare_identifier);
return Token(kind, spelling);
@@ -135,3 +168,30 @@
++curPtr;
return formToken(Token::at_identifier, tokStart);
}
+
+/// Lex an integer literal.
+///
+/// integer-literal ::= digit+ | `0x` hex_digit+
+///
+Token Lexer::lexNumber(const char *tokStart) {
+ assert(isdigit(curPtr[-1]));
+
+ // Handle the hexadecimal case.
+ if (curPtr[-1] == '0' && *curPtr == 'x') {
+ ++curPtr;
+
+ if (!isxdigit(*curPtr))
+ return emitError(curPtr, "expected hexadecimal digit");
+
+ while (isxdigit(*curPtr))
+ ++curPtr;
+
+ return formToken(Token::integer, tokStart);
+ }
+
+ // Handle the normal decimal case.
+ while (isdigit(*curPtr))
+ ++curPtr;
+
+ return formToken(Token::integer, tokStart);
+}
diff --git a/lib/Parser/Lexer.h b/lib/Parser/Lexer.h
index 5886c5c..4f364bc 100644
--- a/lib/Parser/Lexer.h
+++ b/lib/Parser/Lexer.h
@@ -46,6 +46,11 @@
Token lexToken();
+ /// Change the position of the lexer cursor. The next token we lex will start
+ /// at the designated point in the input.
+ void resetPointer(const char *newPointer) {
+ curPtr = newPointer;
+ }
private:
// Helpers.
Token formToken(Token::TokenKind kind, const char *tokStart) {
@@ -58,6 +63,7 @@
Token lexComment();
Token lexBareIdentifierOrKeyword(const char *tokStart);
Token lexAtIdentifier(const char *tokStart);
+ Token lexNumber(const char *tokStart);
};
} // end namespace mlir
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index abad611..6dde8c0 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -73,7 +73,29 @@
consumeToken();
}
+ /// If the current token has the specified kind, consume it and return true.
+ /// If not, return false.
+ bool consumeIf(Token::TokenKind kind) {
+ if (curToken.isNot(kind))
+ return false;
+ consumeToken(kind);
+ return true;
+ }
+
+ ParseResult parseCommaSeparatedList(Token::TokenKind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList = true);
+
// Type parsing.
+ ParseResult parsePrimitiveType();
+ ParseResult parseElementType();
+ ParseResult parseVectorType();
+ ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
+ ParseResult parseTensorType();
+ ParseResult parseMemRefType();
+ ParseResult parseFunctionType();
+ ParseResult parseType();
+ ParseResult parseTypeList();
// Top level entity parsing.
ParseResult parseFunctionSignature(StringRef &name);
@@ -86,6 +108,11 @@
//===----------------------------------------------------------------------===//
ParseResult Parser::emitError(const Twine &message) {
+ // If we hit a parse error in response to a lexer error, then the lexer
+ // already emitted an error.
+ if (curToken.is(Token::error))
+ return ParseFailure;
+
// TODO(clattner): If/when we want to implement a -verify mode, this will need
// to package up errors into SMDiagnostic and report them.
lex.getSourceMgr().PrintMessage(curToken.getLoc(), SourceMgr::DK_Error,
@@ -93,12 +120,318 @@
return ParseFailure;
}
+/// Parse a comma-separated list of elements, terminated with an arbitrary
+/// token. This allows empty lists if allowEmptyList is true.
+///
+/// abstract-list ::= rightToken // if allowEmptyList == true
+/// abstract-list ::= element (',' element)* rightToken
+///
+ParseResult Parser::
+parseCommaSeparatedList(Token::TokenKind rightToken,
+ const std::function<ParseResult()> &parseElement,
+ bool allowEmptyList) {
+ // Handle the empty case.
+ if (curToken.is(rightToken)) {
+ if (!allowEmptyList)
+ return emitError("expected list element");
+ consumeToken(rightToken);
+ return ParseSuccess;
+ }
+
+ // Non-empty case starts with an element.
+ if (parseElement())
+ return ParseFailure;
+
+ // Otherwise we have a list of comma separated elements.
+ while (consumeIf(Token::comma)) {
+ if (parseElement())
+ return ParseFailure;
+ }
+
+ // Consume the end character.
+ if (!consumeIf(rightToken))
+ return emitError("expected ',' or ')'");
+
+ return ParseSuccess;
+}
//===----------------------------------------------------------------------===//
// Type Parsing
//===----------------------------------------------------------------------===//
-// ... TODO
+/// Parse the low-level fixed dtypes in the system.
+///
+/// primitive-type
+/// ::= `f16` | `bf16` | `f32` | `f64` // Floating point
+/// | `i1` | `i8` | `i16` | `i32` | `i64` // Sized integers
+/// | `int`
+///
+ParseResult Parser::parsePrimitiveType() {
+ // TODO: Build IR objects.
+ switch (curToken.getKind()) {
+ default: return emitError("expected type");
+ case Token::kw_bf16:
+ consumeToken(Token::kw_bf16);
+ return ParseSuccess;
+ case Token::kw_f16:
+ consumeToken(Token::kw_f16);
+ return ParseSuccess;
+ case Token::kw_f32:
+ consumeToken(Token::kw_f32);
+ return ParseSuccess;
+ case Token::kw_f64:
+ consumeToken(Token::kw_f64);
+ return ParseSuccess;
+ case Token::kw_i1:
+ consumeToken(Token::kw_i1);
+ return ParseSuccess;
+ case Token::kw_i16:
+ consumeToken(Token::kw_i16);
+ return ParseSuccess;
+ case Token::kw_i32:
+ consumeToken(Token::kw_i32);
+ return ParseSuccess;
+ case Token::kw_i64:
+ consumeToken(Token::kw_i64);
+ return ParseSuccess;
+ case Token::kw_i8:
+ consumeToken(Token::kw_i8);
+ return ParseSuccess;
+ case Token::kw_int:
+ consumeToken(Token::kw_int);
+ return ParseSuccess;
+ }
+}
+
+/// Parse the element type of a tensor or memref type.
+///
+/// element-type ::= primitive-type | vector-type
+///
+ParseResult Parser::parseElementType() {
+ if (curToken.is(Token::kw_vector))
+ return parseVectorType();
+
+ return parsePrimitiveType();
+}
+
+/// Parse a vector type.
+///
+/// vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
+/// const-dimension-list ::= (integer-literal `x`)+
+///
+ParseResult Parser::parseVectorType() {
+ consumeToken(Token::kw_vector);
+
+ if (!consumeIf(Token::less))
+ return emitError("expected '<' in vector type");
+
+ if (curToken.isNot(Token::integer))
+ return emitError("expected dimension size in vector type");
+
+ SmallVector<unsigned, 4> dimensions;
+ while (curToken.is(Token::integer)) {
+ // Make sure this integer value is in bound and valid.
+ auto dimension = curToken.getUnsignedIntegerValue();
+ if (!dimension.hasValue())
+ return emitError("invalid dimension in vector type");
+ dimensions.push_back(dimension.getValue());
+
+ consumeToken(Token::integer);
+
+ // Make sure we have an 'x' or something like 'xbf32'.
+ if (curToken.isNot(Token::bare_identifier) ||
+ curToken.getSpelling()[0] != 'x')
+ return emitError("expected 'x' in vector dimension list");
+
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+ if (curToken.getSpelling().size() != 1)
+ lex.resetPointer(curToken.getSpelling().data()+1);
+
+ // Consume the 'x'.
+ consumeToken(Token::bare_identifier);
+ }
+
+ // Parse the element type.
+ if (parsePrimitiveType())
+ return ParseFailure;
+
+ if (!consumeIf(Token::greater))
+ return emitError("expected '>' in vector type");
+
+ // TODO: Form IR object.
+
+ return ParseSuccess;
+}
+
+/// Parse a dimension list of a tensor or memref type. This populates the
+/// dimension list, returning -1 for the '?' dimensions.
+///
+/// dimension-list-ranked ::= (dimension `x`)*
+/// dimension ::= `?` | integer-literal
+///
+ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
+ while (curToken.isAny(Token::integer, Token::question)) {
+ if (consumeIf(Token::question)) {
+ dimensions.push_back(-1);
+ } else {
+ // Make sure this integer value is in bound and valid.
+ auto dimension = curToken.getUnsignedIntegerValue();
+ if (!dimension.hasValue() || (int)dimension.getValue() < 0)
+ return emitError("invalid dimension");
+ dimensions.push_back((int)dimension.getValue());
+ consumeToken(Token::integer);
+ }
+
+ // Make sure we have an 'x' or something like 'xbf32'.
+ if (curToken.isNot(Token::bare_identifier) ||
+ curToken.getSpelling()[0] != 'x')
+ return emitError("expected 'x' in dimension list");
+
+ // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+ if (curToken.getSpelling().size() != 1)
+ lex.resetPointer(curToken.getSpelling().data()+1);
+
+ // Consume the 'x'.
+ consumeToken(Token::bare_identifier);
+ }
+
+ return ParseSuccess;
+}
+
+/// Parse a tensor type.
+///
+/// tensor-type ::= `tensor` `<` dimension-list element-type `>`
+/// dimension-list ::= dimension-list-ranked | `??`
+///
+ParseResult Parser::parseTensorType() {
+ consumeToken(Token::kw_tensor);
+
+ if (!consumeIf(Token::less))
+ return emitError("expected '<' in tensor type");
+
+ bool isUnranked;
+ SmallVector<int, 4> dimensions;
+
+ if (consumeIf(Token::questionquestion)) {
+ isUnranked = true;
+ } else {
+ isUnranked = false;
+ if (parseDimensionListRanked(dimensions))
+ return ParseFailure;
+ }
+
+ // Parse the element type.
+ if (parseElementType())
+ return ParseFailure;
+
+ if (!consumeIf(Token::greater))
+ return emitError("expected '>' in tensor type");
+
+ // TODO: Form IR object.
+
+ return ParseSuccess;
+}
+
+/// Parse a memref type.
+///
+/// memref-type ::= `memref` `<` dimension-list-ranked element-type
+/// (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
+///
+/// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
+/// memory-space ::= integer-literal /* | TODO: address-space-id */
+///
+ParseResult Parser::parseMemRefType() {
+ consumeToken(Token::kw_memref);
+
+ if (!consumeIf(Token::less))
+ return emitError("expected '<' in memref type");
+
+ SmallVector<int, 4> dimensions;
+ if (parseDimensionListRanked(dimensions))
+ return ParseFailure;
+
+ // Parse the element type.
+ if (parseElementType())
+ return ParseFailure;
+
+ // TODO: Parse semi-affine-map-composition.
+ // TODO: Parse memory-space.
+
+ if (!consumeIf(Token::greater))
+ return emitError("expected '>' in memref type");
+
+ // TODO: Form IR object.
+
+ return ParseSuccess;
+}
+
+
+
+/// Parse a function type.
+///
+/// function-type ::= type-list-parens `->` type-list
+///
+ParseResult Parser::parseFunctionType() {
+ assert(curToken.is(Token::l_paren));
+
+ if (parseTypeList())
+ return ParseFailure;
+
+ if (!consumeIf(Token::arrow))
+ return emitError("expected '->' in function type");
+
+ if (parseTypeList())
+ return ParseFailure;
+
+ // TODO: Build IR object.
+ return ParseSuccess;
+}
+
+
+/// Parse an arbitrary type.
+///
+/// type ::= primitive-type
+/// | vector-type
+/// | tensor-type
+/// | memref-type
+/// | function-type
+/// element-type ::= primitive-type | vector-type
+///
+ParseResult Parser::parseType() {
+ switch (curToken.getKind()) {
+ case Token::kw_memref: return parseMemRefType();
+ case Token::kw_tensor: return parseTensorType();
+ case Token::kw_vector: return parseVectorType();
+ case Token::l_paren: return parseFunctionType();
+ default:
+ return parsePrimitiveType();
+ }
+}
+
+/// Parse a "type list", which is a singular type, or a parenthesized list of
+/// types.
+///
+/// type-list ::= type-list-parens | type
+/// type-list-parens ::= `(` `)`
+/// | `(` type (`,` type)* `)`
+///
+ParseResult Parser::parseTypeList() {
+ // If there is no parens, then it must be a singular type.
+ if (!consumeIf(Token::l_paren))
+ return parseType();
+
+ if (parseCommaSeparatedList(Token::r_paren,
+ [&]() -> ParseResult {
+ // TODO: Add to list of IR values we're parsing.
+ return parseType();
+ })) {
+ return ParseFailure;
+ }
+
+ // TODO: Build IR objects.
+ return ParseSuccess;
+}
+
//===----------------------------------------------------------------------===//
// Top-level entity parsing.
@@ -119,13 +452,17 @@
if (curToken.isNot(Token::l_paren))
return emitError("expected '(' in function signature");
- consumeToken(Token::l_paren);
- // TODO: This should actually parse the full grammar here.
+ if (parseTypeList())
+ return ParseFailure;
- if (curToken.isNot(Token::r_paren))
- return emitError("expected ')' in function signature");
- consumeToken(Token::r_paren);
+ // Parse the return type if present.
+ if (consumeIf(Token::arrow)) {
+ if (parseTypeList())
+ return ParseFailure;
+
+ // TODO: Build IR object.
+ }
return ParseSuccess;
}
diff --git a/lib/Parser/Token.cpp b/lib/Parser/Token.cpp
index 551bd1e..c721cf1 100644
--- a/lib/Parser/Token.cpp
+++ b/lib/Parser/Token.cpp
@@ -35,3 +35,15 @@
SMRange Token::getLocRange() const {
return SMRange(getLoc(), getEndLoc());
}
+#include "llvm/Support/raw_ostream.h"
+
+/// For an integer token, return its value as an unsigned. If it doesn't fit,
+/// return None.
+Optional<unsigned> Token::getUnsignedIntegerValue() {
+ bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+ unsigned result = 0;
+ if (spelling.getAsInteger(isHex ? 0 : 10, result))
+ return None;
+ return result;
+}
diff --git a/lib/Parser/Token.h b/lib/Parser/Token.h
index 03c967e..36b1abd 100644
--- a/lib/Parser/Token.h
+++ b/lib/Parser/Token.h
@@ -36,16 +36,34 @@
at_identifier, // @foo
// TODO: @@foo, etc.
+ integer, // 42
+
// Punctuation.
+ arrow, // ->
+ comma, // ,
+ question, // ?
+ questionquestion, // ??
l_paren, r_paren, // ( )
less, greater, // < >
// TODO: More punctuation.
// Keywords.
+ kw_bf16,
kw_cfgfunc,
kw_extfunc,
+ kw_f16,
+ kw_f32,
+ kw_f64,
+ kw_i1,
+ kw_i16,
+ kw_i32,
+ kw_i64,
+ kw_i8,
+ kw_int,
+ kw_memref,
kw_mlfunc,
- // TODO: More keywords.
+ kw_tensor,
+ kw_vector,
};
Token(TokenKind kind, StringRef spelling)
@@ -78,8 +96,13 @@
return !isAny(k1, k2, others...);
}
+ // Helpers to decode specific sorts of tokens.
- /// Location processing.
+ /// For an integer token, return its value as an unsigned. If it doesn't fit,
+ /// return None.
+ Optional<unsigned> getUnsignedIntegerValue();
+
+ // Location processing.
llvm::SMLoc getLoc() const;
llvm::SMLoc getEndLoc() const;
llvm::SMRange getLocRange() const;