[mlir] Add a TypeAttr class, allow type attributes

PiperOrigin-RevId: 207235956
diff --git a/include/mlir/IR/Attributes.h b/include/mlir/IR/Attributes.h
index 7891b32..16b99fb 100644
--- a/include/mlir/IR/Attributes.h
+++ b/include/mlir/IR/Attributes.h
@@ -24,6 +24,7 @@
 namespace mlir {
 class MLIRContext;
 class AffineMap;
+class Type;
 
 /// Instances of the Attribute class are immutable, uniqued, immortal, and owned
 /// by MLIRContext.  As such, they are passed around by raw non-const pointer.
@@ -34,6 +35,7 @@
     Integer,
     Float,
     String,
+    Type,
     Array,
     AffineMap,
     // TODO: Function references.
@@ -173,6 +175,23 @@
   AffineMap *value;
 };
 
+class TypeAttr : public Attribute {
+public:
+  static TypeAttr *get(Type *type, MLIRContext *context);
+
+  Type *getValue() const { return value; }
+
+  /// Methods for support type inquiry through isa, cast, and dyn_cast.
+  static bool classof(const Attribute *attr) {
+    return attr->getKind() == Kind::Type;
+  }
+
+private:
+  TypeAttr(Type *value) : Attribute(Kind::Type), value(value) {}
+  ~TypeAttr() = delete;
+  Type *value;
+};
+
 } // end namespace mlir.
 
 #endif
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 4f7161b..3e4bbd7 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -37,6 +37,7 @@
 class IntegerAttr;
 class FloatAttr;
 class StringAttr;
+class TypeAttr;
 class ArrayAttr;
 class AffineMapAttr;
 class AffineMap;
@@ -80,6 +81,7 @@
   StringAttr *getStringAttr(StringRef bytes);
   ArrayAttr *getArrayAttr(ArrayRef<Attribute *> value);
   AffineMapAttr *getAffineMapAttr(AffineMap *value);
+  TypeAttr *getTypeAttr(Type *type);
 
   // Affine Expressions and Affine Map.
   AffineMap *getAffineMap(unsigned dimCount, unsigned symbolCount,
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 6177e19..b2ebaa6 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -280,6 +280,9 @@
   case Attribute::Kind::AffineMap:
     printAffineMapReference(cast<AffineMapAttr>(attr)->getValue());
     break;
+  case Attribute::Kind::Type:
+    printType(cast<TypeAttr>(attr)->getValue());
+    break;
   }
 }
 
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 199f35d..715e460 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -100,6 +100,10 @@
   return AffineMapAttr::get(value, context);
 }
 
+TypeAttr *Builder::getTypeAttr(Type *type) {
+  return TypeAttr::get(type, context);
+}
+
 //===----------------------------------------------------------------------===//
 // Affine Expressions and Affine Map.
 //===----------------------------------------------------------------------===//
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index f558ae1..3edb115 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -249,6 +249,7 @@
   using ArrayAttrSet = DenseSet<ArrayAttr *, ArrayAttrKeyInfo>;
   ArrayAttrSet arrayAttrs;
   DenseMap<AffineMap *, AffineMapAttr *> affineMapAttrs;
+  DenseMap<Type *, TypeAttr *> typeAttrs;
   using AttributeListSet =
       DenseSet<AttributeListStorage *, AttributeListKeyInfo>;
   AttributeListSet attributeLists;
@@ -622,6 +623,16 @@
   return result;
 }
 
+TypeAttr *TypeAttr::get(Type *type, MLIRContext *context) {
+  auto *&result = context->getImpl().typeAttrs[type];
+  if (result)
+    return result;
+
+  result = context->getImpl().allocator.Allocate<TypeAttr>();
+  new (result) TypeAttr(type);
+  return result;
+}
+
 /// Perform a three-way comparison between the names of the specified
 /// NamedAttributes.
 static int compareNamedAttributes(const NamedAttribute *lhs,
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index c146a93..c77f0f4 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -21,8 +21,6 @@
 
 #include "mlir/Parser.h"
 #include "Lexer.h"
-#include "llvm/ADT/DenseMap.h"
-#include "llvm/Support/SourceMgr.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Attributes.h"
@@ -33,6 +31,8 @@
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/Types.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/SourceMgr.h"
 using namespace mlir;
 using llvm::SMLoc;
 using llvm::SourceMgr;
@@ -524,7 +524,6 @@
   return builder.getFunctionType(arguments, results);
 }
 
-
 /// Parse a list of types without an enclosing parenthesis.  The list must have
 /// at least one member.
 ///
@@ -574,6 +573,7 @@
 ///                    | integer-literal
 ///                    | float-literal
 ///                    | string-literal
+///                    | type
 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
 ///
 Attribute *Parser::parseAttribute() {
@@ -642,14 +642,19 @@
       return nullptr;
     return builder.getArrayAttr(elements);
   }
-  default:
+  case Token::hash_identifier:
+  case Token::l_paren: {
     // Try to parse affine map reference.
-    auto *affineMap = parseAffineMapReference();
-    if (affineMap != nullptr)
+    if (auto *affineMap = parseAffineMapReference())
       return builder.getAffineMapAttr(affineMap);
-
     return (emitError("expected constant attribute value"), nullptr);
   }
+  default: {
+    if (Type *type = parseType())
+      return builder.getTypeAttr(type);
+    return nullptr;
+  }
+  }
 }
 
 /// Attribute dictionary.
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index 45b8444..b2d1a3d 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -287,3 +287,10 @@
   return %x, %y, %z, %t, %f : i32, i23, i23, i1, i1
 }
 
+// CHECK-LABEL: cfgfunc @typeattr
+cfgfunc @typeattr() -> () {
+bb0:
+// CHECK: "foo"() {bar: tensor<??f32>} : () -> ()
+  "foo"(){bar: tensor<??f32>} : () -> ()
+  return
+}