[mlir] Add a string type
PiperOrigin-RevId: 206977161
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9a94d66..f881872 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -303,6 +303,9 @@
case Type::Kind::TFControl:
os << "tf_control";
return;
+ case Type::Kind::TFString:
+ os << "tf_string";
+ return;
case Type::Kind::Integer: {
auto *integer = cast<IntegerType>(type);
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 9ba75e4..199f35d 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -47,6 +47,8 @@
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
+OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
+
IntegerType *Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context);
}
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 8d8b013..f558ae1 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -260,7 +260,8 @@
/// Copy the specified array of elements into memory managed by our bump
/// pointer allocator. This assumes the elements are all PODs.
- template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
+ template <typename T>
+ ArrayRef<T> copyInto(ArrayRef<T> elements) {
auto result = allocator.Allocate<T>(elements.size());
std::uninitialized_copy(elements.begin(), elements.end(), result);
return ArrayRef<T>(result, elements.size());
@@ -445,11 +446,14 @@
return *existing.first = result;
}
+static bool isValidTensorElementType(Type *type, MLIRContext *context) {
+ return isa<FloatType>(type) || isa<VectorType>(type) ||
+ isa<IntegerType>(type) || type == Type::getTFString(context);
+}
+
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
: Type(kind, context), elementType(elementType) {
- assert((isa<FloatType>(elementType) || isa<VectorType>(elementType) ||
- isa<IntegerType>(elementType)) &&
- "tensor elements must be primitives or vectors");
+ assert(isValidTensorElementType(elementType, context));
assert(isa<TensorType>(this));
}