Finish parser/printer support for AffineMapOp, implement operand iterators on
VariadicOperands, tidy up some code in the asmprinter, fill out more
verification logic in for LoadOp.
PiperOrigin-RevId: 206443020
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 104ff2e..55c0c25 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -52,6 +52,75 @@
return nullptr;
}
+OpAsmParserResult AffineApplyOp::parse(OpAsmParser *parser) {
+ SmallVector<OpAsmParser::OperandType, 2> opInfos;
+ SmallVector<SSAValue *, 4> operands;
+
+ auto &builder = parser->getBuilder();
+ auto *affineIntTy = builder.getAffineIntType();
+
+ AffineMapAttr *mapAttr;
+ if (parser->parseAttribute(mapAttr) ||
+ parser->parseOperandList(opInfos, -1,
+ OpAsmParser::Delimeter::ParenDelimeter))
+ return {};
+ unsigned numDims = opInfos.size();
+
+ if (parser->parseOperandList(
+ opInfos, -1, OpAsmParser::Delimeter::OptionalSquareDelimeter) ||
+ parser->resolveOperands(opInfos, affineIntTy, operands))
+ return {};
+
+ auto *map = mapAttr->getValue();
+ if (map->getNumDims() != numDims ||
+ numDims + map->getNumSymbols() != opInfos.size()) {
+ parser->emitError(parser->getNameLoc(),
+ "dimension or symbol index mismatch");
+ return {};
+ }
+
+ SmallVector<Type *, 4> resultTypes(map->getNumResults(), affineIntTy);
+ return OpAsmParserResult(
+ operands, resultTypes,
+ NamedAttribute(builder.getIdentifier("map"), mapAttr));
+}
+
+void AffineApplyOp::print(OpAsmPrinter *p) const {
+ auto *map = getAffineMap();
+ *p << "affine_apply " << *map;
+
+ auto opit = operand_begin();
+ *p << '(';
+ p->printOperands(opit, opit + map->getNumDims());
+ *p << ')';
+
+ if (map->getNumSymbols()) {
+ *p << '[';
+ p->printOperands(opit + map->getNumDims(), operand_end());
+ *p << ']';
+ }
+}
+
+const char *AffineApplyOp::verify() const {
+ // Check that affine map attribute was specified.
+ auto *affineMapAttr = getAttrOfType<AffineMapAttr>("map");
+ if (!affineMapAttr)
+ return "requires an affine map.";
+
+ // Check input and output dimensions match.
+ auto *map = affineMapAttr->getValue();
+
+ // Verify that operand count matches affine map dimension and symbol count.
+ if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
+ return "operand count and affine map dimension and symbol count must match";
+
+ // Verify that result count matches affine map result count.
+ if (getNumResults() != map->getNumResults())
+ return "result count and affine map result count must match";
+
+ return nullptr;
+}
+
/// The constant op requires an attribute, and furthermore requires that it
/// matches the return type.
const char *ConstantOp::verify() const {
@@ -151,37 +220,26 @@
}
const char *LoadOp::verify() const {
- // TODO: Check load
- return nullptr;
-}
+ if (getNumOperands() == 0)
+ return "expected a memref to load from";
-void AffineApplyOp::print(OpAsmPrinter *p) const {
- // TODO: Print operands etc.
- *p << "affine_apply map: " << *getAffineMap();
-}
+ auto *memRefType = dyn_cast<MemRefType>(getMemRef()->getType());
+ if (!memRefType)
+ return "first operand must be a memref";
-const char *AffineApplyOp::verify() const {
- // Check that affine map attribute was specified
- auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
- if (!affineMapAttr)
- return "requires an affine map.";
+ for (auto *idx : getIndices())
+ if (!idx->getType()->isAffineInt())
+ return "index to load must have 'affineint' type";
- // Check input and output dimensions match.
- auto *map = affineMapAttr->getValue();
+ // TODO: Verify we have the right number of indices.
- // Verify that operand count matches affine map dimension and symbol count.
- if (getNumOperands() != map->getNumDims() + map->getNumSymbols())
- return "operand count and affine map dimension and symbol count must match";
-
- // Verify that result count matches affine map result count.
- if (getNumResults() != map->getNumResults())
- return "result count and affine map result count must match";
-
+ // TODO: in MLFunction verify that the indices are parameters, IV's, or the
+ // result of an affine_apply.
return nullptr;
}
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
- opSet.addOperations<AddFOp, ConstantOp, DimOp, LoadOp, AffineApplyOp>(
+ opSet.addOperations<AddFOp, AffineApplyOp, ConstantOp, DimOp, LoadOp>(
/*prefix=*/"");
}