Introduce a Parser::parseToken method to encapsulate a common pattern with
consumeIf+emitError.  NFC.

PiperOrigin-RevId: 205753212
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index fe17b6c..90e3dd3 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -134,6 +134,10 @@
     return true;
   }
 
+  /// Consume the specified token if present and return success.  On failure,
+  /// output a diagnostic and return failure.
+  ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
+
   /// Parse a comma-separated list of elements up until the specified end token.
   ParseResult
   parseCommaSeparatedListUntil(Token::Kind rightToken,
@@ -192,6 +196,15 @@
   return ParseFailure;
 }
 
+/// Consume the specified token if present and return success.  On failure,
+/// output a diagnostic and return failure.
+ParseResult Parser::parseToken(Token::Kind expectedToken,
+                               const Twine &message) {
+  if (consumeIf(expectedToken))
+    return ParseSuccess;
+  return emitError(message);
+}
+
 /// Parse a comma separated list of elements that must have at least one entry
 /// in it.
 ParseResult Parser::parseCommaSeparatedList(
@@ -225,14 +238,11 @@
     return ParseSuccess;
   }
 
-  if (parseCommaSeparatedList(parseElement))
+  if (parseCommaSeparatedList(parseElement) ||
+      parseToken(rightToken, "expected ',' or '" +
+                                 Token::getTokenSpelling(rightToken) + "'"))
     return ParseFailure;
 
-  // Consume the end character.
-  if (!consumeIf(rightToken))
-    return emitError("expected ',' or '" + Token::getTokenSpelling(rightToken) +
-                     "'");
-
   return ParseSuccess;
 }
 
@@ -294,8 +304,8 @@
 VectorType *Parser::parseVectorType() {
   consumeToken(Token::kw_vector);
 
-  if (!consumeIf(Token::less))
-    return (emitError("expected '<' in vector type"), nullptr);
+  if (parseToken(Token::less, "expected '<' in vector type"))
+    return nullptr;
 
   if (getToken().isNot(Token::integer))
     return (emitError("expected dimension size in vector type"), nullptr);
@@ -328,8 +338,8 @@
   if (!elementType)
     return nullptr;
 
-  if (!consumeIf(Token::greater))
-    return (emitError("expected '>' in vector type"), nullptr);
+  if (parseToken(Token::greater, "expected '>' in vector type"))
+    return nullptr;
 
   return VectorType::get(dimensions, elementType);
 }
