Parse operations in ML functions. Add builder class for ML functions.
Refactors operation parsing to share functionality between CFG and ML functions. ML function construction now goes through a builder, similar to the way it is done for
CFG functions.
PiperOrigin-RevId: 204779279
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 626b92f..4b9133c 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -84,6 +84,9 @@
namespace {
+typedef std::function<Operation *(Identifier, ArrayRef<NamedAttribute>)>
+ CreateOperationFunction;
+
/// This class implement support for parsing global entities like types and
/// shared entities like SSA names. It is intended to be subclassed by
/// specialized subparsers that include state, e.g. when a local symbol table.
@@ -167,6 +170,9 @@
ParseResult parseSSAUseAndType();
ParseResult parseOptionalSSAUseAndTypeList(Token::Kind endToken);
+ // Operations
+ ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
+
private:
// The Parser is subclassed and reinstantiated. Do not add additional
// non-trivial state here, add it to the ParserState class.
@@ -1217,6 +1223,70 @@
endToken, [&]() -> ParseResult { return parseSSAUseAndType(); });
}
+//===----------------------------------------------------------------------===//
+// Operations
+//===----------------------------------------------------------------------===//
+
+/// Parse the CFG or MLFunc operation.
+///
+/// TODO(clattner): This is a change from the MLIR spec as written, it is an
+/// experiment that will eliminate "builtin" instructions as a thing.
+///
+/// operation ::=
+/// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
+/// `:` function-type
+///
+ParseResult
+Parser::parseOperation(const CreateOperationFunction &createOpFunc) {
+ auto loc = getToken().getLoc();
+
+ StringRef resultID;
+ if (getToken().is(Token::percent_identifier)) {
+ resultID = getTokenSpelling().drop_front();
+ consumeToken(Token::percent_identifier);
+ if (!consumeIf(Token::equal))
+ return emitError("expected '=' after SSA name");
+ }
+
+ if (getToken().isNot(Token::string))
+ return emitError("expected operation name in quotes");
+
+ auto name = getToken().getStringValue();
+ if (name.empty())
+ return emitError("empty operation name is invalid");
+
+ consumeToken(Token::string);
+
+ if (!consumeIf(Token::l_paren))
+ return emitError("expected '(' to start operand list");
+
+ // Parse the operand list.
+ parseOptionalSSAUseList(Token::r_paren);
+
+ SmallVector<NamedAttribute, 4> attributes;
+ if (getToken().is(Token::l_brace)) {
+ if (parseAttributeDict(attributes))
+ return ParseFailure;
+ }
+
+ // TODO: Don't drop result name and operand names on the floor.
+ auto nameId = builder.getIdentifier(name);
+
+ auto oper = createOpFunc(nameId, attributes);
+
+ if (!oper)
+ return ParseFailure;
+
+ // We just parsed an operation. If it is a recognized one, verify that it
+ // is structurally as we expect. If not, produce an error with a reasonable
+ // source location.
+ if (auto *opInfo = oper->getAbstractOperation(builder.getContext())) {
+ if (auto error = opInfo->verifyInvariants(oper))
+ return emitError(loc, error);
+ }
+
+ return ParseSuccess;
+}
//===----------------------------------------------------------------------===//
// CFG Functions
@@ -1322,72 +1392,23 @@
// Set the insertion point to the block we want to insert new operations into.
builder.setInsertionPoint(block);
+ auto createOpFunc = [this](Identifier name,
+ ArrayRef<NamedAttribute> attrs) -> Operation * {
+ return builder.createOperation(name, attrs);
+ };
+
// Parse the list of operations that make up the body of the block.
while (getToken().isNot(Token::kw_return, Token::kw_br)) {
- auto loc = getToken().getLoc();
- auto *inst = parseCFGOperation();
- if (!inst)
+ if (parseOperation(createOpFunc))
return ParseFailure;
-
- // We just parsed an operation. If it is a recognized one, verify that it
- // is structurally as we expect. If not, produce an error with a reasonable
- // source location.
- if (auto *opInfo = inst->getAbstractOperation(builder.getContext()))
- if (auto error = opInfo->verifyInvariants(inst))
- return emitError(loc, error);
}
- auto *term = parseTerminator();
- if (!term)
+ if (!parseTerminator())
return ParseFailure;
return ParseSuccess;
}
-/// Parse the CFG operation.
-///
-/// TODO(clattner): This is a change from the MLIR spec as written, it is an
-/// experiment that will eliminate "builtin" instructions as a thing.
-///
-/// cfg-operation ::=
-/// (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
-/// `:` function-type
-///
-OperationInst *CFGFunctionParser::parseCFGOperation() {
- StringRef resultID;
- if (getToken().is(Token::percent_identifier)) {
- resultID = getTokenSpelling().drop_front();
- consumeToken();
- if (!consumeIf(Token::equal))
- return (emitError("expected '=' after SSA name"), nullptr);
- }
-
- if (getToken().isNot(Token::string))
- return (emitError("expected operation name in quotes"), nullptr);
-
- auto name = getToken().getStringValue();
- if (name.empty())
- return (emitError("empty operation name is invalid"), nullptr);
-
- consumeToken(Token::string);
-
- if (!consumeIf(Token::l_paren))
- return (emitError("expected '(' to start operand list"), nullptr);
-
- // Parse the operand list.
- parseOptionalSSAUseList(Token::r_paren);
-
- SmallVector<NamedAttribute, 4> attributes;
- if (getToken().is(Token::l_brace)) {
- if (parseAttributeDict(attributes))
- return nullptr;
- }
-
- // TODO: Don't drop result name and operand names on the floor.
- auto nameId = builder.getIdentifier(name);
- return builder.createOperation(nameId, attributes);
-}
-
/// Parse the terminator instruction for a basic block.
///
/// terminator-stmt ::= `br` bb-id branch-use-list?
@@ -1424,22 +1445,22 @@
/// Refined parser for MLFunction bodies.
class MLFunctionParser : public Parser {
public:
- MLFunction *function;
-
- /// This builder intentionally shadows the builder in the base class, with a
- /// more specific builder type.
- // TODO: MLFuncBuilder builder;
-
MLFunctionParser(ParserState &state, MLFunction *function)
- : Parser(state), function(function) {}
+ : Parser(state), function(function), builder(function) {}
ParseResult parseFunctionBody();
private:
- Statement *parseStatement();
- ForStmt *parseForStmt();
- IfStmt *parseIfStmt();
+ MLFunction *function;
+
+ /// This builder intentionally shadows the builder in the base class, with a
+ /// more specific builder type.
+ MLFuncBuilder builder;
+
+ ParseResult parseForStmt();
+ ParseResult parseIfStmt();
ParseResult parseElseClause(IfClause *elseClause);
+ ParseResult parseStatements(StmtBlock *block);
ParseResult parseStmtBlock(StmtBlock *block);
};
} // end anonymous namespace
@@ -1448,19 +1469,14 @@
if (!consumeIf(Token::l_brace))
return emitError("expected '{' in ML function");
- // Make sure we have at least one statement.
- if (getToken().is(Token::r_brace))
- return emitError("ML function must end with return statement");
+ // Parse statements in this function
+ if (parseStatements(function))
+ return ParseFailure;
- // Parse the list of instructions.
- while (!consumeIf(Token::kw_return)) {
- auto *stmt = parseStatement();
- if (!stmt)
- return ParseFailure;
- function->push_back(stmt);
- }
-
+ if (!consumeIf(Token::kw_return))
+ emitError("ML function must end with return statement");
// TODO: parse return statement operands
+
if (!consumeIf(Token::r_brace))
emitError("expected '}' in ML function");
@@ -1469,42 +1485,23 @@
return ParseSuccess;
}
-/// Statement.
-///
-/// ml-stmt ::= instruction | ml-for-stmt | ml-if-stmt
-///
-/// TODO: fix terminology in MLSpec document. ML functions
-/// contain operation statements, not instructions.
-///
-Statement *MLFunctionParser::parseStatement() {
- switch (getToken().getKind()) {
- default:
- //TODO: parse OperationStmt
- return (emitError("expected statement"), nullptr);
-
- case Token::kw_for:
- return parseForStmt();
-
- case Token::kw_if:
- return parseIfStmt();
- }
-}
-
/// For statement.
///
/// ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
/// (`step` integer-literal)? `{` ml-stmt* `}`
///
-ForStmt *MLFunctionParser::parseForStmt() {
+ParseResult MLFunctionParser::parseForStmt() {
consumeToken(Token::kw_for);
//TODO: parse loop header
- ForStmt *stmt = new ForStmt();
- if (parseStmtBlock(static_cast<StmtBlock *>(stmt))) {
- delete stmt;
- return nullptr;
- }
- return stmt;
+ ForStmt *stmt = builder.createFor();
+
+ // If parsing of the for statement body fails
+ // MLIR contains for statement with successfully parsed nested statements
+ if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
+ return ParseFailure;
+
+ return ParseSuccess;
}
/// If statement.
@@ -1514,45 +1511,69 @@
/// ml-if-stmt ::= ml-if-head
/// | ml-if-head `else` `{` ml-stmt* `}`
///
-IfStmt *MLFunctionParser::parseIfStmt() {
+ParseResult MLFunctionParser::parseIfStmt() {
consumeToken(Token::kw_if);
if (!consumeIf(Token::l_paren))
- return (emitError("expected ("), nullptr);
+ return emitError("expected (");
//TODO: parse condition
if (!consumeIf(Token::r_paren))
- return (emitError("expected )"), nullptr);
+ return emitError("expected )");
- IfStmt *ifStmt = new IfStmt();
+ IfStmt *ifStmt = builder.createIf();
IfClause *thenClause = ifStmt->getThenClause();
- if (parseStmtBlock(thenClause)) {
- delete ifStmt;
- return nullptr;
- }
+
+ // If parsing of the then or optional else clause fails MLIR contains
+ // if statement with successfully parsed nested statements.
+ if (parseStmtBlock(thenClause))
+ return ParseFailure;
if (consumeIf(Token::kw_else)) {
IfClause *elseClause = ifStmt->createElseClause();
- if (parseElseClause(elseClause)) {
- delete ifStmt;
- return nullptr;
- }
+ if (parseElseClause(elseClause))
+ return ParseFailure;
}
- return ifStmt;
+ return ParseSuccess;
}
ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
if (getToken().is(Token::kw_if)) {
- IfStmt *nextIf = parseIfStmt();
- if (!nextIf)
- return ParseFailure;
- elseClause->push_back(nextIf);
- return ParseSuccess;
+ builder.setInsertionPoint(elseClause);
+ return parseIfStmt();
}
- if (parseStmtBlock(elseClause))
- return ParseFailure;
+ return parseStmtBlock(elseClause);
+}
+
+///
+/// Parse a list of statements ending with `return` or `}`
+///
+ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
+ auto createOpFunc = [this](Identifier name,
+ ArrayRef<NamedAttribute> attrs) -> Operation * {
+ return builder.createOperation(name, attrs);
+ };
+
+ builder.setInsertionPoint(block);
+
+ while (getToken().isNot(Token::kw_return, Token::r_brace)) {
+ switch (getToken().getKind()) {
+ default:
+ if (parseOperation(createOpFunc))
+ return ParseFailure;
+ break;
+ case Token::kw_for:
+ if (parseForStmt())
+ return ParseFailure;
+ break;
+ case Token::kw_if:
+ if (parseIfStmt())
+ return ParseFailure;
+ break;
+ } // end switch
+ }
return ParseSuccess;
}
@@ -1564,12 +1585,11 @@
if (!consumeIf(Token::l_brace))
return emitError("expected '{' before statement list");
- while (!consumeIf(Token::r_brace)) {
- auto *stmt = parseStatement();
- if (!stmt)
- return ParseFailure;
- block->push_back(stmt);
- }
+ if (parseStatements(block))
+ return ParseFailure;
+
+ if (!consumeIf(Token::r_brace))
+ return emitError("expected '}' at the end of the statement block");
return ParseSuccess;
}