Implement ML function arguments. Add representation for argument list in ML Function using TrailingObjects template. Implement argument iterators, parsing and printing.

Unrelated minor change - remove OperationStmt::dropReferences(). Since MLFunction does not have cyclic operand references (it's an AST) destruction can be safely done w/o a special pass to drop references.

PiperOrigin-RevId: 207583024
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6875120..cb7560e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -200,6 +200,7 @@
   ModuleState &state;
 
   void printFunctionSignature(const Function *fn);
+  void printFunctionResultType(const FunctionType *type);
   void printAffineMapId(int affineMapId) const;
   void printAffineMapReference(const AffineMap *affineMap);
 
@@ -527,14 +528,7 @@
 // Function printing
 //===----------------------------------------------------------------------===//
 
-void ModulePrinter::printFunctionSignature(const Function *fn) {
-  auto type = fn->getType();
-
-  os << "@" << fn->getName() << '(';
-  interleaveComma(type->getInputs(),
-                  [&](Type *eltType) { printType(eltType); });
-  os << ')';
-
+void ModulePrinter::printFunctionResultType(const FunctionType *type) {
   switch (type->getResults().size()) {
   case 0:
     break;
@@ -551,6 +545,17 @@
   }
 }
 
+void ModulePrinter::printFunctionSignature(const Function *fn) {
+  auto type = fn->getType();
+
+  os << "@" << fn->getName() << '(';
+  interleaveComma(type->getInputs(),
+                  [&](Type *eltType) { printType(eltType); });
+  os << ')';
+
+  printFunctionResultType(type);
+}
+
 void ModulePrinter::print(const ExtFunction *fn) {
   os << "extfunc ";
   printFunctionSignature(fn);
@@ -627,7 +632,7 @@
         // done with it.
         valueIDs[value] = nextValueID++;
         return;
-      case SSAValueKind::FnArgument:
+      case SSAValueKind::MLFuncArgument:
         specialName << "arg" << nextArgumentID++;
         break;
       case SSAValueKind::ForStmt:
@@ -1018,6 +1023,9 @@
   // Prints ML function
   void print();
 
+  // Prints ML function signature
+  void printFunctionSignature();
+
   // Methods to print ML function statements
   void print(const Statement *stmt);
   void print(const OperationStmt *stmt);
@@ -1039,12 +1047,18 @@
 MLFunctionPrinter::MLFunctionPrinter(const MLFunction *function,
                                      const ModulePrinter &other)
     : FunctionPrinter(other), function(function), numSpaces(0) {
+  assert(function && "Cannot print nullptr function");
   numberValues();
 }
 
 /// Number all of the SSA values in this ML function.
 void MLFunctionPrinter::numberValues() {
-  // Visits all operation statements and numbers the first result.
+  // Numbers ML function arguments
+  for (auto *arg : function->getArguments())
+    numberValueID(arg);
+
+  // Walks ML function statements and numbers for statements and
+  // the first result of the operation statements.
   struct NumberValuesPass : public StmtWalker<NumberValuesPass> {
     NumberValuesPass(MLFunctionPrinter *printer) : printer(printer) {}
     void visitOperationStmt(OperationStmt *stmt) {
@@ -1062,14 +1076,30 @@
 
 void MLFunctionPrinter::print() {
   os << "mlfunc ";
-  // FIXME: should print argument names rather than just signature
-  printFunctionSignature(function);
+  printFunctionSignature();
   os << " {\n";
   print(function);
   os << "  return\n";
   os << "}\n\n";
 }
 
+void MLFunctionPrinter::printFunctionSignature() {
+  auto type = function->getType();
+
+  os << "@" << function->getName() << '(';
+
+  for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
+    if (i > 0)
+      os << ", ";
+    auto *arg = function->getArgument(i);
+    printOperand(arg);
+    os << " : ";
+    printType(arg->getType());
+  }
+  os << ")";
+  printFunctionResultType(type);
+}
+
 void MLFunctionPrinter::print(const StmtBlock *block) {
   numSpaces += indentWidth;
   for (auto &stmt : block->getStatements()) {
@@ -1169,7 +1199,7 @@
     return;
   case SSAValueKind::InstResult:
     return getDefiningInst()->print(os);
-  case SSAValueKind::FnArgument:
+  case SSAValueKind::MLFuncArgument:
     // TODO: Improve this.
     os << "<function argument>\n";
     return;
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index 96c6dfc..4719b35 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -36,7 +36,7 @@
     delete cast<ExtFunction>(this);
     break;
   case Kind::MLFunc:
-    delete cast<MLFunction>(this);
+    cast<MLFunction>(this)->destroy();
     break;
   case Kind::CFGFunc:
     delete cast<CFGFunction>(this);
@@ -118,13 +118,37 @@
 // MLFunction implementation.
 //===----------------------------------------------------------------------===//
 
+/// Create a new MLFunction with the specific fields.
+MLFunction *MLFunction::create(StringRef name, FunctionType *type) {
+  const auto &argTypes = type->getInputs();
+  auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
+  void *rawMem = malloc(byteSize);
+
+  // Initialize the MLFunction part of the function object.
+  auto function = ::new (rawMem) MLFunction(name, type);
+
+  // Initialize the arguments.
+  auto arguments = function->getArgumentsInternal();
+  for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
+    new (&arguments[i]) MLFuncArgument(argTypes[i], function);
+  return function;
+}
+
 MLFunction::MLFunction(StringRef name, FunctionType *type)
     : Function(name, type, Kind::MLFunc), StmtBlock(StmtBlockKind::MLFunc) {}
 
 MLFunction::~MLFunction() {
-  struct DropReferencesPass : public StmtWalker<DropReferencesPass> {
-    void visitOperationStmt(OperationStmt *stmt) { stmt->dropAllReferences(); }
-  };
-  DropReferencesPass pass;
-  pass.walk(const_cast<MLFunction *>(this));
+  // Explicitly erase statements instead of relying of 'StmtBlock' destructor
+  // since child statements need to be destroyed before function arguments
+  // are destroyed.
+  clear();
+
+  // Explicitly run the destructors for the function arguments.
+  for (auto &arg : getArgumentsInternal())
+    arg.~MLFuncArgument();
+}
+
+void MLFunction::destroy() {
+  this->~MLFunction();
+  free(this);
 }
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 44e44c8..0a76587 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -221,14 +221,6 @@
   return findFunction()->getContext();
 }
 
-/// This drops all operand uses from this statement, which is an essential
-/// step in breaking cyclic dependences between references when they are to
-/// be deleted.
-void OperationStmt::dropAllReferences() {
-  for (auto &op : getStmtOperands())
-    op.drop();
-}
-
 //===----------------------------------------------------------------------===//
 // ForStmt
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index 45df86a..ee8c68f 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -332,7 +332,8 @@
   HashTable::ScopeTy topScope(liveValues);
 
   // All of the arguments to the function are live for the whole function.
-  // TODO: Add arguments when they are supported.
+  for (auto *arg : fn.getArguments())
+    liveValues.insert(arg, true);
 
   // This recursive function walks the statement list pushing scopes onto the
   // stack as it goes, and popping them to remove them from the table.
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index ff3811f..e84cb86 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -2371,7 +2371,7 @@
     if (getToken().isNot(Token::percent_identifier))
       return emitError("expected SSA identifier");
 
-    StringRef name = getTokenSpelling().drop_front();
+    StringRef name = getTokenSpelling();
     consumeToken(Token::percent_identifier);
     argNames.push_back(name);
 
@@ -2474,16 +2474,24 @@
   StringRef name;
   FunctionType *type = nullptr;
   SmallVector<StringRef, 4> argNames;
-  // FIXME: Parse ML function signature (args + types)
-  // by passing pointer to SmallVector<identifier> into parseFunctionSignature
 
+  auto loc = getToken().getLoc();
   if (parseFunctionSignature(name, type, &argNames))
     return ParseFailure;
 
   // Okay, the ML function signature was parsed correctly, create the function.
-  auto function = new MLFunction(name, type);
+  auto function = MLFunction::create(name, type);
 
-  return MLFunctionParser(getState(), function).parseFunctionBody();
+  // Create the parser.
+  auto parser = MLFunctionParser(getState(), function);
+
+  // Add definitions of the function arguments.
+  for (unsigned i = 0, e = function->getNumArguments(); i != e; ++i) {
+    if (parser.addDefinition({argNames[i], 0, loc}, function->getArgument(i)))
+      return ParseFailure;
+  }
+
+  return parser.parseFunctionBody();
 }
 
 /// This is the top-level module parser.