Add standard op for MLIR 'alloc' instruction (with parser and associated tests).
Adds field to MemRefType to query number of dynamic dimensions.
PiperOrigin-RevId: 206633162
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index ff07e73..585e24b 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -86,6 +86,44 @@
explicit AffineApplyOp(const Operation *state) : OpBase(state) {}
};
+/// The "alloc" operation allocates a region of memory, as specified by its
+/// memref type. For example:
+///
+/// %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1>
+///
+/// The optional list of dimension operands are bound to the dynamic dimensions
+/// specified in its memref type. In the example below, the ssa value '%d' is
+/// bound to the second dimension of the memref (which is dynamic).
+///
+/// %0 = alloc(%d) : memref<8x?xf32, (d0, d1) -> (d0, d1), 1>
+///
+/// The optional list of symbol operands are bound to the symbols of the
+/// memrefs affine map. In the example below, the ssa value '%s' is bound to
+/// the symbol 's0' in the affine map specified in the allocs memref type.
+///
+/// %0 = alloc()[%s] : memref<8x64xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+///
+/// This operation returns a single ssa value of memref type, which can be used
+/// by subsequent load and store operations.
+
+class AllocOp
+ : public OpBase<AllocOp, OpTrait::VariadicOperands, OpTrait::OneResult> {
+public:
+ SSAValue *getMemRef() { return getOperation()->getResult(0); }
+ const SSAValue *getMemRef() const { return getOperation()->getResult(0); }
+
+ static StringRef getOperationName() { return "alloc"; }
+
+ // Hooks to customize behavior of this op.
+ const char *verify() const;
+ static OpAsmParserResult parse(OpAsmParser *parser);
+ void print(OpAsmPrinter *p) const;
+
+private:
+ friend class Operation;
+ explicit AllocOp(const Operation *state) : OpBase(state) {}
+};
+
/// The "constant" operation requires a single attribute named "value".
/// It returns its value as an SSA value. For example:
///
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 75ba406..7e24c79 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -346,6 +346,9 @@
/// Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpace() const { return memorySpace; }
+ /// Returns the number of dimensions with dynamic size.
+ unsigned getNumDynamicDims();
+
static bool classof(const Type *type) {
return type->getKind() == Kind::MemRef;
}
@@ -356,11 +359,11 @@
/// An array of integers which stores the shape dimension sizes.
const int *shapeElements;
/// The number of affine maps in the 'affineMapList' array.
- unsigned numAffineMaps;
- /// List of affine maps in affine map composition.
+ const unsigned numAffineMaps;
+ /// List of affine maps in the memref's layout/index map composition.
AffineMap *const *const affineMapList;
/// Memory space in which data referenced by memref resides.
- unsigned memorySpace;
+ const unsigned memorySpace;
MemRefType(ArrayRef<int> shape, Type *elementType,
ArrayRef<AffineMap*> affineMapList, unsigned memorySpace,
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 55c0c25..206b39c 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -25,6 +25,42 @@
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
+static void printDimAndSymbolList(Operation::const_operand_iterator begin,
+ Operation::const_operand_iterator end,
+ unsigned numDims, OpAsmPrinter *p) {
+ *p << '(';
+ p->printOperands(begin, begin + numDims);
+ *p << ')';
+
+ if (begin + numDims != end) {
+ *p << '[';
+ p->printOperands(begin + numDims, end);
+ *p << ']';
+ }
+}
+
+// Parses dimension and symbol list, and sets 'numDims' to the number of
+// dimension operands parsed.
+// Returns 'false' on success and 'true' on error.
+static bool
+parseDimAndSymbolList(OpAsmParser *parser,
+ SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
+ SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
+ if (parser->parseOperandList(opInfos, -1,
+ OpAsmParser::Delimeter::ParenDelimeter))
+ return true;
+ // Store number of dimensions for validation by caller.
+ numDims = opInfos.size();
+
+ // Parse the optional symbol operands.
+ auto *affineIntTy = parser->getBuilder().getAffineIntType();
+ if (parser->parseOperandList(
+ opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
+ parser->resolveOperands(opInfos, affineIntTy, operands))
+ return true;
+ return false;
+}
+
// TODO: Have verify functions return std::string to enable more descriptive
// error messages.
OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
@@ -60,18 +96,14 @@
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
- if (parser->parseAttribute(mapAttr) ||
- parser->parseOperandList(opInfos, -1,
- OpAsmParser::Delimeter::ParenDelimeter))
- return {};
- unsigned numDims = opInfos.size();
-
- if (parser->parseOperandList(
- opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
- parser->resolveOperands(opInfos, affineIntTy, operands))
+ if (parser->parseAttribute(mapAttr))
return {};
+ unsigned numDims;
+ if (parseDimAndSymbolList(parser, opInfos, operands, numDims))
+ return {};
auto *map = mapAttr->getValue();
+
if (map->getNumDims() != numDims ||
numDims + map->getNumSymbols() != opInfos.size()) {
parser->emitError(parser->getNameLoc(),
@@ -88,17 +120,7 @@
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto *map = getAffineMap();
*p << "affine_apply " << *map;
-
- auto opit = operand_begin();
- *p << '(';
- p->printOperands(opit, opit + map->getNumDims());
- *p << ')';
-
- if (map->getNumSymbols()) {
- *p << '[';
- p->printOperands(opit + map->getNumDims(), operand_end());
- *p << ']';
- }
+ printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
}
const char *AffineApplyOp::verify() const {
@@ -121,6 +143,57 @@
return nullptr;
}
+void AllocOp::print(OpAsmPrinter *p) const {
+ MemRefType *type = cast<MemRefType>(getMemRef()->getType());
+ *p << "alloc";
+ // Print dynamic dimension operands.
+ printDimAndSymbolList(operand_begin(), operand_end(),
+ type->getNumDynamicDims(), p);
+ // Print memref type.
+ *p << " : " << *type;
+}
+
+OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
+ MemRefType *type;
+ SmallVector<SSAValue *, 4> operands;
+ SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
+
+ // Parse the dimension operands and optional symbol operands.
+ unsigned numDimOperands;
+ if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands))
+ return {};
+
+ // Parse memref type.
+ if (parser->parseColonType(type))
+ return {};
+
+ // Check numDynamicDims against number of question marks in memref type.
+ if (numDimOperands != type->getNumDynamicDims()) {
+ parser->emitError(parser->getNameLoc(),
+ "Dynamic dimensions count mismatch: dimension operand "
+ "count does not equal memref dynamic dimension count.");
+ return {};
+ }
+
+ // Check that the number of symbol operands matches the number of symbols in
+ // the first affinemap of the memref's affine map composition.
+ // Note that a memref must specify at least one affine map in the composition.
+ if ((operandsInfo.size() - numDimOperands) !=
+ type->getAffineMaps()[0]->getNumSymbols()) {
+ parser->emitError(parser->getNameLoc(),
+ "AffineMap symbol count mismatch: symbol operand "
+ "count does not equal memref affine map symbol count.");
+ return {};
+ }
+
+ return OpAsmParserResult(operands, type);
+}
+
+const char *AllocOp::verify() const {
+ // TODO(andydavis): Verify alloc.
+ return nullptr;
+}
+
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
const char *ConstantOp::verify() const {
@@ -240,6 +313,7 @@
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
- opSet.addOperations<AddFOp, AffineApplyOp, ConstantOp, DimOp, LoadOp>(
- /*prefix=*/"");
+ opSet
+ .addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp>(
+ /*prefix=*/"");
}
diff --git a/lib/IR/Types.cpp b/lib/IR/Types.cpp
index daf4679..42c10ed 100644
--- a/lib/IR/Types.cpp
+++ b/lib/IR/Types.cpp
@@ -52,14 +52,21 @@
}
MemRefType::MemRefType(ArrayRef<int> shape, Type *elementType,
- ArrayRef<AffineMap*> affineMapList,
+ ArrayRef<AffineMap *> affineMapList,
unsigned memorySpace, MLIRContext *context)
- : Type(Kind::MemRef, context, shape.size()),
- elementType(elementType), shapeElements(shape.data()),
- numAffineMaps(affineMapList.size()), affineMapList(affineMapList.data()),
- memorySpace(memorySpace) {
-}
+ : Type(Kind::MemRef, context, shape.size()), elementType(elementType),
+ shapeElements(shape.data()), numAffineMaps(affineMapList.size()),
+ affineMapList(affineMapList.data()), memorySpace(memorySpace) {}
ArrayRef<AffineMap*> MemRefType::getAffineMaps() const {
return ArrayRef<AffineMap*>(affineMapList, numAffineMaps);
}
+
+unsigned MemRefType::getNumDynamicDims() {
+ unsigned numDynamicDims = 0;
+ for (int dimSize : getShape()) {
+ if (dimSize < 0)
+ ++numDynamicDims;
+ }
+ return numDynamicDims;
+}
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
index 2991d63..572147b 100644
--- a/test/IR/invalid-ops.mlir
+++ b/test/IR/invalid-ops.mlir
@@ -68,3 +68,23 @@
%i = crazyThing() {value: 0} : () -> affineint // expected-error {{custom op 'crazyThing' is unknown}}
return
}
+
+// -----
+
+cfgfunc @bad_alloc_wrong_dynamic_dim_count() {
+bb0:
+ %0 = "constant"() {value: 7} : () -> affineint
+ // Test alloc with wrong number of dynamic dimensions.
+ %1 = alloc(%0)[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> // expected-error {{Dynamic dimensions count mismatch: dimension operand count does not equal memref dynamic dimension count}}
+ return
+}
+
+// -----
+
+cfgfunc @bad_alloc_wrong_symbol_count() {
+bb0:
+ %0 = "constant"() {value: 7} : () -> affineint
+ // Test alloc with wrong number of symbols
+ %1 = alloc(%0) : memref<2x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1> // expected-error {{AffineMap symbol count mismatch: symbol operand count does not equal memref affine map symbol count}}
+ return
+}
diff --git a/test/IR/memory-ops.mlir b/test/IR/memory-ops.mlir
new file mode 100644
index 0000000..9181f03
--- /dev/null
+++ b/test/IR/memory-ops.mlir
@@ -0,0 +1,27 @@
+// RUN: %S/../../mlir-opt %s -o - | FileCheck %s
+
+// CHECK-LABEL: cfgfunc @alloc() {
+cfgfunc @alloc() {
+bb0:
+ // Test simple alloc.
+ // CHECK: %0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
+ %0 = alloc() : memref<1024x64xf32, (d0, d1) -> (d0, d1), 1>
+
+ %1 = "constant"() {value: 0} : () -> affineint
+ %2 = "constant"() {value: 1} : () -> affineint
+
+ // Test alloc with dynamic dimensions.
+ // CHECK: %3 = alloc(%1, %2) : memref<?x?xf32, (d0, d1) -> (d0, d1), 1>
+ %3 = alloc(%1, %2) : memref<?x?xf32, (d0, d1) -> (d0, d1), 1>
+
+ // Test alloc with no dynamic dimensions and one symbol.
+ // CHECK: %4 = alloc()[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+ %4 = alloc()[%1] : memref<2x4xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+
+ // Test alloc with dynamic dimensions and one symbol.
+ // CHECK: %5 = alloc(%2)[%1] : memref<2x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+ %5 = alloc(%2)[%1] : memref<2x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1), 1>
+
+ // CHECK: return
+ return
+}
\ No newline at end of file