Add op create helper on CFG and ML builder.
Add create function on builder to make it easier to create ops of registered types. Enables doing `builder.create<AddFOp>(lhs, rhs)` as well as having default values on the build method.
This CL does not add a default build method (i.e., create<DimOp>(...) would fail).
PiperOrigin-RevId: 207268882
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 3e4bbd7..795dde4 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -171,6 +171,12 @@
return op;
}
+ // Create operation of specific op type at the current insertion point.
+ template <typename OpTy, typename... Args>
+ OpPointer<OpTy> create(Args... args) {
+ return OpTy::build(this, args...);
+ }
+
// Terminators.
ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
@@ -262,6 +268,12 @@
return op;
}
+ // Create operation of specific op type at the current insertion point.
+ template <typename OpTy, typename... Args>
+ OpPointer<OpTy> create(Args... args) {
+ return OpTy::build(this, args...);
+ }
+
// Creates for statement. When step is not specified, it is set to 1.
ForStmt *createFor(AffineConstantExpr *lowerBound,
AffineConstantExpr *upperBound,
diff --git a/include/mlir/IR/OpDefinition.h b/include/mlir/IR/OpDefinition.h
index 1b06a62..f88aaf69 100644
--- a/include/mlir/IR/OpDefinition.h
+++ b/include/mlir/IR/OpDefinition.h
@@ -29,6 +29,7 @@
#define MLIR_IR_OPDEFINITION_H
#include "mlir/IR/Operation.h"
+#include "mlir/IR/SSAValue.h"
namespace mlir {
class Type;
@@ -494,6 +495,26 @@
}
};
+/// This class provides verification for ops that are known to have the same
+/// operand and result type.
+template <typename ConcreteType>
+class SameOperandsAndResultType
+ : public TraitBase<ConcreteType, SameOperandsAndResultType> {
+public:
+ static const char *verifyTrait(const Operation *op) {
+ auto *type = op->getResult(0)->getType();
+ for (unsigned i = 1, e = op->getNumResults(); i < e; ++i) {
+ if (op->getResult(i)->getType() != type)
+ return "requires the same type for all operands and results";
+ }
+ for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
+ if (op->getOperand(i)->getType() != type)
+ return "requires the same type for all operands and results";
+ }
+ return nullptr;
+ }
+};
+
} // end namespace OpTrait
} // end namespace mlir
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index ab3a53a..93f9818 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -37,11 +37,19 @@
/// %2 = addf %0, %1 : f32
///
class AddFOp
- : public OpBase<AddFOp, OpTrait::NOperands<2>::Impl, OpTrait::OneResult> {
+ : public OpBase<AddFOp, OpTrait::NOperands<2>::Impl, OpTrait::OneResult,
+ OpTrait::SameOperandsAndResultType> {
public:
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static StringRef getOperationName() { return "addf"; }
+ template <class Builder, class Value>
+ static OpPointer<AddFOp> build(Builder *builder, Value *lhs, Value *rhs) {
+ // The resultant type of a addf is the same as both the lhs and rhs.
+ return OpPointer<AddFOp>(AddFOp(builder->createOperation(
+ builder->getIdentifier("addf"), {lhs, rhs}, {lhs->getType()}, {})));
+ }
+
const char *verify() const;
static OpAsmParserResult parse(OpAsmParser *parser);
void print(OpAsmPrinter *p) const;