Introduce a new extract_element operation that does what it says.  Introduce a
new VectorOrTensorType class that provides a common interface between vector
and tensor since a number of operations will be uniform across them (including
extract_element).  Improve the LoadOp verifier.

I also updated the MLIR spec doc as well.

PiperOrigin-RevId: 209953189
diff --git a/lib/IR/MLIRContext.cpp b/lib/IR/MLIRContext.cpp
index e59b59c..024e1b6 100644
--- a/lib/IR/MLIRContext.cpp
+++ b/lib/IR/MLIRContext.cpp
@@ -454,17 +454,6 @@
   return *existing.first = result;
 }
 
-static bool isValidTensorElementType(Type *type, MLIRContext *context) {
-  return isa<FloatType>(type) || isa<VectorType>(type) ||
-         isa<IntegerType>(type) || type == Type::getTFString(context);
-}
-
-TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
-    : Type(kind, context), elementType(elementType) {
-  assert(isValidTensorElementType(elementType, context));
-  assert(isa<TensorType>(this));
-}
-
 RankedTensorType *RankedTensorType::get(ArrayRef<int> shape,
                                         Type *elementType) {
   auto *context = elementType->getContext();
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 9f01c13..18507b0 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -520,9 +520,77 @@
 }
 
 //===----------------------------------------------------------------------===//
+// ExtractElementOp
+//===----------------------------------------------------------------------===//
+
+void ExtractElementOp::build(Builder *builder, OperationState *result,
+                             SSAValue *aggregate,
+                             ArrayRef<SSAValue *> indices) {
+  auto *aggregateType = cast<VectorOrTensorType>(aggregate->getType());
+  result->addOperands(aggregate);
+  result->addOperands(indices);
+  result->types.push_back(aggregateType->getElementType());
+}
+
+void ExtractElementOp::print(OpAsmPrinter *p) const {
+  *p << "extract_element " << *getAggregate() << '[';
+  p->printOperands(getIndices());
+  *p << ']';
+  p->printOptionalAttrDict(getAttrs());
+  *p << " : " << *getAggregate()->getType();
+}
+
+bool ExtractElementOp::parse(OpAsmParser *parser, OperationState *result) {
+  OpAsmParser::OperandType aggregateInfo;
+  SmallVector<OpAsmParser::OperandType, 4> indexInfo;
+  VectorOrTensorType *type;
+
+  auto affineIntTy = parser->getBuilder().getAffineIntType();
+  return parser->parseOperand(aggregateInfo) ||
+         parser->parseOperandList(indexInfo, -1,
+                                  OpAsmParser::Delimiter::Square) ||
+         parser->parseOptionalAttributeDict(result->attributes) ||
+         parser->parseColonType(type) ||
+         parser->resolveOperand(aggregateInfo, type, result->operands) ||
+         parser->resolveOperands(indexInfo, affineIntTy, result->operands) ||
+         parser->addTypeToList(type->getElementType(), result->types);
+}
+
+const char *ExtractElementOp::verify() const {
+  if (getNumOperands() == 0)
+    return "expected an aggregate to index into";
+
+  auto *aggregateType = dyn_cast<VectorOrTensorType>(getAggregate()->getType());
+  if (!aggregateType)
+    return "first operand must be a vector or tensor";
+
+  if (getResult()->getType() != aggregateType->getElementType())
+    return "result type must match element type of aggregate";
+
+  for (auto *idx : getIndices())
+    if (!idx->getType()->isAffineInt())
+      return "index to extract_element must have 'affineint' type";
+
+  // Verify the # indices match if we have a ranked type.
+  auto aggregateRank = aggregateType->getRankIfPresent();
+  if (aggregateRank != -1 && aggregateRank != getNumOperands() - 1)
+    return "incorrect number of indices for extract_element";
+
+  return nullptr;
+}
+
+//===----------------------------------------------------------------------===//
 // LoadOp
 //===----------------------------------------------------------------------===//
 
+void LoadOp::build(Builder *builder, OperationState *result, SSAValue *memref,
+                   ArrayRef<SSAValue *> indices) {
+  auto *memrefType = cast<MemRefType>(memref->getType());
+  result->addOperands(memref);
+  result->addOperands(indices);
+  result->types.push_back(memrefType->getElementType());
+}
+
 void LoadOp::print(OpAsmPrinter *p) const {
   *p << "load " << *getMemRef() << '[';
   p->printOperands(getIndices());
@@ -555,6 +623,12 @@
   if (!memRefType)
     return "first operand must be a memref";
 
+  if (getResult()->getType() != memRefType->getElementType())
+    return "result type must match element type of memref";
+
+  if (memRefType->getRank() != getNumOperands() - 1)
+    return "incorrect number of indices for load";
+
   for (auto *idx : getIndices())
     if (!idx->getType()->isAffineInt())
       return "index to load must have 'affineint' type";
@@ -671,6 +745,7 @@
 /// Install the standard operations in the specified operation set.
 void mlir::registerStandardOperations(OperationSet &opSet) {
   opSet.addOperations<AddFOp, AffineApplyOp, AllocOp, CallOp, CallIndirectOp,
-                      ConstantOp, DeallocOp, DimOp, LoadOp, ReturnOp, StoreOp>(
+                      ConstantOp, DeallocOp, DimOp, ExtractElementOp, LoadOp,
+                      ReturnOp, StoreOp>(
       /*prefix=*/"");
 }
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index 7dfad79..d32fae3 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -35,10 +35,40 @@
     numResults(numResults), inputsAndResults(inputsAndResults) {
 }
 
+VectorOrTensorType::VectorOrTensorType(Kind kind, MLIRContext *context,
+                                       Type *elementType, unsigned subClassData)
+    : Type(kind, context, subClassData), elementType(elementType) {}
+
+/// If this is ranked tensor or vector type, return the rank.  If it is an
+/// unranked tensor, return -1.
+int VectorOrTensorType::getRankIfPresent() const {
+  switch (getKind()) {
+  default:
+    llvm_unreachable("not a VectorOrTensorType");
+  case Kind::Vector:
+    return cast<VectorType>(this)->getRank();
+  case Kind::RankedTensor:
+    return cast<RankedTensorType>(this)->getRank();
+  case Kind::UnrankedTensor:
+    return -1;
+  }
+}
+
 VectorType::VectorType(ArrayRef<unsigned> shape, Type *elementType,
                        MLIRContext *context)
-    : Type(Kind::Vector, context, shape.size()), shapeElements(shape.data()),
-      elementType(elementType) {}
+    : VectorOrTensorType(Kind::Vector, context, elementType, shape.size()),
+      shapeElements(shape.data()) {}
+
+/// Return true if the specified element type is ok in a tensor.
+static bool isValidTensorElementType(Type *type, MLIRContext *context) {
+  return isa<FloatType>(type) || isa<VectorType>(type) ||
+         isa<IntegerType>(type) || type == Type::getTFString(context);
+}
+
+TensorType::TensorType(Kind kind, Type *elementType, MLIRContext *context)
+    : VectorOrTensorType(kind, context, elementType) {
+  assert(isValidTensorElementType(elementType, context));
+}
 
 RankedTensorType::RankedTensorType(ArrayRef<int> shape, Type *elementType,
                                    MLIRContext *context)
@@ -65,7 +95,7 @@
 unsigned MemRefType::getNumDynamicDims() const {
   unsigned numDynamicDims = 0;
   for (int dimSize : getShape()) {
-    if (dimSize < 0)
+    if (dimSize == -1)
       ++numDynamicDims;
   }
   return numDynamicDims;