Introduce a Parser::parseToken method to encapsulate a common pattern with
consumeIf+emitError. NFC.
PiperOrigin-RevId: 205753212
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index fe17b6c..90e3dd3 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -134,6 +134,10 @@
return true;
}
+ /// Consume the specified token if present and return success. On failure,
+ /// output a diagnostic and return failure.
+ ParseResult parseToken(Token::Kind expectedToken, const Twine &message);
+
/// Parse a comma-separated list of elements up until the specified end token.
ParseResult
parseCommaSeparatedListUntil(Token::Kind rightToken,
@@ -192,6 +196,15 @@
return ParseFailure;
}
+/// Consume the specified token if present and return success. On failure,
+/// output a diagnostic and return failure.
+ParseResult Parser::parseToken(Token::Kind expectedToken,
+ const Twine &message) {
+ if (consumeIf(expectedToken))
+ return ParseSuccess;
+ return emitError(message);
+}
+
/// Parse a comma separated list of elements that must have at least one entry
/// in it.
ParseResult Parser::parseCommaSeparatedList(
@@ -225,14 +238,11 @@
return ParseSuccess;
}
- if (parseCommaSeparatedList(parseElement))
+ if (parseCommaSeparatedList(parseElement) ||
+ parseToken(rightToken, "expected ',' or '" +
+ Token::getTokenSpelling(rightToken) + "'"))
return ParseFailure;
- // Consume the end character.
- if (!consumeIf(rightToken))
- return emitError("expected ',' or '" + Token::getTokenSpelling(rightToken) +
- "'");
-
return ParseSuccess;
}
@@ -294,8 +304,8 @@
VectorType *Parser::parseVectorType() {
consumeToken(Token::kw_vector);
- if (!consumeIf(Token::less))
- return (emitError("expected '<' in vector type"), nullptr);
+ if (parseToken(Token::less, "expected '<' in vector type"))
+ return nullptr;
if (getToken().isNot(Token::integer))
return (emitError("expected dimension size in vector type"), nullptr);
@@ -328,8 +338,8 @@
if (!elementType)
return nullptr;
- if (!consumeIf(Token::greater))
- return (emitError("expected '>' in vector type"), nullptr);
+ if (parseToken(Token::greater, "expected '>' in vector type"))
+ return nullptr;
return VectorType::get(dimensions, elementType);
}
@@ -377,8 +387,8 @@
Type *Parser::parseTensorType() {
consumeToken(Token::kw_tensor);
- if (!consumeIf(Token::less))
- return (emitError("expected '<' in tensor type"), nullptr);
+ if (parseToken(Token::less, "expected '<' in tensor type"))
+ return nullptr;
bool isUnranked;
SmallVector<int, 4> dimensions;
@@ -396,8 +406,8 @@
if (!elementType)
return nullptr;
- if (!consumeIf(Token::greater))
- return (emitError("expected '>' in tensor type"), nullptr);
+ if (parseToken(Token::greater, "expected '>' in tensor type"))
+ return nullptr;
if (isUnranked)
return builder.getTensorType(elementType);
@@ -415,8 +425,8 @@
Type *Parser::parseMemRefType() {
consumeToken(Token::kw_memref);
- if (!consumeIf(Token::less))
- return (emitError("expected '<' in memref type"), nullptr);
+ if (parseToken(Token::less, "expected '<' in memref type"))
+ return nullptr;
SmallVector<int, 4> dimensions;
if (parseDimensionListRanked(dimensions))
@@ -427,8 +437,8 @@
if (!elementType)
return nullptr;
- if (!consumeIf(Token::comma))
- return (emitError("expected ',' in memref type"), nullptr);
+ if (parseToken(Token::comma, "expected ',' in memref type"))
+ return nullptr;
// Parse semi-affine-map-composition.
SmallVector<AffineMap *, 2> affineMapComposition;
@@ -480,15 +490,10 @@
Type *Parser::parseFunctionType() {
assert(getToken().is(Token::l_paren));
- SmallVector<Type *, 4> arguments;
- if (parseTypeList(arguments))
- return nullptr;
-
- if (!consumeIf(Token::arrow))
- return (emitError("expected '->' in function type"), nullptr);
-
- SmallVector<Type *, 4> results;
- if (parseTypeList(results))
+ SmallVector<Type *, 4> arguments, results;
+ if (parseTypeList(arguments) ||
+ parseToken(Token::arrow, "expected '->' in function type") ||
+ parseTypeList(results))
return nullptr;
return builder.getFunctionType(arguments, results);
@@ -648,8 +653,8 @@
auto nameId = builder.getIdentifier(getTokenSpelling());
consumeToken();
- if (!consumeIf(Token::colon))
- return emitError("expected ':' in attribute list");
+ if (parseToken(Token::colon, "expected ':' in attribute list"))
+ return ParseFailure;
auto attr = parseAttribute();
if (!attr)
@@ -879,15 +884,17 @@
///
/// affine-expr ::= `(` affine-expr `)`
AffineExpr *AffineMapParser::parseParentheticalExpr() {
- if (!consumeIf(Token::l_paren))
- return (emitError("expected '('"), nullptr);
+ if (parseToken(Token::l_paren, "expected '('"))
+ return nullptr;
if (getToken().is(Token::r_paren))
return (emitError("no expression inside parentheses"), nullptr);
+
auto *expr = parseAffineExpr();
if (!expr)
return nullptr;
- if (!consumeIf(Token::r_paren))
- return (emitError("expected ')'"), nullptr);
+ if (parseToken(Token::r_paren, "expected ')'"))
+ return nullptr;
+
return expr;
}
@@ -895,8 +902,8 @@
///
/// affine-expr ::= `-` affine-expr
AffineExpr *AffineMapParser::parseNegateExpression(AffineExpr *lhs) {
- if (!consumeIf(Token::minus))
- return (emitError("expected '-'"), nullptr);
+ if (parseToken(Token::minus, "expected '-'"))
+ return nullptr;
AffineExpr *operand = parseAffineOperandExpr(lhs);
// Since negation has the highest precedence of all ops (including high
@@ -1094,8 +1101,8 @@
/// Parse the list of symbolic identifiers to an affine map.
ParseResult AffineMapParser::parseSymbolIdList() {
- if (!consumeIf(Token::l_bracket))
- return emitError("expected '['");
+ if (parseToken(Token::l_bracket, "expected '['"))
+ return ParseFailure;
auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(false); };
return parseCommaSeparatedListUntil(Token::r_bracket, parseElt);
@@ -1103,8 +1110,9 @@
/// Parse the list of dimensional identifiers to an affine map.
ParseResult AffineMapParser::parseDimIdList() {
- if (!consumeIf(Token::l_paren))
- return emitError("expected '(' at start of dimensional identifiers list");
+ if (parseToken(Token::l_paren,
+ "expected '(' at start of dimensional identifiers list"))
+ return ParseFailure;
auto parseElt = [&]() -> ParseResult { return parseDimOrSymbolId(true); };
return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
@@ -1127,13 +1135,10 @@
if (parseSymbolIdList())
return nullptr;
}
- if (!consumeIf(Token::arrow)) {
- return (emitError("expected '->' or '['"), nullptr);
- }
- if (!consumeIf(Token::l_paren)) {
- emitError("expected '(' at start of affine map range");
+
+ if (parseToken(Token::arrow, "expected '->' or '['") ||
+ parseToken(Token::l_paren, "expected '(' at start of affine map range"))
return nullptr;
- }
SmallVector<AffineExpr *, 4> exprs;
auto parseElt = [&]() -> ParseResult {
@@ -1158,8 +1163,8 @@
if (consumeIf(Token::kw_size)) {
// Location of the l_paren token (if it exists) for error reporting later.
auto loc = getToken().getLoc();
- if (!consumeIf(Token::l_paren))
- return (emitError("expected '(' at start of affine map range"), nullptr);
+ if (parseToken(Token::l_paren, "expected '(' at start of affine map range"))
+ return nullptr;
auto parseRangeSize = [&]() -> ParseResult {
auto *elt = parseAffineExpr();
@@ -1389,8 +1394,8 @@
result.name = getTokenSpelling();
result.number = 0;
result.loc = getToken().getLoc();
- if (!consumeIf(Token::percent_identifier))
- return emitError("expected SSA operand");
+ if (parseToken(Token::percent_identifier, "expected SSA operand"))
+ return ParseFailure;
// If we have an affine map ID, it is a result number.
if (getToken().is(Token::hash_identifier)) {
@@ -1428,12 +1433,11 @@
template <typename ResultType>
ResultType FunctionParser::parseSSADefOrUseAndType(
const std::function<ResultType(SSAUseInfo, Type *)> &action) {
- SSAUseInfo useInfo;
- if (parseSSAUse(useInfo))
- return nullptr;
- if (!consumeIf(Token::colon))
- return (emitError("expected ':' and type for SSA operand"), nullptr);
+ SSAUseInfo useInfo;
+ if (parseSSAUse(useInfo) ||
+ parseToken(Token::colon, "expected ':' and type for SSA operand"))
+ return nullptr;
auto *type = parseType();
if (!type)
@@ -1472,11 +1476,9 @@
if (valueIDs.empty())
return ParseSuccess;
- if (!consumeIf(Token::colon))
- return emitError("expected ':' in operand list");
-
SmallVector<Type *, 4> types;
- if (parseTypeListNoParens(types))
+ if (parseToken(Token::colon, "expected ':' in operand list") ||
+ parseTypeListNoParens(types))
return ParseFailure;
if (valueIDs.size() != types.size())
@@ -1508,8 +1510,8 @@
if (getToken().is(Token::percent_identifier)) {
resultID = getTokenSpelling();
consumeToken(Token::percent_identifier);
- if (!consumeIf(Token::equal))
- return emitError("expected '=' after SSA name");
+ if (parseToken(Token::equal, "expected '=' after SSA name"))
+ return ParseFailure;
}
if (getToken().isNot(Token::string))
@@ -1521,16 +1523,14 @@
consumeToken(Token::string);
- if (!consumeIf(Token::l_paren))
- return emitError("expected '(' to start operand list");
-
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
- if (parseOptionalSSAUseList(operandInfos))
- return ParseFailure;
- if (!consumeIf(Token::r_paren))
- return emitError("expected ')' to end operand list");
+ if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
+ parseOptionalSSAUseList(operandInfos) ||
+ parseToken(Token::r_paren, "expected ')' to end operand list")) {
+ return ParseFailure;
+ }
SmallVector<NamedAttribute, 4> attributes;
if (getToken().is(Token::l_brace)) {
@@ -1538,8 +1538,8 @@
return ParseFailure;
}
- if (!consumeIf(Token::colon))
- return emitError("expected ':' followed by instruction type");
+ if (parseToken(Token::colon, "expected ':' followed by instruction type"))
+ return ParseFailure;
auto typeLoc = getToken().getLoc();
auto type = parseType();
@@ -1663,8 +1663,8 @@
ParseResult CFGFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
- if (!consumeIf(Token::l_brace))
- return emitError("expected '{' in CFG function");
+ if (parseToken(Token::l_brace, "expected '{' in CFG function"))
+ return ParseFailure;
// Make sure we have at least one block.
if (getToken().is(Token::r_brace))
@@ -1700,8 +1700,8 @@
ParseResult CFGFunctionParser::parseBasicBlock() {
SMLoc nameLoc = getToken().getLoc();
auto name = getTokenSpelling();
- if (!consumeIf(Token::bare_identifier))
- return emitError("expected basic block name");
+ if (parseToken(Token::bare_identifier, "expected basic block name"))
+ return ParseFailure;
auto *block = getBlockNamed(name, nameLoc);
@@ -1713,17 +1713,16 @@
// If an argument list is present, parse it.
if (consumeIf(Token::l_paren)) {
SmallVector<BBArgument *, 8> bbArgs;
- if (parseOptionalBasicBlockArgList(bbArgs, block))
+ if (parseOptionalBasicBlockArgList(bbArgs, block) ||
+ parseToken(Token::r_paren, "expected ')' to end argument list"))
return ParseFailure;
- if (!consumeIf(Token::r_paren))
- return emitError("expected ')' to end argument list");
}
// Add the block to the function.
function->push_back(block);
- if (!consumeIf(Token::colon))
- return emitError("expected ':' after basic block name");
+ if (parseToken(Token::colon, "expected ':' after basic block name"))
+ return ParseFailure;
// Set the insertion point to the block we want to insert new operations into.
builder.setInsertionPoint(block);
@@ -1776,8 +1775,9 @@
case Token::kw_br: {
consumeToken(Token::kw_br);
auto destBB = getBlockNamed(getTokenSpelling(), getToken().getLoc());
- if (!consumeIf(Token::bare_identifier))
- return (emitError("expected basic block name"), nullptr);
+ if (parseToken(Token::bare_identifier, "expected basic block name"))
+ return nullptr;
+
auto branch = builder.createBranchInst(destBB);
SmallVector<CFGValue *, 8> operands;
@@ -1821,23 +1821,21 @@
ParseResult MLFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
- if (!consumeIf(Token::l_brace))
- return emitError("expected '{' in ML function");
-
// Parse statements in this function
- if (parseStatements(function))
- return ParseFailure;
- if (!consumeIf(Token::kw_return))
- emitError("ML function must end with return statement");
+ if (parseToken(Token::l_brace, "expected '{' in ML function") ||
+ parseStatements(function)) {
+ return ParseFailure;
+ }
// TODO: store return operands in the IR.
SmallVector<SSAUseInfo, 4> dummyUseInfo;
- if (parseOptionalSSAUseList(dummyUseInfo))
- return ParseFailure;
- if (!consumeIf(Token::r_brace))
- return emitError("expected '}' to end mlfunc");
+ if (parseToken(Token::kw_return,
+ "ML function must end with return statement") ||
+ parseOptionalSSAUseList(dummyUseInfo) ||
+ parseToken(Token::r_brace, "expected '}' to end mlfunc"))
+ return ParseFailure;
getModule()->functionList.push_back(function);
@@ -1862,16 +1860,16 @@
consumeToken(Token::percent_identifier);
- if (!consumeIf(Token::equal))
- return emitError("expected =");
+ if (parseToken(Token::equal, "expected ="))
+ return ParseFailure;
// Parse loop bounds
AffineConstantExpr *lowerBound = parseIntConstant();
if (!lowerBound)
return ParseFailure;
- if (!consumeIf(Token::kw_to))
- return emitError("expected 'to' between bounds");
+ if (parseToken(Token::kw_to, "expected 'to' between bounds"))
+ return ParseFailure;
AffineConstantExpr *upperBound = parseIntConstant();
if (!upperBound)
@@ -1921,13 +1919,13 @@
///
ParseResult MLFunctionParser::parseIfStmt() {
consumeToken(Token::kw_if);
- if (!consumeIf(Token::l_paren))
- return emitError("expected (");
+ if (parseToken(Token::l_paren, "expected ("))
+ return ParseFailure;
// TODO: parse condition
- if (!consumeIf(Token::r_paren))
- return emitError("expected ')'");
+ if (parseToken(Token::r_paren, "expected )"))
+ return ParseFailure;
IfStmt *ifStmt = builder.createIf();
IfClause *thenClause = ifStmt->getThenClause();
@@ -1939,7 +1937,7 @@
return ParseFailure;
if (consumeIf(Token::kw_else)) {
- IfClause *elseClause = ifStmt->createElseClause();
+ auto *elseClause = ifStmt->createElseClause();
if (parseElseClause(elseClause))
return ParseFailure;
}
@@ -1992,15 +1990,12 @@
/// Parse `{` ml-stmt* `}`
///
ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
- if (!consumeIf(Token::l_brace))
- return emitError("expected '{' before statement list");
-
- if (parseStatements(block))
+ if (parseToken(Token::l_brace, "expected '{' before statement list") ||
+ parseStatements(block) ||
+ parseToken(Token::r_brace,
+ "expected '}' at the end of the statement block"))
return ParseFailure;
- if (!consumeIf(Token::r_brace))
- return emitError("expected '}' at the end of the statement block");
-
return ParseSuccess;
}
@@ -2048,8 +2043,9 @@
consumeToken(Token::hash_identifier);
// Parse the '='
- if (!consumeIf(Token::equal))
- return emitError("expected '=' in affine map outlined definition");
+ if (parseToken(Token::equal,
+ "expected '=' in affine map outlined definition"))
+ return ParseFailure;
entry = parseAffineMapInline();
if (!entry)
@@ -2066,6 +2062,8 @@
ParseResult
ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
SmallVectorImpl<StringRef> &argNames) {
+ consumeToken(Token::l_paren);
+
auto parseElt = [&]() -> ParseResult {
// Parse argument name
if (getToken().isNot(Token::percent_identifier))
@@ -2075,8 +2073,8 @@
consumeToken(Token::percent_identifier);
argNames.push_back(name);
- if (!consumeIf(Token::colon))
- return emitError("expected ':'");
+ if (parseToken(Token::colon, "expected ':'"))
+ return ParseFailure;
// Parse argument type
auto elt = parseType();
@@ -2087,9 +2085,6 @@
return ParseSuccess;
};
- if (!consumeIf(Token::l_paren))
- llvm_unreachable("expected '('");
-
return parseCommaSeparatedListUntil(Token::r_paren, parseElt);
}