Implement ML function arguments. Add representation for argument list in ML Function using TrailingObjects template. Implement argument iterators, parsing and printing.

Unrelated minor change - remove OperationStmt::dropReferences(). Since MLFunction does not have cyclic operand references (it's an AST) destruction can be safely done w/o a special pass to drop references.

PiperOrigin-RevId: 207583024
diff --git a/include/mlir/IR/CFGValue.h b/include/mlir/IR/CFGValue.h
index f8c2f23..d50792d 100644
--- a/include/mlir/IR/CFGValue.h
+++ b/include/mlir/IR/CFGValue.h
@@ -49,7 +49,7 @@
     case SSAValueKind::InstResult:
       return true;
 
-    case SSAValueKind::FnArgument:
+    case SSAValueKind::MLFuncArgument:
     case SSAValueKind::StmtResult:
     case SSAValueKind::ForStmt:
       return false;
diff --git a/include/mlir/IR/MLFunction.h b/include/mlir/IR/MLFunction.h
index 24e8353..248828c 100644
--- a/include/mlir/IR/MLFunction.h
+++ b/include/mlir/IR/MLFunction.h
@@ -23,30 +23,146 @@
 #define MLIR_IR_MLFUNCTION_H_
 
 #include "mlir/IR/Function.h"
+#include "mlir/IR/MLValue.h"
+#include "mlir/IR/Operation.h"
 #include "mlir/IR/StmtBlock.h"
+#include "llvm/Support/TrailingObjects.h"
 
 namespace mlir {
 
+template <typename ObjectType, typename ElementType> class ArgumentIterator;
+
 // MLFunction is defined as a sequence of statements that may
 // include nested affine for loops, conditionals and operations.
-class MLFunction : public Function, public StmtBlock {
+class MLFunction final
+    : public Function,
+      public StmtBlock,
+      private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
 public:
-  MLFunction(StringRef name, FunctionType *type);
-  ~MLFunction();
+  /// Creates a new MLFunction with the specific fields.
+  static MLFunction *create(StringRef name, FunctionType *type);
 
-  // TODO: add function arguments and return values once
-  // SSA values are implemented
+  /// Destroys this statement and its subclass data.
+  void destroy();
+
+  //===--------------------------------------------------------------------===//
+  // Arguments
+  //===--------------------------------------------------------------------===//
+
+  /// Returns number of arguments.
+  unsigned getNumArguments() const { return getType()->getInputs().size(); }
+
+  /// Gets argument.
+  MLFuncArgument *getArgument(unsigned idx) {
+    return &getArgumentsInternal()[idx];
+  }
+
+  const MLFuncArgument *getArgument(unsigned idx) const {
+    return &getArgumentsInternal()[idx];
+  }
+
+  // Supports non-const operand iteration.
+  using args_iterator = ArgumentIterator<MLFunction, MLFuncArgument>;
+  args_iterator args_begin();
+  args_iterator args_end();
+  llvm::iterator_range<args_iterator> getArguments();
+
+  // Supports const operand iteration.
+  using const_args_iterator =
+      ArgumentIterator<const MLFunction, const MLFuncArgument>;
+  const_args_iterator args_begin() const;
+  const_args_iterator args_end() const;
+  llvm::iterator_range<const_args_iterator> getArguments() const;
+
+  //===--------------------------------------------------------------------===//
+  // Other
+  //===--------------------------------------------------------------------===//
+
+  ~MLFunction();
 
   /// Methods for support type inquiry through isa, cast, and dyn_cast.
   static bool classof(const Function *func) {
     return func->getKind() == Function::Kind::MLFunc;
   }
-
   static bool classof(const StmtBlock *block) {
     return block->getStmtBlockKind() == StmtBlockKind::MLFunc;
   }
+
+private:
+  MLFunction(StringRef name, FunctionType *type);
+
+  // This stuff is used by the TrailingObjects template.
+  friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
+  size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
+    return getType()->getInputs().size();
+  }
+
+  // Internal functions to get argument list used by getArgument() methods.
+  ArrayRef<MLFuncArgument> getArgumentsInternal() const {
+    return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
+  }
+  MutableArrayRef<MLFuncArgument> getArgumentsInternal() {
+    return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
+  }
 };
 
+//===--------------------------------------------------------------------===//
+// ArgumentIterator
+//===--------------------------------------------------------------------===//
+
+/// This template implements the argument iterator in terms of getArgument(idx).
+template <typename ObjectType, typename ElementType>
+class ArgumentIterator final
+    : public IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
+                                     ObjectType, ElementType> {
+public:
+  /// Initializes the result iterator to the specified index.
+  ArgumentIterator(ObjectType *object, unsigned index)
+      : IndexedAccessorIterator<ArgumentIterator<ObjectType, ElementType>,
+                                ObjectType, ElementType>(object, index) {}
+
+  /// Support converting to the const variant. This will be a no-op for const
+  /// variant.
+  operator ArgumentIterator<const ObjectType, const ElementType>() const {
+    return ArgumentIterator<const ObjectType, const ElementType>(this->object,
+                                                                 this->index);
+  }
+
+  ElementType *operator*() const {
+    return this->object->getArgument(this->index);
+  }
+};
+
+//===--------------------------------------------------------------------===//
+// MLFunction iterator methods.
+//===--------------------------------------------------------------------===//
+
+inline MLFunction::args_iterator MLFunction::args_begin() {
+  return args_iterator(this, 0);
+}
+
+inline MLFunction::args_iterator MLFunction::args_end() {
+  return args_iterator(this, getNumArguments());
+}
+
+inline llvm::iterator_range<MLFunction::args_iterator>
+MLFunction::getArguments() {
+  return {args_begin(), args_end()};
+}
+
+inline MLFunction::const_args_iterator MLFunction::args_begin() const {
+  return const_args_iterator(this, 0);
+}
+
+inline MLFunction::const_args_iterator MLFunction::args_end() const {
+  return const_args_iterator(this, getNumArguments());
+}
+
+inline llvm::iterator_range<MLFunction::const_args_iterator>
+MLFunction::getArguments() const {
+  return {args_begin(), args_end()};
+}
+
 } // end namespace mlir
 
 #endif  // MLIR_IR_MLFUNCTION_H_
diff --git a/include/mlir/IR/MLValue.h b/include/mlir/IR/MLValue.h
index 911207a..99e4628 100644
--- a/include/mlir/IR/MLValue.h
+++ b/include/mlir/IR/MLValue.h
@@ -34,7 +34,7 @@
 /// function.  This should be kept as a proper subtype of SSAValueKind,
 /// including having all of the values of the enumerators align.
 enum class MLValueKind {
-  FnArgument = (int)SSAValueKind::FnArgument,
+  MLFuncArgument = (int)SSAValueKind::MLFuncArgument,
   StmtResult = (int)SSAValueKind::StmtResult,
   ForStmt = (int)SSAValueKind::ForStmt,
 };
@@ -47,7 +47,7 @@
 public:
   static bool classof(const SSAValue *value) {
     switch (value->getKind()) {
-    case SSAValueKind::FnArgument:
+    case SSAValueKind::MLFuncArgument:
     case SSAValueKind::StmtResult:
     case SSAValueKind::ForStmt:
       return true;
@@ -63,10 +63,10 @@
 };
 
 /// This is the value defined by an argument of an ML function.
-class FnArgument : public MLValue {
+class MLFuncArgument : public MLValue {
 public:
   static bool classof(const SSAValue *value) {
-    return value->getKind() == SSAValueKind::FnArgument;
+    return value->getKind() == SSAValueKind::MLFuncArgument;
   }
 
   MLFunction *getOwner() { return owner; }
@@ -74,8 +74,8 @@
 
 private:
   friend class MLFunction; // For access to private constructor.
-  FnArgument(Type *type, MLFunction *owner)
-      : MLValue(MLValueKind::FnArgument, type), owner(owner) {}
+  MLFuncArgument(Type *type, MLFunction *owner)
+      : MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
 
   /// The owner of this operand.
   /// TODO: can encode this more efficiently to avoid the space hit of this
diff --git a/include/mlir/IR/SSAValue.h b/include/mlir/IR/SSAValue.h
index a584d67c..3a6e901 100644
--- a/include/mlir/IR/SSAValue.h
+++ b/include/mlir/IR/SSAValue.h
@@ -34,11 +34,11 @@
 
 /// This enumerates all of the SSA value kinds in the MLIR system.
 enum class SSAValueKind {
-  BBArgument, // basic block argument
-  InstResult, // instruction result
-  FnArgument, // ML function argument
-  StmtResult, // statement result
-  ForStmt,    // for statement induction variable
+  BBArgument,     // basic block argument
+  InstResult,     // instruction result
+  MLFuncArgument, // ML function argument
+  StmtResult,     // statement result
+  ForStmt,        // for statement induction variable
 };
 
 /// This is the common base class for all values in the MLIR system,
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6875120..cb7560e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -200,6 +200,7 @@
   ModuleState &state;
 
   void printFunctionSignature(const Function *fn);
+  void printFunctionResultType(const FunctionType *type);
   void printAffineMapId(int affineMapId) const;
   void printAffineMapReference(const AffineMap *affineMap);
 
@@ -527,14 +528,7 @@
 // Function printing
 //===----------------------------------------------------------------------===//
 
-void ModulePrinter::printFunctionSignature(const Function *fn) {
-  auto type = fn->getType();
-
-  os << "@" << fn->getName() << '(';
-  interleaveComma(type->getInputs(),
-                  [&](Type *eltType) { printType(eltType); });
-  os << ')';
-
+void ModulePrinter::printFunctionResultType(const FunctionType *type) {
   switch (type->getResults().size()) {
   case 0:
     break;
@@ -551,6 +545,17 @@
   }
 }
 
+void ModulePrinter::printFunctionSignature(const Function *fn) {
+  auto type = fn->getType();
+
+  os << "@" << fn->getName() << '(';
+  interleaveComma(type->getInputs(),
+                  [&](Type *eltType) { printType(eltType); });
+  os << ')';
+
+  printFunctionResultType(type);
+}
+
 void ModulePrinter::print(const ExtFunction *fn) {
   os << "extfunc ";
   printFunctionSignature(fn);
@@ -627,7 +632,7 @@
         // done with it.
         valueIDs[value] = nextValueID++;
         return;
-      case SSAValueKind::FnArgument:
+      case SSAValueKind::MLFuncArgument:
         specialName << "arg" << nextArgumentID++;
         break;
       case SSAValueKind::ForStmt:
@@ -1018,6 +1023,9 @@
   // Prints ML function
   void print();
 
+  // Prints ML function signature
+  void printFunctionSignature();
+
   // Methods to print ML function statements
   void print(const Statement *stmt);
   void print(const OperationStmt *stmt);
@@ -1039,12 +1047,18 @@
 MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
                                      const ModulePrinter &other)
     : FunctionPrinter(other), function(function), numSpaces(0) {
+  assert(function && "Cannot print nullptr function");
   numberValues();
 }
 
 /// Number all of the SSA values in this ML function.
 void MLFunctionPrinter::numberValues() {
-  // Visits all operation statements and numbers the first result.
+  // Numbers ML function arguments
+  for (auto *arg : function->getArguments())
+    numberValueID(arg);
+
+  // Walks ML function statements and numbers for statements and
+  // the first result of the operation statements.
   struct NumberValuesPass : public StmtWalker<NumberValuesPass> {
     NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
     void visitOperationStmt(OperationStmt *stmt) {
@@ -1062,14 +1076,30 @@
 
 void MLFunctionPrinter::print() {
   os << "mlfunc ";
-  // FIXME: should print argument names rather than just signature
-  printFunctionSignature(function);
+  printFunctionSignature();
   os << " {\n";
   print(function);
   os << "  return\n";
   os << "}\n\n";
 }
 
+void MLFunctionPrinter::printFunctionSignature() {
+  auto type = function->getType();
+
+  os << "@" << function->getName() << '(';
+
+  for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
+    if (i > 0)
+      os << ", ";
+    auto *arg = function->getArgument(i);
+    printOperand(arg);
+    os << " : ";
+    printType(arg->getType());
+  }
+  os << ")";
+  printFunctionResultType(type);
+}
+
 void MLFunctionPrinter::print(const StmtBlock *block) {
   numSpaces += indentWidth;
   for (auto &stmt : block->getStatements()) {
@@ -1169,7 +1199,7 @@
     return;
   case SSAValueKind::InstResult:
     return getDefiningInst()->print(os);
-  case SSAValueKind::FnArgument:
+  case SSAValueKind::MLFuncArgument:
     // TODO: Improve this.
     os << "<function argument>\n";
     return;
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 96c6dfc..4719b35 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -36,7 +36,7 @@
     delete cast<ExtFunction>(this);
     break;
   case Kind::MLFunc:
-    delete cast<MLFunction>(this);
+    cast<MLFunction>(this)->destroy();
     break;
   case Kind::CFGFunc:
     delete cast<CFGFunction>(this);
@@ -118,13 +118,37 @@
 // MLFunction implementation.
 //===----------------------------------------------------------------------===//
 
+/// Create a new MLFunction with the specific fields.
+MLFunction *MLFunction::create(StringRef name, FunctionType *type) {
+  const auto &argTypes = type->getInputs();
+  auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
+  void *rawMem = malloc(byteSize);
+
+  // Initialize the MLFunction part of the function object.
+  auto function = ::new (rawMem) MLFunction(name, type);
+
+  // Initialize the arguments.
+  auto arguments = function->getArgumentsInternal();
+  for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
+    new (&arguments[i]) MLFuncArgument(argTypes[i], function);
+  return function;
+}
+
 MLFunction::MLFunction(StringRef name, FunctionType *type)
     : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
 
 MLFunction::~MLFunction() {
-  struct DropReferencesPass : public StmtWalker<DropReferencesPass> {
-    void visitOperationStmt(OperationStmt *stmt) { stmt->dropAllReferences(); }
-  };
-  DropReferencesPass pass;
-  pass.walk(const_cast<MLFunction *>(this));
+  // Explicitly erase statements instead of relying of 'StmtBlock' destructor
+  // since child statements need to be destroyed before function arguments
+  // are destroyed.
+  clear();
+
+  // Explicitly run the destructors for the function arguments.
+  for (auto &arg : getArgumentsInternal())
+    arg.~MLFuncArgument();
+}
+
+void MLFunction::destroy() {
+  this->~MLFunction();
+  free(this);
 }
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 44e44c8..0a76587 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -221,14 +221,6 @@
   return findFunction()->getContext();
 }
 
-/// This drops all operand uses from this statement, which is an essential
-/// step in breaking cyclic dependences between references when they are to
-/// be deleted.
-void OperationStmt::dropAllReferences() {
-  for (auto &op : getStmtOperands())
-    op.drop();
-}
-
 //===----------------------------------------------------------------------===//
 // ForStmt
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 45df86a..ee8c68f 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -332,7 +332,8 @@
   HashTable::ScopeTy topScope(liveValues);
 
   // All of the arguments to the function are live for the whole function.
-  // TODO: Add arguments when they are supported.
+  for (auto *arg : fn.getArguments())
+    liveValues.insert(arg, true);
 
   // This recursive function walks the statement list pushing scopes onto the
   // stack as it goes, and popping them to remove them from the table.
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index ff3811f..e84cb86 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -2371,7 +2371,7 @@
     if (getToken().isNot(Token::percent_identifier))
       return emitError("expected SSA identifier");
 
-    StringRef name = getTokenSpelling().drop_front();
+    StringRef name = getTokenSpelling();
     consumeToken(Token::percent_identifier);
     argNames.push_back(name);
 
@@ -2474,16 +2474,24 @@
   StringRef name;
   FunctionType *type = nullptr;
   SmallVector<StringRef, 4> argNames;
-  // FIXME: Parse ML function signature (args + types)
-  // by passing pointer to SmallVector<identifier> into parseFunctionSignature
 
+  auto loc = getToken().getLoc();
   if (parseFunctionSignature(name, type, &argNames))
     return ParseFailure;
 
   // Okay, the ML function signature was parsed correctly, create the function.
-  auto function = new MLFunction(name, type);
+  auto function = MLFunction::create(name, type);
 
-  return MLFunctionParser(getState(), function).parseFunctionBody();
+  // Create the parser.
+  auto parser = MLFunctionParser(getState(), function);
+
+  // Add definitions of the function arguments.
+  for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
+    if (parser.addDefinition({argNames[i], 0, loc}, function->getArgument(i)))
+      return ParseFailure;
+  }
+
+  return parser.parseFunctionBody();
 }
 
 /// This is the top-level module parser.
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index b2d1a3d..7c218f7 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -107,8 +107,15 @@
   return     // CHECK:  return
 }            // CHECK: }
 
-// CHECK-LABEL: mlfunc @mlfunc_with_args(f16) {
-mlfunc @mlfunc_with_args(%a : f16) {
+// CHECK-LABEL: mlfunc @mlfunc_with_one_arg(%arg0 : i1) {
+mlfunc @mlfunc_with_one_arg(%c : i1) {
+  // CHECK: %0 = "foo"(%arg0) : (i1) -> i2
+  %b = "foo"(%c) : (i1) -> (i2)
+  return     // CHECK: return
+}
+
+// CHECK-LABEL: mlfunc @mlfunc_with_args(%arg0 : f16, %arg1 : i8) {
+mlfunc @mlfunc_with_args(%a : f16, %b : i8) {
   return  %a  // CHECK: return
 }
 
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index aef6a01..f091ba1 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -101,7 +101,7 @@
 
 
 // Imperfect loop nest. Unrolling innermost here yields a perfect nest.
-// CHECK-LABEL: mlfunc @loop_nest_seq_imperfect(memref<128x128xf32>) {
+// CHECK-LABEL: mlfunc @loop_nest_seq_imperfect(%arg0 : memref<128x128xf32>) {
 mlfunc @loop_nest_seq_imperfect(%a : memref<128x128xf32>) {
   // CHECK: %c1_i32 = constant 1 : i32
   // CHECK-NEXT: %c2_i32 = constant 2 : i32