Implement parser and lexer support for most of the type grammar.

Semi-affine maps and address spaces are not yet supported (someone want to take
this on?).  We also don't generate IR objects for types yet, which I plan to
tackle next.

PiperOrigin-RevId: 201754283
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index abad611..6dde8c0 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -73,7 +73,29 @@
     consumeToken();
   }
 
+  /// If the current token has the specified kind, consume it and return true.
+  /// If not, return false.
+  bool consumeIf(Token::TokenKind kind) {
+    if (curToken.isNot(kind))
+      return false;
+    consumeToken(kind);
+    return true;
+  }
+
+  ParseResult parseCommaSeparatedList(Token::TokenKind rightToken,
+                               const std::function<ParseResult()> &parseElement,
+                                      bool allowEmptyList = true);
+
   // Type parsing.
+  ParseResult parsePrimitiveType();
+  ParseResult parseElementType();
+  ParseResult parseVectorType();
+  ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
+  ParseResult parseTensorType();
+  ParseResult parseMemRefType();
+  ParseResult parseFunctionType();
+  ParseResult parseType();
+  ParseResult parseTypeList();
 
   // Top level entity parsing.
   ParseResult parseFunctionSignature(StringRef &name);
@@ -86,6 +108,11 @@
 //===----------------------------------------------------------------------===//
 
 ParseResult Parser::emitError(const Twine &message) {
+  // If we hit a parse error in response to a lexer error, then the lexer
+  // already emitted an error.
+  if (curToken.is(Token::error))
+    return ParseFailure;
+
   // TODO(clattner): If/when we want to implement a -verify mode, this will need
   // to package up errors into SMDiagnostic and report them.
   lex.getSourceMgr().PrintMessage(curToken.getLoc(), SourceMgr::DK_Error,
@@ -93,12 +120,318 @@
   return ParseFailure;
 }
 
+/// Parse a comma-separated list of elements, terminated with an arbitrary
+/// token.  This allows empty lists if allowEmptyList is true.
+///
+///   abstract-list ::= rightToken                  // if allowEmptyList == true
+///   abstract-list ::= element (',' element)* rightToken
+///
+ParseResult Parser::
+parseCommaSeparatedList(Token::TokenKind rightToken,
+                        const std::function<ParseResult()> &parseElement,
+                        bool allowEmptyList) {
+  // Handle the empty case.
+  if (curToken.is(rightToken)) {
+    if (!allowEmptyList)
+      return emitError("expected list element");
+    consumeToken(rightToken);
+    return ParseSuccess;
+  }
+
+  // Non-empty case starts with an element.
+  if (parseElement())
+    return ParseFailure;
+
+  // Otherwise we have a list of comma separated elements.
+  while (consumeIf(Token::comma)) {
+    if (parseElement())
+      return ParseFailure;
+  }
+
+  // Consume the end character.
+  if (!consumeIf(rightToken))
+    return emitError("expected ',' or ')'");
+
+  return ParseSuccess;
+}
 
 //===----------------------------------------------------------------------===//
 // Type Parsing
 //===----------------------------------------------------------------------===//
 
-// ... TODO
+/// Parse the low-level fixed dtypes in the system.
+///
+///   primitive-type
+///      ::= `f16` | `bf16` | `f32` | `f64`      // Floating point
+///        | `i1` | `i8` | `i16` | `i32` | `i64` // Sized integers
+///        | `int`
+///
+ParseResult Parser::parsePrimitiveType() {
+  // TODO: Build IR objects.
+  switch (curToken.getKind()) {
+  default: return emitError("expected type");
+  case Token::kw_bf16:
+    consumeToken(Token::kw_bf16);
+    return ParseSuccess;
+  case Token::kw_f16:
+    consumeToken(Token::kw_f16);
+    return ParseSuccess;
+  case Token::kw_f32:
+    consumeToken(Token::kw_f32);
+    return ParseSuccess;
+  case Token::kw_f64:
+    consumeToken(Token::kw_f64);
+    return ParseSuccess;
+  case Token::kw_i1:
+    consumeToken(Token::kw_i1);
+    return ParseSuccess;
+  case Token::kw_i16:
+    consumeToken(Token::kw_i16);
+    return ParseSuccess;
+  case Token::kw_i32:
+    consumeToken(Token::kw_i32);
+    return ParseSuccess;
+  case Token::kw_i64:
+    consumeToken(Token::kw_i64);
+    return ParseSuccess;
+  case Token::kw_i8:
+    consumeToken(Token::kw_i8);
+    return ParseSuccess;
+  case Token::kw_int:
+    consumeToken(Token::kw_int);
+    return ParseSuccess;
+  }
+}
+
+/// Parse the element type of a tensor or memref type.
+///
+///   element-type ::= primitive-type | vector-type
+///
+ParseResult Parser::parseElementType() {
+  if (curToken.is(Token::kw_vector))
+    return parseVectorType();
+
+  return parsePrimitiveType();
+}
+
+/// Parse a vector type.
+///
+///   vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
+///   const-dimension-list ::= (integer-literal `x`)+
+///
+ParseResult Parser::parseVectorType() {
+  consumeToken(Token::kw_vector);
+
+  if (!consumeIf(Token::less))
+    return emitError("expected '<' in vector type");
+
+  if (curToken.isNot(Token::integer))
+    return emitError("expected dimension size in vector type");
+
+  SmallVector<unsigned, 4> dimensions;
+  while (curToken.is(Token::integer)) {
+    // Make sure this integer value is in bound and valid.
+    auto dimension = curToken.getUnsignedIntegerValue();
+    if (!dimension.hasValue())
+      return emitError("invalid dimension in vector type");
+    dimensions.push_back(dimension.getValue());
+
+    consumeToken(Token::integer);
+
+    // Make sure we have an 'x' or something like 'xbf32'.
+    if (curToken.isNot(Token::bare_identifier) ||
+        curToken.getSpelling()[0] != 'x')
+      return emitError("expected 'x' in vector dimension list");
+
+    // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+    if (curToken.getSpelling().size() != 1)
+      lex.resetPointer(curToken.getSpelling().data()+1);
+
+    // Consume the 'x'.
+    consumeToken(Token::bare_identifier);
+  }
+
+  // Parse the element type.
+  if (parsePrimitiveType())
+    return ParseFailure;
+
+  if (!consumeIf(Token::greater))
+    return emitError("expected '>' in vector type");
+
+  // TODO: Form IR object.
+
+  return ParseSuccess;
+}
+
+/// Parse a dimension list of a tensor or memref type.  This populates the
+/// dimension list, returning -1 for the '?' dimensions.
+///
+///   dimension-list-ranked ::= (dimension `x`)*
+///   dimension ::= `?` | integer-literal
+///
+ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
+  while (curToken.isAny(Token::integer, Token::question)) {
+    if (consumeIf(Token::question)) {
+      dimensions.push_back(-1);
+    } else {
+      // Make sure this integer value is in bound and valid.
+      auto dimension = curToken.getUnsignedIntegerValue();
+      if (!dimension.hasValue() || (int)dimension.getValue() < 0)
+        return emitError("invalid dimension");
+      dimensions.push_back((int)dimension.getValue());
+      consumeToken(Token::integer);
+    }
+
+    // Make sure we have an 'x' or something like 'xbf32'.
+    if (curToken.isNot(Token::bare_identifier) ||
+        curToken.getSpelling()[0] != 'x')
+      return emitError("expected 'x' in dimension list");
+
+    // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+    if (curToken.getSpelling().size() != 1)
+      lex.resetPointer(curToken.getSpelling().data()+1);
+
+    // Consume the 'x'.
+    consumeToken(Token::bare_identifier);
+  }
+
+  return ParseSuccess;
+}
+
+/// Parse a tensor type.
+///
+///   tensor-type ::= `tensor` `<` dimension-list element-type `>`
+///   dimension-list ::= dimension-list-ranked | `??`
+///
+ParseResult Parser::parseTensorType() {
+  consumeToken(Token::kw_tensor);
+
+  if (!consumeIf(Token::less))
+    return emitError("expected '<' in tensor type");
+
+  bool isUnranked;
+  SmallVector<int, 4> dimensions;
+
+  if (consumeIf(Token::questionquestion)) {
+    isUnranked = true;
+  } else {
+    isUnranked = false;
+    if (parseDimensionListRanked(dimensions))
+      return ParseFailure;
+  }
+
+  // Parse the element type.
+  if (parseElementType())
+    return ParseFailure;
+
+  if (!consumeIf(Token::greater))
+    return emitError("expected '>' in tensor type");
+
+  // TODO: Form IR object.
+
+  return ParseSuccess;
+}
+
+/// Parse a memref type.
+///
+///   memref-type ::= `memref` `<` dimension-list-ranked element-type
+///                   (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
+///
+///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
+///   memory-space ::= integer-literal /* | TODO: address-space-id */
+///
+ParseResult Parser::parseMemRefType() {
+  consumeToken(Token::kw_memref);
+
+  if (!consumeIf(Token::less))
+    return emitError("expected '<' in memref type");
+
+  SmallVector<int, 4> dimensions;
+  if (parseDimensionListRanked(dimensions))
+    return ParseFailure;
+
+  // Parse the element type.
+  if (parseElementType())
+    return ParseFailure;
+
+  // TODO: Parse semi-affine-map-composition.
+  // TODO: Parse memory-space.
+
+  if (!consumeIf(Token::greater))
+    return emitError("expected '>' in memref type");
+
+  // TODO: Form IR object.
+
+  return ParseSuccess;
+}
+
+
+
+/// Parse a function type.
+///
+///   function-type ::= type-list-parens `->` type-list
+///
+ParseResult Parser::parseFunctionType() {
+  assert(curToken.is(Token::l_paren));
+
+  if (parseTypeList())
+    return ParseFailure;
+
+  if (!consumeIf(Token::arrow))
+    return emitError("expected '->' in function type");
+
+  if (parseTypeList())
+    return ParseFailure;
+
+  // TODO: Build IR object.
+  return ParseSuccess;
+}
+
+
+/// Parse an arbitrary type.
+///
+///   type ::= primitive-type
+///          | vector-type
+///          | tensor-type
+///          | memref-type
+///          | function-type
+///   element-type ::= primitive-type | vector-type
+///
+ParseResult Parser::parseType() {
+  switch (curToken.getKind()) {
+  case Token::kw_memref: return parseMemRefType();
+  case Token::kw_tensor: return parseTensorType();
+  case Token::kw_vector: return parseVectorType();
+  case Token::l_paren:   return parseFunctionType();
+  default:
+    return parsePrimitiveType();
+  }
+}
+
+/// Parse a "type list", which is a singular type, or a parenthesized list of
+/// types.
+///
+///   type-list ::= type-list-parens | type
+///   type-list-parens ::= `(` `)`
+///                      | `(` type (`,` type)* `)`
+///
+ParseResult Parser::parseTypeList() {
+  // If there is no parens, then it must be a singular type.
+  if (!consumeIf(Token::l_paren))
+    return parseType();
+
+  if (parseCommaSeparatedList(Token::r_paren,
+                              [&]() -> ParseResult {
+    // TODO: Add to list of IR values we're parsing.
+    return parseType();
+  })) {
+    return ParseFailure;
+  }
+
+  // TODO: Build IR objects.
+  return ParseSuccess;
+}
+
 
 //===----------------------------------------------------------------------===//
 // Top-level entity parsing.
@@ -119,13 +452,17 @@
 
   if (curToken.isNot(Token::l_paren))
     return emitError("expected '(' in function signature");
-  consumeToken(Token::l_paren);
 
-  // TODO: This should actually parse the full grammar here.
+  if (parseTypeList())
+    return ParseFailure;
 
-  if (curToken.isNot(Token::r_paren))
-    return emitError("expected ')' in function signature");
-  consumeToken(Token::r_paren);
+  // Parse the return type if present.
+  if (consumeIf(Token::arrow)) {
+    if (parseTypeList())
+      return ParseFailure;
+
+    // TODO: Build IR object.
+  }
 
   return ParseSuccess;
 }