Enhance the customizable "Op" implementations in a bunch of ways:
- Op classes can now provide customized matchers, allowing specializations
beyond just a name match.
- We now provide default implementations of verify/print hooks, so Op classes
only need to implement them if they're doing custom stuff, and only have to
implement the ones they're interested in.
- "Base" now takes a variadic list of template template arguments, allowing
concrete Op types to avoid passing the Concrete type multiple times.
- Add new ZeroOperands trait.
- Add verification hooks to Zero/One/Two operands and OneResult to check that
ops using them are correctly formed.
- Implement getOperand hooks to zero/one/two operand traits, and
getResult/getType hook to OneResult trait.
- Add a new "constant" op to show some of this off, with a specialization for
the constant case.
This patch also splits op validity checks out to a new test/IR/invalid-ops.mlir
file.
This stubs out support for default asmprinter support. My next planned patch
building on top of this will make asmprinter hooks real and will revise this.
PiperOrigin-RevId: 205833214
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index 0976a93..d17ddde 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -148,7 +148,7 @@
/// a null OpPointer on failure.
template <typename OpClass>
OpPointer<OpClass> getAs() {
- bool isMatch = getName().is(OpClass::getOperationName());
+ bool isMatch = OpClass::isClassFor(this);
return OpPointer<OpClass>(OpClass(isMatch ? this : nullptr));
}
@@ -157,7 +157,7 @@
/// a null ConstOpPointer on failure.
template <typename OpClass>
ConstOpPointer<OpClass> getAs() const {
- bool isMatch = getName().is(OpClass::getOperationName());
+ bool isMatch = OpClass::isClassFor(this);
return ConstOpPointer<OpClass>(OpClass(isMatch ? this : nullptr));
}
diff --git a/include/mlir/IR/OperationImpl.h b/include/mlir/IR/OperationImpl.h
index d30923d..8d8bc79 100644
--- a/include/mlir/IR/OperationImpl.h
+++ b/include/mlir/IR/OperationImpl.h
@@ -30,6 +30,7 @@
#include "mlir/IR/Operation.h"
namespace mlir {
+class Type;
/// This pointer represents a notional "Operation*" but where the actual
/// storage of the pointer is maintained in the templated "OpType" class.
@@ -72,20 +73,25 @@
namespace OpImpl {
-/// This provides public APIs that all operations should have. The template
-/// argument 'ConcreteType' should be the concrete type by CRTP and the others
-/// are base classes by the policy pattern.
-template <typename ConcreteType, typename... Traits>
-class Base : public Traits... {
+/// This is the concrete base class that holds the operation pointer and has
+/// non-generic methods that only depend on State (to avoid having them
+/// instantiated on template types that don't affect them.
+///
+/// This also has the fallback implementations of customization hooks for when
+/// they aren't customized.
+class BaseState {
public:
/// Return the operation that this refers to.
const Operation *getOperation() const { return state; }
Operation *getOperation() { return state; }
+ /// Return an attribute with the specified name.
+ Attribute *getAttr(StringRef name) const { return state->getAttr(name); }
+
/// If the operation has an attribute of the specified type, return it.
template <typename AttrClass>
AttrClass *getAttrOfType(StringRef name) const {
- return dyn_cast_or_null<AttrClass>(state->getAttr(name));
+ return dyn_cast_or_null<AttrClass>(getAttr(name));
}
/// If the an attribute exists with the specified name, change it to the new
@@ -94,6 +100,43 @@
state->setAttr(name, value, context);
}
+protected:
+ // These are default implementations of customization hooks.
+
+ /// If the concrete type didn't implement a custom verifier hook, just fall
+ /// back to this one which accepts everything.
+ const char *verify() const { return nullptr; }
+
+ // The fallback for the printer is to print it the longhand form.
+ void print(raw_ostream &os) const;
+
+ /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
+ /// so we can cast it away here.
+ explicit BaseState(const Operation *state)
+ : state(const_cast<Operation *>(state)) {}
+
+private:
+ Operation *state;
+};
+
+/// This provides public APIs that all operations should have. The template
+/// argument 'ConcreteType' should be the concrete type by CRTP and the others
+/// are base classes by the policy pattern.
+template <typename ConcreteType, template <typename T> class... Traits>
+class Base : public BaseState, public Traits<ConcreteType>... {
+public:
+ /// Return the operation that this refers to.
+ const Operation *getOperation() const { return BaseState::getOperation(); }
+ Operation *getOperation() { return BaseState::getOperation(); }
+
+ /// Return true if this "op class" can match against the specified operation.
+ /// This hook can be overridden with a more specific implementation in
+ /// the subclass of Base.
+ ///
+ static bool isClassFor(const Operation *op) {
+ return op->getName().is(ConcreteType::getOperationName());
+ }
+
/// This is the hook used by the AsmPrinter to emit this to the .mlir file.
/// Op implementations should provide a print method.
static void printAssembly(const Operation *op, raw_ostream &os) {
@@ -104,16 +147,15 @@
/// delegates to the Traits for their policy implementations, and allows the
/// user to specify their own verify() method.
static const char *verifyInvariants(const Operation *op) {
- if (auto error = BaseVerifier<Traits...>::verifyBase(op))
+ if (auto error = BaseVerifier<Traits<ConcreteType>...>::verifyTrait(op))
return error;
return op->getAs<ConcreteType>()->verify();
}
+ // TODO: Provide a dump() method.
+
protected:
- /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
- /// so we can cast it away here.
- explicit Base(const Operation *state)
- : state(const_cast<Operation *>(state)) {}
+ explicit Base(const Operation *state) : BaseState(state) {}
private:
template <typename... Types>
@@ -121,72 +163,129 @@
template <typename First, typename... Rest>
struct BaseVerifier<First, Rest...> {
- static const char *verifyBase(const Operation *op) {
- if (auto error = First::verifyBase(op))
+ static const char *verifyTrait(const Operation *op) {
+ if (auto error = First::verifyTrait(op))
return error;
- return BaseVerifier<Rest...>::verifyBase(op);
+ return BaseVerifier<Rest...>::verifyTrait(op);
}
};
template <typename First>
struct BaseVerifier<First> {
- static const char *verifyBase(const Operation *op) {
- return First::verifyBase(op);
+ static const char *verifyTrait(const Operation *op) {
+ return First::verifyTrait(op);
}
};
template <>
struct BaseVerifier<> {
- static const char *verifyBase(const Operation *op) {
- return nullptr;
- }
+ static const char *verifyTrait(const Operation *op) { return nullptr; }
};
+};
- Operation *state;
+/// Helper class for implementing traits. Clients are not expected to interact
+/// with this directly, so its members are all protected.
+template <typename ConcreteType, template <typename> class TraitType>
+class TraitImpl {
+protected:
+ /// Return the ultimate Operation being worked on.
+ Operation *getOperation() {
+ // We have to cast up to the trait type, then to the concrete type, then to
+ // the BaseState class in explicit hops because the concrete type will
+ // multiply derive from the (content free) TraitImpl class, and we need to
+ // be able to disambiguate the path for the C++ compiler.
+ auto *trait = static_cast<TraitType<ConcreteType> *>(this);
+ auto *concrete = static_cast<ConcreteType *>(trait);
+ auto *base = static_cast<BaseState *>(concrete);
+ return base->getOperation();
+ }
+ const Operation *getOperation() const {
+ return const_cast<TraitImpl *>(this)->getOperation();
+ }
+
+ /// Provide default implementations of trait hooks. This allows traits to
+ /// provide exactly the overrides they care about.
+ static const char *verifyTrait(const Operation *op) { return nullptr; }
};
/// This class provides the API for ops that are known to have exactly one
/// SSA operand.
-template <typename ConcreteType> class OneOperand {
+template <typename ConcreteType>
+class ZeroOperands : public TraitImpl<ConcreteType, ZeroOperands> {
public:
- SSAValue *getOperand() const {
- return static_cast<ConcreteType *>(this)->getOperand(0);
- }
- void setOperand(SSAValue *value) {
- static_cast<ConcreteType *>(this)->setOperand(0, value);
+ static const char *verifyTrait(const Operation *op) {
+ if (op->getNumOperands() != 0)
+ return "requires zero operands";
+ return nullptr;
}
- static const char *verifyBase(const Operation *op) {
- // TODO: Check that op has one operand.
+private:
+ // Disable these.
+ void getOperand() const {}
+ void setOperand() const {}
+};
+
+/// This class provides the API for ops that are known to have exactly one
+/// SSA operand.
+template <typename ConcreteType>
+class OneOperand : public TraitImpl<ConcreteType, OneOperand> {
+public:
+ const SSAValue *getOperand() const {
+ return this->getOperation()->getOperand(0);
+ }
+
+ SSAValue *getOperand() { return this->getOperation()->getOperand(0); }
+
+ void setOperand(SSAValue *value) {
+ this->getOperation()->setOperand(0, value);
+ }
+
+ static const char *verifyTrait(const Operation *op) {
+ if (op->getNumOperands() != 1)
+ return "requires a single operand";
return nullptr;
}
};
/// This class provides the API for ops that are known to have exactly two
/// SSA operands.
-class TwoOperands {
+template <typename ConcreteType>
+class TwoOperands : public TraitImpl<ConcreteType, TwoOperands> {
public:
- void getOperand() const {
- /// TODO.
- }
- void setOperand() {
- /// TODO.
+ const SSAValue *getOperand(unsigned i) const {
+ return this->getOperation()->getOperand(i);
}
- static const char *verifyBase(const Operation *op) {
- // TODO: Check that op has two operands.
+ SSAValue *getOperand(unsigned i) {
+ return this->getOperation()->getOperand(i);
+ }
+
+ void setOperand(unsigned i, SSAValue *value) {
+ this->getOperation()->setOperand(i, value);
+ }
+
+ static const char *verifyTrait(const Operation *op) {
+ if (op->getNumOperands() != 2)
+ return "requires two operands";
return nullptr;
}
};
/// This class provides return value APIs for ops that are known to have a
/// single result.
-class OneResult {
+template <typename ConcreteType>
+class OneResult : public TraitImpl<ConcreteType, OneResult> {
public:
- // TODO: Implement results!
+ SSAValue *getResult() { return this->getOperation()->getResult(0); }
+ const SSAValue *getResult() const {
+ return this->getOperation()->getResult(0);
+ }
- static const char *verifyBase(const Operation *op) {
- // TODO: Check that op has one result.
+ Type *getType() const { return getResult()->getType(); }
+
+ static const char *verifyTrait(const Operation *op) {
+ if (op->getNumResults() != 1)
+ return "requires one result";
return nullptr;
}
};
diff --git a/include/mlir/IR/OperationSet.h b/include/mlir/IR/OperationSet.h
index dbc912d..3ebab6e 100644
--- a/include/mlir/IR/OperationSet.h
+++ b/include/mlir/IR/OperationSet.h
@@ -38,13 +38,16 @@
public:
template <typename T>
static AbstractOperation get() {
- return AbstractOperation(T::getOperationName(), T::printAssembly,
- T::verifyInvariants);
+ return AbstractOperation(T::getOperationName(), T::isClassFor,
+ T::printAssembly, T::verifyInvariants);
}
/// This is the name of the operation.
const StringRef name;
+ /// Return true if this "op class" can match against the specified operation.
+ bool (&isClassFor)(const Operation *op);
+
/// This hook implements the AsmPrinter for this operation.
void (&printAssembly)(const Operation *op, raw_ostream &os);
@@ -55,10 +58,10 @@
// TODO: Parsing hook.
private:
- AbstractOperation(StringRef name,
+ AbstractOperation(StringRef name, bool (&isClassFor)(const Operation *op),
void (&printAssembly)(const Operation *op, raw_ostream &os),
const char *(&verifyInvariants)(const Operation *op))
- : name(name), printAssembly(printAssembly),
+ : name(name), isClassFor(isClassFor), printAssembly(printAssembly),
verifyInvariants(verifyInvariants) {}
};
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index fd78922..f1c745b 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -50,14 +50,54 @@
explicit AddFOp(const Operation *state) : Base(state) {}
};
-/// The "dim" builtin takes a memref or tensor operand and returns an
+/// The "constant" operation requires a single attribute named "value".
+/// It returns its value as an SSA value. For example:
+///
+/// %1 = "constant"(){value: 42} : i32
+/// %2 = "constant"(){value: @foo} : (f32)->f32
+///
+class ConstantOp
+ : public OpImpl::Base<ConstantOp, OpImpl::ZeroOperands, OpImpl::OneResult> {
+public:
+ Attribute *getValue() const { return getAttr("value"); }
+
+ /// Methods for support type inquiry through isa, cast, and dyn_cast.
+ static StringRef getOperationName() { return "constant"; }
+
+ // Hooks to customize behavior of this op.
+ const char *verify() const;
+
+protected:
+ friend class Operation;
+ explicit ConstantOp(const Operation *state) : Base(state) {}
+};
+
+/// This is a refinement of the "constant" op for the case where it is
+/// returning an integer value.
+///
+/// %1 = "constant"(){value: 42}
+///
+class ConstantIntOp : public ConstantOp {
+public:
+ int64_t getValue() const {
+ return getAttrOfType<IntegerAttr>("value")->getValue();
+ }
+
+ static bool isClassFor(const Operation *op);
+
+private:
+ friend class Operation;
+ explicit ConstantIntOp(const Operation *state) : ConstantOp(state) {}
+};
+
+/// The "dim" operation takes a memref or tensor operand and returns an
/// "affineint". It requires a single integer attribute named "index". It
/// returns the size of the specified dimension. For example:
///
/// %1 = dim %0, 2 : tensor<?x?x?xf32>
///
class DimOp
- : public OpImpl::Base<DimOp, OpImpl::OneOperand<DimOp>, OpImpl::OneResult> {
+ : public OpImpl::Base<DimOp, OpImpl::OneOperand, OpImpl::OneResult> {
public:
/// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() const {
diff --git a/include/mlir/IR/Types.h b/include/mlir/IR/Types.h
index 9029b26..654b558 100644
--- a/include/mlir/IR/Types.h
+++ b/include/mlir/IR/Types.h
@@ -294,6 +294,8 @@
return ArrayRef<int>(shapeElements, getSubclassData());
}
+ unsigned getRank() const { return getShape().size(); }
+
/// Returns the elemental type for this memref shape.
Type *getElementType() const { return elementType; }
diff --git a/lib/IR/OperationSet.cpp b/lib/IR/OperationSet.cpp
index 298d0a2..bc7ba34 100644
--- a/lib/IR/OperationSet.cpp
+++ b/lib/IR/OperationSet.cpp
@@ -16,11 +16,17 @@
// =============================================================================
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/OperationImpl.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
using llvm::StringMap;
+// The fallback for the printer is to print it the longhand form.
+void OpImpl::BaseState::print(raw_ostream &os) const {
+ os << "FIXME: IMPLEMENT DEFAULT PRINTER";
+}
+
static StringMap<AbstractOperation> &getImpl(void *pImpl) {
return *static_cast<StringMap<AbstractOperation> *>(pImpl);
}
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index d1ddf65..ce9857a 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -17,6 +17,8 @@
#include "mlir/IR/StandardOps.h"
#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/SSAValue.h"
+#include "mlir/IR/Types.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -32,24 +34,62 @@
return nullptr;
}
+/// The constant op requires an attribute, and furthermore requires that it
+/// matches the return type.
+const char *ConstantOp::verify() const {
+ auto *value = getValue();
+ if (!value)
+ return "requires a 'value' attribute";
+
+ auto *type = this->getType();
+ if (isa<IntegerType>(type)) {
+ if (!isa<IntegerAttr>(value))
+ return "requires 'value' to be an integer for an integer result type";
+ return nullptr;
+ }
+
+ if (isa<FunctionType>(type)) {
+ // TODO: Verify a function attr.
+ }
+
+ return "requires a result type that aligns with the 'value' attribute";
+}
+
+/// ConstantIntOp only matches values whose result type is an IntegerType.
+bool ConstantIntOp::isClassFor(const Operation *op) {
+ return ConstantOp::isClassFor(op) &&
+ isa<IntegerType>(op->getResult(0)->getType());
+}
+
void DimOp::print(raw_ostream &os) const {
os << "dim xxx, " << getIndex() << " : sometype";
}
const char *DimOp::verify() const {
- // TODO: Check that the operand has tensor or memref type.
-
// Check that we have an integer index operand.
auto indexAttr = getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
- return "'dim' op requires an integer attribute named 'index'";
+ return "requires an integer attribute named 'index'";
+ uint64_t index = (uint64_t)indexAttr->getValue();
- // TODO: Check that the index is in range.
+ auto *type = getOperand()->getType();
+ if (auto *tensorType = dyn_cast<RankedTensorType>(type)) {
+ if (index >= tensorType->getRank())
+ return "index is out of range";
+ } else if (auto *memrefType = dyn_cast<MemRefType>(type)) {
+ if (index >= memrefType->getRank())
+ return "index is out of range";
+
+ } else if (isa<UnrankedTensorType>(type)) {
+ // ok, assumed to be in-range.
+ } else {
+ return "requires an operand with tensor or memref type";
+ }
return nullptr;
}
/// Install the standard operations in the specified operation set.
void mlir::registerStandardOperations(OperationSet &opSet) {
- opSet.addOperations<AddFOp, DimOp>(/*prefix=*/ "");
+ opSet.addOperations<AddFOp, ConstantOp, DimOp>(/*prefix=*/"");
}
diff --git a/lib/IR/Verifier.cpp b/lib/IR/Verifier.cpp
index aaec3ce..cb622a2 100644
--- a/lib/IR/Verifier.cpp
+++ b/lib/IR/Verifier.cpp
@@ -216,7 +216,8 @@
// See if we can get operation info for this.
if (auto *opInfo = inst.getAbstractOperation(fn.getContext())) {
if (auto errorMessage = opInfo->verifyInvariants(&inst))
- return failure(errorMessage, inst);
+ return failure(Twine("'") + inst.getName().str() + "' op " + errorMessage,
+ inst);
}
return false;
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 90e3dd3..daf20db 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -1576,7 +1576,7 @@
// source location.
if (auto *opInfo = op->getAbstractOperation(builder.getContext())) {
if (auto error = opInfo->verifyInvariants(op))
- return emitError(loc, error);
+ return emitError(loc, Twine("'") + op->getName().str() + "' op " + error);
}
// If the instruction had a name, register it.
diff --git a/test/IR/invalid-ops.mlir b/test/IR/invalid-ops.mlir
new file mode 100644
index 0000000..cd3f986
--- /dev/null
+++ b/test/IR/invalid-ops.mlir
@@ -0,0 +1,35 @@
+// TODO(andydavis) Resolve relative path issue w.r.t invoking mlir-opt in RUN
+// statements (perhaps through using lit config substitutions).
+//
+// RUN: %S/../../mlir-opt %s -o - -check-parser-errors
+
+cfgfunc @dim(tensor<1xf32>) {
+bb(%0: tensor<1xf32>):
+ "dim"(%0){index: "xyz"} : (tensor<1xf32>)->i32 // expected-error {{'dim' op requires an integer attribute named 'index'}}
+ return
+}
+
+// -----
+
+cfgfunc @dim2(tensor<1xf32>) {
+bb(%0: tensor<1xf32>):
+ "dim"(){index: "xyz"} : ()->i32 // expected-error {{'dim' op requires a single operand}}
+ return
+}
+
+// -----
+
+cfgfunc @dim3(tensor<1xf32>) {
+bb(%0: tensor<1xf32>):
+ "dim"(%0){index: 1} : (tensor<1xf32>)->i32 // expected-error {{'dim' op index is out of range}}
+ return
+}
+
+// -----
+
+cfgfunc @constant() {
+bb:
+ %x = "constant"(){value: "xyz"} : () -> i32 // expected-error {{'constant' op requires 'value' to be an integer for an integer result type}}
+ return
+}
+
diff --git a/test/IR/parser-errors.mlir b/test/IR/parser-errors.mlir
index 4714d77..f9582e8 100644
--- a/test/IR/parser-errors.mlir
+++ b/test/IR/parser-errors.mlir
@@ -175,14 +175,6 @@
// -----
-cfgfunc @malformed_dim() {
-bb42:
- "dim"(){index: "xyz"} : ()->i32 // expected-error {{'dim' op requires an integer attribute named 'index'}}
- return
-}
-
-// -----
-
#map = (d0) -> (% // expected-error {{invalid SSA name}}
// -----
@@ -197,8 +189,8 @@
cfgfunc @redef() {
bb42:
- %x = "dim"(){index: 0} : ()->i32 // expected-error {{previously defined here}}
- %x = "dim"(){index: 0} : ()->i32 // expected-error {{redefinition of SSA value '%x'}}
+ %x = "xxx"(){index: 0} : ()->i32 // expected-error {{previously defined here}}
+ %x = "xxx"(){index: 0} : ()->i32 // expected-error {{redefinition of SSA value '%x'}}
return
}
diff --git a/test/IR/parser.mlir b/test/IR/parser.mlir
index b1b91b5..41993b9 100644
--- a/test/IR/parser.mlir
+++ b/test/IR/parser.mlir
@@ -112,17 +112,17 @@
return %a // CHECK: return
}
-// CHECK-LABEL: cfgfunc @cfgfunc_with_ops() {
-cfgfunc @cfgfunc_with_ops() {
-bb0:
- // CHECK: %0 = "getTensor"() : () -> tensor<4x4x?xf32>
+// CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) {
+cfgfunc @cfgfunc_with_ops(f32) {
+bb0(%a : f32):
+ // CHECK: %1 = "getTensor"() : () -> tensor<4x4x?xf32>
%t = "getTensor"() : () -> tensor<4x4x?xf32>
// CHECK: dim xxx, 2 : sometype
- %a = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
+ %t2 = "dim"(%t){index: 2} : (tensor<4x4x?xf32>) -> affineint
// CHECK: addf xx, yy : sometype
- "addf"() : () -> ()
+ %x = "addf"(%a, %a) : (f32,f32) -> (f32)
// CHECK: return
return
@@ -187,6 +187,9 @@
%f = "Const"(){value: 1} : () -> f32
// CHECK: addf xx, yy : sometype
"addf"(%f, %f) : (f32,f32) -> f32
+
+ // TODO: CHECK: FIXME: IMPLEMENT DEFAULT PRINTER
+ %x = "constant"(){value: 42} : () -> i32
return
}