Introduce a new extract_element operation that does what it says. Introduce a
new VectorOrTensorType class that provides a common interface between vector
and tensor since a number of operations will be uniform across them (including
extract_element). Improve the LoadOp verifier.
I also updated the MLIR spec doc as well.
PiperOrigin-RevId: 209953189
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index e59b59c..024e1b6 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -454,17 +454,6 @@
return *existing.first = result;
}
-static bool isValidTensorElementType(Type *type, MLIRContext *context) {
- return isa<FloatType>(type) || isa<VectorType>(type) ||
- isa<IntegerType>(type) || type == Type::getTFString(context);
-}
-
-TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
- : Type(kind, context), elementType(elementType) {
- assert(isValidTensorElementType(elementType, context));
- assert(isa<TensorType>(this));
-}
-
RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
Type *elementType) {
auto *context = elementType->getContext();
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 9f01c13..18507b0 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -520,9 +520,77 @@
}
//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+void ExtractElementOp::build(Builder *builder, OperationState *result,
+ SSAValue *aggregate,
+ ArrayRef<SSAValue *> indices) {
+ auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+ result->addOperands(aggregate);
+ result->addOperands(indices);
+ result->types.push_back(aggregateType->getElementType());
+}
+
+void ExtractElementOp::print(OpAsmPrinter *p) const {
+ *p << "extract_element " << *getAggregate() << '[';
+ p->printOperands(getIndices());
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getAggregate()->getType();
+}
+
+bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType aggregateInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ VectorOrTensorType *type;
+
+ auto affineIntTy = parser->getBuilder().getAffineIntType();
+ return parser->parseOperand(aggregateInfo) ||
+ parser->parseOperandList(indexInfo, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(aggregateInfo, type, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+ parser->addTypeToList(type->getElementType(), result->types);
+}
+
+const char *ExtractElementOp::verify() const {
+ if (getNumOperands() == 0)
+ return "expected an aggregate to index into";
+
+ auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+ if (!aggregateType)
+ return "first operand must be a vector or tensor";
+
+ if (getResult()->getType() != aggregateType->getElementType())
+ return "result type must match element type of aggregate";
+
+ for (auto *idx : getIndices())
+ if (!idx->getType()->isAffineInt())
+ return "index to extract_element must have 'affineint' type";
+
+ // Verify the # indices match if we have a ranked type.
+ auto aggregateRank = aggregateType->getRankIfPresent();
+ if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
+ return "incorrect number of indices for extract_element";
+
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
+void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
+ ArrayRef<SSAValue *> indices) {
+ auto *memrefType = cast<MemRefType>(memref->getType());
+ result->addOperands(memref);
+ result->addOperands(indices);
+ result->types.push_back(memrefType->getElementType());
+}
+
void LoadOp::print(OpAsmPrinter *p) const {
*p << "load " << *getMemRef() << '[';
p->printOperands(getIndices());
@@ -555,6 +623,12 @@
if (!memRefType)
return "first operand must be a memref";
+ if (getResult()->getType() != memRefType->getElementType())
+ return "result type must match element type of memref";
+
+ if (memRefType->getRank() != getNumOperands() - 1)
+ return "incorrect number of indices for load";
+
for (auto *idx : getIndices())
if (!idx->getType()->isAffineInt())
return "index to load must have 'affineint' type";
@@ -671,6 +745,7 @@
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
- ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
+ ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
+ ReturnOp, StoreOp>(
/*prefix=*/"");
}
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 7dfad79..d32fae3 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -35,10 +35,40 @@
numResults(numResults), inputsAndResults(inputsAndResults) {
}
+VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
+ Type *elementType, unsigned subClassData)
+ : Type(kind, context, subClassData), elementType(elementType) {}
+
+/// If this is ranked tensor or vector type, return the rank. If it is an
+/// unranked tensor, return -1.
+int VectorOrTensorType::getRankIfPresent() const {
+ switch (getKind()) {
+ default:
+ llvm_unreachable("not a VectorOrTensorType");
+ case Kind::Vector:
+ return cast<VectorType>(this)->getRank();
+ case Kind::RankedTensor:
+ return cast<RankedTensorType>(this)->getRank();
+ case Kind::UnrankedTensor:
+ return -1;
+ }
+}
+
VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
MLIRContext *context)
- : Type(Kind::Vector, context, shape.size()), shapeElements(shape.data()),
- elementType(elementType) {}
+ : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
+ shapeElements(shape.data()) {}
+
+/// Return true if the specified element type is ok in a tensor.
+static bool isValidTensorElementType(Type *type, MLIRContext *context) {
+ return isa<FloatType>(type) || isa<VectorType>(type) ||
+ isa<IntegerType>(type) || type == Type::getTFString(context);
+}
+
+TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
+ : VectorOrTensorType(kind, context, elementType) {
+ assert(isValidTensorElementType(elementType, context));
+}
RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
MLIRContext *context)
@@ -65,7 +95,7 @@
unsigned MemRefType::getNumDynamicDims() const {
unsigned numDynamicDims = 0;
for (int dimSize : getShape()) {
- if (dimSize < 0)
+ if (dimSize == -1)
++numDynamicDims;
}
return numDynamicDims;