[mlir] Add a string type

PiperOrigin-RevId: 206977161
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index c6a09f1..67cc7b8 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -64,6 +64,7 @@
 
   OtherType *getAffineIntType();
   OtherType *getTFControlType();
+  OtherType *getTFStringType();
   IntegerType *getIntegerType(unsigned width);
   FunctionType *getFunctionType(ArrayRef<Type *> inputs,
                                 ArrayRef<Type *> results);
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 9c8233d..606b7b8 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -40,10 +40,11 @@
 
     // TensorFlow types.
     TFControl,
+    TFString,
 
     /// These are marker for the first and last 'other' type.
     FIRST_OTHER_TYPE = AffineInt,
-    LAST_OTHER_TYPE = TFControl,
+    LAST_OTHER_TYPE = TFString,
 
     // Floating point.
     BF16,
@@ -87,6 +88,7 @@
   static FloatType *getF64(MLIRContext *ctx);
   static OtherType *getAffineInt(MLIRContext *ctx);
   static OtherType *getTFControl(MLIRContext *ctx);
+  static OtherType *getTFString(MLIRContext *ctx);
 
   /// Print the current type.
   void print(raw_ostream &os) const;
@@ -201,6 +203,9 @@
 inline OtherType *Type::getTFControl(MLIRContext *ctx) {
   return OtherType::get(Kind::TFControl, ctx);
 }
+inline OtherType *Type::getTFString(MLIRContext *ctx) {
+  return OtherType::get(Kind::TFString, ctx);
+}
 
 /// Function types map from a list of inputs to a list of results.
 class FunctionType : public Type {
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9a94d66..f881872 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -303,6 +303,9 @@
   case Type::Kind::TFControl:
     os << "tf_control";
     return;
+  case Type::Kind::TFString:
+    os << "tf_string";
+    return;
 
   case Type::Kind::Integer: {
     auto *integer = cast<IntegerType>(type);
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 9ba75e4..199f35d 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -47,6 +47,8 @@
 
 OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
 
+OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
+
 IntegerType *Builder::getIntegerType(unsigned width) {
   return Type::getInteger(width, context);
 }
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 8d8b013..f558ae1 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -260,7 +260,8 @@
 
   /// Copy the specified array of elements into memory managed by our bump
   /// pointer allocator.  This assumes the elements are all PODs.
-  template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
+  template <typename T>
+  ArrayRef<T> copyInto(ArrayRef<T> elements) {
     auto result = allocator.Allocate<T>(elements.size());
     std::uninitialized_copy(elements.begin(), elements.end(), result);
     return ArrayRef<T>(result, elements.size());
@@ -445,11 +446,14 @@
   return *existing.first = result;
 }
 
+static bool isValidTensorElementType(Type *type, MLIRContext *context) {
+  return isa<FloatType>(type) || isa<VectorType>(type) ||
+         isa<IntegerType>(type) || type == Type::getTFString(context);
+}
+
 TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
     : Type(kind, context), elementType(elementType) {
-  assert((isa<FloatType>(elementType) || isa<VectorType>(elementType) ||
-          isa<IntegerType>(elementType)) &&
-         "tensor elements must be primitives or vectors");
+  assert(isValidTensorElementType(elementType, context));
   assert(isa<TensorType>(this));
 }
 
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index a468c9e..24bda89 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -309,6 +309,9 @@
   case Token::kw_tf_control:
     consumeToken(Token::kw_tf_control);
     return builder.getTFControlType();
+  case Token::kw_tf_string:
+    consumeToken(Token::kw_tf_string);
+    return builder.getTFStringType();
   }
 }
 
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index 9e60328..6d71884 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -109,6 +109,7 @@
 TOK_KEYWORD(step)
 TOK_KEYWORD(tensor)
 TOK_KEYWORD(tf_control)
+TOK_KEYWORD(tf_string)
 TOK_KEYWORD(to)
 TOK_KEYWORD(true)
 TOK_KEYWORD(vector)