Add standard op for MLIR 'alloc' instruction (with parser and associated tests).
Adds field to MemRefType to query number of dynamic dimensions.

PiperOrigin-RevId: 206633162
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index ff07e73..585e24b 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -86,6 +86,44 @@
   explicit AffineApplyOp(const Operation *state) : OpBase(state) {}
 };
 
+/// The "alloc" operation allocates a region of memory, as specified by its
+/// memref type. For example:
+///
+///   %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
+///
+/// The optional list of dimension operands are bound to the dynamic dimensions
+/// specified in its memref type. In the example below, the ssa value '%d' is
+/// bound to the second dimension of the memref (which is dynamic).
+///
+///   %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1>
+///
+/// The optional list of symbol operands are bound to the symbols of the
+/// memrefs affine map. In the example below, the ssa value '%s' is bound to
+/// the symbol 's0' in the affine map specified in the allocs memref type.
+///
+///   %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+///
+/// This operation returns a single ssa value of memref type, which can be used
+/// by subsequent load and store operations.
+
+class AllocOp
+    : public OpBase<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
+public:
+  SSAValue *getMemRef() { return getOperation()->getResult(0); }
+  const SSAValue *getMemRef() const { return getOperation()->getResult(0); }
+
+  static StringRef getOperationName() { return "alloc"; }
+
+  // Hooks to customize behavior of this op.
+  const char *verify() const;
+  static OpAsmParserResult parse(OpAsmParser *parser);
+  void print(OpAsmPrinter *p) const;
+
+private:
+  friend class Operation;
+  explicit AllocOp(const Operation *state) : OpBase(state) {}
+};
+
 /// The "constant" operation requires a single attribute named "value".
 /// It returns its value as an SSA value.  For example:
 ///
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 75ba406..7e24c79 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -346,6 +346,9 @@
   /// Returns the memory space in which data referred to by this memref resides.
   unsigned getMemorySpace() const { return memorySpace; }
 
+  /// Returns the number of dimensions with dynamic size.
+  unsigned getNumDynamicDims();
+
   static bool classof(const Type *type) {
     return type->getKind() == Kind::MemRef;
   }
@@ -356,11 +359,11 @@
   /// An array of integers which stores the shape dimension sizes.
   const int *shapeElements;
   /// The number of affine maps in the 'affineMapList' array.
-  unsigned numAffineMaps;
-  /// List of affine maps in affine map composition.
+  const unsigned numAffineMaps;
+  /// List of affine maps in the memref's layout/index map composition.
   AffineMap *const *const affineMapList;
   /// Memory space in which data referenced by memref resides.
-  unsigned memorySpace;
+  const unsigned memorySpace;
 
   MemRefType(ArrayRef<int> shape, Type *elementType,
              ArrayRef<AffineMap*> affineMapList, unsigned memorySpace,
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 55c0c25..206b39c 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -25,6 +25,42 @@
 #include "llvm/Support/raw_ostream.h"
 using namespace mlir;
 
+static void printDimAndSymbolList(Operation::const_operand_iterator begin,
+                                  Operation::const_operand_iterator end,
+                                  unsigned numDims, OpAsmPrinter *p) {
+  *p << '(';
+  p->printOperands(begin, begin + numDims);
+  *p << ')';
+
+  if (begin + numDims != end) {
+    *p << '[';
+    p->printOperands(begin + numDims, end);
+    *p << ']';
+  }
+}
+
+// Parses dimension and symbol list, and sets 'numDims' to the number of
+// dimension operands parsed.
+// Returns 'false' on success and 'true' on error.
+static bool
+parseDimAndSymbolList(OpAsmParser *parser,
+                      SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
+                      SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
+  if (parser->parseOperandList(opInfos, -1,
+                               OpAsmParser::Delimeter::ParenDelimeter))
+    return true;
+  // Store number of dimensions for validation by caller.
+  numDims = opInfos.size();
+
+  // Parse the optional symbol operands.
+  auto *affineIntTy = parser->getBuilder().getAffineIntType();
+  if (parser->parseOperandList(
+          opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
+      parser->resolveOperands(opInfos, affineIntTy, operands))
+    return true;
+  return false;
+}
+
 // TODO: Have verify functions return std::string to enable more descriptive
 // error messages.
 OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
@@ -60,18 +96,14 @@
   auto *affineIntTy = builder.getAffineIntType();
 
   AffineMapAttr *mapAttr;
-  if (parser->parseAttribute(mapAttr) ||
-      parser->parseOperandList(opInfos, -1,
-                               OpAsmParser::Delimeter::ParenDelimeter))
-    return {};
-  unsigned numDims = opInfos.size();
-
-  if (parser->parseOperandList(
-          opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
-      parser->resolveOperands(opInfos, affineIntTy, operands))
+  if (parser->parseAttribute(mapAttr))
     return {};
 
+  unsigned numDims;
+  if (parseDimAndSymbolList(parser, opInfos, operands, numDims))
+    return {};
   auto *map = mapAttr->getValue();
+
   if (map->getNumDims() != numDims ||
       numDims + map->getNumSymbols() != opInfos.size()) {
     parser->emitError(parser->getNameLoc(),
@@ -88,17 +120,7 @@
 void AffineApplyOp::print(OpAsmPrinter *p) const {
   auto *map = getAffineMap();
   *p << "affine_apply " << *map;
-
-  auto opit = operand_begin();
-  *p << '(';
-  p->printOperands(opit, opit + map->getNumDims());
-  *p << ')';
-
-  if (map->getNumSymbols()) {
-    *p << '[';
-    p->printOperands(opit + map->getNumDims(), operand_end());
-    *p << ']';
-  }
+  printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
 }
 
 const char *AffineApplyOp::verify() const {
@@ -121,6 +143,57 @@
   return nullptr;
 }
 
+void AllocOp::print(OpAsmPrinter *p) const {
+  MemRefType *type = cast<MemRefType>(getMemRef()->getType());
+  *p << "alloc";
+  // Print dynamic dimension operands.
+  printDimAndSymbolList(operand_begin(), operand_end(),
+                        type->getNumDynamicDims(), p);
+  // Print memref type.
+  *p << " : " << *type;
+}
+
+OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
+  MemRefType *type;
+  SmallVector<SSAValue *, 4> operands;
+  SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
+
+  // Parse the dimension operands and optional symbol operands.
+  unsigned numDimOperands;
+  if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands))
+    return {};
+
+  // Parse memref type.
+  if (parser->parseColonType(type))
+    return {};
+
+  // Check numDynamicDims against number of question marks in memref type.
+  if (numDimOperands != type->getNumDynamicDims()) {
+    parser->emitError(parser->getNameLoc(),
+                      "Dynamic dimensions count mismatch: dimension operand "
+                      "count does not equal memref dynamic dimension count.");
+    return {};
+  }
+
+  // Check that the number of symbol operands matches the number of symbols in
+  // the first affinemap of the memref's affine map composition.
+  // Note that a memref must specify at least one affine map in the composition.
+  if ((operandsInfo.size() - numDimOperands) !=
+      type->getAffineMaps()[0]->getNumSymbols()) {
+    parser->emitError(parser->getNameLoc(),
+                      "AffineMap symbol count mismatch: symbol operand "
+                      "count does not equal memref affine map symbol count.");
+    return {};
+  }
+
+  return OpAsmParserResult(operands, type);
+}
+
+const char *AllocOp::verify() const {
+  // TODO(andydavis): Verify alloc.
+  return nullptr;
+}
+
 /// The constant op requires an attribute, and furthermore requires that it
 /// matches the return type.
 const char *ConstantOp::verify() const {
@@ -240,6 +313,7 @@
 
 /// Install the standard operations in the specified operation set.
 void mlir::registerStandardOperations(OperationSet &opSet) {
-  opSet.addOperations<AddFOp, AffineApplyOp, ConstantOp, DimOp, LoadOp>(
-      /*prefix=*/"");
+  opSet
+      .addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp>(
+          /*prefix=*/"");
 }
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index daf4679..42c10ed 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -52,14 +52,21 @@
 }
 
 MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
-                       ArrayRef<AffineMap*> affineMapList,
+                       ArrayRef<AffineMap *> affineMapList,
                        unsigned memorySpace, MLIRContext *context)
-  : Type(Kind::MemRef, context, shape.size()),
-    elementType(elementType), shapeElements(shape.data()),
-    numAffineMaps(affineMapList.size()), affineMapList(affineMapList.data()),
-    memorySpace(memorySpace) {
-}
+    : Type(Kind::MemRef, context, shape.size()), elementType(elementType),
+      shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
+      affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
 
 ArrayRef<AffineMap*> MemRefType::getAffineMaps() const {
   return ArrayRef<AffineMap*>(affineMapList, numAffineMaps);
 }
+
+unsigned MemRefType::getNumDynamicDims() {
+  unsigned numDynamicDims = 0;
+  for (int dimSize : getShape()) {
+    if (dimSize < 0)
+      ++numDynamicDims;
+  }
+  return numDynamicDims;
+}
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
index 2991d63..572147b 100644
--- a/test/IR/invalid-ops.mlir
+++ b/test/IR/invalid-ops.mlir
@@ -68,3 +68,23 @@
   %i = crazyThing() {value: 0} : () -> affineint  // expected-error {{custom op 'crazyThing' is unknown}}
   return
 }
+
+// -----
+
+cfgfunc @bad_alloc_wrong_dynamic_dim_count() {
+bb0:
+  %0 = "constant"() {value: 7} : () -> affineint
+  // Test alloc with wrong number of dynamic dimensions.
+  %1 = alloc(%0)[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> // expected-error {{Dynamic dimensions count mismatch: dimension operand count does not equal memref dynamic dimension count}}
+  return
+}
+
+// -----
+
+cfgfunc @bad_alloc_wrong_symbol_count() {
+bb0:
+  %0 = "constant"() {value: 7} : () -> affineint
+  // Test alloc with wrong number of symbols
+  %1 = alloc(%0) : memref<2x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> // expected-error {{AffineMap symbol count mismatch: symbol operand count does not equal memref affine map symbol count}}
+  return
+}
diff --git a/test/IR/memory-ops.mlir b/test/IR/memory-ops.mlir
new file mode 100644
index 0000000..9181f03
--- /dev/null
+++ b/test/IR/memory-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: %S/../../mlir-opt %s -o - | FileCheck %s
+
+// CHECK-LABEL: cfgfunc @alloc() {
+cfgfunc @alloc() {
+bb0:
+  // Test simple alloc.
+  // CHECK: %0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
+  %0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
+
+  %1 = "constant"() {value: 0} : () -> affineint
+  %2 = "constant"() {value: 1} : () -> affineint
+
+  // Test alloc with dynamic dimensions.
+  // CHECK: %3 = alloc(%1, %2) : memref<?x?xf32, (d0, d1) -> (d0, d1), 1>
+  %3 = alloc(%1, %2) : memref<?x?xf32, (d0, d1) -> (d0, d1), 1>
+
+  // Test alloc with no dynamic dimensions and one symbol.
+  // CHECK: %4 = alloc()[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+  %4 = alloc()[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+
+  // Test alloc with dynamic dimensions and one symbol.
+  // CHECK: %5 = alloc(%2)[%1] : memref<2x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+  %5 = alloc(%2)[%1] : memref<2x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+
+  // CHECK:   return
+  return
+}
\ No newline at end of file