Parse operations in ML functions. Add builder class for ML functions.

Refactors operation parsing to share functionality between CFG and ML functions. ML function construction now goes through a builder, similar to the way it is done for
CFG functions.

PiperOrigin-RevId: 204779279
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 626b92f..4b9133c 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -84,6 +84,9 @@
 
 namespace {
 
+typedef std::function<Operation *(Identifier, ArrayRef<NamedAttribute>)>
+    CreateOperationFunction;
+
 /// This class implement support for parsing global entities like types and
 /// shared entities like SSA names.  It is intended to be subclassed by
 /// specialized subparsers that include state, e.g. when a local symbol table.
@@ -167,6 +170,9 @@
   ParseResult parseSSAUseAndType();
   ParseResult parseOptionalSSAUseAndTypeList(Token::Kind endToken);
 
+  // Operations
+  ParseResult parseOperation(const CreateOperationFunction &createOpFunc);
+
 private:
   // The Parser is subclassed and reinstantiated.  Do not add additional
   // non-trivial state here, add it to the ParserState class.
@@ -1217,6 +1223,70 @@
       endToken, [&]() -> ParseResult { return parseSSAUseAndType(); });
 }
 
+//===----------------------------------------------------------------------===//
+// Operations
+//===----------------------------------------------------------------------===//
+
+/// Parse the CFG or MLFunc operation.
+///
+/// TODO(clattner): This is a change from the MLIR spec as written, it is an
+/// experiment that will eliminate "builtin" instructions as a thing.
+///
+///  operation ::=
+///    (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
+///    `:` function-type
+///
+ParseResult
+Parser::parseOperation(const CreateOperationFunction &createOpFunc) {
+  auto loc = getToken().getLoc();
+
+  StringRef resultID;
+  if (getToken().is(Token::percent_identifier)) {
+    resultID = getTokenSpelling().drop_front();
+    consumeToken(Token::percent_identifier);
+    if (!consumeIf(Token::equal))
+      return emitError("expected '=' after SSA name");
+  }
+
+  if (getToken().isNot(Token::string))
+    return emitError("expected operation name in quotes");
+
+  auto name = getToken().getStringValue();
+  if (name.empty())
+    return emitError("empty operation name is invalid");
+
+  consumeToken(Token::string);
+
+  if (!consumeIf(Token::l_paren))
+    return emitError("expected '(' to start operand list");
+
+  // Parse the operand list.
+  parseOptionalSSAUseList(Token::r_paren);
+
+  SmallVector<NamedAttribute, 4> attributes;
+  if (getToken().is(Token::l_brace)) {
+    if (parseAttributeDict(attributes))
+      return ParseFailure;
+  }
+
+  // TODO: Don't drop result name and operand names on the floor.
+  auto nameId = builder.getIdentifier(name);
+
+  auto oper = createOpFunc(nameId, attributes);
+
+  if (!oper)
+    return ParseFailure;
+
+  // We just parsed an operation.  If it is a recognized one, verify that it
+  // is structurally as we expect.  If not, produce an error with a reasonable
+  // source location.
+  if (auto *opInfo = oper->getAbstractOperation(builder.getContext())) {
+    if (auto error = opInfo->verifyInvariants(oper))
+      return emitError(loc, error);
+  }
+
+  return ParseSuccess;
+}
 
 //===----------------------------------------------------------------------===//
 // CFG Functions
@@ -1322,72 +1392,23 @@
   // Set the insertion point to the block we want to insert new operations into.
   builder.setInsertionPoint(block);
 
+  auto createOpFunc = [this](Identifier name,
+                             ArrayRef<NamedAttribute> attrs) -> Operation * {
+    return builder.createOperation(name, attrs);
+  };
+
   // Parse the list of operations that make up the body of the block.
   while (getToken().isNot(Token::kw_return, Token::kw_br)) {
-    auto loc = getToken().getLoc();
-    auto *inst = parseCFGOperation();
-    if (!inst)
+    if (parseOperation(createOpFunc))
       return ParseFailure;
-
-    // We just parsed an operation.  If it is a recognized one, verify that it
-    // is structurally as we expect.  If not, produce an error with a reasonable
-    // source location.
-    if (auto *opInfo = inst->getAbstractOperation(builder.getContext()))
-      if (auto error = opInfo->verifyInvariants(inst))
-        return emitError(loc, error);
   }
 
-  auto *term = parseTerminator();
-  if (!term)
+  if (!parseTerminator())
     return ParseFailure;
 
   return ParseSuccess;
 }
 
-/// Parse the CFG operation.
-///
-/// TODO(clattner): This is a change from the MLIR spec as written, it is an
-/// experiment that will eliminate "builtin" instructions as a thing.
-///
-///  cfg-operation ::=
-///    (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
-///    `:` function-type
-///
-OperationInst *CFGFunctionParser::parseCFGOperation() {
-  StringRef resultID;
-  if (getToken().is(Token::percent_identifier)) {
-    resultID = getTokenSpelling().drop_front();
-    consumeToken();
-    if (!consumeIf(Token::equal))
-      return (emitError("expected '=' after SSA name"), nullptr);
-  }
-
-  if (getToken().isNot(Token::string))
-    return (emitError("expected operation name in quotes"), nullptr);
-
-  auto name = getToken().getStringValue();
-  if (name.empty())
-    return (emitError("empty operation name is invalid"), nullptr);
-
-  consumeToken(Token::string);
-
-  if (!consumeIf(Token::l_paren))
-    return (emitError("expected '(' to start operand list"), nullptr);
-
-  // Parse the operand list.
-  parseOptionalSSAUseList(Token::r_paren);
-
-  SmallVector<NamedAttribute, 4> attributes;
-  if (getToken().is(Token::l_brace)) {
-    if (parseAttributeDict(attributes))
-      return nullptr;
-  }
-
-  // TODO: Don't drop result name and operand names on the floor.
-  auto nameId = builder.getIdentifier(name);
-  return builder.createOperation(nameId, attributes);
-}
-
 /// Parse the terminator instruction for a basic block.
 ///
 ///   terminator-stmt ::= `br` bb-id branch-use-list?
@@ -1424,22 +1445,22 @@
 /// Refined parser for MLFunction bodies.
 class MLFunctionParser : public Parser {
 public:
-  MLFunction *function;
-
-  /// This builder intentionally shadows the builder in the base class, with a
-  /// more specific builder type.
-  // TODO: MLFuncBuilder builder;
-
   MLFunctionParser(ParserState &state, MLFunction *function)
-      : Parser(state), function(function) {}
+      : Parser(state), function(function), builder(function) {}
 
   ParseResult parseFunctionBody();
 
 private:
-  Statement *parseStatement();
-  ForStmt *parseForStmt();
-  IfStmt *parseIfStmt();
+  MLFunction *function;
+
+  /// This builder intentionally shadows the builder in the base class, with a
+  /// more specific builder type.
+  MLFuncBuilder builder;
+
+  ParseResult parseForStmt();
+  ParseResult parseIfStmt();
   ParseResult parseElseClause(IfClause *elseClause);
+  ParseResult parseStatements(StmtBlock *block);
   ParseResult parseStmtBlock(StmtBlock *block);
 };
 } // end anonymous namespace
@@ -1448,19 +1469,14 @@
   if (!consumeIf(Token::l_brace))
     return emitError("expected '{' in ML function");
 
-  // Make sure we have at least one statement.
-  if (getToken().is(Token::r_brace))
-    return emitError("ML function must end with return statement");
+  // Parse statements in this function
+  if (parseStatements(function))
+    return ParseFailure;
 
-  // Parse the list of instructions.
-  while (!consumeIf(Token::kw_return)) {
-    auto *stmt = parseStatement();
-    if (!stmt)
-      return ParseFailure;
-    function->push_back(stmt);
-  }
-
+  if (!consumeIf(Token::kw_return))
+    emitError("ML function must end with return statement");
   // TODO: parse return statement operands
+
   if (!consumeIf(Token::r_brace))
     emitError("expected '}' in ML function");
 
@@ -1469,42 +1485,23 @@
   return ParseSuccess;
 }
 
-/// Statement.
-///
-///    ml-stmt ::= instruction | ml-for-stmt | ml-if-stmt
-///
-/// TODO: fix terminology in MLSpec document. ML functions
-/// contain operation statements, not instructions.
-///
-Statement *MLFunctionParser::parseStatement() {
-  switch (getToken().getKind()) {
-  default:
-    //TODO: parse OperationStmt
-    return (emitError("expected statement"), nullptr);
-
-  case Token::kw_for:
-    return parseForStmt();
-
-  case Token::kw_if:
-    return parseIfStmt();
-  }
-}
-
 /// For statement.
 ///
 ///    ml-for-stmt ::= `for` ssa-id `=` lower-bound `to` upper-bound
 ///                   (`step` integer-literal)? `{` ml-stmt* `}`
 ///
-ForStmt *MLFunctionParser::parseForStmt() {
+ParseResult MLFunctionParser::parseForStmt() {
   consumeToken(Token::kw_for);
 
   //TODO: parse loop header
-  ForStmt *stmt = new ForStmt();
-  if (parseStmtBlock(static_cast<StmtBlock *>(stmt))) {
-    delete stmt;
-    return nullptr;
-  }
-  return stmt;
+  ForStmt *stmt = builder.createFor();
+
+  // If parsing of the for statement body fails
+  // MLIR contains for statement with successfully parsed nested statements
+  if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
+    return ParseFailure;
+
+  return ParseSuccess;
 }
 
 /// If statement.
@@ -1514,45 +1511,69 @@
 ///   ml-if-stmt ::= ml-if-head
 ///               | ml-if-head `else` `{` ml-stmt* `}`
 ///
-IfStmt *MLFunctionParser::parseIfStmt() {
+ParseResult MLFunctionParser::parseIfStmt() {
   consumeToken(Token::kw_if);
   if (!consumeIf(Token::l_paren))
-    return (emitError("expected ("), nullptr);
+    return emitError("expected (");
 
   //TODO: parse condition
 
   if (!consumeIf(Token::r_paren))
-    return (emitError("expected )"), nullptr);
+    return emitError("expected )");
 
-  IfStmt *ifStmt = new IfStmt();
+  IfStmt *ifStmt = builder.createIf();
   IfClause *thenClause = ifStmt->getThenClause();
-  if (parseStmtBlock(thenClause)) {
-    delete ifStmt;
-    return nullptr;
-  }
+
+  // If parsing of the then or optional else clause fails MLIR contains
+  // if statement with successfully parsed nested statements.
+  if (parseStmtBlock(thenClause))
+    return ParseFailure;
 
   if (consumeIf(Token::kw_else)) {
     IfClause *elseClause = ifStmt->createElseClause();
-    if (parseElseClause(elseClause)) {
-      delete ifStmt;
-      return nullptr;
-    }
+    if (parseElseClause(elseClause))
+      return ParseFailure;
   }
 
-  return ifStmt;
+  return ParseSuccess;
 }
 
 ParseResult MLFunctionParser::parseElseClause(IfClause *elseClause) {
   if (getToken().is(Token::kw_if)) {
-      IfStmt *nextIf = parseIfStmt();
-      if (!nextIf)
-        return ParseFailure;
-      elseClause->push_back(nextIf);
-    return ParseSuccess;
+    builder.setInsertionPoint(elseClause);
+    return parseIfStmt();
   }
 
-  if (parseStmtBlock(elseClause))
-      return ParseFailure;
+  return parseStmtBlock(elseClause);
+}
+
+///
+/// Parse a list of statements ending with `return` or `}`
+///
+ParseResult MLFunctionParser::parseStatements(StmtBlock *block) {
+  auto createOpFunc = [this](Identifier name,
+                             ArrayRef<NamedAttribute> attrs) -> Operation * {
+    return builder.createOperation(name, attrs);
+  };
+
+  builder.setInsertionPoint(block);
+
+  while (getToken().isNot(Token::kw_return, Token::r_brace)) {
+    switch (getToken().getKind()) {
+    default:
+      if (parseOperation(createOpFunc))
+        return ParseFailure;
+      break;
+    case Token::kw_for:
+      if (parseForStmt())
+        return ParseFailure;
+      break;
+    case Token::kw_if:
+      if (parseIfStmt())
+        return ParseFailure;
+      break;
+    } // end switch
+  }
 
   return ParseSuccess;
 }
@@ -1564,12 +1585,11 @@
   if (!consumeIf(Token::l_brace))
     return emitError("expected '{' before statement list");
 
-  while (!consumeIf(Token::r_brace)) {
-    auto *stmt = parseStatement();
-    if (!stmt)
-      return ParseFailure;
-    block->push_back(stmt);
-  }
+  if (parseStatements(block))
+    return ParseFailure;
+
+  if (!consumeIf(Token::r_brace))
+    return emitError("expected '}' at the end of the statement block");
 
   return ParseSuccess;
 }