Add tensor type.
PiperOrigin-RevId: 201830793
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index a2befc3..9c30660 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -60,6 +60,40 @@
return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
}
};
+struct RankedTensorTypeKeyInfo : DenseMapInfo<RankedTensorType*> {
+ // Ranked tensors are uniqued based on their element type and shape.
+ using KeyTy = std::pair<Type*, ArrayRef<int>>;
+ using DenseMapInfo<RankedTensorType*>::getHashValue;
+ using DenseMapInfo<RankedTensorType*>::isEqual;
+
+ static unsigned getHashValue(KeyTy key) {
+ return hash_combine(DenseMapInfo<Type*>::getHashValue(key.first),
+ hash_combine_range(key.second.begin(),
+ key.second.end()));
+ }
+
+ static bool isEqual(const KeyTy &lhs, const RankedTensorType *rhs) {
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs == KeyTy(rhs->getElementType(), rhs->getShape());
+ }
+};
+struct UnrankedTensorTypeKeyInfo : DenseMapInfo<UnrankedTensorType*> {
+ // Ranked tensors are uniqued based on their element type and shape.
+ using KeyTy = Type*;
+ using DenseMapInfo<UnrankedTensorType*>::getHashValue;
+ using DenseMapInfo<UnrankedTensorType*>::isEqual;
+
+ static unsigned getHashValue(KeyTy key) {
+ return hash_combine(DenseMapInfo<Type*>::getHashValue(key));
+ }
+
+ static bool isEqual(const KeyTy &lhs, const UnrankedTensorType *rhs) {
+ if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+ return false;
+ return lhs == rhs->getElementType();
+ }
+};
} // end anonymous namespace.
@@ -82,6 +116,14 @@
using VectorTypeSet = DenseSet<VectorType*, VectorTypeKeyInfo>;
VectorTypeSet vectors;
+ /// Ranked tensor type uniquing.
+ using RankedTensorTypeSet = DenseSet<RankedTensorType*,
+ RankedTensorTypeKeyInfo>;
+ RankedTensorTypeSet rankedTensors;
+
+ /// Unranked tensor type uniquing.
+ DenseMap<Type*, UnrankedTensorType*> unrankedTensors;
+
public:
/// Copy the specified array of elements into memory managed by our bump
@@ -198,3 +240,69 @@
// Cache and return it.
return *existing.first = result;
}
+
+
+TensorType::TensorType(TypeKind kind, Type *elementType, MLIRContext *context)
+ : Type(kind, context), elementType(elementType) {
+ assert((isa<PrimitiveType>(elementType) || isa<VectorType>(elementType)) &&
+ "tensor elements must be primitives or vectors");
+ assert(isa<TensorType>(this));
+}
+
+RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
+ MLIRContext *context)
+ : TensorType(TypeKind::RankedTensor, elementType, context),
+ shapeElements(shape.data()) {
+ setSubclassData(shape.size());
+}
+
+UnrankedTensorType::UnrankedTensorType(Type *elementType, MLIRContext *context)
+ : TensorType(TypeKind::UnrankedTensor, elementType, context) {
+}
+
+RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
+ Type *elementType) {
+ auto *context = elementType->getContext();
+ auto &impl = context->getImpl();
+
+ // Look to see if we already have this ranked tensor type.
+ RankedTensorTypeKeyInfo::KeyTy key(elementType, shape);
+ auto existing = impl.rankedTensors.insert_as(nullptr, key);
+
+ // If we already have it, return that value.
+ if (!existing.second)
+ return *existing.first;
+
+ // On the first use, we allocate them into the bump pointer.
+ auto *result = impl.allocator.Allocate<RankedTensorType>();
+
+ // Copy the shape into the bump pointer.
+ shape = impl.copyInto(shape);
+
+ // Initialize the memory using placement new.
+ new (result) RankedTensorType(shape, elementType, context);
+
+ // Cache and return it.
+ return *existing.first = result;
+}
+
+UnrankedTensorType *UnrankedTensorType::get(Type *elementType) {
+ auto *context = elementType->getContext();
+ auto &impl = context->getImpl();
+
+ // Look to see if we already have this unranked tensor type.
+ auto existing = impl.unrankedTensors.insert({elementType, nullptr});
+
+ // If we already have it, return that value.
+ if (!existing.second)
+ return existing.first->second;
+
+ // On the first use, we allocate them into the bump pointer.
+ auto *result = impl.allocator.Allocate<UnrankedTensorType>();
+
+ // Initialize the memory using placement new.
+ new (result) UnrankedTensorType(elementType, context);
+
+ // Cache and return it.
+ return existing.first->second = result;
+}