Parse ML function arguments, return statement operands, and for statement loop header.
Loop bounds and presumed to be constants for now and are stored in ForStmt as affine constant expressions.  ML function arguments, return statement operands and loop variable name are dropped for now.

PiperOrigin-RevId: 205256208
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 7fbbf9b..71b4904 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1581,6 +1581,7 @@
   MLFuncBuilder builder;
 
   ParseResult parseForStmt();
+  AffineConstantExpr *parseIntConstant();
   ParseResult parseIfStmt();
   ParseResult parseElseClause(IfClause *elseClause);
   ParseResult parseStatements(StmtBlock *block);
@@ -1598,10 +1599,11 @@
 
   if (!consumeIf(Token::kw_return))
     emitError("ML function must end with return statement");
-  // TODO: parse return statement operands
 
-  if (!consumeIf(Token::r_brace))
-    emitError("expected '}' in ML function");
+  // TODO: store return operands in the IR.
+  SmallVector<SSAUseInfo, 4> dummyUseInfo;
+  if (parseOptionalSSAUseList(Token::r_brace, dummyUseInfo))
+    return ParseFailure;
 
   getModule()->functionList.push_back(function);
 
@@ -1616,17 +1618,66 @@
 ParseResult MLFunctionParser::parseForStmt() {
   consumeToken(Token::kw_for);
 
-  //TODO: parse loop header
-  ForStmt *stmt = builder.createFor();
+  // Parse induction variable
+  if (getToken().isNot(Token::percent_identifier))
+    return emitError("expected SSA identifier for the loop variable");
 
-  // If parsing of the for statement body fails
-  // MLIR contains for statement with successfully parsed nested statements
+  // TODO: create SSA value definition from name
+  StringRef name = getTokenSpelling().drop_front();
+  (void)name;
+
+  consumeToken(Token::percent_identifier);
+
+  if (!consumeIf(Token::equal))
+    return emitError("expected =");
+
+  // Parse loop bounds
+  AffineConstantExpr *lowerBound = parseIntConstant();
+  if (!lowerBound)
+    return ParseFailure;
+
+  if (!consumeIf(Token::kw_to))
+    return emitError("expected 'to' between bounds");
+
+  AffineConstantExpr *upperBound = parseIntConstant();
+  if (!upperBound)
+    return ParseFailure;
+
+  // Parse step
+  AffineConstantExpr *step = nullptr;
+  if (consumeIf(Token::kw_step)) {
+    step = parseIntConstant();
+    if (!step)
+      return ParseFailure;
+  }
+
+  // Create for statement.
+  ForStmt *stmt = builder.createFor(lowerBound, upperBound, step);
+
+  // If parsing of the for statement body fails,
+  // MLIR contains for statement with those nested statements that have been
+  // successfully parsed.
   if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
     return ParseFailure;
 
   return ParseSuccess;
 }
 
+// This method is temporary workaround to parse simple loop bounds and
+// step.
+// TODO: remove this method once it's no longer used.
+AffineConstantExpr *MLFunctionParser::parseIntConstant() {
+  if (getToken().isNot(Token::integer))
+    return (emitError("expected non-negative integer for now"), nullptr);
+
+  auto val = getToken().getUInt64IntegerValue();
+  if (!val.hasValue() || (int64_t)val.getValue() < 0) {
+    return (emitError("constant too large for affineint"), nullptr);
+  }
+  consumeToken(Token::integer);
+  return builder.getConstantExpr((int64_t)val.getValue());
+}
+
 /// If statement.
 ///
 ///   ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
@@ -1642,13 +1693,14 @@
   //TODO: parse condition
 
   if (!consumeIf(Token::r_paren))
-    return emitError("expected )");
+    return emitError("expected ')'");
 
   IfStmt *ifStmt = builder.createIf();
   IfClause *thenClause = ifStmt->getThenClause();
 
-  // If parsing of the then or optional else clause fails MLIR contains
-  // if statement with successfully parsed nested statements.
+  // When parsing of an if statement body fails, the IR contains
+  // the if statement with the portion of the body that has been
+  // successfully parsed.
   if (parseStmtBlock(thenClause))
     return ParseFailure;
 
@@ -1735,7 +1787,10 @@
   ParseResult parseAffineMapDef();
 
   // Functions.
-  ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type);
+  ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+                                  SmallVectorImpl<StringRef> &argNames);
+  ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
+                                     SmallVectorImpl<StringRef> *argNames);
   ParseResult parseExtFunc();
   ParseResult parseCFGFunc();
   ParseResult parseMLFunc();
@@ -1769,14 +1824,50 @@
   return ParseSuccess;
 }
 
+/// Parse a (possibly empty) list of MLFunction arguments with types.
+///
+/// ml-argument      ::= ssa-id `:` type
+/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
+///
+ParseResult
+ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+                                  SmallVectorImpl<StringRef> &argNames) {
+  auto parseElt = [&]() -> ParseResult {
+    // Parse argument name
+    if (getToken().isNot(Token::percent_identifier))
+      return emitError("expected SSA identifier");
+
+    StringRef name = getTokenSpelling().drop_front();
+    consumeToken(Token::percent_identifier);
+    argNames.push_back(name);
+
+    if (!consumeIf(Token::colon))
+      return emitError("expected ':'");
+
+    // Parse argument type
+    auto elt = parseType();
+    if (!elt)
+      return ParseFailure;
+    argTypes.push_back(elt);
+
+    return ParseSuccess;
+  };
+
+  if (!consumeIf(Token::l_paren))
+    llvm_unreachable("expected '('");
+
+  return parseCommaSeparatedList(Token::r_paren, parseElt);
+}
+
 /// Parse a function signature, starting with a name and including the parameter
 /// list.
 ///
-///   argument-list ::= type (`,` type)* | /*empty*/
+///   argument-list ::= type (`,` type)* | /*empty*/ | ml-argument-list
 ///   function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
 ///
-ParseResult ModuleParser::parseFunctionSignature(StringRef &name,
-                                                 FunctionType *&type) {
+ParseResult
+ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
+                                     SmallVectorImpl<StringRef> *argNames) {
   if (getToken().isNot(Token::at_identifier))
     return emitError("expected a function identifier like '@foo'");
 
@@ -1786,8 +1877,15 @@
   if (getToken().isNot(Token::l_paren))
     return emitError("expected '(' in function signature");
 
-  SmallVector<Type *, 4> arguments;
-  if (parseTypeList(arguments))
+  SmallVector<Type *, 4> argTypes;
+  ParseResult parseResult;
+
+  if (argNames)
+    parseResult = parseMLArgumentList(argTypes, *argNames);
+  else
+    parseResult = parseTypeList(argTypes);
+
+  if (parseResult)
     return ParseFailure;
 
   // Parse the return type if present.
@@ -1796,7 +1894,7 @@
     if (parseTypeList(results))
       return ParseFailure;
   }
-  type = builder.getFunctionType(arguments, results);
+  type = builder.getFunctionType(argTypes, results);
   return ParseSuccess;
 }
 
@@ -1809,7 +1907,7 @@
 
   StringRef name;
   FunctionType *type = nullptr;
-  if (parseFunctionSignature(name, type))
+  if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
     return ParseFailure;
 
   // Okay, the external function definition was parsed correctly.
@@ -1826,7 +1924,7 @@
 
   StringRef name;
   FunctionType *type = nullptr;
-  if (parseFunctionSignature(name, type))
+  if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
     return ParseFailure;
 
   // Okay, the CFG function signature was parsed correctly, create the function.
@@ -1844,10 +1942,11 @@
 
   StringRef name;
   FunctionType *type = nullptr;
-
+  SmallVector<StringRef, 4> argNames;
   // FIXME: Parse ML function signature (args + types)
   // by passing pointer to SmallVector<identifier> into parseFunctionSignature
-  if (parseFunctionSignature(name, type))
+
+  if (parseFunctionSignature(name, type, &argNames))
     return ParseFailure;
 
   // Okay, the ML function signature was parsed correctly, create the function.