Implement initial support for function attributes, including parser, printer,
resolver support.

Still TODO are verifier support (to make sure you don't use an attribute for a
function in another module) and the TODO in ModuleParser::finalizeModule that I
will handle in the next patch.

PiperOrigin-RevId: 209361648
diff --git a/include/mlir/IR/Attributes.h b/include/mlir/IR/Attributes.h
index 16b99fb..8ccc73e 100644
--- a/include/mlir/IR/Attributes.h
+++ b/include/mlir/IR/Attributes.h
@@ -22,8 +22,9 @@
 #include "llvm/ADT/ArrayRef.h"
 
 namespace mlir {
-class MLIRContext;
 class AffineMap;
+class Function;
+class MLIRContext;
 class Type;
 
 /// Instances of the Attribute class are immutable, uniqued, immortal, and owned
@@ -38,7 +39,7 @@
     Type,
     Array,
     AffineMap,
-    // TODO: Function references.
+    Function,
   };
 
   /// Return the classification for this attribute.
@@ -192,6 +193,34 @@
   Type *value;
 };
 
+/// A function attribute represents a reference to a function object.
+///
+/// When working with IR, it is important to know that a function attribute can
+/// exist with a null Function inside of it, which occurs when a function object
+/// is deleted that had an attribute which referenced it.  No references to this
+/// attribute should persist across the transformation, but that attribute will
+/// remain in MLIRContext.
+class FunctionAttr : public Attribute {
+public:
+  static FunctionAttr *get(Function *value, MLIRContext *context);
+
+  Function *getValue() const { return value; }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const Attribute *attr) {
+    return attr->getKind() == Kind::Function;
+  }
+
+  /// This function is used by the internals of the Function class to null out
+  /// attributes refering to functions that are about to be deleted.
+  static void dropFunctionReference(Function *value);
+
+private:
+  FunctionAttr(Function *value) : Attribute(Kind::Function), value(value) {}
+  ~FunctionAttr() = delete;
+  Function *value;
+};
+
 } // end namespace mlir.
 
 #endif
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 3324d4b..62c0dcb 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -39,6 +39,7 @@
 class StringAttr;
 class TypeAttr;
 class ArrayAttr;
+class FunctionAttr;
 class AffineMapAttr;
 class AffineMap;
 class AffineExpr;
@@ -85,6 +86,7 @@
   ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
   AffineMapAttr *getAffineMapAttr(AffineMap *value);
   TypeAttr *getTypeAttr(Type *type);
+  FunctionAttr *getFunctionAttr(Function *value);
 
   // Affine Expressions and Affine Map.
   AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
diff --git a/include/mlir/IR/Function.h b/include/mlir/IR/Function.h
index 15e08e3..f3b7aa0 100644
--- a/include/mlir/IR/Function.h
+++ b/include/mlir/IR/Function.h
@@ -66,7 +66,7 @@
 
 protected:
   Function(StringRef name, FunctionType *type, Kind kind);
-  ~Function() {}
+  ~Function();
 
 private:
   Kind kind;
diff --git a/include/mlir/IR/Identifier.h b/include/mlir/IR/Identifier.h
index 70528e3..af16636 100644
--- a/include/mlir/IR/Identifier.h
+++ b/include/mlir/IR/Identifier.h
@@ -40,10 +40,13 @@
   Identifier &operator=(const Identifier &other) = default;
 
   /// Return a StringRef for the string.
-  StringRef ref() const { return StringRef(pointer, size()); }
+  StringRef strref() const { return StringRef(pointer, size()); }
+
+  /// Identifiers implicitly convert to StringRefs.
+  operator StringRef() const { return strref(); }
 
   /// Return an std::string.
-  std::string str() const { return ref().str(); }
+  std::string str() const { return strref().str(); }
 
   /// Return a null terminated C string.
   const char *c_str() const { return pointer; }
@@ -59,7 +62,7 @@
   }
 
   /// Return true if this identifier is the specified string.
-  bool is(StringRef string) const { return ref().equals(string); }
+  bool is(StringRef string) const { return strref().equals(string); }
 
   const char *begin() const { return pointer; }
   const char *end() const { return pointer + size(); }
@@ -100,7 +103,7 @@
 
 // Make identifiers hashable.
 inline llvm::hash_code hash_value(Identifier arg) {
-  return llvm::hash_value(arg.ref());
+  return llvm::hash_value(arg.strref());
 }
 
 } // end namespace mlir
diff --git a/include/mlir/IR/Module.h b/include/mlir/IR/Module.h
index f6c543b..38c4ef8 100644
--- a/include/mlir/IR/Module.h
+++ b/include/mlir/IR/Module.h
@@ -60,14 +60,14 @@
   // Interfaces for working with the symbol table.
 
   /// Look up a function with the specified name, returning null if no such