@@ -377,8 +387,8 @@
 Type *Parser::parseTensorType() {
   consumeToken(Token::kw_tensor);
 
-  if (!consumeIf(Token::less))
-    return (emitError("expected '<' in tensor type"), nullptr);
+  if (parseToken(Token::less, "expected '<' in tensor type"))
+    return nullptr;
 
   bool isUnranked;
   SmallVector<int, 4> dimensions;
@@ -396,8 +406,8 @@
   if (!elementType)
     return nullptr;
 
-  if (!consumeIf(Token::greater))
-    return (emitError("expected '>' in tensor type"), nullptr);
+  if (parseToken(Token::greater, "expected '>' in tensor type"))
+    return nullptr;
 
   if (isUnranked)
     return builder.getTensorType(elementType);
@@ -415,8 +425,8 @@
 Type *Parser::parseMemRefType() {
   consumeToken(Token::kw_memref);
 
-  if (!consumeIf(Token::less))
-    return (emitError("expected '<' in memref type"), nullptr);
+  if (parseToken(Token::less, "expected '<' in memref type"))
+    return nullptr;
 
   SmallVector<int, 4> dimensions;
   if (parseDimensionListRanked(dimensions))
@@ -427,8 +437,8 @@
   if (!elementType)
     return nullptr;
 
-  if (!consumeIf(Token::comma))
-    return (emitError("expected ',' in memref type"), nullptr);
+  if (parseToken(Token::comma, "expected ',' in memref type"))
+    return nullptr;
 
   // Parse semi-affine-map-composition.
   SmallVector<AffineMap *, 2> affineMapComposition;
@@ -480,15 +490,10 @@
 Type *Parser::parseFunctionType() {
   assert(getToken().is(Token::l_paren));
 
-  SmallVector<Type *, 4> arguments;
-  if (parseTypeList(arguments))
-    return nullptr;
-
-  if (!consumeIf(Token::arrow))
-    return (emitError("expected '->' in function type"), nullptr);
-
-  SmallVector<Type *, 4> results;
-  if (parseTypeList(results))
+  SmallVector<Type *, 4> arguments, results;
+  if (parseTypeList(arguments) ||
+      parseToken(Token::arrow, "expected '->' in function type") ||
+      parseTypeList(results))
     return nullptr;
 
   return builder.getFunctionType(arguments, results);
@@ -648,8 +653,8 @@
     auto nameId = builder.getIdentifier(getTokenSpelling());
     consumeToken();
 
-    if (!consumeIf(Token::colon))
-      return emitError("expected ':' in attribute list");
+    if (parseToken(Token::colon, "expected ':' in attribute list"))
+      return ParseFailure;
 
     auto attr = parseAttribute();
     if (!attr)
@@ -879,15 +884,17 @@
 ///
 ///   affine-expr ::= `(` affine-expr `)`
 AffineExpr *AffineMapParser::parseParentheticalExpr() {
-  if (!consumeIf(Token::l_paren))
-    return (emitError("expected '('"), nullptr);
+  if (parseToken(Token::l_paren, "expected '('"))
+    return nullptr;
   if (getToken().is(Token::r_paren))
     return (emitError("no expression inside parentheses"), nullptr);
+
   auto *expr = parseAffineExpr();
   if (!expr)
     return nullptr;
-  if (!consumeIf(Token::r_paren))
-    return (emitError("expected ')'"), nullptr);
+  if (parseToken(Token::r_paren, "expected ')'"))
+    return nullptr;
+
   return expr;
 }
 
@@ -895,8 +902,8 @@
 ///
 ///   affine-expr ::= `-` affine-expr
 AffineExpr *AffineMapParser::parseNegateExpression(AffineExpr *lhs) {
-  if (!consumeIf(Token::minus))
-    return (emitError("expected '-'"), nullptr);
+  if (parseToken(Token::minus, "expected '-'"))
+    return nullptr;
 
   AffineExpr *operand = parseAffineOperandExpr(lhs);
   // Since negation has the highest precedence of all ops (including high
@@ -1094,8 +1101,8 @@
 
 /// Parse the list of symbolic identifiers to an affine map.
 ParseResult AffineMapParser::parseSymbolIdList() {
-  if (!consumeIf(Token::l_bracket))
-    return emitError("expected '['");
+  if (parseToken(Token::l_bracket, "expected '['"))
+    return ParseFailure;
 
   auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(false); };
   return parseCommaSeparatedListUntil(Token::r_bracket, parseElt);
@@ -1103,8 +1110,9 @@
 
 /// Parse the list of dimensional identifiers to an affine map.
 ParseResult AffineMapParser::parseDimIdList() {
-  if (!consumeIf(Token::l_paren))
-    return emitError("expected '(' at start of dimensional identifiers list");
+  if (parseToken(Token::l_paren,
+                 "expected '(' at start of dimensional identifiers list"))
+    return ParseFailure;
 
   auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(true); };
   return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
@@ -1127,13 +1135,10 @@
     if (parseSymbolIdList())
       return nullptr;
   }
-  if (!consumeIf(Token::arrow)) {
-    return (emitError("expected '->' or '['"), nullptr);
-  }
-  if (!consumeIf(Token::l_paren)) {
-    emitError("expected '(' at start of affine map range");
+
+  if (parseToken(Token::arrow, "expected '->' or '['") ||
+      parseToken(Token::l_paren, "expected '(' at start of affine map range"))
     return nullptr;
-  }
 
   SmallVector<AffineExpr *, 4> exprs;
   auto parseElt = [&]() -> ParseResult {
@@ -1158,8 +1163,8 @@
   if (consumeIf(Token::kw_size)) {
     // Location of the l_paren token (if it exists) for error reporting later.
     auto loc = getToken().getLoc();
-    if (!consumeIf(Token::l_paren))
-      return (emitError("expected '(' at start of affine map range"), nullptr);
+    if (parseToken(Token::l_paren, "expected '(' at start of affine map range"))
+      return nullptr;
 
     auto parseRangeSize = [&]() -> ParseResult {
       auto *elt = parseAffineExpr();
@@ -1389,8 +1394,8 @@
   result.name = getTokenSpelling();
   result.number = 0;
   result.loc = getToken().getLoc();
-  if (!consumeIf(Token::percent_identifier))
-    return emitError("expected SSA operand");
+  if (parseToken(Token::percent_identifier, "expected SSA operand"))
+    return ParseFailure;
 
   // If we have an affine map ID, it is a result number.
   if (getToken().is(Token::hash_identifier)) {
@@ -1428,12 +1433,11 @@
 template <typename ResultType>
 ResultType FunctionParser::parseSSADefOrUseAndType(
     const std::function<ResultType(SSAUseInfo, Type *)> &action) {
-  SSAUseInfo useInfo;
-  if (parseSSAUse(useInfo))
-    return nullptr;
 
-  if (!consumeIf(Token::colon))
-    return (emitError("expected ':' and type for SSA operand"), nullptr);
+  SSAUseInfo useInfo;
+  if (parseSSAUse(useInfo) ||
+      parseToken(Token::colon, "expected ':' and type for SSA operand"))
+    return nullptr;
 
   auto *type = parseType();
   if (!type)
@@ -1472,11 +1476,9 @@
   if (valueIDs.empty())
     return ParseSuccess;
 
-  if (!consumeIf(Token::colon))
-    return emitError("expected ':' in operand list");
-
   SmallVector<Type *, 4> types;
-  if (parseTypeListNoParens(types))
+  if (parseToken(Token::colon, "expected ':' in operand list") ||
+      parseTypeListNoParens(types))
     return ParseFailure;
 
   if (valueIDs.size() != types.size())
@@ -1508,8 +1510,8 @@
   if (getToken().is(Token::percent_identifier)) {
     resultID = getTokenSpelling();
     consumeToken(Token::percent_identifier);
-    if (!consumeIf(Token::equal))
-      return emitError("expected '=' after SSA name");
+    if (parseToken(Token::equal, "expected '=' after SSA name"))
+      return ParseFailure;
   }
 
   if (getToken().isNot(Token::string))
@@ -1521,16 +1523,14 @@
 
   consumeToken(Token::string);
 
-  if (!consumeIf(Token::l_paren))
-    return emitError("expected '(' to start operand list");
-
   // Parse the operand list.
   SmallVector<SSAUseInfo, 8> operandInfos;
-  if (parseOptionalSSAUseList(operandInfos))
-    return ParseFailure;
 
-  if (!consumeIf(Token::r_paren))
-    return emitError("expected ')' to end operand list");
+  if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
+      parseOptionalSSAUseList(operandInfos) ||
+      parseToken(Token::r_paren, "expected ')' to end operand list")) {
+    return ParseFailure;
+  }
 
   SmallVector<NamedAttribute, 4> attributes;
   if (getToken().is(Token::l_brace)) {
@@ -1538,8 +1538,8 @@
       return ParseFailure;
   }
 
-  if (!consumeIf(Token::colon))
-    return emitError("expected ':' followed by instruction type");
+  if (parseToken(Token::colon, "expected ':' followed by instruction type"))
+    return ParseFailure;
 
   auto typeLoc = getToken().getLoc();
   auto type = parseType();
@@ -1663,8 +1663,8 @@
 
 ParseResult CFGFunctionParser::parseFunctionBody() {
   auto braceLoc = getToken().getLoc();
-  if (!consumeIf(Token::l_brace))
-    return emitError("expected '{' in CFG function");
+  if (parseToken(Token::l_brace, "expected '{' in CFG function"))
+    return ParseFailure;
 
   // Make sure we have at least one block.
   if (getToken().is(Token::r_brace))
@@ -1700,8 +1700,8 @@
 ParseResult CFGFunctionParser::parseBasicBlock() {
   SMLoc nameLoc = getToken().getLoc();
   auto name = getTokenSpelling();
-  if (!consumeIf(Token::bare_identifier))
-    return emitError("expected basic block name");
+  if (parseToken(Token::bare_identifier, "expected basic block name"))
+    return ParseFailure;
 
   auto *block = getBlockNamed(name, nameLoc);
 
@@ -1713,17 +1713,16 @@
   // If an argument list is present, parse it.
   if (consumeIf(Token::l_paren)) {
     SmallVector<BBArgument *, 8> bbArgs;
-    if (parseOptionalBasicBlockArgList(bbArgs, block))
+    if (parseOptionalBasicBlockArgList(bbArgs, block) ||
+        parseToken(Token::r_paren, "expected ')' to end argument list"))
       return ParseFailure;
-    if (!consumeIf(Token::r_paren))
-      return emitError("expected ')' to end argument list");
   }
 
   // Add the block to the function.
   function->push_back(block);
 
-  if (!consumeIf(Token::colon))
-    return emitError("expected ':' after basic block name");
+  if (parseToken(Token::colon, "expected ':' after basic block name"))
+    return ParseFailure;
 
   // Set the insertion point to the block we want to insert new operations into.
   builder.setInsertionPoint(block);
@@ -1776,8 +1775,9 @@
   case Token::kw_br: {
     consumeToken(Token::kw_br);
     auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
-    if (!consumeIf(Token::bare_identifier))
-      return (emitError("expected basic block name"), nullptr);
+    if (parseToken(Token::bare_identifier, "expected basic block name"))
+      return nullptr;
+
     auto branch = builder.createBranchInst(destBB);
 
     SmallVector<CFGValue *, 8> operands;
@@ -1821,23 +1821,21 @@
 
 ParseResult MLFunctionParser::parseFunctionBody() {
   auto braceLoc = getToken().getLoc();
-  if (!consumeIf(Token::l_brace))
-    return emitError("expected '{' in ML function");
-
   // Parse statements in this function
-  if (parseStatements(function))
-    return ParseFailure;
 
-  if (!consumeIf(Token::kw_return))
-    emitError("ML function must end with return statement");
+  if (parseToken(Token::l_brace, "expected '{' in ML function") ||
+      parseStatements(function)) {
+    return ParseFailure;
+  }
 
   // TODO: store return operands in the IR.
   SmallVector<SSAUseInfo, 4> dummyUseInfo;
-  if (parseOptionalSSAUseList(dummyUseInfo))
-    return ParseFailure;
 
-  if (!consumeIf(Token::r_brace))
-    return emitError("expected '}' to end mlfunc");
+  if (parseToken(Token::kw_return,
+                 "ML function must end with return statement") ||
+      parseOptionalSSAUseList(dummyUseInfo) ||
+      parseToken(Token::r_brace, "expected '}' to end mlfunc"))
+    return ParseFailure;
 
   getModule()->functionList.push_back(function);
 
@@ -1862,16 +1860,16 @@
 
   consumeToken(Token::percent_identifier);
 
-  if (!consumeIf(Token::equal))
-    return emitError("expected =");
+  if (parseToken(Token::equal, "expected ="))
+    return ParseFailure;
 
   // Parse loop bounds
   AffineConstantExpr *lowerBound = parseIntConstant();
   if (!lowerBound)
     return ParseFailure;
 
-  if (!consumeIf(Token::kw_to))
-    return emitError("expected 'to' between bounds");
+  if (parseToken(Token::kw_to, "expected 'to' between bounds"))
+    return ParseFailure;
 
   AffineConstantExpr *upperBound = parseIntConstant();
   if (!upperBound)
@@ -1921,13 +1919,13 @@
 ///
 ParseResult MLFunctionParser::parseIfStmt() {
   consumeToken(Token::kw_if);
-  if (!consumeIf(Token::l_paren))
-    return emitError("expected (");
+  if (parseToken(Token::l_paren, "expected ("))
+    return ParseFailure;
 
   // TODO: parse condition
 
-  if (!consumeIf(Token::r_paren))
-    return emitError("expected ')'");
+  if (parseToken(Token::r_paren, "expected )"))
+    return ParseFailure;
 
   IfStmt *ifStmt = builder.createIf();
   IfClause *thenClause = ifStmt->getThenClause();
@@ -1939,7 +1937,7 @@
     return ParseFailure;
 
   if (consumeIf(Token::kw_else)) {
-    IfClause *elseClause = ifStmt->createElseClause();
+    auto *elseClause = ifStmt->createElseClause();
     if (parseElseClause(elseClause))
       return ParseFailure;
   }
@@ -1992,15 +1990,12 @@
 /// Parse `{` ml-stmt* `}`
 ///
 ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
-  if (!consumeIf(Token::l_brace))
-    return emitError("expected '{' before statement list");
-
-  if (parseStatements(block))
+  if (parseToken(Token::l_brace, "expected '{' before statement list") ||
+      parseStatements(block) ||
+      parseToken(Token::r_brace,
+                 "expected '}' at the end of the statement block"))
     return ParseFailure;
 
-  if (!consumeIf(Token::r_brace))
-    return emitError("expected '}' at the end of the statement block");
-
   return ParseSuccess;
 }
 
@@ -2048,8 +2043,9 @@
   consumeToken(Token::hash_identifier);
 
   // Parse the '='
-  if (!consumeIf(Token::equal))
-    return emitError("expected '=' in affine map outlined definition");
+  if (parseToken(Token::equal,
+                 "expected '=' in affine map outlined definition"))
+    return ParseFailure;
 
   entry = parseAffineMapInline();
   if (!entry)
@@ -2066,6 +2062,8 @@
 ParseResult
 ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
                                   SmallVectorImpl<StringRef> &argNames) {
+  consumeToken(Token::l_paren);
+
   auto parseElt = [&]() -> ParseResult {
     // Parse argument name
     if (getToken().isNot(Token::percent_identifier))
@@ -2075,8 +2073,8 @@
     consumeToken(Token::percent_identifier);
     argNames.push_back(name);
 
-    if (!consumeIf(Token::colon))
-      return emitError("expected ':'");
+    if (parseToken(Token::colon, "expected ':'"))
+      return ParseFailure;
 
     // Parse argument type
     auto elt = parseType();
@@ -2087,9 +2085,6 @@
     return ParseSuccess;
   };
 
-  if (!consumeIf(Token::l_paren))
-    llvm_unreachable("expected '('");
-
   return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
 }