Implement return statement as RetOp operation. Add verification of the return statement placement and operands. Add parser and parsing error tests for return statements with non-zero number of operands. Add a few missing tests for ForStmt parsing errors.
Prior to this CL, return statement had no explicit representation in MLIR. Now, it is represented as ReturnOp standard operation and is pretty printed according to the return statement syntax. This way statement walkers can process ML function return operands without making special case for them.
PiperOrigin-RevId: 208092424
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 45ac4e6..f263804 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -2117,20 +2117,9 @@
ParseResult MLFunctionParser::parseFunctionBody() {
auto braceLoc = getToken().getLoc();
- // Parse statements in this function
- if (parseToken(Token::l_brace, "expected '{' in ML function") ||
- parseStatements(function)) {
- return ParseFailure;
- }
-
- // TODO: store return operands in the IR.
- SmallVector<SSAUseInfo, 4> dummyUseInfo;
-
- if (parseToken(Token::kw_return,
- "ML function must end with return statement") ||
- parseOptionalSSAUseList(dummyUseInfo) ||
- parseToken(Token::r_brace, "expected '}' to end mlfunc"))
+ // Parse statements in this function.
+ if (parseStmtBlock(function))
return ParseFailure;
getModule()->getFunctions().push_back(function);
@@ -2154,7 +2143,7 @@
StringRef inductionVariableName = getTokenSpelling();
consumeToken(Token::percent_identifier);
- if (parseToken(Token::equal, "expected ="))
+ if (parseToken(Token::equal, "expected '='"))
return ParseFailure;
// Parse loop bounds
@@ -2387,7 +2376,10 @@
builder.setInsertionPointToEnd(block);
- while (getToken().isNot(Token::kw_return, Token::r_brace)) {
+ // Parse statements till we see '}' or 'return'.
+ // Return statement is parsed separately to emit a more intuitive error
+ // when '}' is missing after the return statement.
+ while (getToken().isNot(Token::r_brace, Token::kw_return)) {
switch (getToken().getKind()) {
default:
if (parseOperation(createOpFunc))
@@ -2404,6 +2396,11 @@
} // end switch
}
+ // Parse the return statement.
+ if (getToken().is(Token::kw_return))
+ if (parseOperation(createOpFunc))
+ return ParseFailure;
+
return ParseSuccess;
}
@@ -2413,8 +2410,7 @@
ParseResult MLFunctionParser::parseStmtBlock(StmtBlock *block) {
if (parseToken(Token::l_brace, "expected '{' before statement list") ||
parseStatements(block) ||
- parseToken(Token::r_brace,
- "expected '}' at the end of the statement block"))
+ parseToken(Token::r_brace, "expected '}' after statement list"))
return ParseFailure;
return ParseSuccess;