-  /// name exists.
+  /// name exists.  Function names never include the @ on them.
   Function *getNamedFunction(StringRef name);
   const Function *getNamedFunction(StringRef name) const {
     return const_cast<Module *>(this)->getNamedFunction(name);
   }
 
   /// Look up a function with the specified name, returning null if no such
-  /// name exists.
+  /// name exists.  Function names never include the @ on them.
   Function *getNamedFunction(Identifier name);
 
   const Function *getNamedFunction(Identifier name) const {
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 896b76e..8303f57 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -415,6 +415,16 @@
   case Attribute::Kind::Type:
     printType(cast<TypeAttr>(attr)->getValue());
     break;
+  case Attribute::Kind::Function: {
+    auto *function = cast<FunctionAttr>(attr)->getValue();
+    if (!function) {
+      os << "<<FUNCTION ATTR FOR DELETED FUNCTION>>";
+    } else {
+      os << '@' << function->getName() << " : ";
+      printType(function->getType());
+    }
+    break;
+  }
   }
 }
 
@@ -784,6 +794,11 @@
         }
       } else if (auto intOp = op->getAs<ConstantAffineIntOp>()) {
         specialName << 'c' << intOp->getValue();
+      } else if (auto constant = op->getAs<ConstantOp>()) {
+        if (isa<FunctionAttr>(constant->getValue()))
+          specialName << 'f';
+        else
+          specialName << "cst";
       }
     }
 
