Add tensor type.

PiperOrigin-RevId: 201830793
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index a2befc3..9c30660 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -60,6 +60,40 @@
     return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
   }
 };
+struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
+  // Ranked tensors are uniqued based on their element type and shape.
+  using KeyTy = std::pair<Type*, ArrayRef<int>>;
+  using DenseMapInfo<RankedTensorType*>::getHashValue;
+  using DenseMapInfo<RankedTensorType*>::isEqual;
+
+  static unsigned getHashValue(KeyTy key) {
+    return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
+                        hash_combine_range(key.second.begin(),
+                                           key.second.end()));
+  }
+
+  static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
+    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+      return false;
+    return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+  }
+};
+struct UnrankedTensorTypeKeyInfo : DenseMapInfo<UnrankedTensorType*> {
+  // Ranked tensors are uniqued based on their element type and shape.
+  using KeyTy = Type*;
+  using DenseMapInfo<UnrankedTensorType*>::getHashValue;
+  using DenseMapInfo<UnrankedTensorType*>::isEqual;
+
+  static unsigned getHashValue(KeyTy key) {
+    return hash_combine(DenseMapInfo<Type*>::getHashValue(key));
+  }
+
+  static bool isEqual(const KeyTy &lhs, const UnrankedTensorType *rhs) {
+    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+      return false;
+    return lhs == rhs->getElementType();
+  }
+};
 } // end anonymous namespace.
 
 
@@ -82,6 +116,14 @@
   using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
   VectorTypeSet vectors;
 
+  /// Ranked tensor type uniquing.
+  using RankedTensorTypeSet = DenseSet<RankedTensorType*,
+                                       RankedTensorTypeKeyInfo>;
+  RankedTensorTypeSet rankedTensors;
+
+  /// Unranked tensor type uniquing.
+  DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
+
 
 public:
   /// Copy the specified array of elements into memory managed by our bump
@@ -198,3 +240,69 @@
   // Cache and return it.
   return *existing.first = result;
 }
+
+
+TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
+  : Type(kind, context), elementType(elementType) {
+  assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
+         "tensor elements must be primitives or vectors");
+  assert(isa<TensorType>(this));
+}
+
+RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
+                                   MLIRContext *context)
+  : TensorType(TypeKind::RankedTensor, elementType, context),
+    shapeElements(shape.data()) {
+  setSubclassData(shape.size());
+}
+
+UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
+  : TensorType(TypeKind::UnrankedTensor, elementType, context) {
+}
+
+RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
+                                        Type *elementType) {
+  auto *context = elementType->getContext();
+  auto &impl = context->getImpl();
+
+  // Look to see if we already have this ranked tensor type.
+  RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
+  auto existing = impl.rankedTensors.insert_as(nullptr, key);
+
+  // If we already have it, return that value.
+  if (!existing.second)
+    return *existing.first;
+
+  // On the first use, we allocate them into the bump pointer.
+  auto *result = impl.allocator.Allocate<RankedTensorType>();
+
+  // Copy the shape into the bump pointer.
+  shape = impl.copyInto(shape);
+
+  // Initialize the memory using placement new.
+  new (result) RankedTensorType(shape, elementType, context);
+
+  // Cache and return it.
+  return *existing.first = result;
+}
+
+UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
+  auto *context = elementType->getContext();
+  auto &impl = context->getImpl();
+
+  // Look to see if we already have this unranked tensor type.
+  auto existing = impl.unrankedTensors.insert({elementType, nullptr});
+
+  // If we already have it, return that value.
+  if (!existing.second)
+    return existing.first->second;
+
+  // On the first use, we allocate them into the bump pointer.
+  auto *result = impl.allocator.Allocate<UnrankedTensorType>();
+
+  // Initialize the memory using placement new.
+  new (result) UnrankedTensorType(elementType, context);
+
+  // Cache and return it.
+  return existing.first->second = result;
+}