Add standard op for MLIR 'alloc' instruction (with parser and associated tests).
Adds field to MemRefType to query number of dynamic dimensions.
PiperOrigin-RevId: 206633162
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 55c0c25..206b39c 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -25,6 +25,42 @@
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
+static void printDimAndSymbolList(Operation::const_operand_iterator begin,
+ Operation::const_operand_iterator end,
+ unsigned numDims, OpAsmPrinter *p) {
+ *p << '(';
+ p->printOperands(begin, begin + numDims);
+ *p << ')';
+
+ if (begin + numDims != end) {
+ *p << '[';
+ p->printOperands(begin + numDims, end);
+ *p << ']';
+ }
+}
+
+// Parses dimension and symbol list, and sets 'numDims' to the number of
+// dimension operands parsed.
+// Returns 'false' on success and 'true' on error.
+static bool
+parseDimAndSymbolList(OpAsmParser *parser,
+ SmallVectorImpl<OpAsmParser::OperandType> &opInfos,
+ SmallVector<SSAValue *, 4> &operands, unsigned &numDims) {
+ if (parser->parseOperandList(opInfos, -1,
+ OpAsmParser::Delimeter::ParenDelimeter))
+ return true;
+ // Store number of dimensions for validation by caller.
+ numDims = opInfos.size();
+
+ // Parse the optional symbol operands.
+ auto *affineIntTy = parser->getBuilder().getAffineIntType();
+ if (parser->parseOperandList(
+ opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
+ parser->resolveOperands(opInfos, affineIntTy, operands))
+ return true;
+ return false;
+}
+
// TODO: Have verify functions return std::string to enable more descriptive
// error messages.
OpAsmParserResult AddFOp::parse(OpAsmParser *parser) {
@@ -60,18 +96,14 @@
auto *affineIntTy = builder.getAffineIntType();
AffineMapAttr *mapAttr;
- if (parser->parseAttribute(mapAttr) ||
- parser->parseOperandList(opInfos, -1,
- OpAsmParser::Delimeter::ParenDelimeter))
- return {};
- unsigned numDims = opInfos.size();
-
- if (parser->parseOperandList(
- opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
- parser->resolveOperands(opInfos, affineIntTy, operands))
+ if (parser->parseAttribute(mapAttr))
return {};
+ unsigned numDims;
+ if (parseDimAndSymbolList(parser, opInfos, operands, numDims))
+ return {};
auto *map = mapAttr->getValue();
+
if (map->getNumDims() != numDims ||
numDims + map->getNumSymbols() != opInfos.size()) {
parser->emitError(parser->getNameLoc(),
@@ -88,17 +120,7 @@
void AffineApplyOp::print(OpAsmPrinter *p) const {
auto *map = getAffineMap();
*p << "affine_apply " << *map;
-
- auto opit = operand_begin();
- *p << '(';
- p->printOperands(opit, opit + map->getNumDims());
- *p << ')';
-
- if (map->getNumSymbols()) {
- *p << '[';
- p->printOperands(opit + map->getNumDims(), operand_end());
- *p << ']';
- }
+ printDimAndSymbolList(operand_begin(), operand_end(), map->getNumDims(), p);
}
const char *AffineApplyOp::verify() const {
@@ -121,6 +143,57 @@
return nullptr;
}
+void AllocOp::print(OpAsmPrinter *p) const {
+ MemRefType *type = cast<MemRefType>(getMemRef()->getType());
+ *p << "alloc";
+ // Print dynamic dimension operands.
+ printDimAndSymbolList(operand_begin(), operand_end(),
+ type->getNumDynamicDims(), p);
+ // Print memref type.
+ *p << " : " << *type;
+}
+
+OpAsmParserResult AllocOp::parse(OpAsmParser *parser) {
+ MemRefType *type;
+ SmallVector<SSAValue *, 4> operands;
+ SmallVector<OpAsmParser::OperandType, 4> operandsInfo;
+
+ // Parse the dimension operands and optional symbol operands.
+ unsigned numDimOperands;
+ if (parseDimAndSymbolList(parser, operandsInfo, operands, numDimOperands))
+ return {};
+
+ // Parse memref type.
+ if (parser->parseColonType(type))
+ return {};
+
+ // Check numDynamicDims against number of question marks in memref type.
+ if (numDimOperands != type->getNumDynamicDims()) {
+ parser->emitError(parser->getNameLoc(),
+ "Dynamic dimensions count mismatch: dimension operand "
+ "count does not equal memref dynamic dimension count.");
+ return {};
+ }
+
+ // Check that the number of symbol operands matches the number of symbols in
+ // the first affinemap of the memref's affine map composition.
+ // Note that a memref must specify at least one affine map in the composition.
+ if ((operandsInfo.size() - numDimOperands) !=
+ type->getAffineMaps()[0]->getNumSymbols()) {
+ parser->emitError(parser->getNameLoc(),
+ "AffineMap symbol count mismatch: symbol operand "
+ "count does not equal memref affine map symbol count.");
+ return {};
+ }
+
+ return OpAsmParserResult(operands, type);
+}
+
+const char *AllocOp::verify() const {
+ // TODO(andydavis): Verify alloc.
+ return nullptr;
+}
+
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
const char *ConstantOp::verify() const {
@@ -240,6 +313,7 @@
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
- opSet.addOperations<AddFOp, AffineApplyOp, ConstantOp, DimOp, LoadOp>(
- /*prefix=*/"");
+ opSet
+ .addOperations<AddFOp, AffineApplyOp, AllocOp, ConstantOp, DimOp, LoadOp>(
+ /*prefix=*/"");
}