Introduce a new extract_element operation that does what it says. Introduce a
new VectorOrTensorType class that provides a common interface between vector
and tensor since a number of operations will be uniform across them (including
extract_element). Improve the LoadOp verifier.
I also updated the MLIR spec doc as well.
PiperOrigin-RevId: 209953189
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 9f01c13..18507b0 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -520,9 +520,77 @@
}
//===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+void ExtractElementOp::build(Builder *builder, OperationState *result,
+ SSAValue *aggregate,
+ ArrayRef<SSAValue *> indices) {
+ auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+ result->addOperands(aggregate);
+ result->addOperands(indices);
+ result->types.push_back(aggregateType->getElementType());
+}
+
+void ExtractElementOp::print(OpAsmPrinter *p) const {
+ *p << "extract_element " << *getAggregate() << '[';
+ p->printOperands(getIndices());
+ *p << ']';
+ p->printOptionalAttrDict(getAttrs());
+ *p << " : " << *getAggregate()->getType();
+}
+
+bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
+ OpAsmParser::OperandType aggregateInfo;
+ SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+ VectorOrTensorType *type;
+
+ auto affineIntTy = parser->getBuilder().getAffineIntType();
+ return parser->parseOperand(aggregateInfo) ||
+ parser->parseOperandList(indexInfo, -1,
+ OpAsmParser::Delimiter::Square) ||
+ parser->parseOptionalAttributeDict(result->attributes) ||
+ parser->parseColonType(type) ||
+ parser->resolveOperand(aggregateInfo, type, result->operands) ||
+ parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+ parser->addTypeToList(type->getElementType(), result->types);
+}
+
+const char *ExtractElementOp::verify() const {
+ if (getNumOperands() == 0)
+ return "expected an aggregate to index into";
+
+ auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+ if (!aggregateType)
+ return "first operand must be a vector or tensor";
+
+ if (getResult()->getType() != aggregateType->getElementType())
+ return "result type must match element type of aggregate";
+
+ for (auto *idx : getIndices())
+ if (!idx->getType()->isAffineInt())
+ return "index to extract_element must have 'affineint' type";
+
+ // Verify the # indices match if we have a ranked type.
+ auto aggregateRank = aggregateType->getRankIfPresent();
+ if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
+ return "incorrect number of indices for extract_element";
+
+ return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
// LoadOp
//===----------------------------------------------------------------------===//
+void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
+ ArrayRef<SSAValue *> indices) {
+ auto *memrefType = cast<MemRefType>(memref->getType());
+ result->addOperands(memref);
+ result->addOperands(indices);
+ result->types.push_back(memrefType->getElementType());
+}
+
void LoadOp::print(OpAsmPrinter *p) const {
*p << "load " << *getMemRef() << '[';
p->printOperands(getIndices());
@@ -555,6 +623,12 @@
if (!memRefType)
return "first operand must be a memref";
+ if (getResult()->getType() != memRefType->getElementType())
+ return "result type must match element type of memref";
+
+ if (memRefType->getRank() != getNumOperands() - 1)
+ return "incorrect number of indices for load";
+
for (auto *idx : getIndices())
if (!idx->getType()->isAffineInt())
return "index to load must have 'affineint' type";
@@ -671,6 +745,7 @@
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
- ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
+ ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
+ ReturnOp, StoreOp>(
/*prefix=*/"");
}