Implement parser and lexer support for most of the type grammar.

Semi-affine maps and address spaces are not yet supported (someone want to take
this on?).  We also don't generate IR objects for types yet, which I plan to
tackle next.

PiperOrigin-RevId: 201754283
diff --git a/lib/Parser/Lexer.cpp b/lib/Parser/Lexer.cpp
index 5958658..2e79271 100644
--- a/lib/Parser/Lexer.cpp
+++ b/lib/Parser/Lexer.cpp
@@ -66,13 +66,33 @@
     // Ignore whitespace.
     return lexToken();
 
+  case ',': return formToken(Token::comma, tokStart);
   case '(': return formToken(Token::l_paren, tokStart);
   case ')': return formToken(Token::r_paren, tokStart);
   case '<': return formToken(Token::less, tokStart);
   case '>': return formToken(Token::greater, tokStart);
 
+  case '-':
+    if (*curPtr == '>') {
+      ++curPtr;
+      return formToken(Token::arrow, tokStart);
+    }
+    return emitError(tokStart, "unexpected character");
+
+  case '?':
+    if (*curPtr == '?') {
+      ++curPtr;
+      return formToken(Token::questionquestion, tokStart);
+    }
+
+    return formToken(Token::question, tokStart);
+
   case ';': return lexComment();
   case '@': return lexAtIdentifier(tokStart);
+
+  case '0': case '1': case '2': case '3': case '4':
+  case '5': case '6': case '7': case '8': case '9':
+    return lexNumber(tokStart);
   }
 }
 
@@ -114,9 +134,22 @@
   StringRef spelling(tokStart, curPtr-tokStart);
 
   Token::TokenKind kind = llvm::StringSwitch<Token::TokenKind>(spelling)
+    .Case("bf16",    Token::kw_bf16)
     .Case("cfgfunc", Token::kw_cfgfunc)
     .Case("extfunc", Token::kw_extfunc)
+    .Case("f16", Token::kw_f16)
+    .Case("f32", Token::kw_f32)
+    .Case("f64", Token::kw_f64)
+    .Case("i1", Token::kw_i1)
+    .Case("i16", Token::kw_i16)
+    .Case("i32", Token::kw_i32)
+    .Case("i64", Token::kw_i64)
+    .Case("i8", Token::kw_i8)
+    .Case("int", Token::kw_int)
+    .Case("memref", Token::kw_memref)
     .Case("mlfunc", Token::kw_mlfunc)
+    .Case("tensor", Token::kw_tensor)
+    .Case("vector", Token::kw_vector)
     .Default(Token::bare_identifier);
 
   return Token(kind, spelling);
@@ -135,3 +168,30 @@
     ++curPtr;
   return formToken(Token::at_identifier, tokStart);
 }
+
+/// Lex an integer literal.
+///
+///   integer-literal ::= digit+ | `0x` hex_digit+
+///
+Token Lexer::lexNumber(const char *tokStart) {
+  assert(isdigit(curPtr[-1]));
+
+  // Handle the hexadecimal case.
+  if (curPtr[-1] == '0' && *curPtr == 'x') {
+    ++curPtr;
+
+    if (!isxdigit(*curPtr))
+      return emitError(curPtr, "expected hexadecimal digit");
+
+    while (isxdigit(*curPtr))
+      ++curPtr;
+
+    return formToken(Token::integer, tokStart);
+  }
+
+  // Handle the normal decimal case.
+  while (isdigit(*curPtr))
+    ++curPtr;
+
+  return formToken(Token::integer, tokStart);
+}
diff --git a/lib/Parser/Lexer.h b/lib/Parser/Lexer.h
index 5886c5c..4f364bc 100644
--- a/lib/Parser/Lexer.h
+++ b/lib/Parser/Lexer.h
@@ -46,6 +46,11 @@
 
   Token lexToken();
 
+  /// Change the position of the lexer cursor.  The next token we lex will start
+  /// at the designated point in the input.
+  void resetPointer(const char *newPointer) {
+    curPtr = newPointer;
+  }
 private:
   // Helpers.
   Token formToken(Token::TokenKind kind, const char *tokStart) {
@@ -58,6 +63,7 @@
   Token lexComment();
   Token lexBareIdentifierOrKeyword(const char *tokStart);
   Token lexAtIdentifier(const char *tokStart);
+  Token lexNumber(const char *tokStart);
 };
 
 } // end namespace mlir
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index abad611..6dde8c0 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -73,7 +73,29 @@
     consumeToken();
   }
 
+  /// If the current token has the specified kind, consume it and return true.
+  /// If not, return false.
+  bool consumeIf(Token::TokenKind kind) {
+    if (curToken.isNot(kind))
+      return false;
+    consumeToken(kind);
+    return true;
+  }
+
+  ParseResult parseCommaSeparatedList(Token::TokenKind rightToken,
+                               const std::function<ParseResult()> &parseElement,
+                                      bool allowEmptyList = true);
+
   // Type parsing.
+  ParseResult parsePrimitiveType();
+  ParseResult parseElementType();
+  ParseResult parseVectorType();
+  ParseResult parseDimensionListRanked(SmallVectorImpl<int> &dimensions);
+  ParseResult parseTensorType();
+  ParseResult parseMemRefType();
+  ParseResult parseFunctionType();
+  ParseResult parseType();
+  ParseResult parseTypeList();
 
   // Top level entity parsing.
   ParseResult parseFunctionSignature(StringRef &name);
@@ -86,6 +108,11 @@
 //===----------------------------------------------------------------------===//
 
 ParseResult Parser::emitError(const Twine &message) {
+  // If we hit a parse error in response to a lexer error, then the lexer
+  // already emitted an error.
+  if (curToken.is(Token::error))
+    return ParseFailure;
+
   // TODO(clattner): If/when we want to implement a -verify mode, this will need
   // to package up errors into SMDiagnostic and report them.
   lex.getSourceMgr().PrintMessage(curToken.getLoc(), SourceMgr::DK_Error,
@@ -93,12 +120,318 @@
   return ParseFailure;
 }
 
+/// Parse a comma-separated list of elements, terminated with an arbitrary
+/// token.  This allows empty lists if allowEmptyList is true.
+///
+///   abstract-list ::= rightToken                  // if allowEmptyList == true
+///   abstract-list ::= element (',' element)* rightToken
+///
+ParseResult Parser::
+parseCommaSeparatedList(Token::TokenKind rightToken,
+                        const std::function<ParseResult()> &parseElement,
+                        bool allowEmptyList) {
+  // Handle the empty case.
+  if (curToken.is(rightToken)) {
+    if (!allowEmptyList)
+      return emitError("expected list element");
+    consumeToken(rightToken);
+    return ParseSuccess;
+  }
+
+  // Non-empty case starts with an element.
+  if (parseElement())
+    return ParseFailure;
+
+  // Otherwise we have a list of comma separated elements.
+  while (consumeIf(Token::comma)) {
+    if (parseElement())
+      return ParseFailure;
+  }
+
+  // Consume the end character.
+  if (!consumeIf(rightToken))
+    return emitError("expected ',' or ')'");
+
+  return ParseSuccess;
+}
 
 //===----------------------------------------------------------------------===//
 // Type Parsing
 //===----------------------------------------------------------------------===//
 
-// ... TODO
+/// Parse the low-level fixed dtypes in the system.
+///
+///   primitive-type
+///      ::= `f16` | `bf16` | `f32` | `f64`      // Floating point
+///        | `i1` | `i8` | `i16` | `i32` | `i64` // Sized integers
+///        | `int`
+///
+ParseResult Parser::parsePrimitiveType() {
+  // TODO: Build IR objects.
+  switch (curToken.getKind()) {
+  default: return emitError("expected type");
+  case Token::kw_bf16:
+    consumeToken(Token::kw_bf16);
+    return ParseSuccess;
+  case Token::kw_f16:
+    consumeToken(Token::kw_f16);
+    return ParseSuccess;
+  case Token::kw_f32:
+    consumeToken(Token::kw_f32);
+    return ParseSuccess;
+  case Token::kw_f64:
+    consumeToken(Token::kw_f64);
+    return ParseSuccess;
+  case Token::kw_i1:
+    consumeToken(Token::kw_i1);
+    return ParseSuccess;
+  case Token::kw_i16:
+    consumeToken(Token::kw_i16);
+    return ParseSuccess;
+  case Token::kw_i32:
+    consumeToken(Token::kw_i32);
+    return ParseSuccess;
+  case Token::kw_i64:
+    consumeToken(Token::kw_i64);
+    return ParseSuccess;
+  case Token::kw_i8:
+    consumeToken(Token::kw_i8);
+    return ParseSuccess;
+  case Token::kw_int:
+    consumeToken(Token::kw_int);
+    return ParseSuccess;
+  }
+}
+
+/// Parse the element type of a tensor or memref type.
+///
+///   element-type ::= primitive-type | vector-type
+///
+ParseResult Parser::parseElementType() {
+  if (curToken.is(Token::kw_vector))
+    return parseVectorType();
+
+  return parsePrimitiveType();
+}
+
+/// Parse a vector type.
+///
+///   vector-type ::= `vector` `<` const-dimension-list primitive-type `>`
+///   const-dimension-list ::= (integer-literal `x`)+
+///
+ParseResult Parser::parseVectorType() {
+  consumeToken(Token::kw_vector);
+
+  if (!consumeIf(Token::less))
+    return emitError("expected '<' in vector type");
+
+  if (curToken.isNot(Token::integer))
+    return emitError("expected dimension size in vector type");
+
+  SmallVector<unsigned, 4> dimensions;
+  while (curToken.is(Token::integer)) {
+    // Make sure this integer value is in bound and valid.
+    auto dimension = curToken.getUnsignedIntegerValue();
+    if (!dimension.hasValue())
+      return emitError("invalid dimension in vector type");
+    dimensions.push_back(dimension.getValue());
+
+    consumeToken(Token::integer);
+
+    // Make sure we have an 'x' or something like 'xbf32'.
+    if (curToken.isNot(Token::bare_identifier) ||
+        curToken.getSpelling()[0] != 'x')
+      return emitError("expected 'x' in vector dimension list");
+
+    // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+    if (curToken.getSpelling().size() != 1)
+      lex.resetPointer(curToken.getSpelling().data()+1);
+
+    // Consume the 'x'.
+    consumeToken(Token::bare_identifier);
+  }
+
+  // Parse the element type.
+  if (parsePrimitiveType())
+    return ParseFailure;
+
+  if (!consumeIf(Token::greater))
+    return emitError("expected '>' in vector type");
+
+  // TODO: Form IR object.
+
+  return ParseSuccess;
+}
+
+/// Parse a dimension list of a tensor or memref type.  This populates the
+/// dimension list, returning -1 for the '?' dimensions.
+///
+///   dimension-list-ranked ::= (dimension `x`)*
+///   dimension ::= `?` | integer-literal
+///
+ParseResult Parser::parseDimensionListRanked(SmallVectorImpl<int> &dimensions) {
+  while (curToken.isAny(Token::integer, Token::question)) {
+    if (consumeIf(Token::question)) {
+      dimensions.push_back(-1);
+    } else {
+      // Make sure this integer value is in bound and valid.
+      auto dimension = curToken.getUnsignedIntegerValue();
+      if (!dimension.hasValue() || (int)dimension.getValue() < 0)
+        return emitError("invalid dimension");
+      dimensions.push_back((int)dimension.getValue());
+      consumeToken(Token::integer);
+    }
+
+    // Make sure we have an 'x' or something like 'xbf32'.
+    if (curToken.isNot(Token::bare_identifier) ||
+        curToken.getSpelling()[0] != 'x')
+      return emitError("expected 'x' in dimension list");
+
+    // If we had a prefix of 'x', lex the next token immediately after the 'x'.
+    if (curToken.getSpelling().size() != 1)
+      lex.resetPointer(curToken.getSpelling().data()+1);
+
+    // Consume the 'x'.
+    consumeToken(Token::bare_identifier);
+  }
+
+  return ParseSuccess;
+}
+
+/// Parse a tensor type.
+///
+///   tensor-type ::= `tensor` `<` dimension-list element-type `>`
+///   dimension-list ::= dimension-list-ranked | `??`
+///
+ParseResult Parser::parseTensorType() {
+  consumeToken(Token::kw_tensor);
+
+  if (!consumeIf(Token::less))
+    return emitError("expected '<' in tensor type");
+
+  bool isUnranked;
+  SmallVector<int, 4> dimensions;
+
+  if (consumeIf(Token::questionquestion)) {
+    isUnranked = true;
+  } else {
+    isUnranked = false;
+    if (parseDimensionListRanked(dimensions))
+      return ParseFailure;
+  }
+
+  // Parse the element type.
+  if (parseElementType())
+    return ParseFailure;
+
+  if (!consumeIf(Token::greater))
+    return emitError("expected '>' in tensor type");
+
+  // TODO: Form IR object.
+
+  return ParseSuccess;
+}
+
+/// Parse a memref type.
+///
+///   memref-type ::= `memref` `<` dimension-list-ranked element-type
+///                   (`,` semi-affine-map-composition)? (`,` memory-space)? `>`
+///
+///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
+///   memory-space ::= integer-literal /* | TODO: address-space-id */
+///
+ParseResult Parser::parseMemRefType() {
+  consumeToken(Token::kw_memref);
+
+  if (!consumeIf(Token::less))
+    return emitError("expected '<' in memref type");
+
+  SmallVector<int, 4> dimensions;
+  if (parseDimensionListRanked(dimensions))
+    return ParseFailure;
+
+  // Parse the element type.
+  if (parseElementType())
+    return ParseFailure;
+
+  // TODO: Parse semi-affine-map-composition.
+  // TODO: Parse memory-space.
+
+  if (!consumeIf(Token::greater))
+    return emitError("expected '>' in memref type");
+
+  // TODO: Form IR object.
+
+  return ParseSuccess;
+}
+
+
+
+/// Parse a function type.
+///
+///   function-type ::= type-list-parens `->` type-list
+///
+ParseResult Parser::parseFunctionType() {
+  assert(curToken.is(Token::l_paren));
+
+  if (parseTypeList())
+    return ParseFailure;
+
+  if (!consumeIf(Token::arrow))
+    return emitError("expected '->' in function type");
+
+  if (parseTypeList())
+    return ParseFailure;
+
+  // TODO: Build IR object.
+  return ParseSuccess;
+}
+
+
+/// Parse an arbitrary type.
+///
+///   type ::= primitive-type
+///          | vector-type
+///          | tensor-type
+///          | memref-type
+///          | function-type
+///   element-type ::= primitive-type | vector-type
+///
+ParseResult Parser::parseType() {
+  switch (curToken.getKind()) {
+  case Token::kw_memref: return parseMemRefType();
+  case Token::kw_tensor: return parseTensorType();
+  case Token::kw_vector: return parseVectorType();
+  case Token::l_paren:   return parseFunctionType();
+  default:
+    return parsePrimitiveType();
+  }
+}
+
+/// Parse a "type list", which is a singular type, or a parenthesized list of
+/// types.
+///
+///   type-list ::= type-list-parens | type
+///   type-list-parens ::= `(` `)`
+///                      | `(` type (`,` type)* `)`
+///
+ParseResult Parser::parseTypeList() {
+  // If there is no parens, then it must be a singular type.
+  if (!consumeIf(Token::l_paren))
+    return parseType();
+
+  if (parseCommaSeparatedList(Token::r_paren,
+                              [&]() -> ParseResult {
+    // TODO: Add to list of IR values we're parsing.
+    return parseType();
+  })) {
+    return ParseFailure;
+  }
+
+  // TODO: Build IR objects.
+  return ParseSuccess;
+}
+
 
 //===----------------------------------------------------------------------===//
 // Top-level entity parsing.
@@ -119,13 +452,17 @@
 
   if (curToken.isNot(Token::l_paren))
     return emitError("expected '(' in function signature");
-  consumeToken(Token::l_paren);
 
-  // TODO: This should actually parse the full grammar here.
+  if (parseTypeList())
+    return ParseFailure;
 
-  if (curToken.isNot(Token::r_paren))
-    return emitError("expected ')' in function signature");
-  consumeToken(Token::r_paren);
+  // Parse the return type if present.
+  if (consumeIf(Token::arrow)) {
+    if (parseTypeList())
+      return ParseFailure;
+
+    // TODO: Build IR object.
+  }
 
   return ParseSuccess;
 }
diff --git a/lib/Parser/Token.cpp b/lib/Parser/Token.cpp
index 551bd1e..c721cf1 100644
--- a/lib/Parser/Token.cpp
+++ b/lib/Parser/Token.cpp
@@ -35,3 +35,15 @@
 SMRange Token::getLocRange() const {
   return SMRange(getLoc(), getEndLoc());
 }
+#include "llvm/Support/raw_ostream.h"
+
+/// For an integer token, return its value as an unsigned.  If it doesn't fit,
+/// return None.
+Optional<unsigned> Token::getUnsignedIntegerValue() {
+  bool isHex = spelling.size() > 1 && spelling[1] == 'x';
+
+  unsigned result = 0;
+  if (spelling.getAsInteger(isHex ? 0 : 10, result))
+    return None;
+  return result;
+}
diff --git a/lib/Parser/Token.h b/lib/Parser/Token.h
index 03c967e..36b1abd 100644
--- a/lib/Parser/Token.h
+++ b/lib/Parser/Token.h
@@ -36,16 +36,34 @@
     at_identifier,      // @foo
     // TODO: @@foo, etc.
 
+    integer,            // 42
+
     // Punctuation.
+    arrow,              // ->
+    comma,              // ,
+    question,           // ?
+    questionquestion,   // ??
     l_paren, r_paren,   // ( )
     less, greater,      // < >
     // TODO: More punctuation.
 
     // Keywords.
+    kw_bf16,
     kw_cfgfunc,
     kw_extfunc,
+    kw_f16,
+    kw_f32,
+    kw_f64,
+    kw_i1,
+    kw_i16,
+    kw_i32,
+    kw_i64,
+    kw_i8,
+    kw_int,
+    kw_memref,
     kw_mlfunc,
-    // TODO: More keywords.
+    kw_tensor,
+    kw_vector,
   };
 
   Token(TokenKind kind, StringRef spelling)
@@ -78,8 +96,13 @@
     return !isAny(k1, k2, others...);
   }
 
+  // Helpers to decode specific sorts of tokens.
 
-  /// Location processing.
+  /// For an integer token, return its value as an unsigned.  If it doesn't fit,
+  /// return None.
+  Optional<unsigned> getUnsignedIntegerValue();
+
+  // Location processing.
   llvm::SMLoc getLoc() const;
   llvm::SMLoc getEndLoc() const;
   llvm::SMRange getLocRange() const;