@@ -909,7 +924,7 @@
   // Filter out any attributes that shouldn't be included.
   SmallVector<NamedAttribute, 8> filteredAttrs;
   for (auto attr : attrs) {
-    auto attrName = attr.first.ref();
+    auto attrName = attr.first.strref();
     // Never print attributes that start with a colon.  These are internal
     // attributes that represent location or other internal metadata.
     if (attrName.startswith(":"))
@@ -946,7 +961,7 @@
 
   // Check to see if this is a known operation.  If so, use the registered
   // custom printer hook.
-  if (auto *opInfo = state.operationSet->lookup(op->getName().ref())) {
+  if (auto *opInfo = state.operationSet->lookup(op->getName())) {
     opInfo->printAssembly(op, this);
     return;
   }
@@ -957,7 +972,7 @@
 
 void FunctionPrinter::printDefaultOp(const Operation *op) {
   os << '"';
-  printEscapedString(op->getName().ref(), os);
+  printEscapedString(op->getName(), os);
   os << "\"(";
 
   interleaveComma(op->getOperands(),
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index be334ef..5fce94c 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -111,6 +111,10 @@
   return TypeAttr::get(type, context);
 }
 
+FunctionAttr *Builder::getFunctionAttr(Function *value) {
+  return FunctionAttr::get(value, context);
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Expressions, Affine Maps, and Integet Sets.
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/Function.cpp b/lib/IR/Function.cpp
index df3d7c8..bfcfd6f 100644
--- a/lib/IR/Function.cpp
+++ b/lib/IR/Function.cpp
@@ -15,6 +15,7 @@
 // limitations under the License.
 // =============================================================================
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Module.h"
@@ -27,6 +28,11 @@
 Function::Function(StringRef name, FunctionType *type, Kind kind)
     : kind(kind), name(Identifier::get(name, type->getContext())), type(type) {}
 
+Function::~Function() {
+  // Clean up function attributes referring to this function.
+  FunctionAttr::dropFunctionReference(this);
+}
+
 MLIRContext *Function::getContext() const { return getType()->getContext(); }
 
 /// Delete this object.
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 644a74a..a960223 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Function.h"
 #include "mlir/IR/Identifier.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/OperationSet.h"
@@ -254,6 +255,7 @@
   using AttributeListSet =
       DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
   AttributeListSet attributeLists;
+  DenseMap<Function *, FunctionAttr *> functionAttrs;
 
 public:
   MLIRContextImpl() : identifiers(allocator) {
@@ -633,6 +635,34 @@
   return result;
 }
 
+FunctionAttr *FunctionAttr::get(Function *value, MLIRContext *context) {
+  auto *&result = context->getImpl().functionAttrs[value];
+  if (result)
+    return result;
+
+  result = context->getImpl().allocator.Allocate<FunctionAttr>();
+  new (result) FunctionAttr(value);
+  return result;
+}
+
+/// This function is used by the internals of the Function class to null out
+/// attributes refering to functions that are about to be deleted.
+void FunctionAttr::dropFunctionReference(Function *value) {
+  // Check to see if there was an attribute referring to this function.
+  auto &functionAttrs = value->getContext()->getImpl().functionAttrs;
+
+  // If not, then we're done.
+  auto it = functionAttrs.find(value);
+  if (it == functionAttrs.end())
+    return;
+
+  // If so, null out the function reference in the attribute (to avoid dangling
+  // pointers) and remove the entry from the map so the map doesn't contain
+  // dangling keys.
+  it->second->value = nullptr;
+  functionAttrs.erase(it);
+}
+
 /// Perform a three-way comparison between the names of the specified
 /// NamedAttributes.
 static int compareNamedAttributes(const NamedAttribute *lhs,
diff --git a/lib/IR/Module.cpp b/lib/IR/Module.cpp
index 5a33d26..36b892a 100644
--- a/lib/IR/Module.cpp
+++ b/lib/IR/Module.cpp
@@ -21,13 +21,13 @@
 Module::Module(MLIRContext *context) : context(context) {}
 
 /// Look up a function with the specified name, returning null if no such
-/// name exists.
+/// name exists.  Function names never include the @ on them.
 Function *Module::getNamedFunction(StringRef name) {
   return getNamedFunction(Identifier::get(name, context));
 }
 
 /// Look up a function with the specified name, returning null if no such
-/// name exists.
+/// name exists.  Function names never include the @ on them.
 Function *Module::getNamedFunction(Identifier name) {
   auto it = symbolTable.find(name);
   return it != symbolTable.end() ? it->second : nullptr;
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 723dcc2..6d8c366 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -204,16 +204,25 @@
 void ConstantOp::print(OpAsmPrinter *p) const {
   *p << "constant " << *getValue();
   p->printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/"value");
-  *p << " : " << *getType();
+
+  if (!isa<FunctionAttr>(getValue()))
+    *p << " : " << *getType();
 }
 
 bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
   Attribute *valueAttr;
   Type *type;
 
-  return parser->parseAttribute(valueAttr, "value", result->attributes) ||
-         parser->parseOptionalAttributeDict(result->attributes) ||
-         parser->parseColonType(type) ||
+  if (parser->parseAttribute(valueAttr, "value", result->attributes) ||
+      parser->parseOptionalAttributeDict(result->attributes))
+    return true;
+
+  // 'constant' taking a function reference doesn't get a redundant type
+  // specifier.  The attribute itself carries it.
+  if (auto *fnAttr = dyn_cast<FunctionAttr>(valueAttr))
+    return parser->addTypeToList(fnAttr->getValue()->getType(), result->types);
+
+  return parser->parseColonType(type) ||
          parser->addTypeToList(type, result->types);
 }
 
@@ -244,7 +253,9 @@
   }
 
   if (isa<FunctionType>(type)) {
-    // TODO: Verify a function attr.
+    if (!isa<FunctionAttr>(value))
+      return "requires 'value' to be a function reference";
+    return nullptr;
   }
 
   return "requires a result type that aligns with the 'value' attribute";
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index ae87ad2..5310daa 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -58,9 +58,15 @@
 
   // A map from affine map identifier to AffineMap.
   llvm::StringMap<AffineMap *> affineMapDefinitions;
+
   // A map from integer set identifier to IntegerSet.
   llvm::StringMap<IntegerSet *> integerSetDefinitions;
 
+  // This keeps track of all forward references to functions along with the
+  // temporary function used to represent them and the location of the first
+  // reference.
+  llvm::DenseMap<Identifier, std::pair<Function *, SMLoc>> functionForwardRefs;
+
 private:
   ParserState(const ParserState &) = delete;
   void operator=(const ParserState &) = delete;
@@ -579,6 +585,7 @@
 ///                    | string-literal
 ///                    | type
 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
+///                    | function-id `:` function-type
 ///
 Attribute *Parser::parseAttribute() {
   switch (getToken().getKind()) {
@@ -653,6 +660,42 @@
       return builder.getAffineMapAttr(affineMap);
     return (emitError("expected constant attribute value"), nullptr);
   }
+
+  case Token::at_identifier: {
+    auto nameLoc = getToken().getLoc();
+    Identifier name = builder.getIdentifier(getTokenSpelling().drop_front());
+    consumeToken(Token::at_identifier);
+
+    if (parseToken(Token::colon, "expected ':' and function type"))
+      return nullptr;
+    auto typeLoc = getToken().getLoc();
+    Type *type = parseType();
+    if (!type)
+      return nullptr;
+    auto fnType = dyn_cast<FunctionType>(type);
+    if (!fnType)
+      return (emitError(typeLoc, "expected function type"), nullptr);
+
+    // See if the function has already been defined in the module.
+    Function *function = getModule()->getNamedFunction(name);
+
+    // If not, get or create a forward reference to one.
+    if (!function) {
+      auto &entry = state.functionForwardRefs[name];
+      if (!entry.first) {
+        entry.first = new ExtFunction(name, fnType);
+        entry.second = nameLoc;
+      }
+      function = entry.first;
+    }
+
+    if (function->getType() != type)
+      return (emitError(typeLoc, "reference to function with mismatched type"),
+              nullptr);
+
+    return builder.getFunctionAttr(function);
+  }
+
   default: {
     if (Type *type = parseType())
       return builder.getTypeAttr(type);
@@ -2426,6 +2469,8 @@
   ParseResult parseModule();
 
 private:
+  ParseResult finalizeModule();
+
   ParseResult parseAffineMapDef();
   ParseResult parseIntegerSetDef();
 
@@ -2587,7 +2632,7 @@
   getModule()->getFunctions().push_back(function);
 
   // Verify no name collision / redefinition.
-  if (function->getName().ref() != name)
+  if (function->getName() != name)
     return emitError(loc,
                      "redefinition of function named '" + name.str() + "'");
 
@@ -2612,7 +2657,7 @@
   getModule()->getFunctions().push_back(function);
 
   // Verify no name collision / redefinition.
-  if (function->getName().ref() != name)
+  if (function->getName() != name)
     return emitError(loc,
                      "redefinition of function named '" + name.str() + "'");
 
@@ -2639,7 +2684,7 @@
   getModule()->getFunctions().push_back(function);
 
   // Verify no name collision / redefinition.
-  if (function->getName().ref() != name)
+  if (function->getName() != name)
     return emitError(loc,
                      "redefinition of function named '" + name.str() + "'");
 
@@ -2655,6 +2700,27 @@
   return parser.parseFunctionBody();
 }
 
+/// Finish the end of module parsing - when the result is valid, do final
+/// checking.
+ParseResult ModuleParser::finalizeModule() {
+
+  // Resolve all forward references.
+  for (auto forwardRef : getState().functionForwardRefs) {
+    auto name = forwardRef.first;
+
+    // Resolve the reference.
+    auto *resolvedFunction = getModule()->getNamedFunction(name);
+    if (!resolvedFunction)
+      return emitError(forwardRef.second.second,
+                       "reference to undefined function '" + name.str() + "'");
+
+    // TODO(clattner): actually go through and update references in the module
+    // to the new function.
+  }
+
+  return ParseSuccess;
+}
+
 /// This is the top-level module parser.
 ParseResult ModuleParser::parseModule() {
   while (1) {
@@ -2665,7 +2731,7 @@
 
       // If we got to the end of the file, then we're done.
     case Token::eof:
-      return ParseSuccess;
+      return finalizeModule();
 
     // If we got an error token, then the lexer already emitted an error, just
     // stop.  Someday we could introduce error recovery if there was demand for
diff --git a/test/IR/core-ops.mlir b/test/IR/core-ops.mlir
index 626654a..2897129 100644
--- a/test/IR/core-ops.mlir
+++ b/test/IR/core-ops.mlir
@@ -49,8 +49,15 @@
   // CHECK: %c43 = constant 43 {crazy: "foo"} : affineint
   %8 = constant 43 {crazy: "foo"} : affineint
 
-  // CHECK: %4 = constant 4.300000e+01 : bf16
+  // CHECK: %cst = constant 4.300000e+01 : bf16
   %9 = constant 43.0 : bf16
+
+  // CHECK: %f = constant @cfgfunc_with_ops : (f32) -> ()
+  %10 = constant @cfgfunc_with_ops : (f32) -> ()
+
+  // CHECK: %f_1 = constant @affine_apply : () -> ()
+  %11 = constant @affine_apply : () -> ()
+
   return
 }
 
diff --git a/test/IR/invalid.mlir b/test/IR/invalid.mlir
index 9edc152..9defcb0 100644
--- a/test/IR/invalid.mlir
+++ b/test/IR/invalid.mlir
@@ -402,3 +402,19 @@
 
 extfunc @redef()
 extfunc @redef()  // expected-error {{redefinition of function named 'redef'}}
+
+// -----
+
+cfgfunc @foo() {
+bb0:
+  %x = constant @foo : (i32) -> ()  // expected-error {{reference to function with mismatched type}}
+  return
+}
+
+// -----
+
+cfgfunc @foo() {
+bb0:
+  %x = constant @bar : (i32) -> ()  // expected-error {{reference to undefined function 'bar'}}
+  return
+}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index a5eb9e6..fe6cbb2 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -215,6 +215,8 @@
   // CHECK: "foo"() {cfgfunc: [], d: 1.000000e-09, i123: 7, if: "foo"} : () -> ()
   "foo"() {if: "foo", cfgfunc: [], i123: 7, d: 1.e-9} : () -> ()
 
+  // CHECK: "foo"() {fn: @attributes : () -> (), if: @ifstmt : (i32) -> ()} : () -> ()
+  "foo"() {fn: @attributes : () -> (), if: @ifstmt : (i32) -> ()} : () -> ()
   return
 }