[mlir] Add a string type
PiperOrigin-RevId: 206977161
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index c6a09f1..67cc7b8 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -64,6 +64,7 @@
OtherType *getAffineIntType();
OtherType *getTFControlType();
+ OtherType *getTFStringType();
IntegerType *getIntegerType(unsigned width);
FunctionType *getFunctionType(ArrayRef<Type *> inputs,
ArrayRef<Type *> results);
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 9c8233d..606b7b8 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -40,10 +40,11 @@
// TensorFlow types.
TFControl,
+ TFString,
/// These are marker for the first and last 'other' type.
FIRST_OTHER_TYPE = AffineInt,
- LAST_OTHER_TYPE = TFControl,
+ LAST_OTHER_TYPE = TFString,
// Floating point.
BF16,
@@ -87,6 +88,7 @@
static FloatType *getF64(MLIRContext *ctx);
static OtherType *getAffineInt(MLIRContext *ctx);
static OtherType *getTFControl(MLIRContext *ctx);
+ static OtherType *getTFString(MLIRContext *ctx);
/// Print the current type.
void print(raw_ostream &os) const;
@@ -201,6 +203,9 @@
inline OtherType *Type::getTFControl(MLIRContext *ctx) {
return OtherType::get(Kind::TFControl, ctx);
}
+inline OtherType *Type::getTFString(MLIRContext *ctx) {
+ return OtherType::get(Kind::TFString, ctx);
+}
/// Function types map from a list of inputs to a list of results.
class FunctionType : public Type {
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 9a94d66..f881872 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -303,6 +303,9 @@
case Type::Kind::TFControl:
os << "tf_control";
return;
+ case Type::Kind::TFString:
+ os << "tf_string";
+ return;
case Type::Kind::Integer: {
auto *integer = cast<IntegerType>(type);
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 9ba75e4..199f35d 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -47,6 +47,8 @@
OtherType *Builder::getTFControlType() { return Type::getTFControl(context); }
+OtherType *Builder::getTFStringType() { return Type::getTFString(context); }
+
IntegerType *Builder::getIntegerType(unsigned width) {
return Type::getInteger(width, context);
}
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index 8d8b013..f558ae1 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -260,7 +260,8 @@
/// Copy the specified array of elements into memory managed by our bump
/// pointer allocator. This assumes the elements are all PODs.
- template <typename T> ArrayRef<T> copyInto(ArrayRef<T> elements) {
+ template <typename T>
+ ArrayRef<T> copyInto(ArrayRef<T> elements) {
auto result = allocator.Allocate<T>(elements.size());
std::uninitialized_copy(elements.begin(), elements.end(), result);
return ArrayRef<T>(result, elements.size());
@@ -445,11 +446,14 @@
return *existing.first = result;
}
+static bool isValidTensorElementType(Type *type, MLIRContext *context) {
+ return isa<FloatType>(type) || isa<VectorType>(type) ||
+ isa<IntegerType>(type) || type == Type::getTFString(context);
+}
+
TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
: Type(kind, context), elementType(elementType) {
- assert((isa<FloatType>(elementType) || isa<VectorType>(elementType) ||
- isa<IntegerType>(elementType)) &&
- "tensor elements must be primitives or vectors");
+ assert(isValidTensorElementType(elementType, context));
assert(isa<TensorType>(this));
}
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index a468c9e..24bda89 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -309,6 +309,9 @@
case Token::kw_tf_control:
consumeToken(Token::kw_tf_control);
return builder.getTFControlType();
+ case Token::kw_tf_string:
+ consumeToken(Token::kw_tf_string);
+ return builder.getTFStringType();
}
}
diff --git a/lib/Parser/TokenKinds.def b/lib/Parser/TokenKinds.def
index 9e60328..6d71884 100644
--- a/lib/Parser/TokenKinds.def
+++ b/lib/Parser/TokenKinds.def
@@ -109,6 +109,7 @@
TOK_KEYWORD(step)
TOK_KEYWORD(tensor)
TOK_KEYWORD(tf_control)
+TOK_KEYWORD(tf_string)
TOK_KEYWORD(to)
TOK_KEYWORD(true)
TOK_KEYWORD(vector)