Enhance the customizable "Op" implementations in a bunch of ways:
- Op classes can now provide customized matchers, allowing specializations
beyond just a name match.
- We now provide default implementations of verify/print hooks, so Op classes
only need to implement them if they're doing custom stuff, and only have to
implement the ones they're interested in.
- "Base" now takes a variadic list of template template arguments, allowing
concrete Op types to avoid passing the Concrete type multiple times.
- Add new ZeroOperands trait.
- Add verification hooks to Zero/One/Two operands and OneResult to check that
ops using them are correctly formed.
- Implement getOperand hooks to zero/one/two operand traits, and
getResult/getType hook to OneResult trait.
- Add a new "constant" op to show some of this off, with a specialization for
the constant case.
This patch also splits op validity checks out to a new test/IR/invalid-ops.mlir
file.
This stubs out support for default asmprinter support. My next planned patch
building on top of this will make asmprinter hooks real and will revise this.
PiperOrigin-RevId: 205833214
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index d1ddf65..ce9857a 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -17,6 +17,8 @@
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/SSAValue.h"
+#include "mlir/IR/Types.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -32,24 +34,62 @@
return nullptr;
}
+/// The constant op requires an attribute, and furthermore requires that it
+/// matches the return type.
+const char *ConstantOp::verify() const {
+ auto *value = getValue();
+ if (!value)
+ return "requires a 'value' attribute";
+
+ auto *type = this->getType();
+ if (isa<IntegerType>(type)) {
+ if (!isa<IntegerAttr>(value))
+ return "requires 'value' to be an integer for an integer result type";
+ return nullptr;
+ }
+
+ if (isa<FunctionType>(type)) {
+ // TODO: Verify a function attr.
+ }
+
+ return "requires a result type that aligns with the 'value' attribute";
+}
+
+/// ConstantIntOp only matches values whose result type is an IntegerType.
+bool ConstantIntOp::isClassFor(const Operation *op) {
+ return ConstantOp::isClassFor(op) &&
+ isa<IntegerType>(op->getResult(0)->getType());
+}
+
void DimOp::print(raw_ostream &os) const {
os << "dim xxx, " << getIndex() << " : sometype";
}
const char *DimOp::verify() const {
- // TODO: Check that the operand has tensor or memref type.
-
// Check that we have an integer index operand.
auto indexAttr = getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
- return "'dim' op requires an integer attribute named 'index'";
+ return "requires an integer attribute named 'index'";
+ uint64_t index = (uint64_t)indexAttr->getValue();
- // TODO: Check that the index is in range.
+ auto *type = getOperand()->getType();
+ if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
+ if (index >= tensorType->getRank())
+ return "index is out of range";
+ } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
+ if (index >= memrefType->getRank())
+ return "index is out of range";
+
+ } else if (isa<UnrankedTensorType>(type)) {
+ // ok, assumed to be in-range.
+ } else {
+ return "requires an operand with tensor or memref type";
+ }
return nullptr;
}
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
- opSet.addOperations<AddFOp, DimOp>(/*prefix=*/ "");
+ opSet.addOperations<AddFOp, ConstantOp, DimOp>(/*prefix=*/"");
}