Introduce a new extract_element operation that does what it says.  Introduce a
new VectorOrTensorType class that provides a common interface between vector
and tensor since a number of operations will be uniform across them (including
extract_element).  Improve the LoadOp verifier.

I also updated the MLIR spec doc as well.

PiperOrigin-RevId: 209953189
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 9f01c13..18507b0 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -520,9 +520,77 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+void ExtractElementOp::build(Builder *builder, OperationState *result,
+                             SSAValue *aggregate,
+                             ArrayRef<SSAValue *> indices) {
+  auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+  result->addOperands(aggregate);
+  result->addOperands(indices);
+  result->types.push_back(aggregateType->getElementType());
+}
+
+void ExtractElementOp::print(OpAsmPrinter *p) const {
+  *p << "extract_element " << *getAggregate() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << *getAggregate()->getType();
+}
+
+bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType aggregateInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  VectorOrTensorType *type;
+
+  auto affineIntTy = parser->getBuilder().getAffineIntType();
+  return parser->parseOperand(aggregateInfo) ||
+         parser->parseOperandList(indexInfo, -1,
+                                  OpAsmParser::Delimiter::Square) ||
+         parser->parseOptionalAttributeDict(result->attributes) ||
+         parser->parseColonType(type) ||
+         parser->resolveOperand(aggregateInfo, type, result->operands) ||
+         parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+         parser->addTypeToList(type->getElementType(), result->types);
+}
+
+const char *ExtractElementOp::verify() const {
+  if (getNumOperands() == 0)
+    return "expected an aggregate to index into";
+
+  auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+  if (!aggregateType)
+    return "first operand must be a vector or tensor";
+
+  if (getResult()->getType() != aggregateType->getElementType())
+    return "result type must match element type of aggregate";
+
+  for (auto *idx : getIndices())
+    if (!idx->getType()->isAffineInt())
+      return "index to extract_element must have 'affineint' type";
+
+  // Verify the # indices match if we have a ranked type.
+  auto aggregateRank = aggregateType->getRankIfPresent();
+  if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
+    return "incorrect number of indices for extract_element";
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
+                   ArrayRef<SSAValue *> indices) {
+  auto *memrefType = cast<MemRefType>(memref->getType());
+  result->addOperands(memref);
+  result->addOperands(indices);
+  result->types.push_back(memrefType->getElementType());
+}
+
 void LoadOp::print(OpAsmPrinter *p) const {
   *p << "load " << *getMemRef() << '[';
   p->printOperands(getIndices());
@@ -555,6 +623,12 @@
   if (!memRefType)
     return "first operand must be a memref";
 
+  if (getResult()->getType() != memRefType->getElementType())
+    return "result type must match element type of memref";
+
+  if (memRefType->getRank() != getNumOperands() - 1)
+    return "incorrect number of indices for load";
+
   for (auto *idx : getIndices())
     if (!idx->getType()->isAffineInt())
       return "index to load must have 'affineint' type";
@@ -671,6 +745,7 @@
 /// Install the standard operations in the specified operation set.
 void mlir::registerStandardOperations(OperationSet &opSet) {
   opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
-                      ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
+                      ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
+                      ReturnOp, StoreOp>(
       /*prefix=*/"");
 }