Implement call and call_indirect ops.

This also fixes an infinite recursion in VariadicOperands that this turned up.

PiperOrigin-RevId: 209692932
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6bcaad4..cd90e3e 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -247,6 +247,7 @@
   }
 
   void print(const Module *module);
+  void printFunctionReference(const Function *func);
   void printAttribute(const Attribute *attr);
   void printType(const Type *type);
   void print(const Function *fn);
@@ -387,6 +388,10 @@
   }
 }
 
+void ModulePrinter::printFunctionReference(const Function *func) {
+  os << '@' << func->getName();
+}
+
 void ModulePrinter::printAttribute(const Attribute *attr) {
   switch (attr->getKind()) {
   case Attribute::Kind::Bool:
@@ -420,7 +425,8 @@
     if (!function) {
       os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
     } else {
-      os << '@' << function->getName() << " : ";
+      printFunctionReference(function);
+      os << " : ";
       printType(function->getType());
     }
     break;
@@ -768,6 +774,9 @@
   void printAffineExpr(const AffineExpr *expr) {
     return ModulePrinter::printAffineExpr(expr);
   }
+  void printFunctionReference(const Function *func) {
+    return ModulePrinter::printFunctionReference(func);
+  }
 
   void printOperand(const SSAValue *value) { printValueID(value); }
 
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index b1dce25..73beea6 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -111,7 +111,7 @@
   return TypeAttr::get(type, context);
 }
 
-FunctionAttr *Builder::getFunctionAttr(Function *value) {
+FunctionAttr *Builder::getFunctionAttr(const Function *value) {
   return FunctionAttr::get(value, context);
 }
 
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 68839d8..e59b59c 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -255,7 +255,7 @@
   using AttributeListSet =
       DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
   AttributeListSet attributeLists;
-  DenseMap<Function *, FunctionAttr *> functionAttrs;
+  DenseMap<const Function *, FunctionAttr *> functionAttrs;
 
 public:
   MLIRContextImpl() : identifiers(allocator) {
@@ -648,16 +648,20 @@
   return result;
 }
 
-FunctionAttr *FunctionAttr::get(Function *value, MLIRContext *context) {
+FunctionAttr *FunctionAttr::get(const Function *value, MLIRContext *context) {
+  assert(value && "Cannot get FunctionAttr for a null function");
+
   auto *&result = context->getImpl().functionAttrs[value];
   if (result)
     return result;
 
   result = context->getImpl().allocator.Allocate<FunctionAttr>();
-  new (result) FunctionAttr(value);
+  new (result) FunctionAttr(const_cast<Function *>(value));
   return result;
 }
 
+FunctionType *FunctionAttr::getType() const { return getValue()->getType(); }
+
 /// This function is used by the internals of the Function class to null out
 /// attributes refering to functions that are about to be deleted.
 void FunctionAttr::dropFunctionReference(Function *value) {
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index cec23e0..fb7dfe7 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -129,7 +129,7 @@
   // Check that affine map attribute was specified.
   auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
   if (!affineMapAttr)
-    return "requires an affine map.";
+    return "requires an affine map";
 
   // Check input and output dimensions match.
   auto *map = affineMapAttr->getValue();
@@ -198,7 +198,152 @@
 }
 
 //===----------------------------------------------------------------------===//
-// ConstantOp
+// CallOp
+//===----------------------------------------------------------------------===//
+
+OperationState CallOp::build(Builder *builder, Function *callee,
+                             ArrayRef<SSAValue *> operands) {
+  OperationState result(builder->getIdentifier("call"));
+  result.operands.append(operands.begin(), operands.end());
+  result.attributes.push_back(
+      {builder->getIdentifier("callee"), builder->getFunctionAttr(callee)});
+  result.types.append(callee->getType()->getResults().begin(),
+                      callee->getType()->getResults().end());
+  return result;
+}
+
+bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
+  StringRef calleeName;
+  llvm::SMLoc calleeLoc;
+  FunctionType *calleeType = nullptr;
+  SmallVector<OpAsmParser::OperandType, 4> operands;
+  Function *callee = nullptr;
+  if (parser->parseFunctionName(calleeName, calleeLoc) ||
+      parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
+                               OpAsmParser::Delimiter::Paren) ||
+      parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonType(calleeType) ||
+      parser->resolveFunctionName(calleeName, calleeType, calleeLoc, callee) ||
+      parser->addTypesToList(calleeType->getResults(), result->types) ||
+      parser->resolveOperands(operands, calleeType->getInputs(), calleeLoc,
+                              result->operands))
+    return true;
+
+  auto &builder = parser->getBuilder();
+  result->attributes.push_back(
+      {builder.getIdentifier("callee"), builder.getFunctionAttr(callee)});
+
+  return false;
+}
+
+void CallOp::print(OpAsmPrinter *p) const {
+  *p << "call ";
+  p->printFunctionReference(getCallee());
+  *p << '(';
+  p->printOperands(getOperands());
+  *p << ')';
+  p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
+  *p << " : " << *getCallee()->getType();
+}
+
+const char *CallOp::verify() const {
+  // Check that the callee attribute was specified.
+  auto *fnAttr = getAttrOfType<FunctionAttr>("callee");
+  if (!fnAttr)
+    return "requires a 'callee' function attribute";
+
+  // Verify that the operand and result types match the callee.
+  auto *fnType = fnAttr->getValue()->getType();
+  if (fnType->getNumInputs() != getNumOperands())
+    return "incorrect number of operands for callee";
+
+  for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
+    if (getOperand(i)->getType() != fnType->getInput(i))
+      return "operand type mismatch";
+  }
+
+  if (fnType->getNumResults() != getNumResults())
+    return "incorrect number of results for callee";
+
+  for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
+    if (getResult(i)->getType() != fnType->getResult(i))
+      return "result type mismatch";
+  }
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// CallIndirectOp
+//===----------------------------------------------------------------------===//
+
+OperationState CallIndirectOp::build(Builder *builder, SSAValue *callee,
+                                     ArrayRef<SSAValue *> operands) {
+  auto *fnType = cast<FunctionType>(callee->getType());
+
+  OperationState result(builder->getIdentifier("call_indirect"));
+  result.operands.push_back(callee);
+  result.operands.append(operands.begin(), operands.end());
+  result.types.append(fnType->getResults().begin(), fnType->getResults().end());
+  return result;
+}
+
+bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
+  FunctionType *calleeType = nullptr;
+  OpAsmParser::OperandType callee;
+  llvm::SMLoc operandsLoc;
+  SmallVector<OpAsmParser::OperandType, 4> operands;
+  return parser->parseOperand(callee) ||
+         parser->getCurrentLocation(&operandsLoc) ||
+         parser->parseOperandList(operands, /*requiredOperandCount=*/-1,
+                                  OpAsmParser::Delimiter::Paren) ||
+         parser->parseOptionalAttributeDict(result->attributes) ||
+         parser->parseColonType(calleeType) ||
+         parser->resolveOperand(callee, calleeType, result->operands) ||
+         parser->resolveOperands(operands, calleeType->getInputs(), operandsLoc,
+                                 result->operands) ||
+         parser->addTypesToList(calleeType->getResults(), result->types);
+}
+
+void CallIndirectOp::print(OpAsmPrinter *p) const {
+  *p << "call_indirect ";
+  p->printOperand(getCallee());
+  *p << '(';
+  auto operandRange = getOperands();
+  p->printOperands(++operandRange.begin(), operandRange.end());
+  *p << ')';
+  p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"callee");
+  *p << " : " << *getCallee()->getType();
+}
+
+const char *CallIndirectOp::verify() const {
+  // The callee must be a function.
+  auto *fnType = dyn_cast<FunctionType>(getCallee()->getType());
+  if (!fnType)
+    return "callee must have function type";
+
+  // Verify that the operand and result types match the callee.
+  if (fnType->getNumInputs() != getNumOperands() - 1)
+    return "incorrect number of operands for callee";
+
+  for (unsigned i = 0, e = fnType->getNumInputs(); i != e; ++i) {
+    if (getOperand(i + 1)->getType() != fnType->getInput(i))
+      return "operand type mismatch";
+  }
+
+  if (fnType->getNumResults() != getNumResults())
+    return "incorrect number of results for callee";
+
+  for (unsigned i = 0, e = fnType->getNumResults(); i != e; ++i) {
+    if (getResult(i)->getType() != fnType->getResult(i))
+      return "result type mismatch";
+  }
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
+// Constant*Op
 //===----------------------------------------------------------------------===//
 
 void ConstantOp::print(OpAsmPrinter *p) const {
@@ -444,10 +589,10 @@
 bool ReturnOp::parse(OpAsmParser *parser, OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> opInfo;
   SmallVector<Type *, 2> types;
-
-  return parser->parseOperandList(opInfo, -1, OpAsmParser::Delimiter::None) ||
+  llvm::SMLoc loc;
+  return parser->getCurrentLocation(&loc) || parser->parseOperandList(opInfo) ||
          (!opInfo.empty() && parser->parseColonTypeList(types)) ||
-         parser->resolveOperands(opInfo, types, result->operands);
+         parser->resolveOperands(opInfo, types, loc, result->operands);
 }
 
 void ReturnOp::print(OpAsmPrinter *p) const {
@@ -541,7 +686,7 @@
 
 /// Install the standard operations in the specified operation set.
 void mlir::registerStandardOperations(OperationSet &opSet) {
-  opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DeallocOp,
-                      DimOp, LoadOp, ReturnOp, StoreOp>(
+  opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
+                      ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
       /*prefix=*/"");
 }
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 0debe2d..d2ea5bf 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -33,6 +33,7 @@
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/IR/Types.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/PrettyStackTrace.h"
 #include "llvm/Support/SourceMgr.h"
 using namespace mlir;
 using llvm::SMLoc;
@@ -180,6 +181,8 @@
   ParseResult parseTypeList(SmallVectorImpl<Type *> &elements);
 
   // Attribute parsing.
+  Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
+                                     FunctionType *type);
   Attribute *parseAttribute();
   ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
 
@@ -578,6 +581,33 @@
 // Attribute parsing.
 //===----------------------------------------------------------------------===//
 
+/// Given a parsed reference to a function name like @foo and a type that it
+/// corresponds to, resolve it to a concrete function object (possibly
+/// synthesizing a forward reference) or emit an error and return null on
+/// failure.
+Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
+                                           FunctionType *type) {
+  Identifier name = builder.getIdentifier(nameStr.drop_front());
+
+  // See if the function has already been defined in the module.
+  Function *function = getModule()->getNamedFunction(name);
+
+  // If not, get or create a forward reference to one.
+  if (!function) {
+    auto &entry = state.functionForwardRefs[name];
+    if (!entry.first) {
+      entry.first = new ExtFunction(name, type);
+      entry.second = nameLoc;
+    }
+    function = entry.first;
+  }
+
+  if (function->getType() != type)
+    return (emitError(nameLoc, "reference to function with mismatched type"),
+            nullptr);
+  return function;
+}
+
 /// Attribute parsing.
 ///
 ///  attribute-value ::= bool-literal
@@ -664,7 +694,7 @@
 
   case Token::at_identifier: {
     auto nameLoc = getToken().getLoc();
-    Identifier name = builder.getIdentifier(getTokenSpelling().drop_front());
+    auto nameStr = getTokenSpelling();
     consumeToken(Token::at_identifier);
 
     if (parseToken(Token::colon, "expected ':' and function type"))
@@ -673,28 +703,12 @@
     Type *type = parseType();
     if (!type)
       return nullptr;
-    auto fnType = dyn_cast<FunctionType>(type);
+    auto *fnType = dyn_cast<FunctionType>(type);
     if (!fnType)
       return (emitError(typeLoc, "expected function type"), nullptr);
 
-    // See if the function has already been defined in the module.
-    Function *function = getModule()->getNamedFunction(name);
-
-    // If not, get or create a forward reference to one.
-    if (!function) {
-      auto &entry = state.functionForwardRefs[name];
-      if (!entry.first) {
-        entry.first = new ExtFunction(name, fnType);
-        entry.second = nameLoc;
-      }
-      function = entry.first;
-    }
-
-    if (function->getType() != type)
-      return (emitError(typeLoc, "reference to function with mismatched type"),
-              nullptr);
-
-    return builder.getFunctionAttr(function);
+    auto *function = resolveFunctionReference(nameStr, nameLoc, fnType);
+    return function ? builder.getFunctionAttr(function) : nullptr;
   }
 
   default: {
@@ -1701,6 +1715,10 @@
   // High level parsing methods.
   //===--------------------------------------------------------------------===//
 
+  bool getCurrentLocation(llvm::SMLoc *loc) override {
+    *loc = parser.getToken().getLoc();
+    return false;
+  }
   bool parseComma(llvm::SMLoc *loc = nullptr) override {
     if (loc)
       *loc = parser.getToken().getLoc();
@@ -1753,6 +1771,19 @@
     return parser.parseAttributeDict(result) == ParseFailure;
   }
 
+  /// Parse a function name like '@foo' and return the name in a form that can
+  /// be passed to resolveFunctionName when a function type is available.
+  virtual bool parseFunctionName(StringRef &result, llvm::SMLoc &loc) {
+    loc = parser.getToken().getLoc();
+
+    if (parser.getToken().isNot(Token::at_identifier))
+      return emitError(loc, "expected function name");
+
+    result = parser.getTokenSpelling();
+    parser.consumeToken(Token::at_identifier);
+    return false;
+  }
+
   bool parseOperand(OperandType &result) override {
     FunctionParser::SSAUseInfo useInfo;
     if (parser.parseSSAUse(useInfo))
@@ -1822,6 +1853,13 @@
     return false;
   }
 
+  /// Resolve a parse function name and a type into a function reference.
+  virtual bool resolveFunctionName(StringRef name, FunctionType *type,
+                                   llvm::SMLoc loc, Function *&result) {
+    result = parser.resolveFunctionReference(name, loc, type);
+    return result == nullptr;
+  }
+
   //===--------------------------------------------------------------------===//
   // Methods for interacting with the parser
   //===--------------------------------------------------------------------===//
@@ -1830,7 +1868,7 @@
 
   llvm::SMLoc getNameLoc() const override { return nameLoc; }
 
-  bool resolveOperand(OperandType operand, Type *type,
+  bool resolveOperand(const OperandType &operand, Type *type,
                       SmallVectorImpl<SSAValue *> &result) override {
     FunctionParser::SSAUseInfo operandInfo = {operand.name, operand.number,
                                               operand.location};
@@ -1872,6 +1910,11 @@
 
   consumeToken();
 
+  // If the custom op parser crashes, produce some indication to help debugging.
+  std::string opNameStr = opName.str();
+  llvm::PrettyStackTraceFormat fmt("MLIR Parser: custom op parser '%s'",
+                                   opNameStr.c_str());
+
   // Have the op implementation take a crack and parsing this.
   OperationState opState(builder.getIdentifier(opName));
   if (opDefinition->parseAssembly(&opAsmParser, &opState))