Parse ML function arguments, return statement operands, and for statement loop header.
Loop bounds and presumed to be constants for now and are stored in ForStmt as affine constant expressions.  ML function arguments, return statement operands and loop variable name are dropped for now.

PiperOrigin-RevId: 205256208
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index c41a886..1da2312 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -184,11 +184,10 @@
     return op;
   }
 
-  ForStmt *createFor() {
-    auto stmt = new ForStmt();
-    block->getStatements().push_back(stmt);
-    return stmt;
-  }
+  // Creates for statement. When step is not specified, it is set to 1. 
+  ForStmt *createFor(AffineConstantExpr *lowerBound,
+                     AffineConstantExpr *upperBound,
+                     AffineConstantExpr *step = nullptr);
 
   IfStmt *createIf() {
     auto stmt = new IfStmt();
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 902bf35..72bcd11 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -22,10 +22,11 @@
 #ifndef MLIR_IR_STATEMENTS_H
 #define MLIR_IR_STATEMENTS_H
 
-#include "mlir/Support/LLVM.h"
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Statement.h"
 #include "mlir/IR/StmtBlock.h"
+#include "mlir/Support/LLVM.h"
 
 namespace mlir {
 
@@ -50,11 +51,20 @@
 /// For statement represents an affine loop nest.
 class ForStmt : public Statement, public StmtBlock {
 public:
-  explicit ForStmt() : Statement(Kind::For), StmtBlock(StmtBlockKind::For) {}
+  // TODO: lower and upper bounds should be affine maps with
+  // dimension and symbol use lists.
+  explicit ForStmt(AffineConstantExpr *lowerBound,
+                   AffineConstantExpr *upperBound, AffineConstantExpr *step)
+      : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
+        lowerBound(lowerBound), upperBound(upperBound), step(step) {}
+
   //TODO: delete nested statements or assert that they are gone.
   ~ForStmt() {}
 
-  // TODO: represent loop variable, bounds and step
+  // TODO: represent induction variable
+  AffineConstantExpr *getLowerBound() const { return lowerBound; }
+  AffineConstantExpr *getUpperBound() const { return upperBound; }
+  AffineConstantExpr *getStep() const { return step; }
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Statement *stmt) {
@@ -64,9 +74,14 @@
   static bool classof(const StmtBlock *block) {
     return block->getStmtBlockKind() == StmtBlockKind::For;
   }
+
+private:
+  AffineConstantExpr *lowerBound;
+  AffineConstantExpr *upperBound;
+  AffineConstantExpr *step;
 };
 
-/// If clause represents statements contained within then or else clause
+/// An if clause represents statements contained within a then or an else clause
 /// of an if statement.
 class IfClause : public StmtBlock {
 public:
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 424487c..570ae49 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -594,7 +594,12 @@
 void MLFunctionState::print(const OperationStmt *stmt) { printOperation(stmt); }
 
 void MLFunctionState::print(const ForStmt *stmt) {
-  os.indent(numSpaces) << "for {\n";
+  os.indent(numSpaces) << "for x = " << *stmt->getLowerBound();
+  os << " to " << *stmt->getUpperBound();
+  if (stmt->getStep()->getValue() != 1)
+    os << " step " << *stmt->getStep();
+
+  os << " {\n";
   print(static_cast<const StmtBlock *>(stmt));
   os.indent(numSpaces) << "}";
 }
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 8d27991..3d7e023 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -143,3 +143,17 @@
 AffineExpr *Builder::getCeilDivExpr(AffineExpr *lhs, AffineExpr *rhs) {
   return AffineCeilDivExpr::get(lhs, rhs, context);
 }
+
+//===----------------------------------------------------------------------===//
+// Statements
+//===----------------------------------------------------------------------===//
+
+ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
+                                  AffineConstantExpr *upperBound,
+                                  AffineConstantExpr *step) {
+  if (!step)
+    step = getConstantExpr(1);
+  auto stmt = new ForStmt(lowerBound, upperBound, step);
+  block->getStatements().push_back(stmt);
+  return stmt;
+}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 7fbbf9b..71b4904 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1581,6 +1581,7 @@
   MLFuncBuilder builder;
 
   ParseResult parseForStmt();
+  AffineConstantExpr *parseIntConstant();
   ParseResult parseIfStmt();
   ParseResult parseElseClause(IfClause *elseClause);
   ParseResult parseStatements(StmtBlock *block);
@@ -1598,10 +1599,11 @@
 
   if (!consumeIf(Token::kw_return))
     emitError("ML function must end with return statement");
-  // TODO: parse return statement operands
 
-  if (!consumeIf(Token::r_brace))
-    emitError("expected '}' in ML function");
+  // TODO: store return operands in the IR.
+  SmallVector<SSAUseInfo, 4> dummyUseInfo;
+  if (parseOptionalSSAUseList(Token::r_brace, dummyUseInfo))
+    return ParseFailure;
 
   getModule()->functionList.push_back(function);
 
@@ -1616,17 +1618,66 @@
 ParseResult MLFunctionParser::parseForStmt() {
   consumeToken(Token::kw_for);
 
-  //TODO: parse loop header
-  ForStmt *stmt = builder.createFor();
+  // Parse induction variable
+  if (getToken().isNot(Token::percent_identifier))
+    return emitError("expected SSA identifier for the loop variable");
 
-  // If parsing of the for statement body fails
-  // MLIR contains for statement with successfully parsed nested statements
+  // TODO: create SSA value definition from name
+  StringRef name = getTokenSpelling().drop_front();
+  (void)name;
+
+  consumeToken(Token::percent_identifier);
+
+  if (!consumeIf(Token::equal))
+    return emitError("expected =");
+
+  // Parse loop bounds
+  AffineConstantExpr *lowerBound = parseIntConstant();
+  if (!lowerBound)
+    return ParseFailure;
+
+  if (!consumeIf(Token::kw_to))
+    return emitError("expected 'to' between bounds");
+
+  AffineConstantExpr *upperBound = parseIntConstant();
+  if (!upperBound)
+    return ParseFailure;
+
+  // Parse step
+  AffineConstantExpr *step = nullptr;
+  if (consumeIf(Token::kw_step)) {
+    step = parseIntConstant();
+    if (!step)
+      return ParseFailure;
+  }
+
+  // Create for statement.
+  ForStmt *stmt = builder.createFor(lowerBound, upperBound, step);
+
+  // If parsing of the for statement body fails,
+  // MLIR contains for statement with those nested statements that have been
+  // successfully parsed.
   if (parseStmtBlock(static_cast<StmtBlock *>(stmt)))
     return ParseFailure;
 
   return ParseSuccess;
 }
 
+// This method is temporary workaround to parse simple loop bounds and
+// step.
+// TODO: remove this method once it's no longer used.
+AffineConstantExpr *MLFunctionParser::parseIntConstant() {
+  if (getToken().isNot(Token::integer))
+    return (emitError("expected non-negative integer for now"), nullptr);
+
+  auto val = getToken().getUInt64IntegerValue();
+  if (!val.hasValue() || (int64_t)val.getValue() < 0) {
+    return (emitError("constant too large for affineint"), nullptr);
+  }
+  consumeToken(Token::integer);
+  return builder.getConstantExpr((int64_t)val.getValue());
+}
+
 /// If statement.
 ///
 ///   ml-if-head ::= `if` ml-if-cond `{` ml-stmt* `}`
@@ -1642,13 +1693,14 @@
   //TODO: parse condition
 
   if (!consumeIf(Token::r_paren))
-    return emitError("expected )");
+    return emitError("expected ')'");
 
   IfStmt *ifStmt = builder.createIf();
   IfClause *thenClause = ifStmt->getThenClause();
 
-  // If parsing of the then or optional else clause fails MLIR contains
-  // if statement with successfully parsed nested statements.
+  // When parsing of an if statement body fails, the IR contains
+  // the if statement with the portion of the body that has been
+  // successfully parsed.
   if (parseStmtBlock(thenClause))
     return ParseFailure;
 
@@ -1735,7 +1787,10 @@
   ParseResult parseAffineMapDef();
 
   // Functions.
-  ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type);
+  ParseResult parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+                                  SmallVectorImpl<StringRef> &argNames);
+  ParseResult parseFunctionSignature(StringRef &name, FunctionType *&type,
+                                     SmallVectorImpl<StringRef> *argNames);
   ParseResult parseExtFunc();
   ParseResult parseCFGFunc();
   ParseResult parseMLFunc();
@@ -1769,14 +1824,50 @@
   return ParseSuccess;
 }
 
+/// Parse a (possibly empty) list of MLFunction arguments with types.
+///
+/// ml-argument      ::= ssa-id `:` type
+/// ml-argument-list ::= ml-argument (`,` ml-argument)* | /*empty*/
+///
+ParseResult
+ModuleParser::parseMLArgumentList(SmallVectorImpl<Type *> &argTypes,
+                                  SmallVectorImpl<StringRef> &argNames) {
+  auto parseElt = [&]() -> ParseResult {
+    // Parse argument name
+    if (getToken().isNot(Token::percent_identifier))
+      return emitError("expected SSA identifier");
+
+    StringRef name = getTokenSpelling().drop_front();
+    consumeToken(Token::percent_identifier);
+    argNames.push_back(name);
+
+    if (!consumeIf(Token::colon))
+      return emitError("expected ':'");
+
+    // Parse argument type
+    auto elt = parseType();
+    if (!elt)
+      return ParseFailure;
+    argTypes.push_back(elt);
+
+    return ParseSuccess;
+  };
+
+  if (!consumeIf(Token::l_paren))
+    llvm_unreachable("expected '('");
+
+  return parseCommaSeparatedList(Token::r_paren, parseElt);
+}
+
 /// Parse a function signature, starting with a name and including the parameter
 /// list.
 ///
-///   argument-list ::= type (`,` type)* | /*empty*/
+///   argument-list ::= type (`,` type)* | /*empty*/ | ml-argument-list
 ///   function-signature ::= function-id `(` argument-list `)` (`->` type-list)?
 ///
-ParseResult ModuleParser::parseFunctionSignature(StringRef &name,
-                                                 FunctionType *&type) {
+ParseResult
+ModuleParser::parseFunctionSignature(StringRef &name, FunctionType *&type,
+                                     SmallVectorImpl<StringRef> *argNames) {
   if (getToken().isNot(Token::at_identifier))
     return emitError("expected a function identifier like '@foo'");
 
@@ -1786,8 +1877,15 @@
   if (getToken().isNot(Token::l_paren))
     return emitError("expected '(' in function signature");
 
-  SmallVector<Type *, 4> arguments;
-  if (parseTypeList(arguments))
+  SmallVector<Type *, 4> argTypes;
+  ParseResult parseResult;
+
+  if (argNames)
+    parseResult = parseMLArgumentList(argTypes, *argNames);
+  else
+    parseResult = parseTypeList(argTypes);
+
+  if (parseResult)
     return ParseFailure;
 
   // Parse the return type if present.
@@ -1796,7 +1894,7 @@
     if (parseTypeList(results))
       return ParseFailure;
   }
-  type = builder.getFunctionType(arguments, results);
+  type = builder.getFunctionType(argTypes, results);
   return ParseSuccess;
 }
 
@@ -1809,7 +1907,7 @@
 
   StringRef name;
   FunctionType *type = nullptr;
-  if (parseFunctionSignature(name, type))
+  if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
     return ParseFailure;
 
   // Okay, the external function definition was parsed correctly.
@@ -1826,7 +1924,7 @@
 
   StringRef name;
   FunctionType *type = nullptr;
-  if (parseFunctionSignature(name, type))
+  if (parseFunctionSignature(name, type, /*arguments*/ nullptr))
     return ParseFailure;
 
   // Okay, the CFG function signature was parsed correctly, create the function.
@@ -1844,10 +1942,11 @@
 
   StringRef name;
   FunctionType *type = nullptr;
-
+  SmallVector<StringRef, 4> argNames;
   // FIXME: Parse ML function signature (args + types)
   // by passing pointer to SmallVector<identifier> into parseFunctionSignature
-  if (parseFunctionSignature(name, type))
+
+  if (parseFunctionSignature(name, type, &argNames))
     return ParseFailure;
 
   // Okay, the ML function signature was parsed correctly, create the function.
diff --git a/lib/Parser/Token.h b/lib/Parser/Token.h
index 73baaac..f847bcf4 100644
--- a/lib/Parser/Token.h
+++ b/lib/Parser/Token.h
@@ -76,7 +76,7 @@
   /// return None.
   Optional<unsigned> getUnsignedIntegerValue() const;
 
-  /// For an integer token, return its value as an int64_t.  If it doesn't fit,
+  /// For an integer token, return its value as an uint64_t.  If it doesn't fit,
   /// return None.
   Optional<uint64_t> getUInt64IntegerValue() const;
 
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index dda5cae..de6758c 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -104,7 +104,9 @@
 TOK_KEYWORD(mod)
 TOK_KEYWORD(return)
 TOK_KEYWORD(size)
+TOK_KEYWORD(step)
 TOK_KEYWORD(tensor)
+TOK_KEYWORD(to)
 TOK_KEYWORD(true)
 TOK_KEYWORD(vector)
 
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index 4f67c8e..49fd2da 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -130,12 +130,24 @@
 
 // -----
 
+mlfunc @malformed_for() {
+  for %i = 1 too 10 { // expected-error {{expected 'to' between bounds}}
+  }
+}
+
+// -----
+
 mlfunc @incomplete_for() {
-  for
+  for %i = 1 to 10 step 2
 }        // expected-error {{expected '{' before statement list}}
 
 // -----
 
+mlfunc @nonconstant_step(%1 : i32) {
+  for %2 = 1 to 5 step %1 { // expected-error {{expected non-negative integer for now}}
+
+// -----
+
 mlfunc @non_statement() {
   asd   // expected-error {{expected operation name in quotes}}
 }
@@ -160,7 +172,6 @@
   return
 }
 
-
 // -----
 
 cfgfunc @redef() {
@@ -168,4 +179,16 @@
   %x = "dim"(){index: 0} : ()->i32
   %x = "dim"(){index: 0} : ()->i32 // expected-error {{redefinition of SSA value %x}}
   return
-}
\ No newline at end of file
+}
+
+mlfunc @missing_rbrace() {
+  return %a
+mlfunc @d {return} // expected-error {{expected ',' or '}'}}
+
+// -----
+
+mlfunc @malformed_type(%a : intt) { // expected-error {{expected type}}
+}
+
+// -----
+
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 3a7986c..9b384b4 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -96,22 +96,32 @@
   return     // CHECK:  return
 }            // CHECK: }
 
+// CHECK-LABEL: mlfunc @mlfunc_with_args(f16) {
+mlfunc @mlfunc_with_args(%a : f16) {
+  return  %a  // CHECK: return
+}
+
 // CHECK-LABEL: cfgfunc @cfgfunc_with_ops() {
 cfgfunc @cfgfunc_with_ops() {
 bb0:
   %t = "getTensor"() : () -> tensor<4x4x?xf32>
+
   // CHECK: dim xxx, 2 : sometype
   %a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
 
   // CHECK: addf xx, yy : sometype
   "addf"() : () -> ()
+
+  // CHECK:   return
   return
 }
 
 // CHECK-LABEL: mlfunc @loops() {
 mlfunc @loops() {
-  for {      // CHECK:   for {
-    for {    // CHECK:     for {
+  // CHECK: for x = 1 to 100 step 2 {
+  for %i = 1 to 100 step 2 {
+    // CHECK: for x = 1 to 200 {
+    for %j = 1 to 200 {
     }        // CHECK:     }
   }          // CHECK:   }
   return     // CHECK:   return
@@ -119,14 +129,14 @@
 
 // CHECK-LABEL: mlfunc @ifstmt() {
 mlfunc @ifstmt() {
-  for {             // CHECK   for {
-    if () {         // CHECK     if () {
-    } else if () {  // CHECK     } else if () {
-    } else {        // CHECK     } else {
-    }               // CHECK     }
-  }                 // CHECK   }
-  return            // CHECK   return
-}                   // CHECK }
+  for %i = 1 to 10 {    // CHECK   for x = 1 to 10 {
+    if () {             // CHECK     if () {
+    } else if () {      // CHECK     } else if () {
+    } else {            // CHECK     } else {
+    }                   // CHECK     }
+  }                     // CHECK   }
+  return                // CHECK   return
+}                       // CHECK }
 
 // CHECK-LABEL: cfgfunc @attributes() {
 cfgfunc @attributes